#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
use engine::Engine;
pub use engine::{
agentic_session::{AgenticSessionStore, SerializedSession, SerializedVideo},
get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
EngineInstruction, IntervalLogger, SearchEmbeddingModel, DEFAULT_MAX_TOOL_ROUNDS,
ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
};
use hanzo_ml::Device;
use hf_hub::Cache;
pub use lora::Ordering;
pub use pipeline::ModelCategory;
pub use pipeline::Pipeline;
#[cfg(feature = "pyo3_macros")]
use pyo3::exceptions::PyValueError;
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::pin::Pin;
use std::sync::OnceLock;
use std::time::{Duration, Instant};
use std::{
cell::RefCell,
error::Error,
fs::OpenOptions,
io::Write,
sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
thread::{self, JoinHandle},
time::{SystemTime, UNIX_EPOCH},
};
use tokio::sync::mpsc::{channel, Sender};
use tracing::{debug, info, warn};
pub const HANZO_GIT_REVISION: &str = match option_env!("HANZO_GIT_REVISION") {
Some(value) => value,
None => "unknown",
};
mod cuda;
mod device_map;
mod engine;
mod lora;
mod metal;
mod model_loader;
mod moe;
mod ops;
mod video_input;
mod vulkan;
pub use model_loader::{
get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
};
pub use video_input::{sample_frame_indices, VideoInput};
pub mod disk_kv_cache;
mod embedding_models;
mod kv_cache;
mod search;
mod model_selected;
pub use model_selected::ModelSelected;
pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
mod amoe;
mod attention;
mod diagnostics;
mod diffusion_models;
pub mod distributed;
pub mod files;
mod gguf;
pub mod layers;
mod layers_masker;
mod layers_utils;
pub mod matformer;
mod mla;
mod models;
mod paged_attention;
mod pipeline;
mod prefix_cacher;
pub mod reasoning_parsers;
mod request;
mod response;
mod sampler;
mod scheduler;
mod sequence;
pub mod speculative;
mod speech_models;
mod toml_selector;
mod tools;
mod topology;
mod utils;
mod vision_models;
mod xlora_models;
pub use diagnostics::{
check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
};
mod tuning;
pub use tuning::{
auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
};
pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
pub use device_map::{
DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
};
pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
pub use hanzo_audio::AudioInput;
pub use hanzo_llm_mcp::{
AgentPermission, AgentToolApprovalNotifier, AgentToolApprovalRequest, AgentToolKind,
AgentToolMetadata, AgentToolSource, CalledFunction, CodeExecutionApprovalNotifier,
CodeExecutionApprovalRequest, CodeExecutionPermission, Function, MultimodalToolCallback, Tool,
ToolCallContext, ToolCallback, ToolCallbackKind, ToolCallbackWithTool, ToolOutput, ToolType,
};
pub use hanzo_llm_mcp::{
McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
};
pub use hanzo_quant::{IsqBits, IsqType, MULTI_LORA_DELIMITER};
pub use hanzo_sandbox::{NetworkMode, SandboxPolicy};
#[derive(Clone, serde::Serialize, serde::Deserialize)]
pub struct CodeExecutionConfig {
#[serde(default = "default_python_path")]
pub python_path: std::path::PathBuf,
#[serde(default = "default_timeout_secs")]
pub timeout_secs: u64,
#[serde(default)]
pub working_directory: Option<std::path::PathBuf>,
#[serde(default)]
pub sandbox_policy: Option<hanzo_sandbox::SandboxPolicy>,
#[serde(default)]
pub permission: CodeExecutionPermission,
#[serde(skip)]
pub approval_callback: Option<CodeExecutionApprovalCallback>,
}
impl std::fmt::Debug for CodeExecutionConfig {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CodeExecutionConfig")
.field("python_path", &self.python_path)
.field("timeout_secs", &self.timeout_secs)
.field("working_directory", &self.working_directory)
.field("sandbox_policy", &self.sandbox_policy)
.field("permission", &self.permission)
.field("approval_callback", &self.approval_callback.is_some())
.finish()
}
}
#[derive(Clone, Debug)]
pub struct AgentToolApproval {
pub approval_id: String,
pub session_id: String,
pub round: usize,
pub tool: AgentToolMetadata,
pub arguments: serde_json::Value,
}
#[derive(Clone, Debug)]
pub struct AgentToolApprovalDecision {
pub approve: bool,
pub remember_for_session: bool,
pub message: Option<String>,
}
impl AgentToolApprovalDecision {
pub fn approve() -> Self {
Self {
approve: true,
remember_for_session: false,
message: None,
}
}
pub fn approve_for_session() -> Self {
Self {
approve: true,
remember_for_session: true,
message: None,
}
}
pub fn deny(message: Option<String>) -> Self {
Self {
approve: false,
remember_for_session: false,
message,
}
}
pub fn deny_with_message(message: impl Into<String>) -> Self {
Self {
approve: false,
remember_for_session: false,
message: Some(message.into()),
}
}
pub fn with_remember_for_session(mut self, remember_for_session: bool) -> Self {
self.remember_for_session = remember_for_session;
self
}
}
pub type AgentToolApprovalCallback =
Arc<dyn Fn(&AgentToolApproval) -> AgentToolApprovalDecision + Send + Sync + 'static>;
pub type AgentToolApprovalFuture =
Pin<Box<dyn Future<Output = AgentToolApprovalDecision> + Send + 'static>>;
pub type AgentToolApprovalAsyncCallback =
Arc<dyn Fn(AgentToolApproval) -> AgentToolApprovalFuture + Send + Sync + 'static>;
#[derive(Clone)]
pub enum AgentToolApprovalHandler {
Sync(AgentToolApprovalCallback),
Async(AgentToolApprovalAsyncCallback),
}
impl AgentToolApprovalHandler {
pub fn from_sync(callback: AgentToolApprovalCallback) -> Self {
Self::Sync(callback)
}
pub fn from_async(callback: AgentToolApprovalAsyncCallback) -> Self {
Self::Async(callback)
}
}
#[derive(Clone, Debug)]
pub struct CodeExecutionApproval {
pub approval_id: String,
pub session_id: String,
pub code: String,
pub outputs: Vec<String>,
pub working_directory: Option<std::path::PathBuf>,
}
pub type CodeExecutionApprovalCallback =
Arc<dyn Fn(&CodeExecutionApproval) -> bool + Send + Sync + 'static>;
fn default_python_path() -> std::path::PathBuf {
if cfg!(windows) {
std::path::PathBuf::from("python")
} else {
std::path::PathBuf::from("python3")
}
}
fn default_timeout_secs() -> u64 {
30
}
impl Default for CodeExecutionConfig {
fn default() -> Self {
Self {
python_path: default_python_path(),
timeout_secs: default_timeout_secs(),
working_directory: None,
sandbox_policy: None,
permission: CodeExecutionPermission::Auto,
approval_callback: None,
}
}
}
pub use files::{
format_from_name, is_text_mime, mime_for_format, File, FileContent, FileSource, FileStore,
RequestedFile, MODEL_INLINE_BYTES, WIRE_EMBED_LIMIT_BYTES,
};
pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
pub use pipeline::hf::{
hf_home_dir, hf_hub_cache_dir, hf_token_path, is_hf_hub_offline, probe_hf_repo_files,
HF_HUB_OFFLINE_ENV,
};
pub use pipeline::{
chat_template::ChatTemplate, expand_isq_value, parse_isq_value, parse_uqff_shard,
resolve_uqff_shorthand, AdapterPaths, AnyMoeLoader, AnyMoePipeline, AutoDeviceMapParams,
AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
ModelPaths, MultimodalLoader, MultimodalLoaderBuilder, MultimodalLoaderType,
MultimodalPromptPrefixer, MultimodalSpecificConfig, NormalLoader, NormalLoaderBuilder,
NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
SpeechLoader, SpeechPipeline, Starcoder2Loader, SupportedModality, TokenSource,
UQFF_MULTI_FILE_DELIMITER,
};
pub use request::{
ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
};
pub use response::*;
pub use sampler::{
CustomLogitsProcessor, DrySamplingParams, ModelGenerationDefaults, SamplingParams, StopTokens,
TopLogprob,
};
pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
use serde::Serialize;
pub use speculative::{MtpConfig, SpeculativeConfig};
pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
use tokio::runtime::Runtime;
use toml_selector::{TomlLoaderArgs, TomlSelector};
pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
pub use topology::{LayerTopology, Topology};
pub use utils::debug::initialize_logging;
pub use utils::memory_usage::MemoryUsage;
pub use utils::normal::{ModelDType, TryIntoDType};
pub use utils::{paged_attn_supported, using_flash_attn};
pub use llguidance;
pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
#[derive(Clone)]
pub struct EngineConfig {
pub no_kv_cache: bool,
pub no_prefix_cache: bool,
pub prefix_cache_n: usize,
pub disable_eos_stop: bool,
pub throughput_logging_enabled: bool,
pub search_embedding_model: Option<SearchEmbeddingModel>,
pub search_callback: Option<Arc<SearchCallback>>,
pub tool_callbacks: tools::ToolCallbacksWithTools,
}
impl Default for EngineConfig {
fn default() -> Self {
Self {
no_kv_cache: false,
no_prefix_cache: false,
prefix_cache_n: 16,
disable_eos_stop: false,
throughput_logging_enabled: true,
search_embedding_model: None,
search_callback: None,
tool_callbacks: HashMap::new(),
}
}
}
#[derive(Clone)]
pub struct AddModelConfig {
pub engine_config: EngineConfig,
pub mcp_client_config: Option<McpClientConfig>,
pub loader_config: Option<ModelLoaderConfig>,
pub code_exec_config: Option<CodeExecutionConfig>,
}
impl AddModelConfig {
pub fn new(engine_config: EngineConfig) -> Self {
Self {
engine_config,
mcp_client_config: None,
loader_config: None,
code_exec_config: None,
}
}
pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
self.mcp_client_config = Some(mcp_config);
self
}
pub fn with_code_execution(mut self, config: CodeExecutionConfig) -> Self {
self.code_exec_config = Some(config);
self
}
pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
self.loader_config = Some(loader_config);
self
}
}
#[derive(Clone)]
pub struct HanzoConfig {
pub kind: ModelKind,
pub device: Device,
pub category: ModelCategory,
pub modalities: Modalities,
pub max_seq_len: Option<usize>,
pub generation_defaults: Option<ModelGenerationDefaults>,
}
#[derive(Clone)]
pub struct ModelLoaderConfig {
pub model_selected: ModelSelected,
pub token_source: TokenSource,
pub hf_revision: Option<String>,
pub dtype: ModelDType,
pub device: Device,
pub device_map_setting: DeviceMapSetting,
pub isq: Option<IsqType>,
pub paged_attn_config: Option<PagedAttentionConfig>,
pub silent: bool,
pub chat_template: Option<String>,
pub jinja_explicit: Option<String>,
pub mtp_config: Option<MtpConfig>,
}
#[derive(Clone)]
pub struct UnloadedModelState {
pub loader_config: ModelLoaderConfig,
pub scheduler_config: SchedulerConfig,
pub engine_config: EngineConfig,
pub mcp_client_config: Option<McpClientConfig>,
pub category: ModelCategory,
pub hanzo_config: HanzoConfig,
}
struct EngineInstance {
sender: Sender<Request>,
engine_handler: JoinHandle<()>,
reboot_state: RebootState,
config: HanzoConfig,
category: ModelCategory,
logger: Arc<IntervalLogger>,
session_store: Arc<std::sync::Mutex<engine::agentic_session::AgenticSessionStore>>,
pub(crate) file_store: files::FileStore,
}
pub struct Hanzo {
engines: RwLock<HashMap<String, EngineInstance>>,
unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
reloading_models: RwLock<HashSet<String>>,
default_engine_id: RwLock<Option<String>>,
model_aliases: RwLock<HashMap<String, String>>,
log: Option<String>,
id: String,
creation_time: u64,
next_request_id: Mutex<RefCell<usize>>,
}
#[derive(Clone)]
struct RebootState {
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
method: SchedulerConfig,
no_kv_cache: bool,
no_prefix_cache: bool,
prefix_cache_n: usize,
disable_eos_stop: bool,
throughput_logging_enabled: bool,
search_embedding_model: Option<SearchEmbeddingModel>,
search_callback: Option<Arc<search::SearchCallback>>,
tool_callbacks: tools::ToolCallbacksWithTools,
mcp_client_config: Option<McpClientConfig>,
loader_config: Option<ModelLoaderConfig>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ModelStatus {
Loaded,
Unloaded,
Reloading,
}
impl std::fmt::Display for ModelStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ModelStatus::Loaded => write!(f, "loaded"),
ModelStatus::Unloaded => write!(f, "unloaded"),
ModelStatus::Reloading => write!(f, "reloading"),
}
}
}
#[derive(Debug)]
pub enum HanzoError {
EnginePoisoned,
SenderPoisoned,
ModelNotFound(String),
ModelReloading(String),
ReloadFailed(String),
NoLoaderConfig(String),
ModelAlreadyLoaded(String),
ModelAlreadyUnloaded(String),
Other(String),
}
impl std::fmt::Display for HanzoError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", &self)
}
}
impl std::error::Error for HanzoError {}
#[cfg(feature = "pyo3_macros")]
impl From<HanzoError> for pyo3::PyErr {
fn from(value: HanzoError) -> Self {
PyValueError::new_err(format!("{value:?}"))
}
}
pub struct HanzoBuilder {
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
method: SchedulerConfig,
model_id_override: Option<String>,
log: Option<String>,
no_kv_cache: Option<bool>,
no_prefix_cache: Option<bool>,
prefix_cache_n: Option<usize>,
disable_eos_stop: Option<bool>,
throughput_logging_enabled: bool,
search_embedding_model: Option<SearchEmbeddingModel>,
search_callback: Option<Arc<SearchCallback>>,
tool_callbacks: tools::ToolCallbacksWithTools,
mcp_client_config: Option<McpClientConfig>,
loader_config: Option<ModelLoaderConfig>,
code_exec_config: Option<CodeExecutionConfig>,
}
impl HanzoBuilder {
pub fn new(
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
method: SchedulerConfig,
throughput_logging: bool,
search_embedding_model: Option<SearchEmbeddingModel>,
) -> Self {
Self {
pipeline,
method,
model_id_override: None,
log: None,
no_kv_cache: None,
no_prefix_cache: None,
prefix_cache_n: None,
disable_eos_stop: None,
throughput_logging_enabled: throughput_logging,
search_embedding_model,
search_callback: None,
tool_callbacks: HashMap::new(),
mcp_client_config: None,
loader_config: None,
code_exec_config: None,
}
}
pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
self.model_id_override = Some(model_id.into());
self
}
pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
self.loader_config = Some(loader_config);
self
}
pub fn with_log(mut self, log: String) -> Self {
self.log = Some(log);
self
}
pub fn with_opt_log(mut self, log: Option<String>) -> Self {
self.log = log;
self
}
pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
self.no_kv_cache = Some(no_kv_cache);
self
}
pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
self.no_prefix_cache = Some(no_prefix_cache);
self
}
pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
self.prefix_cache_n = Some(prefix_cache_n);
self
}
pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
self.disable_eos_stop = Some(disable_eos_stop);
self
}
pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
self.search_callback = Some(search_callback);
self
}
pub fn with_tool_callback(
mut self,
name: impl Into<String>,
tool_callback: Arc<ToolCallback>,
) -> Self {
let name = name.into();
self.tool_callbacks.insert(
name.clone(),
ToolCallbackWithTool {
callback: ToolCallbackKind::Text(tool_callback),
tool: Tool {
tp: ToolType::Function,
function: Function {
description: None,
name,
parameters: None,
strict: None,
},
},
},
);
self
}
pub fn with_tool_callback_and_tool(
mut self,
name: impl Into<String>,
tool_callback: Arc<ToolCallback>,
tool: Tool,
) -> Self {
let name = name.into();
self.tool_callbacks.insert(
name,
ToolCallbackWithTool {
callback: ToolCallbackKind::Text(tool_callback),
tool,
},
);
self
}
pub fn with_tool_callback_with_tool(
mut self,
name: impl Into<String>,
callback_with_tool: ToolCallbackWithTool,
) -> Self {
self.tool_callbacks.insert(name.into(), callback_with_tool);
self
}
pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
self.mcp_client_config = Some(config);
self
}
pub fn with_code_execution(mut self, config: CodeExecutionConfig) -> Self {
self.code_exec_config = Some(config);
self
}
pub async fn build(self) -> Arc<Hanzo> {
Hanzo::new(self).await
}
}
impl Drop for Hanzo {
fn drop(&mut self) {
if let Ok(engines) = self.engines.read() {
for (_, engine) in engines.iter() {
let _ = engine.sender.try_send(Request::Terminate);
}
}
}
}
impl Hanzo {
fn create_engine_instance(
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
method: SchedulerConfig,
config: EngineConfig,
reboot_state: RebootState,
) -> Result<EngineInstance, String> {
let (tx, rx) = channel(10_000);
let pipeline_guard = pipeline.try_lock().unwrap();
let category = pipeline_guard.category();
let metadata = pipeline_guard.get_metadata();
let kind = metadata.kind.clone();
let device = pipeline_guard.device();
let modalities = metadata.modalities.clone();
let max_seq_len = match &category {
ModelCategory::Diffusion | ModelCategory::Speech => None,
_ => Some(metadata.max_seq_len),
};
let generation_defaults = pipeline_guard.generation_defaults();
let encoder_cache_counters = pipeline_guard.encoder_cache_counters();
drop(pipeline_guard);
let logger = Arc::new(IntervalLogger::new(
Duration::from_secs(5),
encoder_cache_counters,
));
let logger_for_engine = logger.clone();
info!("Pipeline input modalities are {:?}", &modalities.input);
info!("Pipeline output modalities are {:?}", &modalities.output);
let hanzo_config = HanzoConfig {
kind,
device,
category: category.clone(),
modalities,
max_seq_len,
generation_defaults,
};
let session_store = Arc::new(std::sync::Mutex::new(
engine::agentic_session::AgenticSessionStore::new(),
));
let session_store_for_engine = Arc::clone(&session_store);
let file_store = files::FileStore::new();
let file_store_for_engine = file_store.clone();
let tx_for_engine = tx.clone();
let engine_handler = thread::spawn(move || {
#[cfg(feature = "metal")]
objc::rc::autoreleasepool(move || {
let rt = Runtime::new().unwrap();
rt.block_on(async move {
file_store_for_engine.spawn_cleanup_task();
let engine = Engine::new(
tx_for_engine,
rx,
pipeline,
method,
config.no_kv_cache,
config.no_prefix_cache,
config.prefix_cache_n,
config.disable_eos_stop,
config.throughput_logging_enabled,
config.search_embedding_model,
config.search_callback.clone(),
config.tool_callbacks.clone(),
logger_for_engine,
session_store_for_engine,
file_store_for_engine,
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
})
});
#[cfg(not(feature = "metal"))]
{
let rt = Runtime::new().unwrap();
rt.block_on(async move {
file_store_for_engine.spawn_cleanup_task();
let engine = Engine::new(
tx_for_engine,
rx,
pipeline,
method,
config.no_kv_cache,
config.no_prefix_cache,
config.prefix_cache_n,
config.disable_eos_stop,
config.throughput_logging_enabled,
config.search_embedding_model,
config.search_callback.clone(),
config.tool_callbacks.clone(),
logger_for_engine,
session_store_for_engine,
file_store_for_engine,
)
.expect("Engine creation failed.");
Arc::new(engine).run().await;
})
}
});
Ok(EngineInstance {
sender: tx,
engine_handler,
reboot_state,
config: hanzo_config,
category,
logger,
session_store,
file_store,
})
}
async fn init_external_tool_callbacks(
#[cfg_attr(not(feature = "code-execution"), allow(unused_variables))] pipeline: &Arc<
tokio::sync::Mutex<dyn Pipeline>,
>,
tool_callbacks: &mut tools::ToolCallbacksWithTools,
mcp_client_config: Option<&McpClientConfig>,
#[cfg_attr(not(feature = "code-execution"), allow(unused_variables))]
code_exec_config: Option<&CodeExecutionConfig>,
) {
if let Some(config) = mcp_client_config {
let mut mcp_client = McpClient::new(config.clone());
let total_servers = config.servers.len();
match mcp_client.initialize().await {
Ok(()) => {
let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
let tools_count = mcp_callbacks_with_tools.len();
for (name, callback_with_tool) in mcp_callbacks_with_tools {
tool_callbacks.insert(name.clone(), callback_with_tool.clone());
}
if tools_count == 0 {
warn!(
"MCP client initialized but no tools were registered from {} servers",
total_servers
);
} else {
info!(
"MCP client initialized successfully with {} tools from {} servers",
tools_count, total_servers
);
}
}
Err(e) => {
warn!(
"Failed to initialize MCP client with {} configured servers: {}",
total_servers, e
);
warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
}
}
}
#[cfg(feature = "code-execution")]
if let Some(code_exec_cfg) = code_exec_config {
let approval_callback = code_exec_cfg.approval_callback.as_ref().map(|callback| {
let callback = Arc::clone(callback);
Arc::new(move |approval: &hanzo_code_exec::CodeExecutionApproval| {
let approval = CodeExecutionApproval {
approval_id: approval.approval_id.clone(),
session_id: approval.session_id.clone(),
code: approval.code.clone(),
outputs: approval.outputs.clone(),
working_directory: approval.working_directory.clone(),
};
callback(&approval)
}) as Arc<hanzo_code_exec::CodeExecutionApprovalCallback>
});
let exec_config = hanzo_code_exec::CodeExecutionConfig {
python_path: code_exec_cfg.python_path.clone(),
timeout_secs: code_exec_cfg.timeout_secs,
working_directory: code_exec_cfg.working_directory.clone(),
sandbox_policy: code_exec_cfg.sandbox_policy.clone(),
permission: match code_exec_cfg.permission {
CodeExecutionPermission::Auto => hanzo_code_exec::CodeExecutionPermission::Auto,
CodeExecutionPermission::Ask => hanzo_code_exec::CodeExecutionPermission::Ask,
CodeExecutionPermission::Deny => hanzo_code_exec::CodeExecutionPermission::Deny,
},
approval_callback,
};
match hanzo_code_exec::CodeExecutionManager::new(exec_config).await {
Ok(manager) => {
let input_modalities: Vec<hanzo_code_exec::InputModality> = {
let pipe = get_mut_arcmutex!(pipeline);
pipe.get_metadata()
.modalities
.input
.iter()
.filter_map(|m| match m {
pipeline::SupportedModality::Text => {
Some(hanzo_code_exec::InputModality::Text)
}
pipeline::SupportedModality::Vision => {
Some(hanzo_code_exec::InputModality::Vision)
}
pipeline::SupportedModality::Audio => {
Some(hanzo_code_exec::InputModality::Audio)
}
pipeline::SupportedModality::Video => {
Some(hanzo_code_exec::InputModality::Video)
}
_ => None,
})
.collect()
};
let effective = manager.effective_protection();
let network = manager.network_mode();
let callbacks = manager.get_tool_callbacks(&input_modalities);
let count = callbacks.len();
for (name, cb) in callbacks {
tool_callbacks.insert(name, cb);
}
warn!("============================================================");
warn!(" CODE EXECUTION IS ENABLED");
warn!(" The model can execute arbitrary Python code on this machine.");
if effective.any() {
let fs = if effective.fs_isolated {
"workdir + system libs only"
} else {
"NOT restricted"
};
let net = if effective.network_isolated {
match network {
Some(hanzo_sandbox::NetworkMode::None) => "denied",
Some(hanzo_sandbox::NetworkMode::Loopback) => "loopback only",
_ => "NOT restricted",
}
} else {
"NOT restricted"
};
warn!(
" Sandbox: on. Filesystem: {fs}. Network: {net}. rlimits: {}.",
if effective.rlimits_applied {
"applied"
} else {
"not applied"
}
);
if !effective.fs_isolated || !effective.network_isolated {
warn!(" Some layers are inactive on this host. Use --sandbox on to make missing layers a hard error.");
}
} else {
warn!(" Sandbox: OFF. Network and filesystem are NOT restricted.");
warn!(" Pass a sandbox_policy (or --sandbox on at the CLI) to enable isolation.");
}
warn!(" See: https://hanzoai.github.io/engine/reference/sandbox/");
warn!("============================================================");
info!("Code execution initialized with {count} tools");
}
Err(e) => {
warn!("Failed to initialize code execution: {e}");
warn!("Continuing without code execution functionality.");
}
}
}
}
async fn new(config: HanzoBuilder) -> Arc<Self> {
info!("git revision: {HANZO_GIT_REVISION}");
let HanzoBuilder {
pipeline,
method,
model_id_override,
log,
no_kv_cache,
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback,
mut tool_callbacks,
mcp_client_config,
loader_config,
#[cfg_attr(not(feature = "code-execution"), allow(unused_variables))]
code_exec_config,
} = config;
hanzo_quant::cublaslt::maybe_init_cublas_lt_wrapper(get_mut_arcmutex!(pipeline).device());
let no_kv_cache = no_kv_cache.unwrap_or(false);
let no_prefix_cache = no_prefix_cache.unwrap_or(false);
let prefix_cache_n = prefix_cache_n.unwrap_or(16);
let disable_eos_stop = disable_eos_stop.unwrap_or(false);
Self::init_external_tool_callbacks(
&pipeline,
&mut tool_callbacks,
mcp_client_config.as_ref(),
code_exec_config.as_ref(),
)
.await;
let reboot_state = RebootState {
pipeline: pipeline.clone(),
method: method.clone(),
no_kv_cache,
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback: search_callback.clone(),
tool_callbacks: tool_callbacks.clone(),
mcp_client_config: mcp_client_config.clone(),
loader_config,
};
let engine_config = EngineConfig {
no_kv_cache,
no_prefix_cache,
prefix_cache_n,
disable_eos_stop,
throughput_logging_enabled,
search_embedding_model,
search_callback,
tool_callbacks,
};
let engine_instance =
Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
.expect("Failed to create engine instance");
let pipeline_name = pipeline.try_lock().unwrap().name();
let (id, alias_map) = match model_id_override {
Some(override_id) => {
let mut alias_map = HashMap::new();
if override_id != pipeline_name {
alias_map.insert(pipeline_name.clone(), override_id.clone());
}
(override_id, alias_map)
}
None => (pipeline_name.clone(), HashMap::new()),
};
if distributed::is_daemon() {
let request_sender = engine_instance.sender.clone();
if cfg!(feature = "ring") {
distributed::ring_daemon_replicator(request_sender);
} else {
distributed::nccl_daemon_replicator(request_sender);
}
#[allow(clippy::empty_loop)]
loop {}
}
let is_multi_threaded = tokio::runtime::Handle::try_current()
.is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
if !distributed::is_daemon()
&& is_multi_threaded
&& matches!(
engine_instance.category,
ModelCategory::Text | ModelCategory::Multimodal { .. }
)
{
let clone_sender = engine_instance.sender.clone();
tokio::task::block_in_place(|| {
let (tx, mut rx) = channel(1);
let req = Request::Normal(Box::new(NormalRequest {
id: 0,
messages: RequestMessage::Completion {
text: "hello".to_string(),
echo_prompt: false,
best_of: None,
},
sampling_params: SamplingParams {
max_len: Some(1),
..SamplingParams::deterministic()
},
response: tx,
return_logprobs: false,
is_streaming: false,
constraint: Constraint::None,
suffix: None,
tool_choice: None,
tools: None,
logits_processors: None,
return_raw_logits: false,
web_search_options: None,
enable_code_execution: false,
code_execution_permission: None,
code_execution_approval_notifier: None,
agent_permission: None,
agent_approval_handler: None,
agent_approval_notifier: None,
max_tool_rounds: None,
tool_dispatch_url: None,
model_id: None,
truncate_sequence: false,
session_id: None,
files: None,
}));
debug!("Beginning dummy run.");
let start = Instant::now();
clone_sender.blocking_send(req).unwrap();
let mut received_any = false;
while let Some(_resp) = rx.blocking_recv() {
received_any = true;
}
if received_any {
let end = Instant::now();
debug!(
"Dummy run completed in {}s.",
end.duration_since(start).as_secs_f64()
);
} else {
warn!("Dummy run failed!");
}
});
engine_instance.logger.reset();
}
let mut engines = HashMap::new();
engines.insert(id.clone(), engine_instance);
Arc::new(Self {
engines: RwLock::new(engines),
unloaded_models: RwLock::new(HashMap::new()),
reloading_models: RwLock::new(HashSet::new()),
default_engine_id: RwLock::new(Some(id.clone())),
model_aliases: RwLock::new(alias_map),
log,
id,
creation_time: SystemTime::now()
.duration_since(UNIX_EPOCH)
.expect("Time travel has occurred!")
.as_secs(),
next_request_id: Mutex::new(RefCell::new(1)),
})
}
fn reboot_engine(&self, model_id: &str) -> Result<(), HanzoError> {
let mut engines = self.engines.write().map_err(|_| {
tracing::warn!("Couldn't get write lock on engines during reboot attempt");
HanzoError::EnginePoisoned
})?;
if let Some(engine_instance) = engines.get(model_id) {
if !engine_instance.engine_handler.is_finished() {
tracing::info!("Engine {} already running, returning ok", model_id);
return Ok(());
}
let reboot_state = engine_instance.reboot_state.clone();
let engine_config = EngineConfig {
no_kv_cache: reboot_state.no_kv_cache,
no_prefix_cache: reboot_state.no_prefix_cache,
prefix_cache_n: reboot_state.prefix_cache_n,
disable_eos_stop: reboot_state.disable_eos_stop,
throughput_logging_enabled: reboot_state.throughput_logging_enabled,
search_embedding_model: reboot_state.search_embedding_model,
search_callback: reboot_state.search_callback.clone(),
tool_callbacks: reboot_state.tool_callbacks.clone(),
};
let new_engine_instance = Self::create_engine_instance(
reboot_state.pipeline.clone(),
reboot_state.method.clone(),
engine_config,
reboot_state,
)
.map_err(|e| {
tracing::error!("Failed to create new engine instance: {}", e);
HanzoError::EnginePoisoned
})?;
engines.insert(model_id.to_string(), new_engine_instance);
tracing::info!("Successfully rebooted engine {}", model_id);
Ok(())
} else {
Err(HanzoError::EnginePoisoned)
}
}
fn engine_dead(&self, model_id: &str) -> Result<bool, HanzoError> {
let engines = self.engines.read().map_err(|_| {
tracing::warn!("Couldn't get read lock on engines!");
HanzoError::EnginePoisoned
})?;
if let Some(engine_instance) = engines.get(model_id) {
Ok(engine_instance.engine_handler.is_finished())
} else {
Err(HanzoError::EnginePoisoned)
}
}
pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, HanzoError> {
let resolved_model_id = self.resolve_alias_or_default(model_id)?;
let is_loaded = {
let engines = self
.engines
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
engines.contains_key(&resolved_model_id)
};
if is_loaded {
if self.engine_dead(&resolved_model_id)? {
tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
self.reboot_engine(&resolved_model_id)?
}
let engines = self
.engines
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
if let Some(engine_instance) = engines.get(&resolved_model_id) {
return Ok(engine_instance.sender.clone());
}
}
let is_unloaded = {
let unloaded = self
.unloaded_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
unloaded.contains_key(&resolved_model_id)
};
if is_unloaded {
tracing::info!(
"Model {} is unloaded, triggering auto-reload",
resolved_model_id
);
self.reload_model_blocking(&resolved_model_id)?;
let engines = self
.engines
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
if let Some(engine_instance) = engines.get(&resolved_model_id) {
return Ok(engine_instance.sender.clone());
}
}
Err(HanzoError::ModelNotFound(resolved_model_id))
}
pub fn find_file(&self, id: &str) -> Option<Arc<files::File>> {
let engines = self.engines.read().ok()?;
for instance in engines.values() {
if let Some(f) = instance.file_store.get(id) {
return Some(f);
}
}
None
}
pub fn list_files(&self) -> Vec<Arc<files::File>> {
let mut out = Vec::new();
let Ok(engines) = self.engines.read() else {
return out;
};
for instance in engines.values() {
out.extend(instance.file_store.list_all());
}
out
}
pub fn remove_file(&self, id: &str) -> bool {
let Ok(engines) = self.engines.read() else {
return false;
};
for instance in engines.values() {
if instance.file_store.remove(id) {
return true;
}
}
false
}
pub fn get_session_store(
&self,
model_id: Option<&str>,
) -> Result<Arc<std::sync::Mutex<engine::agentic_session::AgenticSessionStore>>, HanzoError>
{
let resolved_model_id = self.resolve_alias_or_default(model_id)?;
let engines = self
.engines
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
engines
.get(&resolved_model_id)
.map(|e| Arc::clone(&e.session_store))
.ok_or(HanzoError::ModelNotFound(resolved_model_id))
}
fn get_file_store(&self, model_id: Option<&str>) -> Result<files::FileStore, HanzoError> {
let resolved_model_id = self.resolve_alias_or_default(model_id)?;
let engines = self
.engines
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
engines
.get(&resolved_model_id)
.map(|e| e.file_store.clone())
.ok_or(HanzoError::ModelNotFound(resolved_model_id))
}
pub fn export_session(
&self,
model_id: Option<&str>,
session_id: &str,
) -> Result<Option<engine::agentic_session::SerializedSession>, HanzoError> {
let store = self.get_session_store(model_id)?;
let exported = {
let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
guard
.export(session_id)
.map_err(|e| HanzoError::Other(e.to_string()))?
};
let Some(mut session) = exported else {
return Ok(None);
};
let file_store = self.get_file_store(model_id)?;
session.files = file_store
.list_for_session(session_id)
.into_iter()
.map(|arc| (*arc).clone())
.collect();
Ok(Some(session))
}
pub fn import_session(
&self,
model_id: Option<&str>,
session_id: String,
session: engine::agentic_session::SerializedSession,
) -> Result<(), HanzoError> {
let files = session.files.clone();
let store = self.get_session_store(model_id)?;
{
let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
guard
.import(session_id.clone(), session)
.map_err(|e| HanzoError::Other(e.to_string()))?;
}
let file_store = self.get_file_store(model_id)?;
for f in files {
file_store.insert(f, Some(session_id.clone()));
}
Ok(())
}
pub fn fork_session(
&self,
model_id: Option<&str>,
src_session_id: &str,
dest_session_id: String,
num_turns: usize,
) -> Result<(), HanzoError> {
let store = self.get_session_store(model_id)?;
let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
guard
.fork(src_session_id, dest_session_id, num_turns)
.map_err(|e| HanzoError::Other(e.to_string()))
}
pub fn delete_session(
&self,
model_id: Option<&str>,
session_id: &str,
) -> Result<bool, HanzoError> {
let store = self.get_session_store(model_id)?;
let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
Ok(guard.delete(session_id))
}
pub fn list_session_ids(&self, model_id: Option<&str>) -> Result<Vec<String>, HanzoError> {
let store = self.get_session_store(model_id)?;
let guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
Ok(guard.list_ids())
}
pub fn get_id(&self) -> String {
self.id.clone()
}
pub fn get_creation_time(&self) -> u64 {
self.creation_time
}
fn resolve_alias(&self, model_id: &str) -> Result<String, HanzoError> {
let aliases = self
.model_aliases
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
if let Some(primary_id) = aliases.get(model_id) {
Ok(primary_id.clone())
} else {
Ok(model_id.to_string())
}
}
fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, HanzoError> {
match model_id {
Some(id) => self.resolve_alias(id),
None => {
let default_lock = self
.default_engine_id
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
Ok(default_lock
.as_ref()
.ok_or(HanzoError::EnginePoisoned)?
.clone())
}
}
}
pub fn register_model_alias(
&self,
alias: impl Into<String>,
model_id: &str,
) -> Result<(), String> {
let alias = alias.into();
let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
if alias == resolved_model_id {
return Ok(());
}
let reloading = self
.reloading_models
.read()
.map_err(|_| "Failed to acquire read lock on reloading_models")?;
let model_reloading = reloading.contains(&resolved_model_id);
let alias_conflict = reloading.contains(&alias);
drop(reloading);
let engines = self
.engines
.read()
.map_err(|_| "Failed to acquire read lock on engines")?;
let model_loaded = engines.contains_key(&resolved_model_id);
let alias_conflict = alias_conflict || engines.contains_key(&alias);
drop(engines);
let unloaded = self
.unloaded_models
.read()
.map_err(|_| "Failed to acquire read lock on unloaded_models")?;
let model_unloaded = unloaded.contains_key(&resolved_model_id);
let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
drop(unloaded);
if !(model_loaded || model_unloaded || model_reloading) {
return Err(format!("Model {resolved_model_id} not found"));
}
if alias_conflict {
return Err(format!(
"Alias '{}' conflicts with an existing model ID",
alias
));
}
let mut aliases = self
.model_aliases
.write()
.map_err(|_| "Failed to acquire write lock on model_aliases")?;
if let Some(existing) = aliases.get(&alias) {
if existing == &resolved_model_id {
return Ok(());
}
return Err(format!(
"Alias '{}' is already assigned to model '{}'",
alias, existing
));
}
aliases.insert(alias, resolved_model_id);
Ok(())
}
pub fn model_exists(&self, model_id: &str) -> Result<bool, HanzoError> {
let resolved_model_id = self.resolve_alias(model_id)?;
let reloading = self
.reloading_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
if reloading.contains(&resolved_model_id) {
return Ok(true);
}
drop(reloading);
let engines = self
.engines
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
if engines.contains_key(&resolved_model_id) {
return Ok(true);
}
drop(engines);
let unloaded = self
.unloaded_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
if unloaded.contains_key(&resolved_model_id) {
return Ok(true);
}
Ok(false)
}
pub fn get_logger(&self, model_id: Option<&str>) -> Result<Arc<IntervalLogger>, HanzoError> {
let resolved_model_id = self.resolve_alias_or_default(model_id)?;
let engines = self
.engines
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
if let Some(engine_instance) = engines.get(&resolved_model_id) {
Ok(engine_instance.logger.clone())
} else {
Err(HanzoError::EnginePoisoned)
}
}
pub fn get_model_category(&self, model_id: Option<&str>) -> Result<ModelCategory, HanzoError> {
let resolved_model_id = self.resolve_alias_or_default(model_id)?;
let engines = self
.engines
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
if let Some(engine_instance) = engines.get(&resolved_model_id) {
Ok(engine_instance.category.clone())
} else {
Err(HanzoError::EnginePoisoned)
}
}
pub fn max_sequence_length(&self, model_id: Option<&str>) -> Result<Option<usize>, HanzoError> {
let resolved_model_id = self.resolve_alias_or_default(model_id)?;
let engines = self
.engines
.read()
.map_err(|_| HanzoError::SenderPoisoned)?;
if let Some(engine_instance) = engines.get(&resolved_model_id) {
Ok(engine_instance.config.max_seq_len)
} else {
Err(HanzoError::EnginePoisoned)
}
}
pub fn next_request_id(&self) -> usize {
let l = self.next_request_id.lock().unwrap();
let last = &mut *l.borrow_mut();
let last_v = *last;
*last += 1;
last_v
}
pub async fn add_model(
&self,
model_id: String,
pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
method: SchedulerConfig,
config: AddModelConfig,
) -> Result<(), String> {
{
let reloading = self
.reloading_models
.read()
.map_err(|_| "Failed to acquire read lock on reloading_models")?;
if reloading.contains(&model_id) {
return Err(format!("Model {model_id} is currently reloading"));
}
}
{
let engines = self
.engines
.read()
.map_err(|_| "Failed to acquire read lock on engines")?;
if engines.contains_key(&model_id) {
return Err(format!("Model {model_id} already exists"));
}
}
{
let unloaded = self
.unloaded_models
.read()
.map_err(|_| "Failed to acquire read lock on unloaded_models")?;
if unloaded.contains_key(&model_id) {
return Err(format!("Model {model_id} already exists (unloaded)"));
}
}
{
let aliases = self
.model_aliases
.read()
.map_err(|_| "Failed to acquire read lock on model_aliases")?;
if aliases.contains_key(&model_id) {
return Err(format!(
"Model ID '{}' conflicts with an existing alias",
model_id
));
}
}
let mut engine_config = config.engine_config;
Self::init_external_tool_callbacks(
&pipeline,
&mut engine_config.tool_callbacks,
config.mcp_client_config.as_ref(),
config.code_exec_config.as_ref(),
)
.await;
let reboot_state = RebootState {
pipeline: pipeline.clone(),
method: method.clone(),
no_kv_cache: engine_config.no_kv_cache,
no_prefix_cache: engine_config.no_prefix_cache,
prefix_cache_n: engine_config.prefix_cache_n,
disable_eos_stop: engine_config.disable_eos_stop,
throughput_logging_enabled: engine_config.throughput_logging_enabled,
search_embedding_model: engine_config.search_embedding_model,
search_callback: engine_config.search_callback.clone(),
tool_callbacks: engine_config.tool_callbacks.clone(),
mcp_client_config: config.mcp_client_config.clone(),
loader_config: config.loader_config.clone(),
};
let engine_instance =
Self::create_engine_instance(pipeline, method, engine_config, reboot_state)?;
let mut engines = self
.engines
.write()
.map_err(|_| "Failed to acquire write lock on engines")?;
engines.insert(model_id.clone(), engine_instance);
if engines.len() == 1 {
let mut default_lock = self
.default_engine_id
.write()
.map_err(|_| "Failed to acquire write lock on default_engine_id")?;
*default_lock = Some(model_id.clone());
info!("First model added, setting '{}' as default", model_id);
}
Ok(())
}
pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
let mut engines = self
.engines
.write()
.map_err(|_| "Failed to acquire write lock on engines")?;
if engines.len() <= 1 {
return Err("Cannot remove the last model from Hanzo".to_string());
}
if let Some(engine_instance) = engines.remove(&resolved_model_id) {
let _ = engine_instance.sender.blocking_send(Request::Terminate);
let mut default_lock = self
.default_engine_id
.write()
.map_err(|_| "Failed to acquire write lock on default_engine_id")?;
if let Some(ref default_id) = *default_lock {
if default_id == &resolved_model_id {
*default_lock = engines.keys().next().cloned();
}
}
drop(default_lock);
drop(engines);
let mut aliases = self
.model_aliases
.write()
.map_err(|_| "Failed to acquire write lock on model_aliases")?;
aliases.retain(|_, target| target != &resolved_model_id);
Ok(())
} else {
Err(format!("Model {resolved_model_id} not found"))
}
}
pub fn list_models(&self) -> Result<Vec<String>, String> {
let engines = self
.engines
.read()
.map_err(|_| "Failed to acquire read lock on engines")?;
Ok(engines.keys().cloned().collect())
}
pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
let default_lock = self
.default_engine_id
.read()
.map_err(|_| "Failed to acquire read lock on default_engine_id")?;
Ok(default_lock.clone())
}
pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
let engines = self
.engines
.read()
.map_err(|_| "Failed to acquire read lock on engines")?;
if !engines.contains_key(&resolved_model_id) {
return Err(format!("Model {resolved_model_id} not found"));
}
drop(engines);
let mut default_lock = self
.default_engine_id
.write()
.map_err(|_| "Failed to acquire write lock on default_engine_id")?;
let old_default = default_lock.clone();
*default_lock = Some(resolved_model_id.clone());
info!(
"Default model changed: {:?} -> {:?}",
old_default, resolved_model_id
);
Ok(())
}
pub fn send_request(&self, mut request: Request) -> Result<(), HanzoError> {
let model_id = match &mut request {
Request::Normal(normal_req) => normal_req.model_id.as_deref(),
_ => None, };
let sender = self.get_sender(model_id)?;
sender
.blocking_send(request)
.map_err(|_| HanzoError::SenderPoisoned)
}
pub fn maybe_log_request(this: Arc<Self>, repr: String) {
if let Some(file) = &this.log {
let mut f = OpenOptions::new()
.append(true)
.create(true) .open(file)
.expect("Unable to open file");
let time = chrono::offset::Local::now();
f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
.expect("Unable to write data");
}
}
pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
if let Some(file) = &this.log {
let mut f = OpenOptions::new()
.append(true)
.create(true) .open(file)
.expect("Unable to open file");
let time = chrono::offset::Local::now();
let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
.expect("Unable to write data");
}
}
pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
if let Some(file) = &this.log {
let mut f = OpenOptions::new()
.append(true)
.create(true) .open(file)
.expect("Unable to open file");
let time = chrono::offset::Local::now();
f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
.expect("Unable to write data");
}
}
pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
let resolved_model_id = self
.resolve_alias_or_default(model_id)
.map_err(|e| e.to_string())?;
let engines = self
.engines
.read()
.map_err(|_| "Failed to acquire read lock on engines")?;
if let Some(engine_instance) = engines.get(&resolved_model_id) {
Ok(engine_instance.reboot_state.tool_callbacks.len())
} else {
Err(format!("Model {resolved_model_id} not found"))
}
}
pub fn list_mcp_tools(
&self,
model_id: Option<&str>,
) -> Result<Vec<(String, Option<String>)>, String> {
let resolved_model_id = self
.resolve_alias_or_default(model_id)
.map_err(|e| e.to_string())?;
let engines = self
.engines
.read()
.map_err(|_| "Failed to acquire read lock on engines")?;
let engine_instance = engines
.get(&resolved_model_id)
.ok_or_else(|| format!("Model {resolved_model_id} not found"))?;
let mut tools: Vec<(String, Option<String>)> = engine_instance
.reboot_state
.tool_callbacks
.values()
.filter(|cb| {
let name = &cb.tool.function.name;
!search::search_tool_called(name) && {
#[cfg(feature = "code-execution")]
{
!hanzo_code_exec::code_exec_tool_called(name)
}
#[cfg(not(feature = "code-execution"))]
{
true
}
}
})
.map(|cb| {
(
cb.tool.function.name.clone(),
cb.tool.function.description.clone(),
)
})
.collect();
tools.sort_by(|a, b| a.0.cmp(&b.0));
Ok(tools)
}
pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
let resolved_model_id = self
.resolve_alias_or_default(model_id)
.map_err(|e| e.to_string())?;
let engines = self
.engines
.read()
.map_err(|_| "Failed to acquire read lock on engines")?;
if let Some(engine_instance) = engines.get(&resolved_model_id) {
Ok(engine_instance.reboot_state.mcp_client_config.is_some())
} else {
Err(format!("Model {resolved_model_id} not found"))
}
}
pub fn config(&self, model_id: Option<&str>) -> Result<HanzoConfig, String> {
let resolved_model_id = self
.resolve_alias_or_default(model_id)
.map_err(|e| e.to_string())?;
let engines = self
.engines
.read()
.map_err(|_| "Failed to acquire read lock on engines")?;
if let Some(engine_instance) = engines.get(&resolved_model_id) {
Ok(engine_instance.config.clone())
} else {
Err(format!("Model {resolved_model_id} not found"))
}
}
pub fn unload_model(&self, model_id: &str) -> Result<(), HanzoError> {
let resolved_model_id = self.resolve_alias(model_id)?;
{
let unloaded = self
.unloaded_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
if unloaded.contains_key(&resolved_model_id) {
return Err(HanzoError::ModelAlreadyUnloaded(resolved_model_id.clone()));
}
}
let mut engines = self
.engines
.write()
.map_err(|_| HanzoError::EnginePoisoned)?;
let engine_instance = engines
.remove(&resolved_model_id)
.ok_or_else(|| HanzoError::ModelNotFound(resolved_model_id.clone()))?;
let loader_config = engine_instance
.reboot_state
.loader_config
.clone()
.ok_or_else(|| HanzoError::NoLoaderConfig(resolved_model_id.clone()))?;
let unloaded_state = UnloadedModelState {
loader_config,
scheduler_config: engine_instance.reboot_state.method.clone(),
engine_config: EngineConfig {
no_kv_cache: engine_instance.reboot_state.no_kv_cache,
no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
search_embedding_model: engine_instance.reboot_state.search_embedding_model,
search_callback: engine_instance.reboot_state.search_callback.clone(),
tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
},
mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
category: engine_instance.category.clone(),
hanzo_config: engine_instance.config.clone(),
};
let _ = engine_instance.sender.try_send(Request::Terminate);
drop(engines);
let mut unloaded = self
.unloaded_models
.write()
.map_err(|_| HanzoError::EnginePoisoned)?;
unloaded.insert(resolved_model_id.to_string(), unloaded_state);
let mut default_lock = self
.default_engine_id
.write()
.map_err(|_| HanzoError::EnginePoisoned)?;
if let Some(ref default_id) = *default_lock {
if default_id == &resolved_model_id {
let engines = self
.engines
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
*default_lock = engines.keys().next().cloned();
}
}
info!("Model {} unloaded successfully", resolved_model_id);
Ok(())
}
pub async fn reload_model(&self, model_id: &str) -> Result<(), HanzoError> {
let resolved_model_id = self.resolve_alias(model_id)?;
{
let reloading = self
.reloading_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
if reloading.contains(&resolved_model_id) {
return Err(HanzoError::ModelReloading(resolved_model_id.clone()));
}
}
{
let mut reloading = self
.reloading_models
.write()
.map_err(|_| HanzoError::EnginePoisoned)?;
reloading.insert(resolved_model_id.clone());
}
let unloaded_state = {
let unloaded = self
.unloaded_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
unloaded
.get(&resolved_model_id)
.cloned()
.ok_or_else(|| HanzoError::ModelNotFound(resolved_model_id.clone()))?
};
let result = self
.do_reload_model(&resolved_model_id, unloaded_state)
.await;
{
let mut reloading = self
.reloading_models
.write()
.map_err(|_| HanzoError::EnginePoisoned)?;
reloading.remove(&resolved_model_id);
}
result
}
async fn do_reload_model(
&self,
model_id: &str,
unloaded_state: UnloadedModelState,
) -> Result<(), HanzoError> {
use crate::model_loader::LoaderBuilder;
info!("Reloading model: {}", model_id);
let loader_config = &unloaded_state.loader_config;
let loader = LoaderBuilder::new(loader_config.model_selected.clone())
.with_chat_template(loader_config.chat_template.clone())
.with_jinja_explicit(loader_config.jinja_explicit.clone())
.build()
.map_err(|e| HanzoError::ReloadFailed(format!("Failed to build loader: {e}")))?;
let pipeline = loader
.load_model_from_hf(
None,
loader_config.token_source.clone(),
&loader_config.dtype,
&loader_config.device,
loader_config.silent,
loader_config.device_map_setting.clone(),
loader_config.isq,
loader_config.paged_attn_config,
)
.map_err(|e| HanzoError::ReloadFailed(format!("Failed to load model: {e}")))?;
if let Some(mtp_config) = loader_config.mtp_config.clone() {
pipeline
.blocking_lock()
.attach_speculative(SpeculativeConfig::Mtp(mtp_config))
.map_err(|e| {
HanzoError::ReloadFailed(format!(
"Failed to attach MTP speculative decoding: {e}"
))
})?;
}
let reboot_state = RebootState {
pipeline: pipeline.clone(),
method: unloaded_state.scheduler_config.clone(),
no_kv_cache: unloaded_state.engine_config.no_kv_cache,
no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
search_embedding_model: unloaded_state.engine_config.search_embedding_model,
search_callback: unloaded_state.engine_config.search_callback.clone(),
tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
mcp_client_config: unloaded_state.mcp_client_config.clone(),
loader_config: Some(unloaded_state.loader_config.clone()),
};
let engine_instance = Self::create_engine_instance(
pipeline,
unloaded_state.scheduler_config,
unloaded_state.engine_config,
reboot_state,
)
.map_err(|e| HanzoError::ReloadFailed(format!("Failed to create engine: {e}")))?;
{
let mut engines = self
.engines
.write()
.map_err(|_| HanzoError::EnginePoisoned)?;
engines.insert(model_id.to_string(), engine_instance);
}
{
let mut unloaded = self
.unloaded_models
.write()
.map_err(|_| HanzoError::EnginePoisoned)?;
unloaded.remove(model_id);
}
info!("Model {} reloaded successfully", model_id);
Ok(())
}
pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), HanzoError> {
match tokio::runtime::Handle::try_current() {
Ok(handle) => {
if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
Err(HanzoError::ReloadFailed(
"Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
))
} else {
tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
}
}
Err(_) => {
let rt = tokio::runtime::Runtime::new().map_err(|e| {
HanzoError::ReloadFailed(format!("Failed to create runtime: {e}"))
})?;
rt.block_on(self.reload_model(model_id))
}
}
}
pub fn list_unloaded_models(&self) -> Result<Vec<String>, HanzoError> {
let unloaded = self
.unloaded_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
Ok(unloaded.keys().cloned().collect())
}
pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, HanzoError> {
let resolved_model_id = self.resolve_alias(model_id)?;
let engines = self
.engines
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
Ok(engines.contains_key(&resolved_model_id))
}
pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, HanzoError> {
let resolved_model_id = self.resolve_alias(model_id)?;
{
let reloading = self
.reloading_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
if reloading.contains(&resolved_model_id) {
return Ok(Some(ModelStatus::Reloading));
}
}
{
let engines = self
.engines
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
if engines.contains_key(&resolved_model_id) {
return Ok(Some(ModelStatus::Loaded));
}
}
{
let unloaded = self
.unloaded_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
if unloaded.contains_key(&resolved_model_id) {
return Ok(Some(ModelStatus::Unloaded));
}
}
Ok(None)
}
pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, HanzoError> {
let mut result = Vec::new();
let reloading = self
.reloading_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
for model_id in reloading.iter() {
result.push((model_id.clone(), ModelStatus::Reloading));
}
drop(reloading);
let engines = self
.engines
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
for model_id in engines.keys() {
result.push((model_id.clone(), ModelStatus::Loaded));
}
drop(engines);
let unloaded = self
.unloaded_models
.read()
.map_err(|_| HanzoError::EnginePoisoned)?;
for model_id in unloaded.keys() {
if !result.iter().any(|(id, _)| id == model_id) {
result.push((model_id.clone(), ModelStatus::Unloaded));
}
}
Ok(result)
}
}