1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 EngineInstruction, SearchEmbeddingModel, ENGINE_INSTRUCTIONS, TERMINATE_ALL_NEXT_STEP,
7};
8use hf_hub::Cache;
9pub use lora::Ordering;
10pub use pipeline::ModelCategory;
11pub use pipeline::Pipeline;
12#[cfg(feature = "pyo3_macros")]
13use pyo3::exceptions::PyValueError;
14use std::collections::{HashMap, HashSet};
15use std::num::NonZeroUsize;
16use std::sync::OnceLock;
17use std::time::Instant;
18use std::{
19 cell::RefCell,
20 error::Error,
21 fs::OpenOptions,
22 io::Write,
23 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24 thread::{self, JoinHandle},
25 time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31pub const MISTRALRS_GIT_REVISION: &str = match option_env!("MISTRALRS_GIT_REVISION") {
32 Some(value) => value,
33 None => "unknown",
34};
35
36mod cuda;
37mod device_map;
38mod engine;
39mod lora;
40mod model_loader;
41mod moe;
42mod ops;
43pub use model_loader::{
44 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
45};
46mod embedding_models;
47mod kv_cache;
48mod search;
49
50mod model_selected;
51pub use model_selected::ModelSelected;
52pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
53
54mod amoe;
55mod attention;
56mod diagnostics;
57mod diffusion_models;
58pub mod distributed;
59mod gguf;
60pub mod harmony;
61pub mod layers;
62mod layers_masker;
63mod layers_utils;
64pub mod matformer;
65mod mla;
66mod models;
67mod paged_attention;
68mod pipeline;
69mod prefix_cacher;
70mod request;
71mod response;
72mod sampler;
73mod scheduler;
74mod sequence;
75mod speech_models;
76pub mod think_tags;
77mod toml_selector;
78mod tools;
79mod topology;
80mod utils;
81mod vision_models;
82mod xlora_models;
83
84pub use diagnostics::{
85 check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
86 DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
87};
88mod tuning;
89pub use tuning::{
90 auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
91};
92
93pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
94pub use device_map::{
95 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
96};
97pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
98pub use mistralrs_audio::AudioInput;
99pub use mistralrs_mcp::{
100 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
101};
102pub use mistralrs_mcp::{
103 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
104};
105pub use mistralrs_quant::{IsqType, MULTI_LORA_DELIMITER};
106pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
107pub use pipeline::hf::{hf_home_dir, hf_hub_cache_dir, hf_token_path};
108pub use pipeline::{
109 chat_template::ChatTemplate, parse_isq_value, AdapterPaths, AnyMoeLoader, AnyMoePipeline,
110 AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams, DiffusionLoader,
111 DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader, EmbeddingLoaderBuilder,
112 EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig, GGMLLoader,
113 GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder, GGUFSpecificConfig,
114 GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader, LlamaLoader,
115 Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader, Modalities, ModelKind,
116 ModelPaths, MultimodalPromptPrefixer, NormalLoader, NormalLoaderBuilder, NormalLoaderType,
117 NormalSpecificConfig, Phi2Loader, Phi3Loader, Phi3VLoader, Qwen2Loader, SpeculativeConfig,
118 SpeculativeLoader, SpeculativePipeline, SpeechLoader, SpeechPipeline, Starcoder2Loader,
119 SupportedModality, TokenSource, VisionLoader, VisionLoaderBuilder, VisionLoaderType,
120 VisionSpecificConfig, UQFF_MULTI_FILE_DELIMITER,
121};
122pub use request::{
123 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
124 LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
125 SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
126};
127pub use response::*;
128pub use sampler::{
129 CustomLogitsProcessor, DrySamplingParams, SamplingParams, StopTokens, TopLogprob,
130};
131pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
132pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
133use serde::Serialize;
134pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
135use tokio::runtime::Runtime;
136use toml_selector::{TomlLoaderArgs, TomlSelector};
137pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
138pub use topology::{LayerTopology, Topology};
139pub use utils::debug::initialize_logging;
140pub use utils::memory_usage::MemoryUsage;
141pub use utils::normal::{ModelDType, TryIntoDType};
142pub use utils::{paged_attn_supported, using_flash_attn};
143
144pub use llguidance;
146
147pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
149pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
150
151#[derive(Clone)]
153pub struct EngineConfig {
154 pub no_kv_cache: bool,
155 pub no_prefix_cache: bool,
156 pub prefix_cache_n: usize,
157 pub disable_eos_stop: bool,
158 pub throughput_logging_enabled: bool,
159 pub search_embedding_model: Option<SearchEmbeddingModel>,
160 pub search_callback: Option<Arc<SearchCallback>>,
161 pub tool_callbacks: tools::ToolCallbacks,
162 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
163}
164
165impl Default for EngineConfig {
166 fn default() -> Self {
167 Self {
168 no_kv_cache: false,
169 no_prefix_cache: false,
170 prefix_cache_n: 16,
171 disable_eos_stop: false,
172 throughput_logging_enabled: true,
173 search_embedding_model: None,
174 search_callback: None,
175 tool_callbacks: HashMap::new(),
176 tool_callbacks_with_tools: HashMap::new(),
177 }
178 }
179}
180
181#[derive(Clone)]
183pub struct AddModelConfig {
184 pub engine_config: EngineConfig,
185 pub mcp_client_config: Option<McpClientConfig>,
186 pub loader_config: Option<ModelLoaderConfig>,
189}
190
191impl AddModelConfig {
192 pub fn new(engine_config: EngineConfig) -> Self {
193 Self {
194 engine_config,
195 mcp_client_config: None,
196 loader_config: None,
197 }
198 }
199
200 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
201 self.mcp_client_config = Some(mcp_config);
202 self
203 }
204
205 pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
208 self.loader_config = Some(loader_config);
209 self
210 }
211}
212
213#[derive(Clone)]
214pub struct MistralRsConfig {
215 pub kind: ModelKind,
216 pub device: Device,
217 pub category: ModelCategory,
218 pub modalities: Modalities,
219 pub max_seq_len: Option<usize>,
220}
221
222#[derive(Clone)]
225pub struct ModelLoaderConfig {
226 pub model_selected: ModelSelected,
228 pub token_source: TokenSource,
230 pub hf_revision: Option<String>,
232 pub dtype: ModelDType,
234 pub device: Device,
236 pub device_map_setting: DeviceMapSetting,
238 pub isq: Option<IsqType>,
240 pub paged_attn_config: Option<PagedAttentionConfig>,
242 pub silent: bool,
244 pub chat_template: Option<String>,
246 pub jinja_explicit: Option<String>,
248}
249
250#[derive(Clone)]
253pub struct UnloadedModelState {
254 pub loader_config: ModelLoaderConfig,
256 pub scheduler_config: SchedulerConfig,
258 pub engine_config: EngineConfig,
260 pub mcp_client_config: Option<McpClientConfig>,
262 pub category: ModelCategory,
264 pub mistralrs_config: MistralRsConfig,
266}
267
268struct EngineInstance {
270 sender: Sender<Request>,
271 engine_handler: JoinHandle<()>,
272 reboot_state: RebootState,
273 config: MistralRsConfig,
274 category: ModelCategory,
275}
276
277pub struct MistralRs {
294 engines: RwLock<HashMap<String, EngineInstance>>,
295 unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
297 reloading_models: RwLock<HashSet<String>>,
299 default_engine_id: RwLock<Option<String>>,
300 model_aliases: RwLock<HashMap<String, String>>,
302 log: Option<String>,
303 id: String,
304 creation_time: u64,
305 next_request_id: Mutex<RefCell<usize>>,
306}
307
308#[derive(Clone)]
309struct RebootState {
310 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
311 method: SchedulerConfig,
312 no_kv_cache: bool,
313 no_prefix_cache: bool,
314 prefix_cache_n: usize,
315 disable_eos_stop: bool,
316 throughput_logging_enabled: bool,
317 search_embedding_model: Option<SearchEmbeddingModel>,
318 search_callback: Option<Arc<search::SearchCallback>>,
319 tool_callbacks: tools::ToolCallbacks,
320 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
321 mcp_client_config: Option<McpClientConfig>,
322 loader_config: Option<ModelLoaderConfig>,
324}
325
326#[derive(Debug, Clone, Copy, PartialEq, Eq)]
328pub enum ModelStatus {
329 Loaded,
330 Unloaded,
331 Reloading,
332}
333
334impl std::fmt::Display for ModelStatus {
335 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
336 match self {
337 ModelStatus::Loaded => write!(f, "loaded"),
338 ModelStatus::Unloaded => write!(f, "unloaded"),
339 ModelStatus::Reloading => write!(f, "reloading"),
340 }
341 }
342}
343
344#[derive(Debug)]
345pub enum MistralRsError {
346 EnginePoisoned,
347 SenderPoisoned,
348 ModelNotFound(String),
350 ModelReloading(String),
352 ReloadFailed(String),
354 NoLoaderConfig(String),
356 ModelAlreadyLoaded(String),
358 ModelAlreadyUnloaded(String),
360}
361
362impl std::fmt::Display for MistralRsError {
363 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
364 write!(f, "{:?}", &self)
365 }
366}
367
368impl std::error::Error for MistralRsError {}
369
370#[cfg(feature = "pyo3_macros")]
371impl From<MistralRsError> for pyo3::PyErr {
372 fn from(value: MistralRsError) -> Self {
373 PyValueError::new_err(format!("{value:?}"))
374 }
375}
376
377pub struct MistralRsBuilder {
381 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
382 method: SchedulerConfig,
383 model_id_override: Option<String>,
384 log: Option<String>,
385 no_kv_cache: Option<bool>,
386 no_prefix_cache: Option<bool>,
387 prefix_cache_n: Option<usize>,
388 disable_eos_stop: Option<bool>,
389 throughput_logging_enabled: bool,
390 search_embedding_model: Option<SearchEmbeddingModel>,
391 search_callback: Option<Arc<SearchCallback>>,
392 tool_callbacks: tools::ToolCallbacks,
393 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
394 mcp_client_config: Option<McpClientConfig>,
395 loader_config: Option<ModelLoaderConfig>,
396}
397
398impl MistralRsBuilder {
399 pub fn new(
403 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
404 method: SchedulerConfig,
405 throughput_logging: bool,
406 search_embedding_model: Option<SearchEmbeddingModel>,
407 ) -> Self {
408 Self {
409 pipeline,
410 method,
411 model_id_override: None,
412 log: None,
413 no_kv_cache: None,
414 no_prefix_cache: None,
415 prefix_cache_n: None,
416 disable_eos_stop: None,
417 throughput_logging_enabled: throughput_logging,
418 search_embedding_model,
419 search_callback: None,
420 tool_callbacks: HashMap::new(),
421 tool_callbacks_with_tools: HashMap::new(),
422 mcp_client_config: None,
423 loader_config: None,
424 }
425 }
426
427 pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
429 self.model_id_override = Some(model_id.into());
430 self
431 }
432
433 pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
436 self.loader_config = Some(loader_config);
437 self
438 }
439 pub fn with_log(mut self, log: String) -> Self {
440 self.log = Some(log);
441 self
442 }
443 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
444 self.log = log;
445 self
446 }
447 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
448 self.no_kv_cache = Some(no_kv_cache);
449 self
450 }
451 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
452 self.no_prefix_cache = Some(no_prefix_cache);
453 self
454 }
455 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
456 self.prefix_cache_n = Some(prefix_cache_n);
457 self
458 }
459 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
460 self.disable_eos_stop = Some(disable_eos_stop);
461 self
462 }
463
464 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
466 self.search_callback = Some(search_callback);
467 self
468 }
469
470 pub fn with_tool_callback(
472 mut self,
473 name: impl Into<String>,
474 tool_callback: Arc<ToolCallback>,
475 ) -> Self {
476 self.tool_callbacks.insert(name.into(), tool_callback);
477 self
478 }
479
480 pub fn with_tool_callback_and_tool(
483 mut self,
484 name: impl Into<String>,
485 tool_callback: Arc<ToolCallback>,
486 tool: Tool,
487 ) -> Self {
488 let name = name.into();
489 self.tool_callbacks_with_tools.insert(
490 name,
491 ToolCallbackWithTool {
492 callback: tool_callback,
493 tool,
494 },
495 );
496 self
497 }
498
499 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
501 self.mcp_client_config = Some(config);
502 self
503 }
504
505 pub async fn build(self) -> Arc<MistralRs> {
506 MistralRs::new(self).await
507 }
508}
509
510impl Drop for MistralRs {
511 fn drop(&mut self) {
512 if let Ok(engines) = self.engines.read() {
514 for (_, engine) in engines.iter() {
515 let _ = engine.sender.try_send(Request::Terminate);
517 }
518 }
519 }
520}
521
522impl MistralRs {
523 fn create_engine_instance(
525 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
526 method: SchedulerConfig,
527 config: EngineConfig,
528 reboot_state: RebootState,
529 ) -> Result<EngineInstance, String> {
530 let (tx, rx) = channel(10_000);
531
532 let pipeline_guard = pipeline.try_lock().unwrap();
533 let category = pipeline_guard.category();
534 let metadata = pipeline_guard.get_metadata();
535 let kind = metadata.kind.clone();
536 let device = pipeline_guard.device();
537 let modalities = metadata.modalities.clone();
538 let max_seq_len = match &category {
539 ModelCategory::Diffusion | ModelCategory::Speech => None,
540 _ => Some(metadata.max_seq_len),
541 };
542 drop(pipeline_guard);
543
544 info!("Pipeline input modalities are {:?}", &modalities.input);
545 info!("Pipeline output modalities are {:?}", &modalities.output);
546
547 let mistralrs_config = MistralRsConfig {
548 kind,
549 device,
550 category: category.clone(),
551 modalities,
552 max_seq_len,
553 };
554
555 let tx_for_engine = tx.clone();
556 let engine_handler = thread::spawn(move || {
557 #[cfg(feature = "metal")]
558 objc::rc::autoreleasepool(move || {
559 let rt = Runtime::new().unwrap();
560 rt.block_on(async move {
561 let engine = Engine::new(
562 tx_for_engine,
563 rx,
564 pipeline,
565 method,
566 config.no_kv_cache,
567 config.no_prefix_cache,
568 config.prefix_cache_n,
569 config.disable_eos_stop,
570 config.throughput_logging_enabled,
571 config.search_embedding_model,
572 config.search_callback.clone(),
573 config.tool_callbacks.clone(),
574 config.tool_callbacks_with_tools.clone(),
575 )
576 .expect("Engine creation failed.");
577 Arc::new(engine).run().await;
578 })
579 });
580
581 #[cfg(not(feature = "metal"))]
582 {
583 let rt = Runtime::new().unwrap();
584 rt.block_on(async move {
585 let engine = Engine::new(
586 tx_for_engine,
587 rx,
588 pipeline,
589 method,
590 config.no_kv_cache,
591 config.no_prefix_cache,
592 config.prefix_cache_n,
593 config.disable_eos_stop,
594 config.throughput_logging_enabled,
595 config.search_embedding_model,
596 config.search_callback.clone(),
597 config.tool_callbacks.clone(),
598 config.tool_callbacks_with_tools.clone(),
599 )
600 .expect("Engine creation failed.");
601 Arc::new(engine).run().await;
602 })
603 }
604 });
605
606 Ok(EngineInstance {
607 sender: tx,
608 engine_handler,
609 reboot_state,
610 config: mistralrs_config,
611 category,
612 })
613 }
614
615 async fn new(config: MistralRsBuilder) -> Arc<Self> {
616 info!("git revision: {MISTRALRS_GIT_REVISION}");
617 let MistralRsBuilder {
618 pipeline,
619 method,
620 model_id_override,
621 log,
622 no_kv_cache,
623 no_prefix_cache,
624 prefix_cache_n,
625 disable_eos_stop,
626 throughput_logging_enabled,
627 search_embedding_model,
628 search_callback,
629 tool_callbacks,
630 mut tool_callbacks_with_tools,
631 mcp_client_config,
632 loader_config,
633 } = config;
634
635 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
636 get_mut_arcmutex!(pipeline).device(),
637 );
638
639 let method = if !get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache
642 && get_mut_arcmutex!(pipeline).cache().is_hybrid()
643 {
644 info!(
645 "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
646 );
647 SchedulerConfig::DefaultScheduler {
648 method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
649 }
650 } else {
651 method
652 };
653
654 let no_kv_cache = no_kv_cache.unwrap_or(false);
655 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
656 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
657 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
658
659 if let Some(config) = &mcp_client_config {
661 let mut mcp_client = McpClient::new(config.clone());
662 let total_servers = config.servers.len();
663
664 match mcp_client.initialize().await {
665 Ok(()) => {
666 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
667 let tools_count = mcp_callbacks_with_tools.len();
668
669 for (name, callback_with_tool) in mcp_callbacks_with_tools {
671 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
672 }
673
674 if tools_count == 0 {
675 warn!(
676 "MCP client initialized but no tools were registered from {} servers",
677 total_servers
678 );
679 } else {
680 info!(
681 "MCP client initialized successfully with {} tools from {} servers",
682 tools_count, total_servers
683 );
684 }
685 }
686 Err(e) => {
687 warn!(
688 "Failed to initialize MCP client with {} configured servers: {}",
689 total_servers, e
690 );
691 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
692 }
693 }
694 }
695
696 let reboot_state = RebootState {
697 pipeline: pipeline.clone(),
698 method: method.clone(),
699 no_kv_cache,
700 no_prefix_cache,
701 prefix_cache_n,
702 disable_eos_stop,
703 throughput_logging_enabled,
704 search_embedding_model,
705 search_callback: search_callback.clone(),
706 tool_callbacks: tool_callbacks.clone(),
707 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
708 mcp_client_config: mcp_client_config.clone(),
709 loader_config,
710 };
711
712 let engine_config = EngineConfig {
714 no_kv_cache,
715 no_prefix_cache,
716 prefix_cache_n,
717 disable_eos_stop,
718 throughput_logging_enabled,
719 search_embedding_model,
720 search_callback,
721 tool_callbacks,
722 tool_callbacks_with_tools,
723 };
724
725 let engine_instance =
727 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
728 .expect("Failed to create engine instance");
729
730 let pipeline_name = pipeline.try_lock().unwrap().name();
731 let (id, alias_map) = match model_id_override {
732 Some(override_id) => {
733 let mut alias_map = HashMap::new();
734 if override_id != pipeline_name {
735 alias_map.insert(pipeline_name.clone(), override_id.clone());
736 }
737 (override_id, alias_map)
738 }
739 None => (pipeline_name.clone(), HashMap::new()),
740 };
741
742 if distributed::is_daemon() {
743 let request_sender = engine_instance.sender.clone();
744
745 if cfg!(feature = "ring") {
746 distributed::ring_daemon_replicator(request_sender);
748 } else {
749 distributed::nccl_daemon_replicator(request_sender);
751 }
752
753 #[allow(clippy::empty_loop)]
754 loop {}
755 }
756
757 let is_multi_threaded = tokio::runtime::Handle::try_current()
759 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
760
761 if !distributed::is_daemon()
763 && is_multi_threaded
764 && matches!(
765 engine_instance.category,
766 ModelCategory::Text | ModelCategory::Vision { .. }
767 )
768 {
769 let clone_sender = engine_instance.sender.clone();
770 tokio::task::block_in_place(|| {
771 let (tx, mut rx) = channel(1);
772 let req = Request::Normal(Box::new(NormalRequest {
773 id: 0,
774 messages: RequestMessage::Completion {
775 text: "hello".to_string(),
776 echo_prompt: false,
777 best_of: None,
778 },
779 sampling_params: SamplingParams {
780 max_len: Some(1),
781 ..SamplingParams::deterministic()
782 },
783 response: tx,
784 return_logprobs: false,
785 is_streaming: false,
786 constraint: Constraint::None,
787 suffix: None,
788 tool_choice: None,
789 tools: None,
790 logits_processors: None,
791 return_raw_logits: false,
792 web_search_options: None,
793 model_id: None,
794 truncate_sequence: false,
795 }));
796 info!("Beginning dummy run.");
797 let start = Instant::now();
798 clone_sender.blocking_send(req).unwrap();
799
800 let mut received_any = false;
802 while let Some(_resp) = rx.blocking_recv() {
803 received_any = true;
804 }
805
806 if received_any {
807 let end = Instant::now();
808 info!(
809 "Dummy run completed in {}s.",
810 end.duration_since(start).as_secs_f64()
811 );
812 } else {
813 warn!("Dummy run failed!");
814 }
815 });
816 }
817
818 let mut engines = HashMap::new();
820 engines.insert(id.clone(), engine_instance);
821
822 Arc::new(Self {
823 engines: RwLock::new(engines),
824 unloaded_models: RwLock::new(HashMap::new()),
825 reloading_models: RwLock::new(HashSet::new()),
826 default_engine_id: RwLock::new(Some(id.clone())),
827 model_aliases: RwLock::new(alias_map),
828 log,
829 id,
830 creation_time: SystemTime::now()
831 .duration_since(UNIX_EPOCH)
832 .expect("Time travel has occurred!")
833 .as_secs(),
834 next_request_id: Mutex::new(RefCell::new(1)),
835 })
836 }
837
838 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
840 let mut engines = self.engines.write().map_err(|_| {
841 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
842 MistralRsError::EnginePoisoned
843 })?;
844
845 if let Some(engine_instance) = engines.get(model_id) {
846 if !engine_instance.engine_handler.is_finished() {
847 tracing::info!("Engine {} already running, returning ok", model_id);
848 return Ok(());
849 }
850
851 let reboot_state = engine_instance.reboot_state.clone();
852 let engine_config = EngineConfig {
853 no_kv_cache: reboot_state.no_kv_cache,
854 no_prefix_cache: reboot_state.no_prefix_cache,
855 prefix_cache_n: reboot_state.prefix_cache_n,
856 disable_eos_stop: reboot_state.disable_eos_stop,
857 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
858 search_embedding_model: reboot_state.search_embedding_model,
859 search_callback: reboot_state.search_callback.clone(),
860 tool_callbacks: reboot_state.tool_callbacks.clone(),
861 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
862 };
863 let new_engine_instance = Self::create_engine_instance(
864 reboot_state.pipeline.clone(),
865 reboot_state.method.clone(),
866 engine_config,
867 reboot_state,
868 )
869 .map_err(|e| {
870 tracing::error!("Failed to create new engine instance: {}", e);
871 MistralRsError::EnginePoisoned
872 })?;
873
874 engines.insert(model_id.to_string(), new_engine_instance);
875 tracing::info!("Successfully rebooted engine {}", model_id);
876 Ok(())
877 } else {
878 Err(MistralRsError::EnginePoisoned)
879 }
880 }
881
882 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
883 let engines = self.engines.read().map_err(|_| {
884 tracing::warn!("Couldn't get read lock on engines!");
885 MistralRsError::EnginePoisoned
886 })?;
887
888 if let Some(engine_instance) = engines.get(model_id) {
889 Ok(engine_instance.engine_handler.is_finished())
890 } else {
891 Err(MistralRsError::EnginePoisoned)
892 }
893 }
894
895 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
898 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
899
900 let is_loaded = {
902 let engines = self
903 .engines
904 .read()
905 .map_err(|_| MistralRsError::SenderPoisoned)?;
906 engines.contains_key(&resolved_model_id)
907 };
908
909 if is_loaded {
910 if self.engine_dead(&resolved_model_id)? {
912 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
913 self.reboot_engine(&resolved_model_id)?
914 }
915
916 let engines = self
917 .engines
918 .read()
919 .map_err(|_| MistralRsError::SenderPoisoned)?;
920 if let Some(engine_instance) = engines.get(&resolved_model_id) {
921 return Ok(engine_instance.sender.clone());
922 }
923 }
924
925 let is_unloaded = {
927 let unloaded = self
928 .unloaded_models
929 .read()
930 .map_err(|_| MistralRsError::EnginePoisoned)?;
931 unloaded.contains_key(&resolved_model_id)
932 };
933
934 if is_unloaded {
935 tracing::info!(
936 "Model {} is unloaded, triggering auto-reload",
937 resolved_model_id
938 );
939 self.reload_model_blocking(&resolved_model_id)?;
940
941 let engines = self
943 .engines
944 .read()
945 .map_err(|_| MistralRsError::SenderPoisoned)?;
946 if let Some(engine_instance) = engines.get(&resolved_model_id) {
947 return Ok(engine_instance.sender.clone());
948 }
949 }
950
951 Err(MistralRsError::ModelNotFound(resolved_model_id))
952 }
953
954 pub fn get_id(&self) -> String {
955 self.id.clone()
956 }
957
958 pub fn get_creation_time(&self) -> u64 {
959 self.creation_time
960 }
961
962 fn resolve_alias(&self, model_id: &str) -> Result<String, MistralRsError> {
963 let aliases = self
964 .model_aliases
965 .read()
966 .map_err(|_| MistralRsError::SenderPoisoned)?;
967 if let Some(primary_id) = aliases.get(model_id) {
968 Ok(primary_id.clone())
969 } else {
970 Ok(model_id.to_string())
971 }
972 }
973
974 fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, MistralRsError> {
975 match model_id {
976 Some(id) => self.resolve_alias(id),
977 None => {
978 let default_lock = self
979 .default_engine_id
980 .read()
981 .map_err(|_| MistralRsError::SenderPoisoned)?;
982 Ok(default_lock
983 .as_ref()
984 .ok_or(MistralRsError::EnginePoisoned)?
985 .clone())
986 }
987 }
988 }
989
990 pub fn register_model_alias(
992 &self,
993 alias: impl Into<String>,
994 model_id: &str,
995 ) -> Result<(), String> {
996 let alias = alias.into();
997 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
998
999 if alias == resolved_model_id {
1000 return Ok(());
1001 }
1002
1003 let reloading = self
1004 .reloading_models
1005 .read()
1006 .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1007 let model_reloading = reloading.contains(&resolved_model_id);
1008 let alias_conflict = reloading.contains(&alias);
1009 drop(reloading);
1010
1011 let engines = self
1012 .engines
1013 .read()
1014 .map_err(|_| "Failed to acquire read lock on engines")?;
1015 let model_loaded = engines.contains_key(&resolved_model_id);
1016 let alias_conflict = alias_conflict || engines.contains_key(&alias);
1017 drop(engines);
1018
1019 let unloaded = self
1020 .unloaded_models
1021 .read()
1022 .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1023 let model_unloaded = unloaded.contains_key(&resolved_model_id);
1024 let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
1025 drop(unloaded);
1026
1027 if !(model_loaded || model_unloaded || model_reloading) {
1028 return Err(format!("Model {resolved_model_id} not found"));
1029 }
1030
1031 if alias_conflict {
1032 return Err(format!(
1033 "Alias '{}' conflicts with an existing model ID",
1034 alias
1035 ));
1036 }
1037
1038 let mut aliases = self
1039 .model_aliases
1040 .write()
1041 .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1042 if let Some(existing) = aliases.get(&alias) {
1043 if existing == &resolved_model_id {
1044 return Ok(());
1045 }
1046 return Err(format!(
1047 "Alias '{}' is already assigned to model '{}'",
1048 alias, existing
1049 ));
1050 }
1051 aliases.insert(alias, resolved_model_id);
1052 Ok(())
1053 }
1054
1055 pub fn model_exists(&self, model_id: &str) -> Result<bool, MistralRsError> {
1057 let resolved_model_id = self.resolve_alias(model_id)?;
1058
1059 let reloading = self
1060 .reloading_models
1061 .read()
1062 .map_err(|_| MistralRsError::EnginePoisoned)?;
1063 if reloading.contains(&resolved_model_id) {
1064 return Ok(true);
1065 }
1066 drop(reloading);
1067
1068 let engines = self
1069 .engines
1070 .read()
1071 .map_err(|_| MistralRsError::EnginePoisoned)?;
1072 if engines.contains_key(&resolved_model_id) {
1073 return Ok(true);
1074 }
1075 drop(engines);
1076
1077 let unloaded = self
1078 .unloaded_models
1079 .read()
1080 .map_err(|_| MistralRsError::EnginePoisoned)?;
1081 if unloaded.contains_key(&resolved_model_id) {
1082 return Ok(true);
1083 }
1084
1085 Ok(false)
1086 }
1087
1088 pub fn get_model_category(
1090 &self,
1091 model_id: Option<&str>,
1092 ) -> Result<ModelCategory, MistralRsError> {
1093 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1094
1095 let engines = self
1096 .engines
1097 .read()
1098 .map_err(|_| MistralRsError::SenderPoisoned)?;
1099 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1100 Ok(engine_instance.category.clone())
1101 } else {
1102 Err(MistralRsError::EnginePoisoned)
1103 }
1104 }
1105
1106 pub fn max_sequence_length(
1108 &self,
1109 model_id: Option<&str>,
1110 ) -> Result<Option<usize>, MistralRsError> {
1111 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1112
1113 let engines = self
1114 .engines
1115 .read()
1116 .map_err(|_| MistralRsError::SenderPoisoned)?;
1117 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1118 Ok(engine_instance.config.max_seq_len)
1119 } else {
1120 Err(MistralRsError::EnginePoisoned)
1121 }
1122 }
1123
1124 pub fn next_request_id(&self) -> usize {
1125 let l = self.next_request_id.lock().unwrap();
1126 let last = &mut *l.borrow_mut();
1127 let last_v = *last;
1128 *last += 1;
1129 last_v
1130 }
1131
1132 pub async fn add_model(
1134 &self,
1135 model_id: String,
1136 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
1137 method: SchedulerConfig,
1138 config: AddModelConfig,
1139 ) -> Result<(), String> {
1140 {
1141 let reloading = self
1142 .reloading_models
1143 .read()
1144 .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1145 if reloading.contains(&model_id) {
1146 return Err(format!("Model {model_id} is currently reloading"));
1147 }
1148 }
1149 {
1150 let engines = self
1151 .engines
1152 .read()
1153 .map_err(|_| "Failed to acquire read lock on engines")?;
1154 if engines.contains_key(&model_id) {
1155 return Err(format!("Model {model_id} already exists"));
1156 }
1157 }
1158 {
1159 let unloaded = self
1160 .unloaded_models
1161 .read()
1162 .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1163 if unloaded.contains_key(&model_id) {
1164 return Err(format!("Model {model_id} already exists (unloaded)"));
1165 }
1166 }
1167 {
1168 let aliases = self
1169 .model_aliases
1170 .read()
1171 .map_err(|_| "Failed to acquire read lock on model_aliases")?;
1172 if aliases.contains_key(&model_id) {
1173 return Err(format!(
1174 "Model ID '{}' conflicts with an existing alias",
1175 model_id
1176 ));
1177 }
1178 }
1179
1180 let method = {
1182 let pipeline_guard = pipeline.try_lock().unwrap();
1183 if !pipeline_guard.get_metadata().no_kv_cache && pipeline_guard.cache().is_hybrid() {
1184 info!(
1185 "Hybrid model detected (Mamba-Attention), enforcing batch_size=1 for correctness"
1186 );
1187 SchedulerConfig::DefaultScheduler {
1188 method: DefaultSchedulerMethod::Fixed(NonZeroUsize::new(1).unwrap()),
1189 }
1190 } else {
1191 method
1192 }
1193 };
1194
1195 let reboot_state = RebootState {
1196 pipeline: pipeline.clone(),
1197 method: method.clone(),
1198 no_kv_cache: config.engine_config.no_kv_cache,
1199 no_prefix_cache: config.engine_config.no_prefix_cache,
1200 prefix_cache_n: config.engine_config.prefix_cache_n,
1201 disable_eos_stop: config.engine_config.disable_eos_stop,
1202 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
1203 search_embedding_model: config.engine_config.search_embedding_model,
1204 search_callback: config.engine_config.search_callback.clone(),
1205 tool_callbacks: config.engine_config.tool_callbacks.clone(),
1206 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
1207 mcp_client_config: config.mcp_client_config.clone(),
1208 loader_config: config.loader_config.clone(),
1209 };
1210
1211 let engine_instance =
1212 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
1213
1214 let mut engines = self
1215 .engines
1216 .write()
1217 .map_err(|_| "Failed to acquire write lock on engines")?;
1218 engines.insert(model_id.clone(), engine_instance);
1219
1220 if engines.len() == 1 {
1222 let mut default_lock = self
1223 .default_engine_id
1224 .write()
1225 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1226 *default_lock = Some(model_id.clone());
1227 info!("First model added, setting '{}' as default", model_id);
1228 }
1229
1230 Ok(())
1231 }
1232
1233 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
1235 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1236 let mut engines = self
1237 .engines
1238 .write()
1239 .map_err(|_| "Failed to acquire write lock on engines")?;
1240
1241 if engines.len() <= 1 {
1242 return Err("Cannot remove the last model from MistralRs".to_string());
1243 }
1244
1245 if let Some(engine_instance) = engines.remove(&resolved_model_id) {
1246 let _ = engine_instance.sender.blocking_send(Request::Terminate);
1248
1249 let mut default_lock = self
1251 .default_engine_id
1252 .write()
1253 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1254 if let Some(ref default_id) = *default_lock {
1255 if default_id == &resolved_model_id {
1256 *default_lock = engines.keys().next().cloned();
1258 }
1259 }
1260 drop(default_lock);
1261 drop(engines);
1262
1263 let mut aliases = self
1265 .model_aliases
1266 .write()
1267 .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1268 aliases.retain(|_, target| target != &resolved_model_id);
1269
1270 Ok(())
1271 } else {
1272 Err(format!("Model {resolved_model_id} not found"))
1273 }
1274 }
1275
1276 pub fn list_models(&self) -> Result<Vec<String>, String> {
1278 let engines = self
1279 .engines
1280 .read()
1281 .map_err(|_| "Failed to acquire read lock on engines")?;
1282 Ok(engines.keys().cloned().collect())
1283 }
1284
1285 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
1287 let default_lock = self
1288 .default_engine_id
1289 .read()
1290 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
1291 Ok(default_lock.clone())
1292 }
1293
1294 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
1296 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1297 let engines = self
1298 .engines
1299 .read()
1300 .map_err(|_| "Failed to acquire read lock on engines")?;
1301 if !engines.contains_key(&resolved_model_id) {
1302 return Err(format!("Model {resolved_model_id} not found"));
1303 }
1304 drop(engines);
1305
1306 let mut default_lock = self
1307 .default_engine_id
1308 .write()
1309 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1310 let old_default = default_lock.clone();
1311 *default_lock = Some(resolved_model_id.clone());
1312
1313 info!(
1315 "Default model changed: {:?} -> {:?}",
1316 old_default, resolved_model_id
1317 );
1318
1319 Ok(())
1320 }
1321
1322 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
1324 let model_id = match &mut request {
1325 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
1326 _ => None, };
1328
1329 let sender = self.get_sender(model_id)?;
1330 sender
1331 .blocking_send(request)
1332 .map_err(|_| MistralRsError::SenderPoisoned)
1333 }
1334
1335 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
1336 if let Some(file) = &this.log {
1337 let mut f = OpenOptions::new()
1338 .append(true)
1339 .create(true) .open(file)
1341 .expect("Unable to open file");
1342 let time = chrono::offset::Local::now();
1343 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
1344 .expect("Unable to write data");
1345 }
1346 }
1347
1348 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1349 if let Some(file) = &this.log {
1350 let mut f = OpenOptions::new()
1351 .append(true)
1352 .create(true) .open(file)
1354 .expect("Unable to open file");
1355 let time = chrono::offset::Local::now();
1356 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1357 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1358 .expect("Unable to write data");
1359 }
1360 }
1361
1362 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1363 if let Some(file) = &this.log {
1364 let mut f = OpenOptions::new()
1365 .append(true)
1366 .create(true) .open(file)
1368 .expect("Unable to open file");
1369 let time = chrono::offset::Local::now();
1370 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1371 .expect("Unable to write data");
1372 }
1373 }
1374
1375 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1377 let resolved_model_id = self
1378 .resolve_alias_or_default(model_id)
1379 .map_err(|e| e.to_string())?;
1380
1381 let engines = self
1382 .engines
1383 .read()
1384 .map_err(|_| "Failed to acquire read lock on engines")?;
1385 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1386 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1387 } else {
1388 Err(format!("Model {resolved_model_id} not found"))
1389 }
1390 }
1391
1392 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1394 let resolved_model_id = self
1395 .resolve_alias_or_default(model_id)
1396 .map_err(|e| e.to_string())?;
1397
1398 let engines = self
1399 .engines
1400 .read()
1401 .map_err(|_| "Failed to acquire read lock on engines")?;
1402 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1403 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1404 } else {
1405 Err(format!("Model {resolved_model_id} not found"))
1406 }
1407 }
1408
1409 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1411 let resolved_model_id = self
1412 .resolve_alias_or_default(model_id)
1413 .map_err(|e| e.to_string())?;
1414
1415 let engines = self
1416 .engines
1417 .read()
1418 .map_err(|_| "Failed to acquire read lock on engines")?;
1419 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1420 Ok(engine_instance.config.clone())
1421 } else {
1422 Err(format!("Model {resolved_model_id} not found"))
1423 }
1424 }
1425
1426 pub fn unload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1433 let resolved_model_id = self.resolve_alias(model_id)?;
1434 {
1436 let unloaded = self
1437 .unloaded_models
1438 .read()
1439 .map_err(|_| MistralRsError::EnginePoisoned)?;
1440 if unloaded.contains_key(&resolved_model_id) {
1441 return Err(MistralRsError::ModelAlreadyUnloaded(
1442 resolved_model_id.clone(),
1443 ));
1444 }
1445 }
1446
1447 let mut engines = self
1449 .engines
1450 .write()
1451 .map_err(|_| MistralRsError::EnginePoisoned)?;
1452
1453 let engine_instance = engines
1454 .remove(&resolved_model_id)
1455 .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?;
1456
1457 let loader_config = engine_instance
1459 .reboot_state
1460 .loader_config
1461 .clone()
1462 .ok_or_else(|| MistralRsError::NoLoaderConfig(resolved_model_id.clone()))?;
1463
1464 let unloaded_state = UnloadedModelState {
1466 loader_config,
1467 scheduler_config: engine_instance.reboot_state.method.clone(),
1468 engine_config: EngineConfig {
1469 no_kv_cache: engine_instance.reboot_state.no_kv_cache,
1470 no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
1471 prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
1472 disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
1473 throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
1474 search_embedding_model: engine_instance.reboot_state.search_embedding_model,
1475 search_callback: engine_instance.reboot_state.search_callback.clone(),
1476 tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
1477 tool_callbacks_with_tools: engine_instance
1478 .reboot_state
1479 .tool_callbacks_with_tools
1480 .clone(),
1481 },
1482 mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
1483 category: engine_instance.category.clone(),
1484 mistralrs_config: engine_instance.config.clone(),
1485 };
1486
1487 let _ = engine_instance.sender.try_send(Request::Terminate);
1489
1490 drop(engines);
1491
1492 let mut unloaded = self
1494 .unloaded_models
1495 .write()
1496 .map_err(|_| MistralRsError::EnginePoisoned)?;
1497 unloaded.insert(resolved_model_id.to_string(), unloaded_state);
1498
1499 let mut default_lock = self
1501 .default_engine_id
1502 .write()
1503 .map_err(|_| MistralRsError::EnginePoisoned)?;
1504 if let Some(ref default_id) = *default_lock {
1505 if default_id == &resolved_model_id {
1506 let engines = self
1508 .engines
1509 .read()
1510 .map_err(|_| MistralRsError::EnginePoisoned)?;
1511 *default_lock = engines.keys().next().cloned();
1512 }
1513 }
1514
1515 info!("Model {} unloaded successfully", resolved_model_id);
1516 Ok(())
1517 }
1518
1519 pub async fn reload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1522 let resolved_model_id = self.resolve_alias(model_id)?;
1523 {
1525 let reloading = self
1526 .reloading_models
1527 .read()
1528 .map_err(|_| MistralRsError::EnginePoisoned)?;
1529 if reloading.contains(&resolved_model_id) {
1530 return Err(MistralRsError::ModelReloading(resolved_model_id.clone()));
1531 }
1532 }
1533
1534 {
1536 let mut reloading = self
1537 .reloading_models
1538 .write()
1539 .map_err(|_| MistralRsError::EnginePoisoned)?;
1540 reloading.insert(resolved_model_id.clone());
1541 }
1542
1543 let unloaded_state = {
1545 let unloaded = self
1546 .unloaded_models
1547 .read()
1548 .map_err(|_| MistralRsError::EnginePoisoned)?;
1549 unloaded
1550 .get(&resolved_model_id)
1551 .cloned()
1552 .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?
1553 };
1554
1555 let result = self
1557 .do_reload_model(&resolved_model_id, unloaded_state)
1558 .await;
1559
1560 {
1562 let mut reloading = self
1563 .reloading_models
1564 .write()
1565 .map_err(|_| MistralRsError::EnginePoisoned)?;
1566 reloading.remove(&resolved_model_id);
1567 }
1568
1569 result
1570 }
1571
1572 async fn do_reload_model(
1574 &self,
1575 model_id: &str,
1576 unloaded_state: UnloadedModelState,
1577 ) -> Result<(), MistralRsError> {
1578 use crate::model_loader::LoaderBuilder;
1579
1580 info!("Reloading model: {}", model_id);
1581
1582 let loader_config = &unloaded_state.loader_config;
1583
1584 let loader = LoaderBuilder::new(loader_config.model_selected.clone())
1586 .with_chat_template(loader_config.chat_template.clone())
1587 .with_jinja_explicit(loader_config.jinja_explicit.clone())
1588 .build()
1589 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to build loader: {e}")))?;
1590
1591 let pipeline = loader
1593 .load_model_from_hf(
1594 None,
1595 loader_config.token_source.clone(),
1596 &loader_config.dtype,
1597 &loader_config.device,
1598 loader_config.silent,
1599 loader_config.device_map_setting.clone(),
1600 loader_config.isq,
1601 loader_config.paged_attn_config,
1602 )
1603 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to load model: {e}")))?;
1604
1605 let reboot_state = RebootState {
1607 pipeline: pipeline.clone(),
1608 method: unloaded_state.scheduler_config.clone(),
1609 no_kv_cache: unloaded_state.engine_config.no_kv_cache,
1610 no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
1611 prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
1612 disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
1613 throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
1614 search_embedding_model: unloaded_state.engine_config.search_embedding_model,
1615 search_callback: unloaded_state.engine_config.search_callback.clone(),
1616 tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
1617 tool_callbacks_with_tools: unloaded_state
1618 .engine_config
1619 .tool_callbacks_with_tools
1620 .clone(),
1621 mcp_client_config: unloaded_state.mcp_client_config.clone(),
1622 loader_config: Some(unloaded_state.loader_config.clone()),
1623 };
1624
1625 let engine_instance = Self::create_engine_instance(
1627 pipeline,
1628 unloaded_state.scheduler_config,
1629 unloaded_state.engine_config,
1630 reboot_state,
1631 )
1632 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to create engine: {e}")))?;
1633
1634 {
1636 let mut engines = self
1637 .engines
1638 .write()
1639 .map_err(|_| MistralRsError::EnginePoisoned)?;
1640 engines.insert(model_id.to_string(), engine_instance);
1641 }
1642
1643 {
1645 let mut unloaded = self
1646 .unloaded_models
1647 .write()
1648 .map_err(|_| MistralRsError::EnginePoisoned)?;
1649 unloaded.remove(model_id);
1650 }
1651
1652 info!("Model {} reloaded successfully", model_id);
1653 Ok(())
1654 }
1655
1656 pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), MistralRsError> {
1663 match tokio::runtime::Handle::try_current() {
1664 Ok(handle) => {
1665 if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
1666 Err(MistralRsError::ReloadFailed(
1667 "Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
1668 ))
1669 } else {
1670 tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
1671 }
1672 }
1673 Err(_) => {
1674 let rt = tokio::runtime::Runtime::new().map_err(|e| {
1675 MistralRsError::ReloadFailed(format!("Failed to create runtime: {e}"))
1676 })?;
1677 rt.block_on(self.reload_model(model_id))
1678 }
1679 }
1680 }
1681
1682 pub fn list_unloaded_models(&self) -> Result<Vec<String>, MistralRsError> {
1684 let unloaded = self
1685 .unloaded_models
1686 .read()
1687 .map_err(|_| MistralRsError::EnginePoisoned)?;
1688 Ok(unloaded.keys().cloned().collect())
1689 }
1690
1691 pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, MistralRsError> {
1693 let resolved_model_id = self.resolve_alias(model_id)?;
1694 let engines = self
1695 .engines
1696 .read()
1697 .map_err(|_| MistralRsError::EnginePoisoned)?;
1698 Ok(engines.contains_key(&resolved_model_id))
1699 }
1700
1701 pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, MistralRsError> {
1703 let resolved_model_id = self.resolve_alias(model_id)?;
1704 {
1706 let reloading = self
1707 .reloading_models
1708 .read()
1709 .map_err(|_| MistralRsError::EnginePoisoned)?;
1710 if reloading.contains(&resolved_model_id) {
1711 return Ok(Some(ModelStatus::Reloading));
1712 }
1713 }
1714
1715 {
1717 let engines = self
1718 .engines
1719 .read()
1720 .map_err(|_| MistralRsError::EnginePoisoned)?;
1721 if engines.contains_key(&resolved_model_id) {
1722 return Ok(Some(ModelStatus::Loaded));
1723 }
1724 }
1725
1726 {
1728 let unloaded = self
1729 .unloaded_models
1730 .read()
1731 .map_err(|_| MistralRsError::EnginePoisoned)?;
1732 if unloaded.contains_key(&resolved_model_id) {
1733 return Ok(Some(ModelStatus::Unloaded));
1734 }
1735 }
1736
1737 Ok(None)
1738 }
1739
1740 pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, MistralRsError> {
1742 let mut result = Vec::new();
1743
1744 let reloading = self
1746 .reloading_models
1747 .read()
1748 .map_err(|_| MistralRsError::EnginePoisoned)?;
1749 for model_id in reloading.iter() {
1750 result.push((model_id.clone(), ModelStatus::Reloading));
1751 }
1752 drop(reloading);
1753
1754 let engines = self
1756 .engines
1757 .read()
1758 .map_err(|_| MistralRsError::EnginePoisoned)?;
1759 for model_id in engines.keys() {
1760 result.push((model_id.clone(), ModelStatus::Loaded));
1761 }
1762 drop(engines);
1763
1764 let unloaded = self
1766 .unloaded_models
1767 .read()
1768 .map_err(|_| MistralRsError::EnginePoisoned)?;
1769 for model_id in unloaded.keys() {
1770 if !result.iter().any(|(id, _)| id == model_id) {
1772 result.push((model_id.clone(), ModelStatus::Unloaded));
1773 }
1774 }
1775
1776 Ok(result)
1777 }
1778}