1#![deny(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2use candle_core::Device;
3use engine::Engine;
4pub use engine::{
5 get_engine_terminate_flag, reset_engine_terminate_flag, should_terminate_engine_sequences,
6 EngineInstruction, IntervalLogger, SearchEmbeddingModel, ENGINE_INSTRUCTIONS,
7 TERMINATE_ALL_NEXT_STEP,
8};
9use hf_hub::Cache;
10pub use lora::Ordering;
11pub use pipeline::ModelCategory;
12pub use pipeline::Pipeline;
13#[cfg(feature = "pyo3_macros")]
14use pyo3::exceptions::PyValueError;
15use std::collections::{HashMap, HashSet};
16use std::sync::OnceLock;
17use std::time::{Duration, Instant};
18use std::{
19 cell::RefCell,
20 error::Error,
21 fs::OpenOptions,
22 io::Write,
23 sync::{atomic::AtomicBool, Arc, Mutex, RwLock},
24 thread::{self, JoinHandle},
25 time::{SystemTime, UNIX_EPOCH},
26};
27use tokio::sync::mpsc::{channel, Sender};
28use tracing::info;
29use tracing::warn;
30
31pub const MISTRALRS_GIT_REVISION: &str = match option_env!("MISTRALRS_GIT_REVISION") {
32 Some(value) => value,
33 None => "unknown",
34};
35
36mod cuda;
37mod device_map;
38mod engine;
39mod lora;
40mod metal;
41mod model_loader;
42mod moe;
43mod ops;
44mod video_input;
45pub use model_loader::{
46 get_auto_device_map_params, get_model_dtype, get_tgt_non_granular_index, LoaderBuilder,
47};
48pub use video_input::{sample_frame_indices, VideoInput};
49mod embedding_models;
50mod kv_cache;
51mod search;
52
53mod model_selected;
54pub use model_selected::ModelSelected;
55pub use toml_selector::{get_toml_selected_model_device_map_params, get_toml_selected_model_dtype};
56
57mod amoe;
58mod attention;
59mod diagnostics;
60mod diffusion_models;
61pub mod distributed;
62mod gguf;
63pub mod layers;
64mod layers_masker;
65mod layers_utils;
66pub mod matformer;
67mod mla;
68mod models;
69mod paged_attention;
70mod pipeline;
71mod prefix_cacher;
72pub mod reasoning_parsers;
73mod request;
74mod response;
75mod sampler;
76mod scheduler;
77mod sequence;
78mod speech_models;
79mod toml_selector;
80mod tools;
81mod topology;
82mod utils;
83mod vision_models;
84mod xlora_models;
85
86pub use diagnostics::{
87 check_hf_gated_access, collect_system_info, run_doctor, BuildInfo, CpuInfo, DeviceInfo,
88 DoctorCheck, DoctorReport, DoctorStatus, HfConnectivityInfo, MemoryInfo, SystemInfo,
89};
90mod tuning;
91pub use tuning::{
92 auto_tune, AutoTuneRequest, AutoTuneResult, FitStatus, QualityTier, TuneCandidate, TuneProfile,
93};
94
95pub use amoe::{AnyMoeConfig, AnyMoeExpertType};
96pub use device_map::{
97 DeviceLayerMapMetadata, DeviceMapMetadata, DeviceMapSetting, LayerDeviceMapper,
98};
99pub use gguf::{GGUFArchitecture, GGUF_MULTI_FILE_DELIMITER};
100pub use mistralrs_audio::AudioInput;
101pub use mistralrs_mcp::{
102 CalledFunction, Function, Tool, ToolCallback, ToolCallbackWithTool, ToolType,
103};
104pub use mistralrs_mcp::{
105 McpClient, McpClientConfig, McpServerConfig, McpServerSource, McpToolInfo,
106};
107pub use mistralrs_quant::{IsqBits, IsqType, MULTI_LORA_DELIMITER};
108pub use paged_attention::{MemoryGpuConfig, PagedAttentionConfig, PagedCacheType};
109pub use pipeline::hf::{hf_home_dir, hf_hub_cache_dir, hf_token_path};
110pub use pipeline::{
111 chat_template::ChatTemplate, expand_isq_value, parse_isq_value, AdapterPaths, AnyMoeLoader,
112 AnyMoePipeline, AutoDeviceMapParams, AutoLoader, AutoLoaderBuilder, DiffusionGenerationParams,
113 DiffusionLoader, DiffusionLoaderBuilder, DiffusionLoaderType, EmbeddingLoader,
114 EmbeddingLoaderBuilder, EmbeddingLoaderType, EmbeddingModelPaths, EmbeddingSpecificConfig,
115 GGMLLoader, GGMLLoaderBuilder, GGMLSpecificConfig, GGUFLoader, GGUFLoaderBuilder,
116 GGUFSpecificConfig, GemmaLoader, Idefics2Loader, IsqOrganization, LLaVALoader, LLaVANextLoader,
117 LlamaLoader, Loader, LocalModelPaths, LoraAdapterPaths, MistralLoader, MixtralLoader,
118 Modalities, ModelKind, ModelPaths, MultimodalLoader, MultimodalLoaderBuilder,
119 MultimodalLoaderType, MultimodalPromptPrefixer, MultimodalSpecificConfig, NormalLoader,
120 NormalLoaderBuilder, NormalLoaderType, NormalSpecificConfig, Phi2Loader, Phi3Loader,
121 Phi3VLoader, Qwen2Loader, SpeculativeConfig, SpeculativeLoader, SpeculativePipeline,
122 SpeechLoader, SpeechPipeline, Starcoder2Loader, SupportedModality, TokenSource,
123 UQFF_MULTI_FILE_DELIMITER,
124};
125pub use request::{
126 ApproximateUserLocation, Constraint, DetokenizationRequest, ImageGenerationResponseFormat,
127 LlguidanceGrammar, MessageContent, NormalRequest, ReasoningEffort, Request, RequestMessage,
128 SearchContextSize, TokenizationRequest, WebSearchOptions, WebSearchUserLocation,
129};
130pub use response::*;
131pub use sampler::{
132 CustomLogitsProcessor, DrySamplingParams, ModelGenerationDefaults, SamplingParams, StopTokens,
133 TopLogprob,
134};
135pub use scheduler::{DefaultSchedulerMethod, SchedulerConfig};
136pub use search::{SearchCallback, SearchFunctionParameters, SearchResult};
137use serde::Serialize;
138pub use speech_models::{utils as speech_utils, SpeechGenerationConfig, SpeechLoaderType};
139use tokio::runtime::Runtime;
140use toml_selector::{TomlLoaderArgs, TomlSelector};
141pub use tools::{ToolCallResponse, ToolCallType, ToolCallbacks, ToolChoice};
142pub use topology::{LayerTopology, Topology};
143pub use utils::debug::initialize_logging;
144pub use utils::memory_usage::MemoryUsage;
145pub use utils::normal::{ModelDType, TryIntoDType};
146pub use utils::{paged_attn_supported, using_flash_attn};
147
148pub use llguidance;
150
151pub(crate) static DEBUG: AtomicBool = AtomicBool::new(false);
153pub static GLOBAL_HF_CACHE: OnceLock<Cache> = OnceLock::new();
154
155#[derive(Clone)]
157pub struct EngineConfig {
158 pub no_kv_cache: bool,
159 pub no_prefix_cache: bool,
160 pub prefix_cache_n: usize,
161 pub disable_eos_stop: bool,
162 pub throughput_logging_enabled: bool,
163 pub search_embedding_model: Option<SearchEmbeddingModel>,
164 pub search_callback: Option<Arc<SearchCallback>>,
165 pub tool_callbacks: tools::ToolCallbacks,
166 pub tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
167}
168
169impl Default for EngineConfig {
170 fn default() -> Self {
171 Self {
172 no_kv_cache: false,
173 no_prefix_cache: false,
174 prefix_cache_n: 16,
175 disable_eos_stop: false,
176 throughput_logging_enabled: true,
177 search_embedding_model: None,
178 search_callback: None,
179 tool_callbacks: HashMap::new(),
180 tool_callbacks_with_tools: HashMap::new(),
181 }
182 }
183}
184
185#[derive(Clone)]
187pub struct AddModelConfig {
188 pub engine_config: EngineConfig,
189 pub mcp_client_config: Option<McpClientConfig>,
190 pub loader_config: Option<ModelLoaderConfig>,
193}
194
195impl AddModelConfig {
196 pub fn new(engine_config: EngineConfig) -> Self {
197 Self {
198 engine_config,
199 mcp_client_config: None,
200 loader_config: None,
201 }
202 }
203
204 pub fn with_mcp_config(mut self, mcp_config: McpClientConfig) -> Self {
205 self.mcp_client_config = Some(mcp_config);
206 self
207 }
208
209 pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
212 self.loader_config = Some(loader_config);
213 self
214 }
215}
216
217#[derive(Clone)]
218pub struct MistralRsConfig {
219 pub kind: ModelKind,
220 pub device: Device,
221 pub category: ModelCategory,
222 pub modalities: Modalities,
223 pub max_seq_len: Option<usize>,
224 pub generation_defaults: Option<ModelGenerationDefaults>,
225}
226
227#[derive(Clone)]
230pub struct ModelLoaderConfig {
231 pub model_selected: ModelSelected,
233 pub token_source: TokenSource,
235 pub hf_revision: Option<String>,
237 pub dtype: ModelDType,
239 pub device: Device,
241 pub device_map_setting: DeviceMapSetting,
243 pub isq: Option<IsqType>,
245 pub paged_attn_config: Option<PagedAttentionConfig>,
247 pub silent: bool,
249 pub chat_template: Option<String>,
251 pub jinja_explicit: Option<String>,
253}
254
255#[derive(Clone)]
258pub struct UnloadedModelState {
259 pub loader_config: ModelLoaderConfig,
261 pub scheduler_config: SchedulerConfig,
263 pub engine_config: EngineConfig,
265 pub mcp_client_config: Option<McpClientConfig>,
267 pub category: ModelCategory,
269 pub mistralrs_config: MistralRsConfig,
271}
272
273struct EngineInstance {
275 sender: Sender<Request>,
276 engine_handler: JoinHandle<()>,
277 reboot_state: RebootState,
278 config: MistralRsConfig,
279 category: ModelCategory,
280 logger: Arc<IntervalLogger>,
281}
282
283pub struct MistralRs {
300 engines: RwLock<HashMap<String, EngineInstance>>,
301 unloaded_models: RwLock<HashMap<String, UnloadedModelState>>,
303 reloading_models: RwLock<HashSet<String>>,
305 default_engine_id: RwLock<Option<String>>,
306 model_aliases: RwLock<HashMap<String, String>>,
308 log: Option<String>,
309 id: String,
310 creation_time: u64,
311 next_request_id: Mutex<RefCell<usize>>,
312}
313
314#[derive(Clone)]
315struct RebootState {
316 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
317 method: SchedulerConfig,
318 no_kv_cache: bool,
319 no_prefix_cache: bool,
320 prefix_cache_n: usize,
321 disable_eos_stop: bool,
322 throughput_logging_enabled: bool,
323 search_embedding_model: Option<SearchEmbeddingModel>,
324 search_callback: Option<Arc<search::SearchCallback>>,
325 tool_callbacks: tools::ToolCallbacks,
326 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
327 mcp_client_config: Option<McpClientConfig>,
328 loader_config: Option<ModelLoaderConfig>,
330}
331
332#[derive(Debug, Clone, Copy, PartialEq, Eq)]
334pub enum ModelStatus {
335 Loaded,
336 Unloaded,
337 Reloading,
338}
339
340impl std::fmt::Display for ModelStatus {
341 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
342 match self {
343 ModelStatus::Loaded => write!(f, "loaded"),
344 ModelStatus::Unloaded => write!(f, "unloaded"),
345 ModelStatus::Reloading => write!(f, "reloading"),
346 }
347 }
348}
349
350#[derive(Debug)]
351pub enum MistralRsError {
352 EnginePoisoned,
353 SenderPoisoned,
354 ModelNotFound(String),
356 ModelReloading(String),
358 ReloadFailed(String),
360 NoLoaderConfig(String),
362 ModelAlreadyLoaded(String),
364 ModelAlreadyUnloaded(String),
366}
367
368impl std::fmt::Display for MistralRsError {
369 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
370 write!(f, "{:?}", &self)
371 }
372}
373
374impl std::error::Error for MistralRsError {}
375
376#[cfg(feature = "pyo3_macros")]
377impl From<MistralRsError> for pyo3::PyErr {
378 fn from(value: MistralRsError) -> Self {
379 PyValueError::new_err(format!("{value:?}"))
380 }
381}
382
383pub struct MistralRsBuilder {
387 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
388 method: SchedulerConfig,
389 model_id_override: Option<String>,
390 log: Option<String>,
391 no_kv_cache: Option<bool>,
392 no_prefix_cache: Option<bool>,
393 prefix_cache_n: Option<usize>,
394 disable_eos_stop: Option<bool>,
395 throughput_logging_enabled: bool,
396 search_embedding_model: Option<SearchEmbeddingModel>,
397 search_callback: Option<Arc<SearchCallback>>,
398 tool_callbacks: tools::ToolCallbacks,
399 tool_callbacks_with_tools: tools::ToolCallbacksWithTools,
400 mcp_client_config: Option<McpClientConfig>,
401 loader_config: Option<ModelLoaderConfig>,
402}
403
404impl MistralRsBuilder {
405 pub fn new(
409 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
410 method: SchedulerConfig,
411 throughput_logging: bool,
412 search_embedding_model: Option<SearchEmbeddingModel>,
413 ) -> Self {
414 Self {
415 pipeline,
416 method,
417 model_id_override: None,
418 log: None,
419 no_kv_cache: None,
420 no_prefix_cache: None,
421 prefix_cache_n: None,
422 disable_eos_stop: None,
423 throughput_logging_enabled: throughput_logging,
424 search_embedding_model,
425 search_callback: None,
426 tool_callbacks: HashMap::new(),
427 tool_callbacks_with_tools: HashMap::new(),
428 mcp_client_config: None,
429 loader_config: None,
430 }
431 }
432
433 pub fn with_model_id(mut self, model_id: impl Into<String>) -> Self {
435 self.model_id_override = Some(model_id.into());
436 self
437 }
438
439 pub fn with_loader_config(mut self, loader_config: ModelLoaderConfig) -> Self {
442 self.loader_config = Some(loader_config);
443 self
444 }
445 pub fn with_log(mut self, log: String) -> Self {
446 self.log = Some(log);
447 self
448 }
449 pub fn with_opt_log(mut self, log: Option<String>) -> Self {
450 self.log = log;
451 self
452 }
453 pub fn with_no_kv_cache(mut self, no_kv_cache: bool) -> Self {
454 self.no_kv_cache = Some(no_kv_cache);
455 self
456 }
457 pub fn with_no_prefix_cache(mut self, no_prefix_cache: bool) -> Self {
458 self.no_prefix_cache = Some(no_prefix_cache);
459 self
460 }
461 pub fn with_prefix_cache_n(mut self, prefix_cache_n: usize) -> Self {
462 self.prefix_cache_n = Some(prefix_cache_n);
463 self
464 }
465 pub fn with_disable_eos_stop(mut self, disable_eos_stop: bool) -> Self {
466 self.disable_eos_stop = Some(disable_eos_stop);
467 self
468 }
469
470 pub fn with_search_callback(mut self, search_callback: Arc<SearchCallback>) -> Self {
472 self.search_callback = Some(search_callback);
473 self
474 }
475
476 pub fn with_tool_callback(
478 mut self,
479 name: impl Into<String>,
480 tool_callback: Arc<ToolCallback>,
481 ) -> Self {
482 self.tool_callbacks.insert(name.into(), tool_callback);
483 self
484 }
485
486 pub fn with_tool_callback_and_tool(
489 mut self,
490 name: impl Into<String>,
491 tool_callback: Arc<ToolCallback>,
492 tool: Tool,
493 ) -> Self {
494 let name = name.into();
495 self.tool_callbacks_with_tools.insert(
496 name,
497 ToolCallbackWithTool {
498 callback: tool_callback,
499 tool,
500 },
501 );
502 self
503 }
504
505 pub fn with_mcp_client(mut self, config: McpClientConfig) -> Self {
507 self.mcp_client_config = Some(config);
508 self
509 }
510
511 pub async fn build(self) -> Arc<MistralRs> {
512 MistralRs::new(self).await
513 }
514}
515
516impl Drop for MistralRs {
517 fn drop(&mut self) {
518 if let Ok(engines) = self.engines.read() {
520 for (_, engine) in engines.iter() {
521 let _ = engine.sender.try_send(Request::Terminate);
523 }
524 }
525 }
526}
527
528impl MistralRs {
529 fn create_engine_instance(
531 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
532 method: SchedulerConfig,
533 config: EngineConfig,
534 reboot_state: RebootState,
535 ) -> Result<EngineInstance, String> {
536 let (tx, rx) = channel(10_000);
537
538 let pipeline_guard = pipeline.try_lock().unwrap();
539 let category = pipeline_guard.category();
540 let metadata = pipeline_guard.get_metadata();
541 let kind = metadata.kind.clone();
542 let device = pipeline_guard.device();
543 let modalities = metadata.modalities.clone();
544 let max_seq_len = match &category {
545 ModelCategory::Diffusion | ModelCategory::Speech => None,
546 _ => Some(metadata.max_seq_len),
547 };
548 let generation_defaults = pipeline_guard.generation_defaults();
549 let encoder_cache_counters = pipeline_guard.encoder_cache_counters();
550 drop(pipeline_guard);
551
552 let logger = Arc::new(IntervalLogger::new(
553 Duration::from_secs(5),
554 encoder_cache_counters,
555 ));
556 let logger_for_engine = logger.clone();
557
558 info!("Pipeline input modalities are {:?}", &modalities.input);
559 info!("Pipeline output modalities are {:?}", &modalities.output);
560
561 let mistralrs_config = MistralRsConfig {
562 kind,
563 device,
564 category: category.clone(),
565 modalities,
566 max_seq_len,
567 generation_defaults,
568 };
569
570 let tx_for_engine = tx.clone();
571 let engine_handler = thread::spawn(move || {
572 #[cfg(feature = "metal")]
573 objc::rc::autoreleasepool(move || {
574 let rt = Runtime::new().unwrap();
575 rt.block_on(async move {
576 let engine = Engine::new(
577 tx_for_engine,
578 rx,
579 pipeline,
580 method,
581 config.no_kv_cache,
582 config.no_prefix_cache,
583 config.prefix_cache_n,
584 config.disable_eos_stop,
585 config.throughput_logging_enabled,
586 config.search_embedding_model,
587 config.search_callback.clone(),
588 config.tool_callbacks.clone(),
589 config.tool_callbacks_with_tools.clone(),
590 logger_for_engine,
591 )
592 .expect("Engine creation failed.");
593 Arc::new(engine).run().await;
594 })
595 });
596
597 #[cfg(not(feature = "metal"))]
598 {
599 let rt = Runtime::new().unwrap();
600 rt.block_on(async move {
601 let engine = Engine::new(
602 tx_for_engine,
603 rx,
604 pipeline,
605 method,
606 config.no_kv_cache,
607 config.no_prefix_cache,
608 config.prefix_cache_n,
609 config.disable_eos_stop,
610 config.throughput_logging_enabled,
611 config.search_embedding_model,
612 config.search_callback.clone(),
613 config.tool_callbacks.clone(),
614 config.tool_callbacks_with_tools.clone(),
615 logger_for_engine,
616 )
617 .expect("Engine creation failed.");
618 Arc::new(engine).run().await;
619 })
620 }
621 });
622
623 Ok(EngineInstance {
624 sender: tx,
625 engine_handler,
626 reboot_state,
627 config: mistralrs_config,
628 category,
629 logger,
630 })
631 }
632
633 async fn new(config: MistralRsBuilder) -> Arc<Self> {
634 info!("git revision: {MISTRALRS_GIT_REVISION}");
635 let MistralRsBuilder {
636 pipeline,
637 method,
638 model_id_override,
639 log,
640 no_kv_cache,
641 no_prefix_cache,
642 prefix_cache_n,
643 disable_eos_stop,
644 throughput_logging_enabled,
645 search_embedding_model,
646 search_callback,
647 tool_callbacks,
648 mut tool_callbacks_with_tools,
649 mcp_client_config,
650 loader_config,
651 } = config;
652
653 mistralrs_quant::cublaslt::maybe_init_cublas_lt_wrapper(
654 get_mut_arcmutex!(pipeline).device(),
655 );
656
657 let no_kv_cache = no_kv_cache.unwrap_or(false);
658 let no_prefix_cache = no_prefix_cache.unwrap_or(false);
659 let prefix_cache_n = prefix_cache_n.unwrap_or(16);
660 let disable_eos_stop = disable_eos_stop.unwrap_or(false);
661
662 if let Some(config) = &mcp_client_config {
664 let mut mcp_client = McpClient::new(config.clone());
665 let total_servers = config.servers.len();
666
667 match mcp_client.initialize().await {
668 Ok(()) => {
669 let mcp_callbacks_with_tools = mcp_client.get_tool_callbacks_with_tools();
670 let tools_count = mcp_callbacks_with_tools.len();
671
672 for (name, callback_with_tool) in mcp_callbacks_with_tools {
674 tool_callbacks_with_tools.insert(name.clone(), callback_with_tool.clone());
675 }
676
677 if tools_count == 0 {
678 warn!(
679 "MCP client initialized but no tools were registered from {} servers",
680 total_servers
681 );
682 } else {
683 info!(
684 "MCP client initialized successfully with {} tools from {} servers",
685 tools_count, total_servers
686 );
687 }
688 }
689 Err(e) => {
690 warn!(
691 "Failed to initialize MCP client with {} configured servers: {}",
692 total_servers, e
693 );
694 warn!("Continuing without MCP functionality. Check your MCP configuration and server availability.");
695 }
696 }
697 }
698
699 let reboot_state = RebootState {
700 pipeline: pipeline.clone(),
701 method: method.clone(),
702 no_kv_cache,
703 no_prefix_cache,
704 prefix_cache_n,
705 disable_eos_stop,
706 throughput_logging_enabled,
707 search_embedding_model,
708 search_callback: search_callback.clone(),
709 tool_callbacks: tool_callbacks.clone(),
710 tool_callbacks_with_tools: tool_callbacks_with_tools.clone(),
711 mcp_client_config: mcp_client_config.clone(),
712 loader_config,
713 };
714
715 let engine_config = EngineConfig {
717 no_kv_cache,
718 no_prefix_cache,
719 prefix_cache_n,
720 disable_eos_stop,
721 throughput_logging_enabled,
722 search_embedding_model,
723 search_callback,
724 tool_callbacks,
725 tool_callbacks_with_tools,
726 };
727
728 let engine_instance =
730 Self::create_engine_instance(pipeline.clone(), method, engine_config, reboot_state)
731 .expect("Failed to create engine instance");
732
733 let pipeline_name = pipeline.try_lock().unwrap().name();
734 let (id, alias_map) = match model_id_override {
735 Some(override_id) => {
736 let mut alias_map = HashMap::new();
737 if override_id != pipeline_name {
738 alias_map.insert(pipeline_name.clone(), override_id.clone());
739 }
740 (override_id, alias_map)
741 }
742 None => (pipeline_name.clone(), HashMap::new()),
743 };
744
745 if distributed::is_daemon() {
746 let request_sender = engine_instance.sender.clone();
747
748 if cfg!(feature = "ring") {
749 distributed::ring_daemon_replicator(request_sender);
751 } else {
752 distributed::nccl_daemon_replicator(request_sender);
754 }
755
756 #[allow(clippy::empty_loop)]
757 loop {}
758 }
759
760 let is_multi_threaded = tokio::runtime::Handle::try_current()
762 .is_ok_and(|h| h.runtime_flavor() != tokio::runtime::RuntimeFlavor::CurrentThread);
763
764 if !distributed::is_daemon()
766 && is_multi_threaded
767 && matches!(
768 engine_instance.category,
769 ModelCategory::Text | ModelCategory::Multimodal { .. }
770 )
771 {
772 let clone_sender = engine_instance.sender.clone();
773 tokio::task::block_in_place(|| {
774 let (tx, mut rx) = channel(1);
775 let req = Request::Normal(Box::new(NormalRequest {
776 id: 0,
777 messages: RequestMessage::Completion {
778 text: "hello".to_string(),
779 echo_prompt: false,
780 best_of: None,
781 },
782 sampling_params: SamplingParams {
783 max_len: Some(1),
784 ..SamplingParams::deterministic()
785 },
786 response: tx,
787 return_logprobs: false,
788 is_streaming: false,
789 constraint: Constraint::None,
790 suffix: None,
791 tool_choice: None,
792 tools: None,
793 logits_processors: None,
794 return_raw_logits: false,
795 web_search_options: None,
796 model_id: None,
797 truncate_sequence: false,
798 }));
799 info!("Beginning dummy run.");
800 let start = Instant::now();
801 clone_sender.blocking_send(req).unwrap();
802
803 let mut received_any = false;
805 while let Some(_resp) = rx.blocking_recv() {
806 received_any = true;
807 }
808
809 if received_any {
810 let end = Instant::now();
811 info!(
812 "Dummy run completed in {}s.",
813 end.duration_since(start).as_secs_f64()
814 );
815 } else {
816 warn!("Dummy run failed!");
817 }
818 });
819
820 engine_instance.logger.reset();
822 }
823
824 let mut engines = HashMap::new();
826 engines.insert(id.clone(), engine_instance);
827
828 Arc::new(Self {
829 engines: RwLock::new(engines),
830 unloaded_models: RwLock::new(HashMap::new()),
831 reloading_models: RwLock::new(HashSet::new()),
832 default_engine_id: RwLock::new(Some(id.clone())),
833 model_aliases: RwLock::new(alias_map),
834 log,
835 id,
836 creation_time: SystemTime::now()
837 .duration_since(UNIX_EPOCH)
838 .expect("Time travel has occurred!")
839 .as_secs(),
840 next_request_id: Mutex::new(RefCell::new(1)),
841 })
842 }
843
844 fn reboot_engine(&self, model_id: &str) -> Result<(), MistralRsError> {
846 let mut engines = self.engines.write().map_err(|_| {
847 tracing::warn!("Couldn't get write lock on engines during reboot attempt");
848 MistralRsError::EnginePoisoned
849 })?;
850
851 if let Some(engine_instance) = engines.get(model_id) {
852 if !engine_instance.engine_handler.is_finished() {
853 tracing::info!("Engine {} already running, returning ok", model_id);
854 return Ok(());
855 }
856
857 let reboot_state = engine_instance.reboot_state.clone();
858 let engine_config = EngineConfig {
859 no_kv_cache: reboot_state.no_kv_cache,
860 no_prefix_cache: reboot_state.no_prefix_cache,
861 prefix_cache_n: reboot_state.prefix_cache_n,
862 disable_eos_stop: reboot_state.disable_eos_stop,
863 throughput_logging_enabled: reboot_state.throughput_logging_enabled,
864 search_embedding_model: reboot_state.search_embedding_model,
865 search_callback: reboot_state.search_callback.clone(),
866 tool_callbacks: reboot_state.tool_callbacks.clone(),
867 tool_callbacks_with_tools: reboot_state.tool_callbacks_with_tools.clone(),
868 };
869 let new_engine_instance = Self::create_engine_instance(
870 reboot_state.pipeline.clone(),
871 reboot_state.method.clone(),
872 engine_config,
873 reboot_state,
874 )
875 .map_err(|e| {
876 tracing::error!("Failed to create new engine instance: {}", e);
877 MistralRsError::EnginePoisoned
878 })?;
879
880 engines.insert(model_id.to_string(), new_engine_instance);
881 tracing::info!("Successfully rebooted engine {}", model_id);
882 Ok(())
883 } else {
884 Err(MistralRsError::EnginePoisoned)
885 }
886 }
887
888 fn engine_dead(&self, model_id: &str) -> Result<bool, MistralRsError> {
889 let engines = self.engines.read().map_err(|_| {
890 tracing::warn!("Couldn't get read lock on engines!");
891 MistralRsError::EnginePoisoned
892 })?;
893
894 if let Some(engine_instance) = engines.get(model_id) {
895 Ok(engine_instance.engine_handler.is_finished())
896 } else {
897 Err(MistralRsError::EnginePoisoned)
898 }
899 }
900
901 pub fn get_sender(&self, model_id: Option<&str>) -> Result<Sender<Request>, MistralRsError> {
904 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
905
906 let is_loaded = {
908 let engines = self
909 .engines
910 .read()
911 .map_err(|_| MistralRsError::SenderPoisoned)?;
912 engines.contains_key(&resolved_model_id)
913 };
914
915 if is_loaded {
916 if self.engine_dead(&resolved_model_id)? {
918 tracing::warn!("Engine {} is dead, rebooting", resolved_model_id);
919 self.reboot_engine(&resolved_model_id)?
920 }
921
922 let engines = self
923 .engines
924 .read()
925 .map_err(|_| MistralRsError::SenderPoisoned)?;
926 if let Some(engine_instance) = engines.get(&resolved_model_id) {
927 return Ok(engine_instance.sender.clone());
928 }
929 }
930
931 let is_unloaded = {
933 let unloaded = self
934 .unloaded_models
935 .read()
936 .map_err(|_| MistralRsError::EnginePoisoned)?;
937 unloaded.contains_key(&resolved_model_id)
938 };
939
940 if is_unloaded {
941 tracing::info!(
942 "Model {} is unloaded, triggering auto-reload",
943 resolved_model_id
944 );
945 self.reload_model_blocking(&resolved_model_id)?;
946
947 let engines = self
949 .engines
950 .read()
951 .map_err(|_| MistralRsError::SenderPoisoned)?;
952 if let Some(engine_instance) = engines.get(&resolved_model_id) {
953 return Ok(engine_instance.sender.clone());
954 }
955 }
956
957 Err(MistralRsError::ModelNotFound(resolved_model_id))
958 }
959
960 pub fn get_id(&self) -> String {
961 self.id.clone()
962 }
963
964 pub fn get_creation_time(&self) -> u64 {
965 self.creation_time
966 }
967
968 fn resolve_alias(&self, model_id: &str) -> Result<String, MistralRsError> {
969 let aliases = self
970 .model_aliases
971 .read()
972 .map_err(|_| MistralRsError::SenderPoisoned)?;
973 if let Some(primary_id) = aliases.get(model_id) {
974 Ok(primary_id.clone())
975 } else {
976 Ok(model_id.to_string())
977 }
978 }
979
980 fn resolve_alias_or_default(&self, model_id: Option<&str>) -> Result<String, MistralRsError> {
981 match model_id {
982 Some(id) => self.resolve_alias(id),
983 None => {
984 let default_lock = self
985 .default_engine_id
986 .read()
987 .map_err(|_| MistralRsError::SenderPoisoned)?;
988 Ok(default_lock
989 .as_ref()
990 .ok_or(MistralRsError::EnginePoisoned)?
991 .clone())
992 }
993 }
994 }
995
996 pub fn register_model_alias(
998 &self,
999 alias: impl Into<String>,
1000 model_id: &str,
1001 ) -> Result<(), String> {
1002 let alias = alias.into();
1003 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1004
1005 if alias == resolved_model_id {
1006 return Ok(());
1007 }
1008
1009 let reloading = self
1010 .reloading_models
1011 .read()
1012 .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1013 let model_reloading = reloading.contains(&resolved_model_id);
1014 let alias_conflict = reloading.contains(&alias);
1015 drop(reloading);
1016
1017 let engines = self
1018 .engines
1019 .read()
1020 .map_err(|_| "Failed to acquire read lock on engines")?;
1021 let model_loaded = engines.contains_key(&resolved_model_id);
1022 let alias_conflict = alias_conflict || engines.contains_key(&alias);
1023 drop(engines);
1024
1025 let unloaded = self
1026 .unloaded_models
1027 .read()
1028 .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1029 let model_unloaded = unloaded.contains_key(&resolved_model_id);
1030 let alias_conflict = alias_conflict || unloaded.contains_key(&alias);
1031 drop(unloaded);
1032
1033 if !(model_loaded || model_unloaded || model_reloading) {
1034 return Err(format!("Model {resolved_model_id} not found"));
1035 }
1036
1037 if alias_conflict {
1038 return Err(format!(
1039 "Alias '{}' conflicts with an existing model ID",
1040 alias
1041 ));
1042 }
1043
1044 let mut aliases = self
1045 .model_aliases
1046 .write()
1047 .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1048 if let Some(existing) = aliases.get(&alias) {
1049 if existing == &resolved_model_id {
1050 return Ok(());
1051 }
1052 return Err(format!(
1053 "Alias '{}' is already assigned to model '{}'",
1054 alias, existing
1055 ));
1056 }
1057 aliases.insert(alias, resolved_model_id);
1058 Ok(())
1059 }
1060
1061 pub fn model_exists(&self, model_id: &str) -> Result<bool, MistralRsError> {
1063 let resolved_model_id = self.resolve_alias(model_id)?;
1064
1065 let reloading = self
1066 .reloading_models
1067 .read()
1068 .map_err(|_| MistralRsError::EnginePoisoned)?;
1069 if reloading.contains(&resolved_model_id) {
1070 return Ok(true);
1071 }
1072 drop(reloading);
1073
1074 let engines = self
1075 .engines
1076 .read()
1077 .map_err(|_| MistralRsError::EnginePoisoned)?;
1078 if engines.contains_key(&resolved_model_id) {
1079 return Ok(true);
1080 }
1081 drop(engines);
1082
1083 let unloaded = self
1084 .unloaded_models
1085 .read()
1086 .map_err(|_| MistralRsError::EnginePoisoned)?;
1087 if unloaded.contains_key(&resolved_model_id) {
1088 return Ok(true);
1089 }
1090
1091 Ok(false)
1092 }
1093
1094 pub fn get_logger(
1096 &self,
1097 model_id: Option<&str>,
1098 ) -> Result<Arc<IntervalLogger>, MistralRsError> {
1099 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1100
1101 let engines = self
1102 .engines
1103 .read()
1104 .map_err(|_| MistralRsError::SenderPoisoned)?;
1105 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1106 Ok(engine_instance.logger.clone())
1107 } else {
1108 Err(MistralRsError::EnginePoisoned)
1109 }
1110 }
1111
1112 pub fn get_model_category(
1114 &self,
1115 model_id: Option<&str>,
1116 ) -> Result<ModelCategory, MistralRsError> {
1117 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1118
1119 let engines = self
1120 .engines
1121 .read()
1122 .map_err(|_| MistralRsError::SenderPoisoned)?;
1123 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1124 Ok(engine_instance.category.clone())
1125 } else {
1126 Err(MistralRsError::EnginePoisoned)
1127 }
1128 }
1129
1130 pub fn max_sequence_length(
1132 &self,
1133 model_id: Option<&str>,
1134 ) -> Result<Option<usize>, MistralRsError> {
1135 let resolved_model_id = self.resolve_alias_or_default(model_id)?;
1136
1137 let engines = self
1138 .engines
1139 .read()
1140 .map_err(|_| MistralRsError::SenderPoisoned)?;
1141 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1142 Ok(engine_instance.config.max_seq_len)
1143 } else {
1144 Err(MistralRsError::EnginePoisoned)
1145 }
1146 }
1147
1148 pub fn next_request_id(&self) -> usize {
1149 let l = self.next_request_id.lock().unwrap();
1150 let last = &mut *l.borrow_mut();
1151 let last_v = *last;
1152 *last += 1;
1153 last_v
1154 }
1155
1156 pub async fn add_model(
1158 &self,
1159 model_id: String,
1160 pipeline: Arc<tokio::sync::Mutex<dyn Pipeline>>,
1161 method: SchedulerConfig,
1162 config: AddModelConfig,
1163 ) -> Result<(), String> {
1164 {
1165 let reloading = self
1166 .reloading_models
1167 .read()
1168 .map_err(|_| "Failed to acquire read lock on reloading_models")?;
1169 if reloading.contains(&model_id) {
1170 return Err(format!("Model {model_id} is currently reloading"));
1171 }
1172 }
1173 {
1174 let engines = self
1175 .engines
1176 .read()
1177 .map_err(|_| "Failed to acquire read lock on engines")?;
1178 if engines.contains_key(&model_id) {
1179 return Err(format!("Model {model_id} already exists"));
1180 }
1181 }
1182 {
1183 let unloaded = self
1184 .unloaded_models
1185 .read()
1186 .map_err(|_| "Failed to acquire read lock on unloaded_models")?;
1187 if unloaded.contains_key(&model_id) {
1188 return Err(format!("Model {model_id} already exists (unloaded)"));
1189 }
1190 }
1191 {
1192 let aliases = self
1193 .model_aliases
1194 .read()
1195 .map_err(|_| "Failed to acquire read lock on model_aliases")?;
1196 if aliases.contains_key(&model_id) {
1197 return Err(format!(
1198 "Model ID '{}' conflicts with an existing alias",
1199 model_id
1200 ));
1201 }
1202 }
1203
1204 let reboot_state = RebootState {
1205 pipeline: pipeline.clone(),
1206 method: method.clone(),
1207 no_kv_cache: config.engine_config.no_kv_cache,
1208 no_prefix_cache: config.engine_config.no_prefix_cache,
1209 prefix_cache_n: config.engine_config.prefix_cache_n,
1210 disable_eos_stop: config.engine_config.disable_eos_stop,
1211 throughput_logging_enabled: config.engine_config.throughput_logging_enabled,
1212 search_embedding_model: config.engine_config.search_embedding_model,
1213 search_callback: config.engine_config.search_callback.clone(),
1214 tool_callbacks: config.engine_config.tool_callbacks.clone(),
1215 tool_callbacks_with_tools: config.engine_config.tool_callbacks_with_tools.clone(),
1216 mcp_client_config: config.mcp_client_config.clone(),
1217 loader_config: config.loader_config.clone(),
1218 };
1219
1220 let engine_instance =
1221 Self::create_engine_instance(pipeline, method, config.engine_config, reboot_state)?;
1222
1223 let mut engines = self
1224 .engines
1225 .write()
1226 .map_err(|_| "Failed to acquire write lock on engines")?;
1227 engines.insert(model_id.clone(), engine_instance);
1228
1229 if engines.len() == 1 {
1231 let mut default_lock = self
1232 .default_engine_id
1233 .write()
1234 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1235 *default_lock = Some(model_id.clone());
1236 info!("First model added, setting '{}' as default", model_id);
1237 }
1238
1239 Ok(())
1240 }
1241
1242 pub fn remove_model(&self, model_id: &str) -> Result<(), String> {
1244 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1245 let mut engines = self
1246 .engines
1247 .write()
1248 .map_err(|_| "Failed to acquire write lock on engines")?;
1249
1250 if engines.len() <= 1 {
1251 return Err("Cannot remove the last model from MistralRs".to_string());
1252 }
1253
1254 if let Some(engine_instance) = engines.remove(&resolved_model_id) {
1255 let _ = engine_instance.sender.blocking_send(Request::Terminate);
1257
1258 let mut default_lock = self
1260 .default_engine_id
1261 .write()
1262 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1263 if let Some(ref default_id) = *default_lock {
1264 if default_id == &resolved_model_id {
1265 *default_lock = engines.keys().next().cloned();
1267 }
1268 }
1269 drop(default_lock);
1270 drop(engines);
1271
1272 let mut aliases = self
1274 .model_aliases
1275 .write()
1276 .map_err(|_| "Failed to acquire write lock on model_aliases")?;
1277 aliases.retain(|_, target| target != &resolved_model_id);
1278
1279 Ok(())
1280 } else {
1281 Err(format!("Model {resolved_model_id} not found"))
1282 }
1283 }
1284
1285 pub fn list_models(&self) -> Result<Vec<String>, String> {
1287 let engines = self
1288 .engines
1289 .read()
1290 .map_err(|_| "Failed to acquire read lock on engines")?;
1291 Ok(engines.keys().cloned().collect())
1292 }
1293
1294 pub fn get_default_model_id(&self) -> Result<Option<String>, String> {
1296 let default_lock = self
1297 .default_engine_id
1298 .read()
1299 .map_err(|_| "Failed to acquire read lock on default_engine_id")?;
1300 Ok(default_lock.clone())
1301 }
1302
1303 pub fn set_default_model_id(&self, model_id: &str) -> Result<(), String> {
1305 let resolved_model_id = self.resolve_alias(model_id).map_err(|e| e.to_string())?;
1306 let engines = self
1307 .engines
1308 .read()
1309 .map_err(|_| "Failed to acquire read lock on engines")?;
1310 if !engines.contains_key(&resolved_model_id) {
1311 return Err(format!("Model {resolved_model_id} not found"));
1312 }
1313 drop(engines);
1314
1315 let mut default_lock = self
1316 .default_engine_id
1317 .write()
1318 .map_err(|_| "Failed to acquire write lock on default_engine_id")?;
1319 let old_default = default_lock.clone();
1320 *default_lock = Some(resolved_model_id.clone());
1321
1322 info!(
1324 "Default model changed: {:?} -> {:?}",
1325 old_default, resolved_model_id
1326 );
1327
1328 Ok(())
1329 }
1330
1331 pub fn send_request(&self, mut request: Request) -> Result<(), MistralRsError> {
1333 let model_id = match &mut request {
1334 Request::Normal(normal_req) => normal_req.model_id.as_deref(),
1335 _ => None, };
1337
1338 let sender = self.get_sender(model_id)?;
1339 sender
1340 .blocking_send(request)
1341 .map_err(|_| MistralRsError::SenderPoisoned)
1342 }
1343
1344 pub fn maybe_log_request(this: Arc<Self>, repr: String) {
1345 if let Some(file) = &this.log {
1346 let mut f = OpenOptions::new()
1347 .append(true)
1348 .create(true) .open(file)
1350 .expect("Unable to open file");
1351 let time = chrono::offset::Local::now();
1352 f.write_all(format!("Request at {time}: {repr}\n\n").as_bytes())
1353 .expect("Unable to write data");
1354 }
1355 }
1356
1357 pub fn maybe_log_response<T: Serialize>(this: Arc<Self>, resp: &T) {
1358 if let Some(file) = &this.log {
1359 let mut f = OpenOptions::new()
1360 .append(true)
1361 .create(true) .open(file)
1363 .expect("Unable to open file");
1364 let time = chrono::offset::Local::now();
1365 let repr = serde_json::to_string(resp).expect("Serialization of response failed.");
1366 f.write_all(format!("Response at {time}: {repr}\n\n").as_bytes())
1367 .expect("Unable to write data");
1368 }
1369 }
1370
1371 pub fn maybe_log_error(this: Arc<Self>, err: &dyn Error) {
1372 if let Some(file) = &this.log {
1373 let mut f = OpenOptions::new()
1374 .append(true)
1375 .create(true) .open(file)
1377 .expect("Unable to open file");
1378 let time = chrono::offset::Local::now();
1379 f.write_all(format!("Error response at {time}: {err}\n\n").as_bytes())
1380 .expect("Unable to write data");
1381 }
1382 }
1383
1384 pub fn get_tools_count(&self, model_id: Option<&str>) -> Result<usize, String> {
1386 let resolved_model_id = self
1387 .resolve_alias_or_default(model_id)
1388 .map_err(|e| e.to_string())?;
1389
1390 let engines = self
1391 .engines
1392 .read()
1393 .map_err(|_| "Failed to acquire read lock on engines")?;
1394 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1395 Ok(engine_instance.reboot_state.tool_callbacks_with_tools.len())
1396 } else {
1397 Err(format!("Model {resolved_model_id} not found"))
1398 }
1399 }
1400
1401 pub fn has_mcp_client(&self, model_id: Option<&str>) -> Result<bool, String> {
1403 let resolved_model_id = self
1404 .resolve_alias_or_default(model_id)
1405 .map_err(|e| e.to_string())?;
1406
1407 let engines = self
1408 .engines
1409 .read()
1410 .map_err(|_| "Failed to acquire read lock on engines")?;
1411 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1412 Ok(engine_instance.reboot_state.mcp_client_config.is_some())
1413 } else {
1414 Err(format!("Model {resolved_model_id} not found"))
1415 }
1416 }
1417
1418 pub fn config(&self, model_id: Option<&str>) -> Result<MistralRsConfig, String> {
1420 let resolved_model_id = self
1421 .resolve_alias_or_default(model_id)
1422 .map_err(|e| e.to_string())?;
1423
1424 let engines = self
1425 .engines
1426 .read()
1427 .map_err(|_| "Failed to acquire read lock on engines")?;
1428 if let Some(engine_instance) = engines.get(&resolved_model_id) {
1429 Ok(engine_instance.config.clone())
1430 } else {
1431 Err(format!("Model {resolved_model_id} not found"))
1432 }
1433 }
1434
1435 pub fn unload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1442 let resolved_model_id = self.resolve_alias(model_id)?;
1443 {
1445 let unloaded = self
1446 .unloaded_models
1447 .read()
1448 .map_err(|_| MistralRsError::EnginePoisoned)?;
1449 if unloaded.contains_key(&resolved_model_id) {
1450 return Err(MistralRsError::ModelAlreadyUnloaded(
1451 resolved_model_id.clone(),
1452 ));
1453 }
1454 }
1455
1456 let mut engines = self
1458 .engines
1459 .write()
1460 .map_err(|_| MistralRsError::EnginePoisoned)?;
1461
1462 let engine_instance = engines
1463 .remove(&resolved_model_id)
1464 .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?;
1465
1466 let loader_config = engine_instance
1468 .reboot_state
1469 .loader_config
1470 .clone()
1471 .ok_or_else(|| MistralRsError::NoLoaderConfig(resolved_model_id.clone()))?;
1472
1473 let unloaded_state = UnloadedModelState {
1475 loader_config,
1476 scheduler_config: engine_instance.reboot_state.method.clone(),
1477 engine_config: EngineConfig {
1478 no_kv_cache: engine_instance.reboot_state.no_kv_cache,
1479 no_prefix_cache: engine_instance.reboot_state.no_prefix_cache,
1480 prefix_cache_n: engine_instance.reboot_state.prefix_cache_n,
1481 disable_eos_stop: engine_instance.reboot_state.disable_eos_stop,
1482 throughput_logging_enabled: engine_instance.reboot_state.throughput_logging_enabled,
1483 search_embedding_model: engine_instance.reboot_state.search_embedding_model,
1484 search_callback: engine_instance.reboot_state.search_callback.clone(),
1485 tool_callbacks: engine_instance.reboot_state.tool_callbacks.clone(),
1486 tool_callbacks_with_tools: engine_instance
1487 .reboot_state
1488 .tool_callbacks_with_tools
1489 .clone(),
1490 },
1491 mcp_client_config: engine_instance.reboot_state.mcp_client_config.clone(),
1492 category: engine_instance.category.clone(),
1493 mistralrs_config: engine_instance.config.clone(),
1494 };
1495
1496 let _ = engine_instance.sender.try_send(Request::Terminate);
1498
1499 drop(engines);
1500
1501 let mut unloaded = self
1503 .unloaded_models
1504 .write()
1505 .map_err(|_| MistralRsError::EnginePoisoned)?;
1506 unloaded.insert(resolved_model_id.to_string(), unloaded_state);
1507
1508 let mut default_lock = self
1510 .default_engine_id
1511 .write()
1512 .map_err(|_| MistralRsError::EnginePoisoned)?;
1513 if let Some(ref default_id) = *default_lock {
1514 if default_id == &resolved_model_id {
1515 let engines = self
1517 .engines
1518 .read()
1519 .map_err(|_| MistralRsError::EnginePoisoned)?;
1520 *default_lock = engines.keys().next().cloned();
1521 }
1522 }
1523
1524 info!("Model {} unloaded successfully", resolved_model_id);
1525 Ok(())
1526 }
1527
1528 pub async fn reload_model(&self, model_id: &str) -> Result<(), MistralRsError> {
1531 let resolved_model_id = self.resolve_alias(model_id)?;
1532 {
1534 let reloading = self
1535 .reloading_models
1536 .read()
1537 .map_err(|_| MistralRsError::EnginePoisoned)?;
1538 if reloading.contains(&resolved_model_id) {
1539 return Err(MistralRsError::ModelReloading(resolved_model_id.clone()));
1540 }
1541 }
1542
1543 {
1545 let mut reloading = self
1546 .reloading_models
1547 .write()
1548 .map_err(|_| MistralRsError::EnginePoisoned)?;
1549 reloading.insert(resolved_model_id.clone());
1550 }
1551
1552 let unloaded_state = {
1554 let unloaded = self
1555 .unloaded_models
1556 .read()
1557 .map_err(|_| MistralRsError::EnginePoisoned)?;
1558 unloaded
1559 .get(&resolved_model_id)
1560 .cloned()
1561 .ok_or_else(|| MistralRsError::ModelNotFound(resolved_model_id.clone()))?
1562 };
1563
1564 let result = self
1566 .do_reload_model(&resolved_model_id, unloaded_state)
1567 .await;
1568
1569 {
1571 let mut reloading = self
1572 .reloading_models
1573 .write()
1574 .map_err(|_| MistralRsError::EnginePoisoned)?;
1575 reloading.remove(&resolved_model_id);
1576 }
1577
1578 result
1579 }
1580
1581 async fn do_reload_model(
1583 &self,
1584 model_id: &str,
1585 unloaded_state: UnloadedModelState,
1586 ) -> Result<(), MistralRsError> {
1587 use crate::model_loader::LoaderBuilder;
1588
1589 info!("Reloading model: {}", model_id);
1590
1591 let loader_config = &unloaded_state.loader_config;
1592
1593 let loader = LoaderBuilder::new(loader_config.model_selected.clone())
1595 .with_chat_template(loader_config.chat_template.clone())
1596 .with_jinja_explicit(loader_config.jinja_explicit.clone())
1597 .build()
1598 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to build loader: {e}")))?;
1599
1600 let pipeline = loader
1602 .load_model_from_hf(
1603 None,
1604 loader_config.token_source.clone(),
1605 &loader_config.dtype,
1606 &loader_config.device,
1607 loader_config.silent,
1608 loader_config.device_map_setting.clone(),
1609 loader_config.isq,
1610 loader_config.paged_attn_config,
1611 )
1612 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to load model: {e}")))?;
1613
1614 let reboot_state = RebootState {
1616 pipeline: pipeline.clone(),
1617 method: unloaded_state.scheduler_config.clone(),
1618 no_kv_cache: unloaded_state.engine_config.no_kv_cache,
1619 no_prefix_cache: unloaded_state.engine_config.no_prefix_cache,
1620 prefix_cache_n: unloaded_state.engine_config.prefix_cache_n,
1621 disable_eos_stop: unloaded_state.engine_config.disable_eos_stop,
1622 throughput_logging_enabled: unloaded_state.engine_config.throughput_logging_enabled,
1623 search_embedding_model: unloaded_state.engine_config.search_embedding_model,
1624 search_callback: unloaded_state.engine_config.search_callback.clone(),
1625 tool_callbacks: unloaded_state.engine_config.tool_callbacks.clone(),
1626 tool_callbacks_with_tools: unloaded_state
1627 .engine_config
1628 .tool_callbacks_with_tools
1629 .clone(),
1630 mcp_client_config: unloaded_state.mcp_client_config.clone(),
1631 loader_config: Some(unloaded_state.loader_config.clone()),
1632 };
1633
1634 let engine_instance = Self::create_engine_instance(
1636 pipeline,
1637 unloaded_state.scheduler_config,
1638 unloaded_state.engine_config,
1639 reboot_state,
1640 )
1641 .map_err(|e| MistralRsError::ReloadFailed(format!("Failed to create engine: {e}")))?;
1642
1643 {
1645 let mut engines = self
1646 .engines
1647 .write()
1648 .map_err(|_| MistralRsError::EnginePoisoned)?;
1649 engines.insert(model_id.to_string(), engine_instance);
1650 }
1651
1652 {
1654 let mut unloaded = self
1655 .unloaded_models
1656 .write()
1657 .map_err(|_| MistralRsError::EnginePoisoned)?;
1658 unloaded.remove(model_id);
1659 }
1660
1661 info!("Model {} reloaded successfully", model_id);
1662 Ok(())
1663 }
1664
1665 pub fn reload_model_blocking(&self, model_id: &str) -> Result<(), MistralRsError> {
1672 match tokio::runtime::Handle::try_current() {
1673 Ok(handle) => {
1674 if handle.runtime_flavor() == tokio::runtime::RuntimeFlavor::CurrentThread {
1675 Err(MistralRsError::ReloadFailed(
1676 "Cannot reload model blocking from single-threaded runtime. Use reload_model() instead.".to_string()
1677 ))
1678 } else {
1679 tokio::task::block_in_place(|| handle.block_on(self.reload_model(model_id)))
1680 }
1681 }
1682 Err(_) => {
1683 let rt = tokio::runtime::Runtime::new().map_err(|e| {
1684 MistralRsError::ReloadFailed(format!("Failed to create runtime: {e}"))
1685 })?;
1686 rt.block_on(self.reload_model(model_id))
1687 }
1688 }
1689 }
1690
1691 pub fn list_unloaded_models(&self) -> Result<Vec<String>, MistralRsError> {
1693 let unloaded = self
1694 .unloaded_models
1695 .read()
1696 .map_err(|_| MistralRsError::EnginePoisoned)?;
1697 Ok(unloaded.keys().cloned().collect())
1698 }
1699
1700 pub fn is_model_loaded(&self, model_id: &str) -> Result<bool, MistralRsError> {
1702 let resolved_model_id = self.resolve_alias(model_id)?;
1703 let engines = self
1704 .engines
1705 .read()
1706 .map_err(|_| MistralRsError::EnginePoisoned)?;
1707 Ok(engines.contains_key(&resolved_model_id))
1708 }
1709
1710 pub fn get_model_status(&self, model_id: &str) -> Result<Option<ModelStatus>, MistralRsError> {
1712 let resolved_model_id = self.resolve_alias(model_id)?;
1713 {
1715 let reloading = self
1716 .reloading_models
1717 .read()
1718 .map_err(|_| MistralRsError::EnginePoisoned)?;
1719 if reloading.contains(&resolved_model_id) {
1720 return Ok(Some(ModelStatus::Reloading));
1721 }
1722 }
1723
1724 {
1726 let engines = self
1727 .engines
1728 .read()
1729 .map_err(|_| MistralRsError::EnginePoisoned)?;
1730 if engines.contains_key(&resolved_model_id) {
1731 return Ok(Some(ModelStatus::Loaded));
1732 }
1733 }
1734
1735 {
1737 let unloaded = self
1738 .unloaded_models
1739 .read()
1740 .map_err(|_| MistralRsError::EnginePoisoned)?;
1741 if unloaded.contains_key(&resolved_model_id) {
1742 return Ok(Some(ModelStatus::Unloaded));
1743 }
1744 }
1745
1746 Ok(None)
1747 }
1748
1749 pub fn list_models_with_status(&self) -> Result<Vec<(String, ModelStatus)>, MistralRsError> {
1751 let mut result = Vec::new();
1752
1753 let reloading = self
1755 .reloading_models
1756 .read()
1757 .map_err(|_| MistralRsError::EnginePoisoned)?;
1758 for model_id in reloading.iter() {
1759 result.push((model_id.clone(), ModelStatus::Reloading));
1760 }
1761 drop(reloading);
1762
1763 let engines = self
1765 .engines
1766 .read()
1767 .map_err(|_| MistralRsError::EnginePoisoned)?;
1768 for model_id in engines.keys() {
1769 result.push((model_id.clone(), ModelStatus::Loaded));
1770 }
1771 drop(engines);
1772
1773 let unloaded = self
1775 .unloaded_models
1776 .read()
1777 .map_err(|_| MistralRsError::EnginePoisoned)?;
1778 for model_id in unloaded.keys() {
1779 if !result.iter().any(|(id, _)| id == model_id) {
1781 result.push((model_id.clone(), ModelStatus::Unloaded));
1782 }
1783 }
1784
1785 Ok(result)
1786 }
1787}