1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 EngineInstruction, IntervalLogger, SearchEmbeddingModel, ENGINE_INSTRUCTIONS,
7 TERMINATE_ALL_NEXT_STEP,
8};
9use hf_hub::Cache;
10pub use lora::Ordering;
11pub use pipeline::ModelCategory;
12pub use pipeline::Pipeline;
13#[cfg(feature = "pyo3_macros")]
14use pyo3::exceptions::PyValueError;
15use std::collections::{HashMap, HashSet};
16use std::sync::OnceLock;
17use std::time::{Duration, Instant};
18use std::{
19 cell::RefCell,
20 error::Error,
21 fs::OpenOptions,
22 io::Write,
23 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24 thread::{self, JoinHandle},
25 time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31pub const MISTRALRS_GIT_REVISION: &str = match option_env!("MISTRALRS_GIT_REVISION") {
32 Some(value) => value,
33 None => "unknown",
34};
35
36mod cuda;
37mod device_map;
38mod engine;
39mod lora;
40mod metal;
41mod model_loader;
42mod moe;
43mod ops;
44mod video_input;
45pub use model_loader::{
46 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
47};
48pub use video_input::{sample_frame_indices, VideoInput};
49mod embedding_models;
50mod kv_cache;
51mod search;
52
53mod model_selected;
54pub use model_selected::ModelSelected;
55pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
56
57mod amoe;
58mod attention;
59mod diagnostics;
60mod diffusion_models;
61pub mod distributed;
62mod gguf;
63pub mod layers;
64mod layers_masker;
65mod layers_utils;
66pub mod matformer;
67mod mla;
68mod models;
69mod paged_attention;
70mod pipeline;
71mod prefix_cacher;
72pub mod reasoning_parsers;
73mod request;
74mod response;
75mod sampler;
76mod scheduler;
77mod sequence;
78mod speech_models;
79mod toml_selector;
80mod tools;
81mod topology;
82mod utils;
83mod vision_models;
84mod xlora_models;
85
86pub use diagnostics::{
87 check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
88 DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
89};
90mod tuning;
91pub use tuning::{
92 auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
93};
94
95pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
96pub use device_map::{
97 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
98};
99pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
100pub use mistralrs_audio::AudioInput;
101pub use mistralrs_mcp::{
102 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
103};
104pub use mistralrs_mcp::{
105 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
106};
107pub use mistralrs_quant::{IsqBits, IsqType, MULTI_LORA_DELIMITER};
108pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
109pub use pipeline::hf::{hf_home_dir, hf_hub_cache_dir, hf_token_path};
110pub use pipeline::{
111 chat_template::ChatTemplate, expand_isq_value, parse_isq_value, AdapterPaths, AnyMoeLoader,
112 AnyMoePipeline, AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams,
113 DiffusionLoader, DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader,
114 EmbeddingLoaderBuilder, EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig,
115 GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder,
116 GGUFSpecificConfig, GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader,
117 LlamaLoader, Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader,
118 Modalities, ModelKind, ModelPaths, MultimodalLoader, MultimodalLoaderBuilder,
119 MultimodalLoaderType, MultimodalPromptPrefixer, MultimodalSpecificConfig, NormalLoader,
120 NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
121 Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
122 SpeechLoader, SpeechPipeline, Starcoder2Loader, SupportedModality, TokenSource,
123 UQFF_MULTI_FILE_DELIMITER,
124};
125pub use request::{
126 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
127 LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
128 SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
129};
130pub use response::*;
131pub use sampler::{
132 CustomLogitsProcessor, DrySamplingParams, ModelGenerationDefaults, SamplingParams, StopTokens,
133 TopLogprob,
134};
135pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
136pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
137use serde::Serialize;
138pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
139use tokio::runtime::Runtime;
140use toml_selector::{TomlLoaderArgs, TomlSelector};
141pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
142pub use topology::{LayerTopology, Topology};
143pub use utils::debug::initialize_logging;
144pub use utils::memory_usage::MemoryUsage;
145pub use utils::normal::{ModelDType, TryIntoDType};
146pub use utils::{paged_attn_supported, using_flash_attn};
147
148pub use llguidance;
150
151pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
153pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
154
155#[derive(Clone)]
157pub struct EngineConfig {
158 pub no_kv_cache: bool,
159 pub no_prefix_cache: bool,
160 pub prefix_cache_n: usize,
161 pub disable_eos_stop: bool,
162 pub throughput_logging_enabled: bool,
163 pub search_embedding_model: Option<SearchEmbeddingModel>,
164 pub search_callback: Option<Arc<SearchCallback>>,
165 pub tool_callbacks: tools::ToolCallbacksWithTools,
166}
167
168impl Default for EngineConfig {
169 fn default() -> Self {
170 Self {
171 no_kv_cache: false,
172 no_prefix_cache: false,
173 prefix_cache_n: 16,
174 disable_eos_stop: false,
175 throughput_logging_enabled: true,
176 search_embedding_model: None,
177 search_callback: None,
178 tool_callbacks: HashMap::new(),
179 }
180 }
181}
182
183#[derive(Clone)]
185pub struct AddModelConfig {
186 pub engine_config: EngineConfig,
187 pub mcp_client_config: Option<McpClientConfig>,
188 pub loader_config: Option<ModelLoaderConfig>,
191}
192
193impl AddModelConfig {
194 pub fn new(engine_config: EngineConfig) -> Self {
195 Self {
196 engine_config,
197 mcp_client_config: None,
198 loader_config: None,
199 }
200 }
201
202 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
203 self.mcp_client_config = Some(mcp_config);
204 self
205 }
206
207 pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
210 self.loader_config = Some(loader_config);
211 self
212 }
213}
214
215#[derive(Clone)]
216pub struct MistralRsConfig {
217 pub kind: ModelKind,
218 pub device: Device,
219 pub category: ModelCategory,
220 pub modalities: Modalities,
221 pub max_seq_len: Option<usize>,
222 pub generation_defaults: Option<ModelGenerationDefaults>,
223}
224
225#[derive(Clone)]
228pub struct ModelLoaderConfig {
229 pub model_selected: ModelSelected,
231 pub token_source: TokenSource,
233 pub hf_revision: Option<String>,
235 pub dtype: ModelDType,
237 pub device: Device,
239 pub device_map_setting: DeviceMapSetting,
241 pub isq: Option<IsqType>,
243 pub paged_attn_config: Option<PagedAttentionConfig>,
245 pub silent: bool,
247 pub chat_template: Option<String>,
249 pub jinja_explicit: Option<String>,
251}
252
253#[derive(Clone)]
256pub struct UnloadedModelState {
257 pub loader_config: ModelLoaderConfig,
259 pub scheduler_config: SchedulerConfig,
261 pub engine_config: EngineConfig,
263 pub mcp_client_config: Option<McpClientConfig>,
265 pub category: ModelCategory,
267 pub mistralrs_config: MistralRsConfig,
269}
270
271struct EngineInstance {
273 sender: Sender<Request>,
274 engine_handler: JoinHandle<()>,
275 reboot_state: RebootState,
276 config: MistralRsConfig,
277 category: ModelCategory,
278 logger: Arc<IntervalLogger>,
279}
280
281pub struct MistralRs {
298 engines: RwLock<HashMap<String, EngineInstance>>,
299 unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
301 reloading_models: RwLock<HashSet<String>>,
303 default_engine_id: RwLock<Option<String>>,
304 model_aliases: RwLock<HashMap<String, String>>,
306 log: Option<String>,
307 id: String,
308 creation_time: u64,
309 next_request_id: Mutex<RefCell<usize>>,
310}
311
312#[derive(Clone)]
313struct RebootState {
314 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
315 method: SchedulerConfig,
316 no_kv_cache: bool,
317 no_prefix_cache: bool,
318 prefix_cache_n: usize,
319 disable_eos_stop: bool,
320 throughput_logging_enabled: bool,
321 search_embedding_model: Option<SearchEmbeddingModel>,
322 search_callback: Option<Arc<search::SearchCallback>>,
323 tool_callbacks: tools::ToolCallbacksWithTools,
324 mcp_client_config: Option<McpClientConfig>,
325 loader_config: Option<ModelLoaderConfig>,
327}
328
329#[derive(Debug, Clone, Copy, PartialEq, Eq)]
331pub enum ModelStatus {
332 Loaded,
333 Unloaded,
334 Reloading,
335}
336
337impl std::fmt::Display for ModelStatus {
338 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339 match self {
340 ModelStatus::Loaded => write!(f, "loaded"),
341 ModelStatus::Unloaded => write!(f, "unloaded"),
342 ModelStatus::Reloading => write!(f, "reloading"),
343 }
344 }
345}
346
347#[derive(Debug)]
348pub enum MistralRsError {
349 EnginePoisoned,
350 SenderPoisoned,
351 ModelNotFound(String),
353 ModelReloading(String),
355 ReloadFailed(String),
357 NoLoaderConfig(String),
359 ModelAlreadyLoaded(String),
361 ModelAlreadyUnloaded(String),
363}
364
365impl std::fmt::Display for MistralRsError {
366 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
367 write!(f, "{:?}", &self)
368 }
369}
370
371impl std::error::Error for MistralRsError {}
372
373#[cfg(feature = "pyo3_macros")]
374impl From<MistralRsError> for pyo3::PyErr {
375 fn from(value: MistralRsError) -> Self {
376 PyValueError::new_err(format!("{value:?}"))
377 }
378}
379
380pub struct MistralRsBuilder {
384 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
385 method: SchedulerConfig,
386 model_id_override: Option<String>,
387 log: Option<String>,
388 no_kv_cache: Option<bool>,
389 no_prefix_cache: Option<bool>,
390 prefix_cache_n: Option<usize>,
391 disable_eos_stop: Option<bool>,
392 throughput_logging_enabled: bool,
393 search_embedding_model: Option<SearchEmbeddingModel>,
394 search_callback: Option<Arc<SearchCallback>>,
395 tool_callbacks: tools::ToolCallbacksWithTools,
396 mcp_client_config: Option<McpClientConfig>,
397 loader_config: Option<ModelLoaderConfig>,
398}
399
400impl MistralRsBuilder {
401 pub fn new(
405 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
406 method: SchedulerConfig,
407 throughput_logging: bool,
408 search_embedding_model: Option<SearchEmbeddingModel>,
409 ) -> Self {
410 Self {
411 pipeline,
412 method,
413 model_id_override: None,
414 log: None,
415 no_kv_cache: None,
416 no_prefix_cache: None,
417 prefix_cache_n: None,
418 disable_eos_stop: None,
419 throughput_logging_enabled: throughput_logging,
420 search_embedding_model,
421 search_callback: None,
422 tool_callbacks: HashMap::new(),
423 mcp_client_config: None,
424 loader_config: None,
425 }
426 }
427
428 pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
430 self.model_id_override = Some(model_id.into());
431 self
432 }
433
434 pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
437 self.loader_config = Some(loader_config);
438 self
439 }
440 pub fn with_log(mut self, log: String) -> Self {
441 self.log = Some(log);
442 self
443 }
444 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
445 self.log = log;
446 self
447 }
448 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
449 self.no_kv_cache = Some(no_kv_cache);
450 self
451 }
452 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
453 self.no_prefix_cache = Some(no_prefix_cache);
454 self
455 }
456 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
457 self.prefix_cache_n = Some(prefix_cache_n);
458 self
459 }
460 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
461 self.disable_eos_stop = Some(disable_eos_stop);
462 self
463 }
464
465 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
467 self.search_callback = Some(search_callback);
468 self
469 }
470
471 pub fn with_tool_callback(
473 mut self,
474 name: impl Into<String>,
475 tool_callback: Arc<ToolCallback>,
476 ) -> Self {
477 let name = name.into();
478 self.tool_callbacks.insert(
480 name.clone(),
481 ToolCallbackWithTool {
482 callback: tool_callback,
483 tool: Tool {
484 tp: ToolType::Function,
485 function: Function {
486 description: None,
487 name,
488 parameters: None,
489 strict: None,
490 },
491 },
492 },
493 );
494 self
495 }
496
497 pub fn with_tool_callback_and_tool(
500 mut self,
501 name: impl Into<String>,
502 tool_callback: Arc<ToolCallback>,
503 tool: Tool,
504 ) -> Self {
505 let name = name.into();
506 self.tool_callbacks.insert(
507 name,
508 ToolCallbackWithTool {
509 callback: tool_callback,
510 tool,
511 },
512 );
513 self
514 }
515
516 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
518 self.mcp_client_config = Some(config);
519 self
520 }
521
522 pub async fn build(self) -> Arc<MistralRs> {
523 MistralRs::new(self).await
524 }
525}
526
527impl Drop for MistralRs {
528 fn drop(&mut self) {
529 if let Ok(engines) = self.engines.read() {
531 for (_, engine) in engines.iter() {
532 let _ = engine.sender.try_send(Request::Terminate);
534 }
535 }
536 }
537}
538
539impl MistralRs {
540 fn create_engine_instance(
542 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
543 method: SchedulerConfig,
544 config: EngineConfig,
545 reboot_state: RebootState,
546 ) -> Result<EngineInstance, String> {
547 let (tx, rx) = channel(10_000);
548
549 let pipeline_guard = pipeline.try_lock().unwrap();
550 let category = pipeline_guard.category();
551 let metadata = pipeline_guard.get_metadata();
552 let kind = metadata.kind.clone();
553 let device = pipeline_guard.device();
554 let modalities = metadata.modalities.clone();
555 let max_seq_len = match &category {
556 ModelCategory::Diffusion | ModelCategory::Speech => None,
557 _ => Some(metadata.max_seq_len),
558 };
559 let generation_defaults = pipeline_guard.generation_defaults();
560 let encoder_cache_counters = pipeline_guard.encoder_cache_counters();
561 drop(pipeline_guard);
562
563 let logger = Arc::new(IntervalLogger::new(
564 Duration::from_secs(5),
565 encoder_cache_counters,
566 ));
567 let logger_for_engine = logger.clone();
568
569 info!("Pipeline input modalities are {:?}", &modalities.input);
570 info!("Pipeline output modalities are {:?}", &modalities.output);
571
572 let mistralrs_config = MistralRsConfig {
573 kind,
574 device,
575 category: category.clone(),
576 modalities,
577 max_seq_len,
578 generation_defaults,
579 };
580
581 let tx_for_engine = tx.clone();
582 let engine_handler = thread::spawn(move || {
583 #[cfg(feature = "metal")]
584 objc::rc::autoreleasepool(move || {
585 let rt = Runtime::new().unwrap();
586 rt.block_on(async move {
587 let engine = Engine::new(
588 tx_for_engine,
589 rx,
590 pipeline,
591 method,
592 config.no_kv_cache,
593 config.no_prefix_cache,
594 config.prefix_cache_n,
595 config.disable_eos_stop,
596 config.throughput_logging_enabled,
597 config.search_embedding_model,
598 config.search_callback.clone(),
599 config.tool_callbacks.clone(),
600 logger_for_engine,
601 )
602 .expect("Engine creation failed.");
603 Arc::new(engine).run().await;
604 })
605 });
606
607 #[cfg(not(feature = "metal"))]
608 {
609 let rt = Runtime::new().unwrap();
610 rt.block_on(async move {
611 let engine = Engine::new(
612 tx_for_engine,
613 rx,
614 pipeline,
615 method,
616 config.no_kv_cache,
617 config.no_prefix_cache,
618 config.prefix_cache_n,
619 config.disable_eos_stop,
620 config.throughput_logging_enabled,
621 config.search_embedding_model,
622 config.search_callback.clone(),
623 config.tool_callbacks.clone(),
624 logger_for_engine,
625 )
626 .expect("Engine creation failed.");
627 Arc::new(engine).run().await;
628 })
629 }
630 });
631
632 Ok(EngineInstance {
633 sender: tx,
634 engine_handler,
635 reboot_state,
636 config: mistralrs_config,
637 category,
638 logger,
639 })
640 }
641
642 async fn new(config: MistralRsBuilder) -> Arc<Self> {
643 info!("git revision: {MISTRALRS_GIT_REVISION}");
644 let MistralRsBuilder {
645 pipeline,
646 method,
647 model_id_override,
648 log,
649 no_kv_cache,
650 no_prefix_cache,
651 prefix_cache_n,
652 disable_eos_stop,
653 throughput_logging_enabled,
654 search_embedding_model,
655 search_callback,
656 mut tool_callbacks,
657 mcp_client_config,
658 loader_config,
659 } = config;
660
661 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
662 get_mut_arcmutex!(pipeline).device(),
663 );
664
665 let no_kv_cache = no_kv_cache.unwrap_or(false);
666 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
667 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
668 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
669
670 if let Some(config) = &mcp_client_config {
672 let mut mcp_client = McpClient::new(config.clone());
673 let total_servers = config.servers.len();
674
675 match mcp_client.initialize().await {
676 Ok(()) => {
677 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
678 let tools_count = mcp_callbacks_with_tools.len();
679
680 for (name, callback_with_tool) in mcp_callbacks_with_tools {
682 tool_callbacks.insert(name.clone(), callback_with_tool.clone());
683 }
684
685 if tools_count == 0 {
686 warn!(
687 "MCP client initialized but no tools were registered from {} servers",
688 total_servers
689 );
690 } else {
691 info!(
692 "MCP client initialized successfully with {} tools from {} servers",
693 tools_count, total_servers
694 );
695 }
696 }
697 Err(e) => {
698 warn!(
699 "Failed to initialize MCP client with {} configured servers: {}",
700 total_servers, e
701 );
702 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
703 }
704 }
705 }
706
707 let reboot_state = RebootState {
708 pipeline: pipeline.clone(),
709 method: method.clone(),
710 no_kv_cache,
711 no_prefix_cache,
712 prefix_cache_n,
713 disable_eos_stop,
714 throughput_logging_enabled,
715 search_embedding_model,
716 search_callback: search_callback.clone(),
717 tool_callbacks: tool_callbacks.clone(),
718 mcp_client_config: mcp_client_config.clone(),
719 loader_config,
720 };
721
722 let engine_config = EngineConfig {
724 no_kv_cache,
725 no_prefix_cache,
726 prefix_cache_n,
727 disable_eos_stop,
728 throughput_logging_enabled,
729 search_embedding_model,
730 search_callback,
731 tool_callbacks,
732 };
733
734 let engine_instance =
736 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
737 .expect("Failed to create engine instance");
738
739 let pipeline_name = pipeline.try_lock().unwrap().name();
740 let (id, alias_map) = match model_id_override {
741 Some(override_id) => {
742 let mut alias_map = HashMap::new();
743 if override_id != pipeline_name {
744 alias_map.insert(pipeline_name.clone(), override_id.clone());
745 }
746 (override_id, alias_map)
747 }
748 None => (pipeline_name.clone(), HashMap::new()),
749 };
750
751 if distributed::is_daemon() {
752 let request_sender = engine_instance.sender.clone();
753
754 if cfg!(feature = "ring") {
755 distributed::ring_daemon_replicator(request_sender);
757 } else {
758 distributed::nccl_daemon_replicator(request_sender);
760 }
761
762 #[allow(clippy::empty_loop)]
763 loop {}
764 }
765
766 let is_multi_threaded = tokio::runtime::Handle::try_current()
768 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
769
770 if !distributed::is_daemon()
772 && is_multi_threaded
773 && matches!(
774 engine_instance.category,
775 ModelCategory::Text | ModelCategory::Multimodal { .. }
776 )
777 {
778 let clone_sender = engine_instance.sender.clone();
779 tokio::task::block_in_place(|| {
780 let (tx, mut rx) = channel(1);
781 let req = Request::Normal(Box::new(NormalRequest {
782 id: 0,
783 messages: RequestMessage::Completion {
784 text: "hello".to_string(),
785 echo_prompt: false,
786 best_of: None,
787 },
788 sampling_params: SamplingParams {
789 max_len: Some(1),
790 ..SamplingParams::deterministic()
791 },
792 response: tx,
793 return_logprobs: false,
794 is_streaming: false,
795 constraint: Constraint::None,
796 suffix: None,
797 tool_choice: None,
798 tools: None,
799 logits_processors: None,
800 return_raw_logits: false,
801 web_search_options: None,
802 max_tool_rounds: None,
803 tool_dispatch_url: None,
804 model_id: None,
805 truncate_sequence: false,
806 }));
807 info!("Beginning dummy run.");
808 let start = Instant::now();
809 clone_sender.blocking_send(req).unwrap();
810
811 let mut received_any = false;
813 while let Some(_resp) = rx.blocking_recv() {
814 received_any = true;
815 }
816
817 if received_any {
818 let end = Instant::now();
819 info!(
820 "Dummy run completed in {}s.",
821 end.duration_since(start).as_secs_f64()
822 );
823 } else {
824 warn!("Dummy run failed!");
825 }
826 });
827
828 engine_instance.logger.reset();
830 }
831
832 let mut engines = HashMap::new();
834 engines.insert(id.clone(), engine_instance);
835
836 Arc::new(Self {
837 engines: RwLock::new(engines),
838 unloaded_models: RwLock::new(HashMap::new()),
839 reloading_models: RwLock::new(HashSet::new()),
840 default_engine_id: RwLock::new(Some(id.clone())),
841 model_aliases: RwLock::new(alias_map),
842 log,
843 id,
844 creation_time: SystemTime::now()
845 .duration_since(UNIX_EPOCH)
846 .expect("Time travel has occurred!")
847 .as_secs(),
848 next_request_id: Mutex::new(RefCell::new(1)),
849 })
850 }
851
852 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
854 let mut engines = self.engines.write().map_err(|_| {
855 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
856 MistralRsError::EnginePoisoned
857 })?;
858
859 if let Some(engine_instance) = engines.get(model_id) {
860 if !engine_instance.engine_handler.is_finished() {
861 tracing::info!("Engine {} already running, returning ok", model_id);
862 return Ok(());
863 }
864
865 let reboot_state = engine_instance.reboot_state.clone();
866 let engine_config = EngineConfig {
867 no_kv_cache: reboot_state.no_kv_cache,
868 no_prefix_cache: reboot_state.no_prefix_cache,
869 prefix_cache_n: reboot_state.prefix_cache_n,
870 disable_eos_stop: reboot_state.disable_eos_stop,
871 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
872 search_embedding_model: reboot_state.search_embedding_model,
873 search_callback: reboot_state.search_callback.clone(),
874 tool_callbacks: reboot_state.tool_callbacks.clone(),
875 };
876 let new_engine_instance = Self::create_engine_instance(
877 reboot_state.pipeline.clone(),
878 reboot_state.method.clone(),
879 engine_config,
880 reboot_state,
881 )
882 .map_err(|e| {
883 tracing::error!("Failed to create new engine instance: {}", e);
884 MistralRsError::EnginePoisoned
885 })?;
886
887 engines.insert(model_id.to_string(), new_engine_instance);
888 tracing::info!("Successfully rebooted engine {}", model_id);
889 Ok(())
890 } else {
891 Err(MistralRsError::EnginePoisoned)
892 }
893 }
894
895 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
896 let engines = self.engines.read().map_err(|_| {
897 tracing::warn!("Couldn't get read lock on engines!");
898 MistralRsError::EnginePoisoned
899 })?;
900
901 if let Some(engine_instance) = engines.get(model_id) {
902 Ok(engine_instance.engine_handler.is_finished())
903 } else {
904 Err(MistralRsError::EnginePoisoned)
905 }
906 }
907
908 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
911 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
912
913 let is_loaded = {
915 let engines = self
916 .engines
917 .read()
918 .map_err(|_| MistralRsError::SenderPoisoned)?;
919 engines.contains_key(&resolved_model_id)
920 };
921
922 if is_loaded {
923 if self.engine_dead(&resolved_model_id)? {
925 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
926 self.reboot_engine(&resolved_model_id)?
927 }
928
929 let engines = self
930 .engines
931 .read()
932 .map_err(|_| MistralRsError::SenderPoisoned)?;
933 if let Some(engine_instance) = engines.get(&resolved_model_id) {
934 return Ok(engine_instance.sender.clone());
935 }
936 }
937
938 let is_unloaded = {
940 let unloaded = self
941 .unloaded_models
942 .read()
943 .map_err(|_| MistralRsError::EnginePoisoned)?;
944 unloaded.contains_key(&resolved_model_id)
945 };
946
947 if is_unloaded {
948 tracing::info!(
949 "Model {} is unloaded, triggering auto-reload",
950 resolved_model_id
951 );
952 self.reload_model_blocking(&resolved_model_id)?;
953
954 let engines = self
956 .engines
957 .read()
958 .map_err(|_| MistralRsError::SenderPoisoned)?;
959 if let Some(engine_instance) = engines.get(&resolved_model_id) {
960 return Ok(engine_instance.sender.clone());
961 }
962 }
963
964 Err(MistralRsError::ModelNotFound(resolved_model_id))
965 }
966
967 pub fn get_id(&self) -> String {
968 self.id.clone()
969 }
970
971 pub fn get_creation_time(&self) -> u64 {
972 self.creation_time
973 }
974
975 fn resolve_alias(&self, model_id: &str) -> Result<String, MistralRsError> {
976 let aliases = self
977 .model_aliases
978 .read()
979 .map_err(|_| MistralRsError::SenderPoisoned)?;
980 if let Some(primary_id) = aliases.get(model_id) {
981 Ok(primary_id.clone())
982 } else {
983 Ok(model_id.to_string())
984 }
985 }
986
987 fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, MistralRsError> {
988 match model_id {
989 Some(id) => self.resolve_alias(id),
990 None => {
991 let default_lock = self
992 .default_engine_id
993 .read()
994 .map_err(|_| MistralRsError::SenderPoisoned)?;
995 Ok(default_lock
996 .as_ref()
997 .ok_or(MistralRsError::EnginePoisoned)?
998 .clone())
999 }
1000 }
1001 }
1002
1003 pub fn register_model_alias(
1005 &self,
1006 alias: impl Into<String>,
1007 model_id: &str,
1008 ) -> Result<(), String> {
1009 let alias = alias.into();
1010 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1011
1012 if alias == resolved_model_id {
1013 return Ok(());
1014 }
1015
1016 let reloading = self
1017 .reloading_models
1018 .read()
1019 .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1020 let model_reloading = reloading.contains(&resolved_model_id);
1021 let alias_conflict = reloading.contains(&alias);
1022 drop(reloading);
1023
1024 let engines = self
1025 .engines
1026 .read()
1027 .map_err(|_| "Failed to acquire read lock on engines")?;
1028 let model_loaded = engines.contains_key(&resolved_model_id);
1029 let alias_conflict = alias_conflict || engines.contains_key(&alias);
1030 drop(engines);
1031
1032 let unloaded = self
1033 .unloaded_models
1034 .read()
1035 .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1036 let model_unloaded = unloaded.contains_key(&resolved_model_id);
1037 let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
1038 drop(unloaded);
1039
1040 if !(model_loaded || model_unloaded || model_reloading) {
1041 return Err(format!("Model {resolved_model_id} not found"));
1042 }
1043
1044 if alias_conflict {
1045 return Err(format!(
1046 "Alias '{}' conflicts with an existing model ID",
1047 alias
1048 ));
1049 }
1050
1051 let mut aliases = self
1052 .model_aliases
1053 .write()
1054 .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1055 if let Some(existing) = aliases.get(&alias) {
1056 if existing == &resolved_model_id {
1057 return Ok(());
1058 }
1059 return Err(format!(
1060 "Alias '{}' is already assigned to model '{}'",
1061 alias, existing
1062 ));
1063 }
1064 aliases.insert(alias, resolved_model_id);
1065 Ok(())
1066 }
1067
1068 pub fn model_exists(&self, model_id: &str) -> Result<bool, MistralRsError> {
1070 let resolved_model_id = self.resolve_alias(model_id)?;
1071
1072 let reloading = self
1073 .reloading_models
1074 .read()
1075 .map_err(|_| MistralRsError::EnginePoisoned)?;
1076 if reloading.contains(&resolved_model_id) {
1077 return Ok(true);
1078 }
1079 drop(reloading);
1080
1081 let engines = self
1082 .engines
1083 .read()
1084 .map_err(|_| MistralRsError::EnginePoisoned)?;
1085 if engines.contains_key(&resolved_model_id) {
1086 return Ok(true);
1087 }
1088 drop(engines);
1089
1090 let unloaded = self
1091 .unloaded_models
1092 .read()
1093 .map_err(|_| MistralRsError::EnginePoisoned)?;
1094 if unloaded.contains_key(&resolved_model_id) {
1095 return Ok(true);
1096 }
1097
1098 Ok(false)
1099 }
1100
1101 pub fn get_logger(
1103 &self,
1104 model_id: Option<&str>,
1105 ) -> Result<Arc<IntervalLogger>, MistralRsError> {
1106 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1107
1108 let engines = self
1109 .engines
1110 .read()
1111 .map_err(|_| MistralRsError::SenderPoisoned)?;
1112 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1113 Ok(engine_instance.logger.clone())
1114 } else {
1115 Err(MistralRsError::EnginePoisoned)
1116 }
1117 }
1118
1119 pub fn get_model_category(
1121 &self,
1122 model_id: Option<&str>,
1123 ) -> Result<ModelCategory, MistralRsError> {
1124 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1125
1126 let engines = self
1127 .engines
1128 .read()
1129 .map_err(|_| MistralRsError::SenderPoisoned)?;
1130 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1131 Ok(engine_instance.category.clone())
1132 } else {
1133 Err(MistralRsError::EnginePoisoned)
1134 }
1135 }
1136
1137 pub fn max_sequence_length(
1139 &self,
1140 model_id: Option<&str>,
1141 ) -> Result<Option<usize>, MistralRsError> {
1142 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1143
1144 let engines = self
1145 .engines
1146 .read()
1147 .map_err(|_| MistralRsError::SenderPoisoned)?;
1148 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1149 Ok(engine_instance.config.max_seq_len)
1150 } else {
1151 Err(MistralRsError::EnginePoisoned)
1152 }
1153 }
1154
1155 pub fn next_request_id(&self) -> usize {
1156 let l = self.next_request_id.lock().unwrap();
1157 let last = &mut *l.borrow_mut();
1158 let last_v = *last;
1159 *last += 1;
1160 last_v
1161 }
1162
1163 pub async fn add_model(
1165 &self,
1166 model_id: String,
1167 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
1168 method: SchedulerConfig,
1169 config: AddModelConfig,
1170 ) -> Result<(), String> {
1171 {
1172 let reloading = self
1173 .reloading_models
1174 .read()
1175 .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1176 if reloading.contains(&model_id) {
1177 return Err(format!("Model {model_id} is currently reloading"));
1178 }
1179 }
1180 {
1181 let engines = self
1182 .engines
1183 .read()
1184 .map_err(|_| "Failed to acquire read lock on engines")?;
1185 if engines.contains_key(&model_id) {
1186 return Err(format!("Model {model_id} already exists"));
1187 }
1188 }
1189 {
1190 let unloaded = self
1191 .unloaded_models
1192 .read()
1193 .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1194 if unloaded.contains_key(&model_id) {
1195 return Err(format!("Model {model_id} already exists (unloaded)"));
1196 }
1197 }
1198 {
1199 let aliases = self
1200 .model_aliases
1201 .read()
1202 .map_err(|_| "Failed to acquire read lock on model_aliases")?;
1203 if aliases.contains_key(&model_id) {
1204 return Err(format!(
1205 "Model ID '{}' conflicts with an existing alias",
1206 model_id
1207 ));
1208 }
1209 }
1210
1211 let reboot_state = RebootState {
1212 pipeline: pipeline.clone(),
1213 method: method.clone(),
1214 no_kv_cache: config.engine_config.no_kv_cache,
1215 no_prefix_cache: config.engine_config.no_prefix_cache,
1216 prefix_cache_n: config.engine_config.prefix_cache_n,
1217 disable_eos_stop: config.engine_config.disable_eos_stop,
1218 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
1219 search_embedding_model: config.engine_config.search_embedding_model,
1220 search_callback: config.engine_config.search_callback.clone(),
1221 tool_callbacks: config.engine_config.tool_callbacks.clone(),
1222 mcp_client_config: config.mcp_client_config.clone(),
1223 loader_config: config.loader_config.clone(),
1224 };
1225
1226 let engine_instance =
1227 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
1228
1229 let mut engines = self
1230 .engines
1231 .write()
1232 .map_err(|_| "Failed to acquire write lock on engines")?;
1233 engines.insert(model_id.clone(), engine_instance);
1234
1235 if engines.len() == 1 {
1237 let mut default_lock = self
1238 .default_engine_id
1239 .write()
1240 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1241 *default_lock = Some(model_id.clone());
1242 info!("First model added, setting '{}' as default", model_id);
1243 }
1244
1245 Ok(())
1246 }
1247
1248 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
1250 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1251 let mut engines = self
1252 .engines
1253 .write()
1254 .map_err(|_| "Failed to acquire write lock on engines")?;
1255
1256 if engines.len() <= 1 {
1257 return Err("Cannot remove the last model from MistralRs".to_string());
1258 }
1259
1260 if let Some(engine_instance) = engines.remove(&resolved_model_id) {
1261 let _ = engine_instance.sender.blocking_send(Request::Terminate);
1263
1264 let mut default_lock = self
1266 .default_engine_id
1267 .write()
1268 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1269 if let Some(ref default_id) = *default_lock {
1270 if default_id == &resolved_model_id {
1271 *default_lock = engines.keys().next().cloned();
1273 }
1274 }
1275 drop(default_lock);
1276 drop(engines);
1277
1278 let mut aliases = self
1280 .model_aliases
1281 .write()
1282 .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1283 aliases.retain(|_, target| target != &resolved_model_id);
1284
1285 Ok(())
1286 } else {
1287 Err(format!("Model {resolved_model_id} not found"))
1288 }
1289 }
1290
1291 pub fn list_models(&self) -> Result<Vec<String>, String> {
1293 let engines = self
1294 .engines
1295 .read()
1296 .map_err(|_| "Failed to acquire read lock on engines")?;
1297 Ok(engines.keys().cloned().collect())
1298 }
1299
1300 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
1302 let default_lock = self
1303 .default_engine_id
1304 .read()
1305 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
1306 Ok(default_lock.clone())
1307 }
1308
1309 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
1311 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1312 let engines = self
1313 .engines
1314 .read()
1315 .map_err(|_| "Failed to acquire read lock on engines")?;
1316 if !engines.contains_key(&resolved_model_id) {
1317 return Err(format!("Model {resolved_model_id} not found"));
1318 }
1319 drop(engines);
1320
1321 let mut default_lock = self
1322 .default_engine_id
1323 .write()
1324 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1325 let old_default = default_lock.clone();
1326 *default_lock = Some(resolved_model_id.clone());
1327
1328 info!(
1330 "Default model changed: {:?} -> {:?}",
1331 old_default, resolved_model_id
1332 );
1333
1334 Ok(())
1335 }
1336
1337 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
1339 let model_id = match &mut request {
1340 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
1341 _ => None, };
1343
1344 let sender = self.get_sender(model_id)?;
1345 sender
1346 .blocking_send(request)
1347 .map_err(|_| MistralRsError::SenderPoisoned)
1348 }
1349
1350 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
1351 if let Some(file) = &this.log {
1352 let mut f = OpenOptions::new()
1353 .append(true)
1354 .create(true) .open(file)
1356 .expect("Unable to open file");
1357 let time = chrono::offset::Local::now();
1358 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
1359 .expect("Unable to write data");
1360 }
1361 }
1362
1363 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1364 if let Some(file) = &this.log {
1365 let mut f = OpenOptions::new()
1366 .append(true)
1367 .create(true) .open(file)
1369 .expect("Unable to open file");
1370 let time = chrono::offset::Local::now();
1371 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1372 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1373 .expect("Unable to write data");
1374 }
1375 }
1376
1377 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1378 if let Some(file) = &this.log {
1379 let mut f = OpenOptions::new()
1380 .append(true)
1381 .create(true) .open(file)
1383 .expect("Unable to open file");
1384 let time = chrono::offset::Local::now();
1385 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1386 .expect("Unable to write data");
1387 }
1388 }
1389
1390 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1392 let resolved_model_id = self
1393 .resolve_alias_or_default(model_id)
1394 .map_err(|e| e.to_string())?;
1395
1396 let engines = self
1397 .engines
1398 .read()
1399 .map_err(|_| "Failed to acquire read lock on engines")?;
1400 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1401 Ok(engine_instance.reboot_state.tool_callbacks.len())
1402 } else {
1403 Err(format!("Model {resolved_model_id} not found"))
1404 }
1405 }
1406
1407 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1409 let resolved_model_id = self
1410 .resolve_alias_or_default(model_id)
1411 .map_err(|e| e.to_string())?;
1412
1413 let engines = self
1414 .engines
1415 .read()
1416 .map_err(|_| "Failed to acquire read lock on engines")?;
1417 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1418 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1419 } else {
1420 Err(format!("Model {resolved_model_id} not found"))
1421 }
1422 }
1423
1424 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1426 let resolved_model_id = self
1427 .resolve_alias_or_default(model_id)
1428 .map_err(|e| e.to_string())?;
1429
1430 let engines = self
1431 .engines
1432 .read()
1433 .map_err(|_| "Failed to acquire read lock on engines")?;
1434 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1435 Ok(engine_instance.config.clone())
1436 } else {
1437 Err(format!("Model {resolved_model_id} not found"))
1438 }
1439 }
1440
1441 pub fn unload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1448 let resolved_model_id = self.resolve_alias(model_id)?;
1449 {
1451 let unloaded = self
1452 .unloaded_models
1453 .read()
1454 .map_err(|_| MistralRsError::EnginePoisoned)?;
1455 if unloaded.contains_key(&resolved_model_id) {
1456 return Err(MistralRsError::ModelAlreadyUnloaded(
1457 resolved_model_id.clone(),
1458 ));
1459 }
1460 }
1461
1462 let mut engines = self
1464 .engines
1465 .write()
1466 .map_err(|_| MistralRsError::EnginePoisoned)?;
1467
1468 let engine_instance = engines
1469 .remove(&resolved_model_id)
1470 .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?;
1471
1472 let loader_config = engine_instance
1474 .reboot_state
1475 .loader_config
1476 .clone()
1477 .ok_or_else(|| MistralRsError::NoLoaderConfig(resolved_model_id.clone()))?;
1478
1479 let unloaded_state = UnloadedModelState {
1481 loader_config,
1482 scheduler_config: engine_instance.reboot_state.method.clone(),
1483 engine_config: EngineConfig {
1484 no_kv_cache: engine_instance.reboot_state.no_kv_cache,
1485 no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
1486 prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
1487 disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
1488 throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
1489 search_embedding_model: engine_instance.reboot_state.search_embedding_model,
1490 search_callback: engine_instance.reboot_state.search_callback.clone(),
1491 tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
1492 },
1493 mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
1494 category: engine_instance.category.clone(),
1495 mistralrs_config: engine_instance.config.clone(),
1496 };
1497
1498 let _ = engine_instance.sender.try_send(Request::Terminate);
1500
1501 drop(engines);
1502
1503 let mut unloaded = self
1505 .unloaded_models
1506 .write()
1507 .map_err(|_| MistralRsError::EnginePoisoned)?;
1508 unloaded.insert(resolved_model_id.to_string(), unloaded_state);
1509
1510 let mut default_lock = self
1512 .default_engine_id
1513 .write()
1514 .map_err(|_| MistralRsError::EnginePoisoned)?;
1515 if let Some(ref default_id) = *default_lock {
1516 if default_id == &resolved_model_id {
1517 let engines = self
1519 .engines
1520 .read()
1521 .map_err(|_| MistralRsError::EnginePoisoned)?;
1522 *default_lock = engines.keys().next().cloned();
1523 }
1524 }
1525
1526 info!("Model {} unloaded successfully", resolved_model_id);
1527 Ok(())
1528 }
1529
1530 pub async fn reload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1533 let resolved_model_id = self.resolve_alias(model_id)?;
1534 {
1536 let reloading = self
1537 .reloading_models
1538 .read()
1539 .map_err(|_| MistralRsError::EnginePoisoned)?;
1540 if reloading.contains(&resolved_model_id) {
1541 return Err(MistralRsError::ModelReloading(resolved_model_id.clone()));
1542 }
1543 }
1544
1545 {
1547 let mut reloading = self
1548 .reloading_models
1549 .write()
1550 .map_err(|_| MistralRsError::EnginePoisoned)?;
1551 reloading.insert(resolved_model_id.clone());
1552 }
1553
1554 let unloaded_state = {
1556 let unloaded = self
1557 .unloaded_models
1558 .read()
1559 .map_err(|_| MistralRsError::EnginePoisoned)?;
1560 unloaded
1561 .get(&resolved_model_id)
1562 .cloned()
1563 .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?
1564 };
1565
1566 let result = self
1568 .do_reload_model(&resolved_model_id, unloaded_state)
1569 .await;
1570
1571 {
1573 let mut reloading = self
1574 .reloading_models
1575 .write()
1576 .map_err(|_| MistralRsError::EnginePoisoned)?;
1577 reloading.remove(&resolved_model_id);
1578 }
1579
1580 result
1581 }
1582
1583 async fn do_reload_model(
1585 &self,
1586 model_id: &str,
1587 unloaded_state: UnloadedModelState,
1588 ) -> Result<(), MistralRsError> {
1589 use crate::model_loader::LoaderBuilder;
1590
1591 info!("Reloading model: {}", model_id);
1592
1593 let loader_config = &unloaded_state.loader_config;
1594
1595 let loader = LoaderBuilder::new(loader_config.model_selected.clone())
1597 .with_chat_template(loader_config.chat_template.clone())
1598 .with_jinja_explicit(loader_config.jinja_explicit.clone())
1599 .build()
1600 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to build loader: {e}")))?;
1601
1602 let pipeline = loader
1604 .load_model_from_hf(
1605 None,
1606 loader_config.token_source.clone(),
1607 &loader_config.dtype,
1608 &loader_config.device,
1609 loader_config.silent,
1610 loader_config.device_map_setting.clone(),
1611 loader_config.isq,
1612 loader_config.paged_attn_config,
1613 )
1614 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to load model: {e}")))?;
1615
1616 let reboot_state = RebootState {
1618 pipeline: pipeline.clone(),
1619 method: unloaded_state.scheduler_config.clone(),
1620 no_kv_cache: unloaded_state.engine_config.no_kv_cache,
1621 no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
1622 prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
1623 disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
1624 throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
1625 search_embedding_model: unloaded_state.engine_config.search_embedding_model,
1626 search_callback: unloaded_state.engine_config.search_callback.clone(),
1627 tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
1628 mcp_client_config: unloaded_state.mcp_client_config.clone(),
1629 loader_config: Some(unloaded_state.loader_config.clone()),
1630 };
1631
1632 let engine_instance = Self::create_engine_instance(
1634 pipeline,
1635 unloaded_state.scheduler_config,
1636 unloaded_state.engine_config,
1637 reboot_state,
1638 )
1639 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to create engine: {e}")))?;
1640
1641 {
1643 let mut engines = self
1644 .engines
1645 .write()
1646 .map_err(|_| MistralRsError::EnginePoisoned)?;
1647 engines.insert(model_id.to_string(), engine_instance);
1648 }
1649
1650 {
1652 let mut unloaded = self
1653 .unloaded_models
1654 .write()
1655 .map_err(|_| MistralRsError::EnginePoisoned)?;
1656 unloaded.remove(model_id);
1657 }
1658
1659 info!("Model {} reloaded successfully", model_id);
1660 Ok(())
1661 }
1662
1663 pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), MistralRsError> {
1670 match tokio::runtime::Handle::try_current() {
1671 Ok(handle) => {
1672 if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
1673 Err(MistralRsError::ReloadFailed(
1674 "Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
1675 ))
1676 } else {
1677 tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
1678 }
1679 }
1680 Err(_) => {
1681 let rt = tokio::runtime::Runtime::new().map_err(|e| {
1682 MistralRsError::ReloadFailed(format!("Failed to create runtime: {e}"))
1683 })?;
1684 rt.block_on(self.reload_model(model_id))
1685 }
1686 }
1687 }
1688
1689 pub fn list_unloaded_models(&self) -> Result<Vec<String>, MistralRsError> {
1691 let unloaded = self
1692 .unloaded_models
1693 .read()
1694 .map_err(|_| MistralRsError::EnginePoisoned)?;
1695 Ok(unloaded.keys().cloned().collect())
1696 }
1697
1698 pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, MistralRsError> {
1700 let resolved_model_id = self.resolve_alias(model_id)?;
1701 let engines = self
1702 .engines
1703 .read()
1704 .map_err(|_| MistralRsError::EnginePoisoned)?;
1705 Ok(engines.contains_key(&resolved_model_id))
1706 }
1707
1708 pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, MistralRsError> {
1710 let resolved_model_id = self.resolve_alias(model_id)?;
1711 {
1713 let reloading = self
1714 .reloading_models
1715 .read()
1716 .map_err(|_| MistralRsError::EnginePoisoned)?;
1717 if reloading.contains(&resolved_model_id) {
1718 return Ok(Some(ModelStatus::Reloading));
1719 }
1720 }
1721
1722 {
1724 let engines = self
1725 .engines
1726 .read()
1727 .map_err(|_| MistralRsError::EnginePoisoned)?;
1728 if engines.contains_key(&resolved_model_id) {
1729 return Ok(Some(ModelStatus::Loaded));
1730 }
1731 }
1732
1733 {
1735 let unloaded = self
1736 .unloaded_models
1737 .read()
1738 .map_err(|_| MistralRsError::EnginePoisoned)?;
1739 if unloaded.contains_key(&resolved_model_id) {
1740 return Ok(Some(ModelStatus::Unloaded));
1741 }
1742 }
1743
1744 Ok(None)
1745 }
1746
1747 pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, MistralRsError> {
1749 let mut result = Vec::new();
1750
1751 let reloading = self
1753 .reloading_models
1754 .read()
1755 .map_err(|_| MistralRsError::EnginePoisoned)?;
1756 for model_id in reloading.iter() {
1757 result.push((model_id.clone(), ModelStatus::Reloading));
1758 }
1759 drop(reloading);
1760
1761 let engines = self
1763 .engines
1764 .read()
1765 .map_err(|_| MistralRsError::EnginePoisoned)?;
1766 for model_id in engines.keys() {
1767 result.push((model_id.clone(), ModelStatus::Loaded));
1768 }
1769 drop(engines);
1770
1771 let unloaded = self
1773 .unloaded_models
1774 .read()
1775 .map_err(|_| MistralRsError::EnginePoisoned)?;
1776 for model_id in unloaded.keys() {
1777 if !result.iter().any(|(id, _)| id == model_id) {
1779 result.push((model_id.clone(), ModelStatus::Unloaded));
1780 }
1781 }
1782
1783 Ok(result)
1784 }
1785}