Skip to main content

mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    lora::{LoraConfig, Ordering},
14    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15    pipeline::{
16        isq::IsqModelLoader,
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel,
19    },
20    utils::varbuilder_utils::DeviceForLoadTensor,
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36    models,
37    xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43    #[allow(clippy::too_many_arguments)]
44    fn forward(
45        &self,
46        input_ids: &Tensor,
47        seqlen_offsets: &[usize],
48        context_lens: Vec<(usize, usize)>,
49        position_ids: Vec<usize>,
50        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51        flash_params: &FlashParams,
52    ) -> candle_core::Result<Tensor>;
53    #[allow(clippy::too_many_arguments)]
54    fn xlora_forward(
55        &self,
56        input_ids: &Tensor,
57        input_ids_full: &Tensor,
58        seqlen_offsets: &[usize],
59        seqlen_offsets_full: &[usize],
60        no_kv_cache: bool,
61        non_granular_state: &Option<NonGranularState>,
62        context_lens: Vec<(usize, usize)>,
63        position_ids: Vec<usize>,
64        flash_params: &FlashParams,
65        flash_params_full: &FlashParams,
66    ) -> candle_core::Result<Tensor>;
67    fn is_xlora(&self) -> bool;
68    fn device(&self) -> &Device;
69    fn cache(&self) -> &EitherCache;
70    fn cache_mut(&mut self) -> &mut EitherCache;
71    fn max_seq_len(&self) -> usize;
72    fn config(&self) -> &ModelConfigMetadata;
73    fn model_config(&self) -> Arc<dyn ModelConfigLike + Send + Sync> {
74        Arc::new(self.config().clone())
75    }
76}
77
78/// Metadata for loading a model with ISQ or device mapping.
79pub struct NormalLoadingMetadata {
80    // Device mapping metadata which can be used to construct a concrete device mapper
81    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
82    // Flag to check if loading in ISQ
83    pub loading_isq: bool,
84    // Device mapping target device (the one that is not the cpu)
85    pub real_device: Device,
86    // MultiProgress support for parallelized loading
87    pub multi_progress: Arc<MultiProgress>,
88    // Optional Matryoshka Transformer slicing configuration
89    pub matformer_slicing_config: Option<MatformerSliceConfig>,
90}
91
92pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
93    fn load(
94        &self,
95        config: &str,
96        vb: ShardedVarBuilder,
97        normal_loading_metadata: NormalLoadingMetadata,
98        attention_mechanism: AttentionImplementation,
99    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
100    #[allow(clippy::too_many_arguments)]
101    fn load_xlora(
102        &self,
103        config: &str,
104        vb: ShardedVarBuilder,
105        lora_config: &[((String, String), LoraConfig)],
106        xlora_config: Option<XLoraConfig>,
107        xlora_ordering: Ordering,
108        normal_loading_metadata: NormalLoadingMetadata,
109        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
110    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
111    fn is_gptx(&self, config: &str) -> Result<bool>;
112    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
113        Ok(true)
114    }
115    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
116    fn get_device_for_tensor(
117        &self,
118        config: &str,
119        _mapper: &dyn DeviceMapper,
120        loading_isq: bool,
121    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
122        if loading_isq {
123            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
124        } else {
125            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
126            let num_layers = self.model_config(config)?.num_layers();
127            let closure = move |name: String| {
128                if let Some(captures) = re.captures(&name) {
129                    captures
130                        .get(1)
131                        .and_then(|m| m.as_str().parse::<usize>().ok())
132                        .map(|l| l.min(num_layers))
133                        .map(DeviceForLoadTensor::Idx)
134                        .unwrap_or(DeviceForLoadTensor::Base)
135                } else {
136                    DeviceForLoadTensor::Base
137                }
138            };
139
140            Ok(Arc::new(closure))
141        }
142    }
143}
144
145#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
146#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
147/// The architecture to load the normal model as.
148pub enum NormalLoaderType {
149    #[serde(rename = "mistral")]
150    Mistral,
151    #[serde(rename = "gemma")]
152    Gemma,
153    #[serde(rename = "mixtral")]
154    Mixtral,
155    #[serde(rename = "llama")]
156    Llama,
157    #[serde(rename = "phi2")]
158    Phi2,
159    #[serde(rename = "phi3")]
160    Phi3,
161    #[serde(rename = "qwen2")]
162    Qwen2,
163    #[serde(rename = "gemma2")]
164    Gemma2,
165    #[serde(rename = "starcoder2")]
166    Starcoder2,
167    #[serde(rename = "phi3.5moe")]
168    Phi3_5MoE,
169    #[serde(rename = "deepseekv2")]
170    DeepSeekV2,
171    #[serde(rename = "deepseekv3")]
172    DeepSeekV3,
173    #[serde(rename = "qwen3")]
174    Qwen3,
175    #[serde(rename = "glm4")]
176    GLM4,
177    #[serde(rename = "glm4moelite")]
178    GLM4MoeLite,
179    #[serde(rename = "glm4moe")]
180    GLM4Moe,
181    #[serde(rename = "qwen3moe")]
182    Qwen3Moe,
183    #[serde(rename = "smollm3")]
184    SmolLm3,
185    #[serde(rename = "granitemoehybrid")]
186    GraniteMoeHybrid,
187    #[serde(rename = "gpt_oss")]
188    GptOss,
189    #[serde(rename = "qwen3next")]
190    Qwen3Next,
191}
192
193// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
194impl NormalLoaderType {
195    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
196        match name {
197            "MistralForCausalLM" => Ok(Self::Mistral),
198            "MixtralForCausalLM" => Ok(Self::Mixtral),
199            "GemmaForCausalLM" => Ok(Self::Gemma),
200            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
201            "PhiForCausalLM" => Ok(Self::Phi2),
202            "Phi3ForCausalLM" => Ok(Self::Phi3),
203            "LlamaForCausalLM" => Ok(Self::Llama),
204            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
205            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
206            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
207            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
208            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
209            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
210            "Glm4ForCausalLM" => Ok(Self::GLM4),
211            "Glm4MoeLiteForCausalLM" => Ok(Self::GLM4MoeLite),
212            "Glm4MoeForCausalLM" => Ok(Self::GLM4Moe),
213            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
214            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
215            "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
216            "GptOssForCausalLM" => Ok(Self::GptOss),
217            "Qwen3NextForCausalLM" => Ok(Self::Qwen3Next),
218            other => anyhow::bail!(
219                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
220            ),
221        }
222    }
223}
224
225impl FromStr for NormalLoaderType {
226    type Err = String;
227    fn from_str(s: &str) -> Result<Self, Self::Err> {
228        match s {
229            "mistral" => Ok(Self::Mistral),
230            "gemma" => Ok(Self::Gemma),
231            "mixtral" => Ok(Self::Mixtral),
232            "llama" => Ok(Self::Llama),
233            "phi2" => Ok(Self::Phi2),
234            "phi3" => Ok(Self::Phi3),
235            "qwen2" => Ok(Self::Qwen2),
236            "gemma2" => Ok(Self::Gemma2),
237            "starcoder2" => Ok(Self::Starcoder2),
238            "phi3.5moe" => Ok(Self::Phi3_5MoE),
239            "deepseekv2" => Ok(Self::DeepSeekV2),
240            "deepseekv3" => Ok(Self::DeepSeekV3),
241            "qwen3" => Ok(Self::Qwen3),
242            "glm4" => Ok(Self::GLM4),
243            "glm4moelite" => Ok(Self::GLM4MoeLite),
244            "glm4moe" => Ok(Self::GLM4Moe),
245            "qwen3moe" => Ok(Self::Qwen3Moe),
246            "smollm3" => Ok(Self::SmolLm3),
247            "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
248            "gpt_oss" => Ok(Self::GptOss),
249            "qwen3next" => Ok(Self::Qwen3Next),
250            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `glm4moelite`, `glm4moe`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`, `qwen3next`.")),
251        }
252    }
253}
254
255impl Display for NormalLoaderType {
256    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257        match self {
258            Self::Gemma => write!(f, "gemma"),
259            Self::Gemma2 => write!(f, "gemma2"),
260            Self::Llama => write!(f, "llama"),
261            Self::Mistral => write!(f, "mistral"),
262            Self::Mixtral => write!(f, "mixtral"),
263            Self::Phi2 => write!(f, "phi2"),
264            Self::Phi3 => write!(f, "phi3"),
265            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
266            Self::Qwen2 => write!(f, "qwen2"),
267            Self::Starcoder2 => write!(f, "starcoder2"),
268            Self::DeepSeekV2 => write!(f, "deepseekv2"),
269            Self::DeepSeekV3 => write!(f, "deepseekv3"),
270            Self::Qwen3 => write!(f, "qwen3"),
271            Self::GLM4 => write!(f, "glm4"),
272            Self::GLM4MoeLite => write!(f, "glm4moelite"),
273            Self::GLM4Moe => write!(f, "glm4moe"),
274            Self::Qwen3Moe => write!(f, "qwen3moe"),
275            Self::SmolLm3 => write!(f, "smollm3"),
276            Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
277            Self::GptOss => write!(f, "gpt_oss"),
278            Self::Qwen3Next => write!(f, "qwen3next"),
279        }
280    }
281}
282
283macro_rules! bias_if {
284    ($cond:expr, $size:expr) => {
285        if $cond {
286            $size
287        } else {
288            0
289        }
290    };
291}
292
293/// Load a model based on the Hugging Face Transformers -CausalLM model class
294pub struct AutoNormalLoader;
295
296#[derive(Deserialize)]
297struct AutoNormalLoaderConfig {
298    architectures: Vec<String>,
299}
300
301impl AutoNormalLoader {
302    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
303        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
304        if auto_cfg.architectures.len() != 1 {
305            anyhow::bail!("Expected to have one name for `architectures` config field.")
306        }
307
308        let name = &auto_cfg.architectures[0];
309
310        let tp = NormalLoaderType::from_causal_lm_name(name)?;
311
312        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
313
314        match tp {
315            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
316            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
317            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
318            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
319            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
320            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
321            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
322            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
323            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
324            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
325            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
326            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
327            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
328            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
329            NormalLoaderType::GLM4MoeLite => Ok(Box::new(GLM4MoeLiteLoader)),
330            NormalLoaderType::GLM4Moe => Ok(Box::new(GLM4MoeLoader)),
331            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
332            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
333            NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
334            NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
335            NormalLoaderType::Qwen3Next => Ok(Box::new(Qwen3NextLoader)),
336        }
337    }
338}
339
340impl NormalModelLoader for AutoNormalLoader {
341    fn load(
342        &self,
343        config: &str,
344        vb: ShardedVarBuilder,
345        normal_loading_metadata: NormalLoadingMetadata,
346        attention_mechanism: AttentionImplementation,
347    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
348        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
349    }
350    fn load_xlora(
351        &self,
352        config: &str,
353        vb: ShardedVarBuilder,
354        lora_config: &[((String, String), LoraConfig)],
355        xlora_config: Option<XLoraConfig>,
356        xlora_ordering: Ordering,
357        normal_loading_metadata: NormalLoadingMetadata,
358        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
359    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
360        Self::get_loader(config)?.load_xlora(
361            config,
362            vb,
363            lora_config,
364            xlora_config,
365            xlora_ordering,
366            normal_loading_metadata,
367            preload_adapters,
368        )
369    }
370    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
371        Self::get_loader(config)?.get_config_repr(config)
372    }
373    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
374        Self::get_loader(config)?.supports_paged_attention(config)
375    }
376    fn is_gptx(&self, config: &str) -> Result<bool> {
377        Self::get_loader(config)?.is_gptx(config)
378    }
379}
380
381impl IsqModelLoader for AutoNormalLoader {
382    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
383        Self::get_loader(config)?.immediate_isq_predicates(config)
384    }
385    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
386        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
387    }
388    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
389        Self::get_loader(config)?.isq_layer_regexes(config)
390    }
391    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
392        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
393    }
394}
395
396impl DeviceMappedModelLoader for AutoNormalLoader {
397    fn non_mapped_size_in_bytes(
398        &self,
399        config: &str,
400        dtype: DType,
401        weight_pack_factor: usize,
402        _matformer_config: Option<&MatformerSliceConfig>,
403    ) -> Result<usize> {
404        Self::get_loader(config)?.non_mapped_size_in_bytes(
405            config,
406            dtype,
407            weight_pack_factor,
408            _matformer_config,
409        )
410    }
411    fn num_layers(&self, config: &str) -> Result<usize> {
412        Self::get_loader(config)?.num_layers(config)
413    }
414    fn layer_sizes_in_bytes(
415        &self,
416        config: &str,
417        dtype: DType,
418        weight_pack_factor: usize,
419        _matformer_config: Option<&MatformerSliceConfig>,
420    ) -> Result<Vec<usize>> {
421        Self::get_loader(config)?.layer_sizes_in_bytes(
422            config,
423            dtype,
424            weight_pack_factor,
425            _matformer_config,
426        )
427    }
428    fn mapped_max_act_size_elems(
429        &self,
430        config: &str,
431        params: &super::AutoDeviceMapParams,
432    ) -> Result<usize> {
433        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
434    }
435    fn non_mapped_max_act_size_elems(
436        &self,
437        _config: &str,
438        _params: &AutoDeviceMapParams,
439    ) -> Result<usize> {
440        Ok(0)
441    }
442    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
443        Self::get_loader(config)?.model_config(config)
444    }
445}
446
447// ======================== Mistral loader
448
449pub struct MistralLoader;
450
451impl NormalModelLoader for MistralLoader {
452    fn load(
453        &self,
454        config: &str,
455        vb: ShardedVarBuilder,
456        normal_loading_metadata: NormalLoadingMetadata,
457        attention_mechanism: AttentionImplementation,
458    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
459        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
460        Ok(Box::new(models::mistral::Model::new(
461            &cfg,
462            vb,
463            self.is_gptx(config)?,
464            normal_loading_metadata,
465            attention_mechanism,
466        )?))
467    }
468    fn load_xlora(
469        &self,
470        config: &str,
471        vb: ShardedVarBuilder,
472        lora_config: &[((String, String), LoraConfig)],
473        xlora_config: Option<XLoraConfig>,
474        xlora_ordering: Ordering,
475        normal_loading_metadata: NormalLoadingMetadata,
476        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
477    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
478        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
479        Ok(Box::new(xlora_models::XLoraMistral::new(
480            &cfg,
481            vb,
482            lora_config,
483            xlora_config,
484            xlora_ordering,
485            self.is_gptx(config)?,
486            normal_loading_metadata,
487            preload_adapters,
488        )?))
489    }
490    fn is_gptx(&self, _: &str) -> Result<bool> {
491        Ok(true)
492    }
493    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
494        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
495        Ok(Box::new(cfg))
496    }
497}
498
499impl IsqModelLoader for MistralLoader {
500    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
501        Ok(vec![
502            Regex::new(r"lm_head\.(weight|bias)$")?,
503            // Attention
504            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
505            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
506            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
507            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
508            // MLP
509            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
510            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
511            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
512        ])
513    }
514    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
515        self.isq_layer_regexes(config)
516    }
517}
518
519impl DeviceMappedModelLoader for MistralLoader {
520    fn mapped_max_act_size_elems(
521        &self,
522        config: &str,
523        params: &AutoDeviceMapParams,
524    ) -> Result<usize> {
525        let AutoDeviceMapParams::Text {
526            max_seq_len,
527            max_batch_size,
528        } = params
529        else {
530            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
531        };
532
533        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
534
535        Ok(
536            max_batch_size
537                * cfg.num_attention_heads
538                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
539        )
540    }
541    fn non_mapped_max_act_size_elems(
542        &self,
543        _config: &str,
544        _params: &AutoDeviceMapParams,
545    ) -> Result<usize> {
546        Ok(0)
547    }
548
549    fn non_mapped_size_in_bytes(
550        &self,
551        config: &str,
552        dtype: DType,
553        weight_pack_factor: usize,
554        _matformer_config: Option<&MatformerSliceConfig>,
555    ) -> Result<usize> {
556        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
557
558        let elems = {
559            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
560            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
561            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
562                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
563            } else {
564                0
565            };
566            let norm = cfg.hidden_size;
567            embed_tokens + lm_head + norm
568        };
569        Ok(elems * dtype.size_in_bytes())
570    }
571
572    fn layer_sizes_in_bytes(
573        &self,
574        config: &str,
575        dtype: DType,
576        weight_pack_factor: usize,
577        _matformer_config: Option<&MatformerSliceConfig>,
578    ) -> Result<Vec<usize>> {
579        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
580
581        let per_layer_elems = {
582            let input_layernorm = cfg.hidden_size;
583            let post_attention_layernorm = cfg.hidden_size;
584
585            let size_in = cfg.hidden_size;
586            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
587            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
588            let q_proj = size_in * size_q / weight_pack_factor;
589            let k_proj = size_in * size_kv / weight_pack_factor;
590            let v_proj = size_in * size_kv / weight_pack_factor;
591            let o_proj = size_q * size_in / weight_pack_factor;
592
593            let h_size = cfg.hidden_size;
594            let i_size = cfg.intermediate_size;
595            let gate_proj = h_size * i_size / weight_pack_factor;
596            let up_proj = h_size * i_size / weight_pack_factor;
597            let down_proj = i_size * h_size / weight_pack_factor;
598
599            input_layernorm
600                + post_attention_layernorm
601                + q_proj
602                + k_proj
603                + v_proj
604                + o_proj
605                + gate_proj
606                + up_proj
607                + down_proj
608        };
609        Ok(vec![
610            per_layer_elems * dtype.size_in_bytes();
611            cfg.num_hidden_layers
612        ])
613    }
614
615    fn num_layers(&self, config: &str) -> Result<usize> {
616        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
617        Ok(cfg.num_hidden_layers)
618    }
619
620    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
621        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
622
623        let cfg = ModelConfigMetadata {
624            max_seq_len: cfg.max_position_embeddings,
625            num_layers: cfg.num_hidden_layers,
626            hidden_size: cfg.hidden_size,
627            num_kv_heads: cfg.num_key_value_heads,
628            num_attn_heads: cfg.num_attention_heads,
629            sliding_window: cfg.sliding_window,
630            k_head_dim: cfg.head_dim(),
631            v_head_dim: cfg.head_dim(),
632            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
633        };
634
635        Ok(Box::new(cfg))
636    }
637}
638
639// ======================== Gemma loader
640
641/// [`NormalLoader`] for a Gemma model.
642///
643/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
644pub struct GemmaLoader;
645
646impl NormalModelLoader for GemmaLoader {
647    fn load(
648        &self,
649        config: &str,
650        vb: ShardedVarBuilder,
651        normal_loading_metadata: NormalLoadingMetadata,
652        attention_mechanism: AttentionImplementation,
653    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
654        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
655
656        Ok(Box::new(models::gemma::Model::new(
657            &cfg,
658            vb,
659            self.is_gptx(config)?,
660            normal_loading_metadata,
661            attention_mechanism,
662        )?))
663    }
664    fn load_xlora(
665        &self,
666        config: &str,
667        vb: ShardedVarBuilder,
668        lora_config: &[((String, String), LoraConfig)],
669        xlora_config: Option<XLoraConfig>,
670        xlora_ordering: Ordering,
671        normal_loading_metadata: NormalLoadingMetadata,
672        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
673    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
674        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
675
676        Ok(Box::new(xlora_models::XLoraGemma::new(
677            &cfg,
678            vb,
679            lora_config,
680            xlora_config,
681            xlora_ordering,
682            self.is_gptx(config)?,
683            normal_loading_metadata,
684            preload_adapters,
685        )?))
686    }
687    fn is_gptx(&self, _: &str) -> Result<bool> {
688        Ok(true)
689    }
690    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
691        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
692        Ok(Box::new(cfg))
693    }
694}
695
696impl IsqModelLoader for GemmaLoader {
697    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
698        Ok(vec![
699            Regex::new(r"lm_head\.(weight|bias)$")?,
700            // Attention
701            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
702            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
703            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
704            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
705            // MLP
706            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
707            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
708            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
709        ])
710    }
711    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
712        self.isq_layer_regexes(config)
713    }
714}
715
716impl DeviceMappedModelLoader for GemmaLoader {
717    fn mapped_max_act_size_elems(
718        &self,
719        config: &str,
720        params: &AutoDeviceMapParams,
721    ) -> Result<usize> {
722        let AutoDeviceMapParams::Text {
723            max_seq_len,
724            max_batch_size,
725        } = params
726        else {
727            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
728        };
729
730        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
731
732        Ok(
733            max_batch_size
734                * cfg.num_attention_heads
735                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
736        )
737    }
738    fn non_mapped_max_act_size_elems(
739        &self,
740        _config: &str,
741        _params: &AutoDeviceMapParams,
742    ) -> Result<usize> {
743        Ok(0)
744    }
745
746    fn non_mapped_size_in_bytes(
747        &self,
748        config: &str,
749        dtype: DType,
750        weight_pack_factor: usize,
751        _matformer_config: Option<&MatformerSliceConfig>,
752    ) -> Result<usize> {
753        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
754
755        let elems = {
756            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
757            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
758            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
759                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
760            } else {
761                0
762            };
763            let norm = cfg.hidden_size;
764            embed_tokens + lm_head + norm
765        };
766        Ok(elems * dtype.size_in_bytes())
767    }
768
769    fn layer_sizes_in_bytes(
770        &self,
771        config: &str,
772        dtype: DType,
773        weight_pack_factor: usize,
774        _matformer_config: Option<&MatformerSliceConfig>,
775    ) -> Result<Vec<usize>> {
776        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
777
778        let per_layer_elems = {
779            let input_layernorm = cfg.hidden_size;
780            let post_attention_layernorm = cfg.hidden_size;
781
782            let size_in = cfg.hidden_size;
783            let size_q = cfg.head_dim * cfg.num_attention_heads;
784            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
785            let q_proj =
786                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
787            let k_proj =
788                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
789            let v_proj =
790                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
791            let o_proj =
792                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
793
794            let h_size = cfg.hidden_size;
795            let i_size = cfg.intermediate_size;
796            let gate_proj = h_size * i_size / weight_pack_factor;
797            let up_proj = h_size * i_size / weight_pack_factor;
798            let down_proj = i_size * h_size / weight_pack_factor;
799
800            input_layernorm
801                + post_attention_layernorm
802                + q_proj
803                + k_proj
804                + v_proj
805                + o_proj
806                + gate_proj
807                + up_proj
808                + down_proj
809        };
810        Ok(vec![
811            per_layer_elems * dtype.size_in_bytes();
812            cfg.num_hidden_layers
813        ])
814    }
815
816    fn num_layers(&self, config: &str) -> Result<usize> {
817        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
818        Ok(cfg.num_hidden_layers)
819    }
820
821    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
822        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
823
824        let cfg = ModelConfigMetadata {
825            max_seq_len: cfg.max_position_embeddings,
826            num_layers: cfg.num_hidden_layers,
827            hidden_size: cfg.hidden_size,
828            num_kv_heads: cfg.num_key_value_heads,
829            num_attn_heads: cfg.num_attention_heads,
830            sliding_window: None,
831            k_head_dim: cfg.head_dim,
832            v_head_dim: cfg.head_dim,
833            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
834        };
835
836        Ok(Box::new(cfg))
837    }
838}
839
840// ======================== Llama loader
841
842/// [`NormalLoader`] for a Llama model.
843///
844/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
845pub struct LlamaLoader;
846
847impl NormalModelLoader for LlamaLoader {
848    fn load(
849        &self,
850        config: &str,
851        vb: ShardedVarBuilder,
852        normal_loading_metadata: NormalLoadingMetadata,
853        attention_mechanism: AttentionImplementation,
854    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
855        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
856
857        Ok(Box::new(models::llama::Llama::new(
858            &cfg,
859            vb,
860            self.is_gptx(config)?,
861            normal_loading_metadata,
862            attention_mechanism,
863        )?))
864    }
865    fn load_xlora(
866        &self,
867        config: &str,
868        vb: ShardedVarBuilder,
869        lora_config: &[((String, String), LoraConfig)],
870        xlora_config: Option<XLoraConfig>,
871        xlora_ordering: Ordering,
872        normal_loading_metadata: NormalLoadingMetadata,
873        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
874    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
875        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
876
877        Ok(Box::new(xlora_models::XLoraLlama::new(
878            &cfg,
879            vb,
880            lora_config,
881            xlora_config,
882            xlora_ordering,
883            self.is_gptx(config)?,
884            normal_loading_metadata,
885            preload_adapters,
886        )?))
887    }
888    fn is_gptx(&self, _: &str) -> Result<bool> {
889        Ok(true)
890    }
891    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
892        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
893        Ok(Box::new(cfg))
894    }
895}
896
897impl IsqModelLoader for LlamaLoader {
898    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
899        Ok(vec![
900            Regex::new(r"lm_head\.(weight|bias)$")?,
901            // Attention
902            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
903            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
904            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
905            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
906            // MLP
907            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
908            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
909            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
910        ])
911    }
912    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
913        self.isq_layer_regexes(config)
914    }
915}
916
917impl DeviceMappedModelLoader for LlamaLoader {
918    fn mapped_max_act_size_elems(
919        &self,
920        config: &str,
921        params: &AutoDeviceMapParams,
922    ) -> Result<usize> {
923        let AutoDeviceMapParams::Text {
924            max_seq_len,
925            max_batch_size,
926        } = params
927        else {
928            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
929        };
930
931        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
932
933        Ok(
934            max_batch_size
935                * cfg.num_attention_heads
936                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
937        )
938    }
939    fn non_mapped_max_act_size_elems(
940        &self,
941        _config: &str,
942        _params: &AutoDeviceMapParams,
943    ) -> Result<usize> {
944        Ok(0)
945    }
946
947    fn non_mapped_size_in_bytes(
948        &self,
949        config: &str,
950        dtype: DType,
951        weight_pack_factor: usize,
952        _matformer_config: Option<&MatformerSliceConfig>,
953    ) -> Result<usize> {
954        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
955
956        let elems = {
957            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
958            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
959            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
960                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
961            } else {
962                0
963            };
964            let norm = cfg.hidden_size;
965            embed_tokens + lm_head + norm
966        };
967        Ok(elems * dtype.size_in_bytes())
968    }
969
970    fn layer_sizes_in_bytes(
971        &self,
972        config: &str,
973        dtype: DType,
974        weight_pack_factor: usize,
975        _matformer_config: Option<&MatformerSliceConfig>,
976    ) -> Result<Vec<usize>> {
977        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
978
979        let per_layer_elems = {
980            let input_layernorm = cfg.hidden_size;
981            let post_attention_layernorm = cfg.hidden_size;
982
983            let size_in = cfg.hidden_size;
984            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
985            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
986            let q_proj = size_in * size_q / weight_pack_factor;
987            let k_proj = size_in * size_kv / weight_pack_factor;
988            let v_proj = size_in * size_kv / weight_pack_factor;
989            let o_proj = size_q * size_in / weight_pack_factor;
990
991            let h_size = cfg.hidden_size;
992            let i_size = cfg.intermediate_size;
993            let gate_proj = h_size * i_size / weight_pack_factor;
994            let up_proj = h_size * i_size / weight_pack_factor;
995            let down_proj = i_size * h_size / weight_pack_factor;
996
997            input_layernorm
998                + post_attention_layernorm
999                + q_proj
1000                + k_proj
1001                + v_proj
1002                + o_proj
1003                + gate_proj
1004                + up_proj
1005                + down_proj
1006        };
1007        Ok(vec![
1008            per_layer_elems * dtype.size_in_bytes();
1009            cfg.num_hidden_layers
1010        ])
1011    }
1012
1013    fn num_layers(&self, config: &str) -> Result<usize> {
1014        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1015
1016        Ok(cfg.num_hidden_layers)
1017    }
1018    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1019        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1020
1021        let cfg = ModelConfigMetadata {
1022            max_seq_len: cfg.max_position_embeddings,
1023            num_layers: cfg.num_hidden_layers,
1024            hidden_size: cfg.hidden_size,
1025            num_kv_heads: cfg.num_key_value_heads,
1026            num_attn_heads: cfg.num_attention_heads,
1027            sliding_window: None,
1028            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1029            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1030            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1031        };
1032
1033        Ok(Box::new(cfg))
1034    }
1035}
1036
1037// ======================== Mixtral loader
1038
1039pub struct MixtralLoader;
1040
1041impl NormalModelLoader for MixtralLoader {
1042    fn load(
1043        &self,
1044        config: &str,
1045        vb: ShardedVarBuilder,
1046        normal_loading_metadata: NormalLoadingMetadata,
1047        attention_mechanism: AttentionImplementation,
1048    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1049        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1050
1051        Ok(Box::new(models::mixtral::Model::new(
1052            &cfg,
1053            vb,
1054            self.is_gptx(config)?,
1055            normal_loading_metadata,
1056            attention_mechanism,
1057        )?))
1058    }
1059    fn load_xlora(
1060        &self,
1061        config: &str,
1062        vb: ShardedVarBuilder,
1063        lora_config: &[((String, String), LoraConfig)],
1064        xlora_config: Option<XLoraConfig>,
1065        xlora_ordering: Ordering,
1066        normal_loading_metadata: NormalLoadingMetadata,
1067        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1068    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1069        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1070
1071        Ok(Box::new(xlora_models::XLoraMixtral::new(
1072            &cfg,
1073            vb,
1074            lora_config,
1075            xlora_config,
1076            xlora_ordering,
1077            self.is_gptx(config)?,
1078            normal_loading_metadata,
1079            preload_adapters,
1080        )?))
1081    }
1082    fn is_gptx(&self, _: &str) -> Result<bool> {
1083        Ok(true)
1084    }
1085    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1086        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1087
1088        Ok(Box::new(cfg))
1089    }
1090}
1091
1092impl IsqModelLoader for MixtralLoader {
1093    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1094        Ok(vec![
1095            Regex::new(r"lm_head\.(weight|bias)$")?,
1096            // Attention
1097            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1098            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1099            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1100            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1101            // Experts
1102            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1103            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1104            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1105            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1106        ])
1107    }
1108    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1109        self.isq_layer_regexes(config)
1110    }
1111}
1112
1113impl DeviceMappedModelLoader for MixtralLoader {
1114    fn mapped_max_act_size_elems(
1115        &self,
1116        config: &str,
1117        params: &AutoDeviceMapParams,
1118    ) -> Result<usize> {
1119        let AutoDeviceMapParams::Text {
1120            max_seq_len,
1121            max_batch_size,
1122        } = params
1123        else {
1124            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1125        };
1126
1127        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1128
1129        Ok(
1130            max_batch_size
1131                * cfg.num_attention_heads
1132                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1133        )
1134    }
1135    fn non_mapped_max_act_size_elems(
1136        &self,
1137        _config: &str,
1138        _params: &AutoDeviceMapParams,
1139    ) -> Result<usize> {
1140        Ok(0)
1141    }
1142
1143    fn non_mapped_size_in_bytes(
1144        &self,
1145        config: &str,
1146        dtype: DType,
1147        weight_pack_factor: usize,
1148        _matformer_config: Option<&MatformerSliceConfig>,
1149    ) -> Result<usize> {
1150        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1151
1152        let elems = {
1153            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1154            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1155            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1156                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1157            } else {
1158                0
1159            };
1160            let norm = cfg.hidden_size;
1161            embed_tokens + lm_head + norm
1162        };
1163        Ok(elems * dtype.size_in_bytes())
1164    }
1165
1166    fn layer_sizes_in_bytes(
1167        &self,
1168        config: &str,
1169        dtype: DType,
1170        weight_pack_factor: usize,
1171        _matformer_config: Option<&MatformerSliceConfig>,
1172    ) -> Result<Vec<usize>> {
1173        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1174
1175        let per_layer_elems = {
1176            let input_layernorm = cfg.hidden_size;
1177            let post_attention_layernorm = cfg.hidden_size;
1178
1179            let size_in = cfg.hidden_size;
1180            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1181            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1182            let q_proj = size_in * size_q / weight_pack_factor;
1183            let k_proj = size_in * size_kv / weight_pack_factor;
1184            let v_proj = size_in * size_kv / weight_pack_factor;
1185            let o_proj = size_q * size_in / weight_pack_factor;
1186
1187            let moe_block = {
1188                let gate = cfg.hidden_size * cfg.num_local_experts;
1189                // Assume quantizing weight pack factor
1190                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1191                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1192                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1193                gate + cfg.num_local_experts * w1
1194                    + cfg.num_local_experts * w2
1195                    + cfg.num_local_experts * w3
1196            };
1197
1198            input_layernorm
1199                + post_attention_layernorm
1200                + q_proj
1201                + k_proj
1202                + v_proj
1203                + o_proj
1204                + moe_block
1205        };
1206        Ok(vec![
1207            per_layer_elems * dtype.size_in_bytes();
1208            cfg.num_hidden_layers
1209        ])
1210    }
1211
1212    fn num_layers(&self, config: &str) -> Result<usize> {
1213        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1214
1215        Ok(cfg.num_hidden_layers)
1216    }
1217
1218    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1219        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1220
1221        let cfg = ModelConfigMetadata {
1222            max_seq_len: cfg.max_position_embeddings,
1223            num_layers: cfg.num_hidden_layers,
1224            hidden_size: cfg.hidden_size,
1225            num_kv_heads: cfg.num_key_value_heads,
1226            num_attn_heads: cfg.num_attention_heads,
1227            sliding_window: cfg.sliding_window,
1228            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1229            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1230            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1231        };
1232
1233        Ok(Box::new(cfg))
1234    }
1235}
1236
1237// ======================== Phi2 loader
1238
1239/// [`NormalLoader`] for a Phi 2 model.
1240///
1241/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
1242pub struct Phi2Loader;
1243
1244impl NormalModelLoader for Phi2Loader {
1245    fn load(
1246        &self,
1247        config: &str,
1248        vb: ShardedVarBuilder,
1249        normal_loading_metadata: NormalLoadingMetadata,
1250        attention_mechanism: AttentionImplementation,
1251    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1252        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1253
1254        Ok(Box::new(models::phi2::Model::new(
1255            &cfg,
1256            vb,
1257            self.is_gptx(config)?,
1258            normal_loading_metadata,
1259            attention_mechanism,
1260        )?))
1261    }
1262    fn load_xlora(
1263        &self,
1264        config: &str,
1265        vb: ShardedVarBuilder,
1266        lora_config: &[((String, String), LoraConfig)],
1267        xlora_config: Option<XLoraConfig>,
1268        xlora_ordering: Ordering,
1269        normal_loading_metadata: NormalLoadingMetadata,
1270        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1271    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1272        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1273
1274        Ok(Box::new(xlora_models::XLoraPhi2::new(
1275            &cfg,
1276            vb,
1277            lora_config,
1278            xlora_config,
1279            xlora_ordering,
1280            self.is_gptx(config)?,
1281            normal_loading_metadata,
1282            preload_adapters,
1283        )?))
1284    }
1285    fn is_gptx(&self, _: &str) -> Result<bool> {
1286        Ok(true)
1287    }
1288    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1289        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1290
1291        Ok(Box::new(cfg))
1292    }
1293}
1294
1295impl IsqModelLoader for Phi2Loader {
1296    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1297        Ok(vec![
1298            Regex::new(r"lm_head\.(weight|bias)$")?,
1299            // Attention
1300            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1301            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1302            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1303            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1304            // MLP
1305            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1306            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1307        ])
1308    }
1309    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1310        self.isq_layer_regexes(config)
1311    }
1312}
1313
1314impl DeviceMappedModelLoader for Phi2Loader {
1315    fn mapped_max_act_size_elems(
1316        &self,
1317        config: &str,
1318        params: &AutoDeviceMapParams,
1319    ) -> Result<usize> {
1320        let AutoDeviceMapParams::Text {
1321            max_seq_len,
1322            max_batch_size,
1323        } = params
1324        else {
1325            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1326        };
1327
1328        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1329
1330        Ok(
1331            max_batch_size
1332                * cfg.num_attention_heads
1333                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1334        )
1335    }
1336    fn non_mapped_max_act_size_elems(
1337        &self,
1338        _config: &str,
1339        _params: &AutoDeviceMapParams,
1340    ) -> Result<usize> {
1341        Ok(0)
1342    }
1343
1344    fn non_mapped_size_in_bytes(
1345        &self,
1346        config: &str,
1347        dtype: DType,
1348        weight_pack_factor: usize,
1349        _matformer_config: Option<&MatformerSliceConfig>,
1350    ) -> Result<usize> {
1351        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1352
1353        let elems = {
1354            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1355            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1356            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1357                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1358            } else {
1359                0
1360            };
1361            let norm = cfg.hidden_size;
1362            embed_tokens + lm_head + norm
1363        };
1364        Ok(elems * dtype.size_in_bytes())
1365    }
1366
1367    fn layer_sizes_in_bytes(
1368        &self,
1369        config: &str,
1370        dtype: DType,
1371        weight_pack_factor: usize,
1372        _matformer_config: Option<&MatformerSliceConfig>,
1373    ) -> Result<Vec<usize>> {
1374        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1375
1376        let per_layer_elems = {
1377            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1378
1379            let size_in = cfg.hidden_size;
1380            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1381            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1382            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1383            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1384            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1385            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1386            let (q_norm, k_norm) = if cfg.qk_layernorm {
1387                (cfg.head_dim(), cfg.head_dim())
1388            } else {
1389                (0, 0)
1390            };
1391
1392            let h_size = cfg.hidden_size;
1393            let i_size = cfg.intermediate_size;
1394            let fc1 = h_size * i_size / weight_pack_factor;
1395            let fc2 = h_size * i_size / weight_pack_factor;
1396
1397            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1398        };
1399        Ok(vec![
1400            per_layer_elems * dtype.size_in_bytes();
1401            cfg.num_hidden_layers
1402        ])
1403    }
1404
1405    fn num_layers(&self, config: &str) -> Result<usize> {
1406        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1407
1408        Ok(cfg.num_hidden_layers)
1409    }
1410
1411    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1412        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1413
1414        let cfg = ModelConfigMetadata {
1415            max_seq_len: cfg.max_position_embeddings,
1416            num_layers: cfg.num_hidden_layers,
1417            hidden_size: cfg.hidden_size,
1418            num_kv_heads: cfg.num_key_value_heads(),
1419            num_attn_heads: cfg.num_attention_heads,
1420            sliding_window: None,
1421            k_head_dim: cfg.head_dim(),
1422            v_head_dim: cfg.head_dim(),
1423            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1424        };
1425
1426        Ok(Box::new(cfg))
1427    }
1428}
1429
1430// ======================== Phi3 loader
1431
1432/// [`NormalLoader`] for a Phi 3 model.
1433///
1434/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
1435pub struct Phi3Loader;
1436
1437impl NormalModelLoader for Phi3Loader {
1438    fn load(
1439        &self,
1440        config: &str,
1441        vb: ShardedVarBuilder,
1442        normal_loading_metadata: NormalLoadingMetadata,
1443        attention_mechanism: AttentionImplementation,
1444    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1445        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1446
1447        Ok(Box::new(models::phi3::Model::new(
1448            &cfg,
1449            vb,
1450            self.is_gptx(config)?,
1451            normal_loading_metadata,
1452            attention_mechanism,
1453        )?))
1454    }
1455    fn load_xlora(
1456        &self,
1457        config: &str,
1458        vb: ShardedVarBuilder,
1459        lora_config: &[((String, String), LoraConfig)],
1460        xlora_config: Option<XLoraConfig>,
1461        xlora_ordering: Ordering,
1462        normal_loading_metadata: NormalLoadingMetadata,
1463        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1464    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1465        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1466
1467        Ok(Box::new(xlora_models::XLoraPhi3::new(
1468            &cfg,
1469            vb,
1470            lora_config,
1471            xlora_config,
1472            xlora_ordering,
1473            self.is_gptx(config)?,
1474            normal_loading_metadata,
1475            preload_adapters,
1476        )?))
1477    }
1478    fn is_gptx(&self, _: &str) -> Result<bool> {
1479        Ok(true)
1480    }
1481    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1482        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1483
1484        Ok(Box::new(cfg))
1485    }
1486}
1487
1488impl IsqModelLoader for Phi3Loader {
1489    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1490        Ok(vec![
1491            Regex::new(r"lm_head\.(weight|bias)$")?,
1492            // Attention
1493            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1494            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1495            // MLP
1496            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1497            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1498            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1499        ])
1500    }
1501    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1502        self.isq_layer_regexes(config)
1503    }
1504}
1505
1506impl DeviceMappedModelLoader for Phi3Loader {
1507    fn mapped_max_act_size_elems(
1508        &self,
1509        config: &str,
1510        params: &AutoDeviceMapParams,
1511    ) -> Result<usize> {
1512        let AutoDeviceMapParams::Text {
1513            max_seq_len,
1514            max_batch_size,
1515        } = params
1516        else {
1517            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1518        };
1519
1520        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1521
1522        Ok(
1523            max_batch_size
1524                * cfg.num_attention_heads
1525                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1526        )
1527    }
1528    fn non_mapped_max_act_size_elems(
1529        &self,
1530        _config: &str,
1531        _params: &AutoDeviceMapParams,
1532    ) -> Result<usize> {
1533        Ok(0)
1534    }
1535
1536    fn non_mapped_size_in_bytes(
1537        &self,
1538        config: &str,
1539        dtype: DType,
1540        weight_pack_factor: usize,
1541        _matformer_config: Option<&MatformerSliceConfig>,
1542    ) -> Result<usize> {
1543        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1544
1545        let elems = {
1546            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1547            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1548            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1549                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1550            } else {
1551                0
1552            };
1553            let norm = cfg.hidden_size;
1554            embed_tokens + lm_head + norm
1555        };
1556        Ok(elems * dtype.size_in_bytes())
1557    }
1558
1559    fn layer_sizes_in_bytes(
1560        &self,
1561        config: &str,
1562        dtype: DType,
1563        weight_pack_factor: usize,
1564        _matformer_config: Option<&MatformerSliceConfig>,
1565    ) -> Result<Vec<usize>> {
1566        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1567
1568        let per_layer_elems = {
1569            let input_layernorm = cfg.hidden_size;
1570            let post_attention_layernorm = cfg.hidden_size;
1571
1572            let size_in = cfg.hidden_size;
1573            let head_dim = cfg.head_dim();
1574            let op_size =
1575                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1576            let qkv_proj = size_in * op_size / weight_pack_factor;
1577            let o_proj =
1578                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1579
1580            let h_size = cfg.hidden_size;
1581            let i_size = cfg.intermediate_size;
1582            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1583            let down_proj = h_size * i_size / weight_pack_factor;
1584
1585            input_layernorm
1586                + post_attention_layernorm
1587                + qkv_proj
1588                + o_proj
1589                + gate_up_proj
1590                + down_proj
1591        };
1592        Ok(vec![
1593            per_layer_elems * dtype.size_in_bytes();
1594            cfg.num_hidden_layers
1595        ])
1596    }
1597
1598    fn num_layers(&self, config: &str) -> Result<usize> {
1599        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1600
1601        Ok(cfg.num_hidden_layers)
1602    }
1603
1604    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1605        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1606
1607        let cfg = ModelConfigMetadata {
1608            max_seq_len: cfg.max_position_embeddings,
1609            num_layers: cfg.num_hidden_layers,
1610            hidden_size: cfg.hidden_size,
1611            num_kv_heads: cfg.num_key_value_heads,
1612            num_attn_heads: cfg.num_attention_heads,
1613            sliding_window: cfg.sliding_window,
1614            k_head_dim: cfg.head_dim(),
1615            v_head_dim: cfg.head_dim(),
1616            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1617        };
1618
1619        Ok(Box::new(cfg))
1620    }
1621}
1622
1623// ======================== Qwen2 loader
1624
1625/// [`NormalLoader`] for a Qwen 2 model.
1626///
1627/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
1628pub struct Qwen2Loader;
1629
1630impl NormalModelLoader for Qwen2Loader {
1631    fn load(
1632        &self,
1633        config: &str,
1634        vb: ShardedVarBuilder,
1635        normal_loading_metadata: NormalLoadingMetadata,
1636        attention_mechanism: AttentionImplementation,
1637    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1638        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1639
1640        Ok(Box::new(models::qwen2::Model::new(
1641            &cfg,
1642            vb,
1643            self.is_gptx(config)?,
1644            normal_loading_metadata,
1645            attention_mechanism,
1646        )?))
1647    }
1648    fn load_xlora(
1649        &self,
1650        _config: &str,
1651        _vb: ShardedVarBuilder,
1652        _lora_config: &[((String, String), LoraConfig)],
1653        _xlora_config: Option<XLoraConfig>,
1654        _xlora_ordering: Ordering,
1655        _normal_loading_metadata: NormalLoadingMetadata,
1656        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1657    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1658        todo!()
1659    }
1660    fn is_gptx(&self, _: &str) -> Result<bool> {
1661        Ok(true)
1662    }
1663    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1664        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1665
1666        Ok(Box::new(cfg))
1667    }
1668}
1669
1670impl IsqModelLoader for Qwen2Loader {
1671    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1672        Ok(vec![
1673            Regex::new(r"lm_head\.(weight|bias)$")?,
1674            // Attention
1675            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1676            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1677            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1678            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1679            // MLP
1680            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1681            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1682            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1683            // MLP MoE
1684            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
1685            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
1686            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
1687        ])
1688    }
1689    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1690        self.isq_layer_regexes(config)
1691    }
1692}
1693
1694impl DeviceMappedModelLoader for Qwen2Loader {
1695    fn mapped_max_act_size_elems(
1696        &self,
1697        config: &str,
1698        params: &AutoDeviceMapParams,
1699    ) -> Result<usize> {
1700        let AutoDeviceMapParams::Text {
1701            max_seq_len,
1702            max_batch_size,
1703        } = params
1704        else {
1705            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1706        };
1707
1708        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1709
1710        Ok(
1711            max_batch_size
1712                * cfg.num_attention_heads
1713                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1714        )
1715    }
1716    fn non_mapped_max_act_size_elems(
1717        &self,
1718        _config: &str,
1719        _params: &AutoDeviceMapParams,
1720    ) -> Result<usize> {
1721        Ok(0)
1722    }
1723
1724    fn non_mapped_size_in_bytes(
1725        &self,
1726        config: &str,
1727        dtype: DType,
1728        weight_pack_factor: usize,
1729        _matformer_config: Option<&MatformerSliceConfig>,
1730    ) -> Result<usize> {
1731        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1732
1733        let elems = {
1734            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1735            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1736            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1737                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1738            } else {
1739                0
1740            };
1741            let norm = cfg.hidden_size;
1742            embed_tokens + lm_head + norm
1743        };
1744        Ok(elems * dtype.size_in_bytes())
1745    }
1746
1747    fn layer_sizes_in_bytes(
1748        &self,
1749        config: &str,
1750        dtype: DType,
1751        weight_pack_factor: usize,
1752        _matformer_config: Option<&MatformerSliceConfig>,
1753    ) -> Result<Vec<usize>> {
1754        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1755
1756        let per_layer_elems = {
1757            let input_layernorm = cfg.hidden_size;
1758            let post_attention_layernorm = cfg.hidden_size;
1759
1760            let size_in = cfg.hidden_size;
1761            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1762            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1763            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1764            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1765            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1766            let o_proj = size_q * size_in / weight_pack_factor;
1767
1768            let h_size = cfg.hidden_size;
1769            let i_size = cfg.intermediate_size;
1770            let gate_proj = h_size * i_size / weight_pack_factor;
1771            let up_proj = h_size * i_size / weight_pack_factor;
1772            let down_proj = i_size * h_size / weight_pack_factor;
1773
1774            input_layernorm
1775                + post_attention_layernorm
1776                + q_proj
1777                + k_proj
1778                + v_proj
1779                + o_proj
1780                + gate_proj
1781                + up_proj
1782                + down_proj
1783        };
1784        Ok(vec![
1785            per_layer_elems * dtype.size_in_bytes();
1786            cfg.num_hidden_layers
1787        ])
1788    }
1789
1790    fn num_layers(&self, config: &str) -> Result<usize> {
1791        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1792
1793        Ok(cfg.num_hidden_layers)
1794    }
1795
1796    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1797        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1798
1799        let cfg = ModelConfigMetadata {
1800            max_seq_len: cfg.max_position_embeddings,
1801            num_layers: cfg.num_hidden_layers,
1802            hidden_size: cfg.hidden_size,
1803            num_kv_heads: cfg.num_key_value_heads,
1804            num_attn_heads: cfg.num_attention_heads,
1805            sliding_window: cfg.sliding_window,
1806            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1807            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1808            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1809        };
1810
1811        Ok(Box::new(cfg))
1812    }
1813}
1814
1815// ======================== Gemma2 loader
1816
1817/// [`NormalLoader`] for a Gemma2 model.
1818///
1819/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
1820pub struct Gemma2Loader;
1821
1822impl NormalModelLoader for Gemma2Loader {
1823    fn load(
1824        &self,
1825        config: &str,
1826        vb: ShardedVarBuilder,
1827        normal_loading_metadata: NormalLoadingMetadata,
1828        attention_mechanism: AttentionImplementation,
1829    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1830        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1831
1832        Ok(Box::new(models::gemma2::Model::new(
1833            &cfg,
1834            vb,
1835            self.is_gptx(config)?,
1836            normal_loading_metadata,
1837            attention_mechanism,
1838        )?))
1839    }
1840    fn load_xlora(
1841        &self,
1842        config: &str,
1843        vb: ShardedVarBuilder,
1844        lora_config: &[((String, String), LoraConfig)],
1845        xlora_config: Option<XLoraConfig>,
1846        xlora_ordering: Ordering,
1847        normal_loading_metadata: NormalLoadingMetadata,
1848        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1849    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1850        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1851
1852        Ok(Box::new(xlora_models::XLoraGemma2::new(
1853            &cfg,
1854            vb,
1855            lora_config,
1856            xlora_config,
1857            xlora_ordering,
1858            self.is_gptx(config)?,
1859            normal_loading_metadata,
1860            preload_adapters,
1861        )?))
1862    }
1863    fn is_gptx(&self, _: &str) -> Result<bool> {
1864        Ok(true)
1865    }
1866    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1867        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1868
1869        Ok(Box::new(cfg))
1870    }
1871}
1872
1873impl IsqModelLoader for Gemma2Loader {
1874    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1875        Ok(vec![
1876            Regex::new(r"lm_head\.(weight|bias)$")?,
1877            // Attention
1878            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1879            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1880            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1881            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1882            // MLP
1883            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1884            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1885            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1886        ])
1887    }
1888    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1889        self.isq_layer_regexes(config)
1890    }
1891}
1892
1893impl DeviceMappedModelLoader for Gemma2Loader {
1894    fn mapped_max_act_size_elems(
1895        &self,
1896        config: &str,
1897        params: &AutoDeviceMapParams,
1898    ) -> Result<usize> {
1899        let AutoDeviceMapParams::Text {
1900            max_seq_len,
1901            max_batch_size,
1902        } = params
1903        else {
1904            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1905        };
1906
1907        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1908
1909        Ok(
1910            max_batch_size
1911                * cfg.num_attention_heads
1912                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1913        )
1914    }
1915    fn non_mapped_max_act_size_elems(
1916        &self,
1917        _config: &str,
1918        _params: &AutoDeviceMapParams,
1919    ) -> Result<usize> {
1920        Ok(0)
1921    }
1922
1923    fn non_mapped_size_in_bytes(
1924        &self,
1925        config: &str,
1926        dtype: DType,
1927        weight_pack_factor: usize,
1928        _matformer_config: Option<&MatformerSliceConfig>,
1929    ) -> Result<usize> {
1930        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1931
1932        let elems = {
1933            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1934            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1935            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1936                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1937            } else {
1938                0
1939            };
1940            let norm = cfg.hidden_size;
1941            embed_tokens + lm_head + norm
1942        };
1943        Ok(elems * dtype.size_in_bytes())
1944    }
1945
1946    fn layer_sizes_in_bytes(
1947        &self,
1948        config: &str,
1949        dtype: DType,
1950        weight_pack_factor: usize,
1951        _matformer_config: Option<&MatformerSliceConfig>,
1952    ) -> Result<Vec<usize>> {
1953        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1954
1955        let per_layer_elems = {
1956            let input_layernorm = cfg.hidden_size;
1957            let post_attention_layernorm = cfg.hidden_size;
1958
1959            let size_in = cfg.hidden_size;
1960            let size_q = cfg.head_dim * cfg.num_attention_heads;
1961            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1962            let q_proj =
1963                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1964            let k_proj =
1965                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1966            let v_proj =
1967                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1968            let o_proj =
1969                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1970
1971            let h_size = cfg.hidden_size;
1972            let i_size = cfg.intermediate_size;
1973            let gate_proj = h_size * i_size / weight_pack_factor;
1974            let up_proj = h_size * i_size / weight_pack_factor;
1975            let down_proj = i_size * h_size / weight_pack_factor;
1976
1977            input_layernorm
1978                + post_attention_layernorm
1979                + q_proj
1980                + k_proj
1981                + v_proj
1982                + o_proj
1983                + gate_proj
1984                + up_proj
1985                + down_proj
1986        };
1987        Ok(vec![
1988            per_layer_elems * dtype.size_in_bytes();
1989            cfg.num_hidden_layers
1990        ])
1991    }
1992
1993    fn num_layers(&self, config: &str) -> Result<usize> {
1994        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1995
1996        Ok(cfg.num_hidden_layers)
1997    }
1998    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1999        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
2000
2001        let cfg = ModelConfigMetadata {
2002            max_seq_len: cfg.max_position_embeddings,
2003            num_layers: cfg.num_hidden_layers,
2004            hidden_size: cfg.hidden_size,
2005            num_kv_heads: cfg.num_key_value_heads,
2006            num_attn_heads: cfg.num_attention_heads,
2007            sliding_window: None, // None to be more forgiving, some do not
2008            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2009            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2010            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2011        };
2012
2013        Ok(Box::new(cfg))
2014    }
2015}
2016
2017// ======================== Starcoder2 loader
2018
2019/// [`NormalLoader`] for a Starcoder2 model.
2020///
2021/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
2022pub struct Starcoder2Loader;
2023
2024impl NormalModelLoader for Starcoder2Loader {
2025    fn load(
2026        &self,
2027        config: &str,
2028        vb: ShardedVarBuilder,
2029        normal_loading_metadata: NormalLoadingMetadata,
2030        attention_mechanism: AttentionImplementation,
2031    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2032        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2033
2034        Ok(Box::new(models::starcoder2::Model::new(
2035            &cfg,
2036            vb,
2037            self.is_gptx(config)?,
2038            normal_loading_metadata,
2039            attention_mechanism,
2040        )?))
2041    }
2042    fn load_xlora(
2043        &self,
2044        config: &str,
2045        vb: ShardedVarBuilder,
2046        lora_config: &[((String, String), LoraConfig)],
2047        xlora_config: Option<XLoraConfig>,
2048        xlora_ordering: Ordering,
2049        normal_loading_metadata: NormalLoadingMetadata,
2050        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2051    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2052        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2053
2054        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2055            &cfg,
2056            vb,
2057            lora_config,
2058            xlora_config,
2059            xlora_ordering,
2060            self.is_gptx(config)?,
2061            normal_loading_metadata,
2062            preload_adapters,
2063        )?))
2064    }
2065    fn is_gptx(&self, _: &str) -> Result<bool> {
2066        Ok(true)
2067    }
2068    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2069        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2070
2071        Ok(Box::new(cfg))
2072    }
2073}
2074
2075impl IsqModelLoader for Starcoder2Loader {
2076    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2077        Ok(vec![
2078            Regex::new(r"lm_head\.(weight|bias)$")?,
2079            // Attention
2080            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2081            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2082            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2083            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2084            // MLP
2085            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2086            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2087        ])
2088    }
2089    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2090        self.isq_layer_regexes(config)
2091    }
2092}
2093
2094impl DeviceMappedModelLoader for Starcoder2Loader {
2095    fn mapped_max_act_size_elems(
2096        &self,
2097        config: &str,
2098        params: &AutoDeviceMapParams,
2099    ) -> Result<usize> {
2100        let AutoDeviceMapParams::Text {
2101            max_seq_len,
2102            max_batch_size,
2103        } = params
2104        else {
2105            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2106        };
2107
2108        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2109
2110        Ok(
2111            max_batch_size
2112                * cfg.num_attention_heads
2113                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2114        )
2115    }
2116    fn non_mapped_max_act_size_elems(
2117        &self,
2118        _config: &str,
2119        _params: &AutoDeviceMapParams,
2120    ) -> Result<usize> {
2121        Ok(0)
2122    }
2123
2124    fn non_mapped_size_in_bytes(
2125        &self,
2126        config: &str,
2127        dtype: DType,
2128        weight_pack_factor: usize,
2129        _matformer_config: Option<&MatformerSliceConfig>,
2130    ) -> Result<usize> {
2131        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2132
2133        let elems = {
2134            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2135            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2136            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2137                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2138            } else {
2139                0
2140            };
2141            let norm = cfg.hidden_size + cfg.hidden_size;
2142            embed_tokens + lm_head + norm
2143        };
2144        Ok(elems * dtype.size_in_bytes())
2145    }
2146
2147    fn layer_sizes_in_bytes(
2148        &self,
2149        config: &str,
2150        dtype: DType,
2151        weight_pack_factor: usize,
2152        _matformer_config: Option<&MatformerSliceConfig>,
2153    ) -> Result<Vec<usize>> {
2154        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2155
2156        let per_layer_elems = {
2157            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2158            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2159
2160            let size_in = cfg.hidden_size;
2161            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2162            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2163            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2164            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2165            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2166            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2167
2168            let h_size = cfg.hidden_size;
2169            let i_size = cfg.intermediate_size;
2170            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2171            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2172
2173            input_layernorm
2174                + post_attention_layernorm
2175                + q_proj
2176                + k_proj
2177                + v_proj
2178                + o_proj
2179                + fc1
2180                + fc2
2181        };
2182        Ok(vec![
2183            per_layer_elems * dtype.size_in_bytes();
2184            cfg.num_hidden_layers
2185        ])
2186    }
2187
2188    fn num_layers(&self, config: &str) -> Result<usize> {
2189        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2190
2191        Ok(cfg.num_hidden_layers)
2192    }
2193
2194    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2195        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2196
2197        let cfg = ModelConfigMetadata {
2198            max_seq_len: cfg.max_position_embeddings,
2199            num_layers: cfg.num_hidden_layers,
2200            hidden_size: cfg.hidden_size,
2201            num_kv_heads: cfg.num_key_value_heads,
2202            num_attn_heads: cfg.num_attention_heads,
2203            sliding_window: cfg.sliding_window,
2204            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2205            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2206            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2207        };
2208
2209        Ok(Box::new(cfg))
2210    }
2211}
2212
2213// ======================== Phi3 loader
2214
2215/// [`NormalLoader`] for a Phi 3.5 MoE model.
2216///
2217/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
2218pub struct Phi3_5MoELoader;
2219
2220impl NormalModelLoader for Phi3_5MoELoader {
2221    fn load(
2222        &self,
2223        config: &str,
2224        vb: ShardedVarBuilder,
2225        normal_loading_metadata: NormalLoadingMetadata,
2226        attention_mechanism: AttentionImplementation,
2227    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2228        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2229
2230        Ok(Box::new(models::phi3_5_moe::Model::new(
2231            &cfg,
2232            vb,
2233            self.is_gptx(config)?,
2234            normal_loading_metadata,
2235            attention_mechanism,
2236        )?))
2237    }
2238    fn load_xlora(
2239        &self,
2240        config: &str,
2241        vb: ShardedVarBuilder,
2242        lora_config: &[((String, String), LoraConfig)],
2243        xlora_config: Option<XLoraConfig>,
2244        xlora_ordering: Ordering,
2245        normal_loading_metadata: NormalLoadingMetadata,
2246        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2247    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2248        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2249
2250        Ok(Box::new(xlora_models::XLoraPhi3::new(
2251            &cfg,
2252            vb,
2253            lora_config,
2254            xlora_config,
2255            xlora_ordering,
2256            self.is_gptx(config)?,
2257            normal_loading_metadata,
2258            preload_adapters,
2259        )?))
2260    }
2261    fn is_gptx(&self, _: &str) -> Result<bool> {
2262        Ok(true)
2263    }
2264    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2265        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2266
2267        Ok(Box::new(cfg))
2268    }
2269}
2270
2271impl IsqModelLoader for Phi3_5MoELoader {
2272    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2273        Ok(vec![
2274            Regex::new(r"lm_head\.(weight|bias)$")?,
2275            // Attention
2276            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2277            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2278            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2279            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2280            // MLP
2281            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2282            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2283            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2284        ])
2285    }
2286    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2287        self.isq_layer_regexes(config)
2288    }
2289
2290    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2291        Ok(vec![
2292            Regex::new(r"lm_head\.(weight|bias)$")?,
2293            // MLP
2294            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2295            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2296            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2297        ])
2298    }
2299    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2300        self.isq_layer_regexes_moqe(config)
2301    }
2302}
2303
2304impl DeviceMappedModelLoader for Phi3_5MoELoader {
2305    fn mapped_max_act_size_elems(
2306        &self,
2307        config: &str,
2308        params: &AutoDeviceMapParams,
2309    ) -> Result<usize> {
2310        let AutoDeviceMapParams::Text {
2311            max_seq_len,
2312            max_batch_size,
2313        } = params
2314        else {
2315            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2316        };
2317
2318        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2319
2320        Ok(
2321            max_batch_size
2322                * cfg.num_attention_heads
2323                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2324        )
2325    }
2326    fn non_mapped_max_act_size_elems(
2327        &self,
2328        _config: &str,
2329        _params: &AutoDeviceMapParams,
2330    ) -> Result<usize> {
2331        Ok(0)
2332    }
2333
2334    fn non_mapped_size_in_bytes(
2335        &self,
2336        config: &str,
2337        dtype: DType,
2338        weight_pack_factor: usize,
2339        _matformer_config: Option<&MatformerSliceConfig>,
2340    ) -> Result<usize> {
2341        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2342
2343        let elems = {
2344            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2345            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2346            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2347                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2348            } else {
2349                0
2350            };
2351            let norm = cfg.hidden_size;
2352            embed_tokens + lm_head + norm
2353        };
2354        Ok(elems * dtype.size_in_bytes())
2355    }
2356
2357    fn layer_sizes_in_bytes(
2358        &self,
2359        config: &str,
2360        dtype: DType,
2361        weight_pack_factor: usize,
2362        _matformer_config: Option<&MatformerSliceConfig>,
2363    ) -> Result<Vec<usize>> {
2364        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2365
2366        let per_layer_elems = {
2367            let input_layernorm = cfg.hidden_size;
2368            let post_attention_layernorm = cfg.hidden_size;
2369
2370            let size_in = cfg.hidden_size;
2371            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2372            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2373            let q_proj =
2374                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2375            let k_proj =
2376                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2377            let v_proj =
2378                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2379            let o_proj =
2380                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2381
2382            let moe_block = {
2383                let gate = cfg.hidden_size * cfg.num_local_experts;
2384                // Assume quantizing weight pack factor
2385                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2386                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2387                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2388                gate + cfg.num_local_experts * w1
2389                    + cfg.num_local_experts * w2
2390                    + cfg.num_local_experts * w3
2391            };
2392
2393            input_layernorm
2394                + post_attention_layernorm
2395                + q_proj
2396                + k_proj
2397                + v_proj
2398                + o_proj
2399                + moe_block
2400        };
2401        Ok(vec![
2402            per_layer_elems * dtype.size_in_bytes();
2403            cfg.num_hidden_layers
2404        ])
2405    }
2406
2407    fn num_layers(&self, config: &str) -> Result<usize> {
2408        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2409
2410        Ok(cfg.num_hidden_layers)
2411    }
2412
2413    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2414        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2415
2416        let cfg = ModelConfigMetadata {
2417            max_seq_len: cfg.max_position_embeddings,
2418            num_layers: cfg.num_hidden_layers,
2419            hidden_size: cfg.hidden_size,
2420            num_kv_heads: cfg.num_key_value_heads,
2421            num_attn_heads: cfg.num_attention_heads,
2422            sliding_window: cfg.sliding_window,
2423            k_head_dim: cfg.head_dim(),
2424            v_head_dim: cfg.head_dim(),
2425            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2426        };
2427
2428        Ok(Box::new(cfg))
2429    }
2430}
2431
2432/// [`NormalLoader`] for a DeepSeekV2 model.
2433///
2434/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
2435pub struct DeepSeekV2Loader;
2436
2437impl NormalModelLoader for DeepSeekV2Loader {
2438    fn load(
2439        &self,
2440        config: &str,
2441        vb: ShardedVarBuilder,
2442        normal_loading_metadata: NormalLoadingMetadata,
2443        attention_mechanism: AttentionImplementation,
2444    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2445        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2446
2447        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2448            &cfg,
2449            vb,
2450            self.is_gptx(config)?,
2451            normal_loading_metadata,
2452            attention_mechanism,
2453        )?))
2454    }
2455    fn load_xlora(
2456        &self,
2457        _config: &str,
2458        _vb: ShardedVarBuilder,
2459        _lora_config: &[((String, String), LoraConfig)],
2460        _xlora_config: Option<XLoraConfig>,
2461        _xlora_ordering: Ordering,
2462        _normal_loading_metadata: NormalLoadingMetadata,
2463        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2464    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2465        todo!()
2466    }
2467    fn is_gptx(&self, _: &str) -> Result<bool> {
2468        Ok(true)
2469    }
2470    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2471        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2472        Ok(Box::new(cfg))
2473    }
2474}
2475
2476impl IsqModelLoader for DeepSeekV2Loader {
2477    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2478        let mut data = vec![
2479            Regex::new(r"lm_head\.(weight|bias)$")?,
2480            // Attention
2481            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2482            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2483            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2484        ];
2485        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2486        if cfg.q_lora_rank.is_some() {
2487            data.extend(vec![
2488                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2489                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2490            ]);
2491        } else {
2492            data.push(Regex::new(
2493                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2494            )?);
2495        }
2496        for layer_idx in 0..cfg.num_hidden_layers {
2497            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2498                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2499            }) {
2500                for i in 0..n_routed_experts {
2501                    data.extend(vec![
2502                        Regex::new(&format!(
2503                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2504                        ))?,
2505                        Regex::new(&format!(
2506                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2507                        ))?,
2508                        Regex::new(&format!(
2509                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2510                        ))?,
2511                    ]);
2512                }
2513                if cfg.n_shared_experts.is_some() {
2514                    data.extend(vec![
2515                        Regex::new(&format!(
2516                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2517                        ))?,
2518                        Regex::new(&format!(
2519                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2520                        ))?,
2521                        Regex::new(&format!(
2522                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2523                        ))?,
2524                    ]);
2525                }
2526            } else {
2527                data.extend(vec![
2528                    Regex::new(&format!(
2529                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2530                    ))?,
2531                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2532                    Regex::new(&format!(
2533                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2534                    ))?,
2535                ]);
2536            };
2537        }
2538        Ok(data)
2539    }
2540    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2541        self.isq_layer_regexes(config)
2542    }
2543
2544    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2545        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2546        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2547        for layer_idx in 0..cfg.num_hidden_layers {
2548            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2549                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2550            }) {
2551                for i in 0..n_routed_experts {
2552                    data.extend(vec![
2553                        Regex::new(&format!(
2554                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2555                        ))?,
2556                        Regex::new(&format!(
2557                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2558                        ))?,
2559                        Regex::new(&format!(
2560                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2561                        ))?,
2562                    ]);
2563                }
2564                if cfg.n_shared_experts.is_some() {
2565                    data.extend(vec![
2566                        Regex::new(&format!(
2567                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2568                        ))?,
2569                        Regex::new(&format!(
2570                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2571                        ))?,
2572                        Regex::new(&format!(
2573                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2574                        ))?,
2575                    ]);
2576                }
2577            } else {
2578                data.extend(vec![
2579                    Regex::new(&format!(
2580                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2581                    ))?,
2582                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2583                    Regex::new(&format!(
2584                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2585                    ))?,
2586                ]);
2587            };
2588        }
2589        Ok(data)
2590    }
2591    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2592        self.isq_layer_regexes_moqe(config)
2593    }
2594}
2595
2596impl DeviceMappedModelLoader for DeepSeekV2Loader {
2597    fn mapped_max_act_size_elems(
2598        &self,
2599        config: &str,
2600        params: &AutoDeviceMapParams,
2601    ) -> Result<usize> {
2602        let AutoDeviceMapParams::Text {
2603            max_seq_len,
2604            max_batch_size,
2605        } = params
2606        else {
2607            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2608        };
2609
2610        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2611
2612        Ok(
2613            max_batch_size
2614                * cfg.num_attention_heads
2615                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2616        )
2617    }
2618    fn non_mapped_max_act_size_elems(
2619        &self,
2620        _config: &str,
2621        _params: &AutoDeviceMapParams,
2622    ) -> Result<usize> {
2623        Ok(0)
2624    }
2625
2626    fn non_mapped_size_in_bytes(
2627        &self,
2628        config: &str,
2629        dtype: DType,
2630        weight_pack_factor: usize,
2631        _matformer_config: Option<&MatformerSliceConfig>,
2632    ) -> Result<usize> {
2633        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2634        let elems = {
2635            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2636            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2637            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2638                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2639            } else {
2640                0
2641            };
2642            let norm = cfg.hidden_size;
2643            embed_tokens + lm_head + norm
2644        };
2645        Ok(elems * dtype.size_in_bytes())
2646    }
2647
2648    fn layer_sizes_in_bytes(
2649        &self,
2650        config: &str,
2651        dtype: DType,
2652        weight_pack_factor: usize,
2653        _matformer_config: Option<&MatformerSliceConfig>,
2654    ) -> Result<Vec<usize>> {
2655        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2656        let mut per_layer_elems = Vec::new();
2657
2658        for layer_idx in 0..cfg.num_hidden_layers {
2659            let input_layernorm = cfg.hidden_size;
2660            let post_attention_layernorm = cfg.hidden_size;
2661
2662            let q_proj = match cfg.q_lora_rank {
2663                Some(lora_rank) => {
2664                    let a = cfg.hidden_size * lora_rank;
2665                    let norm = lora_rank;
2666                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2667                    a + norm + b
2668                }
2669                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2670            };
2671            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2672                / weight_pack_factor
2673                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2674            let kv_a_layernorm = cfg.kv_lora_rank;
2675            let kv_b_proj = cfg.kv_lora_rank
2676                * cfg.num_attention_heads
2677                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2678                / weight_pack_factor;
2679            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2680                / weight_pack_factor
2681                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2682
2683            let moe_block = {
2684                let mut sum = 0;
2685                if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2686                    layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2687                }) {
2688                    let h_size = cfg.hidden_size;
2689                    let gate_proj =
2690                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2691                    let up_proj =
2692                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2693                    let down_proj =
2694                        cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
2695                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2696                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2697                            / weight_pack_factor;
2698                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2699                            / weight_pack_factor;
2700                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2701                            / weight_pack_factor;
2702                        gate_proj + up_proj + down_proj
2703                    } else {
2704                        0
2705                    };
2706                    let gate_weight = n_routed_experts * cfg.hidden_size;
2707                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2708                } else {
2709                    let h_size = cfg.hidden_size;
2710                    let i_size = cfg.intermediate_size;
2711                    let gate_proj = h_size * i_size / weight_pack_factor;
2712                    let up_proj = h_size * i_size / weight_pack_factor;
2713                    let down_proj = i_size * h_size / weight_pack_factor;
2714                    sum += gate_proj + up_proj + down_proj;
2715                }
2716                sum
2717            };
2718
2719            per_layer_elems.push(
2720                input_layernorm
2721                    + post_attention_layernorm
2722                    + q_proj
2723                    + kv_a_layernorm
2724                    + kv_a_proj_with_mqa
2725                    + kv_b_proj
2726                    + o_proj
2727                    + moe_block,
2728            );
2729        }
2730
2731        Ok(per_layer_elems
2732            .into_iter()
2733            .map(|x| x * dtype.size_in_bytes())
2734            .collect())
2735    }
2736
2737    fn num_layers(&self, config: &str) -> Result<usize> {
2738        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2739        Ok(cfg.num_hidden_layers)
2740    }
2741
2742    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2743        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2744
2745        let cfg = ModelConfigMetadata {
2746            max_seq_len: cfg.max_position_embeddings,
2747            num_layers: cfg.num_hidden_layers,
2748            hidden_size: cfg.hidden_size,
2749            num_kv_heads: cfg.num_attention_heads,
2750            num_attn_heads: cfg.num_attention_heads,
2751            sliding_window: None,
2752            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2753            v_head_dim: cfg.v_head_dim,
2754            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2755        };
2756
2757        Ok(Box::new(cfg))
2758    }
2759}
2760
2761/// [`NormalLoader`] for a DeepSeekV3 model.
2762///
2763/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
2764pub struct DeepSeekV3Loader;
2765
2766impl NormalModelLoader for DeepSeekV3Loader {
2767    fn load(
2768        &self,
2769        config: &str,
2770        vb: ShardedVarBuilder,
2771        normal_loading_metadata: NormalLoadingMetadata,
2772        attention_mechanism: AttentionImplementation,
2773    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2774        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2775        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2776            &cfg,
2777            vb,
2778            self.is_gptx(config)?,
2779            normal_loading_metadata,
2780            attention_mechanism,
2781        )?))
2782    }
2783    fn load_xlora(
2784        &self,
2785        _config: &str,
2786        _vb: ShardedVarBuilder,
2787        _lora_config: &[((String, String), LoraConfig)],
2788        _xlora_config: Option<XLoraConfig>,
2789        _xlora_ordering: Ordering,
2790        _normal_loading_metadata: NormalLoadingMetadata,
2791        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2792    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2793        todo!()
2794    }
2795    fn is_gptx(&self, _: &str) -> Result<bool> {
2796        Ok(true)
2797    }
2798    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2799        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2800        Ok(Box::new(cfg))
2801    }
2802}
2803
2804impl IsqModelLoader for DeepSeekV3Loader {
2805    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2806        let mut data = vec![
2807            Regex::new(r"lm_head\.(weight|bias)$")?,
2808            // Attention
2809            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2810            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2811            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2812        ];
2813        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2814        if cfg.q_lora_rank.is_some() {
2815            data.extend(vec![
2816                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2817                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2818            ]);
2819        } else {
2820            data.push(Regex::new(
2821                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2822            )?);
2823        }
2824        for layer_idx in 0..cfg.num_hidden_layers {
2825            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2826                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2827            }) {
2828                for i in 0..n_routed_experts {
2829                    data.extend(vec![
2830                        Regex::new(&format!(
2831                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2832                        ))?,
2833                        Regex::new(&format!(
2834                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2835                        ))?,
2836                        Regex::new(&format!(
2837                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2838                        ))?,
2839                    ]);
2840                }
2841                if cfg.n_shared_experts.is_some() {
2842                    data.extend(vec![
2843                        Regex::new(&format!(
2844                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2845                        ))?,
2846                        Regex::new(&format!(
2847                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2848                        ))?,
2849                        Regex::new(&format!(
2850                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2851                        ))?,
2852                    ]);
2853                }
2854            } else {
2855                data.extend(vec![
2856                    Regex::new(&format!(
2857                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2858                    ))?,
2859                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2860                    Regex::new(&format!(
2861                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2862                    ))?,
2863                ]);
2864            };
2865        }
2866        Ok(data)
2867    }
2868    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2869        self.isq_layer_regexes(config)
2870    }
2871
2872    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2873        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2874        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2875        for layer_idx in 0..cfg.num_hidden_layers {
2876            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2877                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2878            }) {
2879                for i in 0..n_routed_experts {
2880                    data.extend(vec![
2881                        Regex::new(&format!(
2882                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2883                        ))?,
2884                        Regex::new(&format!(
2885                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2886                        ))?,
2887                        Regex::new(&format!(
2888                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2889                        ))?,
2890                    ]);
2891                }
2892                if cfg.n_shared_experts.is_some() {
2893                    data.extend(vec![
2894                        Regex::new(&format!(
2895                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2896                        ))?,
2897                        Regex::new(&format!(
2898                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2899                        ))?,
2900                        Regex::new(&format!(
2901                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2902                        ))?,
2903                    ]);
2904                }
2905            } else {
2906                data.extend(vec![
2907                    Regex::new(&format!(
2908                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2909                    ))?,
2910                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2911                    Regex::new(&format!(
2912                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2913                    ))?,
2914                ]);
2915            };
2916        }
2917        Ok(data)
2918    }
2919    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2920        self.isq_layer_regexes_moqe(config)
2921    }
2922}
2923
2924impl DeviceMappedModelLoader for DeepSeekV3Loader {
2925    fn mapped_max_act_size_elems(
2926        &self,
2927        config: &str,
2928        params: &AutoDeviceMapParams,
2929    ) -> Result<usize> {
2930        let AutoDeviceMapParams::Text {
2931            max_seq_len,
2932            max_batch_size,
2933        } = params
2934        else {
2935            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2936        };
2937
2938        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2939
2940        Ok(
2941            max_batch_size
2942                * cfg.num_attention_heads
2943                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2944        )
2945    }
2946    fn non_mapped_max_act_size_elems(
2947        &self,
2948        _config: &str,
2949        _params: &AutoDeviceMapParams,
2950    ) -> Result<usize> {
2951        Ok(0)
2952    }
2953
2954    fn non_mapped_size_in_bytes(
2955        &self,
2956        config: &str,
2957        dtype: DType,
2958        weight_pack_factor: usize,
2959        _matformer_config: Option<&MatformerSliceConfig>,
2960    ) -> Result<usize> {
2961        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2962        let elems = {
2963            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2964            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2965            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2966                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2967            } else {
2968                0
2969            };
2970            let norm = cfg.hidden_size;
2971            embed_tokens + lm_head + norm
2972        };
2973        Ok(elems * dtype.size_in_bytes())
2974    }
2975
2976    fn layer_sizes_in_bytes(
2977        &self,
2978        config: &str,
2979        dtype: DType,
2980        weight_pack_factor: usize,
2981        _matformer_config: Option<&MatformerSliceConfig>,
2982    ) -> Result<Vec<usize>> {
2983        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2984        let mut per_layer_elems = Vec::new();
2985
2986        for layer_idx in 0..cfg.num_hidden_layers {
2987            let input_layernorm = cfg.hidden_size;
2988            let post_attention_layernorm = cfg.hidden_size;
2989
2990            let q_proj = match cfg.q_lora_rank {
2991                Some(lora_rank) => {
2992                    let a = cfg.hidden_size * lora_rank;
2993                    let norm = lora_rank;
2994                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2995                    a + norm + b
2996                }
2997                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2998            };
2999            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
3000                / weight_pack_factor
3001                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
3002            let kv_a_layernorm = cfg.kv_lora_rank;
3003            let kv_b_proj = cfg.kv_lora_rank
3004                * cfg.num_attention_heads
3005                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3006                / weight_pack_factor;
3007            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
3008                / weight_pack_factor
3009                + bias_if!(cfg.attention_bias, cfg.hidden_size);
3010
3011            let moe_block = {
3012                let mut sum = 0;
3013                if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
3014                    layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
3015                }) {
3016                    let h_size = cfg.hidden_size;
3017                    let gate_proj =
3018                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3019                    let up_proj =
3020                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3021                    let down_proj =
3022                        cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
3023                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3024                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3025                            / weight_pack_factor;
3026                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3027                            / weight_pack_factor;
3028                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3029                            / weight_pack_factor;
3030                        gate_proj + up_proj + down_proj
3031                    } else {
3032                        0
3033                    };
3034                    let gate_weight = n_routed_experts * cfg.hidden_size;
3035                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3036                } else {
3037                    let h_size = cfg.hidden_size;
3038                    let i_size = cfg.intermediate_size;
3039                    let gate_proj = h_size * i_size / weight_pack_factor;
3040                    let up_proj = h_size * i_size / weight_pack_factor;
3041                    let down_proj = i_size * h_size / weight_pack_factor;
3042                    sum += gate_proj + up_proj + down_proj;
3043                }
3044                sum
3045            };
3046
3047            per_layer_elems.push(
3048                input_layernorm
3049                    + post_attention_layernorm
3050                    + q_proj
3051                    + kv_a_layernorm
3052                    + kv_a_proj_with_mqa
3053                    + kv_b_proj
3054                    + o_proj
3055                    + moe_block,
3056            );
3057        }
3058
3059        Ok(per_layer_elems
3060            .into_iter()
3061            .map(|x| x * dtype.size_in_bytes())
3062            .collect())
3063    }
3064
3065    fn num_layers(&self, config: &str) -> Result<usize> {
3066        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3067        Ok(cfg.num_hidden_layers)
3068    }
3069
3070    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3071        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3072
3073        let cfg = ModelConfigMetadata {
3074            max_seq_len: cfg.max_position_embeddings,
3075            num_layers: cfg.num_hidden_layers,
3076            hidden_size: cfg.hidden_size,
3077            num_kv_heads: cfg.num_attention_heads,
3078            num_attn_heads: cfg.num_attention_heads,
3079            sliding_window: None,
3080            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3081            v_head_dim: cfg.v_head_dim,
3082            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3083        };
3084
3085        Ok(Box::new(cfg))
3086    }
3087}
3088
3089/// [`NormalLoader`] for a Qwen 3 model.
3090///
3091/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
3092pub struct Qwen3Loader;
3093
3094impl NormalModelLoader for Qwen3Loader {
3095    fn load(
3096        &self,
3097        config: &str,
3098        vb: ShardedVarBuilder,
3099        normal_loading_metadata: NormalLoadingMetadata,
3100        attention_mechanism: AttentionImplementation,
3101    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3102        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3103
3104        Ok(Box::new(models::qwen3::Model::new(
3105            &cfg,
3106            vb,
3107            self.is_gptx(config)?,
3108            normal_loading_metadata,
3109            attention_mechanism,
3110        )?))
3111    }
3112    fn load_xlora(
3113        &self,
3114        _config: &str,
3115        _vb: ShardedVarBuilder,
3116        _lora_config: &[((String, String), LoraConfig)],
3117        _xlora_config: Option<XLoraConfig>,
3118        _xlora_ordering: Ordering,
3119        _normal_loading_metadata: NormalLoadingMetadata,
3120        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3121    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3122        todo!()
3123    }
3124    fn is_gptx(&self, _: &str) -> Result<bool> {
3125        Ok(true)
3126    }
3127    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3128        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3129
3130        Ok(Box::new(cfg))
3131    }
3132}
3133
3134impl IsqModelLoader for Qwen3Loader {
3135    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3136        Ok(vec![
3137            Regex::new(r"lm_head\.(weight|bias)$")?,
3138            // Attention
3139            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3140            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3141            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3142            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3143            // MLP
3144            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3145            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3146            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3147        ])
3148    }
3149    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3150        self.isq_layer_regexes(config)
3151    }
3152}
3153
3154impl DeviceMappedModelLoader for Qwen3Loader {
3155    fn mapped_max_act_size_elems(
3156        &self,
3157        config: &str,
3158        params: &AutoDeviceMapParams,
3159    ) -> Result<usize> {
3160        let AutoDeviceMapParams::Text {
3161            max_seq_len,
3162            max_batch_size,
3163        } = params
3164        else {
3165            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3166        };
3167
3168        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3169
3170        Ok(
3171            max_batch_size
3172                * cfg.num_attention_heads
3173                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3174        )
3175    }
3176    fn non_mapped_max_act_size_elems(
3177        &self,
3178        _config: &str,
3179        _params: &AutoDeviceMapParams,
3180    ) -> Result<usize> {
3181        Ok(0)
3182    }
3183
3184    fn non_mapped_size_in_bytes(
3185        &self,
3186        config: &str,
3187        dtype: DType,
3188        weight_pack_factor: usize,
3189        _matformer_config: Option<&MatformerSliceConfig>,
3190    ) -> Result<usize> {
3191        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3192        let elems = {
3193            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3194            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3195            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3196                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3197            } else {
3198                0
3199            };
3200            let norm = cfg.hidden_size;
3201            embed_tokens + lm_head + norm
3202        };
3203        Ok(elems * dtype.size_in_bytes())
3204    }
3205
3206    fn layer_sizes_in_bytes(
3207        &self,
3208        config: &str,
3209        dtype: DType,
3210        weight_pack_factor: usize,
3211        _matformer_config: Option<&MatformerSliceConfig>,
3212    ) -> Result<Vec<usize>> {
3213        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3214        let per_layer_elems = {
3215            let input_layernorm = cfg.hidden_size;
3216            let post_attention_layernorm = cfg.hidden_size;
3217
3218            let size_in = cfg.hidden_size;
3219            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3220            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3221            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3222            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3223            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3224            let o_proj = size_q * size_in / weight_pack_factor;
3225
3226            let h_size = cfg.hidden_size;
3227            let i_size = cfg.intermediate_size;
3228            let gate_proj = h_size * i_size / weight_pack_factor;
3229            let up_proj = h_size * i_size / weight_pack_factor;
3230            let down_proj = i_size * h_size / weight_pack_factor;
3231
3232            let q_norm = cfg.head_dim();
3233            let k_norm = cfg.head_dim();
3234
3235            input_layernorm
3236                + post_attention_layernorm
3237                + q_proj
3238                + k_proj
3239                + v_proj
3240                + o_proj
3241                + gate_proj
3242                + up_proj
3243                + down_proj
3244                + q_norm
3245                + k_norm
3246        };
3247        Ok(vec![
3248            per_layer_elems * dtype.size_in_bytes();
3249            cfg.num_hidden_layers
3250        ])
3251    }
3252
3253    fn num_layers(&self, config: &str) -> Result<usize> {
3254        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3255        Ok(cfg.num_hidden_layers)
3256    }
3257
3258    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3259        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3260
3261        let cfg = ModelConfigMetadata {
3262            max_seq_len: cfg.max_position_embeddings,
3263            num_layers: cfg.num_hidden_layers,
3264            hidden_size: cfg.hidden_size,
3265            num_kv_heads: cfg.num_key_value_heads,
3266            num_attn_heads: cfg.num_attention_heads,
3267            sliding_window: cfg.sliding_window,
3268            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3269            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3270            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3271        };
3272
3273        Ok(Box::new(cfg))
3274    }
3275}
3276
3277/// [`NormalLoader`] for a GLM 4 model.
3278///
3279/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
3280pub struct GLM4Loader;
3281
3282impl NormalModelLoader for GLM4Loader {
3283    fn load(
3284        &self,
3285        config: &str,
3286        vb: ShardedVarBuilder,
3287        normal_loading_metadata: NormalLoadingMetadata,
3288        attention_mechanism: AttentionImplementation,
3289    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3290        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3291
3292        Ok(Box::new(models::glm4::Model::new(
3293            &cfg,
3294            vb,
3295            self.is_gptx(config)?,
3296            normal_loading_metadata,
3297            attention_mechanism,
3298        )?))
3299    }
3300    fn load_xlora(
3301        &self,
3302        _config: &str,
3303        _vb: ShardedVarBuilder,
3304        _lora_config: &[((String, String), LoraConfig)],
3305        _xlora_config: Option<XLoraConfig>,
3306        _xlora_ordering: Ordering,
3307        _normal_loading_metadata: NormalLoadingMetadata,
3308        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3309    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3310        todo!()
3311    }
3312    fn is_gptx(&self, _: &str) -> Result<bool> {
3313        Ok(true)
3314    }
3315    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3316        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3317
3318        Ok(Box::new(cfg))
3319    }
3320}
3321
3322impl IsqModelLoader for GLM4Loader {
3323    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3324        Ok(vec![
3325            Regex::new(r"lm_head\.(weight|bias)$")?,
3326            // Attention
3327            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3328            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3329            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3330            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3331            // MLP
3332            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3333            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3334            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3335        ])
3336    }
3337    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3338        self.isq_layer_regexes(config)
3339    }
3340}
3341
3342impl DeviceMappedModelLoader for GLM4Loader {
3343    fn mapped_max_act_size_elems(
3344        &self,
3345        config: &str,
3346        params: &AutoDeviceMapParams,
3347    ) -> Result<usize> {
3348        let AutoDeviceMapParams::Text {
3349            max_seq_len,
3350            max_batch_size,
3351        } = params
3352        else {
3353            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3354        };
3355
3356        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3357
3358        Ok(
3359            max_batch_size
3360                * cfg.num_attention_heads
3361                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3362        )
3363    }
3364    fn non_mapped_max_act_size_elems(
3365        &self,
3366        _config: &str,
3367        _params: &AutoDeviceMapParams,
3368    ) -> Result<usize> {
3369        Ok(0)
3370    }
3371
3372    fn non_mapped_size_in_bytes(
3373        &self,
3374        config: &str,
3375        dtype: DType,
3376        weight_pack_factor: usize,
3377        _matformer_config: Option<&MatformerSliceConfig>,
3378    ) -> Result<usize> {
3379        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3380        let elems = {
3381            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3382            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3383            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3384                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3385            } else {
3386                0
3387            };
3388            let norm = cfg.hidden_size;
3389            embed_tokens + lm_head + norm
3390        };
3391        Ok(elems * dtype.size_in_bytes())
3392    }
3393
3394    fn layer_sizes_in_bytes(
3395        &self,
3396        config: &str,
3397        dtype: DType,
3398        weight_pack_factor: usize,
3399        _matformer_config: Option<&MatformerSliceConfig>,
3400    ) -> Result<Vec<usize>> {
3401        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3402        let per_layer_elems = {
3403            let input_layernorm = cfg.hidden_size;
3404            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3405
3406            let size_in = cfg.hidden_size;
3407            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3408            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3409            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3410            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3411            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3412            let o_proj = size_q * size_in / weight_pack_factor;
3413
3414            let h_size = cfg.hidden_size;
3415            let i_size = cfg.intermediate_size;
3416            let gate_proj = h_size * i_size / weight_pack_factor;
3417            let up_proj = h_size * i_size / weight_pack_factor;
3418            let down_proj = i_size * h_size / weight_pack_factor;
3419
3420            input_layernorm
3421                + post_attention_layernorm
3422                + q_proj
3423                + k_proj
3424                + v_proj
3425                + o_proj
3426                + gate_proj
3427                + up_proj
3428                + down_proj
3429        };
3430        Ok(vec![
3431            per_layer_elems * dtype.size_in_bytes();
3432            cfg.num_hidden_layers
3433        ])
3434    }
3435
3436    fn num_layers(&self, config: &str) -> Result<usize> {
3437        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3438        Ok(cfg.num_hidden_layers)
3439    }
3440
3441    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3442        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3443
3444        let cfg = ModelConfigMetadata {
3445            max_seq_len: cfg.max_position_embeddings,
3446            num_layers: cfg.num_hidden_layers,
3447            hidden_size: cfg.hidden_size,
3448            num_kv_heads: cfg.num_key_value_heads,
3449            num_attn_heads: cfg.num_attention_heads,
3450            sliding_window: cfg.sliding_window,
3451            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3452            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3453            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3454        };
3455
3456        Ok(Box::new(cfg))
3457    }
3458}
3459
3460/// [`NormalLoader`] for a GLM 4 MoE Lite model (GLM-4.7-Flash).
3461///
3462/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
3463pub struct GLM4MoeLiteLoader;
3464
3465impl NormalModelLoader for GLM4MoeLiteLoader {
3466    fn load(
3467        &self,
3468        config: &str,
3469        vb: ShardedVarBuilder,
3470        normal_loading_metadata: NormalLoadingMetadata,
3471        attention_mechanism: AttentionImplementation,
3472    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3473        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3474        Ok(Box::new(models::glm4_moe_lite::Glm4MoeLite::new(
3475            &cfg,
3476            vb,
3477            self.is_gptx(config)?,
3478            normal_loading_metadata,
3479            attention_mechanism,
3480        )?))
3481    }
3482    fn load_xlora(
3483        &self,
3484        _config: &str,
3485        _vb: ShardedVarBuilder,
3486        _lora_config: &[((String, String), LoraConfig)],
3487        _xlora_config: Option<XLoraConfig>,
3488        _xlora_ordering: Ordering,
3489        _normal_loading_metadata: NormalLoadingMetadata,
3490        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3491    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3492        todo!()
3493    }
3494    fn is_gptx(&self, _: &str) -> Result<bool> {
3495        Ok(true)
3496    }
3497    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3498        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3499        Ok(Box::new(cfg))
3500    }
3501}
3502
3503impl IsqModelLoader for GLM4MoeLiteLoader {
3504    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3505        let mut data = vec![
3506            Regex::new(r"lm_head\.(weight|bias)$")?,
3507            // Attention (MLA)
3508            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3509            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3510            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3511            // Q LoRA projections
3512            Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3513            Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3514        ];
3515        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3516        for layer_idx in 0..cfg.num_hidden_layers {
3517            if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3518                // MoE layer
3519                for i in 0..cfg.n_routed_experts {
3520                    data.extend(vec![
3521                        Regex::new(&format!(
3522                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3523                        ))?,
3524                        Regex::new(&format!(
3525                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3526                        ))?,
3527                        Regex::new(&format!(
3528                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3529                        ))?,
3530                    ]);
3531                }
3532                if cfg.n_shared_experts > 0 {
3533                    data.extend(vec![
3534                        Regex::new(&format!(
3535                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3536                        ))?,
3537                        Regex::new(&format!(
3538                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3539                        ))?,
3540                        Regex::new(&format!(
3541                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3542                        ))?,
3543                    ]);
3544                }
3545            } else {
3546                // Dense MLP layer
3547                data.extend(vec![
3548                    Regex::new(&format!(
3549                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3550                    ))?,
3551                    Regex::new(&format!(
3552                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3553                    ))?,
3554                    Regex::new(&format!(
3555                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3556                    ))?,
3557                ]);
3558            };
3559        }
3560        Ok(data)
3561    }
3562    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3563        self.isq_layer_regexes(config)
3564    }
3565
3566    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3567        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3568        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3569        for layer_idx in 0..cfg.num_hidden_layers {
3570            if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3571                // MoE layer
3572                for i in 0..cfg.n_routed_experts {
3573                    data.extend(vec![
3574                        Regex::new(&format!(
3575                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3576                        ))?,
3577                        Regex::new(&format!(
3578                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3579                        ))?,
3580                        Regex::new(&format!(
3581                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3582                        ))?,
3583                    ]);
3584                }
3585                if cfg.n_shared_experts > 0 {
3586                    data.extend(vec![
3587                        Regex::new(&format!(
3588                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3589                        ))?,
3590                        Regex::new(&format!(
3591                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3592                        ))?,
3593                        Regex::new(&format!(
3594                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3595                        ))?,
3596                    ]);
3597                }
3598            } else {
3599                // Dense MLP layer
3600                data.extend(vec![
3601                    Regex::new(&format!(
3602                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3603                    ))?,
3604                    Regex::new(&format!(
3605                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3606                    ))?,
3607                    Regex::new(&format!(
3608                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3609                    ))?,
3610                ]);
3611            };
3612        }
3613        Ok(data)
3614    }
3615    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3616        self.isq_layer_regexes_moqe(config)
3617    }
3618}
3619
3620impl DeviceMappedModelLoader for GLM4MoeLiteLoader {
3621    fn mapped_max_act_size_elems(
3622        &self,
3623        config: &str,
3624        params: &AutoDeviceMapParams,
3625    ) -> Result<usize> {
3626        let AutoDeviceMapParams::Text {
3627            max_seq_len,
3628            max_batch_size,
3629        } = params
3630        else {
3631            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3632        };
3633
3634        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3635
3636        Ok(
3637            max_batch_size
3638                * cfg.num_attention_heads
3639                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3640        )
3641    }
3642    fn non_mapped_max_act_size_elems(
3643        &self,
3644        _config: &str,
3645        _params: &AutoDeviceMapParams,
3646    ) -> Result<usize> {
3647        Ok(0)
3648    }
3649
3650    fn non_mapped_size_in_bytes(
3651        &self,
3652        config: &str,
3653        dtype: DType,
3654        weight_pack_factor: usize,
3655        _matformer_config: Option<&MatformerSliceConfig>,
3656    ) -> Result<usize> {
3657        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3658        let elems = {
3659            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3660            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3661            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3662                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3663            } else {
3664                0
3665            };
3666            let norm = cfg.hidden_size;
3667            embed_tokens + lm_head + norm
3668        };
3669        Ok(elems * dtype.size_in_bytes())
3670    }
3671
3672    fn layer_sizes_in_bytes(
3673        &self,
3674        config: &str,
3675        dtype: DType,
3676        weight_pack_factor: usize,
3677        _matformer_config: Option<&MatformerSliceConfig>,
3678    ) -> Result<Vec<usize>> {
3679        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3680        let mut per_layer_elems = Vec::new();
3681
3682        for layer_idx in 0..cfg.num_hidden_layers {
3683            let input_layernorm = cfg.hidden_size;
3684            let post_attention_layernorm = cfg.hidden_size;
3685
3686            // Q LoRA projection
3687            let q_proj = {
3688                let a = cfg.hidden_size * cfg.q_lora_rank / weight_pack_factor;
3689                let norm = cfg.q_lora_rank;
3690                let b = (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.q_lora_rank
3691                    / weight_pack_factor;
3692                a + norm + b
3693            };
3694            let kv_a_proj_with_mqa =
3695                cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim) / weight_pack_factor;
3696            let kv_a_layernorm = cfg.kv_lora_rank;
3697            let kv_b_proj = cfg.kv_lora_rank
3698                * cfg.num_attention_heads
3699                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3700                / weight_pack_factor;
3701            let o_proj =
3702                cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size / weight_pack_factor;
3703
3704            let moe_block = {
3705                let mut sum = 0;
3706                if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3707                    // MoE layer
3708                    let h_size = cfg.hidden_size;
3709                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3710                        * cfg.n_routed_experts;
3711                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3712                        * cfg.n_routed_experts;
3713                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3714                        * cfg.n_routed_experts;
3715                    let shared_experts = if cfg.n_shared_experts > 0 {
3716                        let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3717                        let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3718                        let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
3719                        gate_proj + up_proj + down_proj
3720                    } else {
3721                        0
3722                    };
3723                    let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
3724                    let e_score_correction_bias = cfg.n_routed_experts;
3725                    sum += gate_proj
3726                        + up_proj
3727                        + down_proj
3728                        + shared_experts
3729                        + gate_weight
3730                        + e_score_correction_bias;
3731                } else {
3732                    // Dense MLP layer
3733                    let h_size = cfg.hidden_size;
3734                    let i_size = cfg.intermediate_size;
3735                    let gate_proj = h_size * i_size / weight_pack_factor;
3736                    let up_proj = h_size * i_size / weight_pack_factor;
3737                    let down_proj = i_size * h_size / weight_pack_factor;
3738                    sum += gate_proj + up_proj + down_proj;
3739                }
3740                sum
3741            };
3742
3743            per_layer_elems.push(
3744                input_layernorm
3745                    + post_attention_layernorm
3746                    + q_proj
3747                    + kv_a_layernorm
3748                    + kv_a_proj_with_mqa
3749                    + kv_b_proj
3750                    + o_proj
3751                    + moe_block,
3752            );
3753        }
3754
3755        Ok(per_layer_elems
3756            .into_iter()
3757            .map(|x| x * dtype.size_in_bytes())
3758            .collect())
3759    }
3760
3761    fn num_layers(&self, config: &str) -> Result<usize> {
3762        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3763        Ok(cfg.num_hidden_layers)
3764    }
3765
3766    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3767        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3768
3769        let cfg = ModelConfigMetadata {
3770            max_seq_len: cfg.max_position_embeddings,
3771            num_layers: cfg.num_hidden_layers,
3772            hidden_size: cfg.hidden_size,
3773            num_kv_heads: cfg.num_attention_heads,
3774            num_attn_heads: cfg.num_attention_heads,
3775            sliding_window: None,
3776            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3777            v_head_dim: cfg.v_head_dim,
3778            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3779        };
3780
3781        Ok(Box::new(cfg))
3782    }
3783}
3784
3785/// [`NormalLoader`] for a GLM 4 MoE model (GLM-4.5).
3786///
3787/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
3788pub struct GLM4MoeLoader;
3789
3790impl NormalModelLoader for GLM4MoeLoader {
3791    fn load(
3792        &self,
3793        config: &str,
3794        vb: ShardedVarBuilder,
3795        normal_loading_metadata: NormalLoadingMetadata,
3796        attention_mechanism: AttentionImplementation,
3797    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3798        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3799        Ok(Box::new(models::glm4_moe::Glm4Moe::new(
3800            &cfg,
3801            vb,
3802            self.is_gptx(config)?,
3803            normal_loading_metadata,
3804            attention_mechanism,
3805        )?))
3806    }
3807    fn load_xlora(
3808        &self,
3809        _config: &str,
3810        _vb: ShardedVarBuilder,
3811        _lora_config: &[((String, String), LoraConfig)],
3812        _xlora_config: Option<XLoraConfig>,
3813        _xlora_ordering: Ordering,
3814        _normal_loading_metadata: NormalLoadingMetadata,
3815        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3816    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3817        todo!()
3818    }
3819    fn is_gptx(&self, _: &str) -> Result<bool> {
3820        Ok(true)
3821    }
3822    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3823        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3824        Ok(Box::new(cfg))
3825    }
3826}
3827
3828impl IsqModelLoader for GLM4MoeLoader {
3829    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3830        let mut data = vec![
3831            Regex::new(r"lm_head\.(weight|bias)$")?,
3832            // Attention (standard GQA)
3833            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3834            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3835            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3836            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3837        ];
3838        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3839        for layer_idx in 0..cfg.num_hidden_layers {
3840            if layer_idx >= cfg.first_k_dense_replace {
3841                // MoE layer
3842                for i in 0..cfg.n_routed_experts {
3843                    data.extend(vec![
3844                        Regex::new(&format!(
3845                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3846                        ))?,
3847                        Regex::new(&format!(
3848                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3849                        ))?,
3850                        Regex::new(&format!(
3851                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3852                        ))?,
3853                    ]);
3854                }
3855                if cfg.n_shared_experts > 0 {
3856                    data.extend(vec![
3857                        Regex::new(&format!(
3858                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3859                        ))?,
3860                        Regex::new(&format!(
3861                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3862                        ))?,
3863                        Regex::new(&format!(
3864                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3865                        ))?,
3866                    ]);
3867                }
3868            } else {
3869                // Dense MLP layer
3870                data.extend(vec![
3871                    Regex::new(&format!(
3872                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3873                    ))?,
3874                    Regex::new(&format!(
3875                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3876                    ))?,
3877                    Regex::new(&format!(
3878                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3879                    ))?,
3880                ]);
3881            };
3882        }
3883        Ok(data)
3884    }
3885    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3886        self.isq_layer_regexes(config)
3887    }
3888
3889    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3890        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3891        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3892        for layer_idx in 0..cfg.num_hidden_layers {
3893            if layer_idx >= cfg.first_k_dense_replace {
3894                // MoE layer
3895                for i in 0..cfg.n_routed_experts {
3896                    data.extend(vec![
3897                        Regex::new(&format!(
3898                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3899                        ))?,
3900                        Regex::new(&format!(
3901                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3902                        ))?,
3903                        Regex::new(&format!(
3904                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3905                        ))?,
3906                    ]);
3907                }
3908                if cfg.n_shared_experts > 0 {
3909                    data.extend(vec![
3910                        Regex::new(&format!(
3911                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3912                        ))?,
3913                        Regex::new(&format!(
3914                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3915                        ))?,
3916                        Regex::new(&format!(
3917                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3918                        ))?,
3919                    ]);
3920                }
3921            } else {
3922                // Dense MLP layer
3923                data.extend(vec![
3924                    Regex::new(&format!(
3925                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3926                    ))?,
3927                    Regex::new(&format!(
3928                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3929                    ))?,
3930                    Regex::new(&format!(
3931                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3932                    ))?,
3933                ]);
3934            };
3935        }
3936        Ok(data)
3937    }
3938    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3939        self.isq_layer_regexes_moqe(config)
3940    }
3941}
3942
3943impl DeviceMappedModelLoader for GLM4MoeLoader {
3944    fn mapped_max_act_size_elems(
3945        &self,
3946        config: &str,
3947        params: &AutoDeviceMapParams,
3948    ) -> Result<usize> {
3949        let AutoDeviceMapParams::Text {
3950            max_seq_len,
3951            max_batch_size,
3952        } = params
3953        else {
3954            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3955        };
3956
3957        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3958
3959        Ok(
3960            max_batch_size
3961                * cfg.num_attention_heads
3962                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3963        )
3964    }
3965    fn non_mapped_max_act_size_elems(
3966        &self,
3967        _config: &str,
3968        _params: &AutoDeviceMapParams,
3969    ) -> Result<usize> {
3970        Ok(0)
3971    }
3972
3973    fn non_mapped_size_in_bytes(
3974        &self,
3975        config: &str,
3976        dtype: DType,
3977        weight_pack_factor: usize,
3978        _matformer_config: Option<&MatformerSliceConfig>,
3979    ) -> Result<usize> {
3980        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3981        let elems = {
3982            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3983            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3984                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3985            } else {
3986                0
3987            };
3988            let norm = cfg.hidden_size;
3989            embed_tokens + lm_head + norm
3990        };
3991        Ok(elems * dtype.size_in_bytes())
3992    }
3993
3994    fn layer_sizes_in_bytes(
3995        &self,
3996        config: &str,
3997        dtype: DType,
3998        weight_pack_factor: usize,
3999        _matformer_config: Option<&MatformerSliceConfig>,
4000    ) -> Result<Vec<usize>> {
4001        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4002        let mut per_layer_elems = Vec::new();
4003
4004        let head_dim = cfg.head_dim();
4005        for layer_idx in 0..cfg.num_hidden_layers {
4006            let input_layernorm = cfg.hidden_size;
4007            let post_attention_layernorm = cfg.hidden_size;
4008
4009            // Standard GQA attention
4010            let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor
4011                + bias_if!(cfg.attention_bias, cfg.num_attention_heads * head_dim);
4012            let k_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4013                + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4014            let v_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4015                + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4016            let o_proj = cfg.num_attention_heads * head_dim * cfg.hidden_size / weight_pack_factor;
4017
4018            // QK norm if enabled
4019            let qk_norm = if cfg.use_qk_norm {
4020                head_dim * 2 // q_norm + k_norm
4021            } else {
4022                0
4023            };
4024
4025            let moe_block = {
4026                let mut sum = 0;
4027                if layer_idx >= cfg.first_k_dense_replace {
4028                    // MoE layer
4029                    let h_size = cfg.hidden_size;
4030                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4031                        * cfg.n_routed_experts;
4032                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4033                        * cfg.n_routed_experts;
4034                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
4035                        * cfg.n_routed_experts;
4036                    let shared_experts = if cfg.n_shared_experts > 0 {
4037                        let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4038                        let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4039                        let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
4040                        gate_proj + up_proj + down_proj
4041                    } else {
4042                        0
4043                    };
4044                    let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
4045                    let e_score_correction_bias = cfg.n_routed_experts;
4046                    sum += gate_proj
4047                        + up_proj
4048                        + down_proj
4049                        + shared_experts
4050                        + gate_weight
4051                        + e_score_correction_bias;
4052                } else {
4053                    // Dense MLP layer
4054                    let h_size = cfg.hidden_size;
4055                    let i_size = cfg.intermediate_size;
4056                    let gate_proj = h_size * i_size / weight_pack_factor;
4057                    let up_proj = h_size * i_size / weight_pack_factor;
4058                    let down_proj = i_size * h_size / weight_pack_factor;
4059                    sum += gate_proj + up_proj + down_proj;
4060                }
4061                sum
4062            };
4063
4064            per_layer_elems.push(
4065                input_layernorm
4066                    + post_attention_layernorm
4067                    + q_proj
4068                    + k_proj
4069                    + v_proj
4070                    + o_proj
4071                    + qk_norm
4072                    + moe_block,
4073            );
4074        }
4075
4076        Ok(per_layer_elems
4077            .into_iter()
4078            .map(|x| x * dtype.size_in_bytes())
4079            .collect())
4080    }
4081
4082    fn num_layers(&self, config: &str) -> Result<usize> {
4083        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4084        Ok(cfg.num_hidden_layers)
4085    }
4086
4087    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4088        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4089
4090        let head_dim = cfg.head_dim();
4091        let cfg = ModelConfigMetadata {
4092            max_seq_len: cfg.max_position_embeddings,
4093            num_layers: cfg.num_hidden_layers,
4094            hidden_size: cfg.hidden_size,
4095            num_kv_heads: cfg.num_key_value_heads,
4096            num_attn_heads: cfg.num_attention_heads,
4097            sliding_window: None,
4098            k_head_dim: head_dim,
4099            v_head_dim: head_dim,
4100            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4101        };
4102
4103        Ok(Box::new(cfg))
4104    }
4105}
4106
4107/// [`NormalLoader`] for a Qwen 3 MoE model.
4108///
4109/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4110pub struct Qwen3MoELoader;
4111
4112impl NormalModelLoader for Qwen3MoELoader {
4113    fn load(
4114        &self,
4115        config: &str,
4116        vb: ShardedVarBuilder,
4117        normal_loading_metadata: NormalLoadingMetadata,
4118        attention_mechanism: AttentionImplementation,
4119    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4120        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4121
4122        Ok(Box::new(models::qwen3_moe::Model::new(
4123            &cfg,
4124            vb,
4125            self.is_gptx(config)?,
4126            normal_loading_metadata,
4127            attention_mechanism,
4128        )?))
4129    }
4130    fn load_xlora(
4131        &self,
4132        _config: &str,
4133        _vb: ShardedVarBuilder,
4134        _lora_config: &[((String, String), LoraConfig)],
4135        _xlora_config: Option<XLoraConfig>,
4136        _xlora_ordering: Ordering,
4137        _normal_loading_metadata: NormalLoadingMetadata,
4138        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4139    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4140        todo!()
4141    }
4142    fn is_gptx(&self, _: &str) -> Result<bool> {
4143        Ok(true)
4144    }
4145    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4146        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4147
4148        Ok(Box::new(cfg))
4149    }
4150}
4151
4152impl IsqModelLoader for Qwen3MoELoader {
4153    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4154        Ok(vec![
4155            Regex::new(r"lm_head\.(weight|bias)$")?,
4156            // Attention
4157            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4158            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4159            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4160            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4161            // MLP
4162            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4163            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4164            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4165            // MLP MoE
4166            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
4167            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
4168            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
4169        ])
4170    }
4171    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4172        self.isq_layer_regexes(config)
4173    }
4174    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
4175        self.isq_layer_regexes_moqe(config)
4176    }
4177}
4178
4179impl DeviceMappedModelLoader for Qwen3MoELoader {
4180    fn mapped_max_act_size_elems(
4181        &self,
4182        config: &str,
4183        params: &AutoDeviceMapParams,
4184    ) -> Result<usize> {
4185        let AutoDeviceMapParams::Text {
4186            max_seq_len,
4187            max_batch_size,
4188        } = params
4189        else {
4190            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4191        };
4192
4193        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4194
4195        Ok(
4196            max_batch_size
4197                * cfg.num_attention_heads
4198                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4199        )
4200    }
4201    fn non_mapped_max_act_size_elems(
4202        &self,
4203        _config: &str,
4204        _params: &AutoDeviceMapParams,
4205    ) -> Result<usize> {
4206        Ok(0)
4207    }
4208
4209    fn non_mapped_size_in_bytes(
4210        &self,
4211        config: &str,
4212        dtype: DType,
4213        weight_pack_factor: usize,
4214        _matformer_config: Option<&MatformerSliceConfig>,
4215    ) -> Result<usize> {
4216        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4217        let elems = {
4218            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
4219            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4220            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4221                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4222            } else {
4223                0
4224            };
4225            let norm = cfg.hidden_size;
4226            embed_tokens + lm_head + norm
4227        };
4228        Ok(elems * dtype.size_in_bytes())
4229    }
4230
4231    fn layer_sizes_in_bytes(
4232        &self,
4233        config: &str,
4234        dtype: DType,
4235        weight_pack_factor: usize,
4236        _matformer_config: Option<&MatformerSliceConfig>,
4237    ) -> Result<Vec<usize>> {
4238        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4239
4240        let mut layer_sizes_in_bytes = Vec::new();
4241        for layer_idx in 0..cfg.num_hidden_layers {
4242            let input_layernorm = cfg.hidden_size;
4243            let post_attention_layernorm = cfg.hidden_size;
4244
4245            let size_in = cfg.hidden_size;
4246            let size_q = cfg.head_dim() * cfg.num_attention_heads;
4247            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
4248            let q_proj = size_in * size_q / weight_pack_factor;
4249            let k_proj = size_in * size_kv / weight_pack_factor;
4250            let v_proj = size_in * size_kv / weight_pack_factor;
4251            let o_proj = size_q * size_in / weight_pack_factor;
4252
4253            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
4254                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
4255            {
4256                let gate_size = cfg.hidden_size * cfg.num_experts;
4257                let expert_size = {
4258                    let h_size = cfg.hidden_size;
4259                    let i_size = cfg.moe_intermediate_size;
4260                    let gate_proj = h_size * i_size / weight_pack_factor;
4261                    let up_proj = h_size * i_size / weight_pack_factor;
4262                    let down_proj = i_size * h_size / weight_pack_factor;
4263                    gate_proj + up_proj + down_proj
4264                };
4265                expert_size * cfg.num_experts + gate_size
4266            } else {
4267                let h_size = cfg.hidden_size;
4268                let i_size = cfg.intermediate_size;
4269                let gate_proj = h_size * i_size / weight_pack_factor;
4270                let up_proj = h_size * i_size / weight_pack_factor;
4271                let down_proj = i_size * h_size / weight_pack_factor;
4272                gate_proj + up_proj + down_proj
4273            };
4274
4275            let q_norm = cfg.head_dim();
4276            let k_norm = cfg.head_dim();
4277
4278            let size_elems = input_layernorm
4279                + post_attention_layernorm
4280                + q_proj
4281                + k_proj
4282                + v_proj
4283                + o_proj
4284                + mlp_size
4285                + q_norm
4286                + k_norm;
4287
4288            let size_in_bytes = size_elems * dtype.size_in_bytes();
4289            layer_sizes_in_bytes.push(size_in_bytes);
4290        }
4291
4292        Ok(layer_sizes_in_bytes)
4293    }
4294
4295    fn num_layers(&self, config: &str) -> Result<usize> {
4296        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4297        Ok(cfg.num_hidden_layers)
4298    }
4299
4300    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4301        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4302
4303        let cfg = ModelConfigMetadata {
4304            max_seq_len: cfg.max_position_embeddings,
4305            num_layers: cfg.num_hidden_layers,
4306            hidden_size: cfg.hidden_size,
4307            num_kv_heads: cfg.num_key_value_heads,
4308            num_attn_heads: cfg.num_attention_heads,
4309            sliding_window: cfg.sliding_window,
4310            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4311            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4312            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4313        };
4314
4315        Ok(Box::new(cfg))
4316    }
4317}
4318
4319// ======================== SmolLm3 loader
4320
4321/// [`NormalLoader`] for a SmolLm3 model.
4322///
4323/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4324pub struct SmolLm3Loader;
4325
4326impl NormalModelLoader for SmolLm3Loader {
4327    fn load(
4328        &self,
4329        config: &str,
4330        vb: ShardedVarBuilder,
4331        normal_loading_metadata: NormalLoadingMetadata,
4332        attention_mechanism: AttentionImplementation,
4333    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4334        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4335
4336        Ok(Box::new(models::smollm3::SmolLm3::new(
4337            &cfg,
4338            vb,
4339            self.is_gptx(config)?,
4340            normal_loading_metadata,
4341            attention_mechanism,
4342        )?))
4343    }
4344    fn load_xlora(
4345        &self,
4346        _config: &str,
4347        _vb: ShardedVarBuilder,
4348        _lora_config: &[((String, String), LoraConfig)],
4349        _xlora_config: Option<XLoraConfig>,
4350        _xlora_ordering: Ordering,
4351        _normal_loading_metadata: NormalLoadingMetadata,
4352        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4353    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4354        todo!()
4355    }
4356    fn is_gptx(&self, _: &str) -> Result<bool> {
4357        Ok(true)
4358    }
4359    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4360        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4361        Ok(Box::new(cfg))
4362    }
4363}
4364
4365impl IsqModelLoader for SmolLm3Loader {
4366    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4367        Ok(vec![
4368            Regex::new(r"lm_head\.(weight|bias)$")?,
4369            // Attention
4370            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4371            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4372            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4373            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4374            // MLP
4375            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4376            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4377            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4378        ])
4379    }
4380    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4381        self.isq_layer_regexes(config)
4382    }
4383}
4384
4385impl DeviceMappedModelLoader for SmolLm3Loader {
4386    fn mapped_max_act_size_elems(
4387        &self,
4388        config: &str,
4389        params: &AutoDeviceMapParams,
4390    ) -> Result<usize> {
4391        let AutoDeviceMapParams::Text {
4392            max_seq_len,
4393            max_batch_size,
4394        } = params
4395        else {
4396            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4397        };
4398
4399        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4400
4401        Ok(
4402            max_batch_size
4403                * cfg.num_attention_heads
4404                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4405        )
4406    }
4407    fn non_mapped_max_act_size_elems(
4408        &self,
4409        _config: &str,
4410        _params: &AutoDeviceMapParams,
4411    ) -> Result<usize> {
4412        Ok(0)
4413    }
4414
4415    fn non_mapped_size_in_bytes(
4416        &self,
4417        config: &str,
4418        dtype: DType,
4419        weight_pack_factor: usize,
4420        _matformer_config: Option<&MatformerSliceConfig>,
4421    ) -> Result<usize> {
4422        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4423
4424        let elems = {
4425            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4426            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4427            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4428                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4429            } else {
4430                0
4431            };
4432            let norm = cfg.hidden_size;
4433            embed_tokens + lm_head + norm
4434        };
4435        Ok(elems * dtype.size_in_bytes())
4436    }
4437
4438    fn layer_sizes_in_bytes(
4439        &self,
4440        config: &str,
4441        dtype: DType,
4442        weight_pack_factor: usize,
4443        _matformer_config: Option<&MatformerSliceConfig>,
4444    ) -> Result<Vec<usize>> {
4445        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4446
4447        let per_layer_elems = {
4448            let input_layernorm = cfg.hidden_size;
4449            let post_attention_layernorm = cfg.hidden_size;
4450
4451            let size_in = cfg.hidden_size;
4452            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4453            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4454            let q_proj = size_in * size_q / weight_pack_factor;
4455            let k_proj = size_in * size_kv / weight_pack_factor;
4456            let v_proj = size_in * size_kv / weight_pack_factor;
4457            let o_proj = size_q * size_in / weight_pack_factor;
4458
4459            let h_size = cfg.hidden_size;
4460            let i_size = cfg.intermediate_size;
4461            let gate_proj = h_size * i_size / weight_pack_factor;
4462            let up_proj = h_size * i_size / weight_pack_factor;
4463            let down_proj = i_size * h_size / weight_pack_factor;
4464
4465            input_layernorm
4466                + post_attention_layernorm
4467                + q_proj
4468                + k_proj
4469                + v_proj
4470                + o_proj
4471                + gate_proj
4472                + up_proj
4473                + down_proj
4474        };
4475        Ok(vec![
4476            per_layer_elems * dtype.size_in_bytes();
4477            cfg.num_hidden_layers
4478        ])
4479    }
4480
4481    fn num_layers(&self, config: &str) -> Result<usize> {
4482        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4483
4484        Ok(cfg.num_hidden_layers)
4485    }
4486    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4487        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4488
4489        let cfg = ModelConfigMetadata {
4490            max_seq_len: cfg.max_position_embeddings,
4491            num_layers: cfg.num_hidden_layers,
4492            hidden_size: cfg.hidden_size,
4493            num_kv_heads: cfg.num_key_value_heads,
4494            num_attn_heads: cfg.num_attention_heads,
4495            sliding_window: None,
4496            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4497            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4498            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4499        };
4500
4501        Ok(Box::new(cfg))
4502    }
4503}
4504
4505// ======================== GraniteMoeHybrid loader
4506
4507/// [`NormalLoader`] for a GraniteMoeHybrid model (IBM Granite 4.0).
4508///
4509/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4510pub struct GraniteMoeHybridLoader;
4511
4512impl NormalModelLoader for GraniteMoeHybridLoader {
4513    fn load(
4514        &self,
4515        config: &str,
4516        vb: ShardedVarBuilder,
4517        normal_loading_metadata: NormalLoadingMetadata,
4518        attention_mechanism: AttentionImplementation,
4519    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4520        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4521
4522        Ok(Box::new(models::granite::GraniteMoeHybrid::new(
4523            &cfg,
4524            vb,
4525            self.is_gptx(config)?,
4526            normal_loading_metadata,
4527            attention_mechanism,
4528        )?))
4529    }
4530    fn load_xlora(
4531        &self,
4532        _config: &str,
4533        _vb: ShardedVarBuilder,
4534        _lora_config: &[((String, String), LoraConfig)],
4535        _xlora_config: Option<XLoraConfig>,
4536        _xlora_ordering: Ordering,
4537        _normal_loading_metadata: NormalLoadingMetadata,
4538        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4539    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4540        anyhow::bail!("GraniteMoeHybrid does not support X-LoRA")
4541    }
4542    fn is_gptx(&self, _: &str) -> Result<bool> {
4543        Ok(true)
4544    }
4545    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4546        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4547        Ok(Box::new(cfg))
4548    }
4549    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4550        Ok(true)
4551    }
4552}
4553
4554impl IsqModelLoader for GraniteMoeHybridLoader {
4555    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4556        Ok(vec![
4557            Regex::new(r"lm_head\.(weight|bias)$")?,
4558            // Attention
4559            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4560            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4561            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4562            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4563            // MLP (GraniteMLP uses shared_mlp.input_linear and shared_mlp.output_linear)
4564            Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
4565            Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
4566        ])
4567    }
4568    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4569        self.isq_layer_regexes(config)
4570    }
4571}
4572
4573impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
4574    fn mapped_max_act_size_elems(
4575        &self,
4576        config: &str,
4577        params: &AutoDeviceMapParams,
4578    ) -> Result<usize> {
4579        let AutoDeviceMapParams::Text {
4580            max_seq_len,
4581            max_batch_size,
4582        } = params
4583        else {
4584            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4585        };
4586
4587        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4588
4589        Ok(
4590            max_batch_size
4591                * cfg.num_attention_heads
4592                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4593        )
4594    }
4595    fn non_mapped_max_act_size_elems(
4596        &self,
4597        _config: &str,
4598        _params: &AutoDeviceMapParams,
4599    ) -> Result<usize> {
4600        Ok(0)
4601    }
4602
4603    fn non_mapped_size_in_bytes(
4604        &self,
4605        config: &str,
4606        dtype: DType,
4607        weight_pack_factor: usize,
4608        _matformer_config: Option<&MatformerSliceConfig>,
4609    ) -> Result<usize> {
4610        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4611
4612        let elems = {
4613            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4614            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4615            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4616                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4617            } else {
4618                0
4619            };
4620            let norm = cfg.hidden_size;
4621            embed_tokens + lm_head + norm
4622        };
4623        Ok(elems * dtype.size_in_bytes())
4624    }
4625
4626    fn layer_sizes_in_bytes(
4627        &self,
4628        config: &str,
4629        dtype: DType,
4630        weight_pack_factor: usize,
4631        _matformer_config: Option<&MatformerSliceConfig>,
4632    ) -> Result<Vec<usize>> {
4633        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4634
4635        let per_layer_elems = {
4636            let input_layernorm = cfg.hidden_size;
4637            let post_attention_layernorm = cfg.hidden_size;
4638
4639            let size_in = cfg.hidden_size;
4640            let size_q = cfg.head_dim() * cfg.num_attention_heads;
4641            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
4642            let q_proj = size_in * size_q / weight_pack_factor;
4643            let k_proj = size_in * size_kv / weight_pack_factor;
4644            let v_proj = size_in * size_kv / weight_pack_factor;
4645            let o_proj = size_q * size_in / weight_pack_factor;
4646
4647            let h_size = cfg.hidden_size;
4648            let shared_i_size = cfg.shared_intermediate_size();
4649            // GraniteMLP: input_linear (h_size -> shared_i_size * 2), output_linear (shared_i_size -> h_size)
4650            let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
4651            let output_linear = shared_i_size * h_size / weight_pack_factor;
4652
4653            input_layernorm
4654                + post_attention_layernorm
4655                + q_proj
4656                + k_proj
4657                + v_proj
4658                + o_proj
4659                + input_linear
4660                + output_linear
4661        };
4662        Ok(vec![
4663            per_layer_elems * dtype.size_in_bytes();
4664            cfg.num_hidden_layers
4665        ])
4666    }
4667
4668    fn num_layers(&self, config: &str) -> Result<usize> {
4669        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4670
4671        Ok(cfg.num_hidden_layers)
4672    }
4673    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4674        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4675
4676        let cfg = ModelConfigMetadata {
4677            max_seq_len: cfg.max_position_embeddings,
4678            num_layers: cfg.num_hidden_layers,
4679            hidden_size: cfg.hidden_size,
4680            num_kv_heads: cfg.num_key_value_heads(),
4681            num_attn_heads: cfg.num_attention_heads,
4682            sliding_window: None,
4683            k_head_dim: cfg.head_dim(),
4684            v_head_dim: cfg.head_dim(),
4685            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4686        };
4687
4688        Ok(Box::new(cfg))
4689    }
4690}
4691
4692// ======================== GPT-OSS loader
4693
4694/// [`NormalLoader`] for a GPT-OSS model.
4695///
4696/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4697pub struct GptOssLoader;
4698
4699impl NormalModelLoader for GptOssLoader {
4700    fn load(
4701        &self,
4702        config: &str,
4703        vb: ShardedVarBuilder,
4704        normal_loading_metadata: NormalLoadingMetadata,
4705        attention_mechanism: AttentionImplementation,
4706    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4707        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4708
4709        Ok(Box::new(models::gpt_oss::Model::new(
4710            &cfg,
4711            vb,
4712            self.is_gptx(config)?,
4713            normal_loading_metadata,
4714            attention_mechanism,
4715        )?))
4716    }
4717    fn load_xlora(
4718        &self,
4719        _config: &str,
4720        _vb: ShardedVarBuilder,
4721        _lora_config: &[((String, String), LoraConfig)],
4722        _xlora_config: Option<XLoraConfig>,
4723        _xlora_ordering: Ordering,
4724        _normal_loading_metadata: NormalLoadingMetadata,
4725        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4726    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4727        anyhow::bail!("GPT-OSS does not support X-LoRA")
4728    }
4729    fn is_gptx(&self, _: &str) -> Result<bool> {
4730        Ok(true)
4731    }
4732    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4733        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4734        Ok(Box::new(cfg))
4735    }
4736}
4737
4738impl IsqModelLoader for GptOssLoader {
4739    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4740        // Only attention layers are ISQ-able - MoE experts are already MXFP4 quantized
4741        Ok(vec![
4742            Regex::new(r"lm_head\.(weight|bias)$")?,
4743            // Attention
4744            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4745            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4746            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4747            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4748        ])
4749    }
4750    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4751        self.isq_layer_regexes(config)
4752    }
4753}
4754
4755impl DeviceMappedModelLoader for GptOssLoader {
4756    fn mapped_max_act_size_elems(
4757        &self,
4758        config: &str,
4759        params: &AutoDeviceMapParams,
4760    ) -> Result<usize> {
4761        let AutoDeviceMapParams::Text {
4762            max_seq_len,
4763            max_batch_size,
4764        } = params
4765        else {
4766            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4767        };
4768
4769        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4770
4771        Ok(
4772            max_batch_size
4773                * cfg.num_attention_heads
4774                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4775        )
4776    }
4777    fn non_mapped_max_act_size_elems(
4778        &self,
4779        _config: &str,
4780        _params: &AutoDeviceMapParams,
4781    ) -> Result<usize> {
4782        Ok(0)
4783    }
4784
4785    fn non_mapped_size_in_bytes(
4786        &self,
4787        config: &str,
4788        dtype: DType,
4789        weight_pack_factor: usize,
4790        _matformer_config: Option<&MatformerSliceConfig>,
4791    ) -> Result<usize> {
4792        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4793
4794        let elems = {
4795            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4796            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4797                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4798            } else {
4799                0
4800            };
4801            let norm = cfg.hidden_size;
4802            embed_tokens + lm_head + norm
4803        };
4804        Ok(elems * dtype.size_in_bytes())
4805    }
4806
4807    fn layer_sizes_in_bytes(
4808        &self,
4809        config: &str,
4810        dtype: DType,
4811        weight_pack_factor: usize,
4812        _matformer_config: Option<&MatformerSliceConfig>,
4813    ) -> Result<Vec<usize>> {
4814        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4815
4816        let per_layer_elems = {
4817            let input_layernorm = cfg.hidden_size;
4818            let post_attention_layernorm = cfg.hidden_size;
4819
4820            let size_in = cfg.hidden_size;
4821            let head_dim = cfg.head_dim();
4822            let size_q = head_dim * cfg.num_attention_heads;
4823            let size_kv = head_dim * cfg.num_key_value_heads;
4824            let q_proj =
4825                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4826            let k_proj =
4827                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4828            let v_proj =
4829                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4830            let o_proj =
4831                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4832
4833            // MoE experts - MXFP4 quantized, so very compact
4834            // gate_up_proj: [num_experts, intermediate_size * 2, hidden_size/2] packed
4835            // down_proj: [num_experts, hidden_size, intermediate_size/2] packed
4836            // At 4 bits per weight, packing factor is 2
4837            let mxfp4_pack = 2;
4838            let gate_up_proj_size =
4839                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4840            let down_proj_size =
4841                cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4842            // Plus scales at 1 byte per 32 elements
4843            let gate_up_scales =
4844                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4845            let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4846            // Plus biases
4847            let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4848            let down_bias = cfg.num_local_experts * cfg.hidden_size;
4849            // Router
4850            let router = cfg.hidden_size * cfg.num_local_experts;
4851            // Sinks per head
4852            let sinks = cfg.num_attention_heads;
4853
4854            input_layernorm
4855                + post_attention_layernorm
4856                + q_proj
4857                + k_proj
4858                + v_proj
4859                + o_proj
4860                + gate_up_proj_size
4861                + down_proj_size
4862                + gate_up_scales
4863                + down_scales
4864                + gate_up_bias
4865                + down_bias
4866                + router
4867                + sinks
4868        };
4869        Ok(vec![
4870            per_layer_elems * dtype.size_in_bytes();
4871            cfg.num_hidden_layers
4872        ])
4873    }
4874
4875    fn num_layers(&self, config: &str) -> Result<usize> {
4876        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4877
4878        Ok(cfg.num_hidden_layers)
4879    }
4880    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4881        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4882
4883        let head_dim = cfg.head_dim();
4884        let cfg = ModelConfigMetadata {
4885            max_seq_len: cfg.max_position_embeddings,
4886            num_layers: cfg.num_hidden_layers,
4887            hidden_size: cfg.hidden_size,
4888            num_kv_heads: cfg.num_key_value_heads,
4889            num_attn_heads: cfg.num_attention_heads,
4890            sliding_window: cfg.sliding_window,
4891            k_head_dim: head_dim,
4892            v_head_dim: head_dim,
4893            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4894        };
4895
4896        Ok(Box::new(cfg))
4897    }
4898}
4899
4900// ======================== Qwen3Next loader
4901
4902/// [`NormalLoader`] for a Qwen3Next (Qwen3-Coder-Next) model.
4903///
4904/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4905pub struct Qwen3NextLoader;
4906
4907impl NormalModelLoader for Qwen3NextLoader {
4908    fn load(
4909        &self,
4910        config: &str,
4911        vb: ShardedVarBuilder,
4912        normal_loading_metadata: NormalLoadingMetadata,
4913        attention_mechanism: AttentionImplementation,
4914    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4915        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4916
4917        Ok(Box::new(models::qwen3_next::Model::new(
4918            &cfg,
4919            vb,
4920            self.is_gptx(config)?,
4921            normal_loading_metadata,
4922            attention_mechanism,
4923        )?))
4924    }
4925    fn load_xlora(
4926        &self,
4927        _config: &str,
4928        _vb: ShardedVarBuilder,
4929        _lora_config: &[((String, String), LoraConfig)],
4930        _xlora_config: Option<XLoraConfig>,
4931        _xlora_ordering: Ordering,
4932        _normal_loading_metadata: NormalLoadingMetadata,
4933        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4934    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4935        anyhow::bail!("Qwen3Next does not support X-LoRA")
4936    }
4937    fn is_gptx(&self, _: &str) -> Result<bool> {
4938        Ok(true)
4939    }
4940    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4941        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4942        Ok(Box::new(cfg))
4943    }
4944    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4945        Ok(true)
4946    }
4947}
4948
4949impl IsqModelLoader for Qwen3NextLoader {
4950    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4951        Ok(vec![
4952            Regex::new(r"lm_head\.(weight|bias)$")?,
4953            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4954            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4955            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4956            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4957            Regex::new(r"layers\.(\d+)\.linear_attn\.out_proj\.(weight|bias)$")?,
4958            Regex::new(
4959                r"layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.(weight|bias)$",
4960            )?,
4961            Regex::new(
4962                r"layers\.(\d+)\.mlp\.shared_expert\.(gate_proj|up_proj|down_proj)\.(weight|bias)$",
4963            )?,
4964        ])
4965    }
4966    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4967        self.isq_layer_regexes(config)
4968    }
4969}
4970
4971impl DeviceMappedModelLoader for Qwen3NextLoader {
4972    fn mapped_max_act_size_elems(
4973        &self,
4974        config: &str,
4975        params: &AutoDeviceMapParams,
4976    ) -> Result<usize> {
4977        let AutoDeviceMapParams::Text {
4978            max_seq_len,
4979            max_batch_size,
4980        } = params
4981        else {
4982            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4983        };
4984
4985        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4986
4987        Ok(
4988            max_batch_size
4989                * cfg.num_attention_heads
4990                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4991        )
4992    }
4993    fn non_mapped_max_act_size_elems(
4994        &self,
4995        _config: &str,
4996        _params: &AutoDeviceMapParams,
4997    ) -> Result<usize> {
4998        Ok(0)
4999    }
5000
5001    fn non_mapped_size_in_bytes(
5002        &self,
5003        config: &str,
5004        dtype: DType,
5005        weight_pack_factor: usize,
5006        _matformer_config: Option<&MatformerSliceConfig>,
5007    ) -> Result<usize> {
5008        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5009
5010        let elems = {
5011            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5012            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
5013                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5014            } else {
5015                0
5016            };
5017            let norm = cfg.hidden_size;
5018            embed_tokens + lm_head + norm
5019        };
5020        Ok(elems * dtype.size_in_bytes())
5021    }
5022
5023    fn layer_sizes_in_bytes(
5024        &self,
5025        config: &str,
5026        dtype: DType,
5027        weight_pack_factor: usize,
5028        _matformer_config: Option<&MatformerSliceConfig>,
5029    ) -> Result<Vec<usize>> {
5030        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5031        let layer_types = cfg.layer_types();
5032        let mut layer_sizes = Vec::with_capacity(cfg.num_hidden_layers);
5033
5034        for layer_type in &layer_types {
5035            let input_layernorm = cfg.hidden_size;
5036            let post_attention_layernorm = cfg.hidden_size;
5037
5038            let attn_elems = match layer_type {
5039                crate::models::qwen3_next::LayerType::FullAttention => {
5040                    let hidden = cfg.hidden_size;
5041                    let q_dim = cfg.head_dim * cfg.num_attention_heads;
5042                    let kv_dim = cfg.head_dim * cfg.num_key_value_heads;
5043                    let q_proj = hidden * q_dim * 2 / weight_pack_factor;
5044                    let k_proj = hidden * kv_dim / weight_pack_factor;
5045                    let v_proj = hidden * kv_dim / weight_pack_factor;
5046                    let o_proj = q_dim * hidden / weight_pack_factor;
5047                    let q_norm = cfg.head_dim;
5048                    let k_norm = cfg.head_dim;
5049                    q_proj + k_proj + v_proj + o_proj + q_norm + k_norm
5050                }
5051                crate::models::qwen3_next::LayerType::LinearAttention => {
5052                    let hidden = cfg.hidden_size;
5053                    let key_dim = cfg.linear_key_dim();
5054                    let value_dim = cfg.linear_value_dim();
5055                    let conv_dim = cfg.linear_conv_dim();
5056                    // in_proj_qkvz: (2 * key_dim + 2 * value_dim, hidden)
5057                    let in_proj_qkvz = hidden * (key_dim * 2 + value_dim * 2);
5058                    // in_proj_ba: (2 * num_v_heads, hidden)
5059                    let in_proj_ba = hidden * (cfg.linear_num_value_heads * 2);
5060                    let out_proj = value_dim * hidden / weight_pack_factor;
5061                    let conv1d = conv_dim * cfg.linear_conv_kernel_dim;
5062                    let dt_bias = cfg.linear_num_value_heads;
5063                    let a_log = cfg.linear_num_value_heads;
5064                    let norm = cfg.linear_value_head_dim;
5065                    in_proj_qkvz + in_proj_ba + out_proj + conv1d + dt_bias + a_log + norm
5066                }
5067            };
5068
5069            let moe_gate = cfg.hidden_size * cfg.num_experts;
5070            let shared_expert =
5071                3 * cfg.hidden_size * cfg.shared_expert_intermediate_size / weight_pack_factor;
5072            let routed_experts = cfg.num_experts * 3 * cfg.hidden_size * cfg.moe_intermediate_size
5073                / weight_pack_factor;
5074
5075            let per_layer_elems = input_layernorm
5076                + post_attention_layernorm
5077                + attn_elems
5078                + moe_gate
5079                + shared_expert
5080                + routed_experts;
5081
5082            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5083        }
5084
5085        Ok(layer_sizes)
5086    }
5087
5088    fn num_layers(&self, config: &str) -> Result<usize> {
5089        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5090        Ok(cfg.num_hidden_layers)
5091    }
5092    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5093        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5094
5095        let cfg = ModelConfigMetadata {
5096            max_seq_len: cfg.max_position_embeddings,
5097            num_layers: cfg.num_hidden_layers,
5098            hidden_size: cfg.hidden_size,
5099            num_kv_heads: cfg.num_key_value_heads,
5100            num_attn_heads: cfg.num_attention_heads,
5101            sliding_window: None,
5102            k_head_dim: cfg.head_dim,
5103            v_head_dim: cfg.head_dim,
5104            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5105        };
5106
5107        Ok(Box::new(cfg))
5108    }
5109}