#pragma once
#include <list>
#include <memory>
#include <mutex>
#include <string>
#include <unordered_map>
#include "constants.h"
#include "filesystem/api.h"
#include "server_message.h"
#include "status.h"
#include "triton/common/model_config.h"
#include "tritonserver_apis.h"
namespace triton { namespace core {
class TritonBackend {
public:
struct Attribute {
Attribute()
: exec_policy_(TRITONBACKEND_EXECUTION_BLOCKING),
parallel_instance_loading_(false)
{
}
TRITONBACKEND_ExecutionPolicy exec_policy_;
std::vector<inference::ModelInstanceGroup> preferred_groups_;
bool parallel_instance_loading_;
};
typedef TRITONSERVER_Error* (*TritonModelInitFn_t)(
TRITONBACKEND_Model* model);
typedef TRITONSERVER_Error* (*TritonModelFiniFn_t)(
TRITONBACKEND_Model* model);
typedef TRITONSERVER_Error* (*TritonModelInstanceInitFn_t)(
TRITONBACKEND_ModelInstance* instance);
typedef TRITONSERVER_Error* (*TritonModelInstanceFiniFn_t)(
TRITONBACKEND_ModelInstance* instance);
typedef TRITONSERVER_Error* (*TritonModelInstanceExecFn_t)(
TRITONBACKEND_ModelInstance* instance, TRITONBACKEND_Request** requests,
const uint32_t request_cnt);
static Status Create(
const std::string& name, const std::string& dir,
const std::string& libpath,
const triton::common::BackendCmdlineConfig& backend_cmdline_config,
std::shared_ptr<TritonBackend>* backend);
~TritonBackend();
const std::string& Name() const { return name_; }
const std::string& Directory() const { return dir_; }
const std::string& LibPath() const { return libpath_; }
const TritonServerMessage& BackendConfig() const { return backend_config_; }
const Attribute& BackendAttributes() const { return attributes_; }
TRITONBACKEND_ExecutionPolicy ExecutionPolicy() const
{
return attributes_.exec_policy_;
}
void SetExecutionPolicy(const TRITONBACKEND_ExecutionPolicy policy)
{
attributes_.exec_policy_ = policy;
}
void* State() { return state_; }
void SetState(void* state) { state_ = state; }
bool IsPythonBackendBased() { return is_python_based_backend_; }
void SetPythonBasedBackendFlag(bool is_python_based_backend)
{
is_python_based_backend_ = is_python_based_backend;
}
TritonModelInitFn_t ModelInitFn() const { return model_init_fn_; }
TritonModelFiniFn_t ModelFiniFn() const { return model_fini_fn_; }
TritonModelInstanceInitFn_t ModelInstanceInitFn() const
{
return inst_init_fn_;
}
TritonModelInstanceFiniFn_t ModelInstanceFiniFn() const
{
return inst_fini_fn_;
}
TritonModelInstanceExecFn_t ModelInstanceExecFn() const
{
return inst_exec_fn_;
}
private:
typedef TRITONSERVER_Error* (*TritonBackendInitFn_t)(
TRITONBACKEND_Backend* backend);
typedef TRITONSERVER_Error* (*TritonBackendFiniFn_t)(
TRITONBACKEND_Backend* backend);
typedef TRITONSERVER_Error* (*TritonBackendAttriFn_t)(
TRITONBACKEND_Backend* backend,
TRITONBACKEND_BackendAttribute* backend_attributes);
TritonBackend(
const std::string& name, const std::string& dir,
const std::string& libpath, const TritonServerMessage& backend_config);
void ClearHandles();
Status LoadBackendLibrary();
Status UpdateAttributes();
const std::string name_;
const std::string dir_;
const std::string libpath_;
bool is_python_based_backend_;
TritonServerMessage backend_config_;
Attribute attributes_;
void* dlhandle_;
TritonBackendInitFn_t backend_init_fn_;
TritonBackendFiniFn_t backend_fini_fn_;
TritonBackendAttriFn_t backend_attri_fn_;
TritonModelInitFn_t model_init_fn_;
TritonModelFiniFn_t model_fini_fn_;
TritonModelInstanceInitFn_t inst_init_fn_;
TritonModelInstanceFiniFn_t inst_fini_fn_;
TritonModelInstanceExecFn_t inst_exec_fn_;
void* state_;
};
class TritonBackendManager {
public:
static Status Create(std::shared_ptr<TritonBackendManager>* manager);
Status CreateBackend(
const std::string& name, const std::string& dir,
const std::string& libpath,
const triton::common::BackendCmdlineConfig& backend_cmdline_config,
bool is_python_based_backend, std::shared_ptr<TritonBackend>* backend);
Status BackendState(
std::unique_ptr<
std::unordered_map<std::string, std::vector<std::string>>>*
backend_state);
private:
DISALLOW_COPY_AND_ASSIGN(TritonBackendManager);
TritonBackendManager() = default;
std::unordered_map<std::string, std::shared_ptr<TritonBackend>> backend_map_;
};
}}