Skip to main content

hanzo_engine/pipeline/loaders/
normal_loaders.rs

1use std::{
2    collections::HashMap,
3    fmt::{Debug, Display},
4    str::FromStr,
5    sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::speculative::SpeculativeTargetMixin;
11use crate::{
12    amoe::AnyMoeBaseModelMixin,
13    device_map::DeviceMapper,
14    lora::{LoraConfig, Ordering},
15    paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
16    pipeline::{
17        isq::IsqModelLoader,
18        text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
19        EitherCache, IsqModel,
20    },
21    utils::varbuilder_utils::DeviceForLoadTensor,
22    xlora_models::NonGranularState,
23};
24use anyhow::Result;
25use hanzo_ml::{DType, Device, Tensor};
26use hanzo_quant::log::once_log_debug;
27
28use hanzo_quant::ShardedVarBuilder;
29use indicatif::MultiProgress;
30#[cfg(feature = "pyo3_macros")]
31use pyo3::pyclass;
32
33use regex::Regex;
34use serde::Deserialize;
35
36use crate::{
37    models,
38    xlora_models::{self, XLoraConfig},
39};
40
41use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
42
43pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin + SpeculativeTargetMixin {
44    #[allow(clippy::too_many_arguments)]
45    fn forward(
46        &self,
47        input_ids: &Tensor,
48        seqlen_offsets: &[usize],
49        context_lens: Vec<(usize, usize)>,
50        position_ids: Vec<usize>,
51        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
52        flash_params: &FlashParams,
53    ) -> hanzo_ml::Result<Tensor>;
54    #[allow(clippy::too_many_arguments)]
55    fn xlora_forward(
56        &self,
57        input_ids: &Tensor,
58        input_ids_full: &Tensor,
59        seqlen_offsets: &[usize],
60        seqlen_offsets_full: &[usize],
61        no_kv_cache: bool,
62        non_granular_state: &Option<NonGranularState>,
63        context_lens: Vec<(usize, usize)>,
64        position_ids: Vec<usize>,
65        flash_params: &FlashParams,
66        flash_params_full: &FlashParams,
67    ) -> hanzo_ml::Result<Tensor>;
68    fn is_xlora(&self) -> bool;
69    fn device(&self) -> &Device;
70    fn cache(&self) -> &EitherCache;
71    fn cache_mut(&mut self) -> &mut EitherCache;
72    fn max_seq_len(&self) -> usize;
73    fn config(&self) -> &ModelConfigMetadata;
74    fn model_config(&self) -> Arc<dyn ModelConfigLike + Send + Sync> {
75        Arc::new(self.config().clone())
76    }
77}
78
79/// Metadata for loading a model with ISQ or device mapping.
80pub struct NormalLoadingMetadata {
81    // Device mapping metadata which can be used to construct a concrete device mapper
82    pub mapper: Box<dyn DeviceMapper + Send + Sync>,
83    // Flag to check if loading in ISQ
84    pub loading_isq: bool,
85    // Device mapping target device (the one that is not the cpu)
86    pub real_device: Device,
87    // MultiProgress support for parallelized loading
88    pub multi_progress: Arc<MultiProgress>,
89    // Optional Matryoshka Transformer slicing configuration
90    pub matformer_slicing_config: Option<MatformerSliceConfig>,
91}
92
93pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
94    fn load(
95        &self,
96        config: &str,
97        vb: ShardedVarBuilder,
98        normal_loading_metadata: NormalLoadingMetadata,
99        attention_mechanism: AttentionImplementation,
100    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
101    #[allow(clippy::too_many_arguments)]
102    fn load_xlora(
103        &self,
104        config: &str,
105        vb: ShardedVarBuilder,
106        lora_config: &[((String, String), LoraConfig)],
107        xlora_config: Option<XLoraConfig>,
108        xlora_ordering: Ordering,
109        normal_loading_metadata: NormalLoadingMetadata,
110        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
111    ) -> Result<Box<dyn NormalModel + Send + Sync>>;
112    fn is_gptx(&self, config: &str) -> Result<bool>;
113    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
114        Ok(true)
115    }
116    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
117    fn get_device_for_tensor(
118        &self,
119        config: &str,
120        _mapper: &dyn DeviceMapper,
121        loading_isq: bool,
122    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
123        if loading_isq {
124            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
125        } else {
126            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
127            let num_layers = self.model_config(config)?.num_layers();
128            let closure = move |name: String| {
129                if let Some(captures) = re.captures(&name) {
130                    captures
131                        .get(1)
132                        .and_then(|m| m.as_str().parse::<usize>().ok())
133                        .map(|l| l.min(num_layers))
134                        .map(DeviceForLoadTensor::Idx)
135                        .unwrap_or(DeviceForLoadTensor::Base)
136                } else {
137                    DeviceForLoadTensor::Base
138                }
139            };
140
141            Ok(Arc::new(closure))
142        }
143    }
144}
145
146#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
147#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
148/// The architecture to load the normal model as.
149pub enum NormalLoaderType {
150    #[serde(rename = "mistral")]
151    Mistral,
152    #[serde(rename = "gemma")]
153    Gemma,
154    #[serde(rename = "mixtral")]
155    Mixtral,
156    #[serde(rename = "llama")]
157    Llama,
158    #[serde(rename = "phi2")]
159    Phi2,
160    #[serde(rename = "phi3")]
161    Phi3,
162    #[serde(rename = "qwen2")]
163    Qwen2,
164    #[serde(rename = "gemma2")]
165    Gemma2,
166    #[serde(rename = "starcoder2")]
167    Starcoder2,
168    #[serde(rename = "phi3.5moe")]
169    Phi3_5MoE,
170    #[serde(rename = "deepseekv2")]
171    DeepSeekV2,
172    #[serde(rename = "deepseekv3")]
173    DeepSeekV3,
174    #[serde(rename = "qwen3")]
175    Qwen3,
176    #[serde(rename = "glm4")]
177    GLM4,
178    #[serde(rename = "glm4moelite")]
179    GLM4MoeLite,
180    #[serde(rename = "glm4moe")]
181    GLM4Moe,
182    #[serde(rename = "qwen3moe")]
183    Qwen3Moe,
184    #[serde(rename = "smollm3")]
185    SmolLm3,
186    #[serde(rename = "granitemoehybrid")]
187    GraniteMoeHybrid,
188    #[serde(rename = "gpt_oss")]
189    GptOss,
190    #[serde(rename = "qwen3next")]
191    Qwen3Next,
192}
193
194// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
195impl NormalLoaderType {
196    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
197        match name {
198            "MistralForCausalLM" => Ok(Self::Mistral),
199            "MixtralForCausalLM" => Ok(Self::Mixtral),
200            "GemmaForCausalLM" => Ok(Self::Gemma),
201            "Gemma2ForCausalLM" => Ok(Self::Gemma2),
202            "PhiForCausalLM" => Ok(Self::Phi2),
203            "Phi3ForCausalLM" => Ok(Self::Phi3),
204            "LlamaForCausalLM" => Ok(Self::Llama),
205            "Qwen2ForCausalLM" => Ok(Self::Qwen2),
206            "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
207            "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
208            "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
209            "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
210            "Qwen3ForCausalLM" => Ok(Self::Qwen3),
211            "Glm4ForCausalLM" => Ok(Self::GLM4),
212            "Glm4MoeLiteForCausalLM" => Ok(Self::GLM4MoeLite),
213            "Glm4MoeForCausalLM" => Ok(Self::GLM4Moe),
214            "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
215            "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
216            "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
217            "GptOssForCausalLM" => Ok(Self::GptOss),
218            "Qwen3NextForCausalLM" => Ok(Self::Qwen3Next),
219            other => anyhow::bail!(
220                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
221            ),
222        }
223    }
224}
225
226impl FromStr for NormalLoaderType {
227    type Err = String;
228    fn from_str(s: &str) -> Result<Self, Self::Err> {
229        match s {
230            "mistral" => Ok(Self::Mistral),
231            "gemma" => Ok(Self::Gemma),
232            "mixtral" => Ok(Self::Mixtral),
233            "llama" => Ok(Self::Llama),
234            "phi2" => Ok(Self::Phi2),
235            "phi3" => Ok(Self::Phi3),
236            "qwen2" => Ok(Self::Qwen2),
237            "gemma2" => Ok(Self::Gemma2),
238            "starcoder2" => Ok(Self::Starcoder2),
239            "phi3.5moe" => Ok(Self::Phi3_5MoE),
240            "deepseekv2" => Ok(Self::DeepSeekV2),
241            "deepseekv3" => Ok(Self::DeepSeekV3),
242            "qwen3" => Ok(Self::Qwen3),
243            "glm4" => Ok(Self::GLM4),
244            "glm4moelite" => Ok(Self::GLM4MoeLite),
245            "glm4moe" => Ok(Self::GLM4Moe),
246            "qwen3moe" => Ok(Self::Qwen3Moe),
247            "smollm3" => Ok(Self::SmolLm3),
248            "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
249            "gpt_oss" => Ok(Self::GptOss),
250            "qwen3next" => Ok(Self::Qwen3Next),
251            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `glm4moelite`, `glm4moe`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`, `qwen3next`.")),
252        }
253    }
254}
255
256impl Display for NormalLoaderType {
257    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258        match self {
259            Self::Gemma => write!(f, "gemma"),
260            Self::Gemma2 => write!(f, "gemma2"),
261            Self::Llama => write!(f, "llama"),
262            Self::Mistral => write!(f, "mistral"),
263            Self::Mixtral => write!(f, "mixtral"),
264            Self::Phi2 => write!(f, "phi2"),
265            Self::Phi3 => write!(f, "phi3"),
266            Self::Phi3_5MoE => write!(f, "phi3.5moe"),
267            Self::Qwen2 => write!(f, "qwen2"),
268            Self::Starcoder2 => write!(f, "starcoder2"),
269            Self::DeepSeekV2 => write!(f, "deepseekv2"),
270            Self::DeepSeekV3 => write!(f, "deepseekv3"),
271            Self::Qwen3 => write!(f, "qwen3"),
272            Self::GLM4 => write!(f, "glm4"),
273            Self::GLM4MoeLite => write!(f, "glm4moelite"),
274            Self::GLM4Moe => write!(f, "glm4moe"),
275            Self::Qwen3Moe => write!(f, "qwen3moe"),
276            Self::SmolLm3 => write!(f, "smollm3"),
277            Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
278            Self::GptOss => write!(f, "gpt_oss"),
279            Self::Qwen3Next => write!(f, "qwen3next"),
280        }
281    }
282}
283
284macro_rules! bias_if {
285    ($cond:expr, $size:expr) => {
286        if $cond {
287            $size
288        } else {
289            0
290        }
291    };
292}
293
294/// Load a model based on the Hugging Face Transformers -CausalLM model class
295pub struct AutoNormalLoader;
296
297#[derive(Deserialize)]
298struct AutoNormalLoaderConfig {
299    architectures: Vec<String>,
300}
301
302impl AutoNormalLoader {
303    fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
304        let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
305        if auto_cfg.architectures.len() != 1 {
306            anyhow::bail!("Expected to have one name for `architectures` config field.")
307        }
308
309        let name = &auto_cfg.architectures[0];
310
311        let tp = NormalLoaderType::from_causal_lm_name(name)?;
312
313        once_log_debug(format!("Automatic loader type determined to be `{tp}`"));
314
315        match tp {
316            NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
317            NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
318            NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
319            NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
320            NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
321            NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
322            NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
323            NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
324            NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
325            NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
326            NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
327            NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
328            NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
329            NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
330            NormalLoaderType::GLM4MoeLite => Ok(Box::new(GLM4MoeLiteLoader)),
331            NormalLoaderType::GLM4Moe => Ok(Box::new(GLM4MoeLoader)),
332            NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
333            NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
334            NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
335            NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
336            NormalLoaderType::Qwen3Next => Ok(Box::new(Qwen3NextLoader)),
337        }
338    }
339}
340
341impl NormalModelLoader for AutoNormalLoader {
342    fn load(
343        &self,
344        config: &str,
345        vb: ShardedVarBuilder,
346        normal_loading_metadata: NormalLoadingMetadata,
347        attention_mechanism: AttentionImplementation,
348    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
349        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
350    }
351    fn load_xlora(
352        &self,
353        config: &str,
354        vb: ShardedVarBuilder,
355        lora_config: &[((String, String), LoraConfig)],
356        xlora_config: Option<XLoraConfig>,
357        xlora_ordering: Ordering,
358        normal_loading_metadata: NormalLoadingMetadata,
359        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
360    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
361        Self::get_loader(config)?.load_xlora(
362            config,
363            vb,
364            lora_config,
365            xlora_config,
366            xlora_ordering,
367            normal_loading_metadata,
368            preload_adapters,
369        )
370    }
371    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
372        Self::get_loader(config)?.get_config_repr(config)
373    }
374    fn supports_paged_attention(&self, config: &str) -> Result<bool> {
375        Self::get_loader(config)?.supports_paged_attention(config)
376    }
377    fn is_gptx(&self, config: &str) -> Result<bool> {
378        Self::get_loader(config)?.is_gptx(config)
379    }
380}
381
382impl IsqModelLoader for AutoNormalLoader {
383    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
384        Self::get_loader(config)?.immediate_isq_predicates(config)
385    }
386    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
387        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
388    }
389    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
390        Self::get_loader(config)?.isq_layer_regexes(config)
391    }
392    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
393        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
394    }
395}
396
397impl DeviceMappedModelLoader for AutoNormalLoader {
398    fn non_mapped_size_in_bytes(
399        &self,
400        config: &str,
401        dtype: DType,
402        weight_pack_factor: usize,
403        _matformer_config: Option<&MatformerSliceConfig>,
404    ) -> Result<usize> {
405        Self::get_loader(config)?.non_mapped_size_in_bytes(
406            config,
407            dtype,
408            weight_pack_factor,
409            _matformer_config,
410        )
411    }
412    fn num_layers(&self, config: &str) -> Result<usize> {
413        Self::get_loader(config)?.num_layers(config)
414    }
415    fn layer_sizes_in_bytes(
416        &self,
417        config: &str,
418        dtype: DType,
419        weight_pack_factor: usize,
420        _matformer_config: Option<&MatformerSliceConfig>,
421    ) -> Result<Vec<usize>> {
422        Self::get_loader(config)?.layer_sizes_in_bytes(
423            config,
424            dtype,
425            weight_pack_factor,
426            _matformer_config,
427        )
428    }
429    fn mapped_max_act_size_elems(
430        &self,
431        config: &str,
432        params: &super::AutoDeviceMapParams,
433    ) -> Result<usize> {
434        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
435    }
436    fn non_mapped_max_act_size_elems(
437        &self,
438        _config: &str,
439        _params: &AutoDeviceMapParams,
440    ) -> Result<usize> {
441        Ok(0)
442    }
443    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
444        Self::get_loader(config)?.model_config(config)
445    }
446}
447
448// ======================== Mistral loader
449
450pub struct MistralLoader;
451
452impl NormalModelLoader for MistralLoader {
453    fn load(
454        &self,
455        config: &str,
456        vb: ShardedVarBuilder,
457        normal_loading_metadata: NormalLoadingMetadata,
458        attention_mechanism: AttentionImplementation,
459    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
460        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
461        Ok(Box::new(models::mistral::Model::new(
462            &cfg,
463            vb,
464            self.is_gptx(config)?,
465            normal_loading_metadata,
466            attention_mechanism,
467        )?))
468    }
469    fn load_xlora(
470        &self,
471        config: &str,
472        vb: ShardedVarBuilder,
473        lora_config: &[((String, String), LoraConfig)],
474        xlora_config: Option<XLoraConfig>,
475        xlora_ordering: Ordering,
476        normal_loading_metadata: NormalLoadingMetadata,
477        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
478    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
479        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
480        Ok(Box::new(xlora_models::XLoraMistral::new(
481            &cfg,
482            vb,
483            lora_config,
484            xlora_config,
485            xlora_ordering,
486            self.is_gptx(config)?,
487            normal_loading_metadata,
488            preload_adapters,
489        )?))
490    }
491    fn is_gptx(&self, _: &str) -> Result<bool> {
492        Ok(true)
493    }
494    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
495        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
496        Ok(Box::new(cfg))
497    }
498}
499
500impl IsqModelLoader for MistralLoader {
501    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
502        Ok(vec![
503            Regex::new(r"lm_head\.(weight|bias)$")?,
504            // Attention
505            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
506            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
507            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
508            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
509            // MLP
510            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
511            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
512            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
513        ])
514    }
515    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
516        self.isq_layer_regexes(config)
517    }
518}
519
520impl DeviceMappedModelLoader for MistralLoader {
521    fn mapped_max_act_size_elems(
522        &self,
523        config: &str,
524        params: &AutoDeviceMapParams,
525    ) -> Result<usize> {
526        let AutoDeviceMapParams::Text {
527            max_seq_len,
528            max_batch_size,
529        } = params
530        else {
531            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
532        };
533
534        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
535
536        Ok(
537            max_batch_size
538                * cfg.num_attention_heads
539                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
540        )
541    }
542    fn non_mapped_max_act_size_elems(
543        &self,
544        _config: &str,
545        _params: &AutoDeviceMapParams,
546    ) -> Result<usize> {
547        Ok(0)
548    }
549
550    fn non_mapped_size_in_bytes(
551        &self,
552        config: &str,
553        dtype: DType,
554        weight_pack_factor: usize,
555        _matformer_config: Option<&MatformerSliceConfig>,
556    ) -> Result<usize> {
557        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
558
559        let elems = {
560            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
561            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
562            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
563                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
564            } else {
565                0
566            };
567            let norm = cfg.hidden_size;
568            embed_tokens + lm_head + norm
569        };
570        Ok(elems * dtype.size_in_bytes())
571    }
572
573    fn layer_sizes_in_bytes(
574        &self,
575        config: &str,
576        dtype: DType,
577        weight_pack_factor: usize,
578        _matformer_config: Option<&MatformerSliceConfig>,
579    ) -> Result<Vec<usize>> {
580        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
581
582        let per_layer_elems = {
583            let input_layernorm = cfg.hidden_size;
584            let post_attention_layernorm = cfg.hidden_size;
585
586            let size_in = cfg.hidden_size;
587            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
588            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
589            let q_proj = size_in * size_q / weight_pack_factor;
590            let k_proj = size_in * size_kv / weight_pack_factor;
591            let v_proj = size_in * size_kv / weight_pack_factor;
592            let o_proj = size_q * size_in / weight_pack_factor;
593
594            let h_size = cfg.hidden_size;
595            let i_size = cfg.intermediate_size;
596            let gate_proj = h_size * i_size / weight_pack_factor;
597            let up_proj = h_size * i_size / weight_pack_factor;
598            let down_proj = i_size * h_size / weight_pack_factor;
599
600            input_layernorm
601                + post_attention_layernorm
602                + q_proj
603                + k_proj
604                + v_proj
605                + o_proj
606                + gate_proj
607                + up_proj
608                + down_proj
609        };
610        Ok(vec![
611            per_layer_elems * dtype.size_in_bytes();
612            cfg.num_hidden_layers
613        ])
614    }
615
616    fn num_layers(&self, config: &str) -> Result<usize> {
617        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
618        Ok(cfg.num_hidden_layers)
619    }
620
621    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
622        let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
623
624        let cfg = ModelConfigMetadata {
625            max_seq_len: cfg.max_position_embeddings,
626            num_layers: cfg.num_hidden_layers,
627            hidden_size: cfg.hidden_size,
628            num_kv_heads: cfg.num_key_value_heads,
629            num_attn_heads: cfg.num_attention_heads,
630            sliding_window: cfg.sliding_window,
631            k_head_dim: cfg.head_dim(),
632            v_head_dim: cfg.head_dim(),
633            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
634        };
635
636        Ok(Box::new(cfg))
637    }
638}
639
640// ======================== Gemma loader
641
642/// [`NormalLoader`] for a Gemma model.
643///
644/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
645pub struct GemmaLoader;
646
647impl NormalModelLoader for GemmaLoader {
648    fn load(
649        &self,
650        config: &str,
651        vb: ShardedVarBuilder,
652        normal_loading_metadata: NormalLoadingMetadata,
653        attention_mechanism: AttentionImplementation,
654    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
655        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
656
657        Ok(Box::new(models::gemma::Model::new(
658            &cfg,
659            vb,
660            self.is_gptx(config)?,
661            normal_loading_metadata,
662            attention_mechanism,
663        )?))
664    }
665    fn load_xlora(
666        &self,
667        config: &str,
668        vb: ShardedVarBuilder,
669        lora_config: &[((String, String), LoraConfig)],
670        xlora_config: Option<XLoraConfig>,
671        xlora_ordering: Ordering,
672        normal_loading_metadata: NormalLoadingMetadata,
673        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
674    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
675        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
676
677        Ok(Box::new(xlora_models::XLoraGemma::new(
678            &cfg,
679            vb,
680            lora_config,
681            xlora_config,
682            xlora_ordering,
683            self.is_gptx(config)?,
684            normal_loading_metadata,
685            preload_adapters,
686        )?))
687    }
688    fn is_gptx(&self, _: &str) -> Result<bool> {
689        Ok(true)
690    }
691    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
692        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
693        Ok(Box::new(cfg))
694    }
695}
696
697impl IsqModelLoader for GemmaLoader {
698    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
699        Ok(vec![
700            Regex::new(r"lm_head\.(weight|bias)$")?,
701            // Attention
702            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
703            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
704            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
705            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
706            // MLP
707            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
708            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
709            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
710        ])
711    }
712    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
713        self.isq_layer_regexes(config)
714    }
715}
716
717impl DeviceMappedModelLoader for GemmaLoader {
718    fn mapped_max_act_size_elems(
719        &self,
720        config: &str,
721        params: &AutoDeviceMapParams,
722    ) -> Result<usize> {
723        let AutoDeviceMapParams::Text {
724            max_seq_len,
725            max_batch_size,
726        } = params
727        else {
728            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
729        };
730
731        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
732
733        Ok(
734            max_batch_size
735                * cfg.num_attention_heads
736                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
737        )
738    }
739    fn non_mapped_max_act_size_elems(
740        &self,
741        _config: &str,
742        _params: &AutoDeviceMapParams,
743    ) -> Result<usize> {
744        Ok(0)
745    }
746
747    fn non_mapped_size_in_bytes(
748        &self,
749        config: &str,
750        dtype: DType,
751        weight_pack_factor: usize,
752        _matformer_config: Option<&MatformerSliceConfig>,
753    ) -> Result<usize> {
754        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
755
756        let elems = {
757            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
758            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
759            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
760                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
761            } else {
762                0
763            };
764            let norm = cfg.hidden_size;
765            embed_tokens + lm_head + norm
766        };
767        Ok(elems * dtype.size_in_bytes())
768    }
769
770    fn layer_sizes_in_bytes(
771        &self,
772        config: &str,
773        dtype: DType,
774        weight_pack_factor: usize,
775        _matformer_config: Option<&MatformerSliceConfig>,
776    ) -> Result<Vec<usize>> {
777        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
778
779        let per_layer_elems = {
780            let input_layernorm = cfg.hidden_size;
781            let post_attention_layernorm = cfg.hidden_size;
782
783            let size_in = cfg.hidden_size;
784            let size_q = cfg.head_dim * cfg.num_attention_heads;
785            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
786            let q_proj =
787                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
788            let k_proj =
789                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
790            let v_proj =
791                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
792            let o_proj =
793                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
794
795            let h_size = cfg.hidden_size;
796            let i_size = cfg.intermediate_size;
797            let gate_proj = h_size * i_size / weight_pack_factor;
798            let up_proj = h_size * i_size / weight_pack_factor;
799            let down_proj = i_size * h_size / weight_pack_factor;
800
801            input_layernorm
802                + post_attention_layernorm
803                + q_proj
804                + k_proj
805                + v_proj
806                + o_proj
807                + gate_proj
808                + up_proj
809                + down_proj
810        };
811        Ok(vec![
812            per_layer_elems * dtype.size_in_bytes();
813            cfg.num_hidden_layers
814        ])
815    }
816
817    fn num_layers(&self, config: &str) -> Result<usize> {
818        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
819        Ok(cfg.num_hidden_layers)
820    }
821
822    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
823        let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
824
825        let cfg = ModelConfigMetadata {
826            max_seq_len: cfg.max_position_embeddings,
827            num_layers: cfg.num_hidden_layers,
828            hidden_size: cfg.hidden_size,
829            num_kv_heads: cfg.num_key_value_heads,
830            num_attn_heads: cfg.num_attention_heads,
831            sliding_window: None,
832            k_head_dim: cfg.head_dim,
833            v_head_dim: cfg.head_dim,
834            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
835        };
836
837        Ok(Box::new(cfg))
838    }
839}
840
841// ======================== Llama loader
842
843/// [`NormalLoader`] for a Llama model.
844///
845/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
846pub struct LlamaLoader;
847
848impl NormalModelLoader for LlamaLoader {
849    fn load(
850        &self,
851        config: &str,
852        vb: ShardedVarBuilder,
853        normal_loading_metadata: NormalLoadingMetadata,
854        attention_mechanism: AttentionImplementation,
855    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
856        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
857
858        Ok(Box::new(models::llama::Llama::new(
859            &cfg,
860            vb,
861            self.is_gptx(config)?,
862            normal_loading_metadata,
863            attention_mechanism,
864        )?))
865    }
866    fn load_xlora(
867        &self,
868        config: &str,
869        vb: ShardedVarBuilder,
870        lora_config: &[((String, String), LoraConfig)],
871        xlora_config: Option<XLoraConfig>,
872        xlora_ordering: Ordering,
873        normal_loading_metadata: NormalLoadingMetadata,
874        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
875    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
876        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
877
878        Ok(Box::new(xlora_models::XLoraLlama::new(
879            &cfg,
880            vb,
881            lora_config,
882            xlora_config,
883            xlora_ordering,
884            self.is_gptx(config)?,
885            normal_loading_metadata,
886            preload_adapters,
887        )?))
888    }
889    fn is_gptx(&self, _: &str) -> Result<bool> {
890        Ok(true)
891    }
892    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
893        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
894        Ok(Box::new(cfg))
895    }
896}
897
898impl IsqModelLoader for LlamaLoader {
899    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
900        Ok(vec![
901            Regex::new(r"lm_head\.(weight|bias)$")?,
902            // Attention
903            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
904            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
905            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
906            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
907            // MLP
908            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
909            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
910            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
911        ])
912    }
913    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
914        self.isq_layer_regexes(config)
915    }
916}
917
918impl DeviceMappedModelLoader for LlamaLoader {
919    fn mapped_max_act_size_elems(
920        &self,
921        config: &str,
922        params: &AutoDeviceMapParams,
923    ) -> Result<usize> {
924        let AutoDeviceMapParams::Text {
925            max_seq_len,
926            max_batch_size,
927        } = params
928        else {
929            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
930        };
931
932        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
933
934        Ok(
935            max_batch_size
936                * cfg.num_attention_heads
937                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
938        )
939    }
940    fn non_mapped_max_act_size_elems(
941        &self,
942        _config: &str,
943        _params: &AutoDeviceMapParams,
944    ) -> Result<usize> {
945        Ok(0)
946    }
947
948    fn non_mapped_size_in_bytes(
949        &self,
950        config: &str,
951        dtype: DType,
952        weight_pack_factor: usize,
953        _matformer_config: Option<&MatformerSliceConfig>,
954    ) -> Result<usize> {
955        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
956
957        let elems = {
958            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
959            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
960            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
961                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
962            } else {
963                0
964            };
965            let norm = cfg.hidden_size;
966            embed_tokens + lm_head + norm
967        };
968        Ok(elems * dtype.size_in_bytes())
969    }
970
971    fn layer_sizes_in_bytes(
972        &self,
973        config: &str,
974        dtype: DType,
975        weight_pack_factor: usize,
976        _matformer_config: Option<&MatformerSliceConfig>,
977    ) -> Result<Vec<usize>> {
978        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
979
980        let per_layer_elems = {
981            let input_layernorm = cfg.hidden_size;
982            let post_attention_layernorm = cfg.hidden_size;
983
984            let size_in = cfg.hidden_size;
985            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
986            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
987            let q_proj = size_in * size_q / weight_pack_factor;
988            let k_proj = size_in * size_kv / weight_pack_factor;
989            let v_proj = size_in * size_kv / weight_pack_factor;
990            let o_proj = size_q * size_in / weight_pack_factor;
991
992            let h_size = cfg.hidden_size;
993            let i_size = cfg.intermediate_size;
994            let gate_proj = h_size * i_size / weight_pack_factor;
995            let up_proj = h_size * i_size / weight_pack_factor;
996            let down_proj = i_size * h_size / weight_pack_factor;
997
998            input_layernorm
999                + post_attention_layernorm
1000                + q_proj
1001                + k_proj
1002                + v_proj
1003                + o_proj
1004                + gate_proj
1005                + up_proj
1006                + down_proj
1007        };
1008        Ok(vec![
1009            per_layer_elems * dtype.size_in_bytes();
1010            cfg.num_hidden_layers
1011        ])
1012    }
1013
1014    fn num_layers(&self, config: &str) -> Result<usize> {
1015        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1016
1017        Ok(cfg.num_hidden_layers)
1018    }
1019    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1020        let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1021
1022        let cfg = ModelConfigMetadata {
1023            max_seq_len: cfg.max_position_embeddings,
1024            num_layers: cfg.num_hidden_layers,
1025            hidden_size: cfg.hidden_size,
1026            num_kv_heads: cfg.num_key_value_heads,
1027            num_attn_heads: cfg.num_attention_heads,
1028            sliding_window: None,
1029            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1030            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1031            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1032        };
1033
1034        Ok(Box::new(cfg))
1035    }
1036}
1037
1038// ======================== Mixtral loader
1039
1040pub struct MixtralLoader;
1041
1042impl NormalModelLoader for MixtralLoader {
1043    fn load(
1044        &self,
1045        config: &str,
1046        vb: ShardedVarBuilder,
1047        normal_loading_metadata: NormalLoadingMetadata,
1048        attention_mechanism: AttentionImplementation,
1049    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1050        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1051
1052        Ok(Box::new(models::mixtral::Model::new(
1053            &cfg,
1054            vb,
1055            self.is_gptx(config)?,
1056            normal_loading_metadata,
1057            attention_mechanism,
1058        )?))
1059    }
1060    fn load_xlora(
1061        &self,
1062        config: &str,
1063        vb: ShardedVarBuilder,
1064        lora_config: &[((String, String), LoraConfig)],
1065        xlora_config: Option<XLoraConfig>,
1066        xlora_ordering: Ordering,
1067        normal_loading_metadata: NormalLoadingMetadata,
1068        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1069    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1070        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1071
1072        Ok(Box::new(xlora_models::XLoraMixtral::new(
1073            &cfg,
1074            vb,
1075            lora_config,
1076            xlora_config,
1077            xlora_ordering,
1078            self.is_gptx(config)?,
1079            normal_loading_metadata,
1080            preload_adapters,
1081        )?))
1082    }
1083    fn is_gptx(&self, _: &str) -> Result<bool> {
1084        Ok(true)
1085    }
1086    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1087        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1088
1089        Ok(Box::new(cfg))
1090    }
1091}
1092
1093impl IsqModelLoader for MixtralLoader {
1094    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1095        Ok(vec![
1096            Regex::new(r"lm_head\.(weight|bias)$")?,
1097            // Attention
1098            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1099            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1100            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1101            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1102            // Experts
1103            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1104            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1105            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1106            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1107        ])
1108    }
1109    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1110        self.isq_layer_regexes(config)
1111    }
1112}
1113
1114impl DeviceMappedModelLoader for MixtralLoader {
1115    fn mapped_max_act_size_elems(
1116        &self,
1117        config: &str,
1118        params: &AutoDeviceMapParams,
1119    ) -> Result<usize> {
1120        let AutoDeviceMapParams::Text {
1121            max_seq_len,
1122            max_batch_size,
1123        } = params
1124        else {
1125            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1126        };
1127
1128        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1129
1130        Ok(
1131            max_batch_size
1132                * cfg.num_attention_heads
1133                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1134        )
1135    }
1136    fn non_mapped_max_act_size_elems(
1137        &self,
1138        _config: &str,
1139        _params: &AutoDeviceMapParams,
1140    ) -> Result<usize> {
1141        Ok(0)
1142    }
1143
1144    fn non_mapped_size_in_bytes(
1145        &self,
1146        config: &str,
1147        dtype: DType,
1148        weight_pack_factor: usize,
1149        _matformer_config: Option<&MatformerSliceConfig>,
1150    ) -> Result<usize> {
1151        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1152
1153        let elems = {
1154            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1155            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1156            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1157                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1158            } else {
1159                0
1160            };
1161            let norm = cfg.hidden_size;
1162            embed_tokens + lm_head + norm
1163        };
1164        Ok(elems * dtype.size_in_bytes())
1165    }
1166
1167    fn layer_sizes_in_bytes(
1168        &self,
1169        config: &str,
1170        dtype: DType,
1171        weight_pack_factor: usize,
1172        _matformer_config: Option<&MatformerSliceConfig>,
1173    ) -> Result<Vec<usize>> {
1174        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1175
1176        let per_layer_elems = {
1177            let input_layernorm = cfg.hidden_size;
1178            let post_attention_layernorm = cfg.hidden_size;
1179
1180            let size_in = cfg.hidden_size;
1181            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1182            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1183            let q_proj = size_in * size_q / weight_pack_factor;
1184            let k_proj = size_in * size_kv / weight_pack_factor;
1185            let v_proj = size_in * size_kv / weight_pack_factor;
1186            let o_proj = size_q * size_in / weight_pack_factor;
1187
1188            let moe_block = {
1189                let gate = cfg.hidden_size * cfg.num_local_experts;
1190                // Assume quantizing weight pack factor
1191                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1192                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1193                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1194                gate + cfg.num_local_experts * w1
1195                    + cfg.num_local_experts * w2
1196                    + cfg.num_local_experts * w3
1197            };
1198
1199            input_layernorm
1200                + post_attention_layernorm
1201                + q_proj
1202                + k_proj
1203                + v_proj
1204                + o_proj
1205                + moe_block
1206        };
1207        Ok(vec![
1208            per_layer_elems * dtype.size_in_bytes();
1209            cfg.num_hidden_layers
1210        ])
1211    }
1212
1213    fn num_layers(&self, config: &str) -> Result<usize> {
1214        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1215
1216        Ok(cfg.num_hidden_layers)
1217    }
1218
1219    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1220        let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1221
1222        let cfg = ModelConfigMetadata {
1223            max_seq_len: cfg.max_position_embeddings,
1224            num_layers: cfg.num_hidden_layers,
1225            hidden_size: cfg.hidden_size,
1226            num_kv_heads: cfg.num_key_value_heads,
1227            num_attn_heads: cfg.num_attention_heads,
1228            sliding_window: cfg.sliding_window,
1229            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1230            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1231            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1232        };
1233
1234        Ok(Box::new(cfg))
1235    }
1236}
1237
1238// ======================== Phi2 loader
1239
1240/// [`NormalLoader`] for a Phi 2 model.
1241///
1242/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
1243pub struct Phi2Loader;
1244
1245impl NormalModelLoader for Phi2Loader {
1246    fn load(
1247        &self,
1248        config: &str,
1249        vb: ShardedVarBuilder,
1250        normal_loading_metadata: NormalLoadingMetadata,
1251        attention_mechanism: AttentionImplementation,
1252    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1253        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1254
1255        Ok(Box::new(models::phi2::Model::new(
1256            &cfg,
1257            vb,
1258            self.is_gptx(config)?,
1259            normal_loading_metadata,
1260            attention_mechanism,
1261        )?))
1262    }
1263    fn load_xlora(
1264        &self,
1265        config: &str,
1266        vb: ShardedVarBuilder,
1267        lora_config: &[((String, String), LoraConfig)],
1268        xlora_config: Option<XLoraConfig>,
1269        xlora_ordering: Ordering,
1270        normal_loading_metadata: NormalLoadingMetadata,
1271        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1272    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1273        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1274
1275        Ok(Box::new(xlora_models::XLoraPhi2::new(
1276            &cfg,
1277            vb,
1278            lora_config,
1279            xlora_config,
1280            xlora_ordering,
1281            self.is_gptx(config)?,
1282            normal_loading_metadata,
1283            preload_adapters,
1284        )?))
1285    }
1286    fn is_gptx(&self, _: &str) -> Result<bool> {
1287        Ok(true)
1288    }
1289    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1290        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1291
1292        Ok(Box::new(cfg))
1293    }
1294}
1295
1296impl IsqModelLoader for Phi2Loader {
1297    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1298        Ok(vec![
1299            Regex::new(r"lm_head\.(weight|bias)$")?,
1300            // Attention
1301            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1302            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1303            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1304            Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1305            // MLP
1306            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1307            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1308        ])
1309    }
1310    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1311        self.isq_layer_regexes(config)
1312    }
1313}
1314
1315impl DeviceMappedModelLoader for Phi2Loader {
1316    fn mapped_max_act_size_elems(
1317        &self,
1318        config: &str,
1319        params: &AutoDeviceMapParams,
1320    ) -> Result<usize> {
1321        let AutoDeviceMapParams::Text {
1322            max_seq_len,
1323            max_batch_size,
1324        } = params
1325        else {
1326            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1327        };
1328
1329        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1330
1331        Ok(
1332            max_batch_size
1333                * cfg.num_attention_heads
1334                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1335        )
1336    }
1337    fn non_mapped_max_act_size_elems(
1338        &self,
1339        _config: &str,
1340        _params: &AutoDeviceMapParams,
1341    ) -> Result<usize> {
1342        Ok(0)
1343    }
1344
1345    fn non_mapped_size_in_bytes(
1346        &self,
1347        config: &str,
1348        dtype: DType,
1349        weight_pack_factor: usize,
1350        _matformer_config: Option<&MatformerSliceConfig>,
1351    ) -> Result<usize> {
1352        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1353
1354        let elems = {
1355            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1356            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1357            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1358                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1359            } else {
1360                0
1361            };
1362            let norm = cfg.hidden_size;
1363            embed_tokens + lm_head + norm
1364        };
1365        Ok(elems * dtype.size_in_bytes())
1366    }
1367
1368    fn layer_sizes_in_bytes(
1369        &self,
1370        config: &str,
1371        dtype: DType,
1372        weight_pack_factor: usize,
1373        _matformer_config: Option<&MatformerSliceConfig>,
1374    ) -> Result<Vec<usize>> {
1375        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1376
1377        let per_layer_elems = {
1378            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1379
1380            let size_in = cfg.hidden_size;
1381            let size_q = cfg.head_dim() * cfg.num_attention_heads;
1382            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1383            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1384            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1385            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1386            let o_proj = size_q * size_in / weight_pack_factor + size_in;
1387            let (q_norm, k_norm) = if cfg.qk_layernorm {
1388                (cfg.head_dim(), cfg.head_dim())
1389            } else {
1390                (0, 0)
1391            };
1392
1393            let h_size = cfg.hidden_size;
1394            let i_size = cfg.intermediate_size;
1395            let fc1 = h_size * i_size / weight_pack_factor;
1396            let fc2 = h_size * i_size / weight_pack_factor;
1397
1398            input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1399        };
1400        Ok(vec![
1401            per_layer_elems * dtype.size_in_bytes();
1402            cfg.num_hidden_layers
1403        ])
1404    }
1405
1406    fn num_layers(&self, config: &str) -> Result<usize> {
1407        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1408
1409        Ok(cfg.num_hidden_layers)
1410    }
1411
1412    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1413        let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1414
1415        let cfg = ModelConfigMetadata {
1416            max_seq_len: cfg.max_position_embeddings,
1417            num_layers: cfg.num_hidden_layers,
1418            hidden_size: cfg.hidden_size,
1419            num_kv_heads: cfg.num_key_value_heads(),
1420            num_attn_heads: cfg.num_attention_heads,
1421            sliding_window: None,
1422            k_head_dim: cfg.head_dim(),
1423            v_head_dim: cfg.head_dim(),
1424            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1425        };
1426
1427        Ok(Box::new(cfg))
1428    }
1429}
1430
1431// ======================== Phi3 loader
1432
1433/// [`NormalLoader`] for a Phi 3 model.
1434///
1435/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
1436pub struct Phi3Loader;
1437
1438impl NormalModelLoader for Phi3Loader {
1439    fn load(
1440        &self,
1441        config: &str,
1442        vb: ShardedVarBuilder,
1443        normal_loading_metadata: NormalLoadingMetadata,
1444        attention_mechanism: AttentionImplementation,
1445    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1446        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1447
1448        Ok(Box::new(models::phi3::Model::new(
1449            &cfg,
1450            vb,
1451            self.is_gptx(config)?,
1452            normal_loading_metadata,
1453            attention_mechanism,
1454        )?))
1455    }
1456    fn load_xlora(
1457        &self,
1458        config: &str,
1459        vb: ShardedVarBuilder,
1460        lora_config: &[((String, String), LoraConfig)],
1461        xlora_config: Option<XLoraConfig>,
1462        xlora_ordering: Ordering,
1463        normal_loading_metadata: NormalLoadingMetadata,
1464        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1465    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1466        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1467
1468        Ok(Box::new(xlora_models::XLoraPhi3::new(
1469            &cfg,
1470            vb,
1471            lora_config,
1472            xlora_config,
1473            xlora_ordering,
1474            self.is_gptx(config)?,
1475            normal_loading_metadata,
1476            preload_adapters,
1477        )?))
1478    }
1479    fn is_gptx(&self, _: &str) -> Result<bool> {
1480        Ok(true)
1481    }
1482    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1483        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1484
1485        Ok(Box::new(cfg))
1486    }
1487}
1488
1489impl IsqModelLoader for Phi3Loader {
1490    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1491        Ok(vec![
1492            Regex::new(r"lm_head\.(weight|bias)$")?,
1493            // Attention
1494            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1495            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1496            // MLP
1497            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1498            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1499            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1500        ])
1501    }
1502    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1503        self.isq_layer_regexes(config)
1504    }
1505}
1506
1507impl DeviceMappedModelLoader for Phi3Loader {
1508    fn mapped_max_act_size_elems(
1509        &self,
1510        config: &str,
1511        params: &AutoDeviceMapParams,
1512    ) -> Result<usize> {
1513        let AutoDeviceMapParams::Text {
1514            max_seq_len,
1515            max_batch_size,
1516        } = params
1517        else {
1518            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1519        };
1520
1521        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1522
1523        Ok(
1524            max_batch_size
1525                * cfg.num_attention_heads
1526                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1527        )
1528    }
1529    fn non_mapped_max_act_size_elems(
1530        &self,
1531        _config: &str,
1532        _params: &AutoDeviceMapParams,
1533    ) -> Result<usize> {
1534        Ok(0)
1535    }
1536
1537    fn non_mapped_size_in_bytes(
1538        &self,
1539        config: &str,
1540        dtype: DType,
1541        weight_pack_factor: usize,
1542        _matformer_config: Option<&MatformerSliceConfig>,
1543    ) -> Result<usize> {
1544        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1545
1546        let elems = {
1547            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1548            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1549            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1550                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1551            } else {
1552                0
1553            };
1554            let norm = cfg.hidden_size;
1555            embed_tokens + lm_head + norm
1556        };
1557        Ok(elems * dtype.size_in_bytes())
1558    }
1559
1560    fn layer_sizes_in_bytes(
1561        &self,
1562        config: &str,
1563        dtype: DType,
1564        weight_pack_factor: usize,
1565        _matformer_config: Option<&MatformerSliceConfig>,
1566    ) -> Result<Vec<usize>> {
1567        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1568
1569        let per_layer_elems = {
1570            let input_layernorm = cfg.hidden_size;
1571            let post_attention_layernorm = cfg.hidden_size;
1572
1573            let size_in = cfg.hidden_size;
1574            let head_dim = cfg.head_dim();
1575            let op_size =
1576                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1577            let qkv_proj = size_in * op_size / weight_pack_factor;
1578            let o_proj =
1579                (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1580
1581            let h_size = cfg.hidden_size;
1582            let i_size = cfg.intermediate_size;
1583            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1584            let down_proj = h_size * i_size / weight_pack_factor;
1585
1586            input_layernorm
1587                + post_attention_layernorm
1588                + qkv_proj
1589                + o_proj
1590                + gate_up_proj
1591                + down_proj
1592        };
1593        Ok(vec![
1594            per_layer_elems * dtype.size_in_bytes();
1595            cfg.num_hidden_layers
1596        ])
1597    }
1598
1599    fn num_layers(&self, config: &str) -> Result<usize> {
1600        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1601
1602        Ok(cfg.num_hidden_layers)
1603    }
1604
1605    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1606        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1607
1608        let cfg = ModelConfigMetadata {
1609            max_seq_len: cfg.max_position_embeddings,
1610            num_layers: cfg.num_hidden_layers,
1611            hidden_size: cfg.hidden_size,
1612            num_kv_heads: cfg.num_key_value_heads,
1613            num_attn_heads: cfg.num_attention_heads,
1614            sliding_window: cfg.sliding_window,
1615            k_head_dim: cfg.head_dim(),
1616            v_head_dim: cfg.head_dim(),
1617            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1618        };
1619
1620        Ok(Box::new(cfg))
1621    }
1622}
1623
1624// ======================== Qwen2 loader
1625
1626/// [`NormalLoader`] for a Qwen 2 model.
1627///
1628/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
1629pub struct Qwen2Loader;
1630
1631impl NormalModelLoader for Qwen2Loader {
1632    fn load(
1633        &self,
1634        config: &str,
1635        vb: ShardedVarBuilder,
1636        normal_loading_metadata: NormalLoadingMetadata,
1637        attention_mechanism: AttentionImplementation,
1638    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1639        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1640
1641        Ok(Box::new(models::qwen2::Model::new(
1642            &cfg,
1643            vb,
1644            self.is_gptx(config)?,
1645            normal_loading_metadata,
1646            attention_mechanism,
1647        )?))
1648    }
1649    fn load_xlora(
1650        &self,
1651        _config: &str,
1652        _vb: ShardedVarBuilder,
1653        _lora_config: &[((String, String), LoraConfig)],
1654        _xlora_config: Option<XLoraConfig>,
1655        _xlora_ordering: Ordering,
1656        _normal_loading_metadata: NormalLoadingMetadata,
1657        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1658    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1659        todo!()
1660    }
1661    fn is_gptx(&self, _: &str) -> Result<bool> {
1662        Ok(true)
1663    }
1664    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1665        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1666
1667        Ok(Box::new(cfg))
1668    }
1669}
1670
1671impl IsqModelLoader for Qwen2Loader {
1672    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1673        Ok(vec![
1674            Regex::new(r"lm_head\.(weight|bias)$")?,
1675            // Attention
1676            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1677            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1678            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1679            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1680            // MLP
1681            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1682            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1683            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1684            // MLP MoE
1685            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
1686            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
1687            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
1688        ])
1689    }
1690    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1691        self.isq_layer_regexes(config)
1692    }
1693}
1694
1695impl DeviceMappedModelLoader for Qwen2Loader {
1696    fn mapped_max_act_size_elems(
1697        &self,
1698        config: &str,
1699        params: &AutoDeviceMapParams,
1700    ) -> Result<usize> {
1701        let AutoDeviceMapParams::Text {
1702            max_seq_len,
1703            max_batch_size,
1704        } = params
1705        else {
1706            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1707        };
1708
1709        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1710
1711        Ok(
1712            max_batch_size
1713                * cfg.num_attention_heads
1714                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1715        )
1716    }
1717    fn non_mapped_max_act_size_elems(
1718        &self,
1719        _config: &str,
1720        _params: &AutoDeviceMapParams,
1721    ) -> Result<usize> {
1722        Ok(0)
1723    }
1724
1725    fn non_mapped_size_in_bytes(
1726        &self,
1727        config: &str,
1728        dtype: DType,
1729        weight_pack_factor: usize,
1730        _matformer_config: Option<&MatformerSliceConfig>,
1731    ) -> Result<usize> {
1732        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1733
1734        let elems = {
1735            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1736            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1737            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1738                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1739            } else {
1740                0
1741            };
1742            let norm = cfg.hidden_size;
1743            embed_tokens + lm_head + norm
1744        };
1745        Ok(elems * dtype.size_in_bytes())
1746    }
1747
1748    fn layer_sizes_in_bytes(
1749        &self,
1750        config: &str,
1751        dtype: DType,
1752        weight_pack_factor: usize,
1753        _matformer_config: Option<&MatformerSliceConfig>,
1754    ) -> Result<Vec<usize>> {
1755        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1756
1757        let per_layer_elems = {
1758            let input_layernorm = cfg.hidden_size;
1759            let post_attention_layernorm = cfg.hidden_size;
1760
1761            let size_in = cfg.hidden_size;
1762            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1763            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1764            let q_proj = size_in * size_q / weight_pack_factor + size_q;
1765            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1766            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1767            let o_proj = size_q * size_in / weight_pack_factor;
1768
1769            let h_size = cfg.hidden_size;
1770            let i_size = cfg.intermediate_size;
1771            let gate_proj = h_size * i_size / weight_pack_factor;
1772            let up_proj = h_size * i_size / weight_pack_factor;
1773            let down_proj = i_size * h_size / weight_pack_factor;
1774
1775            input_layernorm
1776                + post_attention_layernorm
1777                + q_proj
1778                + k_proj
1779                + v_proj
1780                + o_proj
1781                + gate_proj
1782                + up_proj
1783                + down_proj
1784        };
1785        Ok(vec![
1786            per_layer_elems * dtype.size_in_bytes();
1787            cfg.num_hidden_layers
1788        ])
1789    }
1790
1791    fn num_layers(&self, config: &str) -> Result<usize> {
1792        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1793
1794        Ok(cfg.num_hidden_layers)
1795    }
1796
1797    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1798        let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1799
1800        let cfg = ModelConfigMetadata {
1801            max_seq_len: cfg.max_position_embeddings,
1802            num_layers: cfg.num_hidden_layers,
1803            hidden_size: cfg.hidden_size,
1804            num_kv_heads: cfg.num_key_value_heads,
1805            num_attn_heads: cfg.num_attention_heads,
1806            sliding_window: cfg.sliding_window,
1807            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1808            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1809            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1810        };
1811
1812        Ok(Box::new(cfg))
1813    }
1814}
1815
1816// ======================== Gemma2 loader
1817
1818/// [`NormalLoader`] for a Gemma2 model.
1819///
1820/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
1821pub struct Gemma2Loader;
1822
1823impl NormalModelLoader for Gemma2Loader {
1824    fn load(
1825        &self,
1826        config: &str,
1827        vb: ShardedVarBuilder,
1828        normal_loading_metadata: NormalLoadingMetadata,
1829        attention_mechanism: AttentionImplementation,
1830    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1831        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1832
1833        Ok(Box::new(models::gemma2::Model::new(
1834            &cfg,
1835            vb,
1836            self.is_gptx(config)?,
1837            normal_loading_metadata,
1838            attention_mechanism,
1839        )?))
1840    }
1841    fn load_xlora(
1842        &self,
1843        config: &str,
1844        vb: ShardedVarBuilder,
1845        lora_config: &[((String, String), LoraConfig)],
1846        xlora_config: Option<XLoraConfig>,
1847        xlora_ordering: Ordering,
1848        normal_loading_metadata: NormalLoadingMetadata,
1849        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1850    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1851        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1852
1853        Ok(Box::new(xlora_models::XLoraGemma2::new(
1854            &cfg,
1855            vb,
1856            lora_config,
1857            xlora_config,
1858            xlora_ordering,
1859            self.is_gptx(config)?,
1860            normal_loading_metadata,
1861            preload_adapters,
1862        )?))
1863    }
1864    fn is_gptx(&self, _: &str) -> Result<bool> {
1865        Ok(true)
1866    }
1867    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1868        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1869
1870        Ok(Box::new(cfg))
1871    }
1872}
1873
1874impl IsqModelLoader for Gemma2Loader {
1875    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1876        Ok(vec![
1877            Regex::new(r"lm_head\.(weight|bias)$")?,
1878            // Attention
1879            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1880            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1881            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1882            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1883            // MLP
1884            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1885            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1886            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1887        ])
1888    }
1889    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1890        self.isq_layer_regexes(config)
1891    }
1892}
1893
1894impl DeviceMappedModelLoader for Gemma2Loader {
1895    fn mapped_max_act_size_elems(
1896        &self,
1897        config: &str,
1898        params: &AutoDeviceMapParams,
1899    ) -> Result<usize> {
1900        let AutoDeviceMapParams::Text {
1901            max_seq_len,
1902            max_batch_size,
1903        } = params
1904        else {
1905            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1906        };
1907
1908        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1909
1910        Ok(
1911            max_batch_size
1912                * cfg.num_attention_heads
1913                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1914        )
1915    }
1916    fn non_mapped_max_act_size_elems(
1917        &self,
1918        _config: &str,
1919        _params: &AutoDeviceMapParams,
1920    ) -> Result<usize> {
1921        Ok(0)
1922    }
1923
1924    fn non_mapped_size_in_bytes(
1925        &self,
1926        config: &str,
1927        dtype: DType,
1928        weight_pack_factor: usize,
1929        _matformer_config: Option<&MatformerSliceConfig>,
1930    ) -> Result<usize> {
1931        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1932
1933        let elems = {
1934            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1935            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1936            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1937                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1938            } else {
1939                0
1940            };
1941            let norm = cfg.hidden_size;
1942            embed_tokens + lm_head + norm
1943        };
1944        Ok(elems * dtype.size_in_bytes())
1945    }
1946
1947    fn layer_sizes_in_bytes(
1948        &self,
1949        config: &str,
1950        dtype: DType,
1951        weight_pack_factor: usize,
1952        _matformer_config: Option<&MatformerSliceConfig>,
1953    ) -> Result<Vec<usize>> {
1954        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1955
1956        let per_layer_elems = {
1957            let input_layernorm = cfg.hidden_size;
1958            let post_attention_layernorm = cfg.hidden_size;
1959
1960            let size_in = cfg.hidden_size;
1961            let size_q = cfg.head_dim * cfg.num_attention_heads;
1962            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1963            let q_proj =
1964                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1965            let k_proj =
1966                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1967            let v_proj =
1968                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1969            let o_proj =
1970                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1971
1972            let h_size = cfg.hidden_size;
1973            let i_size = cfg.intermediate_size;
1974            let gate_proj = h_size * i_size / weight_pack_factor;
1975            let up_proj = h_size * i_size / weight_pack_factor;
1976            let down_proj = i_size * h_size / weight_pack_factor;
1977
1978            input_layernorm
1979                + post_attention_layernorm
1980                + q_proj
1981                + k_proj
1982                + v_proj
1983                + o_proj
1984                + gate_proj
1985                + up_proj
1986                + down_proj
1987        };
1988        Ok(vec![
1989            per_layer_elems * dtype.size_in_bytes();
1990            cfg.num_hidden_layers
1991        ])
1992    }
1993
1994    fn num_layers(&self, config: &str) -> Result<usize> {
1995        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1996
1997        Ok(cfg.num_hidden_layers)
1998    }
1999    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2000        let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
2001
2002        let cfg = ModelConfigMetadata {
2003            max_seq_len: cfg.max_position_embeddings,
2004            num_layers: cfg.num_hidden_layers,
2005            hidden_size: cfg.hidden_size,
2006            num_kv_heads: cfg.num_key_value_heads,
2007            num_attn_heads: cfg.num_attention_heads,
2008            sliding_window: None, // None to be more forgiving, some do not
2009            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2010            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2011            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2012        };
2013
2014        Ok(Box::new(cfg))
2015    }
2016}
2017
2018// ======================== Starcoder2 loader
2019
2020/// [`NormalLoader`] for a Starcoder2 model.
2021///
2022/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
2023pub struct Starcoder2Loader;
2024
2025impl NormalModelLoader for Starcoder2Loader {
2026    fn load(
2027        &self,
2028        config: &str,
2029        vb: ShardedVarBuilder,
2030        normal_loading_metadata: NormalLoadingMetadata,
2031        attention_mechanism: AttentionImplementation,
2032    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2033        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2034
2035        Ok(Box::new(models::starcoder2::Model::new(
2036            &cfg,
2037            vb,
2038            self.is_gptx(config)?,
2039            normal_loading_metadata,
2040            attention_mechanism,
2041        )?))
2042    }
2043    fn load_xlora(
2044        &self,
2045        config: &str,
2046        vb: ShardedVarBuilder,
2047        lora_config: &[((String, String), LoraConfig)],
2048        xlora_config: Option<XLoraConfig>,
2049        xlora_ordering: Ordering,
2050        normal_loading_metadata: NormalLoadingMetadata,
2051        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2052    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2053        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2054
2055        Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2056            &cfg,
2057            vb,
2058            lora_config,
2059            xlora_config,
2060            xlora_ordering,
2061            self.is_gptx(config)?,
2062            normal_loading_metadata,
2063            preload_adapters,
2064        )?))
2065    }
2066    fn is_gptx(&self, _: &str) -> Result<bool> {
2067        Ok(true)
2068    }
2069    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2070        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2071
2072        Ok(Box::new(cfg))
2073    }
2074}
2075
2076impl IsqModelLoader for Starcoder2Loader {
2077    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2078        Ok(vec![
2079            Regex::new(r"lm_head\.(weight|bias)$")?,
2080            // Attention
2081            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2082            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2083            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2084            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2085            // MLP
2086            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2087            Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2088        ])
2089    }
2090    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2091        self.isq_layer_regexes(config)
2092    }
2093}
2094
2095impl DeviceMappedModelLoader for Starcoder2Loader {
2096    fn mapped_max_act_size_elems(
2097        &self,
2098        config: &str,
2099        params: &AutoDeviceMapParams,
2100    ) -> Result<usize> {
2101        let AutoDeviceMapParams::Text {
2102            max_seq_len,
2103            max_batch_size,
2104        } = params
2105        else {
2106            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2107        };
2108
2109        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2110
2111        Ok(
2112            max_batch_size
2113                * cfg.num_attention_heads
2114                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2115        )
2116    }
2117    fn non_mapped_max_act_size_elems(
2118        &self,
2119        _config: &str,
2120        _params: &AutoDeviceMapParams,
2121    ) -> Result<usize> {
2122        Ok(0)
2123    }
2124
2125    fn non_mapped_size_in_bytes(
2126        &self,
2127        config: &str,
2128        dtype: DType,
2129        weight_pack_factor: usize,
2130        _matformer_config: Option<&MatformerSliceConfig>,
2131    ) -> Result<usize> {
2132        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2133
2134        let elems = {
2135            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2136            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2137            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2138                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2139            } else {
2140                0
2141            };
2142            let norm = cfg.hidden_size + cfg.hidden_size;
2143            embed_tokens + lm_head + norm
2144        };
2145        Ok(elems * dtype.size_in_bytes())
2146    }
2147
2148    fn layer_sizes_in_bytes(
2149        &self,
2150        config: &str,
2151        dtype: DType,
2152        weight_pack_factor: usize,
2153        _matformer_config: Option<&MatformerSliceConfig>,
2154    ) -> Result<Vec<usize>> {
2155        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2156
2157        let per_layer_elems = {
2158            let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2159            let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2160
2161            let size_in = cfg.hidden_size;
2162            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2163            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2164            let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2165            let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2166            let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2167            let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2168
2169            let h_size = cfg.hidden_size;
2170            let i_size = cfg.intermediate_size;
2171            let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2172            let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2173
2174            input_layernorm
2175                + post_attention_layernorm
2176                + q_proj
2177                + k_proj
2178                + v_proj
2179                + o_proj
2180                + fc1
2181                + fc2
2182        };
2183        Ok(vec![
2184            per_layer_elems * dtype.size_in_bytes();
2185            cfg.num_hidden_layers
2186        ])
2187    }
2188
2189    fn num_layers(&self, config: &str) -> Result<usize> {
2190        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2191
2192        Ok(cfg.num_hidden_layers)
2193    }
2194
2195    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2196        let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2197
2198        let cfg = ModelConfigMetadata {
2199            max_seq_len: cfg.max_position_embeddings,
2200            num_layers: cfg.num_hidden_layers,
2201            hidden_size: cfg.hidden_size,
2202            num_kv_heads: cfg.num_key_value_heads,
2203            num_attn_heads: cfg.num_attention_heads,
2204            sliding_window: cfg.sliding_window,
2205            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2206            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2207            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2208        };
2209
2210        Ok(Box::new(cfg))
2211    }
2212}
2213
2214// ======================== Phi3 loader
2215
2216/// [`NormalLoader`] for a Phi 3.5 MoE model.
2217///
2218/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
2219pub struct Phi3_5MoELoader;
2220
2221impl NormalModelLoader for Phi3_5MoELoader {
2222    fn load(
2223        &self,
2224        config: &str,
2225        vb: ShardedVarBuilder,
2226        normal_loading_metadata: NormalLoadingMetadata,
2227        attention_mechanism: AttentionImplementation,
2228    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2229        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2230
2231        Ok(Box::new(models::phi3_5_moe::Model::new(
2232            &cfg,
2233            vb,
2234            self.is_gptx(config)?,
2235            normal_loading_metadata,
2236            attention_mechanism,
2237        )?))
2238    }
2239    fn load_xlora(
2240        &self,
2241        config: &str,
2242        vb: ShardedVarBuilder,
2243        lora_config: &[((String, String), LoraConfig)],
2244        xlora_config: Option<XLoraConfig>,
2245        xlora_ordering: Ordering,
2246        normal_loading_metadata: NormalLoadingMetadata,
2247        preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2248    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2249        let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2250
2251        Ok(Box::new(xlora_models::XLoraPhi3::new(
2252            &cfg,
2253            vb,
2254            lora_config,
2255            xlora_config,
2256            xlora_ordering,
2257            self.is_gptx(config)?,
2258            normal_loading_metadata,
2259            preload_adapters,
2260        )?))
2261    }
2262    fn is_gptx(&self, _: &str) -> Result<bool> {
2263        Ok(true)
2264    }
2265    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2266        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2267
2268        Ok(Box::new(cfg))
2269    }
2270}
2271
2272impl IsqModelLoader for Phi3_5MoELoader {
2273    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2274        Ok(vec![
2275            Regex::new(r"lm_head\.(weight|bias)$")?,
2276            // Attention
2277            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2278            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2279            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2280            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2281            // MLP
2282            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2283            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2284            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2285        ])
2286    }
2287    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2288        self.isq_layer_regexes(config)
2289    }
2290
2291    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2292        Ok(vec![
2293            Regex::new(r"lm_head\.(weight|bias)$")?,
2294            // MLP
2295            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2296            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2297            Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2298        ])
2299    }
2300    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2301        self.isq_layer_regexes_moqe(config)
2302    }
2303}
2304
2305impl DeviceMappedModelLoader for Phi3_5MoELoader {
2306    fn mapped_max_act_size_elems(
2307        &self,
2308        config: &str,
2309        params: &AutoDeviceMapParams,
2310    ) -> Result<usize> {
2311        let AutoDeviceMapParams::Text {
2312            max_seq_len,
2313            max_batch_size,
2314        } = params
2315        else {
2316            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2317        };
2318
2319        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2320
2321        Ok(
2322            max_batch_size
2323                * cfg.num_attention_heads
2324                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2325        )
2326    }
2327    fn non_mapped_max_act_size_elems(
2328        &self,
2329        _config: &str,
2330        _params: &AutoDeviceMapParams,
2331    ) -> Result<usize> {
2332        Ok(0)
2333    }
2334
2335    fn non_mapped_size_in_bytes(
2336        &self,
2337        config: &str,
2338        dtype: DType,
2339        weight_pack_factor: usize,
2340        _matformer_config: Option<&MatformerSliceConfig>,
2341    ) -> Result<usize> {
2342        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2343
2344        let elems = {
2345            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2346            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2347            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2348                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2349            } else {
2350                0
2351            };
2352            let norm = cfg.hidden_size;
2353            embed_tokens + lm_head + norm
2354        };
2355        Ok(elems * dtype.size_in_bytes())
2356    }
2357
2358    fn layer_sizes_in_bytes(
2359        &self,
2360        config: &str,
2361        dtype: DType,
2362        weight_pack_factor: usize,
2363        _matformer_config: Option<&MatformerSliceConfig>,
2364    ) -> Result<Vec<usize>> {
2365        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2366
2367        let per_layer_elems = {
2368            let input_layernorm = cfg.hidden_size;
2369            let post_attention_layernorm = cfg.hidden_size;
2370
2371            let size_in = cfg.hidden_size;
2372            let size_q = cfg.head_dim() * cfg.num_attention_heads;
2373            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2374            let q_proj =
2375                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2376            let k_proj =
2377                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2378            let v_proj =
2379                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2380            let o_proj =
2381                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2382
2383            let moe_block = {
2384                let gate = cfg.hidden_size * cfg.num_local_experts;
2385                // Assume quantizing weight pack factor
2386                let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2387                let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2388                let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2389                gate + cfg.num_local_experts * w1
2390                    + cfg.num_local_experts * w2
2391                    + cfg.num_local_experts * w3
2392            };
2393
2394            input_layernorm
2395                + post_attention_layernorm
2396                + q_proj
2397                + k_proj
2398                + v_proj
2399                + o_proj
2400                + moe_block
2401        };
2402        Ok(vec![
2403            per_layer_elems * dtype.size_in_bytes();
2404            cfg.num_hidden_layers
2405        ])
2406    }
2407
2408    fn num_layers(&self, config: &str) -> Result<usize> {
2409        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2410
2411        Ok(cfg.num_hidden_layers)
2412    }
2413
2414    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2415        let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2416
2417        let cfg = ModelConfigMetadata {
2418            max_seq_len: cfg.max_position_embeddings,
2419            num_layers: cfg.num_hidden_layers,
2420            hidden_size: cfg.hidden_size,
2421            num_kv_heads: cfg.num_key_value_heads,
2422            num_attn_heads: cfg.num_attention_heads,
2423            sliding_window: cfg.sliding_window,
2424            k_head_dim: cfg.head_dim(),
2425            v_head_dim: cfg.head_dim(),
2426            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2427        };
2428
2429        Ok(Box::new(cfg))
2430    }
2431}
2432
2433/// [`NormalLoader`] for a DeepSeekV2 model.
2434///
2435/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
2436pub struct DeepSeekV2Loader;
2437
2438impl NormalModelLoader for DeepSeekV2Loader {
2439    fn load(
2440        &self,
2441        config: &str,
2442        vb: ShardedVarBuilder,
2443        normal_loading_metadata: NormalLoadingMetadata,
2444        attention_mechanism: AttentionImplementation,
2445    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2446        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2447
2448        Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2449            &cfg,
2450            vb,
2451            self.is_gptx(config)?,
2452            normal_loading_metadata,
2453            attention_mechanism,
2454        )?))
2455    }
2456    fn load_xlora(
2457        &self,
2458        _config: &str,
2459        _vb: ShardedVarBuilder,
2460        _lora_config: &[((String, String), LoraConfig)],
2461        _xlora_config: Option<XLoraConfig>,
2462        _xlora_ordering: Ordering,
2463        _normal_loading_metadata: NormalLoadingMetadata,
2464        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2465    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2466        todo!()
2467    }
2468    fn is_gptx(&self, _: &str) -> Result<bool> {
2469        Ok(true)
2470    }
2471    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2472        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2473        Ok(Box::new(cfg))
2474    }
2475}
2476
2477impl IsqModelLoader for DeepSeekV2Loader {
2478    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2479        let mut data = vec![
2480            Regex::new(r"lm_head\.(weight|bias)$")?,
2481            // Attention
2482            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2483            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2484            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2485        ];
2486        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2487        if cfg.q_lora_rank.is_some() {
2488            data.extend(vec![
2489                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2490                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2491            ]);
2492        } else {
2493            data.push(Regex::new(
2494                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2495            )?);
2496        }
2497        for layer_idx in 0..cfg.num_hidden_layers {
2498            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2499                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2500            }) {
2501                for i in 0..n_routed_experts {
2502                    data.extend(vec![
2503                        Regex::new(&format!(
2504                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2505                        ))?,
2506                        Regex::new(&format!(
2507                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2508                        ))?,
2509                        Regex::new(&format!(
2510                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2511                        ))?,
2512                    ]);
2513                }
2514                if cfg.n_shared_experts.is_some() {
2515                    data.extend(vec![
2516                        Regex::new(&format!(
2517                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2518                        ))?,
2519                        Regex::new(&format!(
2520                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2521                        ))?,
2522                        Regex::new(&format!(
2523                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2524                        ))?,
2525                    ]);
2526                }
2527            } else {
2528                data.extend(vec![
2529                    Regex::new(&format!(
2530                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2531                    ))?,
2532                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2533                    Regex::new(&format!(
2534                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2535                    ))?,
2536                ]);
2537            };
2538        }
2539        Ok(data)
2540    }
2541    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2542        self.isq_layer_regexes(config)
2543    }
2544
2545    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2546        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2547        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2548        for layer_idx in 0..cfg.num_hidden_layers {
2549            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2550                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2551            }) {
2552                for i in 0..n_routed_experts {
2553                    data.extend(vec![
2554                        Regex::new(&format!(
2555                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2556                        ))?,
2557                        Regex::new(&format!(
2558                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2559                        ))?,
2560                        Regex::new(&format!(
2561                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2562                        ))?,
2563                    ]);
2564                }
2565                if cfg.n_shared_experts.is_some() {
2566                    data.extend(vec![
2567                        Regex::new(&format!(
2568                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2569                        ))?,
2570                        Regex::new(&format!(
2571                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2572                        ))?,
2573                        Regex::new(&format!(
2574                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2575                        ))?,
2576                    ]);
2577                }
2578            } else {
2579                data.extend(vec![
2580                    Regex::new(&format!(
2581                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2582                    ))?,
2583                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2584                    Regex::new(&format!(
2585                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2586                    ))?,
2587                ]);
2588            };
2589        }
2590        Ok(data)
2591    }
2592    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2593        self.isq_layer_regexes_moqe(config)
2594    }
2595}
2596
2597impl DeviceMappedModelLoader for DeepSeekV2Loader {
2598    fn mapped_max_act_size_elems(
2599        &self,
2600        config: &str,
2601        params: &AutoDeviceMapParams,
2602    ) -> Result<usize> {
2603        let AutoDeviceMapParams::Text {
2604            max_seq_len,
2605            max_batch_size,
2606        } = params
2607        else {
2608            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2609        };
2610
2611        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2612
2613        Ok(
2614            max_batch_size
2615                * cfg.num_attention_heads
2616                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2617        )
2618    }
2619    fn non_mapped_max_act_size_elems(
2620        &self,
2621        _config: &str,
2622        _params: &AutoDeviceMapParams,
2623    ) -> Result<usize> {
2624        Ok(0)
2625    }
2626
2627    fn non_mapped_size_in_bytes(
2628        &self,
2629        config: &str,
2630        dtype: DType,
2631        weight_pack_factor: usize,
2632        _matformer_config: Option<&MatformerSliceConfig>,
2633    ) -> Result<usize> {
2634        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2635        let elems = {
2636            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2637            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2638            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2639                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2640            } else {
2641                0
2642            };
2643            let norm = cfg.hidden_size;
2644            embed_tokens + lm_head + norm
2645        };
2646        Ok(elems * dtype.size_in_bytes())
2647    }
2648
2649    fn layer_sizes_in_bytes(
2650        &self,
2651        config: &str,
2652        dtype: DType,
2653        weight_pack_factor: usize,
2654        _matformer_config: Option<&MatformerSliceConfig>,
2655    ) -> Result<Vec<usize>> {
2656        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2657        let mut per_layer_elems = Vec::new();
2658
2659        for layer_idx in 0..cfg.num_hidden_layers {
2660            let input_layernorm = cfg.hidden_size;
2661            let post_attention_layernorm = cfg.hidden_size;
2662
2663            let q_proj = match cfg.q_lora_rank {
2664                Some(lora_rank) => {
2665                    let a = cfg.hidden_size * lora_rank;
2666                    let norm = lora_rank;
2667                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2668                    a + norm + b
2669                }
2670                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2671            };
2672            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2673                / weight_pack_factor
2674                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2675            let kv_a_layernorm = cfg.kv_lora_rank;
2676            let kv_b_proj = cfg.kv_lora_rank
2677                * cfg.num_attention_heads
2678                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2679                / weight_pack_factor;
2680            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2681                / weight_pack_factor
2682                + bias_if!(cfg.attention_bias, cfg.hidden_size);
2683
2684            let moe_block = {
2685                let mut sum = 0;
2686                if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2687                    layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2688                }) {
2689                    let h_size = cfg.hidden_size;
2690                    let gate_proj =
2691                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2692                    let up_proj =
2693                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2694                    let down_proj =
2695                        cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
2696                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2697                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2698                            / weight_pack_factor;
2699                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2700                            / weight_pack_factor;
2701                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2702                            / weight_pack_factor;
2703                        gate_proj + up_proj + down_proj
2704                    } else {
2705                        0
2706                    };
2707                    let gate_weight = n_routed_experts * cfg.hidden_size;
2708                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2709                } else {
2710                    let h_size = cfg.hidden_size;
2711                    let i_size = cfg.intermediate_size;
2712                    let gate_proj = h_size * i_size / weight_pack_factor;
2713                    let up_proj = h_size * i_size / weight_pack_factor;
2714                    let down_proj = i_size * h_size / weight_pack_factor;
2715                    sum += gate_proj + up_proj + down_proj;
2716                }
2717                sum
2718            };
2719
2720            per_layer_elems.push(
2721                input_layernorm
2722                    + post_attention_layernorm
2723                    + q_proj
2724                    + kv_a_layernorm
2725                    + kv_a_proj_with_mqa
2726                    + kv_b_proj
2727                    + o_proj
2728                    + moe_block,
2729            );
2730        }
2731
2732        Ok(per_layer_elems
2733            .into_iter()
2734            .map(|x| x * dtype.size_in_bytes())
2735            .collect())
2736    }
2737
2738    fn num_layers(&self, config: &str) -> Result<usize> {
2739        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2740        Ok(cfg.num_hidden_layers)
2741    }
2742
2743    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2744        let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2745
2746        let cfg = ModelConfigMetadata {
2747            max_seq_len: cfg.max_position_embeddings,
2748            num_layers: cfg.num_hidden_layers,
2749            hidden_size: cfg.hidden_size,
2750            num_kv_heads: cfg.num_attention_heads,
2751            num_attn_heads: cfg.num_attention_heads,
2752            sliding_window: None,
2753            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2754            v_head_dim: cfg.v_head_dim,
2755            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2756        };
2757
2758        Ok(Box::new(cfg))
2759    }
2760}
2761
2762/// [`NormalLoader`] for a DeepSeekV3 model.
2763///
2764/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
2765pub struct DeepSeekV3Loader;
2766
2767impl NormalModelLoader for DeepSeekV3Loader {
2768    fn load(
2769        &self,
2770        config: &str,
2771        vb: ShardedVarBuilder,
2772        normal_loading_metadata: NormalLoadingMetadata,
2773        attention_mechanism: AttentionImplementation,
2774    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2775        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2776        Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2777            &cfg,
2778            vb,
2779            self.is_gptx(config)?,
2780            normal_loading_metadata,
2781            attention_mechanism,
2782        )?))
2783    }
2784    fn load_xlora(
2785        &self,
2786        _config: &str,
2787        _vb: ShardedVarBuilder,
2788        _lora_config: &[((String, String), LoraConfig)],
2789        _xlora_config: Option<XLoraConfig>,
2790        _xlora_ordering: Ordering,
2791        _normal_loading_metadata: NormalLoadingMetadata,
2792        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2793    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2794        todo!()
2795    }
2796    fn is_gptx(&self, _: &str) -> Result<bool> {
2797        Ok(true)
2798    }
2799    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2800        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2801        Ok(Box::new(cfg))
2802    }
2803}
2804
2805impl IsqModelLoader for DeepSeekV3Loader {
2806    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2807        let mut data = vec![
2808            Regex::new(r"lm_head\.(weight|bias)$")?,
2809            // Attention
2810            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2811            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2812            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2813        ];
2814        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2815        if cfg.q_lora_rank.is_some() {
2816            data.extend(vec![
2817                Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2818                Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2819            ]);
2820        } else {
2821            data.push(Regex::new(
2822                r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2823            )?);
2824        }
2825        for layer_idx in 0..cfg.num_hidden_layers {
2826            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2827                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2828            }) {
2829                for i in 0..n_routed_experts {
2830                    data.extend(vec![
2831                        Regex::new(&format!(
2832                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2833                        ))?,
2834                        Regex::new(&format!(
2835                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2836                        ))?,
2837                        Regex::new(&format!(
2838                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2839                        ))?,
2840                    ]);
2841                }
2842                if cfg.n_shared_experts.is_some() {
2843                    data.extend(vec![
2844                        Regex::new(&format!(
2845                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2846                        ))?,
2847                        Regex::new(&format!(
2848                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2849                        ))?,
2850                        Regex::new(&format!(
2851                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2852                        ))?,
2853                    ]);
2854                }
2855            } else {
2856                data.extend(vec![
2857                    Regex::new(&format!(
2858                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2859                    ))?,
2860                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2861                    Regex::new(&format!(
2862                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2863                    ))?,
2864                ]);
2865            };
2866        }
2867        Ok(data)
2868    }
2869    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2870        self.isq_layer_regexes(config)
2871    }
2872
2873    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2874        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2875        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2876        for layer_idx in 0..cfg.num_hidden_layers {
2877            if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2878                layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2879            }) {
2880                for i in 0..n_routed_experts {
2881                    data.extend(vec![
2882                        Regex::new(&format!(
2883                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2884                        ))?,
2885                        Regex::new(&format!(
2886                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2887                        ))?,
2888                        Regex::new(&format!(
2889                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2890                        ))?,
2891                    ]);
2892                }
2893                if cfg.n_shared_experts.is_some() {
2894                    data.extend(vec![
2895                        Regex::new(&format!(
2896                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2897                        ))?,
2898                        Regex::new(&format!(
2899                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2900                        ))?,
2901                        Regex::new(&format!(
2902                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2903                        ))?,
2904                    ]);
2905                }
2906            } else {
2907                data.extend(vec![
2908                    Regex::new(&format!(
2909                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2910                    ))?,
2911                    Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2912                    Regex::new(&format!(
2913                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2914                    ))?,
2915                ]);
2916            };
2917        }
2918        Ok(data)
2919    }
2920    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2921        self.isq_layer_regexes_moqe(config)
2922    }
2923}
2924
2925impl DeviceMappedModelLoader for DeepSeekV3Loader {
2926    fn mapped_max_act_size_elems(
2927        &self,
2928        config: &str,
2929        params: &AutoDeviceMapParams,
2930    ) -> Result<usize> {
2931        let AutoDeviceMapParams::Text {
2932            max_seq_len,
2933            max_batch_size,
2934        } = params
2935        else {
2936            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2937        };
2938
2939        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2940
2941        Ok(
2942            max_batch_size
2943                * cfg.num_attention_heads
2944                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2945        )
2946    }
2947    fn non_mapped_max_act_size_elems(
2948        &self,
2949        _config: &str,
2950        _params: &AutoDeviceMapParams,
2951    ) -> Result<usize> {
2952        Ok(0)
2953    }
2954
2955    fn non_mapped_size_in_bytes(
2956        &self,
2957        config: &str,
2958        dtype: DType,
2959        weight_pack_factor: usize,
2960        _matformer_config: Option<&MatformerSliceConfig>,
2961    ) -> Result<usize> {
2962        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2963        let elems = {
2964            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2965            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2966            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2967                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2968            } else {
2969                0
2970            };
2971            let norm = cfg.hidden_size;
2972            embed_tokens + lm_head + norm
2973        };
2974        Ok(elems * dtype.size_in_bytes())
2975    }
2976
2977    fn layer_sizes_in_bytes(
2978        &self,
2979        config: &str,
2980        dtype: DType,
2981        weight_pack_factor: usize,
2982        _matformer_config: Option<&MatformerSliceConfig>,
2983    ) -> Result<Vec<usize>> {
2984        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2985        let mut per_layer_elems = Vec::new();
2986
2987        for layer_idx in 0..cfg.num_hidden_layers {
2988            let input_layernorm = cfg.hidden_size;
2989            let post_attention_layernorm = cfg.hidden_size;
2990
2991            let q_proj = match cfg.q_lora_rank {
2992                Some(lora_rank) => {
2993                    let a = cfg.hidden_size * lora_rank;
2994                    let norm = lora_rank;
2995                    let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2996                    a + norm + b
2997                }
2998                None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2999            };
3000            let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
3001                / weight_pack_factor
3002                + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
3003            let kv_a_layernorm = cfg.kv_lora_rank;
3004            let kv_b_proj = cfg.kv_lora_rank
3005                * cfg.num_attention_heads
3006                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3007                / weight_pack_factor;
3008            let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
3009                / weight_pack_factor
3010                + bias_if!(cfg.attention_bias, cfg.hidden_size);
3011
3012            let moe_block = {
3013                let mut sum = 0;
3014                if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
3015                    layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
3016                }) {
3017                    let h_size = cfg.hidden_size;
3018                    let gate_proj =
3019                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3020                    let up_proj =
3021                        h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3022                    let down_proj =
3023                        cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
3024                    let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3025                        let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3026                            / weight_pack_factor;
3027                        let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3028                            / weight_pack_factor;
3029                        let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3030                            / weight_pack_factor;
3031                        gate_proj + up_proj + down_proj
3032                    } else {
3033                        0
3034                    };
3035                    let gate_weight = n_routed_experts * cfg.hidden_size;
3036                    sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3037                } else {
3038                    let h_size = cfg.hidden_size;
3039                    let i_size = cfg.intermediate_size;
3040                    let gate_proj = h_size * i_size / weight_pack_factor;
3041                    let up_proj = h_size * i_size / weight_pack_factor;
3042                    let down_proj = i_size * h_size / weight_pack_factor;
3043                    sum += gate_proj + up_proj + down_proj;
3044                }
3045                sum
3046            };
3047
3048            per_layer_elems.push(
3049                input_layernorm
3050                    + post_attention_layernorm
3051                    + q_proj
3052                    + kv_a_layernorm
3053                    + kv_a_proj_with_mqa
3054                    + kv_b_proj
3055                    + o_proj
3056                    + moe_block,
3057            );
3058        }
3059
3060        Ok(per_layer_elems
3061            .into_iter()
3062            .map(|x| x * dtype.size_in_bytes())
3063            .collect())
3064    }
3065
3066    fn num_layers(&self, config: &str) -> Result<usize> {
3067        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3068        Ok(cfg.num_hidden_layers)
3069    }
3070
3071    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3072        let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3073
3074        let cfg = ModelConfigMetadata {
3075            max_seq_len: cfg.max_position_embeddings,
3076            num_layers: cfg.num_hidden_layers,
3077            hidden_size: cfg.hidden_size,
3078            num_kv_heads: cfg.num_attention_heads,
3079            num_attn_heads: cfg.num_attention_heads,
3080            sliding_window: None,
3081            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3082            v_head_dim: cfg.v_head_dim,
3083            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3084        };
3085
3086        Ok(Box::new(cfg))
3087    }
3088}
3089
3090/// [`NormalLoader`] for a Qwen 3 model.
3091///
3092/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
3093pub struct Qwen3Loader;
3094
3095impl NormalModelLoader for Qwen3Loader {
3096    fn load(
3097        &self,
3098        config: &str,
3099        vb: ShardedVarBuilder,
3100        normal_loading_metadata: NormalLoadingMetadata,
3101        attention_mechanism: AttentionImplementation,
3102    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3103        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3104
3105        Ok(Box::new(models::qwen3::Model::new(
3106            &cfg,
3107            vb,
3108            self.is_gptx(config)?,
3109            normal_loading_metadata,
3110            attention_mechanism,
3111        )?))
3112    }
3113    fn load_xlora(
3114        &self,
3115        _config: &str,
3116        _vb: ShardedVarBuilder,
3117        _lora_config: &[((String, String), LoraConfig)],
3118        _xlora_config: Option<XLoraConfig>,
3119        _xlora_ordering: Ordering,
3120        _normal_loading_metadata: NormalLoadingMetadata,
3121        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3122    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3123        todo!()
3124    }
3125    fn is_gptx(&self, _: &str) -> Result<bool> {
3126        Ok(true)
3127    }
3128    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3129        let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3130
3131        Ok(Box::new(cfg))
3132    }
3133}
3134
3135impl IsqModelLoader for Qwen3Loader {
3136    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3137        Ok(vec![
3138            Regex::new(r"lm_head\.(weight|bias)$")?,
3139            // Attention
3140            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3141            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3142            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3143            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3144            // MLP
3145            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3146            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3147            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3148        ])
3149    }
3150    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3151        self.isq_layer_regexes(config)
3152    }
3153}
3154
3155impl DeviceMappedModelLoader for Qwen3Loader {
3156    fn mapped_max_act_size_elems(
3157        &self,
3158        config: &str,
3159        params: &AutoDeviceMapParams,
3160    ) -> Result<usize> {
3161        let AutoDeviceMapParams::Text {
3162            max_seq_len,
3163            max_batch_size,
3164        } = params
3165        else {
3166            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3167        };
3168
3169        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3170
3171        Ok(
3172            max_batch_size
3173                * cfg.num_attention_heads
3174                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3175        )
3176    }
3177    fn non_mapped_max_act_size_elems(
3178        &self,
3179        _config: &str,
3180        _params: &AutoDeviceMapParams,
3181    ) -> Result<usize> {
3182        Ok(0)
3183    }
3184
3185    fn non_mapped_size_in_bytes(
3186        &self,
3187        config: &str,
3188        dtype: DType,
3189        weight_pack_factor: usize,
3190        _matformer_config: Option<&MatformerSliceConfig>,
3191    ) -> Result<usize> {
3192        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3193        let elems = {
3194            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3195            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3196            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3197                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3198            } else {
3199                0
3200            };
3201            let norm = cfg.hidden_size;
3202            embed_tokens + lm_head + norm
3203        };
3204        Ok(elems * dtype.size_in_bytes())
3205    }
3206
3207    fn layer_sizes_in_bytes(
3208        &self,
3209        config: &str,
3210        dtype: DType,
3211        weight_pack_factor: usize,
3212        _matformer_config: Option<&MatformerSliceConfig>,
3213    ) -> Result<Vec<usize>> {
3214        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3215        let per_layer_elems = {
3216            let input_layernorm = cfg.hidden_size;
3217            let post_attention_layernorm = cfg.hidden_size;
3218
3219            let size_in = cfg.hidden_size;
3220            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3221            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3222            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3223            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3224            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3225            let o_proj = size_q * size_in / weight_pack_factor;
3226
3227            let h_size = cfg.hidden_size;
3228            let i_size = cfg.intermediate_size;
3229            let gate_proj = h_size * i_size / weight_pack_factor;
3230            let up_proj = h_size * i_size / weight_pack_factor;
3231            let down_proj = i_size * h_size / weight_pack_factor;
3232
3233            let q_norm = cfg.head_dim();
3234            let k_norm = cfg.head_dim();
3235
3236            input_layernorm
3237                + post_attention_layernorm
3238                + q_proj
3239                + k_proj
3240                + v_proj
3241                + o_proj
3242                + gate_proj
3243                + up_proj
3244                + down_proj
3245                + q_norm
3246                + k_norm
3247        };
3248        Ok(vec![
3249            per_layer_elems * dtype.size_in_bytes();
3250            cfg.num_hidden_layers
3251        ])
3252    }
3253
3254    fn num_layers(&self, config: &str) -> Result<usize> {
3255        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3256        Ok(cfg.num_hidden_layers)
3257    }
3258
3259    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3260        let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3261
3262        let cfg = ModelConfigMetadata {
3263            max_seq_len: cfg.max_position_embeddings,
3264            num_layers: cfg.num_hidden_layers,
3265            hidden_size: cfg.hidden_size,
3266            num_kv_heads: cfg.num_key_value_heads,
3267            num_attn_heads: cfg.num_attention_heads,
3268            sliding_window: cfg.sliding_window,
3269            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3270            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3271            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3272        };
3273
3274        Ok(Box::new(cfg))
3275    }
3276}
3277
3278/// [`NormalLoader`] for a GLM 4 model.
3279///
3280/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
3281pub struct GLM4Loader;
3282
3283impl NormalModelLoader for GLM4Loader {
3284    fn load(
3285        &self,
3286        config: &str,
3287        vb: ShardedVarBuilder,
3288        normal_loading_metadata: NormalLoadingMetadata,
3289        attention_mechanism: AttentionImplementation,
3290    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3291        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3292
3293        Ok(Box::new(models::glm4::Model::new(
3294            &cfg,
3295            vb,
3296            self.is_gptx(config)?,
3297            normal_loading_metadata,
3298            attention_mechanism,
3299        )?))
3300    }
3301    fn load_xlora(
3302        &self,
3303        _config: &str,
3304        _vb: ShardedVarBuilder,
3305        _lora_config: &[((String, String), LoraConfig)],
3306        _xlora_config: Option<XLoraConfig>,
3307        _xlora_ordering: Ordering,
3308        _normal_loading_metadata: NormalLoadingMetadata,
3309        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3310    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3311        todo!()
3312    }
3313    fn is_gptx(&self, _: &str) -> Result<bool> {
3314        Ok(true)
3315    }
3316    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3317        let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3318
3319        Ok(Box::new(cfg))
3320    }
3321}
3322
3323impl IsqModelLoader for GLM4Loader {
3324    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3325        Ok(vec![
3326            Regex::new(r"lm_head\.(weight|bias)$")?,
3327            // Attention
3328            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3329            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3330            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3331            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3332            // MLP
3333            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3334            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3335            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3336        ])
3337    }
3338    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3339        self.isq_layer_regexes(config)
3340    }
3341}
3342
3343impl DeviceMappedModelLoader for GLM4Loader {
3344    fn mapped_max_act_size_elems(
3345        &self,
3346        config: &str,
3347        params: &AutoDeviceMapParams,
3348    ) -> Result<usize> {
3349        let AutoDeviceMapParams::Text {
3350            max_seq_len,
3351            max_batch_size,
3352        } = params
3353        else {
3354            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3355        };
3356
3357        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3358
3359        Ok(
3360            max_batch_size
3361                * cfg.num_attention_heads
3362                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3363        )
3364    }
3365    fn non_mapped_max_act_size_elems(
3366        &self,
3367        _config: &str,
3368        _params: &AutoDeviceMapParams,
3369    ) -> Result<usize> {
3370        Ok(0)
3371    }
3372
3373    fn non_mapped_size_in_bytes(
3374        &self,
3375        config: &str,
3376        dtype: DType,
3377        weight_pack_factor: usize,
3378        _matformer_config: Option<&MatformerSliceConfig>,
3379    ) -> Result<usize> {
3380        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3381        let elems = {
3382            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3383            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3384            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3385                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3386            } else {
3387                0
3388            };
3389            let norm = cfg.hidden_size;
3390            embed_tokens + lm_head + norm
3391        };
3392        Ok(elems * dtype.size_in_bytes())
3393    }
3394
3395    fn layer_sizes_in_bytes(
3396        &self,
3397        config: &str,
3398        dtype: DType,
3399        weight_pack_factor: usize,
3400        _matformer_config: Option<&MatformerSliceConfig>,
3401    ) -> Result<Vec<usize>> {
3402        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3403        let per_layer_elems = {
3404            let input_layernorm = cfg.hidden_size;
3405            let post_attention_layernorm = cfg.hidden_size * 3; //+post_self_attn_layernorm and post_mlp_layernorm
3406
3407            let size_in = cfg.hidden_size;
3408            let size_q = cfg.head_dim() * cfg.num_attention_heads;
3409            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3410            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3411            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3412            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3413            let o_proj = size_q * size_in / weight_pack_factor;
3414
3415            let h_size = cfg.hidden_size;
3416            let i_size = cfg.intermediate_size;
3417            let gate_proj = h_size * i_size / weight_pack_factor;
3418            let up_proj = h_size * i_size / weight_pack_factor;
3419            let down_proj = i_size * h_size / weight_pack_factor;
3420
3421            input_layernorm
3422                + post_attention_layernorm
3423                + q_proj
3424                + k_proj
3425                + v_proj
3426                + o_proj
3427                + gate_proj
3428                + up_proj
3429                + down_proj
3430        };
3431        Ok(vec![
3432            per_layer_elems * dtype.size_in_bytes();
3433            cfg.num_hidden_layers
3434        ])
3435    }
3436
3437    fn num_layers(&self, config: &str) -> Result<usize> {
3438        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3439        Ok(cfg.num_hidden_layers)
3440    }
3441
3442    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3443        let cfg: models::glm4::Config = serde_json::from_str(config)?;
3444
3445        let cfg = ModelConfigMetadata {
3446            max_seq_len: cfg.max_position_embeddings,
3447            num_layers: cfg.num_hidden_layers,
3448            hidden_size: cfg.hidden_size,
3449            num_kv_heads: cfg.num_key_value_heads,
3450            num_attn_heads: cfg.num_attention_heads,
3451            sliding_window: cfg.sliding_window,
3452            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3453            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3454            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3455        };
3456
3457        Ok(Box::new(cfg))
3458    }
3459}
3460
3461/// [`NormalLoader`] for a GLM 4 MoE Lite model (GLM-4.7-Flash).
3462///
3463/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
3464pub struct GLM4MoeLiteLoader;
3465
3466impl NormalModelLoader for GLM4MoeLiteLoader {
3467    fn load(
3468        &self,
3469        config: &str,
3470        vb: ShardedVarBuilder,
3471        normal_loading_metadata: NormalLoadingMetadata,
3472        attention_mechanism: AttentionImplementation,
3473    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3474        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3475        Ok(Box::new(models::glm4_moe_lite::Glm4MoeLite::new(
3476            &cfg,
3477            vb,
3478            self.is_gptx(config)?,
3479            normal_loading_metadata,
3480            attention_mechanism,
3481        )?))
3482    }
3483    fn load_xlora(
3484        &self,
3485        _config: &str,
3486        _vb: ShardedVarBuilder,
3487        _lora_config: &[((String, String), LoraConfig)],
3488        _xlora_config: Option<XLoraConfig>,
3489        _xlora_ordering: Ordering,
3490        _normal_loading_metadata: NormalLoadingMetadata,
3491        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3492    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3493        todo!()
3494    }
3495    fn is_gptx(&self, _: &str) -> Result<bool> {
3496        Ok(true)
3497    }
3498    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3499        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3500        Ok(Box::new(cfg))
3501    }
3502}
3503
3504impl IsqModelLoader for GLM4MoeLiteLoader {
3505    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3506        let mut data = vec![
3507            Regex::new(r"lm_head\.(weight|bias)$")?,
3508            // Attention (MLA)
3509            Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3510            Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3511            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3512            // Q LoRA projections
3513            Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3514            Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3515        ];
3516        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3517        for layer_idx in 0..cfg.num_hidden_layers {
3518            if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3519                // MoE layer
3520                for i in 0..cfg.n_routed_experts {
3521                    data.extend(vec![
3522                        Regex::new(&format!(
3523                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3524                        ))?,
3525                        Regex::new(&format!(
3526                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3527                        ))?,
3528                        Regex::new(&format!(
3529                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3530                        ))?,
3531                    ]);
3532                }
3533                if cfg.n_shared_experts > 0 {
3534                    data.extend(vec![
3535                        Regex::new(&format!(
3536                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3537                        ))?,
3538                        Regex::new(&format!(
3539                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3540                        ))?,
3541                        Regex::new(&format!(
3542                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3543                        ))?,
3544                    ]);
3545                }
3546            } else {
3547                // Dense MLP layer
3548                data.extend(vec![
3549                    Regex::new(&format!(
3550                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3551                    ))?,
3552                    Regex::new(&format!(
3553                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3554                    ))?,
3555                    Regex::new(&format!(
3556                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3557                    ))?,
3558                ]);
3559            };
3560        }
3561        Ok(data)
3562    }
3563    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3564        self.isq_layer_regexes(config)
3565    }
3566
3567    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3568        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3569        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3570        for layer_idx in 0..cfg.num_hidden_layers {
3571            if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3572                // MoE layer
3573                for i in 0..cfg.n_routed_experts {
3574                    data.extend(vec![
3575                        Regex::new(&format!(
3576                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3577                        ))?,
3578                        Regex::new(&format!(
3579                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3580                        ))?,
3581                        Regex::new(&format!(
3582                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3583                        ))?,
3584                    ]);
3585                }
3586                if cfg.n_shared_experts > 0 {
3587                    data.extend(vec![
3588                        Regex::new(&format!(
3589                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3590                        ))?,
3591                        Regex::new(&format!(
3592                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3593                        ))?,
3594                        Regex::new(&format!(
3595                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3596                        ))?,
3597                    ]);
3598                }
3599            } else {
3600                // Dense MLP layer
3601                data.extend(vec![
3602                    Regex::new(&format!(
3603                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3604                    ))?,
3605                    Regex::new(&format!(
3606                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3607                    ))?,
3608                    Regex::new(&format!(
3609                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3610                    ))?,
3611                ]);
3612            };
3613        }
3614        Ok(data)
3615    }
3616    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3617        self.isq_layer_regexes_moqe(config)
3618    }
3619}
3620
3621impl DeviceMappedModelLoader for GLM4MoeLiteLoader {
3622    fn mapped_max_act_size_elems(
3623        &self,
3624        config: &str,
3625        params: &AutoDeviceMapParams,
3626    ) -> Result<usize> {
3627        let AutoDeviceMapParams::Text {
3628            max_seq_len,
3629            max_batch_size,
3630        } = params
3631        else {
3632            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3633        };
3634
3635        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3636
3637        Ok(
3638            max_batch_size
3639                * cfg.num_attention_heads
3640                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3641        )
3642    }
3643    fn non_mapped_max_act_size_elems(
3644        &self,
3645        _config: &str,
3646        _params: &AutoDeviceMapParams,
3647    ) -> Result<usize> {
3648        Ok(0)
3649    }
3650
3651    fn non_mapped_size_in_bytes(
3652        &self,
3653        config: &str,
3654        dtype: DType,
3655        weight_pack_factor: usize,
3656        _matformer_config: Option<&MatformerSliceConfig>,
3657    ) -> Result<usize> {
3658        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3659        let elems = {
3660            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3661            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3662            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3663                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3664            } else {
3665                0
3666            };
3667            let norm = cfg.hidden_size;
3668            embed_tokens + lm_head + norm
3669        };
3670        Ok(elems * dtype.size_in_bytes())
3671    }
3672
3673    fn layer_sizes_in_bytes(
3674        &self,
3675        config: &str,
3676        dtype: DType,
3677        weight_pack_factor: usize,
3678        _matformer_config: Option<&MatformerSliceConfig>,
3679    ) -> Result<Vec<usize>> {
3680        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3681        let mut per_layer_elems = Vec::new();
3682
3683        for layer_idx in 0..cfg.num_hidden_layers {
3684            let input_layernorm = cfg.hidden_size;
3685            let post_attention_layernorm = cfg.hidden_size;
3686
3687            // Q LoRA projection
3688            let q_proj = {
3689                let a = cfg.hidden_size * cfg.q_lora_rank / weight_pack_factor;
3690                let norm = cfg.q_lora_rank;
3691                let b = (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.q_lora_rank
3692                    / weight_pack_factor;
3693                a + norm + b
3694            };
3695            let kv_a_proj_with_mqa =
3696                cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim) / weight_pack_factor;
3697            let kv_a_layernorm = cfg.kv_lora_rank;
3698            let kv_b_proj = cfg.kv_lora_rank
3699                * cfg.num_attention_heads
3700                * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3701                / weight_pack_factor;
3702            let o_proj =
3703                cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size / weight_pack_factor;
3704
3705            let moe_block = {
3706                let mut sum = 0;
3707                if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3708                    // MoE layer
3709                    let h_size = cfg.hidden_size;
3710                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3711                        * cfg.n_routed_experts;
3712                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3713                        * cfg.n_routed_experts;
3714                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3715                        * cfg.n_routed_experts;
3716                    let shared_experts = if cfg.n_shared_experts > 0 {
3717                        let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3718                        let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3719                        let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
3720                        gate_proj + up_proj + down_proj
3721                    } else {
3722                        0
3723                    };
3724                    let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
3725                    let e_score_correction_bias = cfg.n_routed_experts;
3726                    sum += gate_proj
3727                        + up_proj
3728                        + down_proj
3729                        + shared_experts
3730                        + gate_weight
3731                        + e_score_correction_bias;
3732                } else {
3733                    // Dense MLP layer
3734                    let h_size = cfg.hidden_size;
3735                    let i_size = cfg.intermediate_size;
3736                    let gate_proj = h_size * i_size / weight_pack_factor;
3737                    let up_proj = h_size * i_size / weight_pack_factor;
3738                    let down_proj = i_size * h_size / weight_pack_factor;
3739                    sum += gate_proj + up_proj + down_proj;
3740                }
3741                sum
3742            };
3743
3744            per_layer_elems.push(
3745                input_layernorm
3746                    + post_attention_layernorm
3747                    + q_proj
3748                    + kv_a_layernorm
3749                    + kv_a_proj_with_mqa
3750                    + kv_b_proj
3751                    + o_proj
3752                    + moe_block,
3753            );
3754        }
3755
3756        Ok(per_layer_elems
3757            .into_iter()
3758            .map(|x| x * dtype.size_in_bytes())
3759            .collect())
3760    }
3761
3762    fn num_layers(&self, config: &str) -> Result<usize> {
3763        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3764        Ok(cfg.num_hidden_layers)
3765    }
3766
3767    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3768        let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3769
3770        let cfg = ModelConfigMetadata {
3771            max_seq_len: cfg.max_position_embeddings,
3772            num_layers: cfg.num_hidden_layers,
3773            hidden_size: cfg.hidden_size,
3774            num_kv_heads: cfg.num_attention_heads,
3775            num_attn_heads: cfg.num_attention_heads,
3776            sliding_window: None,
3777            k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3778            v_head_dim: cfg.v_head_dim,
3779            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3780        };
3781
3782        Ok(Box::new(cfg))
3783    }
3784}
3785
3786/// [`NormalLoader`] for a GLM 4 MoE model (GLM-4.5).
3787///
3788/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
3789pub struct GLM4MoeLoader;
3790
3791impl NormalModelLoader for GLM4MoeLoader {
3792    fn load(
3793        &self,
3794        config: &str,
3795        vb: ShardedVarBuilder,
3796        normal_loading_metadata: NormalLoadingMetadata,
3797        attention_mechanism: AttentionImplementation,
3798    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3799        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3800        Ok(Box::new(models::glm4_moe::Glm4Moe::new(
3801            &cfg,
3802            vb,
3803            self.is_gptx(config)?,
3804            normal_loading_metadata,
3805            attention_mechanism,
3806        )?))
3807    }
3808    fn load_xlora(
3809        &self,
3810        _config: &str,
3811        _vb: ShardedVarBuilder,
3812        _lora_config: &[((String, String), LoraConfig)],
3813        _xlora_config: Option<XLoraConfig>,
3814        _xlora_ordering: Ordering,
3815        _normal_loading_metadata: NormalLoadingMetadata,
3816        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3817    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3818        todo!()
3819    }
3820    fn is_gptx(&self, _: &str) -> Result<bool> {
3821        Ok(true)
3822    }
3823    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3824        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3825        Ok(Box::new(cfg))
3826    }
3827}
3828
3829impl IsqModelLoader for GLM4MoeLoader {
3830    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3831        let mut data = vec![
3832            Regex::new(r"lm_head\.(weight|bias)$")?,
3833            // Attention (standard GQA)
3834            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3835            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3836            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3837            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3838        ];
3839        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3840        for layer_idx in 0..cfg.num_hidden_layers {
3841            if layer_idx >= cfg.first_k_dense_replace {
3842                // MoE layer
3843                for i in 0..cfg.n_routed_experts {
3844                    data.extend(vec![
3845                        Regex::new(&format!(
3846                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3847                        ))?,
3848                        Regex::new(&format!(
3849                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3850                        ))?,
3851                        Regex::new(&format!(
3852                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3853                        ))?,
3854                    ]);
3855                }
3856                if cfg.n_shared_experts > 0 {
3857                    data.extend(vec![
3858                        Regex::new(&format!(
3859                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3860                        ))?,
3861                        Regex::new(&format!(
3862                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3863                        ))?,
3864                        Regex::new(&format!(
3865                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3866                        ))?,
3867                    ]);
3868                }
3869            } else {
3870                // Dense MLP layer
3871                data.extend(vec![
3872                    Regex::new(&format!(
3873                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3874                    ))?,
3875                    Regex::new(&format!(
3876                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3877                    ))?,
3878                    Regex::new(&format!(
3879                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3880                    ))?,
3881                ]);
3882            };
3883        }
3884        Ok(data)
3885    }
3886    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3887        self.isq_layer_regexes(config)
3888    }
3889
3890    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3891        let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3892        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3893        for layer_idx in 0..cfg.num_hidden_layers {
3894            if layer_idx >= cfg.first_k_dense_replace {
3895                // MoE layer
3896                for i in 0..cfg.n_routed_experts {
3897                    data.extend(vec![
3898                        Regex::new(&format!(
3899                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3900                        ))?,
3901                        Regex::new(&format!(
3902                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3903                        ))?,
3904                        Regex::new(&format!(
3905                            r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3906                        ))?,
3907                    ]);
3908                }
3909                if cfg.n_shared_experts > 0 {
3910                    data.extend(vec![
3911                        Regex::new(&format!(
3912                            r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3913                        ))?,
3914                        Regex::new(&format!(
3915                            r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3916                        ))?,
3917                        Regex::new(&format!(
3918                            r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3919                        ))?,
3920                    ]);
3921                }
3922            } else {
3923                // Dense MLP layer
3924                data.extend(vec![
3925                    Regex::new(&format!(
3926                        r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3927                    ))?,
3928                    Regex::new(&format!(
3929                        r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3930                    ))?,
3931                    Regex::new(&format!(
3932                        r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3933                    ))?,
3934                ]);
3935            };
3936        }
3937        Ok(data)
3938    }
3939    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3940        self.isq_layer_regexes_moqe(config)
3941    }
3942}
3943
3944impl DeviceMappedModelLoader for GLM4MoeLoader {
3945    fn mapped_max_act_size_elems(
3946        &self,
3947        config: &str,
3948        params: &AutoDeviceMapParams,
3949    ) -> Result<usize> {
3950        let AutoDeviceMapParams::Text {
3951            max_seq_len,
3952            max_batch_size,
3953        } = params
3954        else {
3955            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3956        };
3957
3958        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3959
3960        Ok(
3961            max_batch_size
3962                * cfg.num_attention_heads
3963                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3964        )
3965    }
3966    fn non_mapped_max_act_size_elems(
3967        &self,
3968        _config: &str,
3969        _params: &AutoDeviceMapParams,
3970    ) -> Result<usize> {
3971        Ok(0)
3972    }
3973
3974    fn non_mapped_size_in_bytes(
3975        &self,
3976        config: &str,
3977        dtype: DType,
3978        weight_pack_factor: usize,
3979        _matformer_config: Option<&MatformerSliceConfig>,
3980    ) -> Result<usize> {
3981        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3982        let elems = {
3983            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3984            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3985                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3986            } else {
3987                0
3988            };
3989            let norm = cfg.hidden_size;
3990            embed_tokens + lm_head + norm
3991        };
3992        Ok(elems * dtype.size_in_bytes())
3993    }
3994
3995    fn layer_sizes_in_bytes(
3996        &self,
3997        config: &str,
3998        dtype: DType,
3999        weight_pack_factor: usize,
4000        _matformer_config: Option<&MatformerSliceConfig>,
4001    ) -> Result<Vec<usize>> {
4002        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4003        let mut per_layer_elems = Vec::new();
4004
4005        let head_dim = cfg.head_dim();
4006        for layer_idx in 0..cfg.num_hidden_layers {
4007            let input_layernorm = cfg.hidden_size;
4008            let post_attention_layernorm = cfg.hidden_size;
4009
4010            // Standard GQA attention
4011            let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor
4012                + bias_if!(cfg.attention_bias, cfg.num_attention_heads * head_dim);
4013            let k_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4014                + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4015            let v_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4016                + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4017            let o_proj = cfg.num_attention_heads * head_dim * cfg.hidden_size / weight_pack_factor;
4018
4019            // QK norm if enabled
4020            let qk_norm = if cfg.use_qk_norm {
4021                head_dim * 2 // q_norm + k_norm
4022            } else {
4023                0
4024            };
4025
4026            let moe_block = {
4027                let mut sum = 0;
4028                if layer_idx >= cfg.first_k_dense_replace {
4029                    // MoE layer
4030                    let h_size = cfg.hidden_size;
4031                    let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4032                        * cfg.n_routed_experts;
4033                    let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4034                        * cfg.n_routed_experts;
4035                    let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
4036                        * cfg.n_routed_experts;
4037                    let shared_experts = if cfg.n_shared_experts > 0 {
4038                        let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4039                        let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4040                        let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
4041                        gate_proj + up_proj + down_proj
4042                    } else {
4043                        0
4044                    };
4045                    let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
4046                    let e_score_correction_bias = cfg.n_routed_experts;
4047                    sum += gate_proj
4048                        + up_proj
4049                        + down_proj
4050                        + shared_experts
4051                        + gate_weight
4052                        + e_score_correction_bias;
4053                } else {
4054                    // Dense MLP layer
4055                    let h_size = cfg.hidden_size;
4056                    let i_size = cfg.intermediate_size;
4057                    let gate_proj = h_size * i_size / weight_pack_factor;
4058                    let up_proj = h_size * i_size / weight_pack_factor;
4059                    let down_proj = i_size * h_size / weight_pack_factor;
4060                    sum += gate_proj + up_proj + down_proj;
4061                }
4062                sum
4063            };
4064
4065            per_layer_elems.push(
4066                input_layernorm
4067                    + post_attention_layernorm
4068                    + q_proj
4069                    + k_proj
4070                    + v_proj
4071                    + o_proj
4072                    + qk_norm
4073                    + moe_block,
4074            );
4075        }
4076
4077        Ok(per_layer_elems
4078            .into_iter()
4079            .map(|x| x * dtype.size_in_bytes())
4080            .collect())
4081    }
4082
4083    fn num_layers(&self, config: &str) -> Result<usize> {
4084        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4085        Ok(cfg.num_hidden_layers)
4086    }
4087
4088    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4089        let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4090
4091        let head_dim = cfg.head_dim();
4092        let cfg = ModelConfigMetadata {
4093            max_seq_len: cfg.max_position_embeddings,
4094            num_layers: cfg.num_hidden_layers,
4095            hidden_size: cfg.hidden_size,
4096            num_kv_heads: cfg.num_key_value_heads,
4097            num_attn_heads: cfg.num_attention_heads,
4098            sliding_window: None,
4099            k_head_dim: head_dim,
4100            v_head_dim: head_dim,
4101            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4102        };
4103
4104        Ok(Box::new(cfg))
4105    }
4106}
4107
4108/// [`NormalLoader`] for a Qwen 3 MoE model.
4109///
4110/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
4111pub struct Qwen3MoELoader;
4112
4113impl NormalModelLoader for Qwen3MoELoader {
4114    fn load(
4115        &self,
4116        config: &str,
4117        vb: ShardedVarBuilder,
4118        normal_loading_metadata: NormalLoadingMetadata,
4119        attention_mechanism: AttentionImplementation,
4120    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4121        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4122
4123        Ok(Box::new(models::qwen3_moe::Model::new(
4124            &cfg,
4125            vb,
4126            self.is_gptx(config)?,
4127            normal_loading_metadata,
4128            attention_mechanism,
4129        )?))
4130    }
4131    fn load_xlora(
4132        &self,
4133        _config: &str,
4134        _vb: ShardedVarBuilder,
4135        _lora_config: &[((String, String), LoraConfig)],
4136        _xlora_config: Option<XLoraConfig>,
4137        _xlora_ordering: Ordering,
4138        _normal_loading_metadata: NormalLoadingMetadata,
4139        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4140    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4141        todo!()
4142    }
4143    fn is_gptx(&self, _: &str) -> Result<bool> {
4144        Ok(true)
4145    }
4146    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4147        let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4148
4149        Ok(Box::new(cfg))
4150    }
4151}
4152
4153impl IsqModelLoader for Qwen3MoELoader {
4154    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4155        Ok(vec![
4156            Regex::new(r"lm_head\.(weight|bias)$")?,
4157            // Attention
4158            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4159            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4160            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4161            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4162            // MLP
4163            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4164            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4165            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4166            // MLP MoE
4167            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
4168            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
4169            Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
4170        ])
4171    }
4172    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4173        self.isq_layer_regexes(config)
4174    }
4175    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
4176        self.isq_layer_regexes_moqe(config)
4177    }
4178}
4179
4180impl DeviceMappedModelLoader for Qwen3MoELoader {
4181    fn mapped_max_act_size_elems(
4182        &self,
4183        config: &str,
4184        params: &AutoDeviceMapParams,
4185    ) -> Result<usize> {
4186        let AutoDeviceMapParams::Text {
4187            max_seq_len,
4188            max_batch_size,
4189        } = params
4190        else {
4191            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4192        };
4193
4194        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4195
4196        Ok(
4197            max_batch_size
4198                * cfg.num_attention_heads
4199                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4200        )
4201    }
4202    fn non_mapped_max_act_size_elems(
4203        &self,
4204        _config: &str,
4205        _params: &AutoDeviceMapParams,
4206    ) -> Result<usize> {
4207        Ok(0)
4208    }
4209
4210    fn non_mapped_size_in_bytes(
4211        &self,
4212        config: &str,
4213        dtype: DType,
4214        weight_pack_factor: usize,
4215        _matformer_config: Option<&MatformerSliceConfig>,
4216    ) -> Result<usize> {
4217        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4218        let elems = {
4219            let embed_tokens = cfg.hidden_size * cfg.vocab_size;
4220            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4221            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4222                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4223            } else {
4224                0
4225            };
4226            let norm = cfg.hidden_size;
4227            embed_tokens + lm_head + norm
4228        };
4229        Ok(elems * dtype.size_in_bytes())
4230    }
4231
4232    fn layer_sizes_in_bytes(
4233        &self,
4234        config: &str,
4235        dtype: DType,
4236        weight_pack_factor: usize,
4237        _matformer_config: Option<&MatformerSliceConfig>,
4238    ) -> Result<Vec<usize>> {
4239        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4240
4241        let mut layer_sizes_in_bytes = Vec::new();
4242        for layer_idx in 0..cfg.num_hidden_layers {
4243            let input_layernorm = cfg.hidden_size;
4244            let post_attention_layernorm = cfg.hidden_size;
4245
4246            let size_in = cfg.hidden_size;
4247            let size_q = cfg.head_dim() * cfg.num_attention_heads;
4248            let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
4249            let q_proj = size_in * size_q / weight_pack_factor;
4250            let k_proj = size_in * size_kv / weight_pack_factor;
4251            let v_proj = size_in * size_kv / weight_pack_factor;
4252            let o_proj = size_q * size_in / weight_pack_factor;
4253
4254            let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
4255                && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
4256            {
4257                let gate_size = cfg.hidden_size * cfg.num_experts;
4258                let expert_size = {
4259                    let h_size = cfg.hidden_size;
4260                    let i_size = cfg.moe_intermediate_size;
4261                    let gate_proj = h_size * i_size / weight_pack_factor;
4262                    let up_proj = h_size * i_size / weight_pack_factor;
4263                    let down_proj = i_size * h_size / weight_pack_factor;
4264                    gate_proj + up_proj + down_proj
4265                };
4266                expert_size * cfg.num_experts + gate_size
4267            } else {
4268                let h_size = cfg.hidden_size;
4269                let i_size = cfg.intermediate_size;
4270                let gate_proj = h_size * i_size / weight_pack_factor;
4271                let up_proj = h_size * i_size / weight_pack_factor;
4272                let down_proj = i_size * h_size / weight_pack_factor;
4273                gate_proj + up_proj + down_proj
4274            };
4275
4276            let q_norm = cfg.head_dim();
4277            let k_norm = cfg.head_dim();
4278
4279            let size_elems = input_layernorm
4280                + post_attention_layernorm
4281                + q_proj
4282                + k_proj
4283                + v_proj
4284                + o_proj
4285                + mlp_size
4286                + q_norm
4287                + k_norm;
4288
4289            let size_in_bytes = size_elems * dtype.size_in_bytes();
4290            layer_sizes_in_bytes.push(size_in_bytes);
4291        }
4292
4293        Ok(layer_sizes_in_bytes)
4294    }
4295
4296    fn num_layers(&self, config: &str) -> Result<usize> {
4297        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4298        Ok(cfg.num_hidden_layers)
4299    }
4300
4301    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4302        let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4303
4304        let cfg = ModelConfigMetadata {
4305            max_seq_len: cfg.max_position_embeddings,
4306            num_layers: cfg.num_hidden_layers,
4307            hidden_size: cfg.hidden_size,
4308            num_kv_heads: cfg.num_key_value_heads,
4309            num_attn_heads: cfg.num_attention_heads,
4310            sliding_window: cfg.sliding_window,
4311            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4312            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4313            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4314        };
4315
4316        Ok(Box::new(cfg))
4317    }
4318}
4319
4320// ======================== SmolLm3 loader
4321
4322/// [`NormalLoader`] for a SmolLm3 model.
4323///
4324/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
4325pub struct SmolLm3Loader;
4326
4327impl NormalModelLoader for SmolLm3Loader {
4328    fn load(
4329        &self,
4330        config: &str,
4331        vb: ShardedVarBuilder,
4332        normal_loading_metadata: NormalLoadingMetadata,
4333        attention_mechanism: AttentionImplementation,
4334    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4335        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4336
4337        Ok(Box::new(models::smollm3::SmolLm3::new(
4338            &cfg,
4339            vb,
4340            self.is_gptx(config)?,
4341            normal_loading_metadata,
4342            attention_mechanism,
4343        )?))
4344    }
4345    fn load_xlora(
4346        &self,
4347        _config: &str,
4348        _vb: ShardedVarBuilder,
4349        _lora_config: &[((String, String), LoraConfig)],
4350        _xlora_config: Option<XLoraConfig>,
4351        _xlora_ordering: Ordering,
4352        _normal_loading_metadata: NormalLoadingMetadata,
4353        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4354    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4355        todo!()
4356    }
4357    fn is_gptx(&self, _: &str) -> Result<bool> {
4358        Ok(true)
4359    }
4360    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4361        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4362        Ok(Box::new(cfg))
4363    }
4364}
4365
4366impl IsqModelLoader for SmolLm3Loader {
4367    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4368        Ok(vec![
4369            Regex::new(r"lm_head\.(weight|bias)$")?,
4370            // Attention
4371            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4372            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4373            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4374            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4375            // MLP
4376            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4377            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4378            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4379        ])
4380    }
4381    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4382        self.isq_layer_regexes(config)
4383    }
4384}
4385
4386impl DeviceMappedModelLoader for SmolLm3Loader {
4387    fn mapped_max_act_size_elems(
4388        &self,
4389        config: &str,
4390        params: &AutoDeviceMapParams,
4391    ) -> Result<usize> {
4392        let AutoDeviceMapParams::Text {
4393            max_seq_len,
4394            max_batch_size,
4395        } = params
4396        else {
4397            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4398        };
4399
4400        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4401
4402        Ok(
4403            max_batch_size
4404                * cfg.num_attention_heads
4405                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4406        )
4407    }
4408    fn non_mapped_max_act_size_elems(
4409        &self,
4410        _config: &str,
4411        _params: &AutoDeviceMapParams,
4412    ) -> Result<usize> {
4413        Ok(0)
4414    }
4415
4416    fn non_mapped_size_in_bytes(
4417        &self,
4418        config: &str,
4419        dtype: DType,
4420        weight_pack_factor: usize,
4421        _matformer_config: Option<&MatformerSliceConfig>,
4422    ) -> Result<usize> {
4423        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4424
4425        let elems = {
4426            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4427            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4428            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4429                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4430            } else {
4431                0
4432            };
4433            let norm = cfg.hidden_size;
4434            embed_tokens + lm_head + norm
4435        };
4436        Ok(elems * dtype.size_in_bytes())
4437    }
4438
4439    fn layer_sizes_in_bytes(
4440        &self,
4441        config: &str,
4442        dtype: DType,
4443        weight_pack_factor: usize,
4444        _matformer_config: Option<&MatformerSliceConfig>,
4445    ) -> Result<Vec<usize>> {
4446        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4447
4448        let per_layer_elems = {
4449            let input_layernorm = cfg.hidden_size;
4450            let post_attention_layernorm = cfg.hidden_size;
4451
4452            let size_in = cfg.hidden_size;
4453            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4454            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4455            let q_proj = size_in * size_q / weight_pack_factor;
4456            let k_proj = size_in * size_kv / weight_pack_factor;
4457            let v_proj = size_in * size_kv / weight_pack_factor;
4458            let o_proj = size_q * size_in / weight_pack_factor;
4459
4460            let h_size = cfg.hidden_size;
4461            let i_size = cfg.intermediate_size;
4462            let gate_proj = h_size * i_size / weight_pack_factor;
4463            let up_proj = h_size * i_size / weight_pack_factor;
4464            let down_proj = i_size * h_size / weight_pack_factor;
4465
4466            input_layernorm
4467                + post_attention_layernorm
4468                + q_proj
4469                + k_proj
4470                + v_proj
4471                + o_proj
4472                + gate_proj
4473                + up_proj
4474                + down_proj
4475        };
4476        Ok(vec![
4477            per_layer_elems * dtype.size_in_bytes();
4478            cfg.num_hidden_layers
4479        ])
4480    }
4481
4482    fn num_layers(&self, config: &str) -> Result<usize> {
4483        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4484
4485        Ok(cfg.num_hidden_layers)
4486    }
4487    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4488        let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4489
4490        let cfg = ModelConfigMetadata {
4491            max_seq_len: cfg.max_position_embeddings,
4492            num_layers: cfg.num_hidden_layers,
4493            hidden_size: cfg.hidden_size,
4494            num_kv_heads: cfg.num_key_value_heads,
4495            num_attn_heads: cfg.num_attention_heads,
4496            sliding_window: None,
4497            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4498            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4499            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4500        };
4501
4502        Ok(Box::new(cfg))
4503    }
4504}
4505
4506// ======================== GraniteMoeHybrid loader
4507
4508/// [`NormalLoader`] for a GraniteMoeHybrid model (IBM Granite 4.0).
4509///
4510/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
4511pub struct GraniteMoeHybridLoader;
4512
4513impl NormalModelLoader for GraniteMoeHybridLoader {
4514    fn load(
4515        &self,
4516        config: &str,
4517        vb: ShardedVarBuilder,
4518        normal_loading_metadata: NormalLoadingMetadata,
4519        attention_mechanism: AttentionImplementation,
4520    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4521        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4522
4523        Ok(Box::new(models::granite::GraniteMoeHybrid::new(
4524            &cfg,
4525            vb,
4526            self.is_gptx(config)?,
4527            normal_loading_metadata,
4528            attention_mechanism,
4529        )?))
4530    }
4531    fn load_xlora(
4532        &self,
4533        _config: &str,
4534        _vb: ShardedVarBuilder,
4535        _lora_config: &[((String, String), LoraConfig)],
4536        _xlora_config: Option<XLoraConfig>,
4537        _xlora_ordering: Ordering,
4538        _normal_loading_metadata: NormalLoadingMetadata,
4539        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4540    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4541        anyhow::bail!("GraniteMoeHybrid does not support X-LoRA")
4542    }
4543    fn is_gptx(&self, _: &str) -> Result<bool> {
4544        Ok(true)
4545    }
4546    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4547        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4548        Ok(Box::new(cfg))
4549    }
4550    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4551        Ok(true)
4552    }
4553}
4554
4555impl IsqModelLoader for GraniteMoeHybridLoader {
4556    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4557        Ok(vec![
4558            Regex::new(r"lm_head\.(weight|bias)$")?,
4559            // Attention
4560            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4561            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4562            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4563            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4564            // MLP (GraniteMLP uses shared_mlp.input_linear and shared_mlp.output_linear)
4565            Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
4566            Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
4567        ])
4568    }
4569    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4570        self.isq_layer_regexes(config)
4571    }
4572}
4573
4574impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
4575    fn mapped_max_act_size_elems(
4576        &self,
4577        config: &str,
4578        params: &AutoDeviceMapParams,
4579    ) -> Result<usize> {
4580        let AutoDeviceMapParams::Text {
4581            max_seq_len,
4582            max_batch_size,
4583        } = params
4584        else {
4585            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4586        };
4587
4588        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4589
4590        Ok(
4591            max_batch_size
4592                * cfg.num_attention_heads
4593                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4594        )
4595    }
4596    fn non_mapped_max_act_size_elems(
4597        &self,
4598        _config: &str,
4599        _params: &AutoDeviceMapParams,
4600    ) -> Result<usize> {
4601        Ok(0)
4602    }
4603
4604    fn non_mapped_size_in_bytes(
4605        &self,
4606        config: &str,
4607        dtype: DType,
4608        weight_pack_factor: usize,
4609        _matformer_config: Option<&MatformerSliceConfig>,
4610    ) -> Result<usize> {
4611        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4612
4613        let elems = {
4614            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4615            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4616            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4617                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4618            } else {
4619                0
4620            };
4621            let norm = cfg.hidden_size;
4622            embed_tokens + lm_head + norm
4623        };
4624        Ok(elems * dtype.size_in_bytes())
4625    }
4626
4627    fn layer_sizes_in_bytes(
4628        &self,
4629        config: &str,
4630        dtype: DType,
4631        weight_pack_factor: usize,
4632        _matformer_config: Option<&MatformerSliceConfig>,
4633    ) -> Result<Vec<usize>> {
4634        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4635
4636        let per_layer_elems = {
4637            let input_layernorm = cfg.hidden_size;
4638            let post_attention_layernorm = cfg.hidden_size;
4639
4640            let size_in = cfg.hidden_size;
4641            let size_q = cfg.head_dim() * cfg.num_attention_heads;
4642            let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
4643            let q_proj = size_in * size_q / weight_pack_factor;
4644            let k_proj = size_in * size_kv / weight_pack_factor;
4645            let v_proj = size_in * size_kv / weight_pack_factor;
4646            let o_proj = size_q * size_in / weight_pack_factor;
4647
4648            let h_size = cfg.hidden_size;
4649            let shared_i_size = cfg.shared_intermediate_size();
4650            // GraniteMLP: input_linear (h_size -> shared_i_size * 2), output_linear (shared_i_size -> h_size)
4651            let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
4652            let output_linear = shared_i_size * h_size / weight_pack_factor;
4653
4654            input_layernorm
4655                + post_attention_layernorm
4656                + q_proj
4657                + k_proj
4658                + v_proj
4659                + o_proj
4660                + input_linear
4661                + output_linear
4662        };
4663        Ok(vec![
4664            per_layer_elems * dtype.size_in_bytes();
4665            cfg.num_hidden_layers
4666        ])
4667    }
4668
4669    fn num_layers(&self, config: &str) -> Result<usize> {
4670        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4671
4672        Ok(cfg.num_hidden_layers)
4673    }
4674    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4675        let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4676
4677        let cfg = ModelConfigMetadata {
4678            max_seq_len: cfg.max_position_embeddings,
4679            num_layers: cfg.num_hidden_layers,
4680            hidden_size: cfg.hidden_size,
4681            num_kv_heads: cfg.num_key_value_heads(),
4682            num_attn_heads: cfg.num_attention_heads,
4683            sliding_window: None,
4684            k_head_dim: cfg.head_dim(),
4685            v_head_dim: cfg.head_dim(),
4686            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4687        };
4688
4689        Ok(Box::new(cfg))
4690    }
4691}
4692
4693// ======================== GPT-OSS loader
4694
4695/// [`NormalLoader`] for a GPT-OSS model.
4696///
4697/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
4698pub struct GptOssLoader;
4699
4700impl NormalModelLoader for GptOssLoader {
4701    fn load(
4702        &self,
4703        config: &str,
4704        vb: ShardedVarBuilder,
4705        normal_loading_metadata: NormalLoadingMetadata,
4706        attention_mechanism: AttentionImplementation,
4707    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4708        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4709
4710        Ok(Box::new(models::gpt_oss::Model::new(
4711            &cfg,
4712            vb,
4713            self.is_gptx(config)?,
4714            normal_loading_metadata,
4715            attention_mechanism,
4716        )?))
4717    }
4718    fn load_xlora(
4719        &self,
4720        _config: &str,
4721        _vb: ShardedVarBuilder,
4722        _lora_config: &[((String, String), LoraConfig)],
4723        _xlora_config: Option<XLoraConfig>,
4724        _xlora_ordering: Ordering,
4725        _normal_loading_metadata: NormalLoadingMetadata,
4726        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4727    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4728        anyhow::bail!("GPT-OSS does not support X-LoRA")
4729    }
4730    fn is_gptx(&self, _: &str) -> Result<bool> {
4731        Ok(true)
4732    }
4733    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4734        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4735        Ok(Box::new(cfg))
4736    }
4737}
4738
4739impl IsqModelLoader for GptOssLoader {
4740    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4741        // Only attention layers are ISQ-able - MoE experts are already MXFP4 quantized
4742        Ok(vec![
4743            Regex::new(r"lm_head\.(weight|bias)$")?,
4744            // Attention
4745            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4746            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4747            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4748            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4749        ])
4750    }
4751    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4752        self.isq_layer_regexes(config)
4753    }
4754}
4755
4756impl DeviceMappedModelLoader for GptOssLoader {
4757    fn mapped_max_act_size_elems(
4758        &self,
4759        config: &str,
4760        params: &AutoDeviceMapParams,
4761    ) -> Result<usize> {
4762        let AutoDeviceMapParams::Text {
4763            max_seq_len,
4764            max_batch_size,
4765        } = params
4766        else {
4767            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4768        };
4769
4770        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4771
4772        Ok(
4773            max_batch_size
4774                * cfg.num_attention_heads
4775                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4776        )
4777    }
4778    fn non_mapped_max_act_size_elems(
4779        &self,
4780        _config: &str,
4781        _params: &AutoDeviceMapParams,
4782    ) -> Result<usize> {
4783        Ok(0)
4784    }
4785
4786    fn non_mapped_size_in_bytes(
4787        &self,
4788        config: &str,
4789        dtype: DType,
4790        weight_pack_factor: usize,
4791        _matformer_config: Option<&MatformerSliceConfig>,
4792    ) -> Result<usize> {
4793        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4794
4795        let elems = {
4796            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4797            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4798                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4799            } else {
4800                0
4801            };
4802            let norm = cfg.hidden_size;
4803            embed_tokens + lm_head + norm
4804        };
4805        Ok(elems * dtype.size_in_bytes())
4806    }
4807
4808    fn layer_sizes_in_bytes(
4809        &self,
4810        config: &str,
4811        dtype: DType,
4812        weight_pack_factor: usize,
4813        _matformer_config: Option<&MatformerSliceConfig>,
4814    ) -> Result<Vec<usize>> {
4815        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4816
4817        let per_layer_elems = {
4818            let input_layernorm = cfg.hidden_size;
4819            let post_attention_layernorm = cfg.hidden_size;
4820
4821            let size_in = cfg.hidden_size;
4822            let head_dim = cfg.head_dim();
4823            let size_q = head_dim * cfg.num_attention_heads;
4824            let size_kv = head_dim * cfg.num_key_value_heads;
4825            let q_proj =
4826                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4827            let k_proj =
4828                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4829            let v_proj =
4830                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4831            let o_proj =
4832                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4833
4834            // MoE experts - MXFP4 quantized, so very compact
4835            // gate_up_proj: [num_experts, intermediate_size * 2, hidden_size/2] packed
4836            // down_proj: [num_experts, hidden_size, intermediate_size/2] packed
4837            // At 4 bits per weight, packing factor is 2
4838            let mxfp4_pack = 2;
4839            let gate_up_proj_size =
4840                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4841            let down_proj_size =
4842                cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4843            // Plus scales at 1 byte per 32 elements
4844            let gate_up_scales =
4845                cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4846            let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4847            // Plus biases
4848            let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4849            let down_bias = cfg.num_local_experts * cfg.hidden_size;
4850            // Router
4851            let router = cfg.hidden_size * cfg.num_local_experts;
4852            // Sinks per head
4853            let sinks = cfg.num_attention_heads;
4854
4855            input_layernorm
4856                + post_attention_layernorm
4857                + q_proj
4858                + k_proj
4859                + v_proj
4860                + o_proj
4861                + gate_up_proj_size
4862                + down_proj_size
4863                + gate_up_scales
4864                + down_scales
4865                + gate_up_bias
4866                + down_bias
4867                + router
4868                + sinks
4869        };
4870        Ok(vec![
4871            per_layer_elems * dtype.size_in_bytes();
4872            cfg.num_hidden_layers
4873        ])
4874    }
4875
4876    fn num_layers(&self, config: &str) -> Result<usize> {
4877        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4878
4879        Ok(cfg.num_hidden_layers)
4880    }
4881    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4882        let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4883
4884        let head_dim = cfg.head_dim();
4885        let cfg = ModelConfigMetadata {
4886            max_seq_len: cfg.max_position_embeddings,
4887            num_layers: cfg.num_hidden_layers,
4888            hidden_size: cfg.hidden_size,
4889            num_kv_heads: cfg.num_key_value_heads,
4890            num_attn_heads: cfg.num_attention_heads,
4891            sliding_window: cfg.sliding_window,
4892            k_head_dim: head_dim,
4893            v_head_dim: head_dim,
4894            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4895        };
4896
4897        Ok(Box::new(cfg))
4898    }
4899}
4900
4901// ======================== Qwen3Next loader
4902
4903/// [`NormalLoader`] for a Qwen3Next (Qwen3-Coder-Next) model.
4904///
4905/// [`NormalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.NormalLoader.html
4906pub struct Qwen3NextLoader;
4907
4908impl NormalModelLoader for Qwen3NextLoader {
4909    fn load(
4910        &self,
4911        config: &str,
4912        vb: ShardedVarBuilder,
4913        normal_loading_metadata: NormalLoadingMetadata,
4914        attention_mechanism: AttentionImplementation,
4915    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4916        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4917
4918        Ok(Box::new(models::qwen3_next::Model::new(
4919            &cfg,
4920            vb,
4921            self.is_gptx(config)?,
4922            normal_loading_metadata,
4923            attention_mechanism,
4924        )?))
4925    }
4926    fn load_xlora(
4927        &self,
4928        _config: &str,
4929        _vb: ShardedVarBuilder,
4930        _lora_config: &[((String, String), LoraConfig)],
4931        _xlora_config: Option<XLoraConfig>,
4932        _xlora_ordering: Ordering,
4933        _normal_loading_metadata: NormalLoadingMetadata,
4934        _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4935    ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4936        anyhow::bail!("Qwen3Next does not support X-LoRA")
4937    }
4938    fn is_gptx(&self, _: &str) -> Result<bool> {
4939        Ok(true)
4940    }
4941    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4942        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4943        Ok(Box::new(cfg))
4944    }
4945    fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4946        Ok(true)
4947    }
4948}
4949
4950impl IsqModelLoader for Qwen3NextLoader {
4951    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4952        Ok(vec![
4953            Regex::new(r"lm_head\.(weight|bias)$")?,
4954            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4955            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4956            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4957            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4958            Regex::new(r"layers\.(\d+)\.linear_attn\.out_proj\.(weight|bias)$")?,
4959            Regex::new(
4960                r"layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.(weight|bias)$",
4961            )?,
4962            Regex::new(
4963                r"layers\.(\d+)\.mlp\.shared_expert\.(gate_proj|up_proj|down_proj)\.(weight|bias)$",
4964            )?,
4965        ])
4966    }
4967    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4968        self.isq_layer_regexes(config)
4969    }
4970}
4971
4972impl DeviceMappedModelLoader for Qwen3NextLoader {
4973    fn mapped_max_act_size_elems(
4974        &self,
4975        config: &str,
4976        params: &AutoDeviceMapParams,
4977    ) -> Result<usize> {
4978        let AutoDeviceMapParams::Text {
4979            max_seq_len,
4980            max_batch_size,
4981        } = params
4982        else {
4983            anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4984        };
4985
4986        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4987
4988        Ok(
4989            max_batch_size
4990                * cfg.num_attention_heads
4991                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4992        )
4993    }
4994    fn non_mapped_max_act_size_elems(
4995        &self,
4996        _config: &str,
4997        _params: &AutoDeviceMapParams,
4998    ) -> Result<usize> {
4999        Ok(0)
5000    }
5001
5002    fn non_mapped_size_in_bytes(
5003        &self,
5004        config: &str,
5005        dtype: DType,
5006        weight_pack_factor: usize,
5007        _matformer_config: Option<&MatformerSliceConfig>,
5008    ) -> Result<usize> {
5009        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5010
5011        let elems = {
5012            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5013            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
5014                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5015            } else {
5016                0
5017            };
5018            let norm = cfg.hidden_size;
5019            embed_tokens + lm_head + norm
5020        };
5021        Ok(elems * dtype.size_in_bytes())
5022    }
5023
5024    fn layer_sizes_in_bytes(
5025        &self,
5026        config: &str,
5027        dtype: DType,
5028        weight_pack_factor: usize,
5029        _matformer_config: Option<&MatformerSliceConfig>,
5030    ) -> Result<Vec<usize>> {
5031        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5032        let layer_types = cfg.layer_types();
5033        let mut layer_sizes = Vec::with_capacity(cfg.num_hidden_layers);
5034
5035        for layer_type in &layer_types {
5036            let input_layernorm = cfg.hidden_size;
5037            let post_attention_layernorm = cfg.hidden_size;
5038
5039            let attn_elems = match layer_type {
5040                crate::models::qwen3_next::LayerType::FullAttention => {
5041                    let hidden = cfg.hidden_size;
5042                    let q_dim = cfg.head_dim * cfg.num_attention_heads;
5043                    let kv_dim = cfg.head_dim * cfg.num_key_value_heads;
5044                    let q_proj = hidden * q_dim * 2 / weight_pack_factor;
5045                    let k_proj = hidden * kv_dim / weight_pack_factor;
5046                    let v_proj = hidden * kv_dim / weight_pack_factor;
5047                    let o_proj = q_dim * hidden / weight_pack_factor;
5048                    let q_norm = cfg.head_dim;
5049                    let k_norm = cfg.head_dim;
5050                    q_proj + k_proj + v_proj + o_proj + q_norm + k_norm
5051                }
5052                crate::models::qwen3_next::LayerType::LinearAttention => {
5053                    let hidden = cfg.hidden_size;
5054                    let key_dim = cfg.linear_key_dim();
5055                    let value_dim = cfg.linear_value_dim();
5056                    let conv_dim = cfg.linear_conv_dim();
5057                    // in_proj_qkvz: (2 * key_dim + 2 * value_dim, hidden)
5058                    let in_proj_qkvz = hidden * (key_dim * 2 + value_dim * 2);
5059                    // in_proj_ba: (2 * num_v_heads, hidden)
5060                    let in_proj_ba = hidden * (cfg.linear_num_value_heads * 2);
5061                    let out_proj = value_dim * hidden / weight_pack_factor;
5062                    let conv1d = conv_dim * cfg.linear_conv_kernel_dim;
5063                    let dt_bias = cfg.linear_num_value_heads;
5064                    let a_log = cfg.linear_num_value_heads;
5065                    let norm = cfg.linear_value_head_dim;
5066                    in_proj_qkvz + in_proj_ba + out_proj + conv1d + dt_bias + a_log + norm
5067                }
5068            };
5069
5070            let moe_gate = cfg.hidden_size * cfg.num_experts;
5071            let shared_expert =
5072                3 * cfg.hidden_size * cfg.shared_expert_intermediate_size / weight_pack_factor;
5073            let routed_experts = cfg.num_experts * 3 * cfg.hidden_size * cfg.moe_intermediate_size
5074                / weight_pack_factor;
5075
5076            let per_layer_elems = input_layernorm
5077                + post_attention_layernorm
5078                + attn_elems
5079                + moe_gate
5080                + shared_expert
5081                + routed_experts;
5082
5083            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5084        }
5085
5086        Ok(layer_sizes)
5087    }
5088
5089    fn num_layers(&self, config: &str) -> Result<usize> {
5090        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5091        Ok(cfg.num_hidden_layers)
5092    }
5093    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5094        let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5095
5096        let cfg = ModelConfigMetadata {
5097            max_seq_len: cfg.max_position_embeddings,
5098            num_layers: cfg.num_hidden_layers,
5099            hidden_size: cfg.hidden_size,
5100            num_kv_heads: cfg.num_key_value_heads,
5101            num_attn_heads: cfg.num_attention_heads,
5102            sliding_window: None,
5103            k_head_dim: cfg.head_dim,
5104            v_head_dim: cfg.head_dim,
5105            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5106        };
5107
5108        Ok(Box::new(cfg))
5109    }
5110}