Skip to main content

mistralrs_core/pipeline/loaders/
mod.rs

1pub(crate) mod auto_device_map;
2mod diffusion_loaders;
3mod embedding_loaders;
4mod multimodal_loaders;
5mod normal_loaders;
6pub use auto_device_map::AutoDeviceMapParams;
7use auto_device_map::NonMappedSubModel;
8
9use std::{
10    fmt::{self, Debug},
11    path::PathBuf,
12    str::FromStr,
13    sync::Arc,
14};
15
16use anyhow::Result;
17use as_any::AsAny;
18use candle_core::{DType, Device};
19use mistralrs_quant::{IsqType, QuantizedConfig};
20use serde::Deserialize;
21use tokio::sync::Mutex;
22
23pub use normal_loaders::{
24    AutoNormalLoader, DeepSeekV2Loader, DeepSeekV3Loader, GLM4Loader, GLM4MoeLiteLoader,
25    GLM4MoeLoader, Gemma2Loader, GemmaLoader, GptOssLoader, GraniteMoeHybridLoader, LlamaLoader,
26    MistralLoader, MixtralLoader, NormalLoaderType, NormalLoadingMetadata, NormalModel,
27    NormalModelLoader, Phi2Loader, Phi3Loader, Phi3_5MoELoader, Qwen2Loader, Qwen3Loader,
28    Qwen3MoELoader, Qwen3NextLoader, SmolLm3Loader, Starcoder2Loader,
29};
30
31pub use multimodal_loaders::{
32    AutoMultimodalLoader, Gemma3Loader, Gemma3nLoader, Gemma4Loader, Idefics2Loader,
33    Idefics3Loader, LLaVALoader, LLaVANextLoader, MiniCpmOLoader, Mistral3Loader,
34    MultimodalLoaderType, MultimodalModel, MultimodalModelLoader, Phi3VLoader, Phi4MMLoader,
35    Qwen2VLLoader, Qwen2_5VLLoader, Qwen3VLLoader, Qwen3VLMoELoader, Qwen3_5Loader,
36    Qwen3_5MoeLoader, VLlama4Loader, VLlamaLoader, VoxtralLoader,
37};
38
39pub use embedding_loaders::{
40    AutoEmbeddingLoader, EmbeddingGemmaLoader, EmbeddingLoaderType, EmbeddingModel,
41    EmbeddingModelLoader, EmbeddingModule, EmbeddingModulePaths, EmbeddingModuleType,
42    Qwen3EmbeddingLoader,
43};
44
45pub use diffusion_loaders::{
46    DiffusionLoaderType, DiffusionModel, DiffusionModelLoader, DiffusionModelPaths,
47    DiffusionModelPathsInner, FluxLoader,
48};
49
50use crate::{
51    matformer::MatformerSliceConfig, paged_attention::ModelConfigLike, DeviceMapMetadata,
52    DeviceMapSetting, PagedAttentionConfig, TryIntoDType,
53};
54
55use super::{paths::AdapterPaths, Pipeline};
56
57/// `ModelPaths` abstracts the mechanism to get all necessary files for running a model. For
58/// example `LocalModelPaths` implements `ModelPaths` when all files are in the local file system.
59pub trait ModelPaths: AsAny + Debug + Send + Sync {
60    /// Model weights files (multiple files supported).
61    fn get_weight_filenames(&self) -> &[PathBuf];
62
63    /// Retrieve the [`PretrainedConfig`] file.
64    ///
65    /// [`PretrainedConfig`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/configuration#transformers.PretrainedConfig
66    fn get_config_filename(&self) -> &PathBuf;
67
68    /// A serialised [`tokenizers.Tokenizer`] HuggingFace object.
69    ///
70    /// [`tokenizers.Tokenizer`]: https://huggingface.co/docs/transformers/v4.40.2/en/main_classes/tokenizer
71    fn get_tokenizer_filename(&self) -> &PathBuf;
72
73    /// File where the content is expected to deserialize to [`ChatTemplate`].
74    ///
75    /// [`ChatTemplate`]: crate::ChatTemplate
76    fn get_template_filename(&self) -> &Option<PathBuf>;
77
78    /// Filepath for general model configuration.
79    fn get_gen_conf_filename(&self) -> Option<&PathBuf>;
80
81    /// Get the preprocessor config (for the multimodal models). This is used to pre process images.
82    fn get_preprocessor_config(&self) -> &Option<PathBuf>;
83
84    /// Get the processor config (for the multimodal models). This is primarily used for the chat template.
85    fn get_processor_config(&self) -> &Option<PathBuf>;
86
87    /// Get the explicit chat template.
88    fn get_chat_template_explicit(&self) -> &Option<PathBuf>;
89
90    /// Get adapter paths.
91    fn get_adapter_paths(&self) -> &AdapterPaths;
92
93    /// Get embedding model `modules.json` compatible with sentence-transformers
94    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]>;
95}
96
97#[derive(Clone, Debug)]
98/// All local paths and metadata necessary to load a model.
99pub struct LocalModelPaths<P: Debug> {
100    pub tokenizer_filename: P,
101    pub config_filename: P,
102    pub template_filename: Option<P>,
103    pub filenames: Vec<P>,
104    pub adapter_paths: AdapterPaths,
105    pub gen_conf: Option<P>,
106    pub preprocessor_config: Option<P>,
107    pub processor_config: Option<P>,
108    pub chat_template_json_filename: Option<P>,
109}
110
111impl<P: Debug> LocalModelPaths<P> {
112    #[allow(clippy::too_many_arguments)]
113    pub fn new(
114        tokenizer_filename: P,
115        config_filename: P,
116        template_filename: P,
117        filenames: Vec<P>,
118        adapter_paths: AdapterPaths,
119        gen_conf: Option<P>,
120        preprocessor_config: Option<P>,
121        processor_config: Option<P>,
122        chat_template_json_filename: Option<P>,
123    ) -> Self {
124        Self {
125            tokenizer_filename,
126            config_filename,
127            template_filename: Some(template_filename),
128            filenames,
129            adapter_paths,
130            gen_conf,
131            preprocessor_config,
132            processor_config,
133            chat_template_json_filename,
134        }
135    }
136}
137
138impl ModelPaths for LocalModelPaths<PathBuf> {
139    fn get_config_filename(&self) -> &PathBuf {
140        &self.config_filename
141    }
142    fn get_tokenizer_filename(&self) -> &PathBuf {
143        &self.tokenizer_filename
144    }
145    fn get_weight_filenames(&self) -> &[PathBuf] {
146        &self.filenames
147    }
148    fn get_template_filename(&self) -> &Option<PathBuf> {
149        &self.template_filename
150    }
151    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
152        self.gen_conf.as_ref()
153    }
154    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
155        &self.preprocessor_config
156    }
157    fn get_processor_config(&self) -> &Option<PathBuf> {
158        &self.processor_config
159    }
160    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
161        &self.chat_template_json_filename
162    }
163    fn get_adapter_paths(&self) -> &AdapterPaths {
164        &self.adapter_paths
165    }
166    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
167        None
168    }
169}
170
171#[derive(Clone, Debug)]
172/// All local paths and metadata necessary to load an embedding model.
173pub struct EmbeddingModelPaths<P: Debug> {
174    pub tokenizer_filename: P,
175    pub config_filename: P,
176    pub modules: Vec<EmbeddingModulePaths>,
177    pub filenames: Vec<P>,
178    pub adapter_paths: AdapterPaths,
179}
180
181impl<P: Debug> EmbeddingModelPaths<P> {
182    #[allow(clippy::too_many_arguments)]
183    pub fn new(
184        tokenizer_filename: P,
185        config_filename: P,
186        filenames: Vec<P>,
187        adapter_paths: AdapterPaths,
188        modules: Vec<EmbeddingModulePaths>,
189    ) -> Self {
190        Self {
191            tokenizer_filename,
192            config_filename,
193            filenames,
194            adapter_paths,
195            modules,
196        }
197    }
198}
199
200impl ModelPaths for EmbeddingModelPaths<PathBuf> {
201    fn get_config_filename(&self) -> &PathBuf {
202        &self.config_filename
203    }
204    fn get_tokenizer_filename(&self) -> &PathBuf {
205        &self.tokenizer_filename
206    }
207    fn get_weight_filenames(&self) -> &[PathBuf] {
208        &self.filenames
209    }
210    fn get_template_filename(&self) -> &Option<PathBuf> {
211        &None
212    }
213    fn get_gen_conf_filename(&self) -> Option<&PathBuf> {
214        None
215    }
216    fn get_preprocessor_config(&self) -> &Option<PathBuf> {
217        &None
218    }
219    fn get_processor_config(&self) -> &Option<PathBuf> {
220        &None
221    }
222    fn get_chat_template_explicit(&self) -> &Option<PathBuf> {
223        &None
224    }
225    fn get_adapter_paths(&self) -> &AdapterPaths {
226        &self.adapter_paths
227    }
228    fn get_modules(&self) -> Option<&[EmbeddingModulePaths]> {
229        Some(&self.modules)
230    }
231}
232
233#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
234/// The source of the HF token.
235pub enum TokenSource {
236    Literal(String),
237    EnvVar(String),
238    Path(String),
239    CacheToken,
240    None,
241}
242
243impl FromStr for TokenSource {
244    type Err = String;
245
246    fn from_str(s: &str) -> Result<Self, Self::Err> {
247        let parts: Vec<&str> = s.splitn(2, ':').collect();
248        match parts[0] {
249            "literal" => parts
250                .get(1)
251                .map(|&value| TokenSource::Literal(value.to_string()))
252                .ok_or_else(|| "Expected a value for 'literal'".to_string()),
253            "env" => Ok(TokenSource::EnvVar(
254                parts
255                    .get(1)
256                    .unwrap_or(&"HUGGING_FACE_HUB_TOKEN")
257                    .to_string(),
258            )),
259            "path" => parts
260                .get(1)
261                .map(|&value| TokenSource::Path(value.to_string()))
262                .ok_or_else(|| "Expected a value for 'path'".to_string()),
263            "cache" => Ok(TokenSource::CacheToken),
264            "none" => Ok(TokenSource::None),
265            _ => Err("Invalid token source format".to_string()),
266        }
267    }
268}
269
270impl fmt::Display for TokenSource {
271    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
272        match self {
273            TokenSource::Literal(value) => write!(f, "literal:{value}"),
274            TokenSource::EnvVar(value) => write!(f, "env:{value}"),
275            TokenSource::Path(value) => write!(f, "path:{value}"),
276            TokenSource::CacheToken => write!(f, "cache"),
277            TokenSource::None => write!(f, "none"),
278        }
279    }
280}
281
282/// The kind of model to build.
283#[derive(Clone, Default, derive_more::From, strum::Display)]
284pub enum ModelKind {
285    #[default]
286    #[strum(to_string = "normal (no adapters)")]
287    Normal,
288
289    #[strum(to_string = "gguf quantized from {quant} (no adapters)")]
290    GgufQuantized { quant: QuantizationKind },
291
292    #[strum(to_string = "{adapter}")]
293    Adapter { adapter: AdapterKind },
294
295    #[strum(to_string = "{adapter}, gguf quantized from {quant}")]
296    GgufAdapter {
297        adapter: AdapterKind,
298        quant: QuantizationKind,
299    },
300
301    #[strum(to_string = "speculative: target: `{target}`, draft: `{draft}`")]
302    Speculative {
303        target: Box<ModelKind>,
304        draft: Box<ModelKind>,
305    },
306
307    #[strum(to_string = "anymoe: target: `{target}`")]
308    AnyMoe { target: Box<ModelKind> },
309}
310
311#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
312#[strum(serialize_all = "kebab-case")]
313pub enum QuantizationKind {
314    /// GGML
315    Ggml,
316    /// GGUF
317    Gguf,
318    /// GPTQ
319    Gptq,
320}
321
322#[derive(Clone, Copy, strum::Display, strum::EnumIs, strum::EnumMessage)]
323#[strum(serialize_all = "kebab-case")]
324pub enum AdapterKind {
325    /// LoRA
326    Lora,
327    /// X-LoRA
328    XLora,
329}
330
331// For the proper name as formatted via doc comment for a variant
332pub trait PrettyName: strum::EnumMessage + ToString {
333    fn pretty_name(&self) -> String {
334        match self.get_documentation() {
335            Some(s) => s.to_string(),
336            // Instead of panic via expect(),
337            // fallback to default kebab-case:
338            None => self.to_string(),
339        }
340    }
341}
342
343impl PrettyName for AdapterKind {}
344impl PrettyName for QuantizationKind {}
345
346impl ModelKind {
347    // Quantized helpers:
348    pub fn is_quantized(&self) -> bool {
349        self.quantized_kind().iter().any(|q| q.is_some())
350    }
351
352    pub fn is_quantized_and(&self, mut f: impl FnMut(QuantizationKind) -> bool) -> bool {
353        self.quantized_kind().iter().any(|q| q.is_some_and(&mut f))
354    }
355
356    pub fn quantized_kind(&self) -> Vec<Option<QuantizationKind>> {
357        use ModelKind::*;
358
359        match self {
360            Normal | Adapter { .. } => vec![None],
361            GgufQuantized { quant } | GgufAdapter { quant, .. } => vec![Some(*quant)],
362            Speculative { target, draft } => {
363                let t = *target.clone();
364                let d = *draft.clone();
365
366                [t.quantized_kind(), d.quantized_kind()].concat()
367            }
368            AnyMoe { target } => target.quantized_kind(),
369        }
370    }
371
372    // Adapter helpers:
373    pub fn is_adapted(&self) -> bool {
374        self.adapted_kind().iter().any(|a| a.is_some())
375    }
376
377    pub fn is_adapted_and(&self, mut f: impl FnMut(AdapterKind) -> bool) -> bool {
378        self.adapted_kind().iter().any(|a| a.is_some_and(&mut f))
379    }
380
381    pub fn adapted_kind(&self) -> Vec<Option<AdapterKind>> {
382        use ModelKind::*;
383
384        match self {
385            Normal | GgufQuantized { .. } => vec![None],
386            Adapter { adapter } | GgufAdapter { adapter, .. } => vec![Some(*adapter)],
387            Speculative { target, draft } => {
388                let t = *target.clone();
389                let d = *draft.clone();
390
391                [t.adapted_kind(), d.adapted_kind()].concat()
392            }
393            AnyMoe { target } => target.adapted_kind(),
394        }
395    }
396}
397
398#[derive(Deserialize)]
399pub struct QuantizationConfigShim {
400    quantization_config: Option<QuantizedConfig>,
401}
402
403impl QuantizationConfigShim {
404    pub fn get_quant_config_pack_factor(config: &str, dtype: DType) -> Result<usize> {
405        let QuantizationConfigShim {
406            quantization_config,
407        } = serde_json::from_str(config)?;
408
409        if let Some(quantization_config) = quantization_config {
410            Ok(quantization_config.pack_factor(dtype))
411        } else {
412            Ok(1)
413        }
414    }
415}
416
417pub trait DeviceMappedModelLoader {
418    /// Maximum activation size of non-mapped parts of this model.
419    /// Useful for the multimodal models which may prefer to keep the vison components on the GPU.
420    fn non_mapped_max_act_size_elems(
421        &self,
422        config: &str,
423        params: &AutoDeviceMapParams,
424    ) -> Result<usize>;
425    /// Maximum activation size of mapped parts of the model
426    fn mapped_max_act_size_elems(
427        &self,
428        config: &str,
429        params: &AutoDeviceMapParams,
430    ) -> Result<usize>;
431    /// weight_pack_factor only applies to quantized weights.
432    fn non_mapped_size_in_bytes(
433        &self,
434        config: &str,
435        dtype: DType,
436        weight_pack_factor: usize,
437        matformer_config: Option<&MatformerSliceConfig>,
438    ) -> Result<usize>;
439    /// weight_pack_factor only applies to quantized weights.
440    fn layer_sizes_in_bytes(
441        &self,
442        config: &str,
443        dtype: DType,
444        weight_pack_factor: usize,
445        matformer_config: Option<&MatformerSliceConfig>,
446    ) -> Result<Vec<usize>>;
447    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
448        None
449    }
450    fn num_layers(&self, config: &str) -> Result<usize>;
451    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>>;
452
453    #[allow(clippy::too_many_arguments)]
454    fn get_device_layers(
455        &self,
456        config: &str,
457        num_layers: usize,
458        layer_sizes_in_bytes: Vec<usize>,
459        non_mapped_size_in_bytes: usize,
460        total_model_size_in_bytes: usize,
461        devices: &[Device],
462        dtype: DType,
463        params: &AutoDeviceMapParams,
464        paged_attn_config: Option<&PagedAttentionConfig>,
465    ) -> Result<DeviceMapMetadata>
466    where
467        Self: Sized,
468    {
469        auto_device_map::get_device_layers(
470            self,
471            config,
472            num_layers,
473            layer_sizes_in_bytes,
474            non_mapped_size_in_bytes,
475            total_model_size_in_bytes,
476            devices,
477            dtype,
478            params,
479            paged_attn_config,
480        )
481    }
482}
483
484/// The `Loader` trait abstracts the loading process. The primary entrypoint is the
485/// `load_model` method.
486///
487/// # Example
488/// ```no_run
489/// use mistralrs_core::{Loader, TokenSource, DeviceMapSetting, AutoDeviceMapParams, ModelDType};
490/// use candle_core::Device;
491///
492/// let loader: Box<dyn Loader> = todo!();
493/// let pipeline = loader.load_model_from_hf(
494///     None,
495///     TokenSource::CacheToken,
496///     &ModelDType::Auto,
497///     &Device::cuda_if_available(0).unwrap(),
498///     false,
499///     DeviceMapSetting::Auto(AutoDeviceMapParams::default_text()),
500///     None,
501///     None,
502/// ).unwrap();
503/// ```
504pub trait Loader: Send + Sync {
505    /// If `revision` is None, then it defaults to `main`.
506    /// If `dtype` is None, then it defaults to the model default (usually BF16).
507    /// If model is not found on HF, will attempt to resolve locally.
508    #[allow(clippy::type_complexity, clippy::too_many_arguments)]
509    fn load_model_from_hf(
510        &self,
511        revision: Option<String>,
512        token_source: TokenSource,
513        dtype: &dyn TryIntoDType,
514        device: &Device,
515        silent: bool,
516        mapper: DeviceMapSetting,
517        in_situ_quant: Option<IsqType>,
518        paged_attn_config: Option<PagedAttentionConfig>,
519    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
520
521    /// Load a model from the specified paths.
522    /// Also initializes `DEBUG`.
523    #[allow(
524        clippy::type_complexity,
525        clippy::too_many_arguments,
526        clippy::borrowed_box
527    )]
528    fn load_model_from_path(
529        &self,
530        paths: &Box<dyn ModelPaths>,
531        dtype: &dyn TryIntoDType,
532        device: &Device,
533        silent: bool,
534        mapper: DeviceMapSetting,
535        in_situ_quant: Option<IsqType>,
536        paged_attn_config: Option<PagedAttentionConfig>,
537    ) -> Result<Arc<Mutex<dyn Pipeline + Send + Sync>>>;
538
539    fn get_id(&self) -> String;
540    fn get_kind(&self) -> ModelKind;
541}