Skip to main content

mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    EngineInstruction, IntervalLogger, SearchEmbeddingModel, ENGINE_INSTRUCTIONS,
7    TERMINATE_ALL_NEXT_STEP,
8};
9use hf_hub::Cache;
10pub use lora::Ordering;
11pub use pipeline::ModelCategory;
12pub use pipeline::Pipeline;
13#[cfg(feature = "pyo3_macros")]
14use pyo3::exceptions::PyValueError;
15use std::collections::{HashMap, HashSet};
16use std::sync::OnceLock;
17use std::time::{Duration, Instant};
18use std::{
19    cell::RefCell,
20    error::Error,
21    fs::OpenOptions,
22    io::Write,
23    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24    thread::{self, JoinHandle},
25    time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31pub const MISTRALRS_GIT_REVISION: &str = match option_env!("MISTRALRS_GIT_REVISION") {
32    Some(value) => value,
33    None => "unknown",
34};
35
36mod cuda;
37mod device_map;
38mod engine;
39mod lora;
40mod metal;
41mod model_loader;
42mod moe;
43mod ops;
44mod video_input;
45pub use model_loader::{
46    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
47};
48pub use video_input::{sample_frame_indices, VideoInput};
49mod embedding_models;
50mod kv_cache;
51mod search;
52
53mod model_selected;
54pub use model_selected::ModelSelected;
55pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
56
57mod amoe;
58mod attention;
59mod diagnostics;
60mod diffusion_models;
61pub mod distributed;
62mod gguf;
63pub mod layers;
64mod layers_masker;
65mod layers_utils;
66pub mod matformer;
67mod mla;
68mod models;
69mod paged_attention;
70mod pipeline;
71mod prefix_cacher;
72pub mod reasoning_parsers;
73mod request;
74mod response;
75mod sampler;
76mod scheduler;
77mod sequence;
78mod speech_models;
79mod toml_selector;
80mod tools;
81mod topology;
82mod utils;
83mod vision_models;
84mod xlora_models;
85
86pub use diagnostics::{
87    check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
88    DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
89};
90mod tuning;
91pub use tuning::{
92    auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
93};
94
95pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
96pub use device_map::{
97    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
98};
99pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
100pub use mistralrs_audio::AudioInput;
101pub use mistralrs_mcp::{
102    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
103};
104pub use mistralrs_mcp::{
105    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
106};
107pub use mistralrs_quant::{IsqBits, IsqType, MULTI_LORA_DELIMITER};
108pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
109pub use pipeline::hf::{hf_home_dir, hf_hub_cache_dir, hf_token_path};
110pub use pipeline::{
111    chat_template::ChatTemplate, expand_isq_value, parse_isq_value, AdapterPaths, AnyMoeLoader,
112    AnyMoePipeline, AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams,
113    DiffusionLoader, DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader,
114    EmbeddingLoaderBuilder, EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig,
115    GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder,
116    GGUFSpecificConfig, GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader,
117    LlamaLoader, Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader,
118    Modalities, ModelKind, ModelPaths, MultimodalLoader, MultimodalLoaderBuilder,
119    MultimodalLoaderType, MultimodalPromptPrefixer, MultimodalSpecificConfig, NormalLoader,
120    NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
121    Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
122    SpeechLoader, SpeechPipeline, Starcoder2Loader, SupportedModality, TokenSource,
123    UQFF_MULTI_FILE_DELIMITER,
124};
125pub use request::{
126    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
127    LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
128    SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
129};
130pub use response::*;
131pub use sampler::{
132    CustomLogitsProcessor, DrySamplingParams, ModelGenerationDefaults, SamplingParams, StopTokens,
133    TopLogprob,
134};
135pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
136pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
137use serde::Serialize;
138pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
139use tokio::runtime::Runtime;
140use toml_selector::{TomlLoaderArgs, TomlSelector};
141pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
142pub use topology::{LayerTopology, Topology};
143pub use utils::debug::initialize_logging;
144pub use utils::memory_usage::MemoryUsage;
145pub use utils::normal::{ModelDType, TryIntoDType};
146pub use utils::{paged_attn_supported, using_flash_attn};
147
148// re-export llguidance for easier LlguidanceGrammar construction
149pub use llguidance;
150
151/// `true` if `MISTRALRS_DEBUG=1`
152pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
153pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
154
155/// Configuration for creating an engine instance
156#[derive(Clone)]
157pub struct EngineConfig {
158    pub no_kv_cache: bool,
159    pub no_prefix_cache: bool,
160    pub prefix_cache_n: usize,
161    pub disable_eos_stop: bool,
162    pub throughput_logging_enabled: bool,
163    pub search_embedding_model: Option<SearchEmbeddingModel>,
164    pub search_callback: Option<Arc<SearchCallback>>,
165    pub tool_callbacks: tools::ToolCallbacks,
166    pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
167}
168
169impl Default for EngineConfig {
170    fn default() -> Self {
171        Self {
172            no_kv_cache: false,
173            no_prefix_cache: false,
174            prefix_cache_n: 16,
175            disable_eos_stop: false,
176            throughput_logging_enabled: true,
177            search_embedding_model: None,
178            search_callback: None,
179            tool_callbacks: HashMap::new(),
180            tool_callbacks_with_tools: HashMap::new(),
181        }
182    }
183}
184
185/// Configuration for adding a model to MistralRs
186#[derive(Clone)]
187pub struct AddModelConfig {
188    pub engine_config: EngineConfig,
189    pub mcp_client_config: Option<McpClientConfig>,
190    /// Optional loader config for enabling model unload/reload support.
191    /// Without this, models cannot be unloaded and reloaded.
192    pub loader_config: Option<ModelLoaderConfig>,
193}
194
195impl AddModelConfig {
196    pub fn new(engine_config: EngineConfig) -> Self {
197        Self {
198            engine_config,
199            mcp_client_config: None,
200            loader_config: None,
201        }
202    }
203
204    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
205        self.mcp_client_config = Some(mcp_config);
206        self
207    }
208
209    /// Set the loader config for enabling model unload/reload support.
210    /// Without this, models cannot be unloaded and reloaded.
211    pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
212        self.loader_config = Some(loader_config);
213        self
214    }
215}
216
217#[derive(Clone)]
218pub struct MistralRsConfig {
219    pub kind: ModelKind,
220    pub device: Device,
221    pub category: ModelCategory,
222    pub modalities: Modalities,
223    pub max_seq_len: Option<usize>,
224    pub generation_defaults: Option<ModelGenerationDefaults>,
225}
226
227/// Configuration for recreating a model loader when reloading an unloaded model.
228/// This captures the essential parameters needed to reconstruct a loader.
229#[derive(Clone)]
230pub struct ModelLoaderConfig {
231    /// The model selection configuration (Plain, GGUF, Multimodal, etc.)
232    pub model_selected: ModelSelected,
233    /// Source of the HF token
234    pub token_source: TokenSource,
235    /// Optional HF revision
236    pub hf_revision: Option<String>,
237    /// Model data type
238    pub dtype: ModelDType,
239    /// Device to load the model on
240    pub device: Device,
241    /// Device mapping setting
242    pub device_map_setting: DeviceMapSetting,
243    /// In-situ quantization type
244    pub isq: Option<IsqType>,
245    /// Paged attention configuration
246    pub paged_attn_config: Option<PagedAttentionConfig>,
247    /// Whether to suppress logging during loading
248    pub silent: bool,
249    /// Chat template override
250    pub chat_template: Option<String>,
251    /// Explicit Jinja template path
252    pub jinja_explicit: Option<String>,
253}
254
255/// State preserved when a model is unloaded.
256/// This contains all the information needed to reload the model on demand.
257#[derive(Clone)]
258pub struct UnloadedModelState {
259    /// Configuration to recreate the loader
260    pub loader_config: ModelLoaderConfig,
261    /// Scheduler configuration
262    pub scheduler_config: SchedulerConfig,
263    /// Engine configuration
264    pub engine_config: EngineConfig,
265    /// MCP client configuration
266    pub mcp_client_config: Option<McpClientConfig>,
267    /// Model category (Text, Multimodal, etc.)
268    pub category: ModelCategory,
269    /// Model metadata configuration
270    pub mistralrs_config: MistralRsConfig,
271}
272
273/// Internal structure to hold per-engine state
274struct EngineInstance {
275    sender: Sender<Request>,
276    engine_handler: JoinHandle<()>,
277    reboot_state: RebootState,
278    config: MistralRsConfig,
279    category: ModelCategory,
280    logger: Arc<IntervalLogger>,
281}
282
283/// The MistralRs struct handles sending requests to multiple engines.
284/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
285/// `Sender` and `Receiver` primitives to send and receive requests to the
286/// appropriate engine based on model ID.
287///
288/// ## Lock Ordering Convention
289///
290/// This struct uses multiple `RwLock`s. To prevent deadlocks, locks must be
291/// acquired in this order:
292/// 1. `reloading_models`
293/// 2. `engines`
294/// 3. `unloaded_models`
295/// 4. `default_engine_id`
296/// 5. `model_aliases`
297///
298/// Use scope-based lock management and explicit `drop()` calls.
299pub struct MistralRs {
300    engines: RwLock<HashMap<String, EngineInstance>>,
301    /// Models that have been unloaded but can be reloaded on demand
302    unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
303    /// Models currently being reloaded (to prevent concurrent reloads)
304    reloading_models: RwLock<HashSet<String>>,
305    default_engine_id: RwLock<Option<String>>,
306    /// Alternate IDs that resolve to primary model IDs.
307    model_aliases: RwLock<HashMap<String, String>>,
308    log: Option<String>,
309    id: String,
310    creation_time: u64,
311    next_request_id: Mutex<RefCell<usize>>,
312}
313
314#[derive(Clone)]
315struct RebootState {
316    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
317    method: SchedulerConfig,
318    no_kv_cache: bool,
319    no_prefix_cache: bool,
320    prefix_cache_n: usize,
321    disable_eos_stop: bool,
322    throughput_logging_enabled: bool,
323    search_embedding_model: Option<SearchEmbeddingModel>,
324    search_callback: Option<Arc<search::SearchCallback>>,
325    tool_callbacks: tools::ToolCallbacks,
326    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
327    mcp_client_config: Option<McpClientConfig>,
328    /// Optional loader config for reloading after unload
329    loader_config: Option<ModelLoaderConfig>,
330}
331
332/// Model status for loaded/unloaded state
333#[derive(Debug, Clone, Copy, PartialEq, Eq)]
334pub enum ModelStatus {
335    Loaded,
336    Unloaded,
337    Reloading,
338}
339
340impl std::fmt::Display for ModelStatus {
341    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342        match self {
343            ModelStatus::Loaded => write!(f, "loaded"),
344            ModelStatus::Unloaded => write!(f, "unloaded"),
345            ModelStatus::Reloading => write!(f, "reloading"),
346        }
347    }
348}
349
350#[derive(Debug)]
351pub enum MistralRsError {
352    EnginePoisoned,
353    SenderPoisoned,
354    /// The requested model was not found (neither loaded nor unloaded)
355    ModelNotFound(String),
356    /// The model is currently being reloaded
357    ModelReloading(String),
358    /// Failed to reload the model
359    ReloadFailed(String),
360    /// Model does not have loader config for reloading
361    NoLoaderConfig(String),
362    /// Model is already loaded
363    ModelAlreadyLoaded(String),
364    /// Model is already unloaded
365    ModelAlreadyUnloaded(String),
366}
367
368impl std::fmt::Display for MistralRsError {
369    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370        write!(f, "{:?}", &self)
371    }
372}
373
374impl std::error::Error for MistralRsError {}
375
376#[cfg(feature = "pyo3_macros")]
377impl From<MistralRsError> for pyo3::PyErr {
378    fn from(value: MistralRsError) -> Self {
379        PyValueError::new_err(format!("{value:?}"))
380    }
381}
382
383/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
384/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
385/// instance stays on the calling thread.
386pub struct MistralRsBuilder {
387    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
388    method: SchedulerConfig,
389    model_id_override: Option<String>,
390    log: Option<String>,
391    no_kv_cache: Option<bool>,
392    no_prefix_cache: Option<bool>,
393    prefix_cache_n: Option<usize>,
394    disable_eos_stop: Option<bool>,
395    throughput_logging_enabled: bool,
396    search_embedding_model: Option<SearchEmbeddingModel>,
397    search_callback: Option<Arc<SearchCallback>>,
398    tool_callbacks: tools::ToolCallbacks,
399    tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
400    mcp_client_config: Option<McpClientConfig>,
401    loader_config: Option<ModelLoaderConfig>,
402}
403
404impl MistralRsBuilder {
405    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
406    /// and optional embedding model for web search. To override the search callback,
407    /// use `.with_search_callback(...)` on the builder.
408    pub fn new(
409        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
410        method: SchedulerConfig,
411        throughput_logging: bool,
412        search_embedding_model: Option<SearchEmbeddingModel>,
413    ) -> Self {
414        Self {
415            pipeline,
416            method,
417            model_id_override: None,
418            log: None,
419            no_kv_cache: None,
420            no_prefix_cache: None,
421            prefix_cache_n: None,
422            disable_eos_stop: None,
423            throughput_logging_enabled: throughput_logging,
424            search_embedding_model,
425            search_callback: None,
426            tool_callbacks: HashMap::new(),
427            tool_callbacks_with_tools: HashMap::new(),
428            mcp_client_config: None,
429            loader_config: None,
430        }
431    }
432
433    /// Override the model ID used by MistralRs. Defaults to the pipeline name.
434    pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
435        self.model_id_override = Some(model_id.into());
436        self
437    }
438
439    /// Set the loader config for enabling model unload/reload support.
440    /// Without this, models cannot be unloaded and reloaded.
441    pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
442        self.loader_config = Some(loader_config);
443        self
444    }
445    pub fn with_log(mut self, log: String) -> Self {
446        self.log = Some(log);
447        self
448    }
449    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
450        self.log = log;
451        self
452    }
453    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
454        self.no_kv_cache = Some(no_kv_cache);
455        self
456    }
457    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
458        self.no_prefix_cache = Some(no_prefix_cache);
459        self
460    }
461    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
462        self.prefix_cache_n = Some(prefix_cache_n);
463        self
464    }
465    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
466        self.disable_eos_stop = Some(disable_eos_stop);
467        self
468    }
469
470    /// Use a custom callback to gather search results.
471    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
472        self.search_callback = Some(search_callback);
473        self
474    }
475
476    /// Register a custom callback for the specified tool name.
477    pub fn with_tool_callback(
478        mut self,
479        name: impl Into<String>,
480        tool_callback: Arc<ToolCallback>,
481    ) -> Self {
482        self.tool_callbacks.insert(name.into(), tool_callback);
483        self
484    }
485
486    /// Register a custom callback with its associated Tool definition. The Tool will be
487    /// automatically added to requests when tool callbacks are active.
488    pub fn with_tool_callback_and_tool(
489        mut self,
490        name: impl Into<String>,
491        tool_callback: Arc<ToolCallback>,
492        tool: Tool,
493    ) -> Self {
494        let name = name.into();
495        self.tool_callbacks_with_tools.insert(
496            name,
497            ToolCallbackWithTool {
498                callback: tool_callback,
499                tool,
500            },
501        );
502        self
503    }
504
505    /// Configure MCP client to connect to external MCP servers.
506    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
507        self.mcp_client_config = Some(config);
508        self
509    }
510
511    pub async fn build(self) -> Arc<MistralRs> {
512        MistralRs::new(self).await
513    }
514}
515
516impl Drop for MistralRs {
517    fn drop(&mut self) {
518        // Terminate all engines
519        if let Ok(engines) = self.engines.read() {
520            for (_, engine) in engines.iter() {
521                // Use try_send instead of blocking_send to avoid runtime panics
522                let _ = engine.sender.try_send(Request::Terminate);
523            }
524        }
525    }
526}
527
528impl MistralRs {
529    /// Create an engine instance with the given configuration
530    fn create_engine_instance(
531        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
532        method: SchedulerConfig,
533        config: EngineConfig,
534        reboot_state: RebootState,
535    ) -> Result<EngineInstance, String> {
536        let (tx, rx) = channel(10_000);
537
538        let pipeline_guard = pipeline.try_lock().unwrap();
539        let category = pipeline_guard.category();
540        let metadata = pipeline_guard.get_metadata();
541        let kind = metadata.kind.clone();
542        let device = pipeline_guard.device();
543        let modalities = metadata.modalities.clone();
544        let max_seq_len = match &category {
545            ModelCategory::Diffusion | ModelCategory::Speech => None,
546            _ => Some(metadata.max_seq_len),
547        };
548        let generation_defaults = pipeline_guard.generation_defaults();
549        let encoder_cache_counters = pipeline_guard.encoder_cache_counters();
550        drop(pipeline_guard);
551
552        let logger = Arc::new(IntervalLogger::new(
553            Duration::from_secs(5),
554            encoder_cache_counters,
555        ));
556        let logger_for_engine = logger.clone();
557
558        info!("Pipeline input modalities are {:?}", &modalities.input);
559        info!("Pipeline output modalities are {:?}", &modalities.output);
560
561        let mistralrs_config = MistralRsConfig {
562            kind,
563            device,
564            category: category.clone(),
565            modalities,
566            max_seq_len,
567            generation_defaults,
568        };
569
570        let tx_for_engine = tx.clone();
571        let engine_handler = thread::spawn(move || {
572            #[cfg(feature = "metal")]
573            objc::rc::autoreleasepool(move || {
574                let rt = Runtime::new().unwrap();
575                rt.block_on(async move {
576                    let engine = Engine::new(
577                        tx_for_engine,
578                        rx,
579                        pipeline,
580                        method,
581                        config.no_kv_cache,
582                        config.no_prefix_cache,
583                        config.prefix_cache_n,
584                        config.disable_eos_stop,
585                        config.throughput_logging_enabled,
586                        config.search_embedding_model,
587                        config.search_callback.clone(),
588                        config.tool_callbacks.clone(),
589                        config.tool_callbacks_with_tools.clone(),
590                        logger_for_engine,
591                    )
592                    .expect("Engine creation failed.");
593                    Arc::new(engine).run().await;
594                })
595            });
596
597            #[cfg(not(feature = "metal"))]
598            {
599                let rt = Runtime::new().unwrap();
600                rt.block_on(async move {
601                    let engine = Engine::new(
602                        tx_for_engine,
603                        rx,
604                        pipeline,
605                        method,
606                        config.no_kv_cache,
607                        config.no_prefix_cache,
608                        config.prefix_cache_n,
609                        config.disable_eos_stop,
610                        config.throughput_logging_enabled,
611                        config.search_embedding_model,
612                        config.search_callback.clone(),
613                        config.tool_callbacks.clone(),
614                        config.tool_callbacks_with_tools.clone(),
615                        logger_for_engine,
616                    )
617                    .expect("Engine creation failed.");
618                    Arc::new(engine).run().await;
619                })
620            }
621        });
622
623        Ok(EngineInstance {
624            sender: tx,
625            engine_handler,
626            reboot_state,
627            config: mistralrs_config,
628            category,
629            logger,
630        })
631    }
632
633    async fn new(config: MistralRsBuilder) -> Arc<Self> {
634        info!("git revision: {MISTRALRS_GIT_REVISION}");
635        let MistralRsBuilder {
636            pipeline,
637            method,
638            model_id_override,
639            log,
640            no_kv_cache,
641            no_prefix_cache,
642            prefix_cache_n,
643            disable_eos_stop,
644            throughput_logging_enabled,
645            search_embedding_model,
646            search_callback,
647            tool_callbacks,
648            mut tool_callbacks_with_tools,
649            mcp_client_config,
650            loader_config,
651        } = config;
652
653        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
654            get_mut_arcmutex!(pipeline).device(),
655        );
656
657        let no_kv_cache = no_kv_cache.unwrap_or(false);
658        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
659        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
660        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
661
662        // Initialize MCP client if configured
663        if let Some(config) = &mcp_client_config {
664            let mut mcp_client = McpClient::new(config.clone());
665            let total_servers = config.servers.len();
666
667            match mcp_client.initialize().await {
668                Ok(()) => {
669                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
670                    let tools_count = mcp_callbacks_with_tools.len();
671
672                    // Merge MCP tool callbacks with tools into the new collection
673                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
674                        tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
675                    }
676
677                    if tools_count == 0 {
678                        warn!(
679                            "MCP client initialized but no tools were registered from {} servers",
680                            total_servers
681                        );
682                    } else {
683                        info!(
684                            "MCP client initialized successfully with {} tools from {} servers",
685                            tools_count, total_servers
686                        );
687                    }
688                }
689                Err(e) => {
690                    warn!(
691                        "Failed to initialize MCP client with {} configured servers: {}",
692                        total_servers, e
693                    );
694                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
695                }
696            }
697        }
698
699        let reboot_state = RebootState {
700            pipeline: pipeline.clone(),
701            method: method.clone(),
702            no_kv_cache,
703            no_prefix_cache,
704            prefix_cache_n,
705            disable_eos_stop,
706            throughput_logging_enabled,
707            search_embedding_model,
708            search_callback: search_callback.clone(),
709            tool_callbacks: tool_callbacks.clone(),
710            tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
711            mcp_client_config: mcp_client_config.clone(),
712            loader_config,
713        };
714
715        // Create the engine configuration
716        let engine_config = EngineConfig {
717            no_kv_cache,
718            no_prefix_cache,
719            prefix_cache_n,
720            disable_eos_stop,
721            throughput_logging_enabled,
722            search_embedding_model,
723            search_callback,
724            tool_callbacks,
725            tool_callbacks_with_tools,
726        };
727
728        // Create the engine instance
729        let engine_instance =
730            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
731                .expect("Failed to create engine instance");
732
733        let pipeline_name = pipeline.try_lock().unwrap().name();
734        let (id, alias_map) = match model_id_override {
735            Some(override_id) => {
736                let mut alias_map = HashMap::new();
737                if override_id != pipeline_name {
738                    alias_map.insert(pipeline_name.clone(), override_id.clone());
739                }
740                (override_id, alias_map)
741            }
742            None => (pipeline_name.clone(), HashMap::new()),
743        };
744
745        if distributed::is_daemon() {
746            let request_sender = engine_instance.sender.clone();
747
748            if cfg!(feature = "ring") {
749                // Ring daemon replicator
750                distributed::ring_daemon_replicator(request_sender);
751            } else {
752                // NCCL daemon replicator
753                distributed::nccl_daemon_replicator(request_sender);
754            }
755
756            #[allow(clippy::empty_loop)]
757            loop {}
758        }
759
760        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
761        let is_multi_threaded = tokio::runtime::Handle::try_current()
762            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
763
764        // Do a dummy run
765        if !distributed::is_daemon()
766            && is_multi_threaded
767            && matches!(
768                engine_instance.category,
769                ModelCategory::Text | ModelCategory::Multimodal { .. }
770            )
771        {
772            let clone_sender = engine_instance.sender.clone();
773            tokio::task::block_in_place(|| {
774                let (tx, mut rx) = channel(1);
775                let req = Request::Normal(Box::new(NormalRequest {
776                    id: 0,
777                    messages: RequestMessage::Completion {
778                        text: "hello".to_string(),
779                        echo_prompt: false,
780                        best_of: None,
781                    },
782                    sampling_params: SamplingParams {
783                        max_len: Some(1),
784                        ..SamplingParams::deterministic()
785                    },
786                    response: tx,
787                    return_logprobs: false,
788                    is_streaming: false,
789                    constraint: Constraint::None,
790                    suffix: None,
791                    tool_choice: None,
792                    tools: None,
793                    logits_processors: None,
794                    return_raw_logits: false,
795                    web_search_options: None,
796                    model_id: None,
797                    truncate_sequence: false,
798                }));
799                info!("Beginning dummy run.");
800                let start = Instant::now();
801                clone_sender.blocking_send(req).unwrap();
802
803                // Drain all responses from the channel until it's closed
804                let mut received_any = false;
805                while let Some(_resp) = rx.blocking_recv() {
806                    received_any = true;
807                }
808
809                if received_any {
810                    let end = Instant::now();
811                    info!(
812                        "Dummy run completed in {}s.",
813                        end.duration_since(start).as_secs_f64()
814                    );
815                } else {
816                    warn!("Dummy run failed!");
817                }
818            });
819
820            // Reset logger counters so the dummy run doesn't pollute stats
821            engine_instance.logger.reset();
822        }
823
824        // Create engines map with the first engine
825        let mut engines = HashMap::new();
826        engines.insert(id.clone(), engine_instance);
827
828        Arc::new(Self {
829            engines: RwLock::new(engines),
830            unloaded_models: RwLock::new(HashMap::new()),
831            reloading_models: RwLock::new(HashSet::new()),
832            default_engine_id: RwLock::new(Some(id.clone())),
833            model_aliases: RwLock::new(alias_map),
834            log,
835            id,
836            creation_time: SystemTime::now()
837                .duration_since(UNIX_EPOCH)
838                .expect("Time travel has occurred!")
839                .as_secs(),
840            next_request_id: Mutex::new(RefCell::new(1)),
841        })
842    }
843
844    /// Attempts to reboot a specific engine by model_id
845    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
846        let mut engines = self.engines.write().map_err(|_| {
847            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
848            MistralRsError::EnginePoisoned
849        })?;
850
851        if let Some(engine_instance) = engines.get(model_id) {
852            if !engine_instance.engine_handler.is_finished() {
853                tracing::info!("Engine {} already running, returning ok", model_id);
854                return Ok(());
855            }
856
857            let reboot_state = engine_instance.reboot_state.clone();
858            let engine_config = EngineConfig {
859                no_kv_cache: reboot_state.no_kv_cache,
860                no_prefix_cache: reboot_state.no_prefix_cache,
861                prefix_cache_n: reboot_state.prefix_cache_n,
862                disable_eos_stop: reboot_state.disable_eos_stop,
863                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
864                search_embedding_model: reboot_state.search_embedding_model,
865                search_callback: reboot_state.search_callback.clone(),
866                tool_callbacks: reboot_state.tool_callbacks.clone(),
867                tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
868            };
869            let new_engine_instance = Self::create_engine_instance(
870                reboot_state.pipeline.clone(),
871                reboot_state.method.clone(),
872                engine_config,
873                reboot_state,
874            )
875            .map_err(|e| {
876                tracing::error!("Failed to create new engine instance: {}", e);
877                MistralRsError::EnginePoisoned
878            })?;
879
880            engines.insert(model_id.to_string(), new_engine_instance);
881            tracing::info!("Successfully rebooted engine {}", model_id);
882            Ok(())
883        } else {
884            Err(MistralRsError::EnginePoisoned)
885        }
886    }
887
888    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
889        let engines = self.engines.read().map_err(|_| {
890            tracing::warn!("Couldn't get read lock on engines!");
891            MistralRsError::EnginePoisoned
892        })?;
893
894        if let Some(engine_instance) = engines.get(model_id) {
895            Ok(engine_instance.engine_handler.is_finished())
896        } else {
897            Err(MistralRsError::EnginePoisoned)
898        }
899    }
900
901    /// Get sender for a specific model. If model_id is None, uses default engine.
902    /// If the model is unloaded, it will be automatically reloaded before returning the sender.
903    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
904        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
905
906        // Check if model is loaded
907        let is_loaded = {
908            let engines = self
909                .engines
910                .read()
911                .map_err(|_| MistralRsError::SenderPoisoned)?;
912            engines.contains_key(&resolved_model_id)
913        };
914
915        if is_loaded {
916            // Check if engine is dead and needs reboot
917            if self.engine_dead(&resolved_model_id)? {
918                tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
919                self.reboot_engine(&resolved_model_id)?
920            }
921
922            let engines = self
923                .engines
924                .read()
925                .map_err(|_| MistralRsError::SenderPoisoned)?;
926            if let Some(engine_instance) = engines.get(&resolved_model_id) {
927                return Ok(engine_instance.sender.clone());
928            }
929        }
930
931        // Check if model is unloaded - trigger auto-reload
932        let is_unloaded = {
933            let unloaded = self
934                .unloaded_models
935                .read()
936                .map_err(|_| MistralRsError::EnginePoisoned)?;
937            unloaded.contains_key(&resolved_model_id)
938        };
939
940        if is_unloaded {
941            tracing::info!(
942                "Model {} is unloaded, triggering auto-reload",
943                resolved_model_id
944            );
945            self.reload_model_blocking(&resolved_model_id)?;
946
947            // After reload, get the sender
948            let engines = self
949                .engines
950                .read()
951                .map_err(|_| MistralRsError::SenderPoisoned)?;
952            if let Some(engine_instance) = engines.get(&resolved_model_id) {
953                return Ok(engine_instance.sender.clone());
954            }
955        }
956
957        Err(MistralRsError::ModelNotFound(resolved_model_id))
958    }
959
960    pub fn get_id(&self) -> String {
961        self.id.clone()
962    }
963
964    pub fn get_creation_time(&self) -> u64 {
965        self.creation_time
966    }
967
968    fn resolve_alias(&self, model_id: &str) -> Result<String, MistralRsError> {
969        let aliases = self
970            .model_aliases
971            .read()
972            .map_err(|_| MistralRsError::SenderPoisoned)?;
973        if let Some(primary_id) = aliases.get(model_id) {
974            Ok(primary_id.clone())
975        } else {
976            Ok(model_id.to_string())
977        }
978    }
979
980    fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, MistralRsError> {
981        match model_id {
982            Some(id) => self.resolve_alias(id),
983            None => {
984                let default_lock = self
985                    .default_engine_id
986                    .read()
987                    .map_err(|_| MistralRsError::SenderPoisoned)?;
988                Ok(default_lock
989                    .as_ref()
990                    .ok_or(MistralRsError::EnginePoisoned)?
991                    .clone())
992            }
993        }
994    }
995
996    /// Register an alternate model ID that resolves to an existing model.
997    pub fn register_model_alias(
998        &self,
999        alias: impl Into<String>,
1000        model_id: &str,
1001    ) -> Result<(), String> {
1002        let alias = alias.into();
1003        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1004
1005        if alias == resolved_model_id {
1006            return Ok(());
1007        }
1008
1009        let reloading = self
1010            .reloading_models
1011            .read()
1012            .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1013        let model_reloading = reloading.contains(&resolved_model_id);
1014        let alias_conflict = reloading.contains(&alias);
1015        drop(reloading);
1016
1017        let engines = self
1018            .engines
1019            .read()
1020            .map_err(|_| "Failed to acquire read lock on engines")?;
1021        let model_loaded = engines.contains_key(&resolved_model_id);
1022        let alias_conflict = alias_conflict || engines.contains_key(&alias);
1023        drop(engines);
1024
1025        let unloaded = self
1026            .unloaded_models
1027            .read()
1028            .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1029        let model_unloaded = unloaded.contains_key(&resolved_model_id);
1030        let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
1031        drop(unloaded);
1032
1033        if !(model_loaded || model_unloaded || model_reloading) {
1034            return Err(format!("Model {resolved_model_id} not found"));
1035        }
1036
1037        if alias_conflict {
1038            return Err(format!(
1039                "Alias '{}' conflicts with an existing model ID",
1040                alias
1041            ));
1042        }
1043
1044        let mut aliases = self
1045            .model_aliases
1046            .write()
1047            .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1048        if let Some(existing) = aliases.get(&alias) {
1049            if existing == &resolved_model_id {
1050                return Ok(());
1051            }
1052            return Err(format!(
1053                "Alias '{}' is already assigned to model '{}'",
1054                alias, existing
1055            ));
1056        }
1057        aliases.insert(alias, resolved_model_id);
1058        Ok(())
1059    }
1060
1061    /// Check if a model is known (loaded, unloaded, or reloading), resolving aliases if needed.
1062    pub fn model_exists(&self, model_id: &str) -> Result<bool, MistralRsError> {
1063        let resolved_model_id = self.resolve_alias(model_id)?;
1064
1065        let reloading = self
1066            .reloading_models
1067            .read()
1068            .map_err(|_| MistralRsError::EnginePoisoned)?;
1069        if reloading.contains(&resolved_model_id) {
1070            return Ok(true);
1071        }
1072        drop(reloading);
1073
1074        let engines = self
1075            .engines
1076            .read()
1077            .map_err(|_| MistralRsError::EnginePoisoned)?;
1078        if engines.contains_key(&resolved_model_id) {
1079            return Ok(true);
1080        }
1081        drop(engines);
1082
1083        let unloaded = self
1084            .unloaded_models
1085            .read()
1086            .map_err(|_| MistralRsError::EnginePoisoned)?;
1087        if unloaded.contains_key(&resolved_model_id) {
1088            return Ok(true);
1089        }
1090
1091        Ok(false)
1092    }
1093
1094    /// Get the interval logger for a specific model. If model_id is None, uses default engine.
1095    pub fn get_logger(
1096        &self,
1097        model_id: Option<&str>,
1098    ) -> Result<Arc<IntervalLogger>, MistralRsError> {
1099        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1100
1101        let engines = self
1102            .engines
1103            .read()
1104            .map_err(|_| MistralRsError::SenderPoisoned)?;
1105        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1106            Ok(engine_instance.logger.clone())
1107        } else {
1108            Err(MistralRsError::EnginePoisoned)
1109        }
1110    }
1111
1112    /// Get model category for a specific model. If model_id is None, uses default engine.
1113    pub fn get_model_category(
1114        &self,
1115        model_id: Option<&str>,
1116    ) -> Result<ModelCategory, MistralRsError> {
1117        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1118
1119        let engines = self
1120            .engines
1121            .read()
1122            .map_err(|_| MistralRsError::SenderPoisoned)?;
1123        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1124            Ok(engine_instance.category.clone())
1125        } else {
1126            Err(MistralRsError::EnginePoisoned)
1127        }
1128    }
1129
1130    /// Get the maximum supported sequence length for a model, if applicable.
1131    pub fn max_sequence_length(
1132        &self,
1133        model_id: Option<&str>,
1134    ) -> Result<Option<usize>, MistralRsError> {
1135        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1136
1137        let engines = self
1138            .engines
1139            .read()
1140            .map_err(|_| MistralRsError::SenderPoisoned)?;
1141        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1142            Ok(engine_instance.config.max_seq_len)
1143        } else {
1144            Err(MistralRsError::EnginePoisoned)
1145        }
1146    }
1147
1148    pub fn next_request_id(&self) -> usize {
1149        let l = self.next_request_id.lock().unwrap();
1150        let last = &mut *l.borrow_mut();
1151        let last_v = *last;
1152        *last += 1;
1153        last_v
1154    }
1155
1156    /// Add a new model engine to the MistralRs instance
1157    pub async fn add_model(
1158        &self,
1159        model_id: String,
1160        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
1161        method: SchedulerConfig,
1162        config: AddModelConfig,
1163    ) -> Result<(), String> {
1164        {
1165            let reloading = self
1166                .reloading_models
1167                .read()
1168                .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1169            if reloading.contains(&model_id) {
1170                return Err(format!("Model {model_id} is currently reloading"));
1171            }
1172        }
1173        {
1174            let engines = self
1175                .engines
1176                .read()
1177                .map_err(|_| "Failed to acquire read lock on engines")?;
1178            if engines.contains_key(&model_id) {
1179                return Err(format!("Model {model_id} already exists"));
1180            }
1181        }
1182        {
1183            let unloaded = self
1184                .unloaded_models
1185                .read()
1186                .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1187            if unloaded.contains_key(&model_id) {
1188                return Err(format!("Model {model_id} already exists (unloaded)"));
1189            }
1190        }
1191        {
1192            let aliases = self
1193                .model_aliases
1194                .read()
1195                .map_err(|_| "Failed to acquire read lock on model_aliases")?;
1196            if aliases.contains_key(&model_id) {
1197                return Err(format!(
1198                    "Model ID '{}' conflicts with an existing alias",
1199                    model_id
1200                ));
1201            }
1202        }
1203
1204        let reboot_state = RebootState {
1205            pipeline: pipeline.clone(),
1206            method: method.clone(),
1207            no_kv_cache: config.engine_config.no_kv_cache,
1208            no_prefix_cache: config.engine_config.no_prefix_cache,
1209            prefix_cache_n: config.engine_config.prefix_cache_n,
1210            disable_eos_stop: config.engine_config.disable_eos_stop,
1211            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
1212            search_embedding_model: config.engine_config.search_embedding_model,
1213            search_callback: config.engine_config.search_callback.clone(),
1214            tool_callbacks: config.engine_config.tool_callbacks.clone(),
1215            tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
1216            mcp_client_config: config.mcp_client_config.clone(),
1217            loader_config: config.loader_config.clone(),
1218        };
1219
1220        let engine_instance =
1221            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
1222
1223        let mut engines = self
1224            .engines
1225            .write()
1226            .map_err(|_| "Failed to acquire write lock on engines")?;
1227        engines.insert(model_id.clone(), engine_instance);
1228
1229        // If this is the first model, set it as default
1230        if engines.len() == 1 {
1231            let mut default_lock = self
1232                .default_engine_id
1233                .write()
1234                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1235            *default_lock = Some(model_id.clone());
1236            info!("First model added, setting '{}' as default", model_id);
1237        }
1238
1239        Ok(())
1240    }
1241
1242    /// Remove a model engine from the MistralRs instance
1243    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
1244        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1245        let mut engines = self
1246            .engines
1247            .write()
1248            .map_err(|_| "Failed to acquire write lock on engines")?;
1249
1250        if engines.len() <= 1 {
1251            return Err("Cannot remove the last model from MistralRs".to_string());
1252        }
1253
1254        if let Some(engine_instance) = engines.remove(&resolved_model_id) {
1255            // Send terminate signal to the engine
1256            let _ = engine_instance.sender.blocking_send(Request::Terminate);
1257
1258            // If this was the default engine, set a new default
1259            let mut default_lock = self
1260                .default_engine_id
1261                .write()
1262                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1263            if let Some(ref default_id) = *default_lock {
1264                if default_id == &resolved_model_id {
1265                    // Set the first available engine as the new default
1266                    *default_lock = engines.keys().next().cloned();
1267                }
1268            }
1269            drop(default_lock);
1270            drop(engines);
1271
1272            // Remove any aliases pointing to the removed model
1273            let mut aliases = self
1274                .model_aliases
1275                .write()
1276                .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1277            aliases.retain(|_, target| target != &resolved_model_id);
1278
1279            Ok(())
1280        } else {
1281            Err(format!("Model {resolved_model_id} not found"))
1282        }
1283    }
1284
1285    /// List all available model IDs
1286    pub fn list_models(&self) -> Result<Vec<String>, String> {
1287        let engines = self
1288            .engines
1289            .read()
1290            .map_err(|_| "Failed to acquire read lock on engines")?;
1291        Ok(engines.keys().cloned().collect())
1292    }
1293
1294    /// Get the current default model ID
1295    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
1296        let default_lock = self
1297            .default_engine_id
1298            .read()
1299            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
1300        Ok(default_lock.clone())
1301    }
1302
1303    /// Set the default model ID
1304    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
1305        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1306        let engines = self
1307            .engines
1308            .read()
1309            .map_err(|_| "Failed to acquire read lock on engines")?;
1310        if !engines.contains_key(&resolved_model_id) {
1311            return Err(format!("Model {resolved_model_id} not found"));
1312        }
1313        drop(engines);
1314
1315        let mut default_lock = self
1316            .default_engine_id
1317            .write()
1318            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1319        let old_default = default_lock.clone();
1320        *default_lock = Some(resolved_model_id.clone());
1321
1322        // Log the change
1323        info!(
1324            "Default model changed: {:?} -> {:?}",
1325            old_default, resolved_model_id
1326        );
1327
1328        Ok(())
1329    }
1330
1331    /// Dispatch a request to the appropriate engine based on the model_id in the request
1332    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
1333        let model_id = match &mut request {
1334            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
1335            _ => None, // Other request types don't specify model_id
1336        };
1337
1338        let sender = self.get_sender(model_id)?;
1339        sender
1340            .blocking_send(request)
1341            .map_err(|_| MistralRsError::SenderPoisoned)
1342    }
1343
1344    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
1345        if let Some(file) = &this.log {
1346            let mut f = OpenOptions::new()
1347                .append(true)
1348                .create(true) // Optionally create the file if it doesn't already exist
1349                .open(file)
1350                .expect("Unable to open file");
1351            let time = chrono::offset::Local::now();
1352            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
1353                .expect("Unable to write data");
1354        }
1355    }
1356
1357    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1358        if let Some(file) = &this.log {
1359            let mut f = OpenOptions::new()
1360                .append(true)
1361                .create(true) // Optionally create the file if it doesn't already exist
1362                .open(file)
1363                .expect("Unable to open file");
1364            let time = chrono::offset::Local::now();
1365            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1366            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1367                .expect("Unable to write data");
1368        }
1369    }
1370
1371    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1372        if let Some(file) = &this.log {
1373            let mut f = OpenOptions::new()
1374                .append(true)
1375                .create(true) // Optionally create the file if it doesn't already exist
1376                .open(file)
1377                .expect("Unable to open file");
1378            let time = chrono::offset::Local::now();
1379            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1380                .expect("Unable to write data");
1381        }
1382    }
1383
1384    /// Get the number of tools available for a specific model (including MCP tools)
1385    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1386        let resolved_model_id = self
1387            .resolve_alias_or_default(model_id)
1388            .map_err(|e| e.to_string())?;
1389
1390        let engines = self
1391            .engines
1392            .read()
1393            .map_err(|_| "Failed to acquire read lock on engines")?;
1394        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1395            Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1396        } else {
1397            Err(format!("Model {resolved_model_id} not found"))
1398        }
1399    }
1400
1401    /// Check if MCP client is configured for a specific model
1402    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1403        let resolved_model_id = self
1404            .resolve_alias_or_default(model_id)
1405            .map_err(|e| e.to_string())?;
1406
1407        let engines = self
1408            .engines
1409            .read()
1410            .map_err(|_| "Failed to acquire read lock on engines")?;
1411        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1412            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1413        } else {
1414            Err(format!("Model {resolved_model_id} not found"))
1415        }
1416    }
1417
1418    /// Get config for a specific model
1419    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1420        let resolved_model_id = self
1421            .resolve_alias_or_default(model_id)
1422            .map_err(|e| e.to_string())?;
1423
1424        let engines = self
1425            .engines
1426            .read()
1427            .map_err(|_| "Failed to acquire read lock on engines")?;
1428        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1429            Ok(engine_instance.config.clone())
1430        } else {
1431            Err(format!("Model {resolved_model_id} not found"))
1432        }
1433    }
1434
1435    /// Unload a model from memory while preserving its configuration for later reload.
1436    /// The model can be reloaded automatically when a request is sent to it, or manually
1437    /// using `reload_model()`.
1438    ///
1439    /// Note: The model must have been added with a `ModelLoaderConfig` for auto-reload to work.
1440    /// Models added via `MistralRsBuilder` without explicit loader config cannot be reloaded.
1441    pub fn unload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1442        let resolved_model_id = self.resolve_alias(model_id)?;
1443        // Check if already unloaded
1444        {
1445            let unloaded = self
1446                .unloaded_models
1447                .read()
1448                .map_err(|_| MistralRsError::EnginePoisoned)?;
1449            if unloaded.contains_key(&resolved_model_id) {
1450                return Err(MistralRsError::ModelAlreadyUnloaded(
1451                    resolved_model_id.clone(),
1452                ));
1453            }
1454        }
1455
1456        // Get the engine instance and create UnloadedModelState
1457        let mut engines = self
1458            .engines
1459            .write()
1460            .map_err(|_| MistralRsError::EnginePoisoned)?;
1461
1462        let engine_instance = engines
1463            .remove(&resolved_model_id)
1464            .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?;
1465
1466        // Check if we have loader config for reloading
1467        let loader_config = engine_instance
1468            .reboot_state
1469            .loader_config
1470            .clone()
1471            .ok_or_else(|| MistralRsError::NoLoaderConfig(resolved_model_id.clone()))?;
1472
1473        // Create the unloaded state
1474        let unloaded_state = UnloadedModelState {
1475            loader_config,
1476            scheduler_config: engine_instance.reboot_state.method.clone(),
1477            engine_config: EngineConfig {
1478                no_kv_cache: engine_instance.reboot_state.no_kv_cache,
1479                no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
1480                prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
1481                disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
1482                throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
1483                search_embedding_model: engine_instance.reboot_state.search_embedding_model,
1484                search_callback: engine_instance.reboot_state.search_callback.clone(),
1485                tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
1486                tool_callbacks_with_tools: engine_instance
1487                    .reboot_state
1488                    .tool_callbacks_with_tools
1489                    .clone(),
1490            },
1491            mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
1492            category: engine_instance.category.clone(),
1493            mistralrs_config: engine_instance.config.clone(),
1494        };
1495
1496        // Send terminate signal to the engine
1497        let _ = engine_instance.sender.try_send(Request::Terminate);
1498
1499        drop(engines);
1500
1501        // Store the unloaded state
1502        let mut unloaded = self
1503            .unloaded_models
1504            .write()
1505            .map_err(|_| MistralRsError::EnginePoisoned)?;
1506        unloaded.insert(resolved_model_id.to_string(), unloaded_state);
1507
1508        // Update default if needed
1509        let mut default_lock = self
1510            .default_engine_id
1511            .write()
1512            .map_err(|_| MistralRsError::EnginePoisoned)?;
1513        if let Some(ref default_id) = *default_lock {
1514            if default_id == &resolved_model_id {
1515                // Set the first available engine as the new default
1516                let engines = self
1517                    .engines
1518                    .read()
1519                    .map_err(|_| MistralRsError::EnginePoisoned)?;
1520                *default_lock = engines.keys().next().cloned();
1521            }
1522        }
1523
1524        info!("Model {} unloaded successfully", resolved_model_id);
1525        Ok(())
1526    }
1527
1528    /// Manually reload a previously unloaded model.
1529    /// This is also called automatically by `get_sender()` when a request targets an unloaded model.
1530    pub async fn reload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1531        let resolved_model_id = self.resolve_alias(model_id)?;
1532        // Check if already reloading
1533        {
1534            let reloading = self
1535                .reloading_models
1536                .read()
1537                .map_err(|_| MistralRsError::EnginePoisoned)?;
1538            if reloading.contains(&resolved_model_id) {
1539                return Err(MistralRsError::ModelReloading(resolved_model_id.clone()));
1540            }
1541        }
1542
1543        // Mark as reloading
1544        {
1545            let mut reloading = self
1546                .reloading_models
1547                .write()
1548                .map_err(|_| MistralRsError::EnginePoisoned)?;
1549            reloading.insert(resolved_model_id.clone());
1550        }
1551
1552        // Get the unloaded state
1553        let unloaded_state = {
1554            let unloaded = self
1555                .unloaded_models
1556                .read()
1557                .map_err(|_| MistralRsError::EnginePoisoned)?;
1558            unloaded
1559                .get(&resolved_model_id)
1560                .cloned()
1561                .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?
1562        };
1563
1564        // Attempt to reload
1565        let result = self
1566            .do_reload_model(&resolved_model_id, unloaded_state)
1567            .await;
1568
1569        // Remove from reloading set
1570        {
1571            let mut reloading = self
1572                .reloading_models
1573                .write()
1574                .map_err(|_| MistralRsError::EnginePoisoned)?;
1575            reloading.remove(&resolved_model_id);
1576        }
1577
1578        result
1579    }
1580
1581    /// Internal method to perform the actual model reload
1582    async fn do_reload_model(
1583        &self,
1584        model_id: &str,
1585        unloaded_state: UnloadedModelState,
1586    ) -> Result<(), MistralRsError> {
1587        use crate::model_loader::LoaderBuilder;
1588
1589        info!("Reloading model: {}", model_id);
1590
1591        let loader_config = &unloaded_state.loader_config;
1592
1593        // Build the loader from the stored config
1594        let loader = LoaderBuilder::new(loader_config.model_selected.clone())
1595            .with_chat_template(loader_config.chat_template.clone())
1596            .with_jinja_explicit(loader_config.jinja_explicit.clone())
1597            .build()
1598            .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to build loader: {e}")))?;
1599
1600        // Load the model
1601        let pipeline = loader
1602            .load_model_from_hf(
1603                None,
1604                loader_config.token_source.clone(),
1605                &loader_config.dtype,
1606                &loader_config.device,
1607                loader_config.silent,
1608                loader_config.device_map_setting.clone(),
1609                loader_config.isq,
1610                loader_config.paged_attn_config,
1611            )
1612            .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to load model: {e}")))?;
1613
1614        // Create the reboot state
1615        let reboot_state = RebootState {
1616            pipeline: pipeline.clone(),
1617            method: unloaded_state.scheduler_config.clone(),
1618            no_kv_cache: unloaded_state.engine_config.no_kv_cache,
1619            no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
1620            prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
1621            disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
1622            throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
1623            search_embedding_model: unloaded_state.engine_config.search_embedding_model,
1624            search_callback: unloaded_state.engine_config.search_callback.clone(),
1625            tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
1626            tool_callbacks_with_tools: unloaded_state
1627                .engine_config
1628                .tool_callbacks_with_tools
1629                .clone(),
1630            mcp_client_config: unloaded_state.mcp_client_config.clone(),
1631            loader_config: Some(unloaded_state.loader_config.clone()),
1632        };
1633
1634        // Create the engine instance
1635        let engine_instance = Self::create_engine_instance(
1636            pipeline,
1637            unloaded_state.scheduler_config,
1638            unloaded_state.engine_config,
1639            reboot_state,
1640        )
1641        .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to create engine: {e}")))?;
1642
1643        // Add to engines map
1644        {
1645            let mut engines = self
1646                .engines
1647                .write()
1648                .map_err(|_| MistralRsError::EnginePoisoned)?;
1649            engines.insert(model_id.to_string(), engine_instance);
1650        }
1651
1652        // Remove from unloaded map
1653        {
1654            let mut unloaded = self
1655                .unloaded_models
1656                .write()
1657                .map_err(|_| MistralRsError::EnginePoisoned)?;
1658            unloaded.remove(model_id);
1659        }
1660
1661        info!("Model {} reloaded successfully", model_id);
1662        Ok(())
1663    }
1664
1665    /// Synchronous version of reload_model for use in non-async contexts.
1666    ///
1667    /// This method handles different runtime contexts appropriately:
1668    /// - If called from a multi-threaded tokio runtime, uses `block_in_place`
1669    /// - If called from a single-threaded runtime, returns an error (use `reload_model()` instead)
1670    /// - If called outside any runtime, creates a temporary runtime
1671    pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), MistralRsError> {
1672        match tokio::runtime::Handle::try_current() {
1673            Ok(handle) => {
1674                if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
1675                    Err(MistralRsError::ReloadFailed(
1676                        "Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
1677                    ))
1678                } else {
1679                    tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
1680                }
1681            }
1682            Err(_) => {
1683                let rt = tokio::runtime::Runtime::new().map_err(|e| {
1684                    MistralRsError::ReloadFailed(format!("Failed to create runtime: {e}"))
1685                })?;
1686                rt.block_on(self.reload_model(model_id))
1687            }
1688        }
1689    }
1690
1691    /// List all unloaded model IDs
1692    pub fn list_unloaded_models(&self) -> Result<Vec<String>, MistralRsError> {
1693        let unloaded = self
1694            .unloaded_models
1695            .read()
1696            .map_err(|_| MistralRsError::EnginePoisoned)?;
1697        Ok(unloaded.keys().cloned().collect())
1698    }
1699
1700    /// Check if a model is currently loaded (as opposed to unloaded)
1701    pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, MistralRsError> {
1702        let resolved_model_id = self.resolve_alias(model_id)?;
1703        let engines = self
1704            .engines
1705            .read()
1706            .map_err(|_| MistralRsError::EnginePoisoned)?;
1707        Ok(engines.contains_key(&resolved_model_id))
1708    }
1709
1710    /// Get the status of a model, or None if not found
1711    pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, MistralRsError> {
1712        let resolved_model_id = self.resolve_alias(model_id)?;
1713        // Check if reloading
1714        {
1715            let reloading = self
1716                .reloading_models
1717                .read()
1718                .map_err(|_| MistralRsError::EnginePoisoned)?;
1719            if reloading.contains(&resolved_model_id) {
1720                return Ok(Some(ModelStatus::Reloading));
1721            }
1722        }
1723
1724        // Check if loaded
1725        {
1726            let engines = self
1727                .engines
1728                .read()
1729                .map_err(|_| MistralRsError::EnginePoisoned)?;
1730            if engines.contains_key(&resolved_model_id) {
1731                return Ok(Some(ModelStatus::Loaded));
1732            }
1733        }
1734
1735        // Check if unloaded
1736        {
1737            let unloaded = self
1738                .unloaded_models
1739                .read()
1740                .map_err(|_| MistralRsError::EnginePoisoned)?;
1741            if unloaded.contains_key(&resolved_model_id) {
1742                return Ok(Some(ModelStatus::Unloaded));
1743            }
1744        }
1745
1746        Ok(None)
1747    }
1748
1749    /// List all models with their status
1750    pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, MistralRsError> {
1751        let mut result = Vec::new();
1752
1753        // Get reloading models
1754        let reloading = self
1755            .reloading_models
1756            .read()
1757            .map_err(|_| MistralRsError::EnginePoisoned)?;
1758        for model_id in reloading.iter() {
1759            result.push((model_id.clone(), ModelStatus::Reloading));
1760        }
1761        drop(reloading);
1762
1763        // Get loaded models
1764        let engines = self
1765            .engines
1766            .read()
1767            .map_err(|_| MistralRsError::EnginePoisoned)?;
1768        for model_id in engines.keys() {
1769            result.push((model_id.clone(), ModelStatus::Loaded));
1770        }
1771        drop(engines);
1772
1773        // Get unloaded models
1774        let unloaded = self
1775            .unloaded_models
1776            .read()
1777            .map_err(|_| MistralRsError::EnginePoisoned)?;
1778        for model_id in unloaded.keys() {
1779            // Skip if already in reloading
1780            if !result.iter().any(|(id, _)| id == model_id) {
1781                result.push((model_id.clone(), ModelStatus::Unloaded));
1782            }
1783        }
1784
1785        Ok(result)
1786    }
1787}