Skip to main content

onde_mistralrs_core/
lib.rs

1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5    get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6    EngineInstruction, IntervalLogger, SearchEmbeddingModel, ENGINE_INSTRUCTIONS,
7    TERMINATE_ALL_NEXT_STEP,
8};
9use hf_hub::Cache;
10pub use lora::Ordering;
11pub use pipeline::ModelCategory;
12pub use pipeline::Pipeline;
13#[cfg(feature = "pyo3_macros")]
14use pyo3::exceptions::PyValueError;
15use std::collections::{HashMap, HashSet};
16use std::sync::OnceLock;
17use std::time::{Duration, Instant};
18use std::{
19    cell::RefCell,
20    error::Error,
21    fs::OpenOptions,
22    io::Write,
23    sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24    thread::{self, JoinHandle},
25    time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31pub const MISTRALRS_GIT_REVISION: &str = match option_env!("MISTRALRS_GIT_REVISION") {
32    Some(value) => value,
33    None => "unknown",
34};
35
36mod cuda;
37mod device_map;
38mod engine;
39mod lora;
40mod metal;
41mod model_loader;
42mod moe;
43mod ops;
44mod video_input;
45pub use model_loader::{
46    get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
47};
48pub use video_input::{sample_frame_indices, VideoInput};
49mod embedding_models;
50mod kv_cache;
51mod search;
52
53mod model_selected;
54pub use model_selected::ModelSelected;
55pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
56
57mod amoe;
58mod attention;
59mod diagnostics;
60mod diffusion_models;
61pub mod distributed;
62mod gguf;
63pub mod layers;
64mod layers_masker;
65mod layers_utils;
66pub mod matformer;
67mod mla;
68mod models;
69mod paged_attention;
70mod pipeline;
71mod prefix_cacher;
72pub mod reasoning_parsers;
73mod request;
74mod response;
75mod sampler;
76mod scheduler;
77mod sequence;
78mod speech_models;
79mod toml_selector;
80mod tools;
81mod topology;
82mod utils;
83mod vision_models;
84mod xlora_models;
85
86pub use diagnostics::{
87    check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
88    DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
89};
90mod tuning;
91pub use tuning::{
92    auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
93};
94
95pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
96pub use device_map::{
97    DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
98};
99pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
100pub use mistralrs_audio::AudioInput;
101pub use mistralrs_mcp::{
102    CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
103};
104pub use mistralrs_mcp::{
105    McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
106};
107pub use mistralrs_quant::{IsqBits, IsqType, MULTI_LORA_DELIMITER};
108pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
109pub use pipeline::hf::{hf_home_dir, hf_hub_cache_dir, hf_token_path};
110pub use pipeline::{
111    chat_template::ChatTemplate, expand_isq_value, parse_isq_value, AdapterPaths, AnyMoeLoader,
112    AnyMoePipeline, AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams,
113    DiffusionLoader, DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader,
114    EmbeddingLoaderBuilder, EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig,
115    GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder,
116    GGUFSpecificConfig, GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader,
117    LlamaLoader, Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader,
118    Modalities, ModelKind, ModelPaths, MultimodalLoader, MultimodalLoaderBuilder,
119    MultimodalLoaderType, MultimodalPromptPrefixer, MultimodalSpecificConfig, NormalLoader,
120    NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
121    Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
122    SpeechLoader, SpeechPipeline, Starcoder2Loader, SupportedModality, TokenSource,
123    UQFF_MULTI_FILE_DELIMITER,
124};
125pub use request::{
126    ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
127    LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
128    SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
129};
130pub use response::*;
131pub use sampler::{
132    CustomLogitsProcessor, DrySamplingParams, ModelGenerationDefaults, SamplingParams, StopTokens,
133    TopLogprob,
134};
135pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
136pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
137use serde::Serialize;
138pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
139use tokio::runtime::Runtime;
140use toml_selector::{TomlLoaderArgs, TomlSelector};
141pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
142pub use topology::{LayerTopology, Topology};
143pub use utils::debug::initialize_logging;
144pub use utils::memory_usage::MemoryUsage;
145pub use utils::normal::{ModelDType, TryIntoDType};
146pub use utils::{paged_attn_supported, using_flash_attn};
147
148// re-export llguidance for easier LlguidanceGrammar construction
149pub use llguidance;
150
151/// `true` if `MISTRALRS_DEBUG=1`
152pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
153pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
154
155/// Configuration for creating an engine instance
156#[derive(Clone)]
157pub struct EngineConfig {
158    pub no_kv_cache: bool,
159    pub no_prefix_cache: bool,
160    pub prefix_cache_n: usize,
161    pub disable_eos_stop: bool,
162    pub throughput_logging_enabled: bool,
163    pub search_embedding_model: Option<SearchEmbeddingModel>,
164    pub search_callback: Option<Arc<SearchCallback>>,
165    pub tool_callbacks: tools::ToolCallbacksWithTools,
166}
167
168impl Default for EngineConfig {
169    fn default() -> Self {
170        Self {
171            no_kv_cache: false,
172            no_prefix_cache: false,
173            prefix_cache_n: 16,
174            disable_eos_stop: false,
175            throughput_logging_enabled: true,
176            search_embedding_model: None,
177            search_callback: None,
178            tool_callbacks: HashMap::new(),
179        }
180    }
181}
182
183/// Configuration for adding a model to MistralRs
184#[derive(Clone)]
185pub struct AddModelConfig {
186    pub engine_config: EngineConfig,
187    pub mcp_client_config: Option<McpClientConfig>,
188    /// Optional loader config for enabling model unload/reload support.
189    /// Without this, models cannot be unloaded and reloaded.
190    pub loader_config: Option<ModelLoaderConfig>,
191}
192
193impl AddModelConfig {
194    pub fn new(engine_config: EngineConfig) -> Self {
195        Self {
196            engine_config,
197            mcp_client_config: None,
198            loader_config: None,
199        }
200    }
201
202    pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
203        self.mcp_client_config = Some(mcp_config);
204        self
205    }
206
207    /// Set the loader config for enabling model unload/reload support.
208    /// Without this, models cannot be unloaded and reloaded.
209    pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
210        self.loader_config = Some(loader_config);
211        self
212    }
213}
214
215#[derive(Clone)]
216pub struct MistralRsConfig {
217    pub kind: ModelKind,
218    pub device: Device,
219    pub category: ModelCategory,
220    pub modalities: Modalities,
221    pub max_seq_len: Option<usize>,
222    pub generation_defaults: Option<ModelGenerationDefaults>,
223}
224
225/// Configuration for recreating a model loader when reloading an unloaded model.
226/// This captures the essential parameters needed to reconstruct a loader.
227#[derive(Clone)]
228pub struct ModelLoaderConfig {
229    /// The model selection configuration (Plain, GGUF, Multimodal, etc.)
230    pub model_selected: ModelSelected,
231    /// Source of the HF token
232    pub token_source: TokenSource,
233    /// Optional HF revision
234    pub hf_revision: Option<String>,
235    /// Model data type
236    pub dtype: ModelDType,
237    /// Device to load the model on
238    pub device: Device,
239    /// Device mapping setting
240    pub device_map_setting: DeviceMapSetting,
241    /// In-situ quantization type
242    pub isq: Option<IsqType>,
243    /// Paged attention configuration
244    pub paged_attn_config: Option<PagedAttentionConfig>,
245    /// Whether to suppress logging during loading
246    pub silent: bool,
247    /// Chat template override
248    pub chat_template: Option<String>,
249    /// Explicit Jinja template path
250    pub jinja_explicit: Option<String>,
251}
252
253/// State preserved when a model is unloaded.
254/// This contains all the information needed to reload the model on demand.
255#[derive(Clone)]
256pub struct UnloadedModelState {
257    /// Configuration to recreate the loader
258    pub loader_config: ModelLoaderConfig,
259    /// Scheduler configuration
260    pub scheduler_config: SchedulerConfig,
261    /// Engine configuration
262    pub engine_config: EngineConfig,
263    /// MCP client configuration
264    pub mcp_client_config: Option<McpClientConfig>,
265    /// Model category (Text, Multimodal, etc.)
266    pub category: ModelCategory,
267    /// Model metadata configuration
268    pub mistralrs_config: MistralRsConfig,
269}
270
271/// Internal structure to hold per-engine state
272struct EngineInstance {
273    sender: Sender<Request>,
274    engine_handler: JoinHandle<()>,
275    reboot_state: RebootState,
276    config: MistralRsConfig,
277    category: ModelCategory,
278    logger: Arc<IntervalLogger>,
279}
280
281/// The MistralRs struct handles sending requests to multiple engines.
282/// It is the core multi-threaded component of mistral.rs, and uses `mpsc`
283/// `Sender` and `Receiver` primitives to send and receive requests to the
284/// appropriate engine based on model ID.
285///
286/// ## Lock Ordering Convention
287///
288/// This struct uses multiple `RwLock`s. To prevent deadlocks, locks must be
289/// acquired in this order:
290/// 1. `reloading_models`
291/// 2. `engines`
292/// 3. `unloaded_models`
293/// 4. `default_engine_id`
294/// 5. `model_aliases`
295///
296/// Use scope-based lock management and explicit `drop()` calls.
297pub struct MistralRs {
298    engines: RwLock<HashMap<String, EngineInstance>>,
299    /// Models that have been unloaded but can be reloaded on demand
300    unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
301    /// Models currently being reloaded (to prevent concurrent reloads)
302    reloading_models: RwLock<HashSet<String>>,
303    default_engine_id: RwLock<Option<String>>,
304    /// Alternate IDs that resolve to primary model IDs.
305    model_aliases: RwLock<HashMap<String, String>>,
306    log: Option<String>,
307    id: String,
308    creation_time: u64,
309    next_request_id: Mutex<RefCell<usize>>,
310}
311
312#[derive(Clone)]
313struct RebootState {
314    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
315    method: SchedulerConfig,
316    no_kv_cache: bool,
317    no_prefix_cache: bool,
318    prefix_cache_n: usize,
319    disable_eos_stop: bool,
320    throughput_logging_enabled: bool,
321    search_embedding_model: Option<SearchEmbeddingModel>,
322    search_callback: Option<Arc<search::SearchCallback>>,
323    tool_callbacks: tools::ToolCallbacksWithTools,
324    mcp_client_config: Option<McpClientConfig>,
325    /// Optional loader config for reloading after unload
326    loader_config: Option<ModelLoaderConfig>,
327}
328
329/// Model status for loaded/unloaded state
330#[derive(Debug, Clone, Copy, PartialEq, Eq)]
331pub enum ModelStatus {
332    Loaded,
333    Unloaded,
334    Reloading,
335}
336
337impl std::fmt::Display for ModelStatus {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        match self {
340            ModelStatus::Loaded => write!(f, "loaded"),
341            ModelStatus::Unloaded => write!(f, "unloaded"),
342            ModelStatus::Reloading => write!(f, "reloading"),
343        }
344    }
345}
346
347#[derive(Debug)]
348pub enum MistralRsError {
349    EnginePoisoned,
350    SenderPoisoned,
351    /// The requested model was not found (neither loaded nor unloaded)
352    ModelNotFound(String),
353    /// The model is currently being reloaded
354    ModelReloading(String),
355    /// Failed to reload the model
356    ReloadFailed(String),
357    /// Model does not have loader config for reloading
358    NoLoaderConfig(String),
359    /// Model is already loaded
360    ModelAlreadyLoaded(String),
361    /// Model is already unloaded
362    ModelAlreadyUnloaded(String),
363}
364
365impl std::fmt::Display for MistralRsError {
366    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367        write!(f, "{:?}", &self)
368    }
369}
370
371impl std::error::Error for MistralRsError {}
372
373#[cfg(feature = "pyo3_macros")]
374impl From<MistralRsError> for pyo3::PyErr {
375    fn from(value: MistralRsError) -> Self {
376        PyValueError::new_err(format!("{value:?}"))
377    }
378}
379
380/// The MistralRsBuilder takes the pipeline and a scheduler method and constructs
381/// an Engine and a MistralRs instance. The Engine runs on a separate thread, and the MistralRs
382/// instance stays on the calling thread.
383pub struct MistralRsBuilder {
384    pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
385    method: SchedulerConfig,
386    model_id_override: Option<String>,
387    log: Option<String>,
388    no_kv_cache: Option<bool>,
389    no_prefix_cache: Option<bool>,
390    prefix_cache_n: Option<usize>,
391    disable_eos_stop: Option<bool>,
392    throughput_logging_enabled: bool,
393    search_embedding_model: Option<SearchEmbeddingModel>,
394    search_callback: Option<Arc<SearchCallback>>,
395    tool_callbacks: tools::ToolCallbacksWithTools,
396    mcp_client_config: Option<McpClientConfig>,
397    loader_config: Option<ModelLoaderConfig>,
398}
399
400impl MistralRsBuilder {
401    /// Creates a new builder with the given pipeline, scheduler method, logging flag,
402    /// and optional embedding model for web search. To override the search callback,
403    /// use `.with_search_callback(...)` on the builder.
404    pub fn new(
405        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
406        method: SchedulerConfig,
407        throughput_logging: bool,
408        search_embedding_model: Option<SearchEmbeddingModel>,
409    ) -> Self {
410        Self {
411            pipeline,
412            method,
413            model_id_override: None,
414            log: None,
415            no_kv_cache: None,
416            no_prefix_cache: None,
417            prefix_cache_n: None,
418            disable_eos_stop: None,
419            throughput_logging_enabled: throughput_logging,
420            search_embedding_model,
421            search_callback: None,
422            tool_callbacks: HashMap::new(),
423            mcp_client_config: None,
424            loader_config: None,
425        }
426    }
427
428    /// Override the model ID used by MistralRs. Defaults to the pipeline name.
429    pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
430        self.model_id_override = Some(model_id.into());
431        self
432    }
433
434    /// Set the loader config for enabling model unload/reload support.
435    /// Without this, models cannot be unloaded and reloaded.
436    pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
437        self.loader_config = Some(loader_config);
438        self
439    }
440    pub fn with_log(mut self, log: String) -> Self {
441        self.log = Some(log);
442        self
443    }
444    pub fn with_opt_log(mut self, log: Option<String>) -> Self {
445        self.log = log;
446        self
447    }
448    pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
449        self.no_kv_cache = Some(no_kv_cache);
450        self
451    }
452    pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
453        self.no_prefix_cache = Some(no_prefix_cache);
454        self
455    }
456    pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
457        self.prefix_cache_n = Some(prefix_cache_n);
458        self
459    }
460    pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
461        self.disable_eos_stop = Some(disable_eos_stop);
462        self
463    }
464
465    /// Use a custom callback to gather search results.
466    pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
467        self.search_callback = Some(search_callback);
468        self
469    }
470
471    /// Register a custom callback for the specified tool name.
472    pub fn with_tool_callback(
473        mut self,
474        name: impl Into<String>,
475        tool_callback: Arc<ToolCallback>,
476    ) -> Self {
477        let name = name.into();
478        // Wrap bare callback with a minimal tool definition.
479        self.tool_callbacks.insert(
480            name.clone(),
481            ToolCallbackWithTool {
482                callback: tool_callback,
483                tool: Tool {
484                    tp: ToolType::Function,
485                    function: Function {
486                        description: None,
487                        name,
488                        parameters: None,
489                        strict: None,
490                    },
491                },
492            },
493        );
494        self
495    }
496
497    /// Register a custom callback with its associated Tool definition. The Tool will be
498    /// automatically added to requests when tool callbacks are active.
499    pub fn with_tool_callback_and_tool(
500        mut self,
501        name: impl Into<String>,
502        tool_callback: Arc<ToolCallback>,
503        tool: Tool,
504    ) -> Self {
505        let name = name.into();
506        self.tool_callbacks.insert(
507            name,
508            ToolCallbackWithTool {
509                callback: tool_callback,
510                tool,
511            },
512        );
513        self
514    }
515
516    /// Configure MCP client to connect to external MCP servers.
517    pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
518        self.mcp_client_config = Some(config);
519        self
520    }
521
522    pub async fn build(self) -> Arc<MistralRs> {
523        MistralRs::new(self).await
524    }
525}
526
527impl Drop for MistralRs {
528    fn drop(&mut self) {
529        // Terminate all engines
530        if let Ok(engines) = self.engines.read() {
531            for (_, engine) in engines.iter() {
532                // Use try_send instead of blocking_send to avoid runtime panics
533                let _ = engine.sender.try_send(Request::Terminate);
534            }
535        }
536    }
537}
538
539impl MistralRs {
540    /// Create an engine instance with the given configuration
541    fn create_engine_instance(
542        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
543        method: SchedulerConfig,
544        config: EngineConfig,
545        reboot_state: RebootState,
546    ) -> Result<EngineInstance, String> {
547        let (tx, rx) = channel(10_000);
548
549        let pipeline_guard = pipeline.try_lock().unwrap();
550        let category = pipeline_guard.category();
551        let metadata = pipeline_guard.get_metadata();
552        let kind = metadata.kind.clone();
553        let device = pipeline_guard.device();
554        let modalities = metadata.modalities.clone();
555        let max_seq_len = match &category {
556            ModelCategory::Diffusion | ModelCategory::Speech => None,
557            _ => Some(metadata.max_seq_len),
558        };
559        let generation_defaults = pipeline_guard.generation_defaults();
560        let encoder_cache_counters = pipeline_guard.encoder_cache_counters();
561        drop(pipeline_guard);
562
563        let logger = Arc::new(IntervalLogger::new(
564            Duration::from_secs(5),
565            encoder_cache_counters,
566        ));
567        let logger_for_engine = logger.clone();
568
569        info!("Pipeline input modalities are {:?}", &modalities.input);
570        info!("Pipeline output modalities are {:?}", &modalities.output);
571
572        let mistralrs_config = MistralRsConfig {
573            kind,
574            device,
575            category: category.clone(),
576            modalities,
577            max_seq_len,
578            generation_defaults,
579        };
580
581        let tx_for_engine = tx.clone();
582        let engine_handler = thread::spawn(move || {
583            #[cfg(feature = "metal")]
584            objc::rc::autoreleasepool(move || {
585                let rt = Runtime::new().unwrap();
586                rt.block_on(async move {
587                    let engine = Engine::new(
588                        tx_for_engine,
589                        rx,
590                        pipeline,
591                        method,
592                        config.no_kv_cache,
593                        config.no_prefix_cache,
594                        config.prefix_cache_n,
595                        config.disable_eos_stop,
596                        config.throughput_logging_enabled,
597                        config.search_embedding_model,
598                        config.search_callback.clone(),
599                        config.tool_callbacks.clone(),
600                        logger_for_engine,
601                    )
602                    .expect("Engine creation failed.");
603                    Arc::new(engine).run().await;
604                })
605            });
606
607            #[cfg(not(feature = "metal"))]
608            {
609                let rt = Runtime::new().unwrap();
610                rt.block_on(async move {
611                    let engine = Engine::new(
612                        tx_for_engine,
613                        rx,
614                        pipeline,
615                        method,
616                        config.no_kv_cache,
617                        config.no_prefix_cache,
618                        config.prefix_cache_n,
619                        config.disable_eos_stop,
620                        config.throughput_logging_enabled,
621                        config.search_embedding_model,
622                        config.search_callback.clone(),
623                        config.tool_callbacks.clone(),
624                        logger_for_engine,
625                    )
626                    .expect("Engine creation failed.");
627                    Arc::new(engine).run().await;
628                })
629            }
630        });
631
632        Ok(EngineInstance {
633            sender: tx,
634            engine_handler,
635            reboot_state,
636            config: mistralrs_config,
637            category,
638            logger,
639        })
640    }
641
642    async fn new(config: MistralRsBuilder) -> Arc<Self> {
643        info!("git revision: {MISTRALRS_GIT_REVISION}");
644        let MistralRsBuilder {
645            pipeline,
646            method,
647            model_id_override,
648            log,
649            no_kv_cache,
650            no_prefix_cache,
651            prefix_cache_n,
652            disable_eos_stop,
653            throughput_logging_enabled,
654            search_embedding_model,
655            search_callback,
656            mut tool_callbacks,
657            mcp_client_config,
658            loader_config,
659        } = config;
660
661        mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
662            get_mut_arcmutex!(pipeline).device(),
663        );
664
665        let no_kv_cache = no_kv_cache.unwrap_or(false);
666        let no_prefix_cache = no_prefix_cache.unwrap_or(false);
667        let prefix_cache_n = prefix_cache_n.unwrap_or(16);
668        let disable_eos_stop = disable_eos_stop.unwrap_or(false);
669
670        // Initialize MCP client if configured
671        if let Some(config) = &mcp_client_config {
672            let mut mcp_client = McpClient::new(config.clone());
673            let total_servers = config.servers.len();
674
675            match mcp_client.initialize().await {
676                Ok(()) => {
677                    let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
678                    let tools_count = mcp_callbacks_with_tools.len();
679
680                    // Merge MCP tool callbacks with tools into the new collection
681                    for (name, callback_with_tool) in mcp_callbacks_with_tools {
682                        tool_callbacks.insert(name.clone(), callback_with_tool.clone());
683                    }
684
685                    if tools_count == 0 {
686                        warn!(
687                            "MCP client initialized but no tools were registered from {} servers",
688                            total_servers
689                        );
690                    } else {
691                        info!(
692                            "MCP client initialized successfully with {} tools from {} servers",
693                            tools_count, total_servers
694                        );
695                    }
696                }
697                Err(e) => {
698                    warn!(
699                        "Failed to initialize MCP client with {} configured servers: {}",
700                        total_servers, e
701                    );
702                    warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
703                }
704            }
705        }
706
707        let reboot_state = RebootState {
708            pipeline: pipeline.clone(),
709            method: method.clone(),
710            no_kv_cache,
711            no_prefix_cache,
712            prefix_cache_n,
713            disable_eos_stop,
714            throughput_logging_enabled,
715            search_embedding_model,
716            search_callback: search_callback.clone(),
717            tool_callbacks: tool_callbacks.clone(),
718            mcp_client_config: mcp_client_config.clone(),
719            loader_config,
720        };
721
722        // Create the engine configuration
723        let engine_config = EngineConfig {
724            no_kv_cache,
725            no_prefix_cache,
726            prefix_cache_n,
727            disable_eos_stop,
728            throughput_logging_enabled,
729            search_embedding_model,
730            search_callback,
731            tool_callbacks,
732        };
733
734        // Create the engine instance
735        let engine_instance =
736            Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
737                .expect("Failed to create engine instance");
738
739        let pipeline_name = pipeline.try_lock().unwrap().name();
740        let (id, alias_map) = match model_id_override {
741            Some(override_id) => {
742                let mut alias_map = HashMap::new();
743                if override_id != pipeline_name {
744                    alias_map.insert(pipeline_name.clone(), override_id.clone());
745                }
746                (override_id, alias_map)
747            }
748            None => (pipeline_name.clone(), HashMap::new()),
749        };
750
751        if distributed::is_daemon() {
752            let request_sender = engine_instance.sender.clone();
753
754            if cfg!(feature = "ring") {
755                // Ring daemon replicator
756                distributed::ring_daemon_replicator(request_sender);
757            } else {
758                // NCCL daemon replicator
759                distributed::nccl_daemon_replicator(request_sender);
760            }
761
762            #[allow(clippy::empty_loop)]
763            loop {}
764        }
765
766        // Determine if the current runtime is multi-threaded, as blocking operations are not allowed in single-threaded mode
767        let is_multi_threaded = tokio::runtime::Handle::try_current()
768            .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
769
770        // Do a dummy run
771        if !distributed::is_daemon()
772            && is_multi_threaded
773            && matches!(
774                engine_instance.category,
775                ModelCategory::Text | ModelCategory::Multimodal { .. }
776            )
777        {
778            let clone_sender = engine_instance.sender.clone();
779            tokio::task::block_in_place(|| {
780                let (tx, mut rx) = channel(1);
781                let req = Request::Normal(Box::new(NormalRequest {
782                    id: 0,
783                    messages: RequestMessage::Completion {
784                        text: "hello".to_string(),
785                        echo_prompt: false,
786                        best_of: None,
787                    },
788                    sampling_params: SamplingParams {
789                        max_len: Some(1),
790                        ..SamplingParams::deterministic()
791                    },
792                    response: tx,
793                    return_logprobs: false,
794                    is_streaming: false,
795                    constraint: Constraint::None,
796                    suffix: None,
797                    tool_choice: None,
798                    tools: None,
799                    logits_processors: None,
800                    return_raw_logits: false,
801                    web_search_options: None,
802                    max_tool_rounds: None,
803                    tool_dispatch_url: None,
804                    model_id: None,
805                    truncate_sequence: false,
806                }));
807                info!("Beginning dummy run.");
808                let start = Instant::now();
809                clone_sender.blocking_send(req).unwrap();
810
811                // Drain all responses from the channel until it's closed
812                let mut received_any = false;
813                while let Some(_resp) = rx.blocking_recv() {
814                    received_any = true;
815                }
816
817                if received_any {
818                    let end = Instant::now();
819                    info!(
820                        "Dummy run completed in {}s.",
821                        end.duration_since(start).as_secs_f64()
822                    );
823                } else {
824                    warn!("Dummy run failed!");
825                }
826            });
827
828            // Reset logger counters so the dummy run doesn't pollute stats
829            engine_instance.logger.reset();
830        }
831
832        // Create engines map with the first engine
833        let mut engines = HashMap::new();
834        engines.insert(id.clone(), engine_instance);
835
836        Arc::new(Self {
837            engines: RwLock::new(engines),
838            unloaded_models: RwLock::new(HashMap::new()),
839            reloading_models: RwLock::new(HashSet::new()),
840            default_engine_id: RwLock::new(Some(id.clone())),
841            model_aliases: RwLock::new(alias_map),
842            log,
843            id,
844            creation_time: SystemTime::now()
845                .duration_since(UNIX_EPOCH)
846                .expect("Time travel has occurred!")
847                .as_secs(),
848            next_request_id: Mutex::new(RefCell::new(1)),
849        })
850    }
851
852    /// Attempts to reboot a specific engine by model_id
853    fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
854        let mut engines = self.engines.write().map_err(|_| {
855            tracing::warn!("Couldn't get write lock on engines during reboot attempt");
856            MistralRsError::EnginePoisoned
857        })?;
858
859        if let Some(engine_instance) = engines.get(model_id) {
860            if !engine_instance.engine_handler.is_finished() {
861                tracing::info!("Engine {} already running, returning ok", model_id);
862                return Ok(());
863            }
864
865            let reboot_state = engine_instance.reboot_state.clone();
866            let engine_config = EngineConfig {
867                no_kv_cache: reboot_state.no_kv_cache,
868                no_prefix_cache: reboot_state.no_prefix_cache,
869                prefix_cache_n: reboot_state.prefix_cache_n,
870                disable_eos_stop: reboot_state.disable_eos_stop,
871                throughput_logging_enabled: reboot_state.throughput_logging_enabled,
872                search_embedding_model: reboot_state.search_embedding_model,
873                search_callback: reboot_state.search_callback.clone(),
874                tool_callbacks: reboot_state.tool_callbacks.clone(),
875            };
876            let new_engine_instance = Self::create_engine_instance(
877                reboot_state.pipeline.clone(),
878                reboot_state.method.clone(),
879                engine_config,
880                reboot_state,
881            )
882            .map_err(|e| {
883                tracing::error!("Failed to create new engine instance: {}", e);
884                MistralRsError::EnginePoisoned
885            })?;
886
887            engines.insert(model_id.to_string(), new_engine_instance);
888            tracing::info!("Successfully rebooted engine {}", model_id);
889            Ok(())
890        } else {
891            Err(MistralRsError::EnginePoisoned)
892        }
893    }
894
895    fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
896        let engines = self.engines.read().map_err(|_| {
897            tracing::warn!("Couldn't get read lock on engines!");
898            MistralRsError::EnginePoisoned
899        })?;
900
901        if let Some(engine_instance) = engines.get(model_id) {
902            Ok(engine_instance.engine_handler.is_finished())
903        } else {
904            Err(MistralRsError::EnginePoisoned)
905        }
906    }
907
908    /// Get sender for a specific model. If model_id is None, uses default engine.
909    /// If the model is unloaded, it will be automatically reloaded before returning the sender.
910    pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
911        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
912
913        // Check if model is loaded
914        let is_loaded = {
915            let engines = self
916                .engines
917                .read()
918                .map_err(|_| MistralRsError::SenderPoisoned)?;
919            engines.contains_key(&resolved_model_id)
920        };
921
922        if is_loaded {
923            // Check if engine is dead and needs reboot
924            if self.engine_dead(&resolved_model_id)? {
925                tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
926                self.reboot_engine(&resolved_model_id)?
927            }
928
929            let engines = self
930                .engines
931                .read()
932                .map_err(|_| MistralRsError::SenderPoisoned)?;
933            if let Some(engine_instance) = engines.get(&resolved_model_id) {
934                return Ok(engine_instance.sender.clone());
935            }
936        }
937
938        // Check if model is unloaded - trigger auto-reload
939        let is_unloaded = {
940            let unloaded = self
941                .unloaded_models
942                .read()
943                .map_err(|_| MistralRsError::EnginePoisoned)?;
944            unloaded.contains_key(&resolved_model_id)
945        };
946
947        if is_unloaded {
948            tracing::info!(
949                "Model {} is unloaded, triggering auto-reload",
950                resolved_model_id
951            );
952            self.reload_model_blocking(&resolved_model_id)?;
953
954            // After reload, get the sender
955            let engines = self
956                .engines
957                .read()
958                .map_err(|_| MistralRsError::SenderPoisoned)?;
959            if let Some(engine_instance) = engines.get(&resolved_model_id) {
960                return Ok(engine_instance.sender.clone());
961            }
962        }
963
964        Err(MistralRsError::ModelNotFound(resolved_model_id))
965    }
966
967    pub fn get_id(&self) -> String {
968        self.id.clone()
969    }
970
971    pub fn get_creation_time(&self) -> u64 {
972        self.creation_time
973    }
974
975    fn resolve_alias(&self, model_id: &str) -> Result<String, MistralRsError> {
976        let aliases = self
977            .model_aliases
978            .read()
979            .map_err(|_| MistralRsError::SenderPoisoned)?;
980        if let Some(primary_id) = aliases.get(model_id) {
981            Ok(primary_id.clone())
982        } else {
983            Ok(model_id.to_string())
984        }
985    }
986
987    fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, MistralRsError> {
988        match model_id {
989            Some(id) => self.resolve_alias(id),
990            None => {
991                let default_lock = self
992                    .default_engine_id
993                    .read()
994                    .map_err(|_| MistralRsError::SenderPoisoned)?;
995                Ok(default_lock
996                    .as_ref()
997                    .ok_or(MistralRsError::EnginePoisoned)?
998                    .clone())
999            }
1000        }
1001    }
1002
1003    /// Register an alternate model ID that resolves to an existing model.
1004    pub fn register_model_alias(
1005        &self,
1006        alias: impl Into<String>,
1007        model_id: &str,
1008    ) -> Result<(), String> {
1009        let alias = alias.into();
1010        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1011
1012        if alias == resolved_model_id {
1013            return Ok(());
1014        }
1015
1016        let reloading = self
1017            .reloading_models
1018            .read()
1019            .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1020        let model_reloading = reloading.contains(&resolved_model_id);
1021        let alias_conflict = reloading.contains(&alias);
1022        drop(reloading);
1023
1024        let engines = self
1025            .engines
1026            .read()
1027            .map_err(|_| "Failed to acquire read lock on engines")?;
1028        let model_loaded = engines.contains_key(&resolved_model_id);
1029        let alias_conflict = alias_conflict || engines.contains_key(&alias);
1030        drop(engines);
1031
1032        let unloaded = self
1033            .unloaded_models
1034            .read()
1035            .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1036        let model_unloaded = unloaded.contains_key(&resolved_model_id);
1037        let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
1038        drop(unloaded);
1039
1040        if !(model_loaded || model_unloaded || model_reloading) {
1041            return Err(format!("Model {resolved_model_id} not found"));
1042        }
1043
1044        if alias_conflict {
1045            return Err(format!(
1046                "Alias '{}' conflicts with an existing model ID",
1047                alias
1048            ));
1049        }
1050
1051        let mut aliases = self
1052            .model_aliases
1053            .write()
1054            .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1055        if let Some(existing) = aliases.get(&alias) {
1056            if existing == &resolved_model_id {
1057                return Ok(());
1058            }
1059            return Err(format!(
1060                "Alias '{}' is already assigned to model '{}'",
1061                alias, existing
1062            ));
1063        }
1064        aliases.insert(alias, resolved_model_id);
1065        Ok(())
1066    }
1067
1068    /// Check if a model is known (loaded, unloaded, or reloading), resolving aliases if needed.
1069    pub fn model_exists(&self, model_id: &str) -> Result<bool, MistralRsError> {
1070        let resolved_model_id = self.resolve_alias(model_id)?;
1071
1072        let reloading = self
1073            .reloading_models
1074            .read()
1075            .map_err(|_| MistralRsError::EnginePoisoned)?;
1076        if reloading.contains(&resolved_model_id) {
1077            return Ok(true);
1078        }
1079        drop(reloading);
1080
1081        let engines = self
1082            .engines
1083            .read()
1084            .map_err(|_| MistralRsError::EnginePoisoned)?;
1085        if engines.contains_key(&resolved_model_id) {
1086            return Ok(true);
1087        }
1088        drop(engines);
1089
1090        let unloaded = self
1091            .unloaded_models
1092            .read()
1093            .map_err(|_| MistralRsError::EnginePoisoned)?;
1094        if unloaded.contains_key(&resolved_model_id) {
1095            return Ok(true);
1096        }
1097
1098        Ok(false)
1099    }
1100
1101    /// Get the interval logger for a specific model. If model_id is None, uses default engine.
1102    pub fn get_logger(
1103        &self,
1104        model_id: Option<&str>,
1105    ) -> Result<Arc<IntervalLogger>, MistralRsError> {
1106        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1107
1108        let engines = self
1109            .engines
1110            .read()
1111            .map_err(|_| MistralRsError::SenderPoisoned)?;
1112        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1113            Ok(engine_instance.logger.clone())
1114        } else {
1115            Err(MistralRsError::EnginePoisoned)
1116        }
1117    }
1118
1119    /// Get model category for a specific model. If model_id is None, uses default engine.
1120    pub fn get_model_category(
1121        &self,
1122        model_id: Option<&str>,
1123    ) -> Result<ModelCategory, MistralRsError> {
1124        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1125
1126        let engines = self
1127            .engines
1128            .read()
1129            .map_err(|_| MistralRsError::SenderPoisoned)?;
1130        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1131            Ok(engine_instance.category.clone())
1132        } else {
1133            Err(MistralRsError::EnginePoisoned)
1134        }
1135    }
1136
1137    /// Get the maximum supported sequence length for a model, if applicable.
1138    pub fn max_sequence_length(
1139        &self,
1140        model_id: Option<&str>,
1141    ) -> Result<Option<usize>, MistralRsError> {
1142        let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1143
1144        let engines = self
1145            .engines
1146            .read()
1147            .map_err(|_| MistralRsError::SenderPoisoned)?;
1148        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1149            Ok(engine_instance.config.max_seq_len)
1150        } else {
1151            Err(MistralRsError::EnginePoisoned)
1152        }
1153    }
1154
1155    pub fn next_request_id(&self) -> usize {
1156        let l = self.next_request_id.lock().unwrap();
1157        let last = &mut *l.borrow_mut();
1158        let last_v = *last;
1159        *last += 1;
1160        last_v
1161    }
1162
1163    /// Add a new model engine to the MistralRs instance
1164    pub async fn add_model(
1165        &self,
1166        model_id: String,
1167        pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
1168        method: SchedulerConfig,
1169        config: AddModelConfig,
1170    ) -> Result<(), String> {
1171        {
1172            let reloading = self
1173                .reloading_models
1174                .read()
1175                .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1176            if reloading.contains(&model_id) {
1177                return Err(format!("Model {model_id} is currently reloading"));
1178            }
1179        }
1180        {
1181            let engines = self
1182                .engines
1183                .read()
1184                .map_err(|_| "Failed to acquire read lock on engines")?;
1185            if engines.contains_key(&model_id) {
1186                return Err(format!("Model {model_id} already exists"));
1187            }
1188        }
1189        {
1190            let unloaded = self
1191                .unloaded_models
1192                .read()
1193                .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1194            if unloaded.contains_key(&model_id) {
1195                return Err(format!("Model {model_id} already exists (unloaded)"));
1196            }
1197        }
1198        {
1199            let aliases = self
1200                .model_aliases
1201                .read()
1202                .map_err(|_| "Failed to acquire read lock on model_aliases")?;
1203            if aliases.contains_key(&model_id) {
1204                return Err(format!(
1205                    "Model ID '{}' conflicts with an existing alias",
1206                    model_id
1207                ));
1208            }
1209        }
1210
1211        let reboot_state = RebootState {
1212            pipeline: pipeline.clone(),
1213            method: method.clone(),
1214            no_kv_cache: config.engine_config.no_kv_cache,
1215            no_prefix_cache: config.engine_config.no_prefix_cache,
1216            prefix_cache_n: config.engine_config.prefix_cache_n,
1217            disable_eos_stop: config.engine_config.disable_eos_stop,
1218            throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
1219            search_embedding_model: config.engine_config.search_embedding_model,
1220            search_callback: config.engine_config.search_callback.clone(),
1221            tool_callbacks: config.engine_config.tool_callbacks.clone(),
1222            mcp_client_config: config.mcp_client_config.clone(),
1223            loader_config: config.loader_config.clone(),
1224        };
1225
1226        let engine_instance =
1227            Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
1228
1229        let mut engines = self
1230            .engines
1231            .write()
1232            .map_err(|_| "Failed to acquire write lock on engines")?;
1233        engines.insert(model_id.clone(), engine_instance);
1234
1235        // If this is the first model, set it as default
1236        if engines.len() == 1 {
1237            let mut default_lock = self
1238                .default_engine_id
1239                .write()
1240                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1241            *default_lock = Some(model_id.clone());
1242            info!("First model added, setting '{}' as default", model_id);
1243        }
1244
1245        Ok(())
1246    }
1247
1248    /// Remove a model engine from the MistralRs instance
1249    pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
1250        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1251        let mut engines = self
1252            .engines
1253            .write()
1254            .map_err(|_| "Failed to acquire write lock on engines")?;
1255
1256        if engines.len() <= 1 {
1257            return Err("Cannot remove the last model from MistralRs".to_string());
1258        }
1259
1260        if let Some(engine_instance) = engines.remove(&resolved_model_id) {
1261            // Send terminate signal to the engine
1262            let _ = engine_instance.sender.blocking_send(Request::Terminate);
1263
1264            // If this was the default engine, set a new default
1265            let mut default_lock = self
1266                .default_engine_id
1267                .write()
1268                .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1269            if let Some(ref default_id) = *default_lock {
1270                if default_id == &resolved_model_id {
1271                    // Set the first available engine as the new default
1272                    *default_lock = engines.keys().next().cloned();
1273                }
1274            }
1275            drop(default_lock);
1276            drop(engines);
1277
1278            // Remove any aliases pointing to the removed model
1279            let mut aliases = self
1280                .model_aliases
1281                .write()
1282                .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1283            aliases.retain(|_, target| target != &resolved_model_id);
1284
1285            Ok(())
1286        } else {
1287            Err(format!("Model {resolved_model_id} not found"))
1288        }
1289    }
1290
1291    /// List all available model IDs
1292    pub fn list_models(&self) -> Result<Vec<String>, String> {
1293        let engines = self
1294            .engines
1295            .read()
1296            .map_err(|_| "Failed to acquire read lock on engines")?;
1297        Ok(engines.keys().cloned().collect())
1298    }
1299
1300    /// Get the current default model ID
1301    pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
1302        let default_lock = self
1303            .default_engine_id
1304            .read()
1305            .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
1306        Ok(default_lock.clone())
1307    }
1308
1309    /// Set the default model ID
1310    pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
1311        let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1312        let engines = self
1313            .engines
1314            .read()
1315            .map_err(|_| "Failed to acquire read lock on engines")?;
1316        if !engines.contains_key(&resolved_model_id) {
1317            return Err(format!("Model {resolved_model_id} not found"));
1318        }
1319        drop(engines);
1320
1321        let mut default_lock = self
1322            .default_engine_id
1323            .write()
1324            .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1325        let old_default = default_lock.clone();
1326        *default_lock = Some(resolved_model_id.clone());
1327
1328        // Log the change
1329        info!(
1330            "Default model changed: {:?} -> {:?}",
1331            old_default, resolved_model_id
1332        );
1333
1334        Ok(())
1335    }
1336
1337    /// Dispatch a request to the appropriate engine based on the model_id in the request
1338    pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
1339        let model_id = match &mut request {
1340            Request::Normal(normal_req) => normal_req.model_id.as_deref(),
1341            _ => None, // Other request types don't specify model_id
1342        };
1343
1344        let sender = self.get_sender(model_id)?;
1345        sender
1346            .blocking_send(request)
1347            .map_err(|_| MistralRsError::SenderPoisoned)
1348    }
1349
1350    pub fn maybe_log_request(this: Arc<Self>, repr: String) {
1351        if let Some(file) = &this.log {
1352            let mut f = OpenOptions::new()
1353                .append(true)
1354                .create(true) // Optionally create the file if it doesn't already exist
1355                .open(file)
1356                .expect("Unable to open file");
1357            let time = chrono::offset::Local::now();
1358            f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
1359                .expect("Unable to write data");
1360        }
1361    }
1362
1363    pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1364        if let Some(file) = &this.log {
1365            let mut f = OpenOptions::new()
1366                .append(true)
1367                .create(true) // Optionally create the file if it doesn't already exist
1368                .open(file)
1369                .expect("Unable to open file");
1370            let time = chrono::offset::Local::now();
1371            let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1372            f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1373                .expect("Unable to write data");
1374        }
1375    }
1376
1377    pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1378        if let Some(file) = &this.log {
1379            let mut f = OpenOptions::new()
1380                .append(true)
1381                .create(true) // Optionally create the file if it doesn't already exist
1382                .open(file)
1383                .expect("Unable to open file");
1384            let time = chrono::offset::Local::now();
1385            f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1386                .expect("Unable to write data");
1387        }
1388    }
1389
1390    /// Get the number of tools available for a specific model (including MCP tools)
1391    pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1392        let resolved_model_id = self
1393            .resolve_alias_or_default(model_id)
1394            .map_err(|e| e.to_string())?;
1395
1396        let engines = self
1397            .engines
1398            .read()
1399            .map_err(|_| "Failed to acquire read lock on engines")?;
1400        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1401            Ok(engine_instance.reboot_state.tool_callbacks.len())
1402        } else {
1403            Err(format!("Model {resolved_model_id} not found"))
1404        }
1405    }
1406
1407    /// Check if MCP client is configured for a specific model
1408    pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1409        let resolved_model_id = self
1410            .resolve_alias_or_default(model_id)
1411            .map_err(|e| e.to_string())?;
1412
1413        let engines = self
1414            .engines
1415            .read()
1416            .map_err(|_| "Failed to acquire read lock on engines")?;
1417        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1418            Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1419        } else {
1420            Err(format!("Model {resolved_model_id} not found"))
1421        }
1422    }
1423
1424    /// Get config for a specific model
1425    pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1426        let resolved_model_id = self
1427            .resolve_alias_or_default(model_id)
1428            .map_err(|e| e.to_string())?;
1429
1430        let engines = self
1431            .engines
1432            .read()
1433            .map_err(|_| "Failed to acquire read lock on engines")?;
1434        if let Some(engine_instance) = engines.get(&resolved_model_id) {
1435            Ok(engine_instance.config.clone())
1436        } else {
1437            Err(format!("Model {resolved_model_id} not found"))
1438        }
1439    }
1440
1441    /// Unload a model from memory while preserving its configuration for later reload.
1442    /// The model can be reloaded automatically when a request is sent to it, or manually
1443    /// using `reload_model()`.
1444    ///
1445    /// Note: The model must have been added with a `ModelLoaderConfig` for auto-reload to work.
1446    /// Models added via `MistralRsBuilder` without explicit loader config cannot be reloaded.
1447    pub fn unload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1448        let resolved_model_id = self.resolve_alias(model_id)?;
1449        // Check if already unloaded
1450        {
1451            let unloaded = self
1452                .unloaded_models
1453                .read()
1454                .map_err(|_| MistralRsError::EnginePoisoned)?;
1455            if unloaded.contains_key(&resolved_model_id) {
1456                return Err(MistralRsError::ModelAlreadyUnloaded(
1457                    resolved_model_id.clone(),
1458                ));
1459            }
1460        }
1461
1462        // Get the engine instance and create UnloadedModelState
1463        let mut engines = self
1464            .engines
1465            .write()
1466            .map_err(|_| MistralRsError::EnginePoisoned)?;
1467
1468        let engine_instance = engines
1469            .remove(&resolved_model_id)
1470            .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?;
1471
1472        // Check if we have loader config for reloading
1473        let loader_config = engine_instance
1474            .reboot_state
1475            .loader_config
1476            .clone()
1477            .ok_or_else(|| MistralRsError::NoLoaderConfig(resolved_model_id.clone()))?;
1478
1479        // Create the unloaded state
1480        let unloaded_state = UnloadedModelState {
1481            loader_config,
1482            scheduler_config: engine_instance.reboot_state.method.clone(),
1483            engine_config: EngineConfig {
1484                no_kv_cache: engine_instance.reboot_state.no_kv_cache,
1485                no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
1486                prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
1487                disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
1488                throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
1489                search_embedding_model: engine_instance.reboot_state.search_embedding_model,
1490                search_callback: engine_instance.reboot_state.search_callback.clone(),
1491                tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
1492            },
1493            mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
1494            category: engine_instance.category.clone(),
1495            mistralrs_config: engine_instance.config.clone(),
1496        };
1497
1498        // Send terminate signal to the engine
1499        let _ = engine_instance.sender.try_send(Request::Terminate);
1500
1501        drop(engines);
1502
1503        // Store the unloaded state
1504        let mut unloaded = self
1505            .unloaded_models
1506            .write()
1507            .map_err(|_| MistralRsError::EnginePoisoned)?;
1508        unloaded.insert(resolved_model_id.to_string(), unloaded_state);
1509
1510        // Update default if needed
1511        let mut default_lock = self
1512            .default_engine_id
1513            .write()
1514            .map_err(|_| MistralRsError::EnginePoisoned)?;
1515        if let Some(ref default_id) = *default_lock {
1516            if default_id == &resolved_model_id {
1517                // Set the first available engine as the new default
1518                let engines = self
1519                    .engines
1520                    .read()
1521                    .map_err(|_| MistralRsError::EnginePoisoned)?;
1522                *default_lock = engines.keys().next().cloned();
1523            }
1524        }
1525
1526        info!("Model {} unloaded successfully", resolved_model_id);
1527        Ok(())
1528    }
1529
1530    /// Manually reload a previously unloaded model.
1531    /// This is also called automatically by `get_sender()` when a request targets an unloaded model.
1532    pub async fn reload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1533        let resolved_model_id = self.resolve_alias(model_id)?;
1534        // Check if already reloading
1535        {
1536            let reloading = self
1537                .reloading_models
1538                .read()
1539                .map_err(|_| MistralRsError::EnginePoisoned)?;
1540            if reloading.contains(&resolved_model_id) {
1541                return Err(MistralRsError::ModelReloading(resolved_model_id.clone()));
1542            }
1543        }
1544
1545        // Mark as reloading
1546        {
1547            let mut reloading = self
1548                .reloading_models
1549                .write()
1550                .map_err(|_| MistralRsError::EnginePoisoned)?;
1551            reloading.insert(resolved_model_id.clone());
1552        }
1553
1554        // Get the unloaded state
1555        let unloaded_state = {
1556            let unloaded = self
1557                .unloaded_models
1558                .read()
1559                .map_err(|_| MistralRsError::EnginePoisoned)?;
1560            unloaded
1561                .get(&resolved_model_id)
1562                .cloned()
1563                .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?
1564        };
1565
1566        // Attempt to reload
1567        let result = self
1568            .do_reload_model(&resolved_model_id, unloaded_state)
1569            .await;
1570
1571        // Remove from reloading set
1572        {
1573            let mut reloading = self
1574                .reloading_models
1575                .write()
1576                .map_err(|_| MistralRsError::EnginePoisoned)?;
1577            reloading.remove(&resolved_model_id);
1578        }
1579
1580        result
1581    }
1582
1583    /// Internal method to perform the actual model reload
1584    async fn do_reload_model(
1585        &self,
1586        model_id: &str,
1587        unloaded_state: UnloadedModelState,
1588    ) -> Result<(), MistralRsError> {
1589        use crate::model_loader::LoaderBuilder;
1590
1591        info!("Reloading model: {}", model_id);
1592
1593        let loader_config = &unloaded_state.loader_config;
1594
1595        // Build the loader from the stored config
1596        let loader = LoaderBuilder::new(loader_config.model_selected.clone())
1597            .with_chat_template(loader_config.chat_template.clone())
1598            .with_jinja_explicit(loader_config.jinja_explicit.clone())
1599            .build()
1600            .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to build loader: {e}")))?;
1601
1602        // Load the model
1603        let pipeline = loader
1604            .load_model_from_hf(
1605                None,
1606                loader_config.token_source.clone(),
1607                &loader_config.dtype,
1608                &loader_config.device,
1609                loader_config.silent,
1610                loader_config.device_map_setting.clone(),
1611                loader_config.isq,
1612                loader_config.paged_attn_config,
1613            )
1614            .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to load model: {e}")))?;
1615
1616        // Create the reboot state
1617        let reboot_state = RebootState {
1618            pipeline: pipeline.clone(),
1619            method: unloaded_state.scheduler_config.clone(),
1620            no_kv_cache: unloaded_state.engine_config.no_kv_cache,
1621            no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
1622            prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
1623            disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
1624            throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
1625            search_embedding_model: unloaded_state.engine_config.search_embedding_model,
1626            search_callback: unloaded_state.engine_config.search_callback.clone(),
1627            tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
1628            mcp_client_config: unloaded_state.mcp_client_config.clone(),
1629            loader_config: Some(unloaded_state.loader_config.clone()),
1630        };
1631
1632        // Create the engine instance
1633        let engine_instance = Self::create_engine_instance(
1634            pipeline,
1635            unloaded_state.scheduler_config,
1636            unloaded_state.engine_config,
1637            reboot_state,
1638        )
1639        .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to create engine: {e}")))?;
1640
1641        // Add to engines map
1642        {
1643            let mut engines = self
1644                .engines
1645                .write()
1646                .map_err(|_| MistralRsError::EnginePoisoned)?;
1647            engines.insert(model_id.to_string(), engine_instance);
1648        }
1649
1650        // Remove from unloaded map
1651        {
1652            let mut unloaded = self
1653                .unloaded_models
1654                .write()
1655                .map_err(|_| MistralRsError::EnginePoisoned)?;
1656            unloaded.remove(model_id);
1657        }
1658
1659        info!("Model {} reloaded successfully", model_id);
1660        Ok(())
1661    }
1662
1663    /// Synchronous version of reload_model for use in non-async contexts.
1664    ///
1665    /// This method handles different runtime contexts appropriately:
1666    /// - If called from a multi-threaded tokio runtime, uses `block_in_place`
1667    /// - If called from a single-threaded runtime, returns an error (use `reload_model()` instead)
1668    /// - If called outside any runtime, creates a temporary runtime
1669    pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), MistralRsError> {
1670        match tokio::runtime::Handle::try_current() {
1671            Ok(handle) => {
1672                if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
1673                    Err(MistralRsError::ReloadFailed(
1674                        "Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
1675                    ))
1676                } else {
1677                    tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
1678                }
1679            }
1680            Err(_) => {
1681                let rt = tokio::runtime::Runtime::new().map_err(|e| {
1682                    MistralRsError::ReloadFailed(format!("Failed to create runtime: {e}"))
1683                })?;
1684                rt.block_on(self.reload_model(model_id))
1685            }
1686        }
1687    }
1688
1689    /// List all unloaded model IDs
1690    pub fn list_unloaded_models(&self) -> Result<Vec<String>, MistralRsError> {
1691        let unloaded = self
1692            .unloaded_models
1693            .read()
1694            .map_err(|_| MistralRsError::EnginePoisoned)?;
1695        Ok(unloaded.keys().cloned().collect())
1696    }
1697
1698    /// Check if a model is currently loaded (as opposed to unloaded)
1699    pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, MistralRsError> {
1700        let resolved_model_id = self.resolve_alias(model_id)?;
1701        let engines = self
1702            .engines
1703            .read()
1704            .map_err(|_| MistralRsError::EnginePoisoned)?;
1705        Ok(engines.contains_key(&resolved_model_id))
1706    }
1707
1708    /// Get the status of a model, or None if not found
1709    pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, MistralRsError> {
1710        let resolved_model_id = self.resolve_alias(model_id)?;
1711        // Check if reloading
1712        {
1713            let reloading = self
1714                .reloading_models
1715                .read()
1716                .map_err(|_| MistralRsError::EnginePoisoned)?;
1717            if reloading.contains(&resolved_model_id) {
1718                return Ok(Some(ModelStatus::Reloading));
1719            }
1720        }
1721
1722        // Check if loaded
1723        {
1724            let engines = self
1725                .engines
1726                .read()
1727                .map_err(|_| MistralRsError::EnginePoisoned)?;
1728            if engines.contains_key(&resolved_model_id) {
1729                return Ok(Some(ModelStatus::Loaded));
1730            }
1731        }
1732
1733        // Check if unloaded
1734        {
1735            let unloaded = self
1736                .unloaded_models
1737                .read()
1738                .map_err(|_| MistralRsError::EnginePoisoned)?;
1739            if unloaded.contains_key(&resolved_model_id) {
1740                return Ok(Some(ModelStatus::Unloaded));
1741            }
1742        }
1743
1744        Ok(None)
1745    }
1746
1747    /// List all models with their status
1748    pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, MistralRsError> {
1749        let mut result = Vec::new();
1750
1751        // Get reloading models
1752        let reloading = self
1753            .reloading_models
1754            .read()
1755            .map_err(|_| MistralRsError::EnginePoisoned)?;
1756        for model_id in reloading.iter() {
1757            result.push((model_id.clone(), ModelStatus::Reloading));
1758        }
1759        drop(reloading);
1760
1761        // Get loaded models
1762        let engines = self
1763            .engines
1764            .read()
1765            .map_err(|_| MistralRsError::EnginePoisoned)?;
1766        for model_id in engines.keys() {
1767            result.push((model_id.clone(), ModelStatus::Loaded));
1768        }
1769        drop(engines);
1770
1771        // Get unloaded models
1772        let unloaded = self
1773            .unloaded_models
1774            .read()
1775            .map_err(|_| MistralRsError::EnginePoisoned)?;
1776        for model_id in unloaded.keys() {
1777            // Skip if already in reloading
1778            if !result.iter().any(|(id, _)| id == model_id) {
1779                result.push((model_id.clone(), ModelStatus::Unloaded));
1780            }
1781        }
1782
1783        Ok(result)
1784    }
1785}