#pragma once
#include <stddef.h>
#include <stdint.h>
#include <atomic>
#include <map>
#include <string>
#include <thread>
#include <vector>
#include "backend_manager.h"
#include "cache_manager.h"
#include "infer_parameter.h"
#include "model_config.pb.h"
#include "model_repository_manager/model_repository_manager.h"
#include "rate_limiter.h"
#include "status.h"
#include "triton/common/model_config.h"
namespace triton { namespace core {
using CacheConfigMap = std::unordered_map<std::string, std::string>;
class Model;
class InferenceRequest;
enum class ModelControlMode { MODE_NONE, MODE_POLL, MODE_EXPLICIT };
enum class RateLimitMode { RL_EXEC_COUNT, RL_OFF };
enum class ServerReadyState {
SERVER_INVALID,
SERVER_INITIALIZING,
SERVER_READY,
SERVER_EXITING,
SERVER_FAILED_TO_INITIALIZE
};
class InferenceServer {
public:
InferenceServer();
~InferenceServer();
Status Init();
Status Stop(const bool force = false);
Status PollModelRepository();
Status IsLive(bool* live);
Status IsReady(bool* ready);
Status ModelIsReady(
const std::string& model_name, const int64_t model_version, bool* ready);
Status ModelReadyVersions(
const std::string& model_name, std::vector<int64_t>* versions);
Status ModelReadyVersions(
std::map<std::string, std::vector<int64_t>>* model_versions);
Status RepositoryIndex(
const bool ready_only,
std::vector<ModelRepositoryManager::ModelIndex>* index);
Status InferAsync(std::unique_ptr<InferenceRequest>& request);
Status LoadModel(
const std::unordered_map<
std::string, std::vector<const InferenceParameter*>>& models);
Status UnloadModel(
const std::string& model_name, const bool unload_dependents);
Status PrintBackendAndModelSummary();
Status RegisterModelRepository(
const std::string& repository,
const std::unordered_map<std::string, std::string>& model_mapping);
Status UnregisterModelRepository(const std::string& repository);
const std::string& Version() const { return version_; }
const std::vector<const char*>& Extensions() const { return extensions_; }
const std::string& Id() const { return id_; }
void SetId(const std::string& id) { id_ = id; }
const std::set<std::string>& ModelRepositoryPaths() const
{
return model_repository_paths_;
}
void SetModelRepositoryPaths(const std::set<std::string>& p)
{
model_repository_paths_ = p;
}
ModelControlMode GetModelControlMode() const { return model_control_mode_; }
void SetModelControlMode(ModelControlMode m) { model_control_mode_ = m; }
const std::set<std::string>& StartupModels() const { return startup_models_; }
void SetStartupModels(const std::set<std::string>& m) { startup_models_ = m; }
bool StrictModelConfigEnabled() const { return strict_model_config_; }
void SetStrictModelConfigEnabled(bool e) { strict_model_config_ = e; }
std::string ModelConfigName() const { return model_config_name_; }
void SetModelConfigName(const std::string& name)
{
model_config_name_ = name;
}
RateLimitMode RateLimiterMode() const { return rate_limit_mode_; }
void SetRateLimiterMode(RateLimitMode m) { rate_limit_mode_ = m; }
const RateLimiter::ResourceMap& RateLimiterResources() const
{
return rate_limit_resource_map_;
}
void SetRateLimiterResources(const RateLimiter::ResourceMap& rm)
{
rate_limit_resource_map_ = rm;
}
int64_t PinnedMemoryPoolByteSize() const { return pinned_memory_pool_size_; }
void SetPinnedMemoryPoolByteSize(int64_t s)
{
pinned_memory_pool_size_ = std::max((int64_t)0, s);
}
bool ResponseCacheEnabled()
{
return response_cache_enabled_ && CacheManager() && CacheManager()->Cache();
}
void SetResponseCacheEnabled(bool e) { response_cache_enabled_ = e; }
void SetCacheConfig(CacheConfigMap cfg) { cache_config_map_ = cfg; }
std::string CacheDir() const { return cache_dir_; }
void SetCacheDir(std::string dir) { cache_dir_ = dir; }
const std::map<int, uint64_t>& CudaMemoryPoolByteSize() const
{
return cuda_memory_pool_size_;
}
void SetCudaMemoryPoolByteSize(const std::map<int, uint64_t>& s)
{
cuda_memory_pool_size_ = s;
}
const std::map<int, size_t>& CudaVirtualAddressSpaceSize() const
{
return cuda_virtual_address_space_size_;
}
void SetCudaVirtualAddressSpaceSize(const std::map<int, size_t>& s)
{
cuda_virtual_address_space_size_ = s;
}
double MinSupportedComputeCapability() const
{
return min_supported_compute_capability_;
}
void SetMinSupportedComputeCapability(double c)
{
min_supported_compute_capability_ = c;
}
bool StrictReadinessEnabled() const { return strict_readiness_; }
void SetStrictReadinessEnabled(bool e) { strict_readiness_ = e; }
int32_t ExitTimeoutSeconds() const { return exit_timeout_secs_; }
void SetExitTimeoutSeconds(int32_t s) { exit_timeout_secs_ = std::max(0, s); }
void SetBufferManagerThreadCount(unsigned int c)
{
buffer_manager_thread_count_ = c;
}
void SetModelLoadThreadCount(unsigned int c) { model_load_thread_count_ = c; }
void SetModelLoadRetryCount(unsigned int c) { model_load_retry_count_ = c; }
void SetModelNamespacingEnabled(const bool e)
{
enable_model_namespacing_ = e;
}
void SetEnablePeerAccess(const bool e) { enable_peer_access_ = e; }
void SetBackendCmdlineConfig(
const triton::common::BackendCmdlineConfigMap& bc)
{
backend_cmdline_config_map_ = bc;
}
void SetHostPolicyCmdlineConfig(
const triton::common::HostPolicyCmdlineConfigMap& hp)
{
host_policy_map_ = hp;
}
void SetRepoAgentDir(const std::string& d) { repoagent_dir_ = d; }
Status GetModel(
const std::string& model_name, const int64_t model_version,
std::shared_ptr<Model>* model)
{
if ((ready_state_ != ServerReadyState::SERVER_READY) &&
(ready_state_ != ServerReadyState::SERVER_EXITING)) {
return Status(Status::Code::UNAVAILABLE, "Server not ready");
}
return model_repository_manager_->GetModel(
model_name, model_version, model);
}
Status GetModel(
const ModelIdentifier& model_id, const int64_t model_version,
std::shared_ptr<Model>* model)
{
if ((ready_state_ != ServerReadyState::SERVER_READY) &&
(ready_state_ != ServerReadyState::SERVER_EXITING)) {
return Status(Status::Code::UNAVAILABLE, "Server not ready");
}
return model_repository_manager_->GetModel(model_id, model_version, model);
}
const std::shared_ptr<TritonBackendManager>& BackendManager()
{
return backend_manager_;
}
std::shared_ptr<RateLimiter> GetRateLimiter() { return rate_limiter_; }
const std::shared_ptr<TritonCacheManager>& CacheManager()
{
return cache_manager_;
}
private:
const std::string version_;
std::string id_;
std::vector<const char*> extensions_;
std::set<std::string> model_repository_paths_;
std::set<std::string> startup_models_;
ModelControlMode model_control_mode_;
bool strict_model_config_;
bool strict_readiness_;
std::string model_config_name_;
uint32_t exit_timeout_secs_;
uint32_t buffer_manager_thread_count_;
uint32_t model_load_thread_count_;
uint32_t model_load_retry_count_;
bool enable_model_namespacing_;
bool enable_peer_access_;
uint64_t pinned_memory_pool_size_;
bool response_cache_enabled_;
CacheConfigMap cache_config_map_;
std::string cache_dir_;
std::map<int, uint64_t> cuda_memory_pool_size_;
std::map<int, size_t> cuda_virtual_address_space_size_;
double min_supported_compute_capability_;
triton::common::BackendCmdlineConfigMap backend_cmdline_config_map_;
triton::common::HostPolicyCmdlineConfigMap host_policy_map_;
std::string repoagent_dir_;
RateLimitMode rate_limit_mode_;
RateLimiter::ResourceMap rate_limit_resource_map_;
ServerReadyState ready_state_;
std::atomic<uint64_t> inflight_request_counter_;
std::shared_ptr<RateLimiter> rate_limiter_;
std::unique_ptr<ModelRepositoryManager> model_repository_manager_;
std::shared_ptr<TritonBackendManager> backend_manager_;
std::shared_ptr<TritonCacheManager> cache_manager_;
};
}}