Skip to main content

hanzo_engine/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use engine::Engine;
3pub use engine::{
4    agentic_session::{AgenticSessionStore, SerializedSession, SerializedVideo},
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    EngineInstruction, IntervalLogger, SearchEmbeddingModel, DEFAULT_MAX_TOOL_ROUNDS,
7    ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
8};
9use hanzo_ml::Device;
10use hf_hub::Cache;
11pub use lora::Ordering;
12pub use pipeline::ModelCategory;
13pub use pipeline::Pipeline;
14#[cfg(feature = "pyo3_macros")]
15use pyo3::exceptions::PyValueError;
16use std::collections::{HashMap, HashSet};
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::OnceLock;
20use std::time::{Duration, Instant};
21use std::{
22    cell::RefCell,
23    error::Error,
24    fs::OpenOptions,
25    io::Write,
26    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
27    thread::{self, JoinHandle},
28    time::{SystemTime, UNIX_EPOCH},
29};
30use tokio::sync::mpsc::{channel, Sender};
31use tracing::{debug, info, warn};
32
33pub const HANZO_GIT_REVISION: &str = match option_env!("HANZO_GIT_REVISION") {
34    Some(value) => value,
35    None => "unknown",
36};
37
38mod cuda;
39mod device_map;
40mod engine;
41mod lora;
42mod metal;
43mod model_loader;
44mod moe;
45mod ops;
46mod video_input;
47mod vulkan;
48pub use model_loader::{
49    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
50};
51pub use video_input::{sample_frame_indices, VideoInput};
52pub mod disk_kv_cache;
53mod embedding_models;
54mod kv_cache;
55mod search;
56
57mod model_selected;
58pub use model_selected::ModelSelected;
59pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
60
61mod amoe;
62mod attention;
63mod diagnostics;
64mod diffusion_models;
65pub mod distributed;
66pub mod files;
67mod gguf;
68pub mod layers;
69mod layers_masker;
70mod layers_utils;
71pub mod matformer;
72mod mla;
73mod models;
74mod paged_attention;
75mod pipeline;
76mod prefix_cacher;
77pub mod reasoning_parsers;
78mod request;
79mod response;
80mod sampler;
81mod scheduler;
82mod sequence;
83pub mod speculative;
84mod speech_models;
85mod toml_selector;
86mod tools;
87mod topology;
88mod utils;
89mod vision_models;
90mod xlora_models;
91
92pub use diagnostics::{
93    check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
94    DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
95};
96mod tuning;
97pub use tuning::{
98    auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
99};
100
101pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
102pub use device_map::{
103    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
104};
105pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
106pub use hanzo_audio::AudioInput;
107pub use hanzo_llm_mcp::{
108    AgentPermission, AgentToolApprovalNotifier, AgentToolApprovalRequest, AgentToolKind,
109    AgentToolMetadata, AgentToolSource, CalledFunction, CodeExecutionApprovalNotifier,
110    CodeExecutionApprovalRequest, CodeExecutionPermission, Function, MultimodalToolCallback, Tool,
111    ToolCallContext, ToolCallback, ToolCallbackKind, ToolCallbackWithTool, ToolOutput, ToolType,
112};
113pub use hanzo_llm_mcp::{
114    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
115};
116pub use hanzo_quant::{IsqBits, IsqType, MULTI_LORA_DELIMITER};
117pub use hanzo_sandbox::{NetworkMode, SandboxPolicy};
118
119/// Python code execution config.
120#[derive(Clone, serde::Serialize, serde::Deserialize)]
121pub struct CodeExecutionConfig {
122    /// Defaults to `python3` (`python` on Windows).
123    #[serde(default = "default_python_path")]
124    pub python_path: std::path::PathBuf,
125    /// Per-execution timeout. Defaults to 30s.
126    #[serde(default = "default_timeout_secs")]
127    pub timeout_secs: u64,
128    /// If `None`, a temp dir is created. Otherwise this is the cwd for the model's code.
129    #[serde(default)]
130    pub working_directory: Option<std::path::PathBuf>,
131    /// OS-level sandbox policy. `Some(policy)` enables the platform sandbox
132    /// (Linux/macOS) with the given limits; `None` disables it entirely.
133    /// The CLI/server layer is responsible for choosing.
134    #[serde(default)]
135    pub sandbox_policy: Option<hanzo_sandbox::SandboxPolicy>,
136    #[serde(default)]
137    pub permission: CodeExecutionPermission,
138    #[serde(skip)]
139    pub approval_callback: Option<CodeExecutionApprovalCallback>,
140}
141
142impl std::fmt::Debug for CodeExecutionConfig {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        f.debug_struct("CodeExecutionConfig")
145            .field("python_path", &self.python_path)
146            .field("timeout_secs", &self.timeout_secs)
147            .field("working_directory", &self.working_directory)
148            .field("sandbox_policy", &self.sandbox_policy)
149            .field("permission", &self.permission)
150            .field("approval_callback", &self.approval_callback.is_some())
151            .finish()
152    }
153}
154
155#[derive(Clone, Debug)]
156pub struct AgentToolApproval {
157    pub approval_id: String,
158    pub session_id: String,
159    pub round: usize,
160    pub tool: AgentToolMetadata,
161    pub arguments: serde_json::Value,
162}
163
164#[derive(Clone, Debug)]
165pub struct AgentToolApprovalDecision {
166    pub approve: bool,
167    pub remember_for_session: bool,
168    pub message: Option<String>,
169}
170
171impl AgentToolApprovalDecision {
172    pub fn approve() -> Self {
173        Self {
174            approve: true,
175            remember_for_session: false,
176            message: None,
177        }
178    }
179
180    pub fn approve_for_session() -> Self {
181        Self {
182            approve: true,
183            remember_for_session: true,
184            message: None,
185        }
186    }
187
188    pub fn deny(message: Option<String>) -> Self {
189        Self {
190            approve: false,
191            remember_for_session: false,
192            message,
193        }
194    }
195
196    pub fn deny_with_message(message: impl Into<String>) -> Self {
197        Self {
198            approve: false,
199            remember_for_session: false,
200            message: Some(message.into()),
201        }
202    }
203
204    pub fn with_remember_for_session(mut self, remember_for_session: bool) -> Self {
205        self.remember_for_session = remember_for_session;
206        self
207    }
208}
209
210pub type AgentToolApprovalCallback =
211    Arc<dyn Fn(&AgentToolApproval) -> AgentToolApprovalDecision + Send + Sync + 'static>;
212
213pub type AgentToolApprovalFuture =
214    Pin<Box<dyn Future<Output = AgentToolApprovalDecision> + Send + 'static>>;
215pub type AgentToolApprovalAsyncCallback =
216    Arc<dyn Fn(AgentToolApproval) -> AgentToolApprovalFuture + Send + Sync + 'static>;
217
218#[derive(Clone)]
219pub enum AgentToolApprovalHandler {
220    Sync(AgentToolApprovalCallback),
221    Async(AgentToolApprovalAsyncCallback),
222}
223
224impl AgentToolApprovalHandler {
225    pub fn from_sync(callback: AgentToolApprovalCallback) -> Self {
226        Self::Sync(callback)
227    }
228
229    pub fn from_async(callback: AgentToolApprovalAsyncCallback) -> Self {
230        Self::Async(callback)
231    }
232}
233
234#[derive(Clone, Debug)]
235pub struct CodeExecutionApproval {
236    pub approval_id: String,
237    pub session_id: String,
238    pub code: String,
239    pub outputs: Vec<String>,
240    pub working_directory: Option<std::path::PathBuf>,
241}
242
243pub type CodeExecutionApprovalCallback =
244    Arc<dyn Fn(&CodeExecutionApproval) -> bool + Send + Sync + 'static>;
245
246fn default_python_path() -> std::path::PathBuf {
247    if cfg!(windows) {
248        std::path::PathBuf::from("python")
249    } else {
250        std::path::PathBuf::from("python3")
251    }
252}
253fn default_timeout_secs() -> u64 {
254    30
255}
256
257impl Default for CodeExecutionConfig {
258    fn default() -> Self {
259        Self {
260            python_path: default_python_path(),
261            timeout_secs: default_timeout_secs(),
262            working_directory: None,
263            sandbox_policy: None,
264            permission: CodeExecutionPermission::Auto,
265            approval_callback: None,
266        }
267    }
268}
269pub use files::{
270    format_from_name, is_text_mime, mime_for_format, File, FileContent, FileSource, FileStore,
271    RequestedFile, MODEL_INLINE_BYTES, WIRE_EMBED_LIMIT_BYTES,
272};
273pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
274pub use pipeline::hf::{
275    hf_home_dir, hf_hub_cache_dir, hf_token_path, is_hf_hub_offline, probe_hf_repo_files,
276    HF_HUB_OFFLINE_ENV,
277};
278pub use pipeline::{
279    chat_template::ChatTemplate, expand_isq_value, parse_isq_value, parse_uqff_shard,
280    resolve_uqff_shorthand, AdapterPaths, AnyMoeLoader, AnyMoePipeline, AutoDeviceMapParams,
281    AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
282    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
283    EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
284    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
285    GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
286    Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
287    ModelPaths, MultimodalLoader, MultimodalLoaderBuilder, MultimodalLoaderType,
288    MultimodalPromptPrefixer, MultimodalSpecificConfig, NormalLoader, NormalLoaderBuilder,
289    NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
290    SpeechLoader, SpeechPipeline, Starcoder2Loader, SupportedModality, TokenSource,
291    UQFF_MULTI_FILE_DELIMITER,
292};
293pub use request::{
294    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
295    LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
296    SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
297};
298pub use response::*;
299pub use sampler::{
300    CustomLogitsProcessor, DrySamplingParams, ModelGenerationDefaults, SamplingParams, StopTokens,
301    TopLogprob,
302};
303pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
304pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
305use serde::Serialize;
306pub use speculative::{MtpConfig, SpeculativeConfig};
307pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
308use tokio::runtime::Runtime;
309use toml_selector::{TomlLoaderArgs, TomlSelector};
310pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
311pub use topology::{LayerTopology, Topology};
312pub use utils::debug::initialize_logging;
313pub use utils::memory_usage::MemoryUsage;
314pub use utils::normal::{ModelDType, TryIntoDType};
315pub use utils::{paged_attn_supported, using_flash_attn};
316
317// re-export llguidance for easier LlguidanceGrammar construction
318pub use llguidance;
319
320/// `true` if `HANZO_DEBUG=1`
321pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
322pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
323
324/// Configuration for creating an engine instance
325#[derive(Clone)]
326pub struct EngineConfig {
327    pub no_kv_cache: bool,
328    pub no_prefix_cache: bool,
329    pub prefix_cache_n: usize,
330    pub disable_eos_stop: bool,
331    pub throughput_logging_enabled: bool,
332    pub search_embedding_model: Option<SearchEmbeddingModel>,
333    pub search_callback: Option<Arc<SearchCallback>>,
334    pub tool_callbacks: tools::ToolCallbacksWithTools,
335}
336
337impl Default for EngineConfig {
338    fn default() -> Self {
339        Self {
340            no_kv_cache: false,
341            no_prefix_cache: false,
342            prefix_cache_n: 16,
343            disable_eos_stop: false,
344            throughput_logging_enabled: true,
345            search_embedding_model: None,
346            search_callback: None,
347            tool_callbacks: HashMap::new(),
348        }
349    }
350}
351
352/// Configuration for adding a model to Hanzo
353#[derive(Clone)]
354pub struct AddModelConfig {
355    pub engine_config: EngineConfig,
356    pub mcp_client_config: Option<McpClientConfig>,
357    /// Optional loader config for enabling model unload/reload support.
358    /// Without this, models cannot be unloaded and reloaded.
359    pub loader_config: Option<ModelLoaderConfig>,
360    pub code_exec_config: Option<CodeExecutionConfig>,
361}
362
363impl AddModelConfig {
364    pub fn new(engine_config: EngineConfig) -> Self {
365        Self {
366            engine_config,
367            mcp_client_config: None,
368            loader_config: None,
369            code_exec_config: None,
370        }
371    }
372
373    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
374        self.mcp_client_config = Some(mcp_config);
375        self
376    }
377
378    pub fn with_code_execution(mut self, config: CodeExecutionConfig) -> Self {
379        self.code_exec_config = Some(config);
380        self
381    }
382
383    /// Set the loader config for enabling model unload/reload support.
384    /// Without this, models cannot be unloaded and reloaded.
385    pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
386        self.loader_config = Some(loader_config);
387        self
388    }
389}
390
391#[derive(Clone)]
392pub struct HanzoConfig {
393    pub kind: ModelKind,
394    pub device: Device,
395    pub category: ModelCategory,
396    pub modalities: Modalities,
397    pub max_seq_len: Option<usize>,
398    pub generation_defaults: Option<ModelGenerationDefaults>,
399}
400
401/// Configuration for recreating a model loader when reloading an unloaded model.
402/// This captures the essential parameters needed to reconstruct a loader.
403#[derive(Clone)]
404pub struct ModelLoaderConfig {
405    /// The model selection configuration (Plain, GGUF, Multimodal, etc.)
406    pub model_selected: ModelSelected,
407    /// Source of the HF token
408    pub token_source: TokenSource,
409    /// Optional HF revision
410    pub hf_revision: Option<String>,
411    /// Model data type
412    pub dtype: ModelDType,
413    /// Device to load the model on
414    pub device: Device,
415    /// Device mapping setting
416    pub device_map_setting: DeviceMapSetting,
417    /// In-situ quantization type
418    pub isq: Option<IsqType>,
419    /// Paged attention configuration
420    pub paged_attn_config: Option<PagedAttentionConfig>,
421    /// Whether to suppress logging during loading
422    pub silent: bool,
423    /// Chat template override
424    pub chat_template: Option<String>,
425    /// Explicit Jinja template path
426    pub jinja_explicit: Option<String>,
427    /// Optional speculative decoding attachment to recreate after reload.
428    pub mtp_config: Option<MtpConfig>,
429}
430
431/// State preserved when a model is unloaded.
432/// This contains all the information needed to reload the model on demand.
433#[derive(Clone)]
434pub struct UnloadedModelState {
435    /// Configuration to recreate the loader
436    pub loader_config: ModelLoaderConfig,
437    /// Scheduler configuration
438    pub scheduler_config: SchedulerConfig,
439    /// Engine configuration
440    pub engine_config: EngineConfig,
441    /// MCP client configuration
442    pub mcp_client_config: Option<McpClientConfig>,
443    /// Model category (Text, Multimodal, etc.)
444    pub category: ModelCategory,
445    /// Model metadata configuration
446    pub hanzo_config: HanzoConfig,
447}
448
449/// Internal structure to hold per-engine state
450struct EngineInstance {
451    sender: Sender<Request>,
452    engine_handler: JoinHandle<()>,
453    reboot_state: RebootState,
454    config: HanzoConfig,
455    category: ModelCategory,
456    logger: Arc<IntervalLogger>,
457    /// Shared with the engine so the SDK/HTTP layer can read/write sessions out of band.
458    session_store: Arc<std::sync::Mutex<engine::agentic_session::AgenticSessionStore>>,
459    /// Shared with the engine for fetch-by-id from the SDK/HTTP layer.
460    pub(crate) file_store: files::FileStore,
461}
462
463/// The Hanzo struct handles sending requests to multiple engines.
464/// It is the core multi-threaded component of hanzo, and uses `mpsc`
465/// `Sender` and `Receiver` primitives to send and receive requests to the
466/// appropriate engine based on model ID.
467///
468/// ## Lock Ordering Convention
469///
470/// This struct uses multiple `RwLock`s. To prevent deadlocks, locks must be
471/// acquired in this order:
472/// 1. `reloading_models`
473/// 2. `engines`
474/// 3. `unloaded_models`
475/// 4. `default_engine_id`
476/// 5. `model_aliases`
477///
478/// Use scope-based lock management and explicit `drop()` calls.
479pub struct Hanzo {
480    engines: RwLock<HashMap<String, EngineInstance>>,
481    /// Models that have been unloaded but can be reloaded on demand
482    unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
483    /// Models currently being reloaded (to prevent concurrent reloads)
484    reloading_models: RwLock<HashSet<String>>,
485    default_engine_id: RwLock<Option<String>>,
486    /// Alternate IDs that resolve to primary model IDs.
487    model_aliases: RwLock<HashMap<String, String>>,
488    log: Option<String>,
489    id: String,
490    creation_time: u64,
491    next_request_id: Mutex<RefCell<usize>>,
492}
493
494#[derive(Clone)]
495struct RebootState {
496    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
497    method: SchedulerConfig,
498    no_kv_cache: bool,
499    no_prefix_cache: bool,
500    prefix_cache_n: usize,
501    disable_eos_stop: bool,
502    throughput_logging_enabled: bool,
503    search_embedding_model: Option<SearchEmbeddingModel>,
504    search_callback: Option<Arc<search::SearchCallback>>,
505    tool_callbacks: tools::ToolCallbacksWithTools,
506    mcp_client_config: Option<McpClientConfig>,
507    /// Optional loader config for reloading after unload
508    loader_config: Option<ModelLoaderConfig>,
509}
510
511/// Model status for loaded/unloaded state
512#[derive(Debug, Clone, Copy, PartialEq, Eq)]
513pub enum ModelStatus {
514    Loaded,
515    Unloaded,
516    Reloading,
517}
518
519impl std::fmt::Display for ModelStatus {
520    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521        match self {
522            ModelStatus::Loaded => write!(f, "loaded"),
523            ModelStatus::Unloaded => write!(f, "unloaded"),
524            ModelStatus::Reloading => write!(f, "reloading"),
525        }
526    }
527}
528
529#[derive(Debug)]
530pub enum HanzoError {
531    EnginePoisoned,
532    SenderPoisoned,
533    /// The requested model was not found (neither loaded nor unloaded)
534    ModelNotFound(String),
535    /// The model is currently being reloaded
536    ModelReloading(String),
537    /// Failed to reload the model
538    ReloadFailed(String),
539    /// Model does not have loader config for reloading
540    NoLoaderConfig(String),
541    /// Model is already loaded
542    ModelAlreadyLoaded(String),
543    /// Model is already unloaded
544    ModelAlreadyUnloaded(String),
545    /// Other error with a message.
546    Other(String),
547}
548
549impl std::fmt::Display for HanzoError {
550    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551        write!(f, "{:?}", &self)
552    }
553}
554
555impl std::error::Error for HanzoError {}
556
557#[cfg(feature = "pyo3_macros")]
558impl From<HanzoError> for pyo3::PyErr {
559    fn from(value: HanzoError) -> Self {
560        PyValueError::new_err(format!("{value:?}"))
561    }
562}
563
564/// The HanzoBuilder takes the pipeline and a scheduler method and constructs
565/// an Engine and a Hanzo instance. The Engine runs on a separate thread, and the Hanzo
566/// instance stays on the calling thread.
567pub struct HanzoBuilder {
568    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
569    method: SchedulerConfig,
570    model_id_override: Option<String>,
571    log: Option<String>,
572    no_kv_cache: Option<bool>,
573    no_prefix_cache: Option<bool>,
574    prefix_cache_n: Option<usize>,
575    disable_eos_stop: Option<bool>,
576    throughput_logging_enabled: bool,
577    search_embedding_model: Option<SearchEmbeddingModel>,
578    search_callback: Option<Arc<SearchCallback>>,
579    tool_callbacks: tools::ToolCallbacksWithTools,
580    mcp_client_config: Option<McpClientConfig>,
581    loader_config: Option<ModelLoaderConfig>,
582    code_exec_config: Option<CodeExecutionConfig>,
583}
584
585impl HanzoBuilder {
586    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
587    /// and optional embedding model for web search. To override the search callback,
588    /// use `.with_search_callback(...)` on the builder.
589    pub fn new(
590        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
591        method: SchedulerConfig,
592        throughput_logging: bool,
593        search_embedding_model: Option<SearchEmbeddingModel>,
594    ) -> Self {
595        Self {
596            pipeline,
597            method,
598            model_id_override: None,
599            log: None,
600            no_kv_cache: None,
601            no_prefix_cache: None,
602            prefix_cache_n: None,
603            disable_eos_stop: None,
604            throughput_logging_enabled: throughput_logging,
605            search_embedding_model,
606            search_callback: None,
607            tool_callbacks: HashMap::new(),
608            mcp_client_config: None,
609            loader_config: None,
610            code_exec_config: None,
611        }
612    }
613
614    /// Override the model ID used by Hanzo. Defaults to the pipeline name.
615    pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
616        self.model_id_override = Some(model_id.into());
617        self
618    }
619
620    /// Set the loader config for enabling model unload/reload support.
621    /// Without this, models cannot be unloaded and reloaded.
622    pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
623        self.loader_config = Some(loader_config);
624        self
625    }
626    pub fn with_log(mut self, log: String) -> Self {
627        self.log = Some(log);
628        self
629    }
630    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
631        self.log = log;
632        self
633    }
634    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
635        self.no_kv_cache = Some(no_kv_cache);
636        self
637    }
638    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
639        self.no_prefix_cache = Some(no_prefix_cache);
640        self
641    }
642    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
643        self.prefix_cache_n = Some(prefix_cache_n);
644        self
645    }
646    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
647        self.disable_eos_stop = Some(disable_eos_stop);
648        self
649    }
650
651    /// Use a custom callback to gather search results.
652    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
653        self.search_callback = Some(search_callback);
654        self
655    }
656
657    /// Register a custom callback for the specified tool name.
658    pub fn with_tool_callback(
659        mut self,
660        name: impl Into<String>,
661        tool_callback: Arc<ToolCallback>,
662    ) -> Self {
663        let name = name.into();
664        // Wrap bare callback with a minimal tool definition.
665        self.tool_callbacks.insert(
666            name.clone(),
667            ToolCallbackWithTool {
668                callback: ToolCallbackKind::Text(tool_callback),
669                tool: Tool {
670                    tp: ToolType::Function,
671                    function: Function {
672                        description: None,
673                        name,
674                        parameters: None,
675                        strict: None,
676                    },
677                },
678            },
679        );
680        self
681    }
682
683    /// Register a custom callback with its associated Tool definition. The Tool will be
684    /// automatically added to requests when tool callbacks are active.
685    pub fn with_tool_callback_and_tool(
686        mut self,
687        name: impl Into<String>,
688        tool_callback: Arc<ToolCallback>,
689        tool: Tool,
690    ) -> Self {
691        let name = name.into();
692        self.tool_callbacks.insert(
693            name,
694            ToolCallbackWithTool {
695                callback: ToolCallbackKind::Text(tool_callback),
696                tool,
697            },
698        );
699        self
700    }
701
702    /// Register a pre-built tool callback with its Tool definition.
703    pub fn with_tool_callback_with_tool(
704        mut self,
705        name: impl Into<String>,
706        callback_with_tool: ToolCallbackWithTool,
707    ) -> Self {
708        self.tool_callbacks.insert(name.into(), callback_with_tool);
709        self
710    }
711
712    /// Configure MCP client to connect to external MCP servers.
713    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
714        self.mcp_client_config = Some(config);
715        self
716    }
717
718    /// Enable Python code execution. **Security**: lets the model run arbitrary code on the host with full network and filesystem access.
719    pub fn with_code_execution(mut self, config: CodeExecutionConfig) -> Self {
720        self.code_exec_config = Some(config);
721        self
722    }
723
724    pub async fn build(self) -> Arc<Hanzo> {
725        Hanzo::new(self).await
726    }
727}
728
729impl Drop for Hanzo {
730    fn drop(&mut self) {
731        // Terminate all engines
732        if let Ok(engines) = self.engines.read() {
733            for (_, engine) in engines.iter() {
734                // Use try_send instead of blocking_send to avoid runtime panics
735                let _ = engine.sender.try_send(Request::Terminate);
736            }
737        }
738    }
739}
740
741impl Hanzo {
742    /// Create an engine instance with the given configuration
743    fn create_engine_instance(
744        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
745        method: SchedulerConfig,
746        config: EngineConfig,
747        reboot_state: RebootState,
748    ) -> Result<EngineInstance, String> {
749        let (tx, rx) = channel(10_000);
750
751        let pipeline_guard = pipeline.try_lock().unwrap();
752        let category = pipeline_guard.category();
753        let metadata = pipeline_guard.get_metadata();
754        let kind = metadata.kind.clone();
755        let device = pipeline_guard.device();
756        let modalities = metadata.modalities.clone();
757        let max_seq_len = match &category {
758            ModelCategory::Diffusion | ModelCategory::Speech => None,
759            _ => Some(metadata.max_seq_len),
760        };
761        let generation_defaults = pipeline_guard.generation_defaults();
762        let encoder_cache_counters = pipeline_guard.encoder_cache_counters();
763        drop(pipeline_guard);
764
765        let logger = Arc::new(IntervalLogger::new(
766            Duration::from_secs(5),
767            encoder_cache_counters,
768        ));
769        let logger_for_engine = logger.clone();
770
771        info!("Pipeline input modalities are {:?}", &modalities.input);
772        info!("Pipeline output modalities are {:?}", &modalities.output);
773
774        let hanzo_config = HanzoConfig {
775            kind,
776            device,
777            category: category.clone(),
778            modalities,
779            max_seq_len,
780            generation_defaults,
781        };
782
783        // Shared between engine and EngineInstance so the SDK/HTTP API
784        // can access sessions without going through the request channel.
785        let session_store = Arc::new(std::sync::Mutex::new(
786            engine::agentic_session::AgenticSessionStore::new(),
787        ));
788        let session_store_for_engine = Arc::clone(&session_store);
789        let file_store = files::FileStore::new();
790        let file_store_for_engine = file_store.clone();
791
792        let tx_for_engine = tx.clone();
793        let engine_handler = thread::spawn(move || {
794            #[cfg(feature = "metal")]
795            objc::rc::autoreleasepool(move || {
796                let rt = Runtime::new().unwrap();
797                rt.block_on(async move {
798                    file_store_for_engine.spawn_cleanup_task();
799                    let engine = Engine::new(
800                        tx_for_engine,
801                        rx,
802                        pipeline,
803                        method,
804                        config.no_kv_cache,
805                        config.no_prefix_cache,
806                        config.prefix_cache_n,
807                        config.disable_eos_stop,
808                        config.throughput_logging_enabled,
809                        config.search_embedding_model,
810                        config.search_callback.clone(),
811                        config.tool_callbacks.clone(),
812                        logger_for_engine,
813                        session_store_for_engine,
814                        file_store_for_engine,
815                    )
816                    .expect("Engine creation failed.");
817                    Arc::new(engine).run().await;
818                })
819            });
820
821            #[cfg(not(feature = "metal"))]
822            {
823                let rt = Runtime::new().unwrap();
824                rt.block_on(async move {
825                    file_store_for_engine.spawn_cleanup_task();
826                    let engine = Engine::new(
827                        tx_for_engine,
828                        rx,
829                        pipeline,
830                        method,
831                        config.no_kv_cache,
832                        config.no_prefix_cache,
833                        config.prefix_cache_n,
834                        config.disable_eos_stop,
835                        config.throughput_logging_enabled,
836                        config.search_embedding_model,
837                        config.search_callback.clone(),
838                        config.tool_callbacks.clone(),
839                        logger_for_engine,
840                        session_store_for_engine,
841                        file_store_for_engine,
842                    )
843                    .expect("Engine creation failed.");
844                    Arc::new(engine).run().await;
845                })
846            }
847        });
848
849        Ok(EngineInstance {
850            sender: tx,
851            engine_handler,
852            reboot_state,
853            config: hanzo_config,
854            category,
855            logger,
856            session_store,
857            file_store,
858        })
859    }
860
861    /// Initialize MCP and code-execution tool callbacks and merge them into `tool_callbacks`.
862    /// Used by both `HanzoBuilder::new` and `add_model` so dynamically added models pick up
863    /// the same external tools as the boot-time model.
864    async fn init_external_tool_callbacks(
865        #[cfg_attr(not(feature = "code-execution"), allow(unused_variables))] pipeline: &Arc<
866            tokio::sync::Mutex<dyn Pipeline>,
867        >,
868        tool_callbacks: &mut tools::ToolCallbacksWithTools,
869        mcp_client_config: Option<&McpClientConfig>,
870        #[cfg_attr(not(feature = "code-execution"), allow(unused_variables))]
871        code_exec_config: Option<&CodeExecutionConfig>,
872    ) {
873        if let Some(config) = mcp_client_config {
874            let mut mcp_client = McpClient::new(config.clone());
875            let total_servers = config.servers.len();
876
877            match mcp_client.initialize().await {
878                Ok(()) => {
879                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
880                    let tools_count = mcp_callbacks_with_tools.len();
881
882                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
883                        tool_callbacks.insert(name.clone(), callback_with_tool.clone());
884                    }
885
886                    if tools_count == 0 {
887                        warn!(
888                            "MCP client initialized but no tools were registered from {} servers",
889                            total_servers
890                        );
891                    } else {
892                        info!(
893                            "MCP client initialized successfully with {} tools from {} servers",
894                            tools_count, total_servers
895                        );
896                    }
897                }
898                Err(e) => {
899                    warn!(
900                        "Failed to initialize MCP client with {} configured servers: {}",
901                        total_servers, e
902                    );
903                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
904                }
905            }
906        }
907
908        #[cfg(feature = "code-execution")]
909        if let Some(code_exec_cfg) = code_exec_config {
910            let approval_callback = code_exec_cfg.approval_callback.as_ref().map(|callback| {
911                let callback = Arc::clone(callback);
912                Arc::new(move |approval: &hanzo_code_exec::CodeExecutionApproval| {
913                    let approval = CodeExecutionApproval {
914                        approval_id: approval.approval_id.clone(),
915                        session_id: approval.session_id.clone(),
916                        code: approval.code.clone(),
917                        outputs: approval.outputs.clone(),
918                        working_directory: approval.working_directory.clone(),
919                    };
920                    callback(&approval)
921                }) as Arc<hanzo_code_exec::CodeExecutionApprovalCallback>
922            });
923            let exec_config = hanzo_code_exec::CodeExecutionConfig {
924                python_path: code_exec_cfg.python_path.clone(),
925                timeout_secs: code_exec_cfg.timeout_secs,
926                working_directory: code_exec_cfg.working_directory.clone(),
927                sandbox_policy: code_exec_cfg.sandbox_policy.clone(),
928                permission: match code_exec_cfg.permission {
929                    CodeExecutionPermission::Auto => hanzo_code_exec::CodeExecutionPermission::Auto,
930                    CodeExecutionPermission::Ask => hanzo_code_exec::CodeExecutionPermission::Ask,
931                    CodeExecutionPermission::Deny => hanzo_code_exec::CodeExecutionPermission::Deny,
932                },
933                approval_callback,
934            };
935            match hanzo_code_exec::CodeExecutionManager::new(exec_config).await {
936                Ok(manager) => {
937                    let input_modalities: Vec<hanzo_code_exec::InputModality> = {
938                        let pipe = get_mut_arcmutex!(pipeline);
939                        pipe.get_metadata()
940                            .modalities
941                            .input
942                            .iter()
943                            .filter_map(|m| match m {
944                                pipeline::SupportedModality::Text => {
945                                    Some(hanzo_code_exec::InputModality::Text)
946                                }
947                                pipeline::SupportedModality::Vision => {
948                                    Some(hanzo_code_exec::InputModality::Vision)
949                                }
950                                pipeline::SupportedModality::Audio => {
951                                    Some(hanzo_code_exec::InputModality::Audio)
952                                }
953                                pipeline::SupportedModality::Video => {
954                                    Some(hanzo_code_exec::InputModality::Video)
955                                }
956                                _ => None,
957                            })
958                            .collect()
959                    };
960                    let effective = manager.effective_protection();
961                    let network = manager.network_mode();
962                    let callbacks = manager.get_tool_callbacks(&input_modalities);
963                    let count = callbacks.len();
964                    for (name, cb) in callbacks {
965                        tool_callbacks.insert(name, cb);
966                    }
967                    warn!("============================================================");
968                    warn!("  CODE EXECUTION IS ENABLED");
969                    warn!("  The model can execute arbitrary Python code on this machine.");
970                    if effective.any() {
971                        let fs = if effective.fs_isolated {
972                            "workdir + system libs only"
973                        } else {
974                            "NOT restricted"
975                        };
976                        let net = if effective.network_isolated {
977                            match network {
978                                Some(hanzo_sandbox::NetworkMode::None) => "denied",
979                                Some(hanzo_sandbox::NetworkMode::Loopback) => "loopback only",
980                                _ => "NOT restricted",
981                            }
982                        } else {
983                            "NOT restricted"
984                        };
985                        warn!(
986                            "  Sandbox: on. Filesystem: {fs}. Network: {net}. rlimits: {}.",
987                            if effective.rlimits_applied {
988                                "applied"
989                            } else {
990                                "not applied"
991                            }
992                        );
993                        if !effective.fs_isolated || !effective.network_isolated {
994                            warn!("  Some layers are inactive on this host. Use --sandbox on to make missing layers a hard error.");
995                        }
996                    } else {
997                        warn!("  Sandbox: OFF. Network and filesystem are NOT restricted.");
998                        warn!("  Pass a sandbox_policy (or --sandbox on at the CLI) to enable isolation.");
999                    }
1000                    warn!("  See: https://hanzoai.github.io/engine/reference/sandbox/");
1001                    warn!("============================================================");
1002                    info!("Code execution initialized with {count} tools");
1003                }
1004                Err(e) => {
1005                    warn!("Failed to initialize code execution: {e}");
1006                    warn!("Continuing without code execution functionality.");
1007                }
1008            }
1009        }
1010    }
1011
1012    async fn new(config: HanzoBuilder) -> Arc<Self> {
1013        info!("git revision: {HANZO_GIT_REVISION}");
1014        let HanzoBuilder {
1015            pipeline,
1016            method,
1017            model_id_override,
1018            log,
1019            no_kv_cache,
1020            no_prefix_cache,
1021            prefix_cache_n,
1022            disable_eos_stop,
1023            throughput_logging_enabled,
1024            search_embedding_model,
1025            search_callback,
1026            mut tool_callbacks,
1027            mcp_client_config,
1028            loader_config,
1029            #[cfg_attr(not(feature = "code-execution"), allow(unused_variables))]
1030            code_exec_config,
1031        } = config;
1032
1033        hanzo_quant::cublaslt::maybe_init_cublas_lt_wrapper(get_mut_arcmutex!(pipeline).device());
1034
1035        let no_kv_cache = no_kv_cache.unwrap_or(false);
1036        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
1037        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
1038        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
1039
1040        Self::init_external_tool_callbacks(
1041            &pipeline,
1042            &mut tool_callbacks,
1043            mcp_client_config.as_ref(),
1044            code_exec_config.as_ref(),
1045        )
1046        .await;
1047
1048        let reboot_state = RebootState {
1049            pipeline: pipeline.clone(),
1050            method: method.clone(),
1051            no_kv_cache,
1052            no_prefix_cache,
1053            prefix_cache_n,
1054            disable_eos_stop,
1055            throughput_logging_enabled,
1056            search_embedding_model,
1057            search_callback: search_callback.clone(),
1058            tool_callbacks: tool_callbacks.clone(),
1059            mcp_client_config: mcp_client_config.clone(),
1060            loader_config,
1061        };
1062
1063        let engine_config = EngineConfig {
1064            no_kv_cache,
1065            no_prefix_cache,
1066            prefix_cache_n,
1067            disable_eos_stop,
1068            throughput_logging_enabled,
1069            search_embedding_model,
1070            search_callback,
1071            tool_callbacks,
1072        };
1073
1074        let engine_instance =
1075            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
1076                .expect("Failed to create engine instance");
1077
1078        let pipeline_name = pipeline.try_lock().unwrap().name();
1079        let (id, alias_map) = match model_id_override {
1080            Some(override_id) => {
1081                let mut alias_map = HashMap::new();
1082                if override_id != pipeline_name {
1083                    alias_map.insert(pipeline_name.clone(), override_id.clone());
1084                }
1085                (override_id, alias_map)
1086            }
1087            None => (pipeline_name.clone(), HashMap::new()),
1088        };
1089
1090        if distributed::is_daemon() {
1091            let request_sender = engine_instance.sender.clone();
1092
1093            if cfg!(feature = "ring") {
1094                // Ring daemon replicator
1095                distributed::ring_daemon_replicator(request_sender);
1096            } else {
1097                // NCCL daemon replicator
1098                distributed::nccl_daemon_replicator(request_sender);
1099            }
1100
1101            #[allow(clippy::empty_loop)]
1102            loop {}
1103        }
1104
1105        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
1106        let is_multi_threaded = tokio::runtime::Handle::try_current()
1107            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
1108
1109        // Do a dummy run
1110        if !distributed::is_daemon()
1111            && is_multi_threaded
1112            && matches!(
1113                engine_instance.category,
1114                ModelCategory::Text | ModelCategory::Multimodal { .. }
1115            )
1116        {
1117            let clone_sender = engine_instance.sender.clone();
1118            tokio::task::block_in_place(|| {
1119                let (tx, mut rx) = channel(1);
1120                let req = Request::Normal(Box::new(NormalRequest {
1121                    id: 0,
1122                    messages: RequestMessage::Completion {
1123                        text: "hello".to_string(),
1124                        echo_prompt: false,
1125                        best_of: None,
1126                    },
1127                    sampling_params: SamplingParams {
1128                        max_len: Some(1),
1129                        ..SamplingParams::deterministic()
1130                    },
1131                    response: tx,
1132                    return_logprobs: false,
1133                    is_streaming: false,
1134                    constraint: Constraint::None,
1135                    suffix: None,
1136                    tool_choice: None,
1137                    tools: None,
1138                    logits_processors: None,
1139                    return_raw_logits: false,
1140                    web_search_options: None,
1141                    enable_code_execution: false,
1142                    code_execution_permission: None,
1143                    code_execution_approval_notifier: None,
1144                    agent_permission: None,
1145                    agent_approval_handler: None,
1146                    agent_approval_notifier: None,
1147                    max_tool_rounds: None,
1148                    tool_dispatch_url: None,
1149                    model_id: None,
1150                    truncate_sequence: false,
1151                    session_id: None,
1152                    files: None,
1153                }));
1154                debug!("Beginning dummy run.");
1155                let start = Instant::now();
1156                clone_sender.blocking_send(req).unwrap();
1157
1158                // Drain all responses from the channel until it's closed
1159                let mut received_any = false;
1160                while let Some(_resp) = rx.blocking_recv() {
1161                    received_any = true;
1162                }
1163
1164                if received_any {
1165                    let end = Instant::now();
1166                    debug!(
1167                        "Dummy run completed in {}s.",
1168                        end.duration_since(start).as_secs_f64()
1169                    );
1170                } else {
1171                    warn!("Dummy run failed!");
1172                }
1173            });
1174
1175            // Reset logger counters so the dummy run doesn't pollute stats
1176            engine_instance.logger.reset();
1177        }
1178
1179        // Create engines map with the first engine
1180        let mut engines = HashMap::new();
1181        engines.insert(id.clone(), engine_instance);
1182
1183        Arc::new(Self {
1184            engines: RwLock::new(engines),
1185            unloaded_models: RwLock::new(HashMap::new()),
1186            reloading_models: RwLock::new(HashSet::new()),
1187            default_engine_id: RwLock::new(Some(id.clone())),
1188            model_aliases: RwLock::new(alias_map),
1189            log,
1190            id,
1191            creation_time: SystemTime::now()
1192                .duration_since(UNIX_EPOCH)
1193                .expect("Time travel has occurred!")
1194                .as_secs(),
1195            next_request_id: Mutex::new(RefCell::new(1)),
1196        })
1197    }
1198
1199    /// Attempts to reboot a specific engine by model_id
1200    fn reboot_engine(&self, model_id: &str) -> Result<(), HanzoError> {
1201        let mut engines = self.engines.write().map_err(|_| {
1202            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
1203            HanzoError::EnginePoisoned
1204        })?;
1205
1206        if let Some(engine_instance) = engines.get(model_id) {
1207            if !engine_instance.engine_handler.is_finished() {
1208                tracing::info!("Engine {} already running, returning ok", model_id);
1209                return Ok(());
1210            }
1211
1212            let reboot_state = engine_instance.reboot_state.clone();
1213            let engine_config = EngineConfig {
1214                no_kv_cache: reboot_state.no_kv_cache,
1215                no_prefix_cache: reboot_state.no_prefix_cache,
1216                prefix_cache_n: reboot_state.prefix_cache_n,
1217                disable_eos_stop: reboot_state.disable_eos_stop,
1218                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
1219                search_embedding_model: reboot_state.search_embedding_model,
1220                search_callback: reboot_state.search_callback.clone(),
1221                tool_callbacks: reboot_state.tool_callbacks.clone(),
1222            };
1223            let new_engine_instance = Self::create_engine_instance(
1224                reboot_state.pipeline.clone(),
1225                reboot_state.method.clone(),
1226                engine_config,
1227                reboot_state,
1228            )
1229            .map_err(|e| {
1230                tracing::error!("Failed to create new engine instance: {}", e);
1231                HanzoError::EnginePoisoned
1232            })?;
1233
1234            engines.insert(model_id.to_string(), new_engine_instance);
1235            tracing::info!("Successfully rebooted engine {}", model_id);
1236            Ok(())
1237        } else {
1238            Err(HanzoError::EnginePoisoned)
1239        }
1240    }
1241
1242    fn engine_dead(&self, model_id: &str) -> Result<bool, HanzoError> {
1243        let engines = self.engines.read().map_err(|_| {
1244            tracing::warn!("Couldn't get read lock on engines!");
1245            HanzoError::EnginePoisoned
1246        })?;
1247
1248        if let Some(engine_instance) = engines.get(model_id) {
1249            Ok(engine_instance.engine_handler.is_finished())
1250        } else {
1251            Err(HanzoError::EnginePoisoned)
1252        }
1253    }
1254
1255    /// Get sender for a specific model. If model_id is None, uses default engine.
1256    /// If the model is unloaded, it will be automatically reloaded before returning the sender.
1257    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, HanzoError> {
1258        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1259
1260        // Check if model is loaded
1261        let is_loaded = {
1262            let engines = self
1263                .engines
1264                .read()
1265                .map_err(|_| HanzoError::SenderPoisoned)?;
1266            engines.contains_key(&resolved_model_id)
1267        };
1268
1269        if is_loaded {
1270            // Check if engine is dead and needs reboot
1271            if self.engine_dead(&resolved_model_id)? {
1272                tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
1273                self.reboot_engine(&resolved_model_id)?
1274            }
1275
1276            let engines = self
1277                .engines
1278                .read()
1279                .map_err(|_| HanzoError::SenderPoisoned)?;
1280            if let Some(engine_instance) = engines.get(&resolved_model_id) {
1281                return Ok(engine_instance.sender.clone());
1282            }
1283        }
1284
1285        // Check if model is unloaded - trigger auto-reload
1286        let is_unloaded = {
1287            let unloaded = self
1288                .unloaded_models
1289                .read()
1290                .map_err(|_| HanzoError::EnginePoisoned)?;
1291            unloaded.contains_key(&resolved_model_id)
1292        };
1293
1294        if is_unloaded {
1295            tracing::info!(
1296                "Model {} is unloaded, triggering auto-reload",
1297                resolved_model_id
1298            );
1299            self.reload_model_blocking(&resolved_model_id)?;
1300
1301            // After reload, get the sender
1302            let engines = self
1303                .engines
1304                .read()
1305                .map_err(|_| HanzoError::SenderPoisoned)?;
1306            if let Some(engine_instance) = engines.get(&resolved_model_id) {
1307                return Ok(engine_instance.sender.clone());
1308            }
1309        }
1310
1311        Err(HanzoError::ModelNotFound(resolved_model_id))
1312    }
1313
1314    /// Look up a file across all loaded engines. `None` if missing or expired.
1315    pub fn find_file(&self, id: &str) -> Option<Arc<files::File>> {
1316        let engines = self.engines.read().ok()?;
1317        for instance in engines.values() {
1318            if let Some(f) = instance.file_store.get(id) {
1319                return Some(f);
1320            }
1321        }
1322        None
1323    }
1324
1325    /// Every non-expired file across all loaded engines, including session-less runs. Order unspecified.
1326    pub fn list_files(&self) -> Vec<Arc<files::File>> {
1327        let mut out = Vec::new();
1328        let Ok(engines) = self.engines.read() else {
1329            return out;
1330        };
1331        for instance in engines.values() {
1332            out.extend(instance.file_store.list_all());
1333        }
1334        out
1335    }
1336
1337    /// Returns whether the file existed.
1338    pub fn remove_file(&self, id: &str) -> bool {
1339        let Ok(engines) = self.engines.read() else {
1340            return false;
1341        };
1342        for instance in engines.values() {
1343            if instance.file_store.remove(id) {
1344                return true;
1345            }
1346        }
1347        false
1348    }
1349
1350    /// Agentic session store for `model_id` (or the default model). Returns an `Arc` to lock for inspect/mutate.
1351    pub fn get_session_store(
1352        &self,
1353        model_id: Option<&str>,
1354    ) -> Result<Arc<std::sync::Mutex<engine::agentic_session::AgenticSessionStore>>, HanzoError>
1355    {
1356        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1357        let engines = self
1358            .engines
1359            .read()
1360            .map_err(|_| HanzoError::SenderPoisoned)?;
1361        engines
1362            .get(&resolved_model_id)
1363            .map(|e| Arc::clone(&e.session_store))
1364            .ok_or(HanzoError::ModelNotFound(resolved_model_id))
1365    }
1366
1367    fn get_file_store(&self, model_id: Option<&str>) -> Result<files::FileStore, HanzoError> {
1368        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1369        let engines = self
1370            .engines
1371            .read()
1372            .map_err(|_| HanzoError::SenderPoisoned)?;
1373        engines
1374            .get(&resolved_model_id)
1375            .map(|e| e.file_store.clone())
1376            .ok_or(HanzoError::ModelNotFound(resolved_model_id))
1377    }
1378
1379    /// Export an agentic session by ID. Bundles the session's files (full bodies). `None` if missing.
1380    pub fn export_session(
1381        &self,
1382        model_id: Option<&str>,
1383        session_id: &str,
1384    ) -> Result<Option<engine::agentic_session::SerializedSession>, HanzoError> {
1385        let store = self.get_session_store(model_id)?;
1386        let exported = {
1387            let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1388            guard
1389                .export(session_id)
1390                .map_err(|e| HanzoError::Other(e.to_string()))?
1391        };
1392        let Some(mut session) = exported else {
1393            return Ok(None);
1394        };
1395        let file_store = self.get_file_store(model_id)?;
1396        session.files = file_store
1397            .list_for_session(session_id)
1398            .into_iter()
1399            .map(|arc| (*arc).clone())
1400            .collect();
1401        Ok(Some(session))
1402    }
1403
1404    /// Replaces any existing session with the same ID. Restores its files into the file store.
1405    pub fn import_session(
1406        &self,
1407        model_id: Option<&str>,
1408        session_id: String,
1409        session: engine::agentic_session::SerializedSession,
1410    ) -> Result<(), HanzoError> {
1411        let files = session.files.clone();
1412        let store = self.get_session_store(model_id)?;
1413        {
1414            let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1415            guard
1416                .import(session_id.clone(), session)
1417                .map_err(|e| HanzoError::Other(e.to_string()))?;
1418        }
1419        let file_store = self.get_file_store(model_id)?;
1420        for f in files {
1421            file_store.insert(f, Some(session_id.clone()));
1422        }
1423        Ok(())
1424    }
1425
1426    /// Clone the first `num_turns` complete turns from `src` into `dest`. A turn ends at the
1427    /// first assistant message without `tool_calls`. Used for branching: the new session diverges
1428    /// cleanly from the truncated prefix, so the branch's later edits don't bleed back.
1429    pub fn fork_session(
1430        &self,
1431        model_id: Option<&str>,
1432        src_session_id: &str,
1433        dest_session_id: String,
1434        num_turns: usize,
1435    ) -> Result<(), HanzoError> {
1436        let store = self.get_session_store(model_id)?;
1437        let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1438        guard
1439            .fork(src_session_id, dest_session_id, num_turns)
1440            .map_err(|e| HanzoError::Other(e.to_string()))
1441    }
1442
1443    /// Delete an agentic session. Returns whether the session existed.
1444    pub fn delete_session(
1445        &self,
1446        model_id: Option<&str>,
1447        session_id: &str,
1448    ) -> Result<bool, HanzoError> {
1449        let store = self.get_session_store(model_id)?;
1450        let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1451        Ok(guard.delete(session_id))
1452    }
1453
1454    /// All stored session IDs. SDK-only, not exposed via HTTP.
1455    pub fn list_session_ids(&self, model_id: Option<&str>) -> Result<Vec<String>, HanzoError> {
1456        let store = self.get_session_store(model_id)?;
1457        let guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1458        Ok(guard.list_ids())
1459    }
1460
1461    pub fn get_id(&self) -> String {
1462        self.id.clone()
1463    }
1464
1465    pub fn get_creation_time(&self) -> u64 {
1466        self.creation_time
1467    }
1468
1469    fn resolve_alias(&self, model_id: &str) -> Result<String, HanzoError> {
1470        let aliases = self
1471            .model_aliases
1472            .read()
1473            .map_err(|_| HanzoError::SenderPoisoned)?;
1474        if let Some(primary_id) = aliases.get(model_id) {
1475            Ok(primary_id.clone())
1476        } else {
1477            Ok(model_id.to_string())
1478        }
1479    }
1480
1481    fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, HanzoError> {
1482        match model_id {
1483            Some(id) => self.resolve_alias(id),
1484            None => {
1485                let default_lock = self
1486                    .default_engine_id
1487                    .read()
1488                    .map_err(|_| HanzoError::SenderPoisoned)?;
1489                Ok(default_lock
1490                    .as_ref()
1491                    .ok_or(HanzoError::EnginePoisoned)?
1492                    .clone())
1493            }
1494        }
1495    }
1496
1497    /// Register an alternate model ID that resolves to an existing model.
1498    pub fn register_model_alias(
1499        &self,
1500        alias: impl Into<String>,
1501        model_id: &str,
1502    ) -> Result<(), String> {
1503        let alias = alias.into();
1504        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1505
1506        if alias == resolved_model_id {
1507            return Ok(());
1508        }
1509
1510        let reloading = self
1511            .reloading_models
1512            .read()
1513            .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1514        let model_reloading = reloading.contains(&resolved_model_id);
1515        let alias_conflict = reloading.contains(&alias);
1516        drop(reloading);
1517
1518        let engines = self
1519            .engines
1520            .read()
1521            .map_err(|_| "Failed to acquire read lock on engines")?;
1522        let model_loaded = engines.contains_key(&resolved_model_id);
1523        let alias_conflict = alias_conflict || engines.contains_key(&alias);
1524        drop(engines);
1525
1526        let unloaded = self
1527            .unloaded_models
1528            .read()
1529            .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1530        let model_unloaded = unloaded.contains_key(&resolved_model_id);
1531        let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
1532        drop(unloaded);
1533
1534        if !(model_loaded || model_unloaded || model_reloading) {
1535            return Err(format!("Model {resolved_model_id} not found"));
1536        }
1537
1538        if alias_conflict {
1539            return Err(format!(
1540                "Alias '{}' conflicts with an existing model ID",
1541                alias
1542            ));
1543        }
1544
1545        let mut aliases = self
1546            .model_aliases
1547            .write()
1548            .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1549        if let Some(existing) = aliases.get(&alias) {
1550            if existing == &resolved_model_id {
1551                return Ok(());
1552            }
1553            return Err(format!(
1554                "Alias '{}' is already assigned to model '{}'",
1555                alias, existing
1556            ));
1557        }
1558        aliases.insert(alias, resolved_model_id);
1559        Ok(())
1560    }
1561
1562    /// Check if a model is known (loaded, unloaded, or reloading), resolving aliases if needed.
1563    pub fn model_exists(&self, model_id: &str) -> Result<bool, HanzoError> {
1564        let resolved_model_id = self.resolve_alias(model_id)?;
1565
1566        let reloading = self
1567            .reloading_models
1568            .read()
1569            .map_err(|_| HanzoError::EnginePoisoned)?;
1570        if reloading.contains(&resolved_model_id) {
1571            return Ok(true);
1572        }
1573        drop(reloading);
1574
1575        let engines = self
1576            .engines
1577            .read()
1578            .map_err(|_| HanzoError::EnginePoisoned)?;
1579        if engines.contains_key(&resolved_model_id) {
1580            return Ok(true);
1581        }
1582        drop(engines);
1583
1584        let unloaded = self
1585            .unloaded_models
1586            .read()
1587            .map_err(|_| HanzoError::EnginePoisoned)?;
1588        if unloaded.contains_key(&resolved_model_id) {
1589            return Ok(true);
1590        }
1591
1592        Ok(false)
1593    }
1594
1595    /// Get the interval logger for a specific model. If model_id is None, uses default engine.
1596    pub fn get_logger(&self, model_id: Option<&str>) -> Result<Arc<IntervalLogger>, HanzoError> {
1597        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1598
1599        let engines = self
1600            .engines
1601            .read()
1602            .map_err(|_| HanzoError::SenderPoisoned)?;
1603        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1604            Ok(engine_instance.logger.clone())
1605        } else {
1606            Err(HanzoError::EnginePoisoned)
1607        }
1608    }
1609
1610    /// Get model category for a specific model. If model_id is None, uses default engine.
1611    pub fn get_model_category(&self, model_id: Option<&str>) -> Result<ModelCategory, HanzoError> {
1612        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1613
1614        let engines = self
1615            .engines
1616            .read()
1617            .map_err(|_| HanzoError::SenderPoisoned)?;
1618        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1619            Ok(engine_instance.category.clone())
1620        } else {
1621            Err(HanzoError::EnginePoisoned)
1622        }
1623    }
1624
1625    /// Get the maximum supported sequence length for a model, if applicable.
1626    pub fn max_sequence_length(&self, model_id: Option<&str>) -> Result<Option<usize>, HanzoError> {
1627        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1628
1629        let engines = self
1630            .engines
1631            .read()
1632            .map_err(|_| HanzoError::SenderPoisoned)?;
1633        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1634            Ok(engine_instance.config.max_seq_len)
1635        } else {
1636            Err(HanzoError::EnginePoisoned)
1637        }
1638    }
1639
1640    pub fn next_request_id(&self) -> usize {
1641        let l = self.next_request_id.lock().unwrap();
1642        let last = &mut *l.borrow_mut();
1643        let last_v = *last;
1644        *last += 1;
1645        last_v
1646    }
1647
1648    /// Add a new model engine to the Hanzo instance
1649    pub async fn add_model(
1650        &self,
1651        model_id: String,
1652        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
1653        method: SchedulerConfig,
1654        config: AddModelConfig,
1655    ) -> Result<(), String> {
1656        {
1657            let reloading = self
1658                .reloading_models
1659                .read()
1660                .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1661            if reloading.contains(&model_id) {
1662                return Err(format!("Model {model_id} is currently reloading"));
1663            }
1664        }
1665        {
1666            let engines = self
1667                .engines
1668                .read()
1669                .map_err(|_| "Failed to acquire read lock on engines")?;
1670            if engines.contains_key(&model_id) {
1671                return Err(format!("Model {model_id} already exists"));
1672            }
1673        }
1674        {
1675            let unloaded = self
1676                .unloaded_models
1677                .read()
1678                .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1679            if unloaded.contains_key(&model_id) {
1680                return Err(format!("Model {model_id} already exists (unloaded)"));
1681            }
1682        }
1683        {
1684            let aliases = self
1685                .model_aliases
1686                .read()
1687                .map_err(|_| "Failed to acquire read lock on model_aliases")?;
1688            if aliases.contains_key(&model_id) {
1689                return Err(format!(
1690                    "Model ID '{}' conflicts with an existing alias",
1691                    model_id
1692                ));
1693            }
1694        }
1695
1696        let mut engine_config = config.engine_config;
1697        Self::init_external_tool_callbacks(
1698            &pipeline,
1699            &mut engine_config.tool_callbacks,
1700            config.mcp_client_config.as_ref(),
1701            config.code_exec_config.as_ref(),
1702        )
1703        .await;
1704
1705        let reboot_state = RebootState {
1706            pipeline: pipeline.clone(),
1707            method: method.clone(),
1708            no_kv_cache: engine_config.no_kv_cache,
1709            no_prefix_cache: engine_config.no_prefix_cache,
1710            prefix_cache_n: engine_config.prefix_cache_n,
1711            disable_eos_stop: engine_config.disable_eos_stop,
1712            throughput_logging_enabled: engine_config.throughput_logging_enabled,
1713            search_embedding_model: engine_config.search_embedding_model,
1714            search_callback: engine_config.search_callback.clone(),
1715            tool_callbacks: engine_config.tool_callbacks.clone(),
1716            mcp_client_config: config.mcp_client_config.clone(),
1717            loader_config: config.loader_config.clone(),
1718        };
1719
1720        let engine_instance =
1721            Self::create_engine_instance(pipeline, method, engine_config, reboot_state)?;
1722
1723        let mut engines = self
1724            .engines
1725            .write()
1726            .map_err(|_| "Failed to acquire write lock on engines")?;
1727        engines.insert(model_id.clone(), engine_instance);
1728
1729        // If this is the first model, set it as default
1730        if engines.len() == 1 {
1731            let mut default_lock = self
1732                .default_engine_id
1733                .write()
1734                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1735            *default_lock = Some(model_id.clone());
1736            info!("First model added, setting '{}' as default", model_id);
1737        }
1738
1739        Ok(())
1740    }
1741
1742    /// Remove a model engine from the Hanzo instance
1743    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
1744        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1745        let mut engines = self
1746            .engines
1747            .write()
1748            .map_err(|_| "Failed to acquire write lock on engines")?;
1749
1750        if engines.len() <= 1 {
1751            return Err("Cannot remove the last model from Hanzo".to_string());
1752        }
1753
1754        if let Some(engine_instance) = engines.remove(&resolved_model_id) {
1755            // Send terminate signal to the engine
1756            let _ = engine_instance.sender.blocking_send(Request::Terminate);
1757
1758            // If this was the default engine, set a new default
1759            let mut default_lock = self
1760                .default_engine_id
1761                .write()
1762                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1763            if let Some(ref default_id) = *default_lock {
1764                if default_id == &resolved_model_id {
1765                    // Set the first available engine as the new default
1766                    *default_lock = engines.keys().next().cloned();
1767                }
1768            }
1769            drop(default_lock);
1770            drop(engines);
1771
1772            // Remove any aliases pointing to the removed model
1773            let mut aliases = self
1774                .model_aliases
1775                .write()
1776                .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1777            aliases.retain(|_, target| target != &resolved_model_id);
1778
1779            Ok(())
1780        } else {
1781            Err(format!("Model {resolved_model_id} not found"))
1782        }
1783    }
1784
1785    /// List all available model IDs
1786    pub fn list_models(&self) -> Result<Vec<String>, String> {
1787        let engines = self
1788            .engines
1789            .read()
1790            .map_err(|_| "Failed to acquire read lock on engines")?;
1791        Ok(engines.keys().cloned().collect())
1792    }
1793
1794    /// Get the current default model ID
1795    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
1796        let default_lock = self
1797            .default_engine_id
1798            .read()
1799            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
1800        Ok(default_lock.clone())
1801    }
1802
1803    /// Set the default model ID
1804    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
1805        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1806        let engines = self
1807            .engines
1808            .read()
1809            .map_err(|_| "Failed to acquire read lock on engines")?;
1810        if !engines.contains_key(&resolved_model_id) {
1811            return Err(format!("Model {resolved_model_id} not found"));
1812        }
1813        drop(engines);
1814
1815        let mut default_lock = self
1816            .default_engine_id
1817            .write()
1818            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1819        let old_default = default_lock.clone();
1820        *default_lock = Some(resolved_model_id.clone());
1821
1822        // Log the change
1823        info!(
1824            "Default model changed: {:?} -> {:?}",
1825            old_default, resolved_model_id
1826        );
1827
1828        Ok(())
1829    }
1830
1831    /// Dispatch a request to the appropriate engine based on the model_id in the request
1832    pub fn send_request(&self, mut request: Request) -> Result<(), HanzoError> {
1833        let model_id = match &mut request {
1834            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
1835            _ => None, // Other request types don't specify model_id
1836        };
1837
1838        let sender = self.get_sender(model_id)?;
1839        sender
1840            .blocking_send(request)
1841            .map_err(|_| HanzoError::SenderPoisoned)
1842    }
1843
1844    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
1845        if let Some(file) = &this.log {
1846            let mut f = OpenOptions::new()
1847                .append(true)
1848                .create(true) // Optionally create the file if it doesn't already exist
1849                .open(file)
1850                .expect("Unable to open file");
1851            let time = chrono::offset::Local::now();
1852            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
1853                .expect("Unable to write data");
1854        }
1855    }
1856
1857    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1858        if let Some(file) = &this.log {
1859            let mut f = OpenOptions::new()
1860                .append(true)
1861                .create(true) // Optionally create the file if it doesn't already exist
1862                .open(file)
1863                .expect("Unable to open file");
1864            let time = chrono::offset::Local::now();
1865            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1866            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1867                .expect("Unable to write data");
1868        }
1869    }
1870
1871    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1872        if let Some(file) = &this.log {
1873            let mut f = OpenOptions::new()
1874                .append(true)
1875                .create(true) // Optionally create the file if it doesn't already exist
1876                .open(file)
1877                .expect("Unable to open file");
1878            let time = chrono::offset::Local::now();
1879            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1880                .expect("Unable to write data");
1881        }
1882    }
1883
1884    /// Get the number of tools available for a specific model (including MCP tools)
1885    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1886        let resolved_model_id = self
1887            .resolve_alias_or_default(model_id)
1888            .map_err(|e| e.to_string())?;
1889
1890        let engines = self
1891            .engines
1892            .read()
1893            .map_err(|_| "Failed to acquire read lock on engines")?;
1894        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1895            Ok(engine_instance.reboot_state.tool_callbacks.len())
1896        } else {
1897            Err(format!("Model {resolved_model_id} not found"))
1898        }
1899    }
1900
1901    /// MCP-provided tools registered for `model_id`. Excludes built-ins (web search, code exec). Returns `(name, description)` per tool.
1902    pub fn list_mcp_tools(
1903        &self,
1904        model_id: Option<&str>,
1905    ) -> Result<Vec<(String, Option<String>)>, String> {
1906        let resolved_model_id = self
1907            .resolve_alias_or_default(model_id)
1908            .map_err(|e| e.to_string())?;
1909
1910        let engines = self
1911            .engines
1912            .read()
1913            .map_err(|_| "Failed to acquire read lock on engines")?;
1914        let engine_instance = engines
1915            .get(&resolved_model_id)
1916            .ok_or_else(|| format!("Model {resolved_model_id} not found"))?;
1917
1918        let mut tools: Vec<(String, Option<String>)> = engine_instance
1919            .reboot_state
1920            .tool_callbacks
1921            .values()
1922            .filter(|cb| {
1923                let name = &cb.tool.function.name;
1924                // Exclude built-in tools; everything else came from MCP.
1925                !search::search_tool_called(name) && {
1926                    #[cfg(feature = "code-execution")]
1927                    {
1928                        !hanzo_code_exec::code_exec_tool_called(name)
1929                    }
1930                    #[cfg(not(feature = "code-execution"))]
1931                    {
1932                        true
1933                    }
1934                }
1935            })
1936            .map(|cb| {
1937                (
1938                    cb.tool.function.name.clone(),
1939                    cb.tool.function.description.clone(),
1940                )
1941            })
1942            .collect();
1943        tools.sort_by(|a, b| a.0.cmp(&b.0));
1944        Ok(tools)
1945    }
1946
1947    /// Check if MCP client is configured for a specific model
1948    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1949        let resolved_model_id = self
1950            .resolve_alias_or_default(model_id)
1951            .map_err(|e| e.to_string())?;
1952
1953        let engines = self
1954            .engines
1955            .read()
1956            .map_err(|_| "Failed to acquire read lock on engines")?;
1957        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1958            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1959        } else {
1960            Err(format!("Model {resolved_model_id} not found"))
1961        }
1962    }
1963
1964    /// Get config for a specific model
1965    pub fn config(&self, model_id: Option<&str>) -> Result<HanzoConfig, String> {
1966        let resolved_model_id = self
1967            .resolve_alias_or_default(model_id)
1968            .map_err(|e| e.to_string())?;
1969
1970        let engines = self
1971            .engines
1972            .read()
1973            .map_err(|_| "Failed to acquire read lock on engines")?;
1974        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1975            Ok(engine_instance.config.clone())
1976        } else {
1977            Err(format!("Model {resolved_model_id} not found"))
1978        }
1979    }
1980
1981    /// Unload a model from memory while preserving its configuration for later reload.
1982    /// The model can be reloaded automatically when a request is sent to it, or manually
1983    /// using `reload_model()`.
1984    ///
1985    /// Note: The model must have been added with a `ModelLoaderConfig` for auto-reload to work.
1986    /// Models added via `HanzoBuilder` without explicit loader config cannot be reloaded.
1987    pub fn unload_model(&self, model_id: &str) -> Result<(), HanzoError> {
1988        let resolved_model_id = self.resolve_alias(model_id)?;
1989        // Check if already unloaded
1990        {
1991            let unloaded = self
1992                .unloaded_models
1993                .read()
1994                .map_err(|_| HanzoError::EnginePoisoned)?;
1995            if unloaded.contains_key(&resolved_model_id) {
1996                return Err(HanzoError::ModelAlreadyUnloaded(resolved_model_id.clone()));
1997            }
1998        }
1999
2000        // Get the engine instance and create UnloadedModelState
2001        let mut engines = self
2002            .engines
2003            .write()
2004            .map_err(|_| HanzoError::EnginePoisoned)?;
2005
2006        let engine_instance = engines
2007            .remove(&resolved_model_id)
2008            .ok_or_else(|| HanzoError::ModelNotFound(resolved_model_id.clone()))?;
2009
2010        // Check if we have loader config for reloading
2011        let loader_config = engine_instance
2012            .reboot_state
2013            .loader_config
2014            .clone()
2015            .ok_or_else(|| HanzoError::NoLoaderConfig(resolved_model_id.clone()))?;
2016
2017        // Create the unloaded state
2018        let unloaded_state = UnloadedModelState {
2019            loader_config,
2020            scheduler_config: engine_instance.reboot_state.method.clone(),
2021            engine_config: EngineConfig {
2022                no_kv_cache: engine_instance.reboot_state.no_kv_cache,
2023                no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
2024                prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
2025                disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
2026                throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
2027                search_embedding_model: engine_instance.reboot_state.search_embedding_model,
2028                search_callback: engine_instance.reboot_state.search_callback.clone(),
2029                tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
2030            },
2031            mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
2032            category: engine_instance.category.clone(),
2033            hanzo_config: engine_instance.config.clone(),
2034        };
2035
2036        // Send terminate signal to the engine
2037        let _ = engine_instance.sender.try_send(Request::Terminate);
2038
2039        drop(engines);
2040
2041        // Store the unloaded state
2042        let mut unloaded = self
2043            .unloaded_models
2044            .write()
2045            .map_err(|_| HanzoError::EnginePoisoned)?;
2046        unloaded.insert(resolved_model_id.to_string(), unloaded_state);
2047
2048        // Update default if needed
2049        let mut default_lock = self
2050            .default_engine_id
2051            .write()
2052            .map_err(|_| HanzoError::EnginePoisoned)?;
2053        if let Some(ref default_id) = *default_lock {
2054            if default_id == &resolved_model_id {
2055                // Set the first available engine as the new default
2056                let engines = self
2057                    .engines
2058                    .read()
2059                    .map_err(|_| HanzoError::EnginePoisoned)?;
2060                *default_lock = engines.keys().next().cloned();
2061            }
2062        }
2063
2064        info!("Model {} unloaded successfully", resolved_model_id);
2065        Ok(())
2066    }
2067
2068    /// Manually reload a previously unloaded model.
2069    /// This is also called automatically by `get_sender()` when a request targets an unloaded model.
2070    pub async fn reload_model(&self, model_id: &str) -> Result<(), HanzoError> {
2071        let resolved_model_id = self.resolve_alias(model_id)?;
2072        // Check if already reloading
2073        {
2074            let reloading = self
2075                .reloading_models
2076                .read()
2077                .map_err(|_| HanzoError::EnginePoisoned)?;
2078            if reloading.contains(&resolved_model_id) {
2079                return Err(HanzoError::ModelReloading(resolved_model_id.clone()));
2080            }
2081        }
2082
2083        // Mark as reloading
2084        {
2085            let mut reloading = self
2086                .reloading_models
2087                .write()
2088                .map_err(|_| HanzoError::EnginePoisoned)?;
2089            reloading.insert(resolved_model_id.clone());
2090        }
2091
2092        // Get the unloaded state
2093        let unloaded_state = {
2094            let unloaded = self
2095                .unloaded_models
2096                .read()
2097                .map_err(|_| HanzoError::EnginePoisoned)?;
2098            unloaded
2099                .get(&resolved_model_id)
2100                .cloned()
2101                .ok_or_else(|| HanzoError::ModelNotFound(resolved_model_id.clone()))?
2102        };
2103
2104        // Attempt to reload
2105        let result = self
2106            .do_reload_model(&resolved_model_id, unloaded_state)
2107            .await;
2108
2109        // Remove from reloading set
2110        {
2111            let mut reloading = self
2112                .reloading_models
2113                .write()
2114                .map_err(|_| HanzoError::EnginePoisoned)?;
2115            reloading.remove(&resolved_model_id);
2116        }
2117
2118        result
2119    }
2120
2121    /// Internal method to perform the actual model reload
2122    async fn do_reload_model(
2123        &self,
2124        model_id: &str,
2125        unloaded_state: UnloadedModelState,
2126    ) -> Result<(), HanzoError> {
2127        use crate::model_loader::LoaderBuilder;
2128
2129        info!("Reloading model: {}", model_id);
2130
2131        let loader_config = &unloaded_state.loader_config;
2132
2133        // Build the loader from the stored config
2134        let loader = LoaderBuilder::new(loader_config.model_selected.clone())
2135            .with_chat_template(loader_config.chat_template.clone())
2136            .with_jinja_explicit(loader_config.jinja_explicit.clone())
2137            .build()
2138            .map_err(|e| HanzoError::ReloadFailed(format!("Failed to build loader: {e}")))?;
2139
2140        // Load the model
2141        let pipeline = loader
2142            .load_model_from_hf(
2143                None,
2144                loader_config.token_source.clone(),
2145                &loader_config.dtype,
2146                &loader_config.device,
2147                loader_config.silent,
2148                loader_config.device_map_setting.clone(),
2149                loader_config.isq,
2150                loader_config.paged_attn_config,
2151            )
2152            .map_err(|e| HanzoError::ReloadFailed(format!("Failed to load model: {e}")))?;
2153
2154        if let Some(mtp_config) = loader_config.mtp_config.clone() {
2155            pipeline
2156                .blocking_lock()
2157                .attach_speculative(SpeculativeConfig::Mtp(mtp_config))
2158                .map_err(|e| {
2159                    HanzoError::ReloadFailed(format!(
2160                        "Failed to attach MTP speculative decoding: {e}"
2161                    ))
2162                })?;
2163        }
2164
2165        // Create the reboot state
2166        let reboot_state = RebootState {
2167            pipeline: pipeline.clone(),
2168            method: unloaded_state.scheduler_config.clone(),
2169            no_kv_cache: unloaded_state.engine_config.no_kv_cache,
2170            no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
2171            prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
2172            disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
2173            throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
2174            search_embedding_model: unloaded_state.engine_config.search_embedding_model,
2175            search_callback: unloaded_state.engine_config.search_callback.clone(),
2176            tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
2177            mcp_client_config: unloaded_state.mcp_client_config.clone(),
2178            loader_config: Some(unloaded_state.loader_config.clone()),
2179        };
2180
2181        let engine_instance = Self::create_engine_instance(
2182            pipeline,
2183            unloaded_state.scheduler_config,
2184            unloaded_state.engine_config,
2185            reboot_state,
2186        )
2187        .map_err(|e| HanzoError::ReloadFailed(format!("Failed to create engine: {e}")))?;
2188
2189        // Add to engines map
2190        {
2191            let mut engines = self
2192                .engines
2193                .write()
2194                .map_err(|_| HanzoError::EnginePoisoned)?;
2195            engines.insert(model_id.to_string(), engine_instance);
2196        }
2197
2198        // Remove from unloaded map
2199        {
2200            let mut unloaded = self
2201                .unloaded_models
2202                .write()
2203                .map_err(|_| HanzoError::EnginePoisoned)?;
2204            unloaded.remove(model_id);
2205        }
2206
2207        info!("Model {} reloaded successfully", model_id);
2208        Ok(())
2209    }
2210
2211    /// Synchronous version of reload_model for use in non-async contexts.
2212    ///
2213    /// This method handles different runtime contexts appropriately:
2214    /// - If called from a multi-threaded tokio runtime, uses `block_in_place`
2215    /// - If called from a single-threaded runtime, returns an error (use `reload_model()` instead)
2216    /// - If called outside any runtime, creates a temporary runtime
2217    pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), HanzoError> {
2218        match tokio::runtime::Handle::try_current() {
2219            Ok(handle) => {
2220                if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
2221                    Err(HanzoError::ReloadFailed(
2222                        "Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
2223                    ))
2224                } else {
2225                    tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
2226                }
2227            }
2228            Err(_) => {
2229                let rt = tokio::runtime::Runtime::new().map_err(|e| {
2230                    HanzoError::ReloadFailed(format!("Failed to create runtime: {e}"))
2231                })?;
2232                rt.block_on(self.reload_model(model_id))
2233            }
2234        }
2235    }
2236
2237    /// List all unloaded model IDs
2238    pub fn list_unloaded_models(&self) -> Result<Vec<String>, HanzoError> {
2239        let unloaded = self
2240            .unloaded_models
2241            .read()
2242            .map_err(|_| HanzoError::EnginePoisoned)?;
2243        Ok(unloaded.keys().cloned().collect())
2244    }
2245
2246    /// Check if a model is currently loaded (as opposed to unloaded)
2247    pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, HanzoError> {
2248        let resolved_model_id = self.resolve_alias(model_id)?;
2249        let engines = self
2250            .engines
2251            .read()
2252            .map_err(|_| HanzoError::EnginePoisoned)?;
2253        Ok(engines.contains_key(&resolved_model_id))
2254    }
2255
2256    /// Get the status of a model, or None if not found
2257    pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, HanzoError> {
2258        let resolved_model_id = self.resolve_alias(model_id)?;
2259        // Check if reloading
2260        {
2261            let reloading = self
2262                .reloading_models
2263                .read()
2264                .map_err(|_| HanzoError::EnginePoisoned)?;
2265            if reloading.contains(&resolved_model_id) {
2266                return Ok(Some(ModelStatus::Reloading));
2267            }
2268        }
2269
2270        // Check if loaded
2271        {
2272            let engines = self
2273                .engines
2274                .read()
2275                .map_err(|_| HanzoError::EnginePoisoned)?;
2276            if engines.contains_key(&resolved_model_id) {
2277                return Ok(Some(ModelStatus::Loaded));
2278            }
2279        }
2280
2281        // Check if unloaded
2282        {
2283            let unloaded = self
2284                .unloaded_models
2285                .read()
2286                .map_err(|_| HanzoError::EnginePoisoned)?;
2287            if unloaded.contains_key(&resolved_model_id) {
2288                return Ok(Some(ModelStatus::Unloaded));
2289            }
2290        }
2291
2292        Ok(None)
2293    }
2294
2295    /// List all models with their status
2296    pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, HanzoError> {
2297        let mut result = Vec::new();
2298
2299        // Get reloading models
2300        let reloading = self
2301            .reloading_models
2302            .read()
2303            .map_err(|_| HanzoError::EnginePoisoned)?;
2304        for model_id in reloading.iter() {
2305            result.push((model_id.clone(), ModelStatus::Reloading));
2306        }
2307        drop(reloading);
2308
2309        // Get loaded models
2310        let engines = self
2311            .engines
2312            .read()
2313            .map_err(|_| HanzoError::EnginePoisoned)?;
2314        for model_id in engines.keys() {
2315            result.push((model_id.clone(), ModelStatus::Loaded));
2316        }
2317        drop(engines);
2318
2319        // Get unloaded models
2320        let unloaded = self
2321            .unloaded_models
2322            .read()
2323            .map_err(|_| HanzoError::EnginePoisoned)?;
2324        for model_id in unloaded.keys() {
2325            // Skip if already in reloading
2326            if !result.iter().any(|(id, _)| id == model_id) {
2327                result.push((model_id.clone(), ModelStatus::Unloaded));
2328            }
2329        }
2330
2331        Ok(result)
2332    }
2333}