1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use engine::Engine;
3pub use engine::{
4 agentic_session::{AgenticSessionStore, SerializedSession, SerializedVideo},
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 EngineInstruction, IntervalLogger, SearchEmbeddingModel, DEFAULT_MAX_TOOL_ROUNDS,
7 ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
8};
9use hanzo_ml::Device;
10use hf_hub::Cache;
11pub use lora::Ordering;
12pub use pipeline::ModelCategory;
13pub use pipeline::Pipeline;
14#[cfg(feature = "pyo3_macros")]
15use pyo3::exceptions::PyValueError;
16use std::collections::{HashMap, HashSet};
17use std::future::Future;
18use std::pin::Pin;
19use std::sync::OnceLock;
20use std::time::{Duration, Instant};
21use std::{
22 cell::RefCell,
23 error::Error,
24 fs::OpenOptions,
25 io::Write,
26 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
27 thread::{self, JoinHandle},
28 time::{SystemTime, UNIX_EPOCH},
29};
30use tokio::sync::mpsc::{channel, Sender};
31use tracing::{debug, info, warn};
32
33pub const HANZO_GIT_REVISION: &str = match option_env!("HANZO_GIT_REVISION") {
34 Some(value) => value,
35 None => "unknown",
36};
37
38mod cuda;
39mod device_map;
40mod engine;
41mod lora;
42mod metal;
43mod model_loader;
44mod moe;
45mod ops;
46mod video_input;
47mod vulkan;
48pub use model_loader::{
49 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
50};
51pub use video_input::{sample_frame_indices, VideoInput};
52pub mod disk_kv_cache;
53mod embedding_models;
54mod kv_cache;
55mod search;
56
57mod model_selected;
58pub use model_selected::ModelSelected;
59pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
60
61mod amoe;
62mod attention;
63mod diagnostics;
64mod diffusion_models;
65pub mod distributed;
66pub mod files;
67mod gguf;
68pub mod layers;
69mod layers_masker;
70mod layers_utils;
71pub mod matformer;
72mod mla;
73mod models;
74mod paged_attention;
75mod pipeline;
76mod prefix_cacher;
77pub mod reasoning_parsers;
78mod request;
79mod response;
80mod sampler;
81mod scheduler;
82mod sequence;
83pub mod speculative;
84mod speech_models;
85mod toml_selector;
86mod tools;
87mod topology;
88mod utils;
89mod vision_models;
90mod xlora_models;
91
92pub use diagnostics::{
93 check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
94 DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
95};
96mod tuning;
97pub use tuning::{
98 auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
99};
100
101pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
102pub use device_map::{
103 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
104};
105pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
106pub use hanzo_audio::AudioInput;
107pub use hanzo_llm_mcp::{
108 AgentPermission, AgentToolApprovalNotifier, AgentToolApprovalRequest, AgentToolKind,
109 AgentToolMetadata, AgentToolSource, CalledFunction, CodeExecutionApprovalNotifier,
110 CodeExecutionApprovalRequest, CodeExecutionPermission, Function, MultimodalToolCallback, Tool,
111 ToolCallContext, ToolCallback, ToolCallbackKind, ToolCallbackWithTool, ToolOutput, ToolType,
112};
113pub use hanzo_llm_mcp::{
114 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
115};
116pub use hanzo_quant::{IsqBits, IsqType, MULTI_LORA_DELIMITER};
117pub use hanzo_sandbox::{NetworkMode, SandboxPolicy};
118
119#[derive(Clone, serde::Serialize, serde::Deserialize)]
121pub struct CodeExecutionConfig {
122 #[serde(default = "default_python_path")]
124 pub python_path: std::path::PathBuf,
125 #[serde(default = "default_timeout_secs")]
127 pub timeout_secs: u64,
128 #[serde(default)]
130 pub working_directory: Option<std::path::PathBuf>,
131 #[serde(default)]
135 pub sandbox_policy: Option<hanzo_sandbox::SandboxPolicy>,
136 #[serde(default)]
137 pub permission: CodeExecutionPermission,
138 #[serde(skip)]
139 pub approval_callback: Option<CodeExecutionApprovalCallback>,
140}
141
142impl std::fmt::Debug for CodeExecutionConfig {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 f.debug_struct("CodeExecutionConfig")
145 .field("python_path", &self.python_path)
146 .field("timeout_secs", &self.timeout_secs)
147 .field("working_directory", &self.working_directory)
148 .field("sandbox_policy", &self.sandbox_policy)
149 .field("permission", &self.permission)
150 .field("approval_callback", &self.approval_callback.is_some())
151 .finish()
152 }
153}
154
155#[derive(Clone, Debug)]
156pub struct AgentToolApproval {
157 pub approval_id: String,
158 pub session_id: String,
159 pub round: usize,
160 pub tool: AgentToolMetadata,
161 pub arguments: serde_json::Value,
162}
163
164#[derive(Clone, Debug)]
165pub struct AgentToolApprovalDecision {
166 pub approve: bool,
167 pub remember_for_session: bool,
168 pub message: Option<String>,
169}
170
171impl AgentToolApprovalDecision {
172 pub fn approve() -> Self {
173 Self {
174 approve: true,
175 remember_for_session: false,
176 message: None,
177 }
178 }
179
180 pub fn approve_for_session() -> Self {
181 Self {
182 approve: true,
183 remember_for_session: true,
184 message: None,
185 }
186 }
187
188 pub fn deny(message: Option<String>) -> Self {
189 Self {
190 approve: false,
191 remember_for_session: false,
192 message,
193 }
194 }
195
196 pub fn deny_with_message(message: impl Into<String>) -> Self {
197 Self {
198 approve: false,
199 remember_for_session: false,
200 message: Some(message.into()),
201 }
202 }
203
204 pub fn with_remember_for_session(mut self, remember_for_session: bool) -> Self {
205 self.remember_for_session = remember_for_session;
206 self
207 }
208}
209
210pub type AgentToolApprovalCallback =
211 Arc<dyn Fn(&AgentToolApproval) -> AgentToolApprovalDecision + Send + Sync + 'static>;
212
213pub type AgentToolApprovalFuture =
214 Pin<Box<dyn Future<Output = AgentToolApprovalDecision> + Send + 'static>>;
215pub type AgentToolApprovalAsyncCallback =
216 Arc<dyn Fn(AgentToolApproval) -> AgentToolApprovalFuture + Send + Sync + 'static>;
217
218#[derive(Clone)]
219pub enum AgentToolApprovalHandler {
220 Sync(AgentToolApprovalCallback),
221 Async(AgentToolApprovalAsyncCallback),
222}
223
224impl AgentToolApprovalHandler {
225 pub fn from_sync(callback: AgentToolApprovalCallback) -> Self {
226 Self::Sync(callback)
227 }
228
229 pub fn from_async(callback: AgentToolApprovalAsyncCallback) -> Self {
230 Self::Async(callback)
231 }
232}
233
234#[derive(Clone, Debug)]
235pub struct CodeExecutionApproval {
236 pub approval_id: String,
237 pub session_id: String,
238 pub code: String,
239 pub outputs: Vec<String>,
240 pub working_directory: Option<std::path::PathBuf>,
241}
242
243pub type CodeExecutionApprovalCallback =
244 Arc<dyn Fn(&CodeExecutionApproval) -> bool + Send + Sync + 'static>;
245
246fn default_python_path() -> std::path::PathBuf {
247 if cfg!(windows) {
248 std::path::PathBuf::from("python")
249 } else {
250 std::path::PathBuf::from("python3")
251 }
252}
253fn default_timeout_secs() -> u64 {
254 30
255}
256
257impl Default for CodeExecutionConfig {
258 fn default() -> Self {
259 Self {
260 python_path: default_python_path(),
261 timeout_secs: default_timeout_secs(),
262 working_directory: None,
263 sandbox_policy: None,
264 permission: CodeExecutionPermission::Auto,
265 approval_callback: None,
266 }
267 }
268}
269pub use files::{
270 format_from_name, is_text_mime, mime_for_format, File, FileContent, FileSource, FileStore,
271 RequestedFile, MODEL_INLINE_BYTES, WIRE_EMBED_LIMIT_BYTES,
272};
273pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
274pub use pipeline::hf::{
275 hf_home_dir, hf_hub_cache_dir, hf_token_path, is_hf_hub_offline, probe_hf_repo_files,
276 HF_HUB_OFFLINE_ENV,
277};
278pub use pipeline::{
279 chat_template::ChatTemplate, expand_isq_value, parse_isq_value, parse_uqff_shard,
280 resolve_uqff_shorthand, AdapterPaths, AnyMoeLoader, AnyMoePipeline, AutoDeviceMapParams,
281 AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
282 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
283 EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
284 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
285 GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
286 Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
287 ModelPaths, MultimodalLoader, MultimodalLoaderBuilder, MultimodalLoaderType,
288 MultimodalPromptPrefixer, MultimodalSpecificConfig, NormalLoader, NormalLoaderBuilder,
289 NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader,
290 SpeechLoader, SpeechPipeline, Starcoder2Loader, SupportedModality, TokenSource,
291 UQFF_MULTI_FILE_DELIMITER,
292};
293pub use request::{
294 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
295 LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
296 SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
297};
298pub use response::*;
299pub use sampler::{
300 CustomLogitsProcessor, DrySamplingParams, ModelGenerationDefaults, SamplingParams, StopTokens,
301 TopLogprob,
302};
303pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
304pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
305use serde::Serialize;
306pub use speculative::{MtpConfig, SpeculativeConfig};
307pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
308use tokio::runtime::Runtime;
309use toml_selector::{TomlLoaderArgs, TomlSelector};
310pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
311pub use topology::{LayerTopology, Topology};
312pub use utils::debug::initialize_logging;
313pub use utils::memory_usage::MemoryUsage;
314pub use utils::normal::{ModelDType, TryIntoDType};
315pub use utils::{paged_attn_supported, using_flash_attn};
316
317pub use llguidance;
319
320pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
322pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
323
324#[derive(Clone)]
326pub struct EngineConfig {
327 pub no_kv_cache: bool,
328 pub no_prefix_cache: bool,
329 pub prefix_cache_n: usize,
330 pub disable_eos_stop: bool,
331 pub throughput_logging_enabled: bool,
332 pub search_embedding_model: Option<SearchEmbeddingModel>,
333 pub search_callback: Option<Arc<SearchCallback>>,
334 pub tool_callbacks: tools::ToolCallbacksWithTools,
335}
336
337impl Default for EngineConfig {
338 fn default() -> Self {
339 Self {
340 no_kv_cache: false,
341 no_prefix_cache: false,
342 prefix_cache_n: 16,
343 disable_eos_stop: false,
344 throughput_logging_enabled: true,
345 search_embedding_model: None,
346 search_callback: None,
347 tool_callbacks: HashMap::new(),
348 }
349 }
350}
351
352#[derive(Clone)]
354pub struct AddModelConfig {
355 pub engine_config: EngineConfig,
356 pub mcp_client_config: Option<McpClientConfig>,
357 pub loader_config: Option<ModelLoaderConfig>,
360 pub code_exec_config: Option<CodeExecutionConfig>,
361}
362
363impl AddModelConfig {
364 pub fn new(engine_config: EngineConfig) -> Self {
365 Self {
366 engine_config,
367 mcp_client_config: None,
368 loader_config: None,
369 code_exec_config: None,
370 }
371 }
372
373 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
374 self.mcp_client_config = Some(mcp_config);
375 self
376 }
377
378 pub fn with_code_execution(mut self, config: CodeExecutionConfig) -> Self {
379 self.code_exec_config = Some(config);
380 self
381 }
382
383 pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
386 self.loader_config = Some(loader_config);
387 self
388 }
389}
390
391#[derive(Clone)]
392pub struct HanzoConfig {
393 pub kind: ModelKind,
394 pub device: Device,
395 pub category: ModelCategory,
396 pub modalities: Modalities,
397 pub max_seq_len: Option<usize>,
398 pub generation_defaults: Option<ModelGenerationDefaults>,
399}
400
401#[derive(Clone)]
404pub struct ModelLoaderConfig {
405 pub model_selected: ModelSelected,
407 pub token_source: TokenSource,
409 pub hf_revision: Option<String>,
411 pub dtype: ModelDType,
413 pub device: Device,
415 pub device_map_setting: DeviceMapSetting,
417 pub isq: Option<IsqType>,
419 pub paged_attn_config: Option<PagedAttentionConfig>,
421 pub silent: bool,
423 pub chat_template: Option<String>,
425 pub jinja_explicit: Option<String>,
427 pub mtp_config: Option<MtpConfig>,
429}
430
431#[derive(Clone)]
434pub struct UnloadedModelState {
435 pub loader_config: ModelLoaderConfig,
437 pub scheduler_config: SchedulerConfig,
439 pub engine_config: EngineConfig,
441 pub mcp_client_config: Option<McpClientConfig>,
443 pub category: ModelCategory,
445 pub hanzo_config: HanzoConfig,
447}
448
449struct EngineInstance {
451 sender: Sender<Request>,
452 engine_handler: JoinHandle<()>,
453 reboot_state: RebootState,
454 config: HanzoConfig,
455 category: ModelCategory,
456 logger: Arc<IntervalLogger>,
457 session_store: Arc<std::sync::Mutex<engine::agentic_session::AgenticSessionStore>>,
459 pub(crate) file_store: files::FileStore,
461}
462
463pub struct Hanzo {
480 engines: RwLock<HashMap<String, EngineInstance>>,
481 unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
483 reloading_models: RwLock<HashSet<String>>,
485 default_engine_id: RwLock<Option<String>>,
486 model_aliases: RwLock<HashMap<String, String>>,
488 log: Option<String>,
489 id: String,
490 creation_time: u64,
491 next_request_id: Mutex<RefCell<usize>>,
492}
493
494#[derive(Clone)]
495struct RebootState {
496 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
497 method: SchedulerConfig,
498 no_kv_cache: bool,
499 no_prefix_cache: bool,
500 prefix_cache_n: usize,
501 disable_eos_stop: bool,
502 throughput_logging_enabled: bool,
503 search_embedding_model: Option<SearchEmbeddingModel>,
504 search_callback: Option<Arc<search::SearchCallback>>,
505 tool_callbacks: tools::ToolCallbacksWithTools,
506 mcp_client_config: Option<McpClientConfig>,
507 loader_config: Option<ModelLoaderConfig>,
509}
510
511#[derive(Debug, Clone, Copy, PartialEq, Eq)]
513pub enum ModelStatus {
514 Loaded,
515 Unloaded,
516 Reloading,
517}
518
519impl std::fmt::Display for ModelStatus {
520 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
521 match self {
522 ModelStatus::Loaded => write!(f, "loaded"),
523 ModelStatus::Unloaded => write!(f, "unloaded"),
524 ModelStatus::Reloading => write!(f, "reloading"),
525 }
526 }
527}
528
529#[derive(Debug)]
530pub enum HanzoError {
531 EnginePoisoned,
532 SenderPoisoned,
533 ModelNotFound(String),
535 ModelReloading(String),
537 ReloadFailed(String),
539 NoLoaderConfig(String),
541 ModelAlreadyLoaded(String),
543 ModelAlreadyUnloaded(String),
545 Other(String),
547}
548
549impl std::fmt::Display for HanzoError {
550 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
551 write!(f, "{:?}", &self)
552 }
553}
554
555impl std::error::Error for HanzoError {}
556
557#[cfg(feature = "pyo3_macros")]
558impl From<HanzoError> for pyo3::PyErr {
559 fn from(value: HanzoError) -> Self {
560 PyValueError::new_err(format!("{value:?}"))
561 }
562}
563
564pub struct HanzoBuilder {
568 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
569 method: SchedulerConfig,
570 model_id_override: Option<String>,
571 log: Option<String>,
572 no_kv_cache: Option<bool>,
573 no_prefix_cache: Option<bool>,
574 prefix_cache_n: Option<usize>,
575 disable_eos_stop: Option<bool>,
576 throughput_logging_enabled: bool,
577 search_embedding_model: Option<SearchEmbeddingModel>,
578 search_callback: Option<Arc<SearchCallback>>,
579 tool_callbacks: tools::ToolCallbacksWithTools,
580 mcp_client_config: Option<McpClientConfig>,
581 loader_config: Option<ModelLoaderConfig>,
582 code_exec_config: Option<CodeExecutionConfig>,
583}
584
585impl HanzoBuilder {
586 pub fn new(
590 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
591 method: SchedulerConfig,
592 throughput_logging: bool,
593 search_embedding_model: Option<SearchEmbeddingModel>,
594 ) -> Self {
595 Self {
596 pipeline,
597 method,
598 model_id_override: None,
599 log: None,
600 no_kv_cache: None,
601 no_prefix_cache: None,
602 prefix_cache_n: None,
603 disable_eos_stop: None,
604 throughput_logging_enabled: throughput_logging,
605 search_embedding_model,
606 search_callback: None,
607 tool_callbacks: HashMap::new(),
608 mcp_client_config: None,
609 loader_config: None,
610 code_exec_config: None,
611 }
612 }
613
614 pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
616 self.model_id_override = Some(model_id.into());
617 self
618 }
619
620 pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
623 self.loader_config = Some(loader_config);
624 self
625 }
626 pub fn with_log(mut self, log: String) -> Self {
627 self.log = Some(log);
628 self
629 }
630 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
631 self.log = log;
632 self
633 }
634 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
635 self.no_kv_cache = Some(no_kv_cache);
636 self
637 }
638 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
639 self.no_prefix_cache = Some(no_prefix_cache);
640 self
641 }
642 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
643 self.prefix_cache_n = Some(prefix_cache_n);
644 self
645 }
646 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
647 self.disable_eos_stop = Some(disable_eos_stop);
648 self
649 }
650
651 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
653 self.search_callback = Some(search_callback);
654 self
655 }
656
657 pub fn with_tool_callback(
659 mut self,
660 name: impl Into<String>,
661 tool_callback: Arc<ToolCallback>,
662 ) -> Self {
663 let name = name.into();
664 self.tool_callbacks.insert(
666 name.clone(),
667 ToolCallbackWithTool {
668 callback: ToolCallbackKind::Text(tool_callback),
669 tool: Tool {
670 tp: ToolType::Function,
671 function: Function {
672 description: None,
673 name,
674 parameters: None,
675 strict: None,
676 },
677 },
678 },
679 );
680 self
681 }
682
683 pub fn with_tool_callback_and_tool(
686 mut self,
687 name: impl Into<String>,
688 tool_callback: Arc<ToolCallback>,
689 tool: Tool,
690 ) -> Self {
691 let name = name.into();
692 self.tool_callbacks.insert(
693 name,
694 ToolCallbackWithTool {
695 callback: ToolCallbackKind::Text(tool_callback),
696 tool,
697 },
698 );
699 self
700 }
701
702 pub fn with_tool_callback_with_tool(
704 mut self,
705 name: impl Into<String>,
706 callback_with_tool: ToolCallbackWithTool,
707 ) -> Self {
708 self.tool_callbacks.insert(name.into(), callback_with_tool);
709 self
710 }
711
712 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
714 self.mcp_client_config = Some(config);
715 self
716 }
717
718 pub fn with_code_execution(mut self, config: CodeExecutionConfig) -> Self {
720 self.code_exec_config = Some(config);
721 self
722 }
723
724 pub async fn build(self) -> Arc<Hanzo> {
725 Hanzo::new(self).await
726 }
727}
728
729impl Drop for Hanzo {
730 fn drop(&mut self) {
731 if let Ok(engines) = self.engines.read() {
733 for (_, engine) in engines.iter() {
734 let _ = engine.sender.try_send(Request::Terminate);
736 }
737 }
738 }
739}
740
741impl Hanzo {
742 fn create_engine_instance(
744 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
745 method: SchedulerConfig,
746 config: EngineConfig,
747 reboot_state: RebootState,
748 ) -> Result<EngineInstance, String> {
749 let (tx, rx) = channel(10_000);
750
751 let pipeline_guard = pipeline.try_lock().unwrap();
752 let category = pipeline_guard.category();
753 let metadata = pipeline_guard.get_metadata();
754 let kind = metadata.kind.clone();
755 let device = pipeline_guard.device();
756 let modalities = metadata.modalities.clone();
757 let max_seq_len = match &category {
758 ModelCategory::Diffusion | ModelCategory::Speech => None,
759 _ => Some(metadata.max_seq_len),
760 };
761 let generation_defaults = pipeline_guard.generation_defaults();
762 let encoder_cache_counters = pipeline_guard.encoder_cache_counters();
763 drop(pipeline_guard);
764
765 let logger = Arc::new(IntervalLogger::new(
766 Duration::from_secs(5),
767 encoder_cache_counters,
768 ));
769 let logger_for_engine = logger.clone();
770
771 info!("Pipeline input modalities are {:?}", &modalities.input);
772 info!("Pipeline output modalities are {:?}", &modalities.output);
773
774 let hanzo_config = HanzoConfig {
775 kind,
776 device,
777 category: category.clone(),
778 modalities,
779 max_seq_len,
780 generation_defaults,
781 };
782
783 let session_store = Arc::new(std::sync::Mutex::new(
786 engine::agentic_session::AgenticSessionStore::new(),
787 ));
788 let session_store_for_engine = Arc::clone(&session_store);
789 let file_store = files::FileStore::new();
790 let file_store_for_engine = file_store.clone();
791
792 let tx_for_engine = tx.clone();
793 let engine_handler = thread::spawn(move || {
794 #[cfg(feature = "metal")]
795 objc::rc::autoreleasepool(move || {
796 let rt = Runtime::new().unwrap();
797 rt.block_on(async move {
798 file_store_for_engine.spawn_cleanup_task();
799 let engine = Engine::new(
800 tx_for_engine,
801 rx,
802 pipeline,
803 method,
804 config.no_kv_cache,
805 config.no_prefix_cache,
806 config.prefix_cache_n,
807 config.disable_eos_stop,
808 config.throughput_logging_enabled,
809 config.search_embedding_model,
810 config.search_callback.clone(),
811 config.tool_callbacks.clone(),
812 logger_for_engine,
813 session_store_for_engine,
814 file_store_for_engine,
815 )
816 .expect("Engine creation failed.");
817 Arc::new(engine).run().await;
818 })
819 });
820
821 #[cfg(not(feature = "metal"))]
822 {
823 let rt = Runtime::new().unwrap();
824 rt.block_on(async move {
825 file_store_for_engine.spawn_cleanup_task();
826 let engine = Engine::new(
827 tx_for_engine,
828 rx,
829 pipeline,
830 method,
831 config.no_kv_cache,
832 config.no_prefix_cache,
833 config.prefix_cache_n,
834 config.disable_eos_stop,
835 config.throughput_logging_enabled,
836 config.search_embedding_model,
837 config.search_callback.clone(),
838 config.tool_callbacks.clone(),
839 logger_for_engine,
840 session_store_for_engine,
841 file_store_for_engine,
842 )
843 .expect("Engine creation failed.");
844 Arc::new(engine).run().await;
845 })
846 }
847 });
848
849 Ok(EngineInstance {
850 sender: tx,
851 engine_handler,
852 reboot_state,
853 config: hanzo_config,
854 category,
855 logger,
856 session_store,
857 file_store,
858 })
859 }
860
861 async fn init_external_tool_callbacks(
865 #[cfg_attr(not(feature = "code-execution"), allow(unused_variables))] pipeline: &Arc<
866 tokio::sync::Mutex<dyn Pipeline>,
867 >,
868 tool_callbacks: &mut tools::ToolCallbacksWithTools,
869 mcp_client_config: Option<&McpClientConfig>,
870 #[cfg_attr(not(feature = "code-execution"), allow(unused_variables))]
871 code_exec_config: Option<&CodeExecutionConfig>,
872 ) {
873 if let Some(config) = mcp_client_config {
874 let mut mcp_client = McpClient::new(config.clone());
875 let total_servers = config.servers.len();
876
877 match mcp_client.initialize().await {
878 Ok(()) => {
879 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
880 let tools_count = mcp_callbacks_with_tools.len();
881
882 for (name, callback_with_tool) in mcp_callbacks_with_tools {
883 tool_callbacks.insert(name.clone(), callback_with_tool.clone());
884 }
885
886 if tools_count == 0 {
887 warn!(
888 "MCP client initialized but no tools were registered from {} servers",
889 total_servers
890 );
891 } else {
892 info!(
893 "MCP client initialized successfully with {} tools from {} servers",
894 tools_count, total_servers
895 );
896 }
897 }
898 Err(e) => {
899 warn!(
900 "Failed to initialize MCP client with {} configured servers: {}",
901 total_servers, e
902 );
903 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
904 }
905 }
906 }
907
908 #[cfg(feature = "code-execution")]
909 if let Some(code_exec_cfg) = code_exec_config {
910 let approval_callback = code_exec_cfg.approval_callback.as_ref().map(|callback| {
911 let callback = Arc::clone(callback);
912 Arc::new(move |approval: &hanzo_code_exec::CodeExecutionApproval| {
913 let approval = CodeExecutionApproval {
914 approval_id: approval.approval_id.clone(),
915 session_id: approval.session_id.clone(),
916 code: approval.code.clone(),
917 outputs: approval.outputs.clone(),
918 working_directory: approval.working_directory.clone(),
919 };
920 callback(&approval)
921 }) as Arc<hanzo_code_exec::CodeExecutionApprovalCallback>
922 });
923 let exec_config = hanzo_code_exec::CodeExecutionConfig {
924 python_path: code_exec_cfg.python_path.clone(),
925 timeout_secs: code_exec_cfg.timeout_secs,
926 working_directory: code_exec_cfg.working_directory.clone(),
927 sandbox_policy: code_exec_cfg.sandbox_policy.clone(),
928 permission: match code_exec_cfg.permission {
929 CodeExecutionPermission::Auto => hanzo_code_exec::CodeExecutionPermission::Auto,
930 CodeExecutionPermission::Ask => hanzo_code_exec::CodeExecutionPermission::Ask,
931 CodeExecutionPermission::Deny => hanzo_code_exec::CodeExecutionPermission::Deny,
932 },
933 approval_callback,
934 };
935 match hanzo_code_exec::CodeExecutionManager::new(exec_config).await {
936 Ok(manager) => {
937 let input_modalities: Vec<hanzo_code_exec::InputModality> = {
938 let pipe = get_mut_arcmutex!(pipeline);
939 pipe.get_metadata()
940 .modalities
941 .input
942 .iter()
943 .filter_map(|m| match m {
944 pipeline::SupportedModality::Text => {
945 Some(hanzo_code_exec::InputModality::Text)
946 }
947 pipeline::SupportedModality::Vision => {
948 Some(hanzo_code_exec::InputModality::Vision)
949 }
950 pipeline::SupportedModality::Audio => {
951 Some(hanzo_code_exec::InputModality::Audio)
952 }
953 pipeline::SupportedModality::Video => {
954 Some(hanzo_code_exec::InputModality::Video)
955 }
956 _ => None,
957 })
958 .collect()
959 };
960 let effective = manager.effective_protection();
961 let network = manager.network_mode();
962 let callbacks = manager.get_tool_callbacks(&input_modalities);
963 let count = callbacks.len();
964 for (name, cb) in callbacks {
965 tool_callbacks.insert(name, cb);
966 }
967 warn!("============================================================");
968 warn!(" CODE EXECUTION IS ENABLED");
969 warn!(" The model can execute arbitrary Python code on this machine.");
970 if effective.any() {
971 let fs = if effective.fs_isolated {
972 "workdir + system libs only"
973 } else {
974 "NOT restricted"
975 };
976 let net = if effective.network_isolated {
977 match network {
978 Some(hanzo_sandbox::NetworkMode::None) => "denied",
979 Some(hanzo_sandbox::NetworkMode::Loopback) => "loopback only",
980 _ => "NOT restricted",
981 }
982 } else {
983 "NOT restricted"
984 };
985 warn!(
986 " Sandbox: on. Filesystem: {fs}. Network: {net}. rlimits: {}.",
987 if effective.rlimits_applied {
988 "applied"
989 } else {
990 "not applied"
991 }
992 );
993 if !effective.fs_isolated || !effective.network_isolated {
994 warn!(" Some layers are inactive on this host. Use --sandbox on to make missing layers a hard error.");
995 }
996 } else {
997 warn!(" Sandbox: OFF. Network and filesystem are NOT restricted.");
998 warn!(" Pass a sandbox_policy (or --sandbox on at the CLI) to enable isolation.");
999 }
1000 warn!(" See: https://hanzoai.github.io/engine/reference/sandbox/");
1001 warn!("============================================================");
1002 info!("Code execution initialized with {count} tools");
1003 }
1004 Err(e) => {
1005 warn!("Failed to initialize code execution: {e}");
1006 warn!("Continuing without code execution functionality.");
1007 }
1008 }
1009 }
1010 }
1011
1012 async fn new(config: HanzoBuilder) -> Arc<Self> {
1013 info!("git revision: {HANZO_GIT_REVISION}");
1014 let HanzoBuilder {
1015 pipeline,
1016 method,
1017 model_id_override,
1018 log,
1019 no_kv_cache,
1020 no_prefix_cache,
1021 prefix_cache_n,
1022 disable_eos_stop,
1023 throughput_logging_enabled,
1024 search_embedding_model,
1025 search_callback,
1026 mut tool_callbacks,
1027 mcp_client_config,
1028 loader_config,
1029 #[cfg_attr(not(feature = "code-execution"), allow(unused_variables))]
1030 code_exec_config,
1031 } = config;
1032
1033 hanzo_quant::cublaslt::maybe_init_cublas_lt_wrapper(get_mut_arcmutex!(pipeline).device());
1034
1035 let no_kv_cache = no_kv_cache.unwrap_or(false);
1036 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
1037 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
1038 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
1039
1040 Self::init_external_tool_callbacks(
1041 &pipeline,
1042 &mut tool_callbacks,
1043 mcp_client_config.as_ref(),
1044 code_exec_config.as_ref(),
1045 )
1046 .await;
1047
1048 let reboot_state = RebootState {
1049 pipeline: pipeline.clone(),
1050 method: method.clone(),
1051 no_kv_cache,
1052 no_prefix_cache,
1053 prefix_cache_n,
1054 disable_eos_stop,
1055 throughput_logging_enabled,
1056 search_embedding_model,
1057 search_callback: search_callback.clone(),
1058 tool_callbacks: tool_callbacks.clone(),
1059 mcp_client_config: mcp_client_config.clone(),
1060 loader_config,
1061 };
1062
1063 let engine_config = EngineConfig {
1064 no_kv_cache,
1065 no_prefix_cache,
1066 prefix_cache_n,
1067 disable_eos_stop,
1068 throughput_logging_enabled,
1069 search_embedding_model,
1070 search_callback,
1071 tool_callbacks,
1072 };
1073
1074 let engine_instance =
1075 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
1076 .expect("Failed to create engine instance");
1077
1078 let pipeline_name = pipeline.try_lock().unwrap().name();
1079 let (id, alias_map) = match model_id_override {
1080 Some(override_id) => {
1081 let mut alias_map = HashMap::new();
1082 if override_id != pipeline_name {
1083 alias_map.insert(pipeline_name.clone(), override_id.clone());
1084 }
1085 (override_id, alias_map)
1086 }
1087 None => (pipeline_name.clone(), HashMap::new()),
1088 };
1089
1090 if distributed::is_daemon() {
1091 let request_sender = engine_instance.sender.clone();
1092
1093 if cfg!(feature = "ring") {
1094 distributed::ring_daemon_replicator(request_sender);
1096 } else {
1097 distributed::nccl_daemon_replicator(request_sender);
1099 }
1100
1101 #[allow(clippy::empty_loop)]
1102 loop {}
1103 }
1104
1105 let is_multi_threaded = tokio::runtime::Handle::try_current()
1107 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
1108
1109 if !distributed::is_daemon()
1111 && is_multi_threaded
1112 && matches!(
1113 engine_instance.category,
1114 ModelCategory::Text | ModelCategory::Multimodal { .. }
1115 )
1116 {
1117 let clone_sender = engine_instance.sender.clone();
1118 tokio::task::block_in_place(|| {
1119 let (tx, mut rx) = channel(1);
1120 let req = Request::Normal(Box::new(NormalRequest {
1121 id: 0,
1122 messages: RequestMessage::Completion {
1123 text: "hello".to_string(),
1124 echo_prompt: false,
1125 best_of: None,
1126 },
1127 sampling_params: SamplingParams {
1128 max_len: Some(1),
1129 ..SamplingParams::deterministic()
1130 },
1131 response: tx,
1132 return_logprobs: false,
1133 is_streaming: false,
1134 constraint: Constraint::None,
1135 suffix: None,
1136 tool_choice: None,
1137 tools: None,
1138 logits_processors: None,
1139 return_raw_logits: false,
1140 web_search_options: None,
1141 enable_code_execution: false,
1142 code_execution_permission: None,
1143 code_execution_approval_notifier: None,
1144 agent_permission: None,
1145 agent_approval_handler: None,
1146 agent_approval_notifier: None,
1147 max_tool_rounds: None,
1148 tool_dispatch_url: None,
1149 model_id: None,
1150 truncate_sequence: false,
1151 session_id: None,
1152 files: None,
1153 }));
1154 debug!("Beginning dummy run.");
1155 let start = Instant::now();
1156 clone_sender.blocking_send(req).unwrap();
1157
1158 let mut received_any = false;
1160 while let Some(_resp) = rx.blocking_recv() {
1161 received_any = true;
1162 }
1163
1164 if received_any {
1165 let end = Instant::now();
1166 debug!(
1167 "Dummy run completed in {}s.",
1168 end.duration_since(start).as_secs_f64()
1169 );
1170 } else {
1171 warn!("Dummy run failed!");
1172 }
1173 });
1174
1175 engine_instance.logger.reset();
1177 }
1178
1179 let mut engines = HashMap::new();
1181 engines.insert(id.clone(), engine_instance);
1182
1183 Arc::new(Self {
1184 engines: RwLock::new(engines),
1185 unloaded_models: RwLock::new(HashMap::new()),
1186 reloading_models: RwLock::new(HashSet::new()),
1187 default_engine_id: RwLock::new(Some(id.clone())),
1188 model_aliases: RwLock::new(alias_map),
1189 log,
1190 id,
1191 creation_time: SystemTime::now()
1192 .duration_since(UNIX_EPOCH)
1193 .expect("Time travel has occurred!")
1194 .as_secs(),
1195 next_request_id: Mutex::new(RefCell::new(1)),
1196 })
1197 }
1198
1199 fn reboot_engine(&self, model_id: &str) -> Result<(), HanzoError> {
1201 let mut engines = self.engines.write().map_err(|_| {
1202 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
1203 HanzoError::EnginePoisoned
1204 })?;
1205
1206 if let Some(engine_instance) = engines.get(model_id) {
1207 if !engine_instance.engine_handler.is_finished() {
1208 tracing::info!("Engine {} already running, returning ok", model_id);
1209 return Ok(());
1210 }
1211
1212 let reboot_state = engine_instance.reboot_state.clone();
1213 let engine_config = EngineConfig {
1214 no_kv_cache: reboot_state.no_kv_cache,
1215 no_prefix_cache: reboot_state.no_prefix_cache,
1216 prefix_cache_n: reboot_state.prefix_cache_n,
1217 disable_eos_stop: reboot_state.disable_eos_stop,
1218 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
1219 search_embedding_model: reboot_state.search_embedding_model,
1220 search_callback: reboot_state.search_callback.clone(),
1221 tool_callbacks: reboot_state.tool_callbacks.clone(),
1222 };
1223 let new_engine_instance = Self::create_engine_instance(
1224 reboot_state.pipeline.clone(),
1225 reboot_state.method.clone(),
1226 engine_config,
1227 reboot_state,
1228 )
1229 .map_err(|e| {
1230 tracing::error!("Failed to create new engine instance: {}", e);
1231 HanzoError::EnginePoisoned
1232 })?;
1233
1234 engines.insert(model_id.to_string(), new_engine_instance);
1235 tracing::info!("Successfully rebooted engine {}", model_id);
1236 Ok(())
1237 } else {
1238 Err(HanzoError::EnginePoisoned)
1239 }
1240 }
1241
1242 fn engine_dead(&self, model_id: &str) -> Result<bool, HanzoError> {
1243 let engines = self.engines.read().map_err(|_| {
1244 tracing::warn!("Couldn't get read lock on engines!");
1245 HanzoError::EnginePoisoned
1246 })?;
1247
1248 if let Some(engine_instance) = engines.get(model_id) {
1249 Ok(engine_instance.engine_handler.is_finished())
1250 } else {
1251 Err(HanzoError::EnginePoisoned)
1252 }
1253 }
1254
1255 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, HanzoError> {
1258 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1259
1260 let is_loaded = {
1262 let engines = self
1263 .engines
1264 .read()
1265 .map_err(|_| HanzoError::SenderPoisoned)?;
1266 engines.contains_key(&resolved_model_id)
1267 };
1268
1269 if is_loaded {
1270 if self.engine_dead(&resolved_model_id)? {
1272 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
1273 self.reboot_engine(&resolved_model_id)?
1274 }
1275
1276 let engines = self
1277 .engines
1278 .read()
1279 .map_err(|_| HanzoError::SenderPoisoned)?;
1280 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1281 return Ok(engine_instance.sender.clone());
1282 }
1283 }
1284
1285 let is_unloaded = {
1287 let unloaded = self
1288 .unloaded_models
1289 .read()
1290 .map_err(|_| HanzoError::EnginePoisoned)?;
1291 unloaded.contains_key(&resolved_model_id)
1292 };
1293
1294 if is_unloaded {
1295 tracing::info!(
1296 "Model {} is unloaded, triggering auto-reload",
1297 resolved_model_id
1298 );
1299 self.reload_model_blocking(&resolved_model_id)?;
1300
1301 let engines = self
1303 .engines
1304 .read()
1305 .map_err(|_| HanzoError::SenderPoisoned)?;
1306 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1307 return Ok(engine_instance.sender.clone());
1308 }
1309 }
1310
1311 Err(HanzoError::ModelNotFound(resolved_model_id))
1312 }
1313
1314 pub fn find_file(&self, id: &str) -> Option<Arc<files::File>> {
1316 let engines = self.engines.read().ok()?;
1317 for instance in engines.values() {
1318 if let Some(f) = instance.file_store.get(id) {
1319 return Some(f);
1320 }
1321 }
1322 None
1323 }
1324
1325 pub fn list_files(&self) -> Vec<Arc<files::File>> {
1327 let mut out = Vec::new();
1328 let Ok(engines) = self.engines.read() else {
1329 return out;
1330 };
1331 for instance in engines.values() {
1332 out.extend(instance.file_store.list_all());
1333 }
1334 out
1335 }
1336
1337 pub fn remove_file(&self, id: &str) -> bool {
1339 let Ok(engines) = self.engines.read() else {
1340 return false;
1341 };
1342 for instance in engines.values() {
1343 if instance.file_store.remove(id) {
1344 return true;
1345 }
1346 }
1347 false
1348 }
1349
1350 pub fn get_session_store(
1352 &self,
1353 model_id: Option<&str>,
1354 ) -> Result<Arc<std::sync::Mutex<engine::agentic_session::AgenticSessionStore>>, HanzoError>
1355 {
1356 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1357 let engines = self
1358 .engines
1359 .read()
1360 .map_err(|_| HanzoError::SenderPoisoned)?;
1361 engines
1362 .get(&resolved_model_id)
1363 .map(|e| Arc::clone(&e.session_store))
1364 .ok_or(HanzoError::ModelNotFound(resolved_model_id))
1365 }
1366
1367 fn get_file_store(&self, model_id: Option<&str>) -> Result<files::FileStore, HanzoError> {
1368 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1369 let engines = self
1370 .engines
1371 .read()
1372 .map_err(|_| HanzoError::SenderPoisoned)?;
1373 engines
1374 .get(&resolved_model_id)
1375 .map(|e| e.file_store.clone())
1376 .ok_or(HanzoError::ModelNotFound(resolved_model_id))
1377 }
1378
1379 pub fn export_session(
1381 &self,
1382 model_id: Option<&str>,
1383 session_id: &str,
1384 ) -> Result<Option<engine::agentic_session::SerializedSession>, HanzoError> {
1385 let store = self.get_session_store(model_id)?;
1386 let exported = {
1387 let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1388 guard
1389 .export(session_id)
1390 .map_err(|e| HanzoError::Other(e.to_string()))?
1391 };
1392 let Some(mut session) = exported else {
1393 return Ok(None);
1394 };
1395 let file_store = self.get_file_store(model_id)?;
1396 session.files = file_store
1397 .list_for_session(session_id)
1398 .into_iter()
1399 .map(|arc| (*arc).clone())
1400 .collect();
1401 Ok(Some(session))
1402 }
1403
1404 pub fn import_session(
1406 &self,
1407 model_id: Option<&str>,
1408 session_id: String,
1409 session: engine::agentic_session::SerializedSession,
1410 ) -> Result<(), HanzoError> {
1411 let files = session.files.clone();
1412 let store = self.get_session_store(model_id)?;
1413 {
1414 let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1415 guard
1416 .import(session_id.clone(), session)
1417 .map_err(|e| HanzoError::Other(e.to_string()))?;
1418 }
1419 let file_store = self.get_file_store(model_id)?;
1420 for f in files {
1421 file_store.insert(f, Some(session_id.clone()));
1422 }
1423 Ok(())
1424 }
1425
1426 pub fn fork_session(
1430 &self,
1431 model_id: Option<&str>,
1432 src_session_id: &str,
1433 dest_session_id: String,
1434 num_turns: usize,
1435 ) -> Result<(), HanzoError> {
1436 let store = self.get_session_store(model_id)?;
1437 let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1438 guard
1439 .fork(src_session_id, dest_session_id, num_turns)
1440 .map_err(|e| HanzoError::Other(e.to_string()))
1441 }
1442
1443 pub fn delete_session(
1445 &self,
1446 model_id: Option<&str>,
1447 session_id: &str,
1448 ) -> Result<bool, HanzoError> {
1449 let store = self.get_session_store(model_id)?;
1450 let mut guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1451 Ok(guard.delete(session_id))
1452 }
1453
1454 pub fn list_session_ids(&self, model_id: Option<&str>) -> Result<Vec<String>, HanzoError> {
1456 let store = self.get_session_store(model_id)?;
1457 let guard = store.lock().map_err(|_| HanzoError::SenderPoisoned)?;
1458 Ok(guard.list_ids())
1459 }
1460
1461 pub fn get_id(&self) -> String {
1462 self.id.clone()
1463 }
1464
1465 pub fn get_creation_time(&self) -> u64 {
1466 self.creation_time
1467 }
1468
1469 fn resolve_alias(&self, model_id: &str) -> Result<String, HanzoError> {
1470 let aliases = self
1471 .model_aliases
1472 .read()
1473 .map_err(|_| HanzoError::SenderPoisoned)?;
1474 if let Some(primary_id) = aliases.get(model_id) {
1475 Ok(primary_id.clone())
1476 } else {
1477 Ok(model_id.to_string())
1478 }
1479 }
1480
1481 fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, HanzoError> {
1482 match model_id {
1483 Some(id) => self.resolve_alias(id),
1484 None => {
1485 let default_lock = self
1486 .default_engine_id
1487 .read()
1488 .map_err(|_| HanzoError::SenderPoisoned)?;
1489 Ok(default_lock
1490 .as_ref()
1491 .ok_or(HanzoError::EnginePoisoned)?
1492 .clone())
1493 }
1494 }
1495 }
1496
1497 pub fn register_model_alias(
1499 &self,
1500 alias: impl Into<String>,
1501 model_id: &str,
1502 ) -> Result<(), String> {
1503 let alias = alias.into();
1504 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1505
1506 if alias == resolved_model_id {
1507 return Ok(());
1508 }
1509
1510 let reloading = self
1511 .reloading_models
1512 .read()
1513 .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1514 let model_reloading = reloading.contains(&resolved_model_id);
1515 let alias_conflict = reloading.contains(&alias);
1516 drop(reloading);
1517
1518 let engines = self
1519 .engines
1520 .read()
1521 .map_err(|_| "Failed to acquire read lock on engines")?;
1522 let model_loaded = engines.contains_key(&resolved_model_id);
1523 let alias_conflict = alias_conflict || engines.contains_key(&alias);
1524 drop(engines);
1525
1526 let unloaded = self
1527 .unloaded_models
1528 .read()
1529 .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1530 let model_unloaded = unloaded.contains_key(&resolved_model_id);
1531 let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
1532 drop(unloaded);
1533
1534 if !(model_loaded || model_unloaded || model_reloading) {
1535 return Err(format!("Model {resolved_model_id} not found"));
1536 }
1537
1538 if alias_conflict {
1539 return Err(format!(
1540 "Alias '{}' conflicts with an existing model ID",
1541 alias
1542 ));
1543 }
1544
1545 let mut aliases = self
1546 .model_aliases
1547 .write()
1548 .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1549 if let Some(existing) = aliases.get(&alias) {
1550 if existing == &resolved_model_id {
1551 return Ok(());
1552 }
1553 return Err(format!(
1554 "Alias '{}' is already assigned to model '{}'",
1555 alias, existing
1556 ));
1557 }
1558 aliases.insert(alias, resolved_model_id);
1559 Ok(())
1560 }
1561
1562 pub fn model_exists(&self, model_id: &str) -> Result<bool, HanzoError> {
1564 let resolved_model_id = self.resolve_alias(model_id)?;
1565
1566 let reloading = self
1567 .reloading_models
1568 .read()
1569 .map_err(|_| HanzoError::EnginePoisoned)?;
1570 if reloading.contains(&resolved_model_id) {
1571 return Ok(true);
1572 }
1573 drop(reloading);
1574
1575 let engines = self
1576 .engines
1577 .read()
1578 .map_err(|_| HanzoError::EnginePoisoned)?;
1579 if engines.contains_key(&resolved_model_id) {
1580 return Ok(true);
1581 }
1582 drop(engines);
1583
1584 let unloaded = self
1585 .unloaded_models
1586 .read()
1587 .map_err(|_| HanzoError::EnginePoisoned)?;
1588 if unloaded.contains_key(&resolved_model_id) {
1589 return Ok(true);
1590 }
1591
1592 Ok(false)
1593 }
1594
1595 pub fn get_logger(&self, model_id: Option<&str>) -> Result<Arc<IntervalLogger>, HanzoError> {
1597 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1598
1599 let engines = self
1600 .engines
1601 .read()
1602 .map_err(|_| HanzoError::SenderPoisoned)?;
1603 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1604 Ok(engine_instance.logger.clone())
1605 } else {
1606 Err(HanzoError::EnginePoisoned)
1607 }
1608 }
1609
1610 pub fn get_model_category(&self, model_id: Option<&str>) -> Result<ModelCategory, HanzoError> {
1612 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1613
1614 let engines = self
1615 .engines
1616 .read()
1617 .map_err(|_| HanzoError::SenderPoisoned)?;
1618 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1619 Ok(engine_instance.category.clone())
1620 } else {
1621 Err(HanzoError::EnginePoisoned)
1622 }
1623 }
1624
1625 pub fn max_sequence_length(&self, model_id: Option<&str>) -> Result<Option<usize>, HanzoError> {
1627 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1628
1629 let engines = self
1630 .engines
1631 .read()
1632 .map_err(|_| HanzoError::SenderPoisoned)?;
1633 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1634 Ok(engine_instance.config.max_seq_len)
1635 } else {
1636 Err(HanzoError::EnginePoisoned)
1637 }
1638 }
1639
1640 pub fn next_request_id(&self) -> usize {
1641 let l = self.next_request_id.lock().unwrap();
1642 let last = &mut *l.borrow_mut();
1643 let last_v = *last;
1644 *last += 1;
1645 last_v
1646 }
1647
1648 pub async fn add_model(
1650 &self,
1651 model_id: String,
1652 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
1653 method: SchedulerConfig,
1654 config: AddModelConfig,
1655 ) -> Result<(), String> {
1656 {
1657 let reloading = self
1658 .reloading_models
1659 .read()
1660 .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1661 if reloading.contains(&model_id) {
1662 return Err(format!("Model {model_id} is currently reloading"));
1663 }
1664 }
1665 {
1666 let engines = self
1667 .engines
1668 .read()
1669 .map_err(|_| "Failed to acquire read lock on engines")?;
1670 if engines.contains_key(&model_id) {
1671 return Err(format!("Model {model_id} already exists"));
1672 }
1673 }
1674 {
1675 let unloaded = self
1676 .unloaded_models
1677 .read()
1678 .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1679 if unloaded.contains_key(&model_id) {
1680 return Err(format!("Model {model_id} already exists (unloaded)"));
1681 }
1682 }
1683 {
1684 let aliases = self
1685 .model_aliases
1686 .read()
1687 .map_err(|_| "Failed to acquire read lock on model_aliases")?;
1688 if aliases.contains_key(&model_id) {
1689 return Err(format!(
1690 "Model ID '{}' conflicts with an existing alias",
1691 model_id
1692 ));
1693 }
1694 }
1695
1696 let mut engine_config = config.engine_config;
1697 Self::init_external_tool_callbacks(
1698 &pipeline,
1699 &mut engine_config.tool_callbacks,
1700 config.mcp_client_config.as_ref(),
1701 config.code_exec_config.as_ref(),
1702 )
1703 .await;
1704
1705 let reboot_state = RebootState {
1706 pipeline: pipeline.clone(),
1707 method: method.clone(),
1708 no_kv_cache: engine_config.no_kv_cache,
1709 no_prefix_cache: engine_config.no_prefix_cache,
1710 prefix_cache_n: engine_config.prefix_cache_n,
1711 disable_eos_stop: engine_config.disable_eos_stop,
1712 throughput_logging_enabled: engine_config.throughput_logging_enabled,
1713 search_embedding_model: engine_config.search_embedding_model,
1714 search_callback: engine_config.search_callback.clone(),
1715 tool_callbacks: engine_config.tool_callbacks.clone(),
1716 mcp_client_config: config.mcp_client_config.clone(),
1717 loader_config: config.loader_config.clone(),
1718 };
1719
1720 let engine_instance =
1721 Self::create_engine_instance(pipeline, method, engine_config, reboot_state)?;
1722
1723 let mut engines = self
1724 .engines
1725 .write()
1726 .map_err(|_| "Failed to acquire write lock on engines")?;
1727 engines.insert(model_id.clone(), engine_instance);
1728
1729 if engines.len() == 1 {
1731 let mut default_lock = self
1732 .default_engine_id
1733 .write()
1734 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1735 *default_lock = Some(model_id.clone());
1736 info!("First model added, setting '{}' as default", model_id);
1737 }
1738
1739 Ok(())
1740 }
1741
1742 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
1744 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1745 let mut engines = self
1746 .engines
1747 .write()
1748 .map_err(|_| "Failed to acquire write lock on engines")?;
1749
1750 if engines.len() <= 1 {
1751 return Err("Cannot remove the last model from Hanzo".to_string());
1752 }
1753
1754 if let Some(engine_instance) = engines.remove(&resolved_model_id) {
1755 let _ = engine_instance.sender.blocking_send(Request::Terminate);
1757
1758 let mut default_lock = self
1760 .default_engine_id
1761 .write()
1762 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1763 if let Some(ref default_id) = *default_lock {
1764 if default_id == &resolved_model_id {
1765 *default_lock = engines.keys().next().cloned();
1767 }
1768 }
1769 drop(default_lock);
1770 drop(engines);
1771
1772 let mut aliases = self
1774 .model_aliases
1775 .write()
1776 .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1777 aliases.retain(|_, target| target != &resolved_model_id);
1778
1779 Ok(())
1780 } else {
1781 Err(format!("Model {resolved_model_id} not found"))
1782 }
1783 }
1784
1785 pub fn list_models(&self) -> Result<Vec<String>, String> {
1787 let engines = self
1788 .engines
1789 .read()
1790 .map_err(|_| "Failed to acquire read lock on engines")?;
1791 Ok(engines.keys().cloned().collect())
1792 }
1793
1794 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
1796 let default_lock = self
1797 .default_engine_id
1798 .read()
1799 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
1800 Ok(default_lock.clone())
1801 }
1802
1803 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
1805 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1806 let engines = self
1807 .engines
1808 .read()
1809 .map_err(|_| "Failed to acquire read lock on engines")?;
1810 if !engines.contains_key(&resolved_model_id) {
1811 return Err(format!("Model {resolved_model_id} not found"));
1812 }
1813 drop(engines);
1814
1815 let mut default_lock = self
1816 .default_engine_id
1817 .write()
1818 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1819 let old_default = default_lock.clone();
1820 *default_lock = Some(resolved_model_id.clone());
1821
1822 info!(
1824 "Default model changed: {:?} -> {:?}",
1825 old_default, resolved_model_id
1826 );
1827
1828 Ok(())
1829 }
1830
1831 pub fn send_request(&self, mut request: Request) -> Result<(), HanzoError> {
1833 let model_id = match &mut request {
1834 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
1835 _ => None, };
1837
1838 let sender = self.get_sender(model_id)?;
1839 sender
1840 .blocking_send(request)
1841 .map_err(|_| HanzoError::SenderPoisoned)
1842 }
1843
1844 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
1845 if let Some(file) = &this.log {
1846 let mut f = OpenOptions::new()
1847 .append(true)
1848 .create(true) .open(file)
1850 .expect("Unable to open file");
1851 let time = chrono::offset::Local::now();
1852 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
1853 .expect("Unable to write data");
1854 }
1855 }
1856
1857 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1858 if let Some(file) = &this.log {
1859 let mut f = OpenOptions::new()
1860 .append(true)
1861 .create(true) .open(file)
1863 .expect("Unable to open file");
1864 let time = chrono::offset::Local::now();
1865 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1866 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1867 .expect("Unable to write data");
1868 }
1869 }
1870
1871 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1872 if let Some(file) = &this.log {
1873 let mut f = OpenOptions::new()
1874 .append(true)
1875 .create(true) .open(file)
1877 .expect("Unable to open file");
1878 let time = chrono::offset::Local::now();
1879 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1880 .expect("Unable to write data");
1881 }
1882 }
1883
1884 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1886 let resolved_model_id = self
1887 .resolve_alias_or_default(model_id)
1888 .map_err(|e| e.to_string())?;
1889
1890 let engines = self
1891 .engines
1892 .read()
1893 .map_err(|_| "Failed to acquire read lock on engines")?;
1894 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1895 Ok(engine_instance.reboot_state.tool_callbacks.len())
1896 } else {
1897 Err(format!("Model {resolved_model_id} not found"))
1898 }
1899 }
1900
1901 pub fn list_mcp_tools(
1903 &self,
1904 model_id: Option<&str>,
1905 ) -> Result<Vec<(String, Option<String>)>, String> {
1906 let resolved_model_id = self
1907 .resolve_alias_or_default(model_id)
1908 .map_err(|e| e.to_string())?;
1909
1910 let engines = self
1911 .engines
1912 .read()
1913 .map_err(|_| "Failed to acquire read lock on engines")?;
1914 let engine_instance = engines
1915 .get(&resolved_model_id)
1916 .ok_or_else(|| format!("Model {resolved_model_id} not found"))?;
1917
1918 let mut tools: Vec<(String, Option<String>)> = engine_instance
1919 .reboot_state
1920 .tool_callbacks
1921 .values()
1922 .filter(|cb| {
1923 let name = &cb.tool.function.name;
1924 !search::search_tool_called(name) && {
1926 #[cfg(feature = "code-execution")]
1927 {
1928 !hanzo_code_exec::code_exec_tool_called(name)
1929 }
1930 #[cfg(not(feature = "code-execution"))]
1931 {
1932 true
1933 }
1934 }
1935 })
1936 .map(|cb| {
1937 (
1938 cb.tool.function.name.clone(),
1939 cb.tool.function.description.clone(),
1940 )
1941 })
1942 .collect();
1943 tools.sort_by(|a, b| a.0.cmp(&b.0));
1944 Ok(tools)
1945 }
1946
1947 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1949 let resolved_model_id = self
1950 .resolve_alias_or_default(model_id)
1951 .map_err(|e| e.to_string())?;
1952
1953 let engines = self
1954 .engines
1955 .read()
1956 .map_err(|_| "Failed to acquire read lock on engines")?;
1957 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1958 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1959 } else {
1960 Err(format!("Model {resolved_model_id} not found"))
1961 }
1962 }
1963
1964 pub fn config(&self, model_id: Option<&str>) -> Result<HanzoConfig, String> {
1966 let resolved_model_id = self
1967 .resolve_alias_or_default(model_id)
1968 .map_err(|e| e.to_string())?;
1969
1970 let engines = self
1971 .engines
1972 .read()
1973 .map_err(|_| "Failed to acquire read lock on engines")?;
1974 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1975 Ok(engine_instance.config.clone())
1976 } else {
1977 Err(format!("Model {resolved_model_id} not found"))
1978 }
1979 }
1980
1981 pub fn unload_model(&self, model_id: &str) -> Result<(), HanzoError> {
1988 let resolved_model_id = self.resolve_alias(model_id)?;
1989 {
1991 let unloaded = self
1992 .unloaded_models
1993 .read()
1994 .map_err(|_| HanzoError::EnginePoisoned)?;
1995 if unloaded.contains_key(&resolved_model_id) {
1996 return Err(HanzoError::ModelAlreadyUnloaded(resolved_model_id.clone()));
1997 }
1998 }
1999
2000 let mut engines = self
2002 .engines
2003 .write()
2004 .map_err(|_| HanzoError::EnginePoisoned)?;
2005
2006 let engine_instance = engines
2007 .remove(&resolved_model_id)
2008 .ok_or_else(|| HanzoError::ModelNotFound(resolved_model_id.clone()))?;
2009
2010 let loader_config = engine_instance
2012 .reboot_state
2013 .loader_config
2014 .clone()
2015 .ok_or_else(|| HanzoError::NoLoaderConfig(resolved_model_id.clone()))?;
2016
2017 let unloaded_state = UnloadedModelState {
2019 loader_config,
2020 scheduler_config: engine_instance.reboot_state.method.clone(),
2021 engine_config: EngineConfig {
2022 no_kv_cache: engine_instance.reboot_state.no_kv_cache,
2023 no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
2024 prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
2025 disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
2026 throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
2027 search_embedding_model: engine_instance.reboot_state.search_embedding_model,
2028 search_callback: engine_instance.reboot_state.search_callback.clone(),
2029 tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
2030 },
2031 mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
2032 category: engine_instance.category.clone(),
2033 hanzo_config: engine_instance.config.clone(),
2034 };
2035
2036 let _ = engine_instance.sender.try_send(Request::Terminate);
2038
2039 drop(engines);
2040
2041 let mut unloaded = self
2043 .unloaded_models
2044 .write()
2045 .map_err(|_| HanzoError::EnginePoisoned)?;
2046 unloaded.insert(resolved_model_id.to_string(), unloaded_state);
2047
2048 let mut default_lock = self
2050 .default_engine_id
2051 .write()
2052 .map_err(|_| HanzoError::EnginePoisoned)?;
2053 if let Some(ref default_id) = *default_lock {
2054 if default_id == &resolved_model_id {
2055 let engines = self
2057 .engines
2058 .read()
2059 .map_err(|_| HanzoError::EnginePoisoned)?;
2060 *default_lock = engines.keys().next().cloned();
2061 }
2062 }
2063
2064 info!("Model {} unloaded successfully", resolved_model_id);
2065 Ok(())
2066 }
2067
2068 pub async fn reload_model(&self, model_id: &str) -> Result<(), HanzoError> {
2071 let resolved_model_id = self.resolve_alias(model_id)?;
2072 {
2074 let reloading = self
2075 .reloading_models
2076 .read()
2077 .map_err(|_| HanzoError::EnginePoisoned)?;
2078 if reloading.contains(&resolved_model_id) {
2079 return Err(HanzoError::ModelReloading(resolved_model_id.clone()));
2080 }
2081 }
2082
2083 {
2085 let mut reloading = self
2086 .reloading_models
2087 .write()
2088 .map_err(|_| HanzoError::EnginePoisoned)?;
2089 reloading.insert(resolved_model_id.clone());
2090 }
2091
2092 let unloaded_state = {
2094 let unloaded = self
2095 .unloaded_models
2096 .read()
2097 .map_err(|_| HanzoError::EnginePoisoned)?;
2098 unloaded
2099 .get(&resolved_model_id)
2100 .cloned()
2101 .ok_or_else(|| HanzoError::ModelNotFound(resolved_model_id.clone()))?
2102 };
2103
2104 let result = self
2106 .do_reload_model(&resolved_model_id, unloaded_state)
2107 .await;
2108
2109 {
2111 let mut reloading = self
2112 .reloading_models
2113 .write()
2114 .map_err(|_| HanzoError::EnginePoisoned)?;
2115 reloading.remove(&resolved_model_id);
2116 }
2117
2118 result
2119 }
2120
2121 async fn do_reload_model(
2123 &self,
2124 model_id: &str,
2125 unloaded_state: UnloadedModelState,
2126 ) -> Result<(), HanzoError> {
2127 use crate::model_loader::LoaderBuilder;
2128
2129 info!("Reloading model: {}", model_id);
2130
2131 let loader_config = &unloaded_state.loader_config;
2132
2133 let loader = LoaderBuilder::new(loader_config.model_selected.clone())
2135 .with_chat_template(loader_config.chat_template.clone())
2136 .with_jinja_explicit(loader_config.jinja_explicit.clone())
2137 .build()
2138 .map_err(|e| HanzoError::ReloadFailed(format!("Failed to build loader: {e}")))?;
2139
2140 let pipeline = loader
2142 .load_model_from_hf(
2143 None,
2144 loader_config.token_source.clone(),
2145 &loader_config.dtype,
2146 &loader_config.device,
2147 loader_config.silent,
2148 loader_config.device_map_setting.clone(),
2149 loader_config.isq,
2150 loader_config.paged_attn_config,
2151 )
2152 .map_err(|e| HanzoError::ReloadFailed(format!("Failed to load model: {e}")))?;
2153
2154 if let Some(mtp_config) = loader_config.mtp_config.clone() {
2155 pipeline
2156 .blocking_lock()
2157 .attach_speculative(SpeculativeConfig::Mtp(mtp_config))
2158 .map_err(|e| {
2159 HanzoError::ReloadFailed(format!(
2160 "Failed to attach MTP speculative decoding: {e}"
2161 ))
2162 })?;
2163 }
2164
2165 let reboot_state = RebootState {
2167 pipeline: pipeline.clone(),
2168 method: unloaded_state.scheduler_config.clone(),
2169 no_kv_cache: unloaded_state.engine_config.no_kv_cache,
2170 no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
2171 prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
2172 disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
2173 throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
2174 search_embedding_model: unloaded_state.engine_config.search_embedding_model,
2175 search_callback: unloaded_state.engine_config.search_callback.clone(),
2176 tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
2177 mcp_client_config: unloaded_state.mcp_client_config.clone(),
2178 loader_config: Some(unloaded_state.loader_config.clone()),
2179 };
2180
2181 let engine_instance = Self::create_engine_instance(
2182 pipeline,
2183 unloaded_state.scheduler_config,
2184 unloaded_state.engine_config,
2185 reboot_state,
2186 )
2187 .map_err(|e| HanzoError::ReloadFailed(format!("Failed to create engine: {e}")))?;
2188
2189 {
2191 let mut engines = self
2192 .engines
2193 .write()
2194 .map_err(|_| HanzoError::EnginePoisoned)?;
2195 engines.insert(model_id.to_string(), engine_instance);
2196 }
2197
2198 {
2200 let mut unloaded = self
2201 .unloaded_models
2202 .write()
2203 .map_err(|_| HanzoError::EnginePoisoned)?;
2204 unloaded.remove(model_id);
2205 }
2206
2207 info!("Model {} reloaded successfully", model_id);
2208 Ok(())
2209 }
2210
2211 pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), HanzoError> {
2218 match tokio::runtime::Handle::try_current() {
2219 Ok(handle) => {
2220 if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
2221 Err(HanzoError::ReloadFailed(
2222 "Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
2223 ))
2224 } else {
2225 tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
2226 }
2227 }
2228 Err(_) => {
2229 let rt = tokio::runtime::Runtime::new().map_err(|e| {
2230 HanzoError::ReloadFailed(format!("Failed to create runtime: {e}"))
2231 })?;
2232 rt.block_on(self.reload_model(model_id))
2233 }
2234 }
2235 }
2236
2237 pub fn list_unloaded_models(&self) -> Result<Vec<String>, HanzoError> {
2239 let unloaded = self
2240 .unloaded_models
2241 .read()
2242 .map_err(|_| HanzoError::EnginePoisoned)?;
2243 Ok(unloaded.keys().cloned().collect())
2244 }
2245
2246 pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, HanzoError> {
2248 let resolved_model_id = self.resolve_alias(model_id)?;
2249 let engines = self
2250 .engines
2251 .read()
2252 .map_err(|_| HanzoError::EnginePoisoned)?;
2253 Ok(engines.contains_key(&resolved_model_id))
2254 }
2255
2256 pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, HanzoError> {
2258 let resolved_model_id = self.resolve_alias(model_id)?;
2259 {
2261 let reloading = self
2262 .reloading_models
2263 .read()
2264 .map_err(|_| HanzoError::EnginePoisoned)?;
2265 if reloading.contains(&resolved_model_id) {
2266 return Ok(Some(ModelStatus::Reloading));
2267 }
2268 }
2269
2270 {
2272 let engines = self
2273 .engines
2274 .read()
2275 .map_err(|_| HanzoError::EnginePoisoned)?;
2276 if engines.contains_key(&resolved_model_id) {
2277 return Ok(Some(ModelStatus::Loaded));
2278 }
2279 }
2280
2281 {
2283 let unloaded = self
2284 .unloaded_models
2285 .read()
2286 .map_err(|_| HanzoError::EnginePoisoned)?;
2287 if unloaded.contains_key(&resolved_model_id) {
2288 return Ok(Some(ModelStatus::Unloaded));
2289 }
2290 }
2291
2292 Ok(None)
2293 }
2294
2295 pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, HanzoError> {
2297 let mut result = Vec::new();
2298
2299 let reloading = self
2301 .reloading_models
2302 .read()
2303 .map_err(|_| HanzoError::EnginePoisoned)?;
2304 for model_id in reloading.iter() {
2305 result.push((model_id.clone(), ModelStatus::Reloading));
2306 }
2307 drop(reloading);
2308
2309 let engines = self
2311 .engines
2312 .read()
2313 .map_err(|_| HanzoError::EnginePoisoned)?;
2314 for model_id in engines.keys() {
2315 result.push((model_id.clone(), ModelStatus::Loaded));
2316 }
2317 drop(engines);
2318
2319 let unloaded = self
2321 .unloaded_models
2322 .read()
2323 .map_err(|_| HanzoError::EnginePoisoned)?;
2324 for model_id in unloaded.keys() {
2325 if !result.iter().any(|(id, _)| id == model_id) {
2327 result.push((model_id.clone(), ModelStatus::Unloaded));
2328 }
2329 }
2330
2331 Ok(result)
2332 }
2333}