Skip to main content

mistralrs_core/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11    amoe::AnyMoeBaseModelMixin,
12    device_map::DeviceMapper,
13    lora::{LoraConfig, Ordering},
14    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15    pipeline::{
16        isq::IsqModelLoader,
17        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18        EitherCache, IsqModel,
19    },
20    utils::varbuilder_utils::DeviceForLoadTensor,
21    xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36    models,
37    xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43    #[allow(clippy::too_many_arguments)]
44    fn forward(
45        &self,
46        input_ids: &Tensor,
47        seqlen_offsets: &[usize],
48        context_lens: Vec<(usize, usize)>,
49        position_ids: Vec<usize>,
50        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51        flash_params: &FlashParams,
52    ) -> candle_core::Result<Tensor>;
53    #[allow(clippy::too_many_arguments)]
54    fn xlora_forward(
55        &self,
56        input_ids: &Tensor,
57        input_ids_full: &Tensor,
58        seqlen_offsets: &[usize],
59        seqlen_offsets_full: &[usize],
60        no_kv_cache: bool,
61        non_granular_state: &Option<NonGranularState>,
62        context_lens: Vec<(usize, usize)>,
63        position_ids: Vec<usize>,
64        flash_params: &FlashParams,
65        flash_params_full: &FlashParams,
66    ) -> candle_core::Result<Tensor>;
67    fn is_xlora(&self) -> bool;
68    fn device(&self) -> &Device;
69    fn cache(&self) -> &EitherCache;
70    fn cache_mut(&mut self) -> &mut EitherCache;
71    fn max_seq_len(&self) -> usize;
72    fn config(&self) -> &ModelConfigMetadata;
73}
74
75/// Metadata for loading a model with ISQ or device mapping.
76pub struct NormalLoadingMetadata {
77    // Device mapping metadata which can be used to construct a concrete device mapper
78    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79    // Flag to check if loading in ISQ
80    pub loading_isq: bool,
81    // Device mapping target device (the one that is not the cpu)
82    pub real_device: Device,
83    // MultiProgress support for parallelized loading
84    pub multi_progress: Arc<MultiProgress>,
85    // Optional Matryoshka Transformer slicing configuration
86    pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90    fn load(
91        &self,
92        config: &str,
93        vb: ShardedVarBuilder,
94        normal_loading_metadata: NormalLoadingMetadata,
95        attention_mechanism: AttentionImplementation,
96    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97    #[allow(clippy::too_many_arguments)]
98    fn load_xlora(
99        &self,
100        config: &str,
101        vb: ShardedVarBuilder,
102        lora_config: &[((String, String), LoraConfig)],
103        xlora_config: Option<XLoraConfig>,
104        xlora_ordering: Ordering,
105        normal_loading_metadata: NormalLoadingMetadata,
106        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108    fn is_gptx(&self, config: &str) -> Result<bool>;
109    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110        Ok(true)
111    }
112    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113    fn get_device_for_tensor(
114        &self,
115        config: &str,
116        _mapper: &dyn DeviceMapper,
117        loading_isq: bool,
118    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119        if loading_isq {
120            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121        } else {
122            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123            let num_layers = self.model_config(config)?.num_layers();
124            let closure = move |name: String| {
125                if let Some(captures) = re.captures(&name) {
126                    captures
127                        .get(1)
128                        .and_then(|m| m.as_str().parse::<usize>().ok())
129                        .map(|l| l.min(num_layers))
130                        .map(DeviceForLoadTensor::Idx)
131                        .unwrap_or(DeviceForLoadTensor::Base)
132                } else {
133                    DeviceForLoadTensor::Base
134                }
135            };
136
137            Ok(Arc::new(closure))
138        }
139    }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
144/// The architecture to load the normal model as.
145pub enum NormalLoaderType {
146    #[serde(rename = "mistral")]
147    Mistral,
148    #[serde(rename = "gemma")]
149    Gemma,
150    #[serde(rename = "mixtral")]
151    Mixtral,
152    #[serde(rename = "llama")]
153    Llama,
154    #[serde(rename = "phi2")]
155    Phi2,
156    #[serde(rename = "phi3")]
157    Phi3,
158    #[serde(rename = "qwen2")]
159    Qwen2,
160    #[serde(rename = "gemma2")]
161    Gemma2,
162    #[serde(rename = "starcoder2")]
163    Starcoder2,
164    #[serde(rename = "phi3.5moe")]
165    Phi3_5MoE,
166    #[serde(rename = "deepseekv2")]
167    DeepSeekV2,
168    #[serde(rename = "deepseekv3")]
169    DeepSeekV3,
170    #[serde(rename = "qwen3")]
171    Qwen3,
172    #[serde(rename = "glm4")]
173    GLM4,
174    #[serde(rename = "glm4moelite")]
175    GLM4MoeLite,
176    #[serde(rename = "glm4moe")]
177    GLM4Moe,
178    #[serde(rename = "qwen3moe")]
179    Qwen3Moe,
180    #[serde(rename = "smollm3")]
181    SmolLm3,
182    #[serde(rename = "granitemoehybrid")]
183    GraniteMoeHybrid,
184    #[serde(rename = "gpt_oss")]
185    GptOss,
186}
187
188// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
189impl NormalLoaderType {
190    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
191        match name {
192            "MistralForCausalLM" => Ok(Self::Mistral),
193            "MixtralForCausalLM" => Ok(Self::Mixtral),
194            "GemmaForCausalLM" => Ok(Self::Gemma),
195            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
196            "PhiForCausalLM" => Ok(Self::Phi2),
197            "Phi3ForCausalLM" => Ok(Self::Phi3),
198            "LlamaForCausalLM" => Ok(Self::Llama),
199            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
200            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
201            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
202            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
203            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
204            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
205            "Glm4ForCausalLM" => Ok(Self::GLM4),
206            "Glm4MoeLiteForCausalLM" => Ok(Self::GLM4MoeLite),
207            "Glm4MoeForCausalLM" => Ok(Self::GLM4Moe),
208            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
209            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
210            "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
211            "GptOssForCausalLM" => Ok(Self::GptOss),
212            other => anyhow::bail!(
213                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
214            ),
215        }
216    }
217}
218
219impl FromStr for NormalLoaderType {
220    type Err = String;
221    fn from_str(s: &str) -> Result<Self, Self::Err> {
222        match s {
223            "mistral" => Ok(Self::Mistral),
224            "gemma" => Ok(Self::Gemma),
225            "mixtral" => Ok(Self::Mixtral),
226            "llama" => Ok(Self::Llama),
227            "phi2" => Ok(Self::Phi2),
228            "phi3" => Ok(Self::Phi3),
229            "qwen2" => Ok(Self::Qwen2),
230            "gemma2" => Ok(Self::Gemma2),
231            "starcoder2" => Ok(Self::Starcoder2),
232            "phi3.5moe" => Ok(Self::Phi3_5MoE),
233            "deepseekv2" => Ok(Self::DeepSeekV2),
234            "deepseekv3" => Ok(Self::DeepSeekV3),
235            "qwen3" => Ok(Self::Qwen3),
236            "glm4" => Ok(Self::GLM4),
237            "glm4moelite" => Ok(Self::GLM4MoeLite),
238            "glm4moe" => Ok(Self::GLM4Moe),
239            "qwen3moe" => Ok(Self::Qwen3Moe),
240            "smollm3" => Ok(Self::SmolLm3),
241            "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
242            "gpt_oss" => Ok(Self::GptOss),
243            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `glm4moelite`, `glm4moe`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`.")),
244        }
245    }
246}
247
248impl Display for NormalLoaderType {
249    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250        match self {
251            Self::Gemma => write!(f, "gemma"),
252            Self::Gemma2 => write!(f, "gemma2"),
253            Self::Llama => write!(f, "llama"),
254            Self::Mistral => write!(f, "mistral"),
255            Self::Mixtral => write!(f, "mixtral"),
256            Self::Phi2 => write!(f, "phi2"),
257            Self::Phi3 => write!(f, "phi3"),
258            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
259            Self::Qwen2 => write!(f, "qwen2"),
260            Self::Starcoder2 => write!(f, "starcoder2"),
261            Self::DeepSeekV2 => write!(f, "deepseekv2"),
262            Self::DeepSeekV3 => write!(f, "deepseekv3"),
263            Self::Qwen3 => write!(f, "qwen3"),
264            Self::GLM4 => write!(f, "glm4"),
265            Self::GLM4MoeLite => write!(f, "glm4moelite"),
266            Self::GLM4Moe => write!(f, "glm4moe"),
267            Self::Qwen3Moe => write!(f, "qwen3moe"),
268            Self::SmolLm3 => write!(f, "smollm3"),
269            Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
270            Self::GptOss => write!(f, "gpt_oss"),
271        }
272    }
273}
274
275macro_rules! bias_if {
276    ($cond:expr, $size:expr) => {
277        if $cond {
278            $size
279        } else {
280            0
281        }
282    };
283}
284
285/// Load a model based on the Hugging Face Transformers -CausalLM model class
286pub struct AutoNormalLoader;
287
288#[derive(Deserialize)]
289struct AutoNormalLoaderConfig {
290    architectures: Vec<String>,
291}
292
293impl AutoNormalLoader {
294    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
295        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
296        if auto_cfg.architectures.len() != 1 {
297            anyhow::bail!("Expected to have one name for `architectures` config field.")
298        }
299
300        let name = &auto_cfg.architectures[0];
301
302        let tp = NormalLoaderType::from_causal_lm_name(name)?;
303
304        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
305
306        match tp {
307            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
308            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
309            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
310            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
311            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
312            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
313            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
314            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
315            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
316            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
317            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
318            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
319            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
320            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
321            NormalLoaderType::GLM4MoeLite => Ok(Box::new(GLM4MoeLiteLoader)),
322            NormalLoaderType::GLM4Moe => Ok(Box::new(GLM4MoeLoader)),
323            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
324            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
325            NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
326            NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
327        }
328    }
329}
330
331impl NormalModelLoader for AutoNormalLoader {
332    fn load(
333        &self,
334        config: &str,
335        vb: ShardedVarBuilder,
336        normal_loading_metadata: NormalLoadingMetadata,
337        attention_mechanism: AttentionImplementation,
338    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
339        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
340    }
341    fn load_xlora(
342        &self,
343        config: &str,
344        vb: ShardedVarBuilder,
345        lora_config: &[((String, String), LoraConfig)],
346        xlora_config: Option<XLoraConfig>,
347        xlora_ordering: Ordering,
348        normal_loading_metadata: NormalLoadingMetadata,
349        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
350    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
351        Self::get_loader(config)?.load_xlora(
352            config,
353            vb,
354            lora_config,
355            xlora_config,
356            xlora_ordering,
357            normal_loading_metadata,
358            preload_adapters,
359        )
360    }
361    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
362        Self::get_loader(config)?.get_config_repr(config)
363    }
364    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
365        Self::get_loader(config)?.supports_paged_attention(config)
366    }
367    fn is_gptx(&self, config: &str) -> Result<bool> {
368        Self::get_loader(config)?.is_gptx(config)
369    }
370}
371
372impl IsqModelLoader for AutoNormalLoader {
373    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
374        Self::get_loader(config)?.immediate_isq_predicates(config)
375    }
376    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
377        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
378    }
379    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
380        Self::get_loader(config)?.isq_layer_regexes(config)
381    }
382    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
383        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
384    }
385}
386
387impl DeviceMappedModelLoader for AutoNormalLoader {
388    fn non_mapped_size_in_bytes(
389        &self,
390        config: &str,
391        dtype: DType,
392        weight_pack_factor: usize,
393        _matformer_config: Option<&MatformerSliceConfig>,
394    ) -> Result<usize> {
395        Self::get_loader(config)?.non_mapped_size_in_bytes(
396            config,
397            dtype,
398            weight_pack_factor,
399            _matformer_config,
400        )
401    }
402    fn num_layers(&self, config: &str) -> Result<usize> {
403        Self::get_loader(config)?.num_layers(config)
404    }
405    fn layer_sizes_in_bytes(
406        &self,
407        config: &str,
408        dtype: DType,
409        weight_pack_factor: usize,
410        _matformer_config: Option<&MatformerSliceConfig>,
411    ) -> Result<Vec<usize>> {
412        Self::get_loader(config)?.layer_sizes_in_bytes(
413            config,
414            dtype,
415            weight_pack_factor,
416            _matformer_config,
417        )
418    }
419    fn mapped_max_act_size_elems(
420        &self,
421        config: &str,
422        params: &super::AutoDeviceMapParams,
423    ) -> Result<usize> {
424        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
425    }
426    fn non_mapped_max_act_size_elems(
427        &self,
428        _config: &str,
429        _params: &AutoDeviceMapParams,
430    ) -> Result<usize> {
431        Ok(0)
432    }
433    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
434        Self::get_loader(config)?.model_config(config)
435    }
436}
437
438// ======================== Mistral loader
439
440pub struct MistralLoader;
441
442impl NormalModelLoader for MistralLoader {
443    fn load(
444        &self,
445        config: &str,
446        vb: ShardedVarBuilder,
447        normal_loading_metadata: NormalLoadingMetadata,
448        attention_mechanism: AttentionImplementation,
449    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
450        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
451        Ok(Box::new(models::mistral::Model::new(
452            &cfg,
453            vb,
454            self.is_gptx(config)?,
455            normal_loading_metadata,
456            attention_mechanism,
457        )?))
458    }
459    fn load_xlora(
460        &self,
461        config: &str,
462        vb: ShardedVarBuilder,
463        lora_config: &[((String, String), LoraConfig)],
464        xlora_config: Option<XLoraConfig>,
465        xlora_ordering: Ordering,
466        normal_loading_metadata: NormalLoadingMetadata,
467        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
468    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
469        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
470        Ok(Box::new(xlora_models::XLoraMistral::new(
471            &cfg,
472            vb,
473            lora_config,
474            xlora_config,
475            xlora_ordering,
476            self.is_gptx(config)?,
477            normal_loading_metadata,
478            preload_adapters,
479        )?))
480    }
481    fn is_gptx(&self, _: &str) -> Result<bool> {
482        Ok(true)
483    }
484    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
485        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
486        Ok(Box::new(cfg))
487    }
488}
489
490impl IsqModelLoader for MistralLoader {
491    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
492        Ok(vec![
493            Regex::new(r"lm_head\.(weight|bias)$")?,
494            // Attention
495            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
496            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
497            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
498            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
499            // MLP
500            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
501            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
502            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
503        ])
504    }
505    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
506        self.isq_layer_regexes(config)
507    }
508}
509
510impl DeviceMappedModelLoader for MistralLoader {
511    fn mapped_max_act_size_elems(
512        &self,
513        config: &str,
514        params: &AutoDeviceMapParams,
515    ) -> Result<usize> {
516        let AutoDeviceMapParams::Text {
517            max_seq_len,
518            max_batch_size,
519        } = params
520        else {
521            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
522        };
523
524        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
525
526        Ok(
527            max_batch_size
528                * cfg.num_attention_heads
529                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
530        )
531    }
532    fn non_mapped_max_act_size_elems(
533        &self,
534        _config: &str,
535        _params: &AutoDeviceMapParams,
536    ) -> Result<usize> {
537        Ok(0)
538    }
539
540    fn non_mapped_size_in_bytes(
541        &self,
542        config: &str,
543        dtype: DType,
544        weight_pack_factor: usize,
545        _matformer_config: Option<&MatformerSliceConfig>,
546    ) -> Result<usize> {
547        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
548
549        let elems = {
550            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
551            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
552            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
553                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
554            } else {
555                0
556            };
557            let norm = cfg.hidden_size;
558            embed_tokens + lm_head + norm
559        };
560        Ok(elems * dtype.size_in_bytes())
561    }
562
563    fn layer_sizes_in_bytes(
564        &self,
565        config: &str,
566        dtype: DType,
567        weight_pack_factor: usize,
568        _matformer_config: Option<&MatformerSliceConfig>,
569    ) -> Result<Vec<usize>> {
570        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
571
572        let per_layer_elems = {
573            let input_layernorm = cfg.hidden_size;
574            let post_attention_layernorm = cfg.hidden_size;
575
576            let size_in = cfg.hidden_size;
577            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
578            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
579            let q_proj = size_in * size_q / weight_pack_factor;
580            let k_proj = size_in * size_kv / weight_pack_factor;
581            let v_proj = size_in * size_kv / weight_pack_factor;
582            let o_proj = size_q * size_in / weight_pack_factor;
583
584            let h_size = cfg.hidden_size;
585            let i_size = cfg.intermediate_size;
586            let gate_proj = h_size * i_size / weight_pack_factor;
587            let up_proj = h_size * i_size / weight_pack_factor;
588            let down_proj = i_size * h_size / weight_pack_factor;
589
590            input_layernorm
591                + post_attention_layernorm
592                + q_proj
593                + k_proj
594                + v_proj
595                + o_proj
596                + gate_proj
597                + up_proj
598                + down_proj
599        };
600        Ok(vec![
601            per_layer_elems * dtype.size_in_bytes();
602            cfg.num_hidden_layers
603        ])
604    }
605
606    fn num_layers(&self, config: &str) -> Result<usize> {
607        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
608        Ok(cfg.num_hidden_layers)
609    }
610
611    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
612        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
613
614        let cfg = ModelConfigMetadata {
615            max_seq_len: cfg.max_position_embeddings,
616            num_layers: cfg.num_hidden_layers,
617            hidden_size: cfg.hidden_size,
618            num_kv_heads: cfg.num_key_value_heads,
619            num_attn_heads: cfg.num_attention_heads,
620            sliding_window: cfg.sliding_window,
621            k_head_dim: cfg.head_dim(),
622            v_head_dim: cfg.head_dim(),
623            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
624        };
625
626        Ok(Box::new(cfg))
627    }
628}
629
630// ======================== Gemma loader
631
632/// [`NormalLoader`] for a Gemma model.
633///
634/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
635pub struct GemmaLoader;
636
637impl NormalModelLoader for GemmaLoader {
638    fn load(
639        &self,
640        config: &str,
641        vb: ShardedVarBuilder,
642        normal_loading_metadata: NormalLoadingMetadata,
643        attention_mechanism: AttentionImplementation,
644    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
645        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
646
647        Ok(Box::new(models::gemma::Model::new(
648            &cfg,
649            vb,
650            self.is_gptx(config)?,
651            normal_loading_metadata,
652            attention_mechanism,
653        )?))
654    }
655    fn load_xlora(
656        &self,
657        config: &str,
658        vb: ShardedVarBuilder,
659        lora_config: &[((String, String), LoraConfig)],
660        xlora_config: Option<XLoraConfig>,
661        xlora_ordering: Ordering,
662        normal_loading_metadata: NormalLoadingMetadata,
663        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
664    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
665        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
666
667        Ok(Box::new(xlora_models::XLoraGemma::new(
668            &cfg,
669            vb,
670            lora_config,
671            xlora_config,
672            xlora_ordering,
673            self.is_gptx(config)?,
674            normal_loading_metadata,
675            preload_adapters,
676        )?))
677    }
678    fn is_gptx(&self, _: &str) -> Result<bool> {
679        Ok(true)
680    }
681    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
682        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
683        Ok(Box::new(cfg))
684    }
685}
686
687impl IsqModelLoader for GemmaLoader {
688    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
689        Ok(vec![
690            Regex::new(r"lm_head\.(weight|bias)$")?,
691            // Attention
692            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
693            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
694            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
695            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
696            // MLP
697            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
698            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
699            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
700        ])
701    }
702    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
703        self.isq_layer_regexes(config)
704    }
705}
706
707impl DeviceMappedModelLoader for GemmaLoader {
708    fn mapped_max_act_size_elems(
709        &self,
710        config: &str,
711        params: &AutoDeviceMapParams,
712    ) -> Result<usize> {
713        let AutoDeviceMapParams::Text {
714            max_seq_len,
715            max_batch_size,
716        } = params
717        else {
718            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
719        };
720
721        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
722
723        Ok(
724            max_batch_size
725                * cfg.num_attention_heads
726                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
727        )
728    }
729    fn non_mapped_max_act_size_elems(
730        &self,
731        _config: &str,
732        _params: &AutoDeviceMapParams,
733    ) -> Result<usize> {
734        Ok(0)
735    }
736
737    fn non_mapped_size_in_bytes(
738        &self,
739        config: &str,
740        dtype: DType,
741        weight_pack_factor: usize,
742        _matformer_config: Option<&MatformerSliceConfig>,
743    ) -> Result<usize> {
744        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
745
746        let elems = {
747            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
748            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
749            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
750                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
751            } else {
752                0
753            };
754            let norm = cfg.hidden_size;
755            embed_tokens + lm_head + norm
756        };
757        Ok(elems * dtype.size_in_bytes())
758    }
759
760    fn layer_sizes_in_bytes(
761        &self,
762        config: &str,
763        dtype: DType,
764        weight_pack_factor: usize,
765        _matformer_config: Option<&MatformerSliceConfig>,
766    ) -> Result<Vec<usize>> {
767        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
768
769        let per_layer_elems = {
770            let input_layernorm = cfg.hidden_size;
771            let post_attention_layernorm = cfg.hidden_size;
772
773            let size_in = cfg.hidden_size;
774            let size_q = cfg.head_dim * cfg.num_attention_heads;
775            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
776            let q_proj =
777                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
778            let k_proj =
779                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
780            let v_proj =
781                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
782            let o_proj =
783                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
784
785            let h_size = cfg.hidden_size;
786            let i_size = cfg.intermediate_size;
787            let gate_proj = h_size * i_size / weight_pack_factor;
788            let up_proj = h_size * i_size / weight_pack_factor;
789            let down_proj = i_size * h_size / weight_pack_factor;
790
791            input_layernorm
792                + post_attention_layernorm
793                + q_proj
794                + k_proj
795                + v_proj
796                + o_proj
797                + gate_proj
798                + up_proj
799                + down_proj
800        };
801        Ok(vec![
802            per_layer_elems * dtype.size_in_bytes();
803            cfg.num_hidden_layers
804        ])
805    }
806
807    fn num_layers(&self, config: &str) -> Result<usize> {
808        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
809        Ok(cfg.num_hidden_layers)
810    }
811
812    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
813        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
814
815        let cfg = ModelConfigMetadata {
816            max_seq_len: cfg.max_position_embeddings,
817            num_layers: cfg.num_hidden_layers,
818            hidden_size: cfg.hidden_size,
819            num_kv_heads: cfg.num_key_value_heads,
820            num_attn_heads: cfg.num_attention_heads,
821            sliding_window: None,
822            k_head_dim: cfg.head_dim,
823            v_head_dim: cfg.head_dim,
824            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
825        };
826
827        Ok(Box::new(cfg))
828    }
829}
830
831// ======================== Llama loader
832
833/// [`NormalLoader`] for a Llama model.
834///
835/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
836pub struct LlamaLoader;
837
838impl NormalModelLoader for LlamaLoader {
839    fn load(
840        &self,
841        config: &str,
842        vb: ShardedVarBuilder,
843        normal_loading_metadata: NormalLoadingMetadata,
844        attention_mechanism: AttentionImplementation,
845    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
846        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
847
848        Ok(Box::new(models::llama::Llama::new(
849            &cfg,
850            vb,
851            self.is_gptx(config)?,
852            normal_loading_metadata,
853            attention_mechanism,
854        )?))
855    }
856    fn load_xlora(
857        &self,
858        config: &str,
859        vb: ShardedVarBuilder,
860        lora_config: &[((String, String), LoraConfig)],
861        xlora_config: Option<XLoraConfig>,
862        xlora_ordering: Ordering,
863        normal_loading_metadata: NormalLoadingMetadata,
864        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
865    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
866        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
867
868        Ok(Box::new(xlora_models::XLoraLlama::new(
869            &cfg,
870            vb,
871            lora_config,
872            xlora_config,
873            xlora_ordering,
874            self.is_gptx(config)?,
875            normal_loading_metadata,
876            preload_adapters,
877        )?))
878    }
879    fn is_gptx(&self, _: &str) -> Result<bool> {
880        Ok(true)
881    }
882    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
883        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
884        Ok(Box::new(cfg))
885    }
886}
887
888impl IsqModelLoader for LlamaLoader {
889    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
890        Ok(vec![
891            Regex::new(r"lm_head\.(weight|bias)$")?,
892            // Attention
893            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
894            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
895            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
896            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
897            // MLP
898            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
899            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
900            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
901        ])
902    }
903    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
904        self.isq_layer_regexes(config)
905    }
906}
907
908impl DeviceMappedModelLoader for LlamaLoader {
909    fn mapped_max_act_size_elems(
910        &self,
911        config: &str,
912        params: &AutoDeviceMapParams,
913    ) -> Result<usize> {
914        let AutoDeviceMapParams::Text {
915            max_seq_len,
916            max_batch_size,
917        } = params
918        else {
919            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
920        };
921
922        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
923
924        Ok(
925            max_batch_size
926                * cfg.num_attention_heads
927                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
928        )
929    }
930    fn non_mapped_max_act_size_elems(
931        &self,
932        _config: &str,
933        _params: &AutoDeviceMapParams,
934    ) -> Result<usize> {
935        Ok(0)
936    }
937
938    fn non_mapped_size_in_bytes(
939        &self,
940        config: &str,
941        dtype: DType,
942        weight_pack_factor: usize,
943        _matformer_config: Option<&MatformerSliceConfig>,
944    ) -> Result<usize> {
945        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
946
947        let elems = {
948            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
949            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
950            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
951                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
952            } else {
953                0
954            };
955            let norm = cfg.hidden_size;
956            embed_tokens + lm_head + norm
957        };
958        Ok(elems * dtype.size_in_bytes())
959    }
960
961    fn layer_sizes_in_bytes(
962        &self,
963        config: &str,
964        dtype: DType,
965        weight_pack_factor: usize,
966        _matformer_config: Option<&MatformerSliceConfig>,
967    ) -> Result<Vec<usize>> {
968        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
969
970        let per_layer_elems = {
971            let input_layernorm = cfg.hidden_size;
972            let post_attention_layernorm = cfg.hidden_size;
973
974            let size_in = cfg.hidden_size;
975            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
976            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
977            let q_proj = size_in * size_q / weight_pack_factor;
978            let k_proj = size_in * size_kv / weight_pack_factor;
979            let v_proj = size_in * size_kv / weight_pack_factor;
980            let o_proj = size_q * size_in / weight_pack_factor;
981
982            let h_size = cfg.hidden_size;
983            let i_size = cfg.intermediate_size;
984            let gate_proj = h_size * i_size / weight_pack_factor;
985            let up_proj = h_size * i_size / weight_pack_factor;
986            let down_proj = i_size * h_size / weight_pack_factor;
987
988            input_layernorm
989                + post_attention_layernorm
990                + q_proj
991                + k_proj
992                + v_proj
993                + o_proj
994                + gate_proj
995                + up_proj
996                + down_proj
997        };
998        Ok(vec![
999            per_layer_elems * dtype.size_in_bytes();
1000            cfg.num_hidden_layers
1001        ])
1002    }
1003
1004    fn num_layers(&self, config: &str) -> Result<usize> {
1005        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1006
1007        Ok(cfg.num_hidden_layers)
1008    }
1009    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1010        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1011
1012        let cfg = ModelConfigMetadata {
1013            max_seq_len: cfg.max_position_embeddings,
1014            num_layers: cfg.num_hidden_layers,
1015            hidden_size: cfg.hidden_size,
1016            num_kv_heads: cfg.num_key_value_heads,
1017            num_attn_heads: cfg.num_attention_heads,
1018            sliding_window: None,
1019            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1020            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1021            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1022        };
1023
1024        Ok(Box::new(cfg))
1025    }
1026}
1027
1028// ======================== Mixtral loader
1029
1030pub struct MixtralLoader;
1031
1032impl NormalModelLoader for MixtralLoader {
1033    fn load(
1034        &self,
1035        config: &str,
1036        vb: ShardedVarBuilder,
1037        normal_loading_metadata: NormalLoadingMetadata,
1038        attention_mechanism: AttentionImplementation,
1039    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1040        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1041
1042        Ok(Box::new(models::mixtral::Model::new(
1043            &cfg,
1044            vb,
1045            self.is_gptx(config)?,
1046            normal_loading_metadata,
1047            attention_mechanism,
1048        )?))
1049    }
1050    fn load_xlora(
1051        &self,
1052        config: &str,
1053        vb: ShardedVarBuilder,
1054        lora_config: &[((String, String), LoraConfig)],
1055        xlora_config: Option<XLoraConfig>,
1056        xlora_ordering: Ordering,
1057        normal_loading_metadata: NormalLoadingMetadata,
1058        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1059    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1060        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1061
1062        Ok(Box::new(xlora_models::XLoraMixtral::new(
1063            &cfg,
1064            vb,
1065            lora_config,
1066            xlora_config,
1067            xlora_ordering,
1068            self.is_gptx(config)?,
1069            normal_loading_metadata,
1070            preload_adapters,
1071        )?))
1072    }
1073    fn is_gptx(&self, _: &str) -> Result<bool> {
1074        Ok(true)
1075    }
1076    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1077        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1078
1079        Ok(Box::new(cfg))
1080    }
1081}
1082
1083impl IsqModelLoader for MixtralLoader {
1084    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1085        Ok(vec![
1086            Regex::new(r"lm_head\.(weight|bias)$")?,
1087            // Attention
1088            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1089            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1090            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1091            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1092            // Experts
1093            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1094            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1095            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1096            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1097        ])
1098    }
1099    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1100        self.isq_layer_regexes(config)
1101    }
1102}
1103
1104impl DeviceMappedModelLoader for MixtralLoader {
1105    fn mapped_max_act_size_elems(
1106        &self,
1107        config: &str,
1108        params: &AutoDeviceMapParams,
1109    ) -> Result<usize> {
1110        let AutoDeviceMapParams::Text {
1111            max_seq_len,
1112            max_batch_size,
1113        } = params
1114        else {
1115            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1116        };
1117
1118        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1119
1120        Ok(
1121            max_batch_size
1122                * cfg.num_attention_heads
1123                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1124        )
1125    }
1126    fn non_mapped_max_act_size_elems(
1127        &self,
1128        _config: &str,
1129        _params: &AutoDeviceMapParams,
1130    ) -> Result<usize> {
1131        Ok(0)
1132    }
1133
1134    fn non_mapped_size_in_bytes(
1135        &self,
1136        config: &str,
1137        dtype: DType,
1138        weight_pack_factor: usize,
1139        _matformer_config: Option<&MatformerSliceConfig>,
1140    ) -> Result<usize> {
1141        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1142
1143        let elems = {
1144            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1145            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1146            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1147                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1148            } else {
1149                0
1150            };
1151            let norm = cfg.hidden_size;
1152            embed_tokens + lm_head + norm
1153        };
1154        Ok(elems * dtype.size_in_bytes())
1155    }
1156
1157    fn layer_sizes_in_bytes(
1158        &self,
1159        config: &str,
1160        dtype: DType,
1161        weight_pack_factor: usize,
1162        _matformer_config: Option<&MatformerSliceConfig>,
1163    ) -> Result<Vec<usize>> {
1164        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1165
1166        let per_layer_elems = {
1167            let input_layernorm = cfg.hidden_size;
1168            let post_attention_layernorm = cfg.hidden_size;
1169
1170            let size_in = cfg.hidden_size;
1171            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1172            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1173            let q_proj = size_in * size_q / weight_pack_factor;
1174            let k_proj = size_in * size_kv / weight_pack_factor;
1175            let v_proj = size_in * size_kv / weight_pack_factor;
1176            let o_proj = size_q * size_in / weight_pack_factor;
1177
1178            let moe_block = {
1179                let gate = cfg.hidden_size * cfg.num_local_experts;
1180                // Assume quantizing weight pack factor
1181                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1182                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1183                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1184                gate + cfg.num_local_experts * w1
1185                    + cfg.num_local_experts * w2
1186                    + cfg.num_local_experts * w3
1187            };
1188
1189            input_layernorm
1190                + post_attention_layernorm
1191                + q_proj
1192                + k_proj
1193                + v_proj
1194                + o_proj
1195                + moe_block
1196        };
1197        Ok(vec![
1198            per_layer_elems * dtype.size_in_bytes();
1199            cfg.num_hidden_layers
1200        ])
1201    }
1202
1203    fn num_layers(&self, config: &str) -> Result<usize> {
1204        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1205
1206        Ok(cfg.num_hidden_layers)
1207    }
1208
1209    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1210        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1211
1212        let cfg = ModelConfigMetadata {
1213            max_seq_len: cfg.max_position_embeddings,
1214            num_layers: cfg.num_hidden_layers,
1215            hidden_size: cfg.hidden_size,
1216            num_kv_heads: cfg.num_key_value_heads,
1217            num_attn_heads: cfg.num_attention_heads,
1218            sliding_window: cfg.sliding_window,
1219            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1220            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1221            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1222        };
1223
1224        Ok(Box::new(cfg))
1225    }
1226}
1227
1228// ======================== Phi2 loader
1229
1230/// [`NormalLoader`] for a Phi 2 model.
1231///
1232/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
1233pub struct Phi2Loader;
1234
1235impl NormalModelLoader for Phi2Loader {
1236    fn load(
1237        &self,
1238        config: &str,
1239        vb: ShardedVarBuilder,
1240        normal_loading_metadata: NormalLoadingMetadata,
1241        attention_mechanism: AttentionImplementation,
1242    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1243        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1244
1245        Ok(Box::new(models::phi2::Model::new(
1246            &cfg,
1247            vb,
1248            self.is_gptx(config)?,
1249            normal_loading_metadata,
1250            attention_mechanism,
1251        )?))
1252    }
1253    fn load_xlora(
1254        &self,
1255        config: &str,
1256        vb: ShardedVarBuilder,
1257        lora_config: &[((String, String), LoraConfig)],
1258        xlora_config: Option<XLoraConfig>,
1259        xlora_ordering: Ordering,
1260        normal_loading_metadata: NormalLoadingMetadata,
1261        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1262    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1263        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1264
1265        Ok(Box::new(xlora_models::XLoraPhi2::new(
1266            &cfg,
1267            vb,
1268            lora_config,
1269            xlora_config,
1270            xlora_ordering,
1271            self.is_gptx(config)?,
1272            normal_loading_metadata,
1273            preload_adapters,
1274        )?))
1275    }
1276    fn is_gptx(&self, _: &str) -> Result<bool> {
1277        Ok(true)
1278    }
1279    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1280        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1281
1282        Ok(Box::new(cfg))
1283    }
1284}
1285
1286impl IsqModelLoader for Phi2Loader {
1287    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1288        Ok(vec![
1289            Regex::new(r"lm_head\.(weight|bias)$")?,
1290            // Attention
1291            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1292            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1293            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1294            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1295            // MLP
1296            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1297            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1298        ])
1299    }
1300    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1301        self.isq_layer_regexes(config)
1302    }
1303}
1304
1305impl DeviceMappedModelLoader for Phi2Loader {
1306    fn mapped_max_act_size_elems(
1307        &self,
1308        config: &str,
1309        params: &AutoDeviceMapParams,
1310    ) -> Result<usize> {
1311        let AutoDeviceMapParams::Text {
1312            max_seq_len,
1313            max_batch_size,
1314        } = params
1315        else {
1316            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1317        };
1318
1319        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1320
1321        Ok(
1322            max_batch_size
1323                * cfg.num_attention_heads
1324                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1325        )
1326    }
1327    fn non_mapped_max_act_size_elems(
1328        &self,
1329        _config: &str,
1330        _params: &AutoDeviceMapParams,
1331    ) -> Result<usize> {
1332        Ok(0)
1333    }
1334
1335    fn non_mapped_size_in_bytes(
1336        &self,
1337        config: &str,
1338        dtype: DType,
1339        weight_pack_factor: usize,
1340        _matformer_config: Option<&MatformerSliceConfig>,
1341    ) -> Result<usize> {
1342        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1343
1344        let elems = {
1345            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1346            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1347            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1348                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1349            } else {
1350                0
1351            };
1352            let norm = cfg.hidden_size;
1353            embed_tokens + lm_head + norm
1354        };
1355        Ok(elems * dtype.size_in_bytes())
1356    }
1357
1358    fn layer_sizes_in_bytes(
1359        &self,
1360        config: &str,
1361        dtype: DType,
1362        weight_pack_factor: usize,
1363        _matformer_config: Option<&MatformerSliceConfig>,
1364    ) -> Result<Vec<usize>> {
1365        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1366
1367        let per_layer_elems = {
1368            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1369
1370            let size_in = cfg.hidden_size;
1371            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1372            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1373            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1374            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1375            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1376            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1377            let (q_norm, k_norm) = if cfg.qk_layernorm {
1378                (cfg.head_dim(), cfg.head_dim())
1379            } else {
1380                (0, 0)
1381            };
1382
1383            let h_size = cfg.hidden_size;
1384            let i_size = cfg.intermediate_size;
1385            let fc1 = h_size * i_size / weight_pack_factor;
1386            let fc2 = h_size * i_size / weight_pack_factor;
1387
1388            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1389        };
1390        Ok(vec![
1391            per_layer_elems * dtype.size_in_bytes();
1392            cfg.num_hidden_layers
1393        ])
1394    }
1395
1396    fn num_layers(&self, config: &str) -> Result<usize> {
1397        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1398
1399        Ok(cfg.num_hidden_layers)
1400    }
1401
1402    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1403        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1404
1405        let cfg = ModelConfigMetadata {
1406            max_seq_len: cfg.max_position_embeddings,
1407            num_layers: cfg.num_hidden_layers,
1408            hidden_size: cfg.hidden_size,
1409            num_kv_heads: cfg.num_key_value_heads(),
1410            num_attn_heads: cfg.num_attention_heads,
1411            sliding_window: None,
1412            k_head_dim: cfg.head_dim(),
1413            v_head_dim: cfg.head_dim(),
1414            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1415        };
1416
1417        Ok(Box::new(cfg))
1418    }
1419}
1420
1421// ======================== Phi3 loader
1422
1423/// [`NormalLoader`] for a Phi 3 model.
1424///
1425/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
1426pub struct Phi3Loader;
1427
1428impl NormalModelLoader for Phi3Loader {
1429    fn load(
1430        &self,
1431        config: &str,
1432        vb: ShardedVarBuilder,
1433        normal_loading_metadata: NormalLoadingMetadata,
1434        attention_mechanism: AttentionImplementation,
1435    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1436        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1437
1438        Ok(Box::new(models::phi3::Model::new(
1439            &cfg,
1440            vb,
1441            self.is_gptx(config)?,
1442            normal_loading_metadata,
1443            attention_mechanism,
1444        )?))
1445    }
1446    fn load_xlora(
1447        &self,
1448        config: &str,
1449        vb: ShardedVarBuilder,
1450        lora_config: &[((String, String), LoraConfig)],
1451        xlora_config: Option<XLoraConfig>,
1452        xlora_ordering: Ordering,
1453        normal_loading_metadata: NormalLoadingMetadata,
1454        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1455    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1456        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1457
1458        Ok(Box::new(xlora_models::XLoraPhi3::new(
1459            &cfg,
1460            vb,
1461            lora_config,
1462            xlora_config,
1463            xlora_ordering,
1464            self.is_gptx(config)?,
1465            normal_loading_metadata,
1466            preload_adapters,
1467        )?))
1468    }
1469    fn is_gptx(&self, _: &str) -> Result<bool> {
1470        Ok(true)
1471    }
1472    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1473        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1474
1475        Ok(Box::new(cfg))
1476    }
1477}
1478
1479impl IsqModelLoader for Phi3Loader {
1480    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1481        Ok(vec![
1482            Regex::new(r"lm_head\.(weight|bias)$")?,
1483            // Attention
1484            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1485            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1486            // MLP
1487            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1488            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1489            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1490        ])
1491    }
1492    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1493        self.isq_layer_regexes(config)
1494    }
1495}
1496
1497impl DeviceMappedModelLoader for Phi3Loader {
1498    fn mapped_max_act_size_elems(
1499        &self,
1500        config: &str,
1501        params: &AutoDeviceMapParams,
1502    ) -> Result<usize> {
1503        let AutoDeviceMapParams::Text {
1504            max_seq_len,
1505            max_batch_size,
1506        } = params
1507        else {
1508            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1509        };
1510
1511        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1512
1513        Ok(
1514            max_batch_size
1515                * cfg.num_attention_heads
1516                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1517        )
1518    }
1519    fn non_mapped_max_act_size_elems(
1520        &self,
1521        _config: &str,
1522        _params: &AutoDeviceMapParams,
1523    ) -> Result<usize> {
1524        Ok(0)
1525    }
1526
1527    fn non_mapped_size_in_bytes(
1528        &self,
1529        config: &str,
1530        dtype: DType,
1531        weight_pack_factor: usize,
1532        _matformer_config: Option<&MatformerSliceConfig>,
1533    ) -> Result<usize> {
1534        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1535
1536        let elems = {
1537            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1538            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1539            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1540                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1541            } else {
1542                0
1543            };
1544            let norm = cfg.hidden_size;
1545            embed_tokens + lm_head + norm
1546        };
1547        Ok(elems * dtype.size_in_bytes())
1548    }
1549
1550    fn layer_sizes_in_bytes(
1551        &self,
1552        config: &str,
1553        dtype: DType,
1554        weight_pack_factor: usize,
1555        _matformer_config: Option<&MatformerSliceConfig>,
1556    ) -> Result<Vec<usize>> {
1557        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1558
1559        let per_layer_elems = {
1560            let input_layernorm = cfg.hidden_size;
1561            let post_attention_layernorm = cfg.hidden_size;
1562
1563            let size_in = cfg.hidden_size;
1564            let head_dim = cfg.head_dim();
1565            let op_size =
1566                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1567            let qkv_proj = size_in * op_size / weight_pack_factor;
1568            let o_proj =
1569                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1570
1571            let h_size = cfg.hidden_size;
1572            let i_size = cfg.intermediate_size;
1573            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1574            let down_proj = h_size * i_size / weight_pack_factor;
1575
1576            input_layernorm
1577                + post_attention_layernorm
1578                + qkv_proj
1579                + o_proj
1580                + gate_up_proj
1581                + down_proj
1582        };
1583        Ok(vec![
1584            per_layer_elems * dtype.size_in_bytes();
1585            cfg.num_hidden_layers
1586        ])
1587    }
1588
1589    fn num_layers(&self, config: &str) -> Result<usize> {
1590        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1591
1592        Ok(cfg.num_hidden_layers)
1593    }
1594
1595    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1596        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1597
1598        let cfg = ModelConfigMetadata {
1599            max_seq_len: cfg.max_position_embeddings,
1600            num_layers: cfg.num_hidden_layers,
1601            hidden_size: cfg.hidden_size,
1602            num_kv_heads: cfg.num_key_value_heads,
1603            num_attn_heads: cfg.num_attention_heads,
1604            sliding_window: cfg.sliding_window,
1605            k_head_dim: cfg.head_dim(),
1606            v_head_dim: cfg.head_dim(),
1607            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1608        };
1609
1610        Ok(Box::new(cfg))
1611    }
1612}
1613
1614// ======================== Qwen2 loader
1615
1616/// [`NormalLoader`] for a Qwen 2 model.
1617///
1618/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
1619pub struct Qwen2Loader;
1620
1621impl NormalModelLoader for Qwen2Loader {
1622    fn load(
1623        &self,
1624        config: &str,
1625        vb: ShardedVarBuilder,
1626        normal_loading_metadata: NormalLoadingMetadata,
1627        attention_mechanism: AttentionImplementation,
1628    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1629        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1630
1631        Ok(Box::new(models::qwen2::Model::new(
1632            &cfg,
1633            vb,
1634            self.is_gptx(config)?,
1635            normal_loading_metadata,
1636            attention_mechanism,
1637        )?))
1638    }
1639    fn load_xlora(
1640        &self,
1641        _config: &str,
1642        _vb: ShardedVarBuilder,
1643        _lora_config: &[((String, String), LoraConfig)],
1644        _xlora_config: Option<XLoraConfig>,
1645        _xlora_ordering: Ordering,
1646        _normal_loading_metadata: NormalLoadingMetadata,
1647        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1648    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1649        todo!()
1650    }
1651    fn is_gptx(&self, _: &str) -> Result<bool> {
1652        Ok(true)
1653    }
1654    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1655        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1656
1657        Ok(Box::new(cfg))
1658    }
1659}
1660
1661impl IsqModelLoader for Qwen2Loader {
1662    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1663        Ok(vec![
1664            Regex::new(r"lm_head\.(weight|bias)$")?,
1665            // Attention
1666            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1667            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1668            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1669            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1670            // MLP
1671            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1672            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1673            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1674        ])
1675    }
1676    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1677        self.isq_layer_regexes(config)
1678    }
1679}
1680
1681impl DeviceMappedModelLoader for Qwen2Loader {
1682    fn mapped_max_act_size_elems(
1683        &self,
1684        config: &str,
1685        params: &AutoDeviceMapParams,
1686    ) -> Result<usize> {
1687        let AutoDeviceMapParams::Text {
1688            max_seq_len,
1689            max_batch_size,
1690        } = params
1691        else {
1692            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1693        };
1694
1695        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1696
1697        Ok(
1698            max_batch_size
1699                * cfg.num_attention_heads
1700                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1701        )
1702    }
1703    fn non_mapped_max_act_size_elems(
1704        &self,
1705        _config: &str,
1706        _params: &AutoDeviceMapParams,
1707    ) -> Result<usize> {
1708        Ok(0)
1709    }
1710
1711    fn non_mapped_size_in_bytes(
1712        &self,
1713        config: &str,
1714        dtype: DType,
1715        weight_pack_factor: usize,
1716        _matformer_config: Option<&MatformerSliceConfig>,
1717    ) -> Result<usize> {
1718        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1719
1720        let elems = {
1721            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1722            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1723            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1724                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1725            } else {
1726                0
1727            };
1728            let norm = cfg.hidden_size;
1729            embed_tokens + lm_head + norm
1730        };
1731        Ok(elems * dtype.size_in_bytes())
1732    }
1733
1734    fn layer_sizes_in_bytes(
1735        &self,
1736        config: &str,
1737        dtype: DType,
1738        weight_pack_factor: usize,
1739        _matformer_config: Option<&MatformerSliceConfig>,
1740    ) -> Result<Vec<usize>> {
1741        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1742
1743        let per_layer_elems = {
1744            let input_layernorm = cfg.hidden_size;
1745            let post_attention_layernorm = cfg.hidden_size;
1746
1747            let size_in = cfg.hidden_size;
1748            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1749            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1750            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1751            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1752            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1753            let o_proj = size_q * size_in / weight_pack_factor;
1754
1755            let h_size = cfg.hidden_size;
1756            let i_size = cfg.intermediate_size;
1757            let gate_proj = h_size * i_size / weight_pack_factor;
1758            let up_proj = h_size * i_size / weight_pack_factor;
1759            let down_proj = i_size * h_size / weight_pack_factor;
1760
1761            input_layernorm
1762                + post_attention_layernorm
1763                + q_proj
1764                + k_proj
1765                + v_proj
1766                + o_proj
1767                + gate_proj
1768                + up_proj
1769                + down_proj
1770        };
1771        Ok(vec![
1772            per_layer_elems * dtype.size_in_bytes();
1773            cfg.num_hidden_layers
1774        ])
1775    }
1776
1777    fn num_layers(&self, config: &str) -> Result<usize> {
1778        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1779
1780        Ok(cfg.num_hidden_layers)
1781    }
1782
1783    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1784        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1785
1786        let cfg = ModelConfigMetadata {
1787            max_seq_len: cfg.max_position_embeddings,
1788            num_layers: cfg.num_hidden_layers,
1789            hidden_size: cfg.hidden_size,
1790            num_kv_heads: cfg.num_key_value_heads,
1791            num_attn_heads: cfg.num_attention_heads,
1792            sliding_window: cfg.sliding_window,
1793            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1794            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1795            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1796        };
1797
1798        Ok(Box::new(cfg))
1799    }
1800}
1801
1802// ======================== Gemma2 loader
1803
1804/// [`NormalLoader`] for a Gemma2 model.
1805///
1806/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
1807pub struct Gemma2Loader;
1808
1809impl NormalModelLoader for Gemma2Loader {
1810    fn load(
1811        &self,
1812        config: &str,
1813        vb: ShardedVarBuilder,
1814        normal_loading_metadata: NormalLoadingMetadata,
1815        attention_mechanism: AttentionImplementation,
1816    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1817        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1818
1819        Ok(Box::new(models::gemma2::Model::new(
1820            &cfg,
1821            vb,
1822            self.is_gptx(config)?,
1823            normal_loading_metadata,
1824            attention_mechanism,
1825        )?))
1826    }
1827    fn load_xlora(
1828        &self,
1829        config: &str,
1830        vb: ShardedVarBuilder,
1831        lora_config: &[((String, String), LoraConfig)],
1832        xlora_config: Option<XLoraConfig>,
1833        xlora_ordering: Ordering,
1834        normal_loading_metadata: NormalLoadingMetadata,
1835        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1836    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1837        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1838
1839        Ok(Box::new(xlora_models::XLoraGemma2::new(
1840            &cfg,
1841            vb,
1842            lora_config,
1843            xlora_config,
1844            xlora_ordering,
1845            self.is_gptx(config)?,
1846            normal_loading_metadata,
1847            preload_adapters,
1848        )?))
1849    }
1850    fn is_gptx(&self, _: &str) -> Result<bool> {
1851        Ok(true)
1852    }
1853    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1854        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1855
1856        Ok(Box::new(cfg))
1857    }
1858}
1859
1860impl IsqModelLoader for Gemma2Loader {
1861    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1862        Ok(vec![
1863            Regex::new(r"lm_head\.(weight|bias)$")?,
1864            // Attention
1865            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1866            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1867            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1868            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1869            // MLP
1870            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1871            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1872            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1873        ])
1874    }
1875    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1876        self.isq_layer_regexes(config)
1877    }
1878}
1879
1880impl DeviceMappedModelLoader for Gemma2Loader {
1881    fn mapped_max_act_size_elems(
1882        &self,
1883        config: &str,
1884        params: &AutoDeviceMapParams,
1885    ) -> Result<usize> {
1886        let AutoDeviceMapParams::Text {
1887            max_seq_len,
1888            max_batch_size,
1889        } = params
1890        else {
1891            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1892        };
1893
1894        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1895
1896        Ok(
1897            max_batch_size
1898                * cfg.num_attention_heads
1899                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1900        )
1901    }
1902    fn non_mapped_max_act_size_elems(
1903        &self,
1904        _config: &str,
1905        _params: &AutoDeviceMapParams,
1906    ) -> Result<usize> {
1907        Ok(0)
1908    }
1909
1910    fn non_mapped_size_in_bytes(
1911        &self,
1912        config: &str,
1913        dtype: DType,
1914        weight_pack_factor: usize,
1915        _matformer_config: Option<&MatformerSliceConfig>,
1916    ) -> Result<usize> {
1917        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1918
1919        let elems = {
1920            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1921            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1922            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1923                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1924            } else {
1925                0
1926            };
1927            let norm = cfg.hidden_size;
1928            embed_tokens + lm_head + norm
1929        };
1930        Ok(elems * dtype.size_in_bytes())
1931    }
1932
1933    fn layer_sizes_in_bytes(
1934        &self,
1935        config: &str,
1936        dtype: DType,
1937        weight_pack_factor: usize,
1938        _matformer_config: Option<&MatformerSliceConfig>,
1939    ) -> Result<Vec<usize>> {
1940        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1941
1942        let per_layer_elems = {
1943            let input_layernorm = cfg.hidden_size;
1944            let post_attention_layernorm = cfg.hidden_size;
1945
1946            let size_in = cfg.hidden_size;
1947            let size_q = cfg.head_dim * cfg.num_attention_heads;
1948            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1949            let q_proj =
1950                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1951            let k_proj =
1952                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1953            let v_proj =
1954                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1955            let o_proj =
1956                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1957
1958            let h_size = cfg.hidden_size;
1959            let i_size = cfg.intermediate_size;
1960            let gate_proj = h_size * i_size / weight_pack_factor;
1961            let up_proj = h_size * i_size / weight_pack_factor;
1962            let down_proj = i_size * h_size / weight_pack_factor;
1963
1964            input_layernorm
1965                + post_attention_layernorm
1966                + q_proj
1967                + k_proj
1968                + v_proj
1969                + o_proj
1970                + gate_proj
1971                + up_proj
1972                + down_proj
1973        };
1974        Ok(vec![
1975            per_layer_elems * dtype.size_in_bytes();
1976            cfg.num_hidden_layers
1977        ])
1978    }
1979
1980    fn num_layers(&self, config: &str) -> Result<usize> {
1981        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1982
1983        Ok(cfg.num_hidden_layers)
1984    }
1985    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1986        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1987
1988        let cfg = ModelConfigMetadata {
1989            max_seq_len: cfg.max_position_embeddings,
1990            num_layers: cfg.num_hidden_layers,
1991            hidden_size: cfg.hidden_size,
1992            num_kv_heads: cfg.num_key_value_heads,
1993            num_attn_heads: cfg.num_attention_heads,
1994            sliding_window: None, // None to be more forgiving, some do not
1995            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1996            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1997            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1998        };
1999
2000        Ok(Box::new(cfg))
2001    }
2002}
2003
2004// ======================== Starcoder2 loader
2005
2006/// [`NormalLoader`] for a Starcoder2 model.
2007///
2008/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
2009pub struct Starcoder2Loader;
2010
2011impl NormalModelLoader for Starcoder2Loader {
2012    fn load(
2013        &self,
2014        config: &str,
2015        vb: ShardedVarBuilder,
2016        normal_loading_metadata: NormalLoadingMetadata,
2017        attention_mechanism: AttentionImplementation,
2018    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2019        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2020
2021        Ok(Box::new(models::starcoder2::Model::new(
2022            &cfg,
2023            vb,
2024            self.is_gptx(config)?,
2025            normal_loading_metadata,
2026            attention_mechanism,
2027        )?))
2028    }
2029    fn load_xlora(
2030        &self,
2031        config: &str,
2032        vb: ShardedVarBuilder,
2033        lora_config: &[((String, String), LoraConfig)],
2034        xlora_config: Option<XLoraConfig>,
2035        xlora_ordering: Ordering,
2036        normal_loading_metadata: NormalLoadingMetadata,
2037        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2038    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2039        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2040
2041        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2042            &cfg,
2043            vb,
2044            lora_config,
2045            xlora_config,
2046            xlora_ordering,
2047            self.is_gptx(config)?,
2048            normal_loading_metadata,
2049            preload_adapters,
2050        )?))
2051    }
2052    fn is_gptx(&self, _: &str) -> Result<bool> {
2053        Ok(true)
2054    }
2055    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2056        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2057
2058        Ok(Box::new(cfg))
2059    }
2060}
2061
2062impl IsqModelLoader for Starcoder2Loader {
2063    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2064        Ok(vec![
2065            Regex::new(r"lm_head\.(weight|bias)$")?,
2066            // Attention
2067            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2068            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2069            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2070            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2071            // MLP
2072            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2073            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2074        ])
2075    }
2076    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2077        self.isq_layer_regexes(config)
2078    }
2079}
2080
2081impl DeviceMappedModelLoader for Starcoder2Loader {
2082    fn mapped_max_act_size_elems(
2083        &self,
2084        config: &str,
2085        params: &AutoDeviceMapParams,
2086    ) -> Result<usize> {
2087        let AutoDeviceMapParams::Text {
2088            max_seq_len,
2089            max_batch_size,
2090        } = params
2091        else {
2092            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2093        };
2094
2095        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2096
2097        Ok(
2098            max_batch_size
2099                * cfg.num_attention_heads
2100                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2101        )
2102    }
2103    fn non_mapped_max_act_size_elems(
2104        &self,
2105        _config: &str,
2106        _params: &AutoDeviceMapParams,
2107    ) -> Result<usize> {
2108        Ok(0)
2109    }
2110
2111    fn non_mapped_size_in_bytes(
2112        &self,
2113        config: &str,
2114        dtype: DType,
2115        weight_pack_factor: usize,
2116        _matformer_config: Option<&MatformerSliceConfig>,
2117    ) -> Result<usize> {
2118        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2119
2120        let elems = {
2121            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2122            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2123            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2124                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2125            } else {
2126                0
2127            };
2128            let norm = cfg.hidden_size + cfg.hidden_size;
2129            embed_tokens + lm_head + norm
2130        };
2131        Ok(elems * dtype.size_in_bytes())
2132    }
2133
2134    fn layer_sizes_in_bytes(
2135        &self,
2136        config: &str,
2137        dtype: DType,
2138        weight_pack_factor: usize,
2139        _matformer_config: Option<&MatformerSliceConfig>,
2140    ) -> Result<Vec<usize>> {
2141        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2142
2143        let per_layer_elems = {
2144            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2145            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2146
2147            let size_in = cfg.hidden_size;
2148            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2149            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2150            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2151            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2152            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2153            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2154
2155            let h_size = cfg.hidden_size;
2156            let i_size = cfg.intermediate_size;
2157            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2158            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2159
2160            input_layernorm
2161                + post_attention_layernorm
2162                + q_proj
2163                + k_proj
2164                + v_proj
2165                + o_proj
2166                + fc1
2167                + fc2
2168        };
2169        Ok(vec![
2170            per_layer_elems * dtype.size_in_bytes();
2171            cfg.num_hidden_layers
2172        ])
2173    }
2174
2175    fn num_layers(&self, config: &str) -> Result<usize> {
2176        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2177
2178        Ok(cfg.num_hidden_layers)
2179    }
2180
2181    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2182        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2183
2184        let cfg = ModelConfigMetadata {
2185            max_seq_len: cfg.max_position_embeddings,
2186            num_layers: cfg.num_hidden_layers,
2187            hidden_size: cfg.hidden_size,
2188            num_kv_heads: cfg.num_key_value_heads,
2189            num_attn_heads: cfg.num_attention_heads,
2190            sliding_window: cfg.sliding_window,
2191            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2192            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2193            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2194        };
2195
2196        Ok(Box::new(cfg))
2197    }
2198}
2199
2200// ======================== Phi3 loader
2201
2202/// [`NormalLoader`] for a Phi 3.5 MoE model.
2203///
2204/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
2205pub struct Phi3_5MoELoader;
2206
2207impl NormalModelLoader for Phi3_5MoELoader {
2208    fn load(
2209        &self,
2210        config: &str,
2211        vb: ShardedVarBuilder,
2212        normal_loading_metadata: NormalLoadingMetadata,
2213        attention_mechanism: AttentionImplementation,
2214    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2215        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2216
2217        Ok(Box::new(models::phi3_5_moe::Model::new(
2218            &cfg,
2219            vb,
2220            self.is_gptx(config)?,
2221            normal_loading_metadata,
2222            attention_mechanism,
2223        )?))
2224    }
2225    fn load_xlora(
2226        &self,
2227        config: &str,
2228        vb: ShardedVarBuilder,
2229        lora_config: &[((String, String), LoraConfig)],
2230        xlora_config: Option<XLoraConfig>,
2231        xlora_ordering: Ordering,
2232        normal_loading_metadata: NormalLoadingMetadata,
2233        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2234    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2235        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2236
2237        Ok(Box::new(xlora_models::XLoraPhi3::new(
2238            &cfg,
2239            vb,
2240            lora_config,
2241            xlora_config,
2242            xlora_ordering,
2243            self.is_gptx(config)?,
2244            normal_loading_metadata,
2245            preload_adapters,
2246        )?))
2247    }
2248    fn is_gptx(&self, _: &str) -> Result<bool> {
2249        Ok(true)
2250    }
2251    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2252        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2253
2254        Ok(Box::new(cfg))
2255    }
2256}
2257
2258impl IsqModelLoader for Phi3_5MoELoader {
2259    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2260        Ok(vec![
2261            Regex::new(r"lm_head\.(weight|bias)$")?,
2262            // Attention
2263            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2264            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2265            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2266            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2267            // MLP
2268            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2269            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2270            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2271        ])
2272    }
2273    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2274        self.isq_layer_regexes(config)
2275    }
2276
2277    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2278        Ok(vec![
2279            Regex::new(r"lm_head\.(weight|bias)$")?,
2280            // MLP
2281            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2282            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2283            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2284        ])
2285    }
2286    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2287        self.isq_layer_regexes_moqe(config)
2288    }
2289}
2290
2291impl DeviceMappedModelLoader for Phi3_5MoELoader {
2292    fn mapped_max_act_size_elems(
2293        &self,
2294        config: &str,
2295        params: &AutoDeviceMapParams,
2296    ) -> Result<usize> {
2297        let AutoDeviceMapParams::Text {
2298            max_seq_len,
2299            max_batch_size,
2300        } = params
2301        else {
2302            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2303        };
2304
2305        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2306
2307        Ok(
2308            max_batch_size
2309                * cfg.num_attention_heads
2310                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2311        )
2312    }
2313    fn non_mapped_max_act_size_elems(
2314        &self,
2315        _config: &str,
2316        _params: &AutoDeviceMapParams,
2317    ) -> Result<usize> {
2318        Ok(0)
2319    }
2320
2321    fn non_mapped_size_in_bytes(
2322        &self,
2323        config: &str,
2324        dtype: DType,
2325        weight_pack_factor: usize,
2326        _matformer_config: Option<&MatformerSliceConfig>,
2327    ) -> Result<usize> {
2328        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2329
2330        let elems = {
2331            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2332            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2333            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2334                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2335            } else {
2336                0
2337            };
2338            let norm = cfg.hidden_size;
2339            embed_tokens + lm_head + norm
2340        };
2341        Ok(elems * dtype.size_in_bytes())
2342    }
2343
2344    fn layer_sizes_in_bytes(
2345        &self,
2346        config: &str,
2347        dtype: DType,
2348        weight_pack_factor: usize,
2349        _matformer_config: Option<&MatformerSliceConfig>,
2350    ) -> Result<Vec<usize>> {
2351        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2352
2353        let per_layer_elems = {
2354            let input_layernorm = cfg.hidden_size;
2355            let post_attention_layernorm = cfg.hidden_size;
2356
2357            let size_in = cfg.hidden_size;
2358            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2359            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2360            let q_proj =
2361                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2362            let k_proj =
2363                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2364            let v_proj =
2365                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2366            let o_proj =
2367                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2368
2369            let moe_block = {
2370                let gate = cfg.hidden_size * cfg.num_local_experts;
2371                // Assume quantizing weight pack factor
2372                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2373                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2374                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2375                gate + cfg.num_local_experts * w1
2376                    + cfg.num_local_experts * w2
2377                    + cfg.num_local_experts * w3
2378            };
2379
2380            input_layernorm
2381                + post_attention_layernorm
2382                + q_proj
2383                + k_proj
2384                + v_proj
2385                + o_proj
2386                + moe_block
2387        };
2388        Ok(vec![
2389            per_layer_elems * dtype.size_in_bytes();
2390            cfg.num_hidden_layers
2391        ])
2392    }
2393
2394    fn num_layers(&self, config: &str) -> Result<usize> {
2395        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2396
2397        Ok(cfg.num_hidden_layers)
2398    }
2399
2400    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2401        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2402
2403        let cfg = ModelConfigMetadata {
2404            max_seq_len: cfg.max_position_embeddings,
2405            num_layers: cfg.num_hidden_layers,
2406            hidden_size: cfg.hidden_size,
2407            num_kv_heads: cfg.num_key_value_heads,
2408            num_attn_heads: cfg.num_attention_heads,
2409            sliding_window: cfg.sliding_window,
2410            k_head_dim: cfg.head_dim(),
2411            v_head_dim: cfg.head_dim(),
2412            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2413        };
2414
2415        Ok(Box::new(cfg))
2416    }
2417}
2418
2419/// [`NormalLoader`] for a DeepSeekV2 model.
2420///
2421/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
2422pub struct DeepSeekV2Loader;
2423
2424impl NormalModelLoader for DeepSeekV2Loader {
2425    fn load(
2426        &self,
2427        config: &str,
2428        vb: ShardedVarBuilder,
2429        normal_loading_metadata: NormalLoadingMetadata,
2430        attention_mechanism: AttentionImplementation,
2431    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2432        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2433
2434        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2435            &cfg,
2436            vb,
2437            self.is_gptx(config)?,
2438            normal_loading_metadata,
2439            attention_mechanism,
2440        )?))
2441    }
2442    fn load_xlora(
2443        &self,
2444        _config: &str,
2445        _vb: ShardedVarBuilder,
2446        _lora_config: &[((String, String), LoraConfig)],
2447        _xlora_config: Option<XLoraConfig>,
2448        _xlora_ordering: Ordering,
2449        _normal_loading_metadata: NormalLoadingMetadata,
2450        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2451    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2452        todo!()
2453    }
2454    fn is_gptx(&self, _: &str) -> Result<bool> {
2455        Ok(true)
2456    }
2457    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2458        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2459        Ok(Box::new(cfg))
2460    }
2461}
2462
2463impl IsqModelLoader for DeepSeekV2Loader {
2464    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2465        let mut data = vec![
2466            Regex::new(r"lm_head\.(weight|bias)$")?,
2467            // Attention
2468            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2469            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2470            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2471        ];
2472        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2473        if cfg.q_lora_rank.is_some() {
2474            data.extend(vec![
2475                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2476                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2477            ]);
2478        } else {
2479            data.push(Regex::new(
2480                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2481            )?);
2482        }
2483        for layer_idx in 0..cfg.num_hidden_layers {
2484            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2485                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2486            }) {
2487                for i in 0..n_routed_experts {
2488                    data.extend(vec![
2489                        Regex::new(&format!(
2490                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2491                        ))?,
2492                        Regex::new(&format!(
2493                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2494                        ))?,
2495                        Regex::new(&format!(
2496                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2497                        ))?,
2498                    ]);
2499                }
2500                if cfg.n_shared_experts.is_some() {
2501                    data.extend(vec![
2502                        Regex::new(&format!(
2503                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2504                        ))?,
2505                        Regex::new(&format!(
2506                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2507                        ))?,
2508                        Regex::new(&format!(
2509                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2510                        ))?,
2511                    ]);
2512                }
2513            } else {
2514                data.extend(vec![
2515                    Regex::new(&format!(
2516                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2517                    ))?,
2518                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2519                    Regex::new(&format!(
2520                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2521                    ))?,
2522                ]);
2523            };
2524        }
2525        Ok(data)
2526    }
2527    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2528        self.isq_layer_regexes(config)
2529    }
2530
2531    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2532        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2533        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2534        for layer_idx in 0..cfg.num_hidden_layers {
2535            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2536                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2537            }) {
2538                for i in 0..n_routed_experts {
2539                    data.extend(vec![
2540                        Regex::new(&format!(
2541                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2542                        ))?,
2543                        Regex::new(&format!(
2544                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2545                        ))?,
2546                        Regex::new(&format!(
2547                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2548                        ))?,
2549                    ]);
2550                }
2551                if cfg.n_shared_experts.is_some() {
2552                    data.extend(vec![
2553                        Regex::new(&format!(
2554                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2555                        ))?,
2556                        Regex::new(&format!(
2557                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2558                        ))?,
2559                        Regex::new(&format!(
2560                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2561                        ))?,
2562                    ]);
2563                }
2564            } else {
2565                data.extend(vec![
2566                    Regex::new(&format!(
2567                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2568                    ))?,
2569                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2570                    Regex::new(&format!(
2571                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2572                    ))?,
2573                ]);
2574            };
2575        }
2576        Ok(data)
2577    }
2578    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2579        self.isq_layer_regexes_moqe(config)
2580    }
2581}
2582
2583impl DeviceMappedModelLoader for DeepSeekV2Loader {
2584    fn mapped_max_act_size_elems(
2585        &self,
2586        config: &str,
2587        params: &AutoDeviceMapParams,
2588    ) -> Result<usize> {
2589        let AutoDeviceMapParams::Text {
2590            max_seq_len,
2591            max_batch_size,
2592        } = params
2593        else {
2594            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2595        };
2596
2597        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2598
2599        Ok(
2600            max_batch_size
2601                * cfg.num_attention_heads
2602                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2603        )
2604    }
2605    fn non_mapped_max_act_size_elems(
2606        &self,
2607        _config: &str,
2608        _params: &AutoDeviceMapParams,
2609    ) -> Result<usize> {
2610        Ok(0)
2611    }
2612
2613    fn non_mapped_size_in_bytes(
2614        &self,
2615        config: &str,
2616        dtype: DType,
2617        weight_pack_factor: usize,
2618        _matformer_config: Option<&MatformerSliceConfig>,
2619    ) -> Result<usize> {
2620        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2621        let elems = {
2622            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2623            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2624            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2625                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2626            } else {
2627                0
2628            };
2629            let norm = cfg.hidden_size;
2630            embed_tokens + lm_head + norm
2631        };
2632        Ok(elems * dtype.size_in_bytes())
2633    }
2634
2635    fn layer_sizes_in_bytes(
2636        &self,
2637        config: &str,
2638        dtype: DType,
2639        weight_pack_factor: usize,
2640        _matformer_config: Option<&MatformerSliceConfig>,
2641    ) -> Result<Vec<usize>> {
2642        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2643        let mut per_layer_elems = Vec::new();
2644
2645        for layer_idx in 0..cfg.num_hidden_layers {
2646            let input_layernorm = cfg.hidden_size;
2647            let post_attention_layernorm = cfg.hidden_size;
2648
2649            let q_proj = match cfg.q_lora_rank {
2650                Some(lora_rank) => {
2651                    let a = cfg.hidden_size * lora_rank;
2652                    let norm = lora_rank;
2653                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2654                    a + norm + b
2655                }
2656                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2657            };
2658            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2659                / weight_pack_factor
2660                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2661            let kv_a_layernorm = cfg.kv_lora_rank;
2662            let kv_b_proj = cfg.kv_lora_rank
2663                * cfg.num_attention_heads
2664                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2665                / weight_pack_factor;
2666            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2667                / weight_pack_factor
2668                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2669
2670            let moe_block = {
2671                let mut sum = 0;
2672                if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2673                    layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2674                }) {
2675                    let h_size = cfg.hidden_size;
2676                    let gate_proj =
2677                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2678                    let up_proj =
2679                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2680                    let down_proj =
2681                        cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
2682                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2683                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2684                            / weight_pack_factor;
2685                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2686                            / weight_pack_factor;
2687                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2688                            / weight_pack_factor;
2689                        gate_proj + up_proj + down_proj
2690                    } else {
2691                        0
2692                    };
2693                    let gate_weight = n_routed_experts * cfg.hidden_size;
2694                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2695                } else {
2696                    let h_size = cfg.hidden_size;
2697                    let i_size = cfg.intermediate_size;
2698                    let gate_proj = h_size * i_size / weight_pack_factor;
2699                    let up_proj = h_size * i_size / weight_pack_factor;
2700                    let down_proj = i_size * h_size / weight_pack_factor;
2701                    sum += gate_proj + up_proj + down_proj;
2702                }
2703                sum
2704            };
2705
2706            per_layer_elems.push(
2707                input_layernorm
2708                    + post_attention_layernorm
2709                    + q_proj
2710                    + kv_a_layernorm
2711                    + kv_a_proj_with_mqa
2712                    + kv_b_proj
2713                    + o_proj
2714                    + moe_block,
2715            );
2716        }
2717
2718        Ok(per_layer_elems
2719            .into_iter()
2720            .map(|x| x * dtype.size_in_bytes())
2721            .collect())
2722    }
2723
2724    fn num_layers(&self, config: &str) -> Result<usize> {
2725        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2726        Ok(cfg.num_hidden_layers)
2727    }
2728
2729    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2730        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2731
2732        let cfg = ModelConfigMetadata {
2733            max_seq_len: cfg.max_position_embeddings,
2734            num_layers: cfg.num_hidden_layers,
2735            hidden_size: cfg.hidden_size,
2736            num_kv_heads: cfg.num_attention_heads,
2737            num_attn_heads: cfg.num_attention_heads,
2738            sliding_window: None,
2739            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2740            v_head_dim: cfg.v_head_dim,
2741            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2742        };
2743
2744        Ok(Box::new(cfg))
2745    }
2746}
2747
2748/// [`NormalLoader`] for a DeepSeekV3 model.
2749///
2750/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
2751pub struct DeepSeekV3Loader;
2752
2753impl NormalModelLoader for DeepSeekV3Loader {
2754    fn load(
2755        &self,
2756        config: &str,
2757        vb: ShardedVarBuilder,
2758        normal_loading_metadata: NormalLoadingMetadata,
2759        attention_mechanism: AttentionImplementation,
2760    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2761        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2762        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2763            &cfg,
2764            vb,
2765            self.is_gptx(config)?,
2766            normal_loading_metadata,
2767            attention_mechanism,
2768        )?))
2769    }
2770    fn load_xlora(
2771        &self,
2772        _config: &str,
2773        _vb: ShardedVarBuilder,
2774        _lora_config: &[((String, String), LoraConfig)],
2775        _xlora_config: Option<XLoraConfig>,
2776        _xlora_ordering: Ordering,
2777        _normal_loading_metadata: NormalLoadingMetadata,
2778        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2779    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2780        todo!()
2781    }
2782    fn is_gptx(&self, _: &str) -> Result<bool> {
2783        Ok(true)
2784    }
2785    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2786        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2787        Ok(Box::new(cfg))
2788    }
2789}
2790
2791impl IsqModelLoader for DeepSeekV3Loader {
2792    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2793        let mut data = vec![
2794            Regex::new(r"lm_head\.(weight|bias)$")?,
2795            // Attention
2796            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2797            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2798            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2799        ];
2800        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2801        if cfg.q_lora_rank.is_some() {
2802            data.extend(vec![
2803                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2804                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2805            ]);
2806        } else {
2807            data.push(Regex::new(
2808                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2809            )?);
2810        }
2811        for layer_idx in 0..cfg.num_hidden_layers {
2812            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2813                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2814            }) {
2815                for i in 0..n_routed_experts {
2816                    data.extend(vec![
2817                        Regex::new(&format!(
2818                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2819                        ))?,
2820                        Regex::new(&format!(
2821                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2822                        ))?,
2823                        Regex::new(&format!(
2824                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2825                        ))?,
2826                    ]);
2827                }
2828                if cfg.n_shared_experts.is_some() {
2829                    data.extend(vec![
2830                        Regex::new(&format!(
2831                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2832                        ))?,
2833                        Regex::new(&format!(
2834                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2835                        ))?,
2836                        Regex::new(&format!(
2837                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2838                        ))?,
2839                    ]);
2840                }
2841            } else {
2842                data.extend(vec![
2843                    Regex::new(&format!(
2844                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2845                    ))?,
2846                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2847                    Regex::new(&format!(
2848                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2849                    ))?,
2850                ]);
2851            };
2852        }
2853        Ok(data)
2854    }
2855    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2856        self.isq_layer_regexes(config)
2857    }
2858
2859    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2860        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2861        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2862        for layer_idx in 0..cfg.num_hidden_layers {
2863            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2864                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2865            }) {
2866                for i in 0..n_routed_experts {
2867                    data.extend(vec![
2868                        Regex::new(&format!(
2869                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2870                        ))?,
2871                        Regex::new(&format!(
2872                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2873                        ))?,
2874                        Regex::new(&format!(
2875                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2876                        ))?,
2877                    ]);
2878                }
2879                if cfg.n_shared_experts.is_some() {
2880                    data.extend(vec![
2881                        Regex::new(&format!(
2882                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2883                        ))?,
2884                        Regex::new(&format!(
2885                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2886                        ))?,
2887                        Regex::new(&format!(
2888                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2889                        ))?,
2890                    ]);
2891                }
2892            } else {
2893                data.extend(vec![
2894                    Regex::new(&format!(
2895                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2896                    ))?,
2897                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2898                    Regex::new(&format!(
2899                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2900                    ))?,
2901                ]);
2902            };
2903        }
2904        Ok(data)
2905    }
2906    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2907        self.isq_layer_regexes_moqe(config)
2908    }
2909}
2910
2911impl DeviceMappedModelLoader for DeepSeekV3Loader {
2912    fn mapped_max_act_size_elems(
2913        &self,
2914        config: &str,
2915        params: &AutoDeviceMapParams,
2916    ) -> Result<usize> {
2917        let AutoDeviceMapParams::Text {
2918            max_seq_len,
2919            max_batch_size,
2920        } = params
2921        else {
2922            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2923        };
2924
2925        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2926
2927        Ok(
2928            max_batch_size
2929                * cfg.num_attention_heads
2930                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2931        )
2932    }
2933    fn non_mapped_max_act_size_elems(
2934        &self,
2935        _config: &str,
2936        _params: &AutoDeviceMapParams,
2937    ) -> Result<usize> {
2938        Ok(0)
2939    }
2940
2941    fn non_mapped_size_in_bytes(
2942        &self,
2943        config: &str,
2944        dtype: DType,
2945        weight_pack_factor: usize,
2946        _matformer_config: Option<&MatformerSliceConfig>,
2947    ) -> Result<usize> {
2948        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2949        let elems = {
2950            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2951            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2952            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2953                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2954            } else {
2955                0
2956            };
2957            let norm = cfg.hidden_size;
2958            embed_tokens + lm_head + norm
2959        };
2960        Ok(elems * dtype.size_in_bytes())
2961    }
2962
2963    fn layer_sizes_in_bytes(
2964        &self,
2965        config: &str,
2966        dtype: DType,
2967        weight_pack_factor: usize,
2968        _matformer_config: Option<&MatformerSliceConfig>,
2969    ) -> Result<Vec<usize>> {
2970        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2971        let mut per_layer_elems = Vec::new();
2972
2973        for layer_idx in 0..cfg.num_hidden_layers {
2974            let input_layernorm = cfg.hidden_size;
2975            let post_attention_layernorm = cfg.hidden_size;
2976
2977            let q_proj = match cfg.q_lora_rank {
2978                Some(lora_rank) => {
2979                    let a = cfg.hidden_size * lora_rank;
2980                    let norm = lora_rank;
2981                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2982                    a + norm + b
2983                }
2984                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2985            };
2986            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2987                / weight_pack_factor
2988                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2989            let kv_a_layernorm = cfg.kv_lora_rank;
2990            let kv_b_proj = cfg.kv_lora_rank
2991                * cfg.num_attention_heads
2992                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2993                / weight_pack_factor;
2994            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2995                / weight_pack_factor
2996                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2997
2998            let moe_block = {
2999                let mut sum = 0;
3000                if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
3001                    layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
3002                }) {
3003                    let h_size = cfg.hidden_size;
3004                    let gate_proj =
3005                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3006                    let up_proj =
3007                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3008                    let down_proj =
3009                        cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
3010                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3011                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3012                            / weight_pack_factor;
3013                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3014                            / weight_pack_factor;
3015                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3016                            / weight_pack_factor;
3017                        gate_proj + up_proj + down_proj
3018                    } else {
3019                        0
3020                    };
3021                    let gate_weight = n_routed_experts * cfg.hidden_size;
3022                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3023                } else {
3024                    let h_size = cfg.hidden_size;
3025                    let i_size = cfg.intermediate_size;
3026                    let gate_proj = h_size * i_size / weight_pack_factor;
3027                    let up_proj = h_size * i_size / weight_pack_factor;
3028                    let down_proj = i_size * h_size / weight_pack_factor;
3029                    sum += gate_proj + up_proj + down_proj;
3030                }
3031                sum
3032            };
3033
3034            per_layer_elems.push(
3035                input_layernorm
3036                    + post_attention_layernorm
3037                    + q_proj
3038                    + kv_a_layernorm
3039                    + kv_a_proj_with_mqa
3040                    + kv_b_proj
3041                    + o_proj
3042                    + moe_block,
3043            );
3044        }
3045
3046        Ok(per_layer_elems
3047            .into_iter()
3048            .map(|x| x * dtype.size_in_bytes())
3049            .collect())
3050    }
3051
3052    fn num_layers(&self, config: &str) -> Result<usize> {
3053        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3054        Ok(cfg.num_hidden_layers)
3055    }
3056
3057    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3058        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3059
3060        let cfg = ModelConfigMetadata {
3061            max_seq_len: cfg.max_position_embeddings,
3062            num_layers: cfg.num_hidden_layers,
3063            hidden_size: cfg.hidden_size,
3064            num_kv_heads: cfg.num_attention_heads,
3065            num_attn_heads: cfg.num_attention_heads,
3066            sliding_window: None,
3067            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3068            v_head_dim: cfg.v_head_dim,
3069            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3070        };
3071
3072        Ok(Box::new(cfg))
3073    }
3074}
3075
3076/// [`NormalLoader`] for a Qwen 3 model.
3077///
3078/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
3079pub struct Qwen3Loader;
3080
3081impl NormalModelLoader for Qwen3Loader {
3082    fn load(
3083        &self,
3084        config: &str,
3085        vb: ShardedVarBuilder,
3086        normal_loading_metadata: NormalLoadingMetadata,
3087        attention_mechanism: AttentionImplementation,
3088    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3089        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3090
3091        Ok(Box::new(models::qwen3::Model::new(
3092            &cfg,
3093            vb,
3094            self.is_gptx(config)?,
3095            normal_loading_metadata,
3096            attention_mechanism,
3097        )?))
3098    }
3099    fn load_xlora(
3100        &self,
3101        _config: &str,
3102        _vb: ShardedVarBuilder,
3103        _lora_config: &[((String, String), LoraConfig)],
3104        _xlora_config: Option<XLoraConfig>,
3105        _xlora_ordering: Ordering,
3106        _normal_loading_metadata: NormalLoadingMetadata,
3107        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3108    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3109        todo!()
3110    }
3111    fn is_gptx(&self, _: &str) -> Result<bool> {
3112        Ok(true)
3113    }
3114    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3115        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3116
3117        Ok(Box::new(cfg))
3118    }
3119}
3120
3121impl IsqModelLoader for Qwen3Loader {
3122    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3123        Ok(vec![
3124            Regex::new(r"lm_head\.(weight|bias)$")?,
3125            // Attention
3126            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3127            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3128            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3129            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3130            // MLP
3131            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3132            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3133            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3134        ])
3135    }
3136    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3137        self.isq_layer_regexes(config)
3138    }
3139}
3140
3141impl DeviceMappedModelLoader for Qwen3Loader {
3142    fn mapped_max_act_size_elems(
3143        &self,
3144        config: &str,
3145        params: &AutoDeviceMapParams,
3146    ) -> Result<usize> {
3147        let AutoDeviceMapParams::Text {
3148            max_seq_len,
3149            max_batch_size,
3150        } = params
3151        else {
3152            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3153        };
3154
3155        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3156
3157        Ok(
3158            max_batch_size
3159                * cfg.num_attention_heads
3160                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3161        )
3162    }
3163    fn non_mapped_max_act_size_elems(
3164        &self,
3165        _config: &str,
3166        _params: &AutoDeviceMapParams,
3167    ) -> Result<usize> {
3168        Ok(0)
3169    }
3170
3171    fn non_mapped_size_in_bytes(
3172        &self,
3173        config: &str,
3174        dtype: DType,
3175        weight_pack_factor: usize,
3176        _matformer_config: Option<&MatformerSliceConfig>,
3177    ) -> Result<usize> {
3178        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3179        let elems = {
3180            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3181            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3182            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3183                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3184            } else {
3185                0
3186            };
3187            let norm = cfg.hidden_size;
3188            embed_tokens + lm_head + norm
3189        };
3190        Ok(elems * dtype.size_in_bytes())
3191    }
3192
3193    fn layer_sizes_in_bytes(
3194        &self,
3195        config: &str,
3196        dtype: DType,
3197        weight_pack_factor: usize,
3198        _matformer_config: Option<&MatformerSliceConfig>,
3199    ) -> Result<Vec<usize>> {
3200        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3201        let per_layer_elems = {
3202            let input_layernorm = cfg.hidden_size;
3203            let post_attention_layernorm = cfg.hidden_size;
3204
3205            let size_in = cfg.hidden_size;
3206            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3207            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3208            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3209            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3210            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3211            let o_proj = size_q * size_in / weight_pack_factor;
3212
3213            let h_size = cfg.hidden_size;
3214            let i_size = cfg.intermediate_size;
3215            let gate_proj = h_size * i_size / weight_pack_factor;
3216            let up_proj = h_size * i_size / weight_pack_factor;
3217            let down_proj = i_size * h_size / weight_pack_factor;
3218
3219            let q_norm = cfg.head_dim();
3220            let k_norm = cfg.head_dim();
3221
3222            input_layernorm
3223                + post_attention_layernorm
3224                + q_proj
3225                + k_proj
3226                + v_proj
3227                + o_proj
3228                + gate_proj
3229                + up_proj
3230                + down_proj
3231                + q_norm
3232                + k_norm
3233        };
3234        Ok(vec![
3235            per_layer_elems * dtype.size_in_bytes();
3236            cfg.num_hidden_layers
3237        ])
3238    }
3239
3240    fn num_layers(&self, config: &str) -> Result<usize> {
3241        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3242        Ok(cfg.num_hidden_layers)
3243    }
3244
3245    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3246        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3247
3248        let cfg = ModelConfigMetadata {
3249            max_seq_len: cfg.max_position_embeddings,
3250            num_layers: cfg.num_hidden_layers,
3251            hidden_size: cfg.hidden_size,
3252            num_kv_heads: cfg.num_key_value_heads,
3253            num_attn_heads: cfg.num_attention_heads,
3254            sliding_window: cfg.sliding_window,
3255            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3256            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3257            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3258        };
3259
3260        Ok(Box::new(cfg))
3261    }
3262}
3263
3264/// [`NormalLoader`] for a GLM 4 model.
3265///
3266/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
3267pub struct GLM4Loader;
3268
3269impl NormalModelLoader for GLM4Loader {
3270    fn load(
3271        &self,
3272        config: &str,
3273        vb: ShardedVarBuilder,
3274        normal_loading_metadata: NormalLoadingMetadata,
3275        attention_mechanism: AttentionImplementation,
3276    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3277        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3278
3279        Ok(Box::new(models::glm4::Model::new(
3280            &cfg,
3281            vb,
3282            self.is_gptx(config)?,
3283            normal_loading_metadata,
3284            attention_mechanism,
3285        )?))
3286    }
3287    fn load_xlora(
3288        &self,
3289        _config: &str,
3290        _vb: ShardedVarBuilder,
3291        _lora_config: &[((String, String), LoraConfig)],
3292        _xlora_config: Option<XLoraConfig>,
3293        _xlora_ordering: Ordering,
3294        _normal_loading_metadata: NormalLoadingMetadata,
3295        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3296    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3297        todo!()
3298    }
3299    fn is_gptx(&self, _: &str) -> Result<bool> {
3300        Ok(true)
3301    }
3302    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3303        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3304
3305        Ok(Box::new(cfg))
3306    }
3307}
3308
3309impl IsqModelLoader for GLM4Loader {
3310    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3311        Ok(vec![
3312            Regex::new(r"lm_head\.(weight|bias)$")?,
3313            // Attention
3314            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3315            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3316            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3317            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3318            // MLP
3319            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3320            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3321            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3322        ])
3323    }
3324    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3325        self.isq_layer_regexes(config)
3326    }
3327}
3328
3329impl DeviceMappedModelLoader for GLM4Loader {
3330    fn mapped_max_act_size_elems(
3331        &self,
3332        config: &str,
3333        params: &AutoDeviceMapParams,
3334    ) -> Result<usize> {
3335        let AutoDeviceMapParams::Text {
3336            max_seq_len,
3337            max_batch_size,
3338        } = params
3339        else {
3340            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3341        };
3342
3343        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3344
3345        Ok(
3346            max_batch_size
3347                * cfg.num_attention_heads
3348                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3349        )
3350    }
3351    fn non_mapped_max_act_size_elems(
3352        &self,
3353        _config: &str,
3354        _params: &AutoDeviceMapParams,
3355    ) -> Result<usize> {
3356        Ok(0)
3357    }
3358
3359    fn non_mapped_size_in_bytes(
3360        &self,
3361        config: &str,
3362        dtype: DType,
3363        weight_pack_factor: usize,
3364        _matformer_config: Option<&MatformerSliceConfig>,
3365    ) -> Result<usize> {
3366        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3367        let elems = {
3368            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3369            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3370            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3371                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3372            } else {
3373                0
3374            };
3375            let norm = cfg.hidden_size;
3376            embed_tokens + lm_head + norm
3377        };
3378        Ok(elems * dtype.size_in_bytes())
3379    }
3380
3381    fn layer_sizes_in_bytes(
3382        &self,
3383        config: &str,
3384        dtype: DType,
3385        weight_pack_factor: usize,
3386        _matformer_config: Option<&MatformerSliceConfig>,
3387    ) -> Result<Vec<usize>> {
3388        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3389        let per_layer_elems = {
3390            let input_layernorm = cfg.hidden_size;
3391            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3392
3393            let size_in = cfg.hidden_size;
3394            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3395            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3396            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3397            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3398            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3399            let o_proj = size_q * size_in / weight_pack_factor;
3400
3401            let h_size = cfg.hidden_size;
3402            let i_size = cfg.intermediate_size;
3403            let gate_proj = h_size * i_size / weight_pack_factor;
3404            let up_proj = h_size * i_size / weight_pack_factor;
3405            let down_proj = i_size * h_size / weight_pack_factor;
3406
3407            input_layernorm
3408                + post_attention_layernorm
3409                + q_proj
3410                + k_proj
3411                + v_proj
3412                + o_proj
3413                + gate_proj
3414                + up_proj
3415                + down_proj
3416        };
3417        Ok(vec![
3418            per_layer_elems * dtype.size_in_bytes();
3419            cfg.num_hidden_layers
3420        ])
3421    }
3422
3423    fn num_layers(&self, config: &str) -> Result<usize> {
3424        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3425        Ok(cfg.num_hidden_layers)
3426    }
3427
3428    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3429        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3430
3431        let cfg = ModelConfigMetadata {
3432            max_seq_len: cfg.max_position_embeddings,
3433            num_layers: cfg.num_hidden_layers,
3434            hidden_size: cfg.hidden_size,
3435            num_kv_heads: cfg.num_key_value_heads,
3436            num_attn_heads: cfg.num_attention_heads,
3437            sliding_window: cfg.sliding_window,
3438            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3439            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3440            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3441        };
3442
3443        Ok(Box::new(cfg))
3444    }
3445}
3446
3447/// [`NormalLoader`] for a GLM 4 MoE Lite model (GLM-4.7-Flash).
3448///
3449/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
3450pub struct GLM4MoeLiteLoader;
3451
3452impl NormalModelLoader for GLM4MoeLiteLoader {
3453    fn load(
3454        &self,
3455        config: &str,
3456        vb: ShardedVarBuilder,
3457        normal_loading_metadata: NormalLoadingMetadata,
3458        attention_mechanism: AttentionImplementation,
3459    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3460        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3461        Ok(Box::new(models::glm4_moe_lite::Glm4MoeLite::new(
3462            &cfg,
3463            vb,
3464            self.is_gptx(config)?,
3465            normal_loading_metadata,
3466            attention_mechanism,
3467        )?))
3468    }
3469    fn load_xlora(
3470        &self,
3471        _config: &str,
3472        _vb: ShardedVarBuilder,
3473        _lora_config: &[((String, String), LoraConfig)],
3474        _xlora_config: Option<XLoraConfig>,
3475        _xlora_ordering: Ordering,
3476        _normal_loading_metadata: NormalLoadingMetadata,
3477        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3478    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3479        todo!()
3480    }
3481    fn is_gptx(&self, _: &str) -> Result<bool> {
3482        Ok(true)
3483    }
3484    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3485        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3486        Ok(Box::new(cfg))
3487    }
3488}
3489
3490impl IsqModelLoader for GLM4MoeLiteLoader {
3491    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3492        let mut data = vec![
3493            Regex::new(r"lm_head\.(weight|bias)$")?,
3494            // Attention (MLA)
3495            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3496            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3497            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3498            // Q LoRA projections
3499            Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3500            Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3501        ];
3502        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3503        for layer_idx in 0..cfg.num_hidden_layers {
3504            if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3505                // MoE layer
3506                for i in 0..cfg.n_routed_experts {
3507                    data.extend(vec![
3508                        Regex::new(&format!(
3509                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3510                        ))?,
3511                        Regex::new(&format!(
3512                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3513                        ))?,
3514                        Regex::new(&format!(
3515                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3516                        ))?,
3517                    ]);
3518                }
3519                if cfg.n_shared_experts > 0 {
3520                    data.extend(vec![
3521                        Regex::new(&format!(
3522                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3523                        ))?,
3524                        Regex::new(&format!(
3525                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3526                        ))?,
3527                        Regex::new(&format!(
3528                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3529                        ))?,
3530                    ]);
3531                }
3532            } else {
3533                // Dense MLP layer
3534                data.extend(vec![
3535                    Regex::new(&format!(
3536                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3537                    ))?,
3538                    Regex::new(&format!(
3539                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3540                    ))?,
3541                    Regex::new(&format!(
3542                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3543                    ))?,
3544                ]);
3545            };
3546        }
3547        Ok(data)
3548    }
3549    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3550        self.isq_layer_regexes(config)
3551    }
3552
3553    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3554        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3555        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3556        for layer_idx in 0..cfg.num_hidden_layers {
3557            if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3558                // MoE layer
3559                for i in 0..cfg.n_routed_experts {
3560                    data.extend(vec![
3561                        Regex::new(&format!(
3562                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3563                        ))?,
3564                        Regex::new(&format!(
3565                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3566                        ))?,
3567                        Regex::new(&format!(
3568                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3569                        ))?,
3570                    ]);
3571                }
3572                if cfg.n_shared_experts > 0 {
3573                    data.extend(vec![
3574                        Regex::new(&format!(
3575                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3576                        ))?,
3577                        Regex::new(&format!(
3578                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3579                        ))?,
3580                        Regex::new(&format!(
3581                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3582                        ))?,
3583                    ]);
3584                }
3585            } else {
3586                // Dense MLP layer
3587                data.extend(vec![
3588                    Regex::new(&format!(
3589                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3590                    ))?,
3591                    Regex::new(&format!(
3592                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3593                    ))?,
3594                    Regex::new(&format!(
3595                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3596                    ))?,
3597                ]);
3598            };
3599        }
3600        Ok(data)
3601    }
3602    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3603        self.isq_layer_regexes_moqe(config)
3604    }
3605}
3606
3607impl DeviceMappedModelLoader for GLM4MoeLiteLoader {
3608    fn mapped_max_act_size_elems(
3609        &self,
3610        config: &str,
3611        params: &AutoDeviceMapParams,
3612    ) -> Result<usize> {
3613        let AutoDeviceMapParams::Text {
3614            max_seq_len,
3615            max_batch_size,
3616        } = params
3617        else {
3618            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3619        };
3620
3621        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3622
3623        Ok(
3624            max_batch_size
3625                * cfg.num_attention_heads
3626                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3627        )
3628    }
3629    fn non_mapped_max_act_size_elems(
3630        &self,
3631        _config: &str,
3632        _params: &AutoDeviceMapParams,
3633    ) -> Result<usize> {
3634        Ok(0)
3635    }
3636
3637    fn non_mapped_size_in_bytes(
3638        &self,
3639        config: &str,
3640        dtype: DType,
3641        weight_pack_factor: usize,
3642        _matformer_config: Option<&MatformerSliceConfig>,
3643    ) -> Result<usize> {
3644        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3645        let elems = {
3646            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3647            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3648            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3649                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3650            } else {
3651                0
3652            };
3653            let norm = cfg.hidden_size;
3654            embed_tokens + lm_head + norm
3655        };
3656        Ok(elems * dtype.size_in_bytes())
3657    }
3658
3659    fn layer_sizes_in_bytes(
3660        &self,
3661        config: &str,
3662        dtype: DType,
3663        weight_pack_factor: usize,
3664        _matformer_config: Option<&MatformerSliceConfig>,
3665    ) -> Result<Vec<usize>> {
3666        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3667        let mut per_layer_elems = Vec::new();
3668
3669        for layer_idx in 0..cfg.num_hidden_layers {
3670            let input_layernorm = cfg.hidden_size;
3671            let post_attention_layernorm = cfg.hidden_size;
3672
3673            // Q LoRA projection
3674            let q_proj = {
3675                let a = cfg.hidden_size * cfg.q_lora_rank / weight_pack_factor;
3676                let norm = cfg.q_lora_rank;
3677                let b = (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.q_lora_rank
3678                    / weight_pack_factor;
3679                a + norm + b
3680            };
3681            let kv_a_proj_with_mqa =
3682                cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim) / weight_pack_factor;
3683            let kv_a_layernorm = cfg.kv_lora_rank;
3684            let kv_b_proj = cfg.kv_lora_rank
3685                * cfg.num_attention_heads
3686                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3687                / weight_pack_factor;
3688            let o_proj =
3689                cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size / weight_pack_factor;
3690
3691            let moe_block = {
3692                let mut sum = 0;
3693                if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3694                    // MoE layer
3695                    let h_size = cfg.hidden_size;
3696                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3697                        * cfg.n_routed_experts;
3698                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3699                        * cfg.n_routed_experts;
3700                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3701                        * cfg.n_routed_experts;
3702                    let shared_experts = if cfg.n_shared_experts > 0 {
3703                        let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3704                        let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3705                        let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
3706                        gate_proj + up_proj + down_proj
3707                    } else {
3708                        0
3709                    };
3710                    let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
3711                    let e_score_correction_bias = cfg.n_routed_experts;
3712                    sum += gate_proj
3713                        + up_proj
3714                        + down_proj
3715                        + shared_experts
3716                        + gate_weight
3717                        + e_score_correction_bias;
3718                } else {
3719                    // Dense MLP layer
3720                    let h_size = cfg.hidden_size;
3721                    let i_size = cfg.intermediate_size;
3722                    let gate_proj = h_size * i_size / weight_pack_factor;
3723                    let up_proj = h_size * i_size / weight_pack_factor;
3724                    let down_proj = i_size * h_size / weight_pack_factor;
3725                    sum += gate_proj + up_proj + down_proj;
3726                }
3727                sum
3728            };
3729
3730            per_layer_elems.push(
3731                input_layernorm
3732                    + post_attention_layernorm
3733                    + q_proj
3734                    + kv_a_layernorm
3735                    + kv_a_proj_with_mqa
3736                    + kv_b_proj
3737                    + o_proj
3738                    + moe_block,
3739            );
3740        }
3741
3742        Ok(per_layer_elems
3743            .into_iter()
3744            .map(|x| x * dtype.size_in_bytes())
3745            .collect())
3746    }
3747
3748    fn num_layers(&self, config: &str) -> Result<usize> {
3749        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3750        Ok(cfg.num_hidden_layers)
3751    }
3752
3753    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3754        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3755
3756        let cfg = ModelConfigMetadata {
3757            max_seq_len: cfg.max_position_embeddings,
3758            num_layers: cfg.num_hidden_layers,
3759            hidden_size: cfg.hidden_size,
3760            num_kv_heads: cfg.num_attention_heads,
3761            num_attn_heads: cfg.num_attention_heads,
3762            sliding_window: None,
3763            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3764            v_head_dim: cfg.v_head_dim,
3765            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3766        };
3767
3768        Ok(Box::new(cfg))
3769    }
3770}
3771
3772/// [`NormalLoader`] for a GLM 4 MoE model (GLM-4.5).
3773///
3774/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
3775pub struct GLM4MoeLoader;
3776
3777impl NormalModelLoader for GLM4MoeLoader {
3778    fn load(
3779        &self,
3780        config: &str,
3781        vb: ShardedVarBuilder,
3782        normal_loading_metadata: NormalLoadingMetadata,
3783        attention_mechanism: AttentionImplementation,
3784    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3785        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3786        Ok(Box::new(models::glm4_moe::Glm4Moe::new(
3787            &cfg,
3788            vb,
3789            self.is_gptx(config)?,
3790            normal_loading_metadata,
3791            attention_mechanism,
3792        )?))
3793    }
3794    fn load_xlora(
3795        &self,
3796        _config: &str,
3797        _vb: ShardedVarBuilder,
3798        _lora_config: &[((String, String), LoraConfig)],
3799        _xlora_config: Option<XLoraConfig>,
3800        _xlora_ordering: Ordering,
3801        _normal_loading_metadata: NormalLoadingMetadata,
3802        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3803    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3804        todo!()
3805    }
3806    fn is_gptx(&self, _: &str) -> Result<bool> {
3807        Ok(true)
3808    }
3809    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3810        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3811        Ok(Box::new(cfg))
3812    }
3813}
3814
3815impl IsqModelLoader for GLM4MoeLoader {
3816    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3817        let mut data = vec![
3818            Regex::new(r"lm_head\.(weight|bias)$")?,
3819            // Attention (standard GQA)
3820            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3821            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3822            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3823            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3824        ];
3825        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3826        for layer_idx in 0..cfg.num_hidden_layers {
3827            if layer_idx >= cfg.first_k_dense_replace {
3828                // MoE layer
3829                for i in 0..cfg.n_routed_experts {
3830                    data.extend(vec![
3831                        Regex::new(&format!(
3832                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3833                        ))?,
3834                        Regex::new(&format!(
3835                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3836                        ))?,
3837                        Regex::new(&format!(
3838                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3839                        ))?,
3840                    ]);
3841                }
3842                if cfg.n_shared_experts > 0 {
3843                    data.extend(vec![
3844                        Regex::new(&format!(
3845                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3846                        ))?,
3847                        Regex::new(&format!(
3848                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3849                        ))?,
3850                        Regex::new(&format!(
3851                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3852                        ))?,
3853                    ]);
3854                }
3855            } else {
3856                // Dense MLP layer
3857                data.extend(vec![
3858                    Regex::new(&format!(
3859                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3860                    ))?,
3861                    Regex::new(&format!(
3862                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3863                    ))?,
3864                    Regex::new(&format!(
3865                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3866                    ))?,
3867                ]);
3868            };
3869        }
3870        Ok(data)
3871    }
3872    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3873        self.isq_layer_regexes(config)
3874    }
3875
3876    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3877        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3878        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3879        for layer_idx in 0..cfg.num_hidden_layers {
3880            if layer_idx >= cfg.first_k_dense_replace {
3881                // MoE layer
3882                for i in 0..cfg.n_routed_experts {
3883                    data.extend(vec![
3884                        Regex::new(&format!(
3885                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3886                        ))?,
3887                        Regex::new(&format!(
3888                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3889                        ))?,
3890                        Regex::new(&format!(
3891                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3892                        ))?,
3893                    ]);
3894                }
3895                if cfg.n_shared_experts > 0 {
3896                    data.extend(vec![
3897                        Regex::new(&format!(
3898                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3899                        ))?,
3900                        Regex::new(&format!(
3901                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3902                        ))?,
3903                        Regex::new(&format!(
3904                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3905                        ))?,
3906                    ]);
3907                }
3908            } else {
3909                // Dense MLP layer
3910                data.extend(vec![
3911                    Regex::new(&format!(
3912                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3913                    ))?,
3914                    Regex::new(&format!(
3915                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3916                    ))?,
3917                    Regex::new(&format!(
3918                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3919                    ))?,
3920                ]);
3921            };
3922        }
3923        Ok(data)
3924    }
3925    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3926        self.isq_layer_regexes_moqe(config)
3927    }
3928}
3929
3930impl DeviceMappedModelLoader for GLM4MoeLoader {
3931    fn mapped_max_act_size_elems(
3932        &self,
3933        config: &str,
3934        params: &AutoDeviceMapParams,
3935    ) -> Result<usize> {
3936        let AutoDeviceMapParams::Text {
3937            max_seq_len,
3938            max_batch_size,
3939        } = params
3940        else {
3941            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3942        };
3943
3944        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3945
3946        Ok(
3947            max_batch_size
3948                * cfg.num_attention_heads
3949                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3950        )
3951    }
3952    fn non_mapped_max_act_size_elems(
3953        &self,
3954        _config: &str,
3955        _params: &AutoDeviceMapParams,
3956    ) -> Result<usize> {
3957        Ok(0)
3958    }
3959
3960    fn non_mapped_size_in_bytes(
3961        &self,
3962        config: &str,
3963        dtype: DType,
3964        weight_pack_factor: usize,
3965        _matformer_config: Option<&MatformerSliceConfig>,
3966    ) -> Result<usize> {
3967        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3968        let elems = {
3969            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3970            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3971                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3972            } else {
3973                0
3974            };
3975            let norm = cfg.hidden_size;
3976            embed_tokens + lm_head + norm
3977        };
3978        Ok(elems * dtype.size_in_bytes())
3979    }
3980
3981    fn layer_sizes_in_bytes(
3982        &self,
3983        config: &str,
3984        dtype: DType,
3985        weight_pack_factor: usize,
3986        _matformer_config: Option<&MatformerSliceConfig>,
3987    ) -> Result<Vec<usize>> {
3988        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3989        let mut per_layer_elems = Vec::new();
3990
3991        let head_dim = cfg.head_dim();
3992        for layer_idx in 0..cfg.num_hidden_layers {
3993            let input_layernorm = cfg.hidden_size;
3994            let post_attention_layernorm = cfg.hidden_size;
3995
3996            // Standard GQA attention
3997            let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor
3998                + bias_if!(cfg.attention_bias, cfg.num_attention_heads * head_dim);
3999            let k_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4000                + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4001            let v_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4002                + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4003            let o_proj = cfg.num_attention_heads * head_dim * cfg.hidden_size / weight_pack_factor;
4004
4005            // QK norm if enabled
4006            let qk_norm = if cfg.use_qk_norm {
4007                head_dim * 2 // q_norm + k_norm
4008            } else {
4009                0
4010            };
4011
4012            let moe_block = {
4013                let mut sum = 0;
4014                if layer_idx >= cfg.first_k_dense_replace {
4015                    // MoE layer
4016                    let h_size = cfg.hidden_size;
4017                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4018                        * cfg.n_routed_experts;
4019                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4020                        * cfg.n_routed_experts;
4021                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
4022                        * cfg.n_routed_experts;
4023                    let shared_experts = if cfg.n_shared_experts > 0 {
4024                        let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4025                        let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4026                        let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
4027                        gate_proj + up_proj + down_proj
4028                    } else {
4029                        0
4030                    };
4031                    let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
4032                    let e_score_correction_bias = cfg.n_routed_experts;
4033                    sum += gate_proj
4034                        + up_proj
4035                        + down_proj
4036                        + shared_experts
4037                        + gate_weight
4038                        + e_score_correction_bias;
4039                } else {
4040                    // Dense MLP layer
4041                    let h_size = cfg.hidden_size;
4042                    let i_size = cfg.intermediate_size;
4043                    let gate_proj = h_size * i_size / weight_pack_factor;
4044                    let up_proj = h_size * i_size / weight_pack_factor;
4045                    let down_proj = i_size * h_size / weight_pack_factor;
4046                    sum += gate_proj + up_proj + down_proj;
4047                }
4048                sum
4049            };
4050
4051            per_layer_elems.push(
4052                input_layernorm
4053                    + post_attention_layernorm
4054                    + q_proj
4055                    + k_proj
4056                    + v_proj
4057                    + o_proj
4058                    + qk_norm
4059                    + moe_block,
4060            );
4061        }
4062
4063        Ok(per_layer_elems
4064            .into_iter()
4065            .map(|x| x * dtype.size_in_bytes())
4066            .collect())
4067    }
4068
4069    fn num_layers(&self, config: &str) -> Result<usize> {
4070        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4071        Ok(cfg.num_hidden_layers)
4072    }
4073
4074    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4075        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4076
4077        let head_dim = cfg.head_dim();
4078        let cfg = ModelConfigMetadata {
4079            max_seq_len: cfg.max_position_embeddings,
4080            num_layers: cfg.num_hidden_layers,
4081            hidden_size: cfg.hidden_size,
4082            num_kv_heads: cfg.num_key_value_heads,
4083            num_attn_heads: cfg.num_attention_heads,
4084            sliding_window: None,
4085            k_head_dim: head_dim,
4086            v_head_dim: head_dim,
4087            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4088        };
4089
4090        Ok(Box::new(cfg))
4091    }
4092}
4093
4094/// [`NormalLoader`] for a Qwen 3 MoE model.
4095///
4096/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4097pub struct Qwen3MoELoader;
4098
4099impl NormalModelLoader for Qwen3MoELoader {
4100    fn load(
4101        &self,
4102        config: &str,
4103        vb: ShardedVarBuilder,
4104        normal_loading_metadata: NormalLoadingMetadata,
4105        attention_mechanism: AttentionImplementation,
4106    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4107        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4108
4109        Ok(Box::new(models::qwen3_moe::Model::new(
4110            &cfg,
4111            vb,
4112            self.is_gptx(config)?,
4113            normal_loading_metadata,
4114            attention_mechanism,
4115        )?))
4116    }
4117    fn load_xlora(
4118        &self,
4119        _config: &str,
4120        _vb: ShardedVarBuilder,
4121        _lora_config: &[((String, String), LoraConfig)],
4122        _xlora_config: Option<XLoraConfig>,
4123        _xlora_ordering: Ordering,
4124        _normal_loading_metadata: NormalLoadingMetadata,
4125        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4126    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4127        todo!()
4128    }
4129    fn is_gptx(&self, _: &str) -> Result<bool> {
4130        Ok(true)
4131    }
4132    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4133        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4134
4135        Ok(Box::new(cfg))
4136    }
4137}
4138
4139impl IsqModelLoader for Qwen3MoELoader {
4140    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4141        Ok(vec![
4142            Regex::new(r"lm_head\.(weight|bias)$")?,
4143            // Attention
4144            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4145            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4146            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4147            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4148            // MLP
4149            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4150            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4151            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4152            // MLP MoE
4153            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
4154            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
4155            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
4156        ])
4157    }
4158    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4159        self.isq_layer_regexes(config)
4160    }
4161    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
4162        self.isq_layer_regexes_moqe(config)
4163    }
4164}
4165
4166impl DeviceMappedModelLoader for Qwen3MoELoader {
4167    fn mapped_max_act_size_elems(
4168        &self,
4169        config: &str,
4170        params: &AutoDeviceMapParams,
4171    ) -> Result<usize> {
4172        let AutoDeviceMapParams::Text {
4173            max_seq_len,
4174            max_batch_size,
4175        } = params
4176        else {
4177            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4178        };
4179
4180        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4181
4182        Ok(
4183            max_batch_size
4184                * cfg.num_attention_heads
4185                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4186        )
4187    }
4188    fn non_mapped_max_act_size_elems(
4189        &self,
4190        _config: &str,
4191        _params: &AutoDeviceMapParams,
4192    ) -> Result<usize> {
4193        Ok(0)
4194    }
4195
4196    fn non_mapped_size_in_bytes(
4197        &self,
4198        config: &str,
4199        dtype: DType,
4200        weight_pack_factor: usize,
4201        _matformer_config: Option<&MatformerSliceConfig>,
4202    ) -> Result<usize> {
4203        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4204        let elems = {
4205            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
4206            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4207            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4208                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4209            } else {
4210                0
4211            };
4212            let norm = cfg.hidden_size;
4213            embed_tokens + lm_head + norm
4214        };
4215        Ok(elems * dtype.size_in_bytes())
4216    }
4217
4218    fn layer_sizes_in_bytes(
4219        &self,
4220        config: &str,
4221        dtype: DType,
4222        weight_pack_factor: usize,
4223        _matformer_config: Option<&MatformerSliceConfig>,
4224    ) -> Result<Vec<usize>> {
4225        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4226
4227        let mut layer_sizes_in_bytes = Vec::new();
4228        for layer_idx in 0..cfg.num_hidden_layers {
4229            let input_layernorm = cfg.hidden_size;
4230            let post_attention_layernorm = cfg.hidden_size;
4231
4232            let size_in = cfg.hidden_size;
4233            let size_q = cfg.head_dim() * cfg.num_attention_heads;
4234            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
4235            let q_proj = size_in * size_q / weight_pack_factor;
4236            let k_proj = size_in * size_kv / weight_pack_factor;
4237            let v_proj = size_in * size_kv / weight_pack_factor;
4238            let o_proj = size_q * size_in / weight_pack_factor;
4239
4240            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
4241                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
4242            {
4243                let gate_size = cfg.hidden_size * cfg.num_experts;
4244                let expert_size = {
4245                    let h_size = cfg.hidden_size;
4246                    let i_size = cfg.moe_intermediate_size;
4247                    let gate_proj = h_size * i_size / weight_pack_factor;
4248                    let up_proj = h_size * i_size / weight_pack_factor;
4249                    let down_proj = i_size * h_size / weight_pack_factor;
4250                    gate_proj + up_proj + down_proj
4251                };
4252                expert_size * cfg.num_experts + gate_size
4253            } else {
4254                let h_size = cfg.hidden_size;
4255                let i_size = cfg.intermediate_size;
4256                let gate_proj = h_size * i_size / weight_pack_factor;
4257                let up_proj = h_size * i_size / weight_pack_factor;
4258                let down_proj = i_size * h_size / weight_pack_factor;
4259                gate_proj + up_proj + down_proj
4260            };
4261
4262            let q_norm = cfg.head_dim();
4263            let k_norm = cfg.head_dim();
4264
4265            let size_elems = input_layernorm
4266                + post_attention_layernorm
4267                + q_proj
4268                + k_proj
4269                + v_proj
4270                + o_proj
4271                + mlp_size
4272                + q_norm
4273                + k_norm;
4274
4275            let size_in_bytes = size_elems * dtype.size_in_bytes();
4276            layer_sizes_in_bytes.push(size_in_bytes);
4277        }
4278
4279        Ok(layer_sizes_in_bytes)
4280    }
4281
4282    fn num_layers(&self, config: &str) -> Result<usize> {
4283        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4284        Ok(cfg.num_hidden_layers)
4285    }
4286
4287    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4288        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4289
4290        let cfg = ModelConfigMetadata {
4291            max_seq_len: cfg.max_position_embeddings,
4292            num_layers: cfg.num_hidden_layers,
4293            hidden_size: cfg.hidden_size,
4294            num_kv_heads: cfg.num_key_value_heads,
4295            num_attn_heads: cfg.num_attention_heads,
4296            sliding_window: cfg.sliding_window,
4297            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4298            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4299            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4300        };
4301
4302        Ok(Box::new(cfg))
4303    }
4304}
4305
4306// ======================== SmolLm3 loader
4307
4308/// [`NormalLoader`] for a SmolLm3 model.
4309///
4310/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4311pub struct SmolLm3Loader;
4312
4313impl NormalModelLoader for SmolLm3Loader {
4314    fn load(
4315        &self,
4316        config: &str,
4317        vb: ShardedVarBuilder,
4318        normal_loading_metadata: NormalLoadingMetadata,
4319        attention_mechanism: AttentionImplementation,
4320    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4321        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4322
4323        Ok(Box::new(models::smollm3::SmolLm3::new(
4324            &cfg,
4325            vb,
4326            self.is_gptx(config)?,
4327            normal_loading_metadata,
4328            attention_mechanism,
4329        )?))
4330    }
4331    fn load_xlora(
4332        &self,
4333        _config: &str,
4334        _vb: ShardedVarBuilder,
4335        _lora_config: &[((String, String), LoraConfig)],
4336        _xlora_config: Option<XLoraConfig>,
4337        _xlora_ordering: Ordering,
4338        _normal_loading_metadata: NormalLoadingMetadata,
4339        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4340    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4341        todo!()
4342    }
4343    fn is_gptx(&self, _: &str) -> Result<bool> {
4344        Ok(true)
4345    }
4346    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4347        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4348        Ok(Box::new(cfg))
4349    }
4350}
4351
4352impl IsqModelLoader for SmolLm3Loader {
4353    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4354        Ok(vec![
4355            Regex::new(r"lm_head\.(weight|bias)$")?,
4356            // Attention
4357            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4358            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4359            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4360            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4361            // MLP
4362            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4363            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4364            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4365        ])
4366    }
4367    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4368        self.isq_layer_regexes(config)
4369    }
4370}
4371
4372impl DeviceMappedModelLoader for SmolLm3Loader {
4373    fn mapped_max_act_size_elems(
4374        &self,
4375        config: &str,
4376        params: &AutoDeviceMapParams,
4377    ) -> Result<usize> {
4378        let AutoDeviceMapParams::Text {
4379            max_seq_len,
4380            max_batch_size,
4381        } = params
4382        else {
4383            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4384        };
4385
4386        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4387
4388        Ok(
4389            max_batch_size
4390                * cfg.num_attention_heads
4391                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4392        )
4393    }
4394    fn non_mapped_max_act_size_elems(
4395        &self,
4396        _config: &str,
4397        _params: &AutoDeviceMapParams,
4398    ) -> Result<usize> {
4399        Ok(0)
4400    }
4401
4402    fn non_mapped_size_in_bytes(
4403        &self,
4404        config: &str,
4405        dtype: DType,
4406        weight_pack_factor: usize,
4407        _matformer_config: Option<&MatformerSliceConfig>,
4408    ) -> Result<usize> {
4409        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4410
4411        let elems = {
4412            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4413            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4414            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4415                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4416            } else {
4417                0
4418            };
4419            let norm = cfg.hidden_size;
4420            embed_tokens + lm_head + norm
4421        };
4422        Ok(elems * dtype.size_in_bytes())
4423    }
4424
4425    fn layer_sizes_in_bytes(
4426        &self,
4427        config: &str,
4428        dtype: DType,
4429        weight_pack_factor: usize,
4430        _matformer_config: Option<&MatformerSliceConfig>,
4431    ) -> Result<Vec<usize>> {
4432        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4433
4434        let per_layer_elems = {
4435            let input_layernorm = cfg.hidden_size;
4436            let post_attention_layernorm = cfg.hidden_size;
4437
4438            let size_in = cfg.hidden_size;
4439            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4440            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4441            let q_proj = size_in * size_q / weight_pack_factor;
4442            let k_proj = size_in * size_kv / weight_pack_factor;
4443            let v_proj = size_in * size_kv / weight_pack_factor;
4444            let o_proj = size_q * size_in / weight_pack_factor;
4445
4446            let h_size = cfg.hidden_size;
4447            let i_size = cfg.intermediate_size;
4448            let gate_proj = h_size * i_size / weight_pack_factor;
4449            let up_proj = h_size * i_size / weight_pack_factor;
4450            let down_proj = i_size * h_size / weight_pack_factor;
4451
4452            input_layernorm
4453                + post_attention_layernorm
4454                + q_proj
4455                + k_proj
4456                + v_proj
4457                + o_proj
4458                + gate_proj
4459                + up_proj
4460                + down_proj
4461        };
4462        Ok(vec![
4463            per_layer_elems * dtype.size_in_bytes();
4464            cfg.num_hidden_layers
4465        ])
4466    }
4467
4468    fn num_layers(&self, config: &str) -> Result<usize> {
4469        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4470
4471        Ok(cfg.num_hidden_layers)
4472    }
4473    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4474        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4475
4476        let cfg = ModelConfigMetadata {
4477            max_seq_len: cfg.max_position_embeddings,
4478            num_layers: cfg.num_hidden_layers,
4479            hidden_size: cfg.hidden_size,
4480            num_kv_heads: cfg.num_key_value_heads,
4481            num_attn_heads: cfg.num_attention_heads,
4482            sliding_window: None,
4483            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4484            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4485            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4486        };
4487
4488        Ok(Box::new(cfg))
4489    }
4490}
4491
4492// ======================== GraniteMoeHybrid loader
4493
4494/// [`NormalLoader`] for a GraniteMoeHybrid model (IBM Granite 4.0).
4495///
4496/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4497pub struct GraniteMoeHybridLoader;
4498
4499impl NormalModelLoader for GraniteMoeHybridLoader {
4500    fn load(
4501        &self,
4502        config: &str,
4503        vb: ShardedVarBuilder,
4504        normal_loading_metadata: NormalLoadingMetadata,
4505        attention_mechanism: AttentionImplementation,
4506    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4507        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4508
4509        Ok(Box::new(models::granite::GraniteMoeHybrid::new(
4510            &cfg,
4511            vb,
4512            self.is_gptx(config)?,
4513            normal_loading_metadata,
4514            attention_mechanism,
4515        )?))
4516    }
4517    fn load_xlora(
4518        &self,
4519        _config: &str,
4520        _vb: ShardedVarBuilder,
4521        _lora_config: &[((String, String), LoraConfig)],
4522        _xlora_config: Option<XLoraConfig>,
4523        _xlora_ordering: Ordering,
4524        _normal_loading_metadata: NormalLoadingMetadata,
4525        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4526    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4527        todo!()
4528    }
4529    fn is_gptx(&self, _: &str) -> Result<bool> {
4530        Ok(true)
4531    }
4532    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4533        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4534        Ok(Box::new(cfg))
4535    }
4536}
4537
4538impl IsqModelLoader for GraniteMoeHybridLoader {
4539    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4540        Ok(vec![
4541            Regex::new(r"lm_head\.(weight|bias)$")?,
4542            // Attention
4543            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4544            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4545            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4546            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4547            // MLP (GraniteMLP uses shared_mlp.input_linear and shared_mlp.output_linear)
4548            Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
4549            Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
4550        ])
4551    }
4552    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4553        self.isq_layer_regexes(config)
4554    }
4555}
4556
4557impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
4558    fn mapped_max_act_size_elems(
4559        &self,
4560        config: &str,
4561        params: &AutoDeviceMapParams,
4562    ) -> Result<usize> {
4563        let AutoDeviceMapParams::Text {
4564            max_seq_len,
4565            max_batch_size,
4566        } = params
4567        else {
4568            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4569        };
4570
4571        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4572
4573        Ok(
4574            max_batch_size
4575                * cfg.num_attention_heads
4576                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4577        )
4578    }
4579    fn non_mapped_max_act_size_elems(
4580        &self,
4581        _config: &str,
4582        _params: &AutoDeviceMapParams,
4583    ) -> Result<usize> {
4584        Ok(0)
4585    }
4586
4587    fn non_mapped_size_in_bytes(
4588        &self,
4589        config: &str,
4590        dtype: DType,
4591        weight_pack_factor: usize,
4592        _matformer_config: Option<&MatformerSliceConfig>,
4593    ) -> Result<usize> {
4594        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4595
4596        let elems = {
4597            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4598            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4599            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4600                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4601            } else {
4602                0
4603            };
4604            let norm = cfg.hidden_size;
4605            embed_tokens + lm_head + norm
4606        };
4607        Ok(elems * dtype.size_in_bytes())
4608    }
4609
4610    fn layer_sizes_in_bytes(
4611        &self,
4612        config: &str,
4613        dtype: DType,
4614        weight_pack_factor: usize,
4615        _matformer_config: Option<&MatformerSliceConfig>,
4616    ) -> Result<Vec<usize>> {
4617        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4618
4619        let per_layer_elems = {
4620            let input_layernorm = cfg.hidden_size;
4621            let post_attention_layernorm = cfg.hidden_size;
4622
4623            let size_in = cfg.hidden_size;
4624            let size_q = cfg.head_dim() * cfg.num_attention_heads;
4625            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
4626            let q_proj = size_in * size_q / weight_pack_factor;
4627            let k_proj = size_in * size_kv / weight_pack_factor;
4628            let v_proj = size_in * size_kv / weight_pack_factor;
4629            let o_proj = size_q * size_in / weight_pack_factor;
4630
4631            let h_size = cfg.hidden_size;
4632            let shared_i_size = cfg.shared_intermediate_size();
4633            // GraniteMLP: input_linear (h_size -> shared_i_size * 2), output_linear (shared_i_size -> h_size)
4634            let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
4635            let output_linear = shared_i_size * h_size / weight_pack_factor;
4636
4637            input_layernorm
4638                + post_attention_layernorm
4639                + q_proj
4640                + k_proj
4641                + v_proj
4642                + o_proj
4643                + input_linear
4644                + output_linear
4645        };
4646        Ok(vec![
4647            per_layer_elems * dtype.size_in_bytes();
4648            cfg.num_hidden_layers
4649        ])
4650    }
4651
4652    fn num_layers(&self, config: &str) -> Result<usize> {
4653        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4654
4655        Ok(cfg.num_hidden_layers)
4656    }
4657    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4658        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4659
4660        let cfg = ModelConfigMetadata {
4661            max_seq_len: cfg.max_position_embeddings,
4662            num_layers: cfg.num_hidden_layers,
4663            hidden_size: cfg.hidden_size,
4664            num_kv_heads: cfg.num_key_value_heads(),
4665            num_attn_heads: cfg.num_attention_heads,
4666            sliding_window: None,
4667            k_head_dim: cfg.head_dim(),
4668            v_head_dim: cfg.head_dim(),
4669            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4670        };
4671
4672        Ok(Box::new(cfg))
4673    }
4674}
4675
4676// ======================== GPT-OSS loader
4677
4678/// [`NormalLoader`] for a GPT-OSS model.
4679///
4680/// [`NormalLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.NormalLoader.html
4681pub struct GptOssLoader;
4682
4683impl NormalModelLoader for GptOssLoader {
4684    fn load(
4685        &self,
4686        config: &str,
4687        vb: ShardedVarBuilder,
4688        normal_loading_metadata: NormalLoadingMetadata,
4689        attention_mechanism: AttentionImplementation,
4690    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4691        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4692
4693        Ok(Box::new(models::gpt_oss::Model::new(
4694            &cfg,
4695            vb,
4696            self.is_gptx(config)?,
4697            normal_loading_metadata,
4698            attention_mechanism,
4699        )?))
4700    }
4701    fn load_xlora(
4702        &self,
4703        _config: &str,
4704        _vb: ShardedVarBuilder,
4705        _lora_config: &[((String, String), LoraConfig)],
4706        _xlora_config: Option<XLoraConfig>,
4707        _xlora_ordering: Ordering,
4708        _normal_loading_metadata: NormalLoadingMetadata,
4709        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4710    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4711        anyhow::bail!("GPT-OSS does not support X-LoRA")
4712    }
4713    fn is_gptx(&self, _: &str) -> Result<bool> {
4714        Ok(true)
4715    }
4716    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4717        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4718        Ok(Box::new(cfg))
4719    }
4720    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4721        Ok(false)
4722    }
4723}
4724
4725impl IsqModelLoader for GptOssLoader {
4726    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4727        // Only attention layers are ISQ-able - MoE experts are already MXFP4 quantized
4728        Ok(vec![
4729            Regex::new(r"lm_head\.(weight|bias)$")?,
4730            // Attention
4731            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4732            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4733            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4734            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4735        ])
4736    }
4737    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4738        self.isq_layer_regexes(config)
4739    }
4740}
4741
4742impl DeviceMappedModelLoader for GptOssLoader {
4743    fn mapped_max_act_size_elems(
4744        &self,
4745        config: &str,
4746        params: &AutoDeviceMapParams,
4747    ) -> Result<usize> {
4748        let AutoDeviceMapParams::Text {
4749            max_seq_len,
4750            max_batch_size,
4751        } = params
4752        else {
4753            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4754        };
4755
4756        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4757
4758        Ok(
4759            max_batch_size
4760                * cfg.num_attention_heads
4761                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4762        )
4763    }
4764    fn non_mapped_max_act_size_elems(
4765        &self,
4766        _config: &str,
4767        _params: &AutoDeviceMapParams,
4768    ) -> Result<usize> {
4769        Ok(0)
4770    }
4771
4772    fn non_mapped_size_in_bytes(
4773        &self,
4774        config: &str,
4775        dtype: DType,
4776        weight_pack_factor: usize,
4777        _matformer_config: Option<&MatformerSliceConfig>,
4778    ) -> Result<usize> {
4779        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4780
4781        let elems = {
4782            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4783            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4784                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4785            } else {
4786                0
4787            };
4788            let norm = cfg.hidden_size;
4789            embed_tokens + lm_head + norm
4790        };
4791        Ok(elems * dtype.size_in_bytes())
4792    }
4793
4794    fn layer_sizes_in_bytes(
4795        &self,
4796        config: &str,
4797        dtype: DType,
4798        weight_pack_factor: usize,
4799        _matformer_config: Option<&MatformerSliceConfig>,
4800    ) -> Result<Vec<usize>> {
4801        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4802
4803        let per_layer_elems = {
4804            let input_layernorm = cfg.hidden_size;
4805            let post_attention_layernorm = cfg.hidden_size;
4806
4807            let size_in = cfg.hidden_size;
4808            let head_dim = cfg.head_dim();
4809            let size_q = head_dim * cfg.num_attention_heads;
4810            let size_kv = head_dim * cfg.num_key_value_heads;
4811            let q_proj =
4812                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4813            let k_proj =
4814                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4815            let v_proj =
4816                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4817            let o_proj =
4818                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4819
4820            // MoE experts - MXFP4 quantized, so very compact
4821            // gate_up_proj: [num_experts, intermediate_size * 2, hidden_size/2] packed
4822            // down_proj: [num_experts, hidden_size, intermediate_size/2] packed
4823            // At 4 bits per weight, packing factor is 2
4824            let mxfp4_pack = 2;
4825            let gate_up_proj_size =
4826                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4827            let down_proj_size =
4828                cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4829            // Plus scales at 1 byte per 32 elements
4830            let gate_up_scales =
4831                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4832            let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4833            // Plus biases
4834            let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4835            let down_bias = cfg.num_local_experts * cfg.hidden_size;
4836            // Router
4837            let router = cfg.hidden_size * cfg.num_local_experts;
4838            // Sinks per head
4839            let sinks = cfg.num_attention_heads;
4840
4841            input_layernorm
4842                + post_attention_layernorm
4843                + q_proj
4844                + k_proj
4845                + v_proj
4846                + o_proj
4847                + gate_up_proj_size
4848                + down_proj_size
4849                + gate_up_scales
4850                + down_scales
4851                + gate_up_bias
4852                + down_bias
4853                + router
4854                + sinks
4855        };
4856        Ok(vec![
4857            per_layer_elems * dtype.size_in_bytes();
4858            cfg.num_hidden_layers
4859        ])
4860    }
4861
4862    fn num_layers(&self, config: &str) -> Result<usize> {
4863        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4864
4865        Ok(cfg.num_hidden_layers)
4866    }
4867    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4868        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4869
4870        let head_dim = cfg.head_dim();
4871        let cfg = ModelConfigMetadata {
4872            max_seq_len: cfg.max_position_embeddings,
4873            num_layers: cfg.num_hidden_layers,
4874            hidden_size: cfg.hidden_size,
4875            num_kv_heads: cfg.num_key_value_heads,
4876            num_attn_heads: cfg.num_attention_heads,
4877            sliding_window: cfg.sliding_window,
4878            k_head_dim: head_dim,
4879            v_head_dim: head_dim,
4880            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4881        };
4882
4883        Ok(Box::new(cfg))
4884    }
4885}