Skip to main content

mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::{HashMap, HashSet};
15use std::num::NonZeroUsize;
16use std::sync::OnceLock;
17use std::time::Instant;
18use std::{
19    cell::RefCell,
20    error::Error,
21    fs::OpenOptions,
22    io::Write,
23    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24    thread::{self, JoinHandle},
25    time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31pub const MISTRALRS_GIT_REVISION: &str = match option_env!("MISTRALRS_GIT_REVISION") {
32    Some(value) => value,
33    None => "unknown",
34};
35
36mod cuda;
37mod device_map;
38mod engine;
39mod lora;
40mod model_loader;
41mod moe;
42mod ops;
43pub use model_loader::{
44    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
45};
46mod embedding_models;
47mod kv_cache;
48mod search;
49
50mod model_selected;
51pub use model_selected::ModelSelected;
52pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
53
54mod amoe;
55mod attention;
56mod diagnostics;
57mod diffusion_models;
58pub mod distributed;
59mod gguf;
60pub mod harmony;
61pub mod layers;
62mod layers_masker;
63mod layers_utils;
64pub mod matformer;
65mod mla;
66mod models;
67mod paged_attention;
68mod pipeline;
69mod prefix_cacher;
70mod request;
71mod response;
72mod sampler;
73mod scheduler;
74mod sequence;
75mod speech_models;
76pub mod think_tags;
77mod toml_selector;
78mod tools;
79mod topology;
80mod utils;
81mod vision_models;
82mod xlora_models;
83
84pub use diagnostics::{
85    check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
86    DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
87};
88mod tuning;
89pub use tuning::{
90    auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
91};
92
93pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
94pub use device_map::{
95    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
96};
97pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
98pub use mistralrs_audio::AudioInput;
99pub use mistralrs_mcp::{
100    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
101};
102pub use mistralrs_mcp::{
103    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
104};
105pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
106pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
107pub use pipeline::hf::{hf_home_dir, hf_hub_cache_dir, hf_token_path};
108pub use pipeline::{
109    chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
110    AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
111    DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
112    EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
113    GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
114    GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
115    Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
116    ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
117    NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
118    SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
119    SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
120    VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
121};
122pub use request::{
123    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
124    LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
125    SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
126};
127pub use response::*;
128pub use sampler::{
129    CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
130};
131pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
132pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
133use serde::Serialize;
134pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
135use tokio::runtime::Runtime;
136use toml_selector::{TomlLoaderArgs, TomlSelector};
137pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
138pub use topology::{LayerTopology, Topology};
139pub use utils::debug::initialize_logging;
140pub use utils::memory_usage::MemoryUsage;
141pub use utils::normal::{ModelDType, TryIntoDType};
142pub use utils::{paged_attn_supported, using_flash_attn};
143
144// re-export llguidance for easier LlguidanceGrammar construction
145pub use llguidance;
146
147/// `true` if `MISTRALRS_DEBUG=1`
148pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
149pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
150
151/// Configuration for creating an engine instance
152#[derive(Clone)]
153pub struct EngineConfig {
154    pub no_kv_cache: bool,
155    pub no_prefix_cache: bool,
156    pub prefix_cache_n: usize,
157    pub disable_eos_stop: bool,
158    pub throughput_logging_enabled: bool,
159    pub search_embedding_model: Option<SearchEmbeddingModel>,
160    pub search_callback: Option<Arc<SearchCallback>>,
161    pub tool_callbacks: tools::ToolCallbacks,
162    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
163}
164
165impl Default for EngineConfig {
166    fn default() -> Self {
167        Self {
168            no_kv_cache: false,
169            no_prefix_cache: false,
170            prefix_cache_n: 16,
171            disable_eos_stop: false,
172            throughput_logging_enabled: true,
173            search_embedding_model: None,
174            search_callback: None,
175            tool_callbacks: HashMap::new(),
176            tool_callbacks_with_tools: HashMap::new(),
177        }
178    }
179}
180
181/// Configuration for adding a model to MistralRs
182#[derive(Clone)]
183pub struct AddModelConfig {
184    pub engine_config: EngineConfig,
185    pub mcp_client_config: Option<McpClientConfig>,
186    /// Optional loader config for enabling model unload/reload support.
187    /// Without this, models cannot be unloaded and reloaded.
188    pub loader_config: Option<ModelLoaderConfig>,
189}
190
191impl AddModelConfig {
192    pub fn new(engine_config: EngineConfig) -> Self {
193        Self {
194            engine_config,
195            mcp_client_config: None,
196            loader_config: None,
197        }
198    }
199
200    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
201        self.mcp_client_config = Some(mcp_config);
202        self
203    }
204
205    /// Set the loader config for enabling model unload/reload support.
206    /// Without this, models cannot be unloaded and reloaded.
207    pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
208        self.loader_config = Some(loader_config);
209        self
210    }
211}
212
213#[derive(Clone)]
214pub struct MistralRsConfig {
215    pub kind: ModelKind,
216    pub device: Device,
217    pub category: ModelCategory,
218    pub modalities: Modalities,
219    pub max_seq_len: Option<usize>,
220}
221
222/// Configuration for recreating a model loader when reloading an unloaded model.
223/// This captures the essential parameters needed to reconstruct a loader.
224#[derive(Clone)]
225pub struct ModelLoaderConfig {
226    /// The model selection configuration (Plain, GGUF, Vision, etc.)
227    pub model_selected: ModelSelected,
228    /// Source of the HF token
229    pub token_source: TokenSource,
230    /// Optional HF revision
231    pub hf_revision: Option<String>,
232    /// Model data type
233    pub dtype: ModelDType,
234    /// Device to load the model on
235    pub device: Device,
236    /// Device mapping setting
237    pub device_map_setting: DeviceMapSetting,
238    /// In-situ quantization type
239    pub isq: Option<IsqType>,
240    /// Paged attention configuration
241    pub paged_attn_config: Option<PagedAttentionConfig>,
242    /// Whether to suppress logging during loading
243    pub silent: bool,
244    /// Chat template override
245    pub chat_template: Option<String>,
246    /// Explicit Jinja template path
247    pub jinja_explicit: Option<String>,
248}
249
250/// State preserved when a model is unloaded.
251/// This contains all the information needed to reload the model on demand.
252#[derive(Clone)]
253pub struct UnloadedModelState {
254    /// Configuration to recreate the loader
255    pub loader_config: ModelLoaderConfig,
256    /// Scheduler configuration
257    pub scheduler_config: SchedulerConfig,
258    /// Engine configuration
259    pub engine_config: EngineConfig,
260    /// MCP client configuration
261    pub mcp_client_config: Option<McpClientConfig>,
262    /// Model category (Text, Vision, etc.)
263    pub category: ModelCategory,
264    /// Model metadata configuration
265    pub mistralrs_config: MistralRsConfig,
266}
267
268/// Internal structure to hold per-engine state
269struct EngineInstance {
270    sender: Sender<Request>,
271    engine_handler: JoinHandle<()>,
272    reboot_state: RebootState,
273    config: MistralRsConfig,
274    category: ModelCategory,
275}
276
277/// The MistralRs struct handles sending requests to multiple engines.
278/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
279/// `Sender` and `Receiver` primitives to send and receive requests to the
280/// appropriate engine based on model ID.
281///
282/// ## Lock Ordering Convention
283///
284/// This struct uses multiple `RwLock`s. To prevent deadlocks, locks must be
285/// acquired in this order:
286/// 1. `reloading_models`
287/// 2. `engines`
288/// 3. `unloaded_models`
289/// 4. `default_engine_id`
290/// 5. `model_aliases`
291///
292/// Use scope-based lock management and explicit `drop()` calls.
293pub struct MistralRs {
294    engines: RwLock<HashMap<String, EngineInstance>>,
295    /// Models that have been unloaded but can be reloaded on demand
296    unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
297    /// Models currently being reloaded (to prevent concurrent reloads)
298    reloading_models: RwLock<HashSet<String>>,
299    default_engine_id: RwLock<Option<String>>,
300    /// Alternate IDs that resolve to primary model IDs.
301    model_aliases: RwLock<HashMap<String, String>>,
302    log: Option<String>,
303    id: String,
304    creation_time: u64,
305    next_request_id: Mutex<RefCell<usize>>,
306}
307
308#[derive(Clone)]
309struct RebootState {
310    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
311    method: SchedulerConfig,
312    no_kv_cache: bool,
313    no_prefix_cache: bool,
314    prefix_cache_n: usize,
315    disable_eos_stop: bool,
316    throughput_logging_enabled: bool,
317    search_embedding_model: Option<SearchEmbeddingModel>,
318    search_callback: Option<Arc<search::SearchCallback>>,
319    tool_callbacks: tools::ToolCallbacks,
320    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
321    mcp_client_config: Option<McpClientConfig>,
322    /// Optional loader config for reloading after unload
323    loader_config: Option<ModelLoaderConfig>,
324}
325
326/// Model status for loaded/unloaded state
327#[derive(Debug, Clone, Copy, PartialEq, Eq)]
328pub enum ModelStatus {
329    Loaded,
330    Unloaded,
331    Reloading,
332}
333
334impl std::fmt::Display for ModelStatus {
335    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336        match self {
337            ModelStatus::Loaded => write!(f, "loaded"),
338            ModelStatus::Unloaded => write!(f, "unloaded"),
339            ModelStatus::Reloading => write!(f, "reloading"),
340        }
341    }
342}
343
344#[derive(Debug)]
345pub enum MistralRsError {
346    EnginePoisoned,
347    SenderPoisoned,
348    /// The requested model was not found (neither loaded nor unloaded)
349    ModelNotFound(String),
350    /// The model is currently being reloaded
351    ModelReloading(String),
352    /// Failed to reload the model
353    ReloadFailed(String),
354    /// Model does not have loader config for reloading
355    NoLoaderConfig(String),
356    /// Model is already loaded
357    ModelAlreadyLoaded(String),
358    /// Model is already unloaded
359    ModelAlreadyUnloaded(String),
360}
361
362impl std::fmt::Display for MistralRsError {
363    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364        write!(f, "{:?}", &self)
365    }
366}
367
368impl std::error::Error for MistralRsError {}
369
370#[cfg(feature = "pyo3_macros")]
371impl From<MistralRsError> for pyo3::PyErr {
372    fn from(value: MistralRsError) -> Self {
373        PyValueError::new_err(format!("{value:?}"))
374    }
375}
376
377/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
378/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
379/// instance stays on the calling thread.
380pub struct MistralRsBuilder {
381    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
382    method: SchedulerConfig,
383    model_id_override: Option<String>,
384    log: Option<String>,
385    no_kv_cache: Option<bool>,
386    no_prefix_cache: Option<bool>,
387    prefix_cache_n: Option<usize>,
388    disable_eos_stop: Option<bool>,
389    throughput_logging_enabled: bool,
390    search_embedding_model: Option<SearchEmbeddingModel>,
391    search_callback: Option<Arc<SearchCallback>>,
392    tool_callbacks: tools::ToolCallbacks,
393    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
394    mcp_client_config: Option<McpClientConfig>,
395    loader_config: Option<ModelLoaderConfig>,
396}
397
398impl MistralRsBuilder {
399    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
400    /// and optional embedding model for web search. To override the search callback,
401    /// use `.with_search_callback(...)` on the builder.
402    pub fn new(
403        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
404        method: SchedulerConfig,
405        throughput_logging: bool,
406        search_embedding_model: Option<SearchEmbeddingModel>,
407    ) -> Self {
408        Self {
409            pipeline,
410            method,
411            model_id_override: None,
412            log: None,
413            no_kv_cache: None,
414            no_prefix_cache: None,
415            prefix_cache_n: None,
416            disable_eos_stop: None,
417            throughput_logging_enabled: throughput_logging,
418            search_embedding_model,
419            search_callback: None,
420            tool_callbacks: HashMap::new(),
421            tool_callbacks_with_tools: HashMap::new(),
422            mcp_client_config: None,
423            loader_config: None,
424        }
425    }
426
427    /// Override the model ID used by MistralRs. Defaults to the pipeline name.
428    pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
429        self.model_id_override = Some(model_id.into());
430        self
431    }
432
433    /// Set the loader config for enabling model unload/reload support.
434    /// Without this, models cannot be unloaded and reloaded.
435    pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
436        self.loader_config = Some(loader_config);
437        self
438    }
439    pub fn with_log(mut self, log: String) -> Self {
440        self.log = Some(log);
441        self
442    }
443    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
444        self.log = log;
445        self
446    }
447    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
448        self.no_kv_cache = Some(no_kv_cache);
449        self
450    }
451    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
452        self.no_prefix_cache = Some(no_prefix_cache);
453        self
454    }
455    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
456        self.prefix_cache_n = Some(prefix_cache_n);
457        self
458    }
459    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
460        self.disable_eos_stop = Some(disable_eos_stop);
461        self
462    }
463
464    /// Use a custom callback to gather search results.
465    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
466        self.search_callback = Some(search_callback);
467        self
468    }
469
470    /// Register a custom callback for the specified tool name.
471    pub fn with_tool_callback(
472        mut self,
473        name: impl Into<String>,
474        tool_callback: Arc<ToolCallback>,
475    ) -> Self {
476        self.tool_callbacks.insert(name.into(), tool_callback);
477        self
478    }
479
480    /// Register a custom callback with its associated Tool definition. The Tool will be
481    /// automatically added to requests when tool callbacks are active.
482    pub fn with_tool_callback_and_tool(
483        mut self,
484        name: impl Into<String>,
485        tool_callback: Arc<ToolCallback>,
486        tool: Tool,
487    ) -> Self {
488        let name = name.into();
489        self.tool_callbacks_with_tools.insert(
490            name,
491            ToolCallbackWithTool {
492                callback: tool_callback,
493                tool,
494            },
495        );
496        self
497    }
498
499    /// Configure MCP client to connect to external MCP servers.
500    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
501        self.mcp_client_config = Some(config);
502        self
503    }
504
505    pub async fn build(self) -> Arc<MistralRs> {
506        MistralRs::new(self).await
507    }
508}
509
510impl Drop for MistralRs {
511    fn drop(&mut self) {
512        // Terminate all engines
513        if let Ok(engines) = self.engines.read() {
514            for (_, engine) in engines.iter() {
515                // Use try_send instead of blocking_send to avoid runtime panics
516                let _ = engine.sender.try_send(Request::Terminate);
517            }
518        }
519    }
520}
521
522impl MistralRs {
523    /// Create an engine instance with the given configuration
524    fn create_engine_instance(
525        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
526        method: SchedulerConfig,
527        config: EngineConfig,
528        reboot_state: RebootState,
529    ) -> Result<EngineInstance, String> {
530        let (tx, rx) = channel(10_000);
531
532        let pipeline_guard = pipeline.try_lock().unwrap();
533        let category = pipeline_guard.category();
534        let metadata = pipeline_guard.get_metadata();
535        let kind = metadata.kind.clone();
536        let device = pipeline_guard.device();
537        let modalities = metadata.modalities.clone();
538        let max_seq_len = match &category {
539            ModelCategory::Diffusion | ModelCategory::Speech => None,
540            _ => Some(metadata.max_seq_len),
541        };
542        drop(pipeline_guard);
543
544        info!("Pipeline input modalities are {:?}", &modalities.input);
545        info!("Pipeline output modalities are {:?}", &modalities.output);
546
547        let mistralrs_config = MistralRsConfig {
548            kind,
549            device,
550            category: category.clone(),
551            modalities,
552            max_seq_len,
553        };
554
555        let tx_for_engine = tx.clone();
556        let engine_handler = thread::spawn(move || {
557            #[cfg(feature = "metal")]
558            objc::rc::autoreleasepool(move || {
559                let rt = Runtime::new().unwrap();
560                rt.block_on(async move {
561                    let engine = Engine::new(
562                        tx_for_engine,
563                        rx,
564                        pipeline,
565                        method,
566                        config.no_kv_cache,
567                        config.no_prefix_cache,
568                        config.prefix_cache_n,
569                        config.disable_eos_stop,
570                        config.throughput_logging_enabled,
571                        config.search_embedding_model,
572                        config.search_callback.clone(),
573                        config.tool_callbacks.clone(),
574                        config.tool_callbacks_with_tools.clone(),
575                    )
576                    .expect("Engine creation failed.");
577                    Arc::new(engine).run().await;
578                })
579            });
580
581            #[cfg(not(feature = "metal"))]
582            {
583                let rt = Runtime::new().unwrap();
584                rt.block_on(async move {
585                    let engine = Engine::new(
586                        tx_for_engine,
587                        rx,
588                        pipeline,
589                        method,
590                        config.no_kv_cache,
591                        config.no_prefix_cache,
592                        config.prefix_cache_n,
593                        config.disable_eos_stop,
594                        config.throughput_logging_enabled,
595                        config.search_embedding_model,
596                        config.search_callback.clone(),
597                        config.tool_callbacks.clone(),
598                        config.tool_callbacks_with_tools.clone(),
599                    )
600                    .expect("Engine creation failed.");
601                    Arc::new(engine).run().await;
602                })
603            }
604        });
605
606        Ok(EngineInstance {
607            sender: tx,
608            engine_handler,
609            reboot_state,
610            config: mistralrs_config,
611            category,
612        })
613    }
614
615    async fn new(config: MistralRsBuilder) -> Arc<Self> {
616        info!("git revision: {MISTRALRS_GIT_REVISION}");
617        let MistralRsBuilder {
618            pipeline,
619            method,
620            model_id_override,
621            log,
622            no_kv_cache,
623            no_prefix_cache,
624            prefix_cache_n,
625            disable_eos_stop,
626            throughput_logging_enabled,
627            search_embedding_model,
628            search_callback,
629            tool_callbacks,
630            mut tool_callbacks_with_tools,
631            mcp_client_config,
632            loader_config,
633        } = config;
634
635        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
636            get_mut_arcmutex!(pipeline).device(),
637        );
638
639        // For hybrid models (Mamba-Attention), force batch_size=1 to prevent state bleeding
640        // Mamba's stateful nature makes batched inference complex; this ensures correctness
641        let method = if !get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache
642            && get_mut_arcmutex!(pipeline).cache().is_hybrid()
643        {
644            info!(
645                "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
646            );
647            SchedulerConfig::DefaultScheduler {
648                method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
649            }
650        } else {
651            method
652        };
653
654        let no_kv_cache = no_kv_cache.unwrap_or(false);
655        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
656        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
657        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
658
659        // Initialize MCP client if configured
660        if let Some(config) = &mcp_client_config {
661            let mut mcp_client = McpClient::new(config.clone());
662            let total_servers = config.servers.len();
663
664            match mcp_client.initialize().await {
665                Ok(()) => {
666                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
667                    let tools_count = mcp_callbacks_with_tools.len();
668
669                    // Merge MCP tool callbacks with tools into the new collection
670                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
671                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
672                    }
673
674                    if tools_count == 0 {
675                        warn!(
676                            "MCP client initialized but no tools were registered from {} servers",
677                            total_servers
678                        );
679                    } else {
680                        info!(
681                            "MCP client initialized successfully with {} tools from {} servers",
682                            tools_count, total_servers
683                        );
684                    }
685                }
686                Err(e) => {
687                    warn!(
688                        "Failed to initialize MCP client with {} configured servers: {}",
689                        total_servers, e
690                    );
691                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
692                }
693            }
694        }
695
696        let reboot_state = RebootState {
697            pipeline: pipeline.clone(),
698            method: method.clone(),
699            no_kv_cache,
700            no_prefix_cache,
701            prefix_cache_n,
702            disable_eos_stop,
703            throughput_logging_enabled,
704            search_embedding_model,
705            search_callback: search_callback.clone(),
706            tool_callbacks: tool_callbacks.clone(),
707            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
708            mcp_client_config: mcp_client_config.clone(),
709            loader_config,
710        };
711
712        // Create the engine configuration
713        let engine_config = EngineConfig {
714            no_kv_cache,
715            no_prefix_cache,
716            prefix_cache_n,
717            disable_eos_stop,
718            throughput_logging_enabled,
719            search_embedding_model,
720            search_callback,
721            tool_callbacks,
722            tool_callbacks_with_tools,
723        };
724
725        // Create the engine instance
726        let engine_instance =
727            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
728                .expect("Failed to create engine instance");
729
730        let pipeline_name = pipeline.try_lock().unwrap().name();
731        let (id, alias_map) = match model_id_override {
732            Some(override_id) => {
733                let mut alias_map = HashMap::new();
734                if override_id != pipeline_name {
735                    alias_map.insert(pipeline_name.clone(), override_id.clone());
736                }
737                (override_id, alias_map)
738            }
739            None => (pipeline_name.clone(), HashMap::new()),
740        };
741
742        if distributed::is_daemon() {
743            let request_sender = engine_instance.sender.clone();
744
745            if cfg!(feature = "ring") {
746                // Ring daemon replicator
747                distributed::ring_daemon_replicator(request_sender);
748            } else {
749                // NCCL daemon replicator
750                distributed::nccl_daemon_replicator(request_sender);
751            }
752
753            #[allow(clippy::empty_loop)]
754            loop {}
755        }
756
757        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
758        let is_multi_threaded = tokio::runtime::Handle::try_current()
759            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
760
761        // Do a dummy run
762        if !distributed::is_daemon()
763            && is_multi_threaded
764            && matches!(
765                engine_instance.category,
766                ModelCategory::Text | ModelCategory::Vision { .. }
767            )
768        {
769            let clone_sender = engine_instance.sender.clone();
770            tokio::task::block_in_place(|| {
771                let (tx, mut rx) = channel(1);
772                let req = Request::Normal(Box::new(NormalRequest {
773                    id: 0,
774                    messages: RequestMessage::Completion {
775                        text: "hello".to_string(),
776                        echo_prompt: false,
777                        best_of: None,
778                    },
779                    sampling_params: SamplingParams {
780                        max_len: Some(1),
781                        ..SamplingParams::deterministic()
782                    },
783                    response: tx,
784                    return_logprobs: false,
785                    is_streaming: false,
786                    constraint: Constraint::None,
787                    suffix: None,
788                    tool_choice: None,
789                    tools: None,
790                    logits_processors: None,
791                    return_raw_logits: false,
792                    web_search_options: None,
793                    model_id: None,
794                    truncate_sequence: false,
795                }));
796                info!("Beginning dummy run.");
797                let start = Instant::now();
798                clone_sender.blocking_send(req).unwrap();
799
800                // Drain all responses from the channel until it's closed
801                let mut received_any = false;
802                while let Some(_resp) = rx.blocking_recv() {
803                    received_any = true;
804                }
805
806                if received_any {
807                    let end = Instant::now();
808                    info!(
809                        "Dummy run completed in {}s.",
810                        end.duration_since(start).as_secs_f64()
811                    );
812                } else {
813                    warn!("Dummy run failed!");
814                }
815            });
816        }
817
818        // Create engines map with the first engine
819        let mut engines = HashMap::new();
820        engines.insert(id.clone(), engine_instance);
821
822        Arc::new(Self {
823            engines: RwLock::new(engines),
824            unloaded_models: RwLock::new(HashMap::new()),
825            reloading_models: RwLock::new(HashSet::new()),
826            default_engine_id: RwLock::new(Some(id.clone())),
827            model_aliases: RwLock::new(alias_map),
828            log,
829            id,
830            creation_time: SystemTime::now()
831                .duration_since(UNIX_EPOCH)
832                .expect("Time travel has occurred!")
833                .as_secs(),
834            next_request_id: Mutex::new(RefCell::new(1)),
835        })
836    }
837
838    /// Attempts to reboot a specific engine by model_id
839    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
840        let mut engines = self.engines.write().map_err(|_| {
841            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
842            MistralRsError::EnginePoisoned
843        })?;
844
845        if let Some(engine_instance) = engines.get(model_id) {
846            if !engine_instance.engine_handler.is_finished() {
847                tracing::info!("Engine {} already running, returning ok", model_id);
848                return Ok(());
849            }
850
851            let reboot_state = engine_instance.reboot_state.clone();
852            let engine_config = EngineConfig {
853                no_kv_cache: reboot_state.no_kv_cache,
854                no_prefix_cache: reboot_state.no_prefix_cache,
855                prefix_cache_n: reboot_state.prefix_cache_n,
856                disable_eos_stop: reboot_state.disable_eos_stop,
857                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
858                search_embedding_model: reboot_state.search_embedding_model,
859                search_callback: reboot_state.search_callback.clone(),
860                tool_callbacks: reboot_state.tool_callbacks.clone(),
861                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
862            };
863            let new_engine_instance = Self::create_engine_instance(
864                reboot_state.pipeline.clone(),
865                reboot_state.method.clone(),
866                engine_config,
867                reboot_state,
868            )
869            .map_err(|e| {
870                tracing::error!("Failed to create new engine instance: {}", e);
871                MistralRsError::EnginePoisoned
872            })?;
873
874            engines.insert(model_id.to_string(), new_engine_instance);
875            tracing::info!("Successfully rebooted engine {}", model_id);
876            Ok(())
877        } else {
878            Err(MistralRsError::EnginePoisoned)
879        }
880    }
881
882    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
883        let engines = self.engines.read().map_err(|_| {
884            tracing::warn!("Couldn't get read lock on engines!");
885            MistralRsError::EnginePoisoned
886        })?;
887
888        if let Some(engine_instance) = engines.get(model_id) {
889            Ok(engine_instance.engine_handler.is_finished())
890        } else {
891            Err(MistralRsError::EnginePoisoned)
892        }
893    }
894
895    /// Get sender for a specific model. If model_id is None, uses default engine.
896    /// If the model is unloaded, it will be automatically reloaded before returning the sender.
897    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
898        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
899
900        // Check if model is loaded
901        let is_loaded = {
902            let engines = self
903                .engines
904                .read()
905                .map_err(|_| MistralRsError::SenderPoisoned)?;
906            engines.contains_key(&resolved_model_id)
907        };
908
909        if is_loaded {
910            // Check if engine is dead and needs reboot
911            if self.engine_dead(&resolved_model_id)? {
912                tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
913                self.reboot_engine(&resolved_model_id)?
914            }
915
916            let engines = self
917                .engines
918                .read()
919                .map_err(|_| MistralRsError::SenderPoisoned)?;
920            if let Some(engine_instance) = engines.get(&resolved_model_id) {
921                return Ok(engine_instance.sender.clone());
922            }
923        }
924
925        // Check if model is unloaded - trigger auto-reload
926        let is_unloaded = {
927            let unloaded = self
928                .unloaded_models
929                .read()
930                .map_err(|_| MistralRsError::EnginePoisoned)?;
931            unloaded.contains_key(&resolved_model_id)
932        };
933
934        if is_unloaded {
935            tracing::info!(
936                "Model {} is unloaded, triggering auto-reload",
937                resolved_model_id
938            );
939            self.reload_model_blocking(&resolved_model_id)?;
940
941            // After reload, get the sender
942            let engines = self
943                .engines
944                .read()
945                .map_err(|_| MistralRsError::SenderPoisoned)?;
946            if let Some(engine_instance) = engines.get(&resolved_model_id) {
947                return Ok(engine_instance.sender.clone());
948            }
949        }
950
951        Err(MistralRsError::ModelNotFound(resolved_model_id))
952    }
953
954    pub fn get_id(&self) -> String {
955        self.id.clone()
956    }
957
958    pub fn get_creation_time(&self) -> u64 {
959        self.creation_time
960    }
961
962    fn resolve_alias(&self, model_id: &str) -> Result<String, MistralRsError> {
963        let aliases = self
964            .model_aliases
965            .read()
966            .map_err(|_| MistralRsError::SenderPoisoned)?;
967        if let Some(primary_id) = aliases.get(model_id) {
968            Ok(primary_id.clone())
969        } else {
970            Ok(model_id.to_string())
971        }
972    }
973
974    fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, MistralRsError> {
975        match model_id {
976            Some(id) => self.resolve_alias(id),
977            None => {
978                let default_lock = self
979                    .default_engine_id
980                    .read()
981                    .map_err(|_| MistralRsError::SenderPoisoned)?;
982                Ok(default_lock
983                    .as_ref()
984                    .ok_or(MistralRsError::EnginePoisoned)?
985                    .clone())
986            }
987        }
988    }
989
990    /// Register an alternate model ID that resolves to an existing model.
991    pub fn register_model_alias(
992        &self,
993        alias: impl Into<String>,
994        model_id: &str,
995    ) -> Result<(), String> {
996        let alias = alias.into();
997        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
998
999        if alias == resolved_model_id {
1000            return Ok(());
1001        }
1002
1003        let reloading = self
1004            .reloading_models
1005            .read()
1006            .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1007        let model_reloading = reloading.contains(&resolved_model_id);
1008        let alias_conflict = reloading.contains(&alias);
1009        drop(reloading);
1010
1011        let engines = self
1012            .engines
1013            .read()
1014            .map_err(|_| "Failed to acquire read lock on engines")?;
1015        let model_loaded = engines.contains_key(&resolved_model_id);
1016        let alias_conflict = alias_conflict || engines.contains_key(&alias);
1017        drop(engines);
1018
1019        let unloaded = self
1020            .unloaded_models
1021            .read()
1022            .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1023        let model_unloaded = unloaded.contains_key(&resolved_model_id);
1024        let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
1025        drop(unloaded);
1026
1027        if !(model_loaded || model_unloaded || model_reloading) {
1028            return Err(format!("Model {resolved_model_id} not found"));
1029        }
1030
1031        if alias_conflict {
1032            return Err(format!(
1033                "Alias '{}' conflicts with an existing model ID",
1034                alias
1035            ));
1036        }
1037
1038        let mut aliases = self
1039            .model_aliases
1040            .write()
1041            .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1042        if let Some(existing) = aliases.get(&alias) {
1043            if existing == &resolved_model_id {
1044                return Ok(());
1045            }
1046            return Err(format!(
1047                "Alias '{}' is already assigned to model '{}'",
1048                alias, existing
1049            ));
1050        }
1051        aliases.insert(alias, resolved_model_id);
1052        Ok(())
1053    }
1054
1055    /// Check if a model is known (loaded, unloaded, or reloading), resolving aliases if needed.
1056    pub fn model_exists(&self, model_id: &str) -> Result<bool, MistralRsError> {
1057        let resolved_model_id = self.resolve_alias(model_id)?;
1058
1059        let reloading = self
1060            .reloading_models
1061            .read()
1062            .map_err(|_| MistralRsError::EnginePoisoned)?;
1063        if reloading.contains(&resolved_model_id) {
1064            return Ok(true);
1065        }
1066        drop(reloading);
1067
1068        let engines = self
1069            .engines
1070            .read()
1071            .map_err(|_| MistralRsError::EnginePoisoned)?;
1072        if engines.contains_key(&resolved_model_id) {
1073            return Ok(true);
1074        }
1075        drop(engines);
1076
1077        let unloaded = self
1078            .unloaded_models
1079            .read()
1080            .map_err(|_| MistralRsError::EnginePoisoned)?;
1081        if unloaded.contains_key(&resolved_model_id) {
1082            return Ok(true);
1083        }
1084
1085        Ok(false)
1086    }
1087
1088    /// Get model category for a specific model. If model_id is None, uses default engine.
1089    pub fn get_model_category(
1090        &self,
1091        model_id: Option<&str>,
1092    ) -> Result<ModelCategory, MistralRsError> {
1093        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1094
1095        let engines = self
1096            .engines
1097            .read()
1098            .map_err(|_| MistralRsError::SenderPoisoned)?;
1099        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1100            Ok(engine_instance.category.clone())
1101        } else {
1102            Err(MistralRsError::EnginePoisoned)
1103        }
1104    }
1105
1106    /// Get the maximum supported sequence length for a model, if applicable.
1107    pub fn max_sequence_length(
1108        &self,
1109        model_id: Option<&str>,
1110    ) -> Result<Option<usize>, MistralRsError> {
1111        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1112
1113        let engines = self
1114            .engines
1115            .read()
1116            .map_err(|_| MistralRsError::SenderPoisoned)?;
1117        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1118            Ok(engine_instance.config.max_seq_len)
1119        } else {
1120            Err(MistralRsError::EnginePoisoned)
1121        }
1122    }
1123
1124    pub fn next_request_id(&self) -> usize {
1125        let l = self.next_request_id.lock().unwrap();
1126        let last = &mut *l.borrow_mut();
1127        let last_v = *last;
1128        *last += 1;
1129        last_v
1130    }
1131
1132    /// Add a new model engine to the MistralRs instance
1133    pub async fn add_model(
1134        &self,
1135        model_id: String,
1136        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
1137        method: SchedulerConfig,
1138        config: AddModelConfig,
1139    ) -> Result<(), String> {
1140        {
1141            let reloading = self
1142                .reloading_models
1143                .read()
1144                .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1145            if reloading.contains(&model_id) {
1146                return Err(format!("Model {model_id} is currently reloading"));
1147            }
1148        }
1149        {
1150            let engines = self
1151                .engines
1152                .read()
1153                .map_err(|_| "Failed to acquire read lock on engines")?;
1154            if engines.contains_key(&model_id) {
1155                return Err(format!("Model {model_id} already exists"));
1156            }
1157        }
1158        {
1159            let unloaded = self
1160                .unloaded_models
1161                .read()
1162                .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1163            if unloaded.contains_key(&model_id) {
1164                return Err(format!("Model {model_id} already exists (unloaded)"));
1165            }
1166        }
1167        {
1168            let aliases = self
1169                .model_aliases
1170                .read()
1171                .map_err(|_| "Failed to acquire read lock on model_aliases")?;
1172            if aliases.contains_key(&model_id) {
1173                return Err(format!(
1174                    "Model ID '{}' conflicts with an existing alias",
1175                    model_id
1176                ));
1177            }
1178        }
1179
1180        // For hybrid models (Mamba-Attention), force batch_size=1 to prevent state bleeding
1181        let method = {
1182            let pipeline_guard = pipeline.try_lock().unwrap();
1183            if !pipeline_guard.get_metadata().no_kv_cache && pipeline_guard.cache().is_hybrid() {
1184                info!(
1185                    "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
1186                );
1187                SchedulerConfig::DefaultScheduler {
1188                    method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
1189                }
1190            } else {
1191                method
1192            }
1193        };
1194
1195        let reboot_state = RebootState {
1196            pipeline: pipeline.clone(),
1197            method: method.clone(),
1198            no_kv_cache: config.engine_config.no_kv_cache,
1199            no_prefix_cache: config.engine_config.no_prefix_cache,
1200            prefix_cache_n: config.engine_config.prefix_cache_n,
1201            disable_eos_stop: config.engine_config.disable_eos_stop,
1202            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
1203            search_embedding_model: config.engine_config.search_embedding_model,
1204            search_callback: config.engine_config.search_callback.clone(),
1205            tool_callbacks: config.engine_config.tool_callbacks.clone(),
1206            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
1207            mcp_client_config: config.mcp_client_config.clone(),
1208            loader_config: config.loader_config.clone(),
1209        };
1210
1211        let engine_instance =
1212            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
1213
1214        let mut engines = self
1215            .engines
1216            .write()
1217            .map_err(|_| "Failed to acquire write lock on engines")?;
1218        engines.insert(model_id.clone(), engine_instance);
1219
1220        // If this is the first model, set it as default
1221        if engines.len() == 1 {
1222            let mut default_lock = self
1223                .default_engine_id
1224                .write()
1225                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1226            *default_lock = Some(model_id.clone());
1227            info!("First model added, setting '{}' as default", model_id);
1228        }
1229
1230        Ok(())
1231    }
1232
1233    /// Remove a model engine from the MistralRs instance
1234    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
1235        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1236        let mut engines = self
1237            .engines
1238            .write()
1239            .map_err(|_| "Failed to acquire write lock on engines")?;
1240
1241        if engines.len() <= 1 {
1242            return Err("Cannot remove the last model from MistralRs".to_string());
1243        }
1244
1245        if let Some(engine_instance) = engines.remove(&resolved_model_id) {
1246            // Send terminate signal to the engine
1247            let _ = engine_instance.sender.blocking_send(Request::Terminate);
1248
1249            // If this was the default engine, set a new default
1250            let mut default_lock = self
1251                .default_engine_id
1252                .write()
1253                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1254            if let Some(ref default_id) = *default_lock {
1255                if default_id == &resolved_model_id {
1256                    // Set the first available engine as the new default
1257                    *default_lock = engines.keys().next().cloned();
1258                }
1259            }
1260            drop(default_lock);
1261            drop(engines);
1262
1263            // Remove any aliases pointing to the removed model
1264            let mut aliases = self
1265                .model_aliases
1266                .write()
1267                .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1268            aliases.retain(|_, target| target != &resolved_model_id);
1269
1270            Ok(())
1271        } else {
1272            Err(format!("Model {resolved_model_id} not found"))
1273        }
1274    }
1275
1276    /// List all available model IDs
1277    pub fn list_models(&self) -> Result<Vec<String>, String> {
1278        let engines = self
1279            .engines
1280            .read()
1281            .map_err(|_| "Failed to acquire read lock on engines")?;
1282        Ok(engines.keys().cloned().collect())
1283    }
1284
1285    /// Get the current default model ID
1286    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
1287        let default_lock = self
1288            .default_engine_id
1289            .read()
1290            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
1291        Ok(default_lock.clone())
1292    }
1293
1294    /// Set the default model ID
1295    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
1296        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1297        let engines = self
1298            .engines
1299            .read()
1300            .map_err(|_| "Failed to acquire read lock on engines")?;
1301        if !engines.contains_key(&resolved_model_id) {
1302            return Err(format!("Model {resolved_model_id} not found"));
1303        }
1304        drop(engines);
1305
1306        let mut default_lock = self
1307            .default_engine_id
1308            .write()
1309            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1310        let old_default = default_lock.clone();
1311        *default_lock = Some(resolved_model_id.clone());
1312
1313        // Log the change
1314        info!(
1315            "Default model changed: {:?} -> {:?}",
1316            old_default, resolved_model_id
1317        );
1318
1319        Ok(())
1320    }
1321
1322    /// Dispatch a request to the appropriate engine based on the model_id in the request
1323    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
1324        let model_id = match &mut request {
1325            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
1326            _ => None, // Other request types don't specify model_id
1327        };
1328
1329        let sender = self.get_sender(model_id)?;
1330        sender
1331            .blocking_send(request)
1332            .map_err(|_| MistralRsError::SenderPoisoned)
1333    }
1334
1335    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
1336        if let Some(file) = &this.log {
1337            let mut f = OpenOptions::new()
1338                .append(true)
1339                .create(true) // Optionally create the file if it doesn't already exist
1340                .open(file)
1341                .expect("Unable to open file");
1342            let time = chrono::offset::Local::now();
1343            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
1344                .expect("Unable to write data");
1345        }
1346    }
1347
1348    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1349        if let Some(file) = &this.log {
1350            let mut f = OpenOptions::new()
1351                .append(true)
1352                .create(true) // Optionally create the file if it doesn't already exist
1353                .open(file)
1354                .expect("Unable to open file");
1355            let time = chrono::offset::Local::now();
1356            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1357            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1358                .expect("Unable to write data");
1359        }
1360    }
1361
1362    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1363        if let Some(file) = &this.log {
1364            let mut f = OpenOptions::new()
1365                .append(true)
1366                .create(true) // Optionally create the file if it doesn't already exist
1367                .open(file)
1368                .expect("Unable to open file");
1369            let time = chrono::offset::Local::now();
1370            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1371                .expect("Unable to write data");
1372        }
1373    }
1374
1375    /// Get the number of tools available for a specific model (including MCP tools)
1376    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1377        let resolved_model_id = self
1378            .resolve_alias_or_default(model_id)
1379            .map_err(|e| e.to_string())?;
1380
1381        let engines = self
1382            .engines
1383            .read()
1384            .map_err(|_| "Failed to acquire read lock on engines")?;
1385        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1386            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1387        } else {
1388            Err(format!("Model {resolved_model_id} not found"))
1389        }
1390    }
1391
1392    /// Check if MCP client is configured for a specific model
1393    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1394        let resolved_model_id = self
1395            .resolve_alias_or_default(model_id)
1396            .map_err(|e| e.to_string())?;
1397
1398        let engines = self
1399            .engines
1400            .read()
1401            .map_err(|_| "Failed to acquire read lock on engines")?;
1402        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1403            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1404        } else {
1405            Err(format!("Model {resolved_model_id} not found"))
1406        }
1407    }
1408
1409    /// Get config for a specific model
1410    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1411        let resolved_model_id = self
1412            .resolve_alias_or_default(model_id)
1413            .map_err(|e| e.to_string())?;
1414
1415        let engines = self
1416            .engines
1417            .read()
1418            .map_err(|_| "Failed to acquire read lock on engines")?;
1419        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1420            Ok(engine_instance.config.clone())
1421        } else {
1422            Err(format!("Model {resolved_model_id} not found"))
1423        }
1424    }
1425
1426    /// Unload a model from memory while preserving its configuration for later reload.
1427    /// The model can be reloaded automatically when a request is sent to it, or manually
1428    /// using `reload_model()`.
1429    ///
1430    /// Note: The model must have been added with a `ModelLoaderConfig` for auto-reload to work.
1431    /// Models added via `MistralRsBuilder` without explicit loader config cannot be reloaded.
1432    pub fn unload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1433        let resolved_model_id = self.resolve_alias(model_id)?;
1434        // Check if already unloaded
1435        {
1436            let unloaded = self
1437                .unloaded_models
1438                .read()
1439                .map_err(|_| MistralRsError::EnginePoisoned)?;
1440            if unloaded.contains_key(&resolved_model_id) {
1441                return Err(MistralRsError::ModelAlreadyUnloaded(
1442                    resolved_model_id.clone(),
1443                ));
1444            }
1445        }
1446
1447        // Get the engine instance and create UnloadedModelState
1448        let mut engines = self
1449            .engines
1450            .write()
1451            .map_err(|_| MistralRsError::EnginePoisoned)?;
1452
1453        let engine_instance = engines
1454            .remove(&resolved_model_id)
1455            .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?;
1456
1457        // Check if we have loader config for reloading
1458        let loader_config = engine_instance
1459            .reboot_state
1460            .loader_config
1461            .clone()
1462            .ok_or_else(|| MistralRsError::NoLoaderConfig(resolved_model_id.clone()))?;
1463
1464        // Create the unloaded state
1465        let unloaded_state = UnloadedModelState {
1466            loader_config,
1467            scheduler_config: engine_instance.reboot_state.method.clone(),
1468            engine_config: EngineConfig {
1469                no_kv_cache: engine_instance.reboot_state.no_kv_cache,
1470                no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
1471                prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
1472                disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
1473                throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
1474                search_embedding_model: engine_instance.reboot_state.search_embedding_model,
1475                search_callback: engine_instance.reboot_state.search_callback.clone(),
1476                tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
1477                tool_callbacks_with_tools: engine_instance
1478                    .reboot_state
1479                    .tool_callbacks_with_tools
1480                    .clone(),
1481            },
1482            mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
1483            category: engine_instance.category.clone(),
1484            mistralrs_config: engine_instance.config.clone(),
1485        };
1486
1487        // Send terminate signal to the engine
1488        let _ = engine_instance.sender.try_send(Request::Terminate);
1489
1490        drop(engines);
1491
1492        // Store the unloaded state
1493        let mut unloaded = self
1494            .unloaded_models
1495            .write()
1496            .map_err(|_| MistralRsError::EnginePoisoned)?;
1497        unloaded.insert(resolved_model_id.to_string(), unloaded_state);
1498
1499        // Update default if needed
1500        let mut default_lock = self
1501            .default_engine_id
1502            .write()
1503            .map_err(|_| MistralRsError::EnginePoisoned)?;
1504        if let Some(ref default_id) = *default_lock {
1505            if default_id == &resolved_model_id {
1506                // Set the first available engine as the new default
1507                let engines = self
1508                    .engines
1509                    .read()
1510                    .map_err(|_| MistralRsError::EnginePoisoned)?;
1511                *default_lock = engines.keys().next().cloned();
1512            }
1513        }
1514
1515        info!("Model {} unloaded successfully", resolved_model_id);
1516        Ok(())
1517    }
1518
1519    /// Manually reload a previously unloaded model.
1520    /// This is also called automatically by `get_sender()` when a request targets an unloaded model.
1521    pub async fn reload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1522        let resolved_model_id = self.resolve_alias(model_id)?;
1523        // Check if already reloading
1524        {
1525            let reloading = self
1526                .reloading_models
1527                .read()
1528                .map_err(|_| MistralRsError::EnginePoisoned)?;
1529            if reloading.contains(&resolved_model_id) {
1530                return Err(MistralRsError::ModelReloading(resolved_model_id.clone()));
1531            }
1532        }
1533
1534        // Mark as reloading
1535        {
1536            let mut reloading = self
1537                .reloading_models
1538                .write()
1539                .map_err(|_| MistralRsError::EnginePoisoned)?;
1540            reloading.insert(resolved_model_id.clone());
1541        }
1542
1543        // Get the unloaded state
1544        let unloaded_state = {
1545            let unloaded = self
1546                .unloaded_models
1547                .read()
1548                .map_err(|_| MistralRsError::EnginePoisoned)?;
1549            unloaded
1550                .get(&resolved_model_id)
1551                .cloned()
1552                .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?
1553        };
1554
1555        // Attempt to reload
1556        let result = self
1557            .do_reload_model(&resolved_model_id, unloaded_state)
1558            .await;
1559
1560        // Remove from reloading set
1561        {
1562            let mut reloading = self
1563                .reloading_models
1564                .write()
1565                .map_err(|_| MistralRsError::EnginePoisoned)?;
1566            reloading.remove(&resolved_model_id);
1567        }
1568
1569        result
1570    }
1571
1572    /// Internal method to perform the actual model reload
1573    async fn do_reload_model(
1574        &self,
1575        model_id: &str,
1576        unloaded_state: UnloadedModelState,
1577    ) -> Result<(), MistralRsError> {
1578        use crate::model_loader::LoaderBuilder;
1579
1580        info!("Reloading model: {}", model_id);
1581
1582        let loader_config = &unloaded_state.loader_config;
1583
1584        // Build the loader from the stored config
1585        let loader = LoaderBuilder::new(loader_config.model_selected.clone())
1586            .with_chat_template(loader_config.chat_template.clone())
1587            .with_jinja_explicit(loader_config.jinja_explicit.clone())
1588            .build()
1589            .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to build loader: {e}")))?;
1590
1591        // Load the model
1592        let pipeline = loader
1593            .load_model_from_hf(
1594                None,
1595                loader_config.token_source.clone(),
1596                &loader_config.dtype,
1597                &loader_config.device,
1598                loader_config.silent,
1599                loader_config.device_map_setting.clone(),
1600                loader_config.isq,
1601                loader_config.paged_attn_config,
1602            )
1603            .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to load model: {e}")))?;
1604
1605        // Create the reboot state
1606        let reboot_state = RebootState {
1607            pipeline: pipeline.clone(),
1608            method: unloaded_state.scheduler_config.clone(),
1609            no_kv_cache: unloaded_state.engine_config.no_kv_cache,
1610            no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
1611            prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
1612            disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
1613            throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
1614            search_embedding_model: unloaded_state.engine_config.search_embedding_model,
1615            search_callback: unloaded_state.engine_config.search_callback.clone(),
1616            tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
1617            tool_callbacks_with_tools: unloaded_state
1618                .engine_config
1619                .tool_callbacks_with_tools
1620                .clone(),
1621            mcp_client_config: unloaded_state.mcp_client_config.clone(),
1622            loader_config: Some(unloaded_state.loader_config.clone()),
1623        };
1624
1625        // Create the engine instance
1626        let engine_instance = Self::create_engine_instance(
1627            pipeline,
1628            unloaded_state.scheduler_config,
1629            unloaded_state.engine_config,
1630            reboot_state,
1631        )
1632        .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to create engine: {e}")))?;
1633
1634        // Add to engines map
1635        {
1636            let mut engines = self
1637                .engines
1638                .write()
1639                .map_err(|_| MistralRsError::EnginePoisoned)?;
1640            engines.insert(model_id.to_string(), engine_instance);
1641        }
1642
1643        // Remove from unloaded map
1644        {
1645            let mut unloaded = self
1646                .unloaded_models
1647                .write()
1648                .map_err(|_| MistralRsError::EnginePoisoned)?;
1649            unloaded.remove(model_id);
1650        }
1651
1652        info!("Model {} reloaded successfully", model_id);
1653        Ok(())
1654    }
1655
1656    /// Synchronous version of reload_model for use in non-async contexts.
1657    ///
1658    /// This method handles different runtime contexts appropriately:
1659    /// - If called from a multi-threaded tokio runtime, uses `block_in_place`
1660    /// - If called from a single-threaded runtime, returns an error (use `reload_model()` instead)
1661    /// - If called outside any runtime, creates a temporary runtime
1662    pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), MistralRsError> {
1663        match tokio::runtime::Handle::try_current() {
1664            Ok(handle) => {
1665                if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
1666                    Err(MistralRsError::ReloadFailed(
1667                        "Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
1668                    ))
1669                } else {
1670                    tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
1671                }
1672            }
1673            Err(_) => {
1674                let rt = tokio::runtime::Runtime::new().map_err(|e| {
1675                    MistralRsError::ReloadFailed(format!("Failed to create runtime: {e}"))
1676                })?;
1677                rt.block_on(self.reload_model(model_id))
1678            }
1679        }
1680    }
1681
1682    /// List all unloaded model IDs
1683    pub fn list_unloaded_models(&self) -> Result<Vec<String>, MistralRsError> {
1684        let unloaded = self
1685            .unloaded_models
1686            .read()
1687            .map_err(|_| MistralRsError::EnginePoisoned)?;
1688        Ok(unloaded.keys().cloned().collect())
1689    }
1690
1691    /// Check if a model is currently loaded (as opposed to unloaded)
1692    pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, MistralRsError> {
1693        let resolved_model_id = self.resolve_alias(model_id)?;
1694        let engines = self
1695            .engines
1696            .read()
1697            .map_err(|_| MistralRsError::EnginePoisoned)?;
1698        Ok(engines.contains_key(&resolved_model_id))
1699    }
1700
1701    /// Get the status of a model, or None if not found
1702    pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, MistralRsError> {
1703        let resolved_model_id = self.resolve_alias(model_id)?;
1704        // Check if reloading
1705        {
1706            let reloading = self
1707                .reloading_models
1708                .read()
1709                .map_err(|_| MistralRsError::EnginePoisoned)?;
1710            if reloading.contains(&resolved_model_id) {
1711                return Ok(Some(ModelStatus::Reloading));
1712            }
1713        }
1714
1715        // Check if loaded
1716        {
1717            let engines = self
1718                .engines
1719                .read()
1720                .map_err(|_| MistralRsError::EnginePoisoned)?;
1721            if engines.contains_key(&resolved_model_id) {
1722                return Ok(Some(ModelStatus::Loaded));
1723            }
1724        }
1725
1726        // Check if unloaded
1727        {
1728            let unloaded = self
1729                .unloaded_models
1730                .read()
1731                .map_err(|_| MistralRsError::EnginePoisoned)?;
1732            if unloaded.contains_key(&resolved_model_id) {
1733                return Ok(Some(ModelStatus::Unloaded));
1734            }
1735        }
1736
1737        Ok(None)
1738    }
1739
1740    /// List all models with their status
1741    pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, MistralRsError> {
1742        let mut result = Vec::new();
1743
1744        // Get reloading models
1745        let reloading = self
1746            .reloading_models
1747            .read()
1748            .map_err(|_| MistralRsError::EnginePoisoned)?;
1749        for model_id in reloading.iter() {
1750            result.push((model_id.clone(), ModelStatus::Reloading));
1751        }
1752        drop(reloading);
1753
1754        // Get loaded models
1755        let engines = self
1756            .engines
1757            .read()
1758            .map_err(|_| MistralRsError::EnginePoisoned)?;
1759        for model_id in engines.keys() {
1760            result.push((model_id.clone(), ModelStatus::Loaded));
1761        }
1762        drop(engines);
1763
1764        // Get unloaded models
1765        let unloaded = self
1766            .unloaded_models
1767            .read()
1768            .map_err(|_| MistralRsError::EnginePoisoned)?;
1769        for model_id in unloaded.keys() {
1770            // Skip if already in reloading
1771            if !result.iter().any(|(id, _)| id == model_id) {
1772                result.push((model_id.clone(), ModelStatus::Unloaded));
1773            }
1774        }
1775
1776        Ok(result)
1777    }
1778}