Skip to main content

hanzo_engine/pipeline/loaders/
multimodal_loaders.rs

1use std::any::Any;
2use std::sync::atomic::AtomicUsize;
3use std::sync::Arc;
4use std::{fmt::Debug, str::FromStr};
5
6use anyhow::Result;
7use hanzo_ml::{DType, Device, Tensor, D};
8use hanzo_nn::Conv2dConfig;
9use hanzo_quant::log::once_log_debug;
10use hanzo_quant::ShardedVarBuilder;
11use image::{ColorType, DynamicImage};
12use itertools::Itertools;
13
14#[cfg(feature = "pyo3_macros")]
15use pyo3::pyclass;
16
17use regex::Regex;
18use serde::Deserialize;
19
20use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
21
22use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
23use crate::amoe::AnyMoeBaseModelMixin;
24use crate::attention::ATTENTION_CHUNK_SIZE;
25use crate::device_map::DeviceMapper;
26use crate::layers::Conv3dConfig;
27use crate::matformer::MatformerSliceConfig;
28use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
29use crate::pipeline::isq::IsqModelLoader;
30use crate::pipeline::loaders::AutoDeviceMapParams;
31use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
32use crate::pipeline::{
33    EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
34    SupportedModality,
35};
36use crate::speculative::SpeculativeTargetMixin;
37use crate::utils::varbuilder_utils::DeviceForLoadTensor;
38use crate::vision_models::clip::ClipConfig;
39use crate::vision_models::gemma3::config::Gemma3Config;
40use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
41use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
42use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
43use crate::vision_models::gemma4::config::Gemma4Config;
44use crate::vision_models::gemma4::{Gemma4Model, Gemma4Processor};
45use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
46use crate::vision_models::idefics2_input_processor::Idefics2Processor;
47use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
48use crate::vision_models::image_processor::ImagePreProcessor;
49use crate::vision_models::inputs_processor::Phi4MMProcessor;
50use crate::vision_models::llama4::{
51    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
52};
53use crate::vision_models::llava::config::Config as LLaVAConfig;
54use crate::vision_models::llava15::Model as LLaVA;
55use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
56use crate::vision_models::llava_next::Model as LLaVANext;
57use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
58use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
59use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
60use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
61use crate::vision_models::phi3_inputs_processor::Phi3Processor;
62use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
63use crate::vision_models::preprocessor_config::PreProcessorConfig;
64use crate::vision_models::processor_config::ProcessorConfig;
65use crate::vision_models::qwen2_5_vl::{
66    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
67};
68use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
69use crate::vision_models::qwen3_5::{Config as Qwen3_5Config, Qwen3_5Model, Qwen3_5Processor};
70use crate::vision_models::qwen3_5_moe::{
71    Config as Qwen3_5MoeConfig, Qwen3_5MoeModel, Qwen3_5MoeProcessor,
72};
73use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
74use crate::vision_models::qwen3_vl_moe::{
75    Config as Qwen3VLMoEConfig, Qwen3VLMoEModel, Qwen3VLMoEProcessor,
76};
77use crate::vision_models::voxtral::config::VoxtralConfig;
78use crate::vision_models::voxtral::{VoxtralModel, VoxtralProcessor};
79use crate::vision_models::{minicpmo, phi4};
80
81pub trait MultimodalModel: IsqModel + AnyMoeBaseModelMixin + SpeculativeTargetMixin {
82    // pixel_values and pixel_attention_mask only specified for prompt seqs
83    #[allow(clippy::too_many_arguments)]
84    fn forward(
85        &self,
86        input_ids: &Tensor,
87        pixel_values: Option<Tensor>,
88        seqlen_offsets: &[usize],
89        context_lens: Vec<(usize, usize)>,
90        position_ids: Vec<usize>,
91        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
92        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
93        flash_params: &FlashParams,
94    ) -> hanzo_ml::Result<Tensor>;
95    fn device(&self) -> &Device;
96    fn cache(&self) -> &EitherCache;
97    fn cache_mut(&mut self) -> &mut EitherCache;
98    fn max_seq_len(&self) -> usize;
99    fn config(&self) -> &ModelConfigMetadata;
100    fn model_config(&self) -> Arc<dyn ModelConfigLike + Send + Sync> {
101        Arc::new(self.config().clone())
102    }
103    /// For a prompt without images. Requires batch size of 1!
104    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
105    /// Return encoder cache hit/miss counters (hits, misses) if this model has an encoder cache.
106    fn encoder_cache_counters(&self) -> Option<(Arc<AtomicUsize>, Arc<AtomicUsize>)> {
107        None
108    }
109    /// Reset model-specific state (e.g. cached audio embeddings) between requests.
110    /// Called when the pipeline's non-granular state is reset.
111    fn reset_model_specific_state(&self) {}
112}
113
114pub trait MultimodalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
115    fn load(
116        &self,
117        config: &str,
118        vb: ShardedVarBuilder,
119        normal_loading_metadata: NormalLoadingMetadata,
120        attention_mechanism: AttentionImplementation,
121    ) -> Result<Box<dyn MultimodalModel + Send + Sync>>;
122    fn is_gptx(&self, config: &str) -> bool;
123    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
124    fn get_processor(
125        &self,
126        model_config: &str,
127        processor_config: Option<ProcessorConfig>,
128        preprocessor_config: PreProcessorConfig,
129        max_edge: Option<u32>,
130    ) -> Arc<dyn Processor + Send + Sync>;
131    fn supports_paged_attention(&self, config: &str) -> bool;
132    fn supports_prefix_cacher(&self, _config: &str) -> bool {
133        // Default is false, specific model must override.
134        false
135    }
136    fn modalities(&self, config: &str) -> Result<Modalities>;
137    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
138    /// Return a default chat template (Jinja string) for models that don't ship a
139    /// `tokenizer_config.json` or `chat_template.jinja`. Returns `None` by default.
140    /// The `config` parameter is the raw model config JSON, used by `AutoMultimodalLoader`
141    /// to delegate to the correct concrete loader.
142    fn default_chat_template(&self, _config: &str) -> Option<String> {
143        None
144    }
145    /// Return default (bos_token, eos_token) strings for models that don't ship a
146    /// `tokenizer_config.json`. Used to populate the chat template context and
147    /// EOS token detection. Returns `None` by default.
148    fn default_bos_eos(&self, _config: &str) -> Option<(String, String)> {
149        None
150    }
151    fn get_device_for_tensor(
152        &self,
153        config: &str,
154        _mapper: &dyn DeviceMapper,
155        loading_isq: bool,
156    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
157        if loading_isq {
158            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
159        } else {
160            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
161            let num_layers = self.model_config(config)?.num_layers();
162            let closure = move |name: String| {
163                if let Some(captures) = re.captures(&name) {
164                    captures
165                        .get(1)
166                        .and_then(|m| m.as_str().parse::<usize>().ok())
167                        .map(|l| l.min(num_layers))
168                        .map(DeviceForLoadTensor::Idx)
169                        .unwrap_or(DeviceForLoadTensor::Base)
170                } else {
171                    DeviceForLoadTensor::Base
172                }
173            };
174
175            Ok(Arc::new(closure))
176        }
177    }
178}
179
180#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
181#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
182/// The architecture to load the multimodal model as.
183pub enum MultimodalLoaderType {
184    #[serde(rename = "phi3v")]
185    Phi3V,
186    #[serde(rename = "idefics2")]
187    Idefics2,
188    #[serde(rename = "llava_next")]
189    LLaVANext,
190    #[serde(rename = "llava")]
191    LLaVA,
192    #[serde(rename = "vllama")]
193    VLlama,
194    #[serde(rename = "qwen2vl")]
195    Qwen2VL,
196    #[serde(rename = "idefics3")]
197    Idefics3,
198    #[serde(rename = "minicpmo")]
199    MiniCpmO,
200    #[serde(rename = "phi4mm")]
201    Phi4MM,
202    #[serde(rename = "qwen2_5vl")]
203    Qwen2_5VL,
204    #[serde(rename = "gemma3")]
205    Gemma3,
206    #[serde(rename = "mistral3")]
207    Mistral3,
208    #[serde(rename = "llama4")]
209    Llama4,
210    #[serde(rename = "gemma3n")]
211    Gemma3n,
212    #[serde(rename = "qwen3vl")]
213    Qwen3VL,
214    #[serde(rename = "qwen3vlmoe")]
215    Qwen3VLMoE,
216    #[serde(rename = "qwen3_5")]
217    Qwen3_5,
218    #[serde(rename = "qwen3_5moe")]
219    Qwen3_5Moe,
220    #[serde(rename = "voxtral")]
221    Voxtral,
222    #[serde(rename = "gemma4")]
223    Gemma4,
224}
225
226// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
227impl MultimodalLoaderType {
228    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
229        match name {
230            "Phi3VForCausalLM" => Ok(Self::Phi3V),
231            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
232            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
233            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
234            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
235            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
236            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
237            "MiniCPMO" => Ok(Self::MiniCpmO),
238            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
239            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
240            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
241            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
242            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
243            "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
244            "Gemma4ForConditionalGeneration" => Ok(Self::Gemma4),
245            "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
246            "Qwen3VLMoeForConditionalGeneration" => Ok(Self::Qwen3VLMoE),
247            "Qwen3_5ForConditionalGeneration" => Ok(Self::Qwen3_5),
248            "Qwen3_5MoeForConditionalGeneration" => Ok(Self::Qwen3_5Moe),
249            "VoxtralForConditionalGeneration"
250            | "VoxtralRealtimeForConditionalGeneration" => Ok(Self::Voxtral),
251            other => anyhow::bail!(
252                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
253            ),
254        }
255    }
256}
257
258impl FromStr for MultimodalLoaderType {
259    type Err = String;
260    fn from_str(s: &str) -> Result<Self, Self::Err> {
261        match s {
262            "phi3v" => Ok(Self::Phi3V),
263            "idefics2" => Ok(Self::Idefics2),
264            "llava_next" => Ok(Self::LLaVANext),
265            "llava" => Ok(Self::LLaVA),
266            "vllama" => Ok(Self::VLlama),
267            "qwen2vl" => Ok(Self::Qwen2VL),
268            "idefics3" => Ok(Self::Idefics3),
269            "minicpmo" => Ok(Self::MiniCpmO),
270            "phi4mm" => Ok(Self::Phi4MM),
271            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
272            "gemma3" => Ok(Self::Gemma3),
273            "mistral3" => Ok(Self::Mistral3),
274            "llama4" => Ok(Self::Llama4),
275            "gemma3n" => Ok(Self::Gemma3n),
276            "gemma4" => Ok(Self::Gemma4),
277            "qwen3vl" => Ok(Self::Qwen3VL),
278            "qwen3vlmoe" => Ok(Self::Qwen3VLMoE),
279            "qwen3_5" => Ok(Self::Qwen3_5),
280            "qwen3_5moe" => Ok(Self::Qwen3_5Moe),
281            "voxtral" => Ok(Self::Voxtral),
282            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`, `gemma4`, `qwen3vl`, `qwen3vlmoe`, `qwen3_5`, `qwen3_5moe`, `voxtral`.")),
283        }
284    }
285}
286
287impl std::fmt::Display for MultimodalLoaderType {
288    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289        let name = match self {
290            MultimodalLoaderType::Phi3V => "phi3v",
291            MultimodalLoaderType::Idefics2 => "idefics2",
292            MultimodalLoaderType::LLaVANext => "llava_next",
293            MultimodalLoaderType::LLaVA => "llava",
294            MultimodalLoaderType::VLlama => "vllama",
295            MultimodalLoaderType::Qwen2VL => "qwen2vl",
296            MultimodalLoaderType::Idefics3 => "idefics3",
297            MultimodalLoaderType::MiniCpmO => "minicpmo",
298            MultimodalLoaderType::Phi4MM => "phi4mm",
299            MultimodalLoaderType::Qwen2_5VL => "qwen2_5vl",
300            MultimodalLoaderType::Gemma3 => "gemma3",
301            MultimodalLoaderType::Mistral3 => "mistral3",
302            MultimodalLoaderType::Llama4 => "llama4",
303            MultimodalLoaderType::Gemma3n => "gemma3n",
304            MultimodalLoaderType::Qwen3VL => "qwen3vl",
305            MultimodalLoaderType::Qwen3VLMoE => "qwen3vlmoe",
306            MultimodalLoaderType::Qwen3_5 => "qwen3_5",
307            MultimodalLoaderType::Qwen3_5Moe => "qwen3_5moe",
308            MultimodalLoaderType::Voxtral => "voxtral",
309            MultimodalLoaderType::Gemma4 => "gemma4",
310        };
311        write!(f, "{name}")
312    }
313}
314
315#[derive(Deserialize)]
316struct AutoMultimodalLoaderConfig {
317    #[serde(default)]
318    architectures: Vec<String>,
319    /// Voxtral params.json uses a `multimodal` key instead of `architectures`.
320    #[serde(default)]
321    multimodal: Option<serde_json::Value>,
322}
323
324/// Automatically selects a MultimodalModelLoader implementation based on the JSON `architectures` field.
325pub struct AutoMultimodalLoader;
326
327impl AutoMultimodalLoader {
328    fn get_loader(config: &str) -> Result<Box<dyn MultimodalModelLoader>> {
329        let auto_cfg: AutoMultimodalLoaderConfig = serde_json::from_str(config)?;
330
331        // Voxtral: params.json has `multimodal` but no `architectures`
332        if auto_cfg.multimodal.is_some() && auto_cfg.architectures.is_empty() {
333            once_log_debug("Automatic loader type determined to be `voxtral`");
334            return Ok(Box::new(VoxtralLoader));
335        }
336
337        if auto_cfg.architectures.len() != 1 {
338            anyhow::bail!("Expected exactly one architecture in config");
339        }
340
341        let name = &auto_cfg.architectures[0];
342        let tp = MultimodalLoaderType::from_causal_lm_name(name)?;
343
344        once_log_debug(format!("Automatic loader type determined to be `{tp}`"));
345
346        // Delegate to the concrete loader
347        Ok(match tp {
348            MultimodalLoaderType::Phi3V => Box::new(Phi3VLoader),
349            MultimodalLoaderType::Idefics2 => Box::new(Idefics2Loader),
350            MultimodalLoaderType::LLaVANext => Box::new(LLaVANextLoader),
351            MultimodalLoaderType::LLaVA => Box::new(LLaVALoader),
352            MultimodalLoaderType::VLlama => Box::new(VLlamaLoader),
353            MultimodalLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
354            MultimodalLoaderType::Idefics3 => Box::new(Idefics3Loader),
355            MultimodalLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
356            MultimodalLoaderType::Phi4MM => Box::new(Phi4MMLoader),
357            MultimodalLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
358            MultimodalLoaderType::Gemma3 => Box::new(Gemma3Loader),
359            MultimodalLoaderType::Mistral3 => Box::new(Mistral3Loader),
360            MultimodalLoaderType::Llama4 => Box::new(VLlama4Loader),
361            MultimodalLoaderType::Gemma3n => Box::new(Gemma3nLoader),
362            MultimodalLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
363            MultimodalLoaderType::Qwen3VLMoE => Box::new(Qwen3VLMoELoader),
364            MultimodalLoaderType::Qwen3_5 => Box::new(Qwen3_5Loader),
365            MultimodalLoaderType::Qwen3_5Moe => Box::new(Qwen3_5MoeLoader),
366            MultimodalLoaderType::Voxtral => Box::new(VoxtralLoader),
367            MultimodalLoaderType::Gemma4 => Box::new(Gemma4Loader),
368        })
369    }
370}
371
372impl MultimodalModelLoader for AutoMultimodalLoader {
373    fn load(
374        &self,
375        config: &str,
376        vb: ShardedVarBuilder,
377        normal_loading_metadata: NormalLoadingMetadata,
378        attention_mechanism: AttentionImplementation,
379    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
380        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
381    }
382
383    fn is_gptx(&self, config: &str) -> bool {
384        Self::get_loader(config)
385            .expect("AutoMultimodalLoader get_loader")
386            .is_gptx(config)
387    }
388
389    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
390        Self::get_loader(config)?.get_config_repr(config)
391    }
392
393    fn get_processor(
394        &self,
395        model_config: &str,
396        proc_cfg: Option<ProcessorConfig>,
397        preproc_cfg: PreProcessorConfig,
398        max_edge: Option<u32>,
399    ) -> Arc<dyn Processor + Send + Sync> {
400        Self::get_loader(model_config)
401            .expect("AutoMultimodalLoader get_loader")
402            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
403    }
404
405    fn supports_paged_attention(&self, config: &str) -> bool {
406        Self::get_loader(config)
407            .expect("AutoMultimodalLoader")
408            .supports_paged_attention(config)
409    }
410
411    fn modalities(&self, config: &str) -> Result<Modalities> {
412        Self::get_loader(config)?.modalities(config)
413    }
414
415    fn supports_prefix_cacher(&self, config: &str) -> bool {
416        Self::get_loader(config)
417            .expect("AutoMultimodalLoader")
418            .supports_prefix_cacher(config)
419    }
420
421    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
422        Self::get_loader(config)
423            .expect("AutoMultimodalLoader")
424            .prefixer(config)
425    }
426
427    fn default_chat_template(&self, config: &str) -> Option<String> {
428        Self::get_loader(config).ok()?.default_chat_template(config)
429    }
430
431    fn default_bos_eos(&self, config: &str) -> Option<(String, String)> {
432        Self::get_loader(config).ok()?.default_bos_eos(config)
433    }
434
435    fn get_device_for_tensor(
436        &self,
437        config: &str,
438        mapper: &dyn DeviceMapper,
439        loading_isq: bool,
440    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
441        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
442    }
443}
444
445impl IsqModelLoader for AutoMultimodalLoader {
446    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
447        Self::get_loader(config)?.isq_layer_regexes(config)
448    }
449    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
450        Self::get_loader(config)?.immediate_isq_predicates(config)
451    }
452    fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
453        Self::get_loader(config)?.isq_layer_regexes_moqe(config)
454    }
455    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
456        Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
457    }
458}
459
460impl DeviceMappedModelLoader for AutoMultimodalLoader {
461    fn mapped_max_act_size_elems(
462        &self,
463        config: &str,
464        params: &AutoDeviceMapParams,
465    ) -> Result<usize> {
466        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
467    }
468    fn non_mapped_max_act_size_elems(
469        &self,
470        config: &str,
471        params: &AutoDeviceMapParams,
472    ) -> Result<usize> {
473        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
474    }
475    fn non_mapped_size_in_bytes(
476        &self,
477        config: &str,
478        dtype: DType,
479        weight_pack_factor: usize,
480        _matformer_config: Option<&MatformerSliceConfig>,
481    ) -> Result<usize> {
482        Self::get_loader(config)?.non_mapped_size_in_bytes(
483            config,
484            dtype,
485            weight_pack_factor,
486            _matformer_config,
487        )
488    }
489    fn layer_sizes_in_bytes(
490        &self,
491        config: &str,
492        dtype: DType,
493        weight_pack_factor: usize,
494        _matformer_config: Option<&MatformerSliceConfig>,
495    ) -> Result<Vec<usize>> {
496        Self::get_loader(config)?.layer_sizes_in_bytes(
497            config,
498            dtype,
499            weight_pack_factor,
500            _matformer_config,
501        )
502    }
503    fn num_layers(&self, config: &str) -> Result<usize> {
504        Self::get_loader(config)?.num_layers(config)
505    }
506    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
507        Self::get_loader(config)?.model_config(config)
508    }
509}
510
511macro_rules! bias_if {
512    ($cond:expr, $size:expr) => {
513        if $cond {
514            $size
515        } else {
516            0
517        }
518    };
519}
520
521fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
522    let pre_layer_norm = cfg.hidden_size;
523    let final_layer_norm = cfg.hidden_size;
524
525    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
526    let num_positions = num_patches + 1;
527
528    let class_embedding = cfg.hidden_size;
529
530    let position_ids = num_positions;
531    let position_embedding = num_positions * cfg.hidden_size;
532
533    let conv2dconfig = Conv2dConfig {
534        stride: cfg.patch_size,
535        ..Default::default()
536    };
537    let patch_embedding =
538        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
539
540    let encoder_layer_elems = {
541        let layer_norm1 = cfg.hidden_size;
542        let layer_norm2 = cfg.hidden_size;
543
544        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
545        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
546        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
547        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
548
549        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
550        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
551
552        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
553    };
554
555    pre_layer_norm
556        + final_layer_norm
557        + class_embedding
558        + position_ids
559        + position_embedding
560        + patch_embedding
561        + cfg.num_hidden_layers * encoder_layer_elems
562}
563
564// ======================== Phi 3 loader
565
566/// [`MultimodalLoader`] for a Phi 3 Vision model.
567///
568/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
569pub struct Phi3VLoader;
570
571pub struct Phi3VPrefixer;
572
573impl MultimodalPromptPrefixer for Phi3VPrefixer {
574    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
575        // Image indexing starts at 0.
576        format!(
577            "{}{prompt}",
578            image_indexes
579                .into_iter()
580                .map(|image_index| format!("<|image_{}|>", image_index + 1))
581                .join("")
582        )
583    }
584}
585
586impl MultimodalModelLoader for Phi3VLoader {
587    fn load(
588        &self,
589        config: &str,
590        vb: ShardedVarBuilder,
591        normal_loading_metadata: NormalLoadingMetadata,
592        attention_mechanism: AttentionImplementation,
593    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
594        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
595        Ok(Box::new(Phi3::new(
596            &cfg,
597            vb,
598            self.is_gptx(config),
599            normal_loading_metadata,
600            attention_mechanism,
601        )?))
602    }
603    fn is_gptx(&self, _config: &str) -> bool {
604        true
605    }
606    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
607        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
608        Ok(Box::new(cfg))
609    }
610    fn get_processor(
611        &self,
612        _model_config: &str,
613        processor_config: Option<ProcessorConfig>,
614        preprocessor_config: PreProcessorConfig,
615        _max_edge: Option<u32>,
616    ) -> Arc<dyn Processor + Send + Sync> {
617        Phi3Processor::new_processor(processor_config, preprocessor_config)
618    }
619    fn supports_paged_attention(&self, _config: &str) -> bool {
620        true
621    }
622    fn supports_prefix_cacher(&self, _config: &str) -> bool {
623        true
624    }
625    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
626        Arc::new(Phi3VPrefixer)
627    }
628    fn modalities(&self, _config: &str) -> Result<Modalities> {
629        Ok(Modalities {
630            input: vec![SupportedModality::Text, SupportedModality::Vision],
631            output: vec![SupportedModality::Text],
632        })
633    }
634}
635
636impl IsqModelLoader for Phi3VLoader {
637    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
638        Ok(vec![
639            Regex::new(r"lm_head\.(weight|bias)$")?,
640            // Attention
641            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
642            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
643            // MLP
644            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
645            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
646        ])
647    }
648    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
649        self.isq_layer_regexes(config)
650    }
651}
652
653impl DeviceMappedModelLoader for Phi3VLoader {
654    fn mapped_max_act_size_elems(
655        &self,
656        config: &str,
657        params: &AutoDeviceMapParams,
658    ) -> Result<usize> {
659        // NOTE: we ignore max_num_images although it can only be one...
660        let AutoDeviceMapParams::Multimodal {
661            max_seq_len,
662            max_batch_size,
663            max_image_shape: _,
664            max_num_images,
665        } = params
666        else {
667            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
668        };
669
670        let cfg: Phi3Config = serde_json::from_str(config)?;
671
672        let vcfg = &PHI3V_CLIP_CONFIG;
673
674        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
675        let img_seq_len = (num_patches + 1) * max_num_images;
676
677        let max_text_attn = {
678            // This model injects the vision information directly into the input embeddings
679            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
680            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
681        };
682
683        Ok(max_text_attn)
684    }
685
686    fn non_mapped_max_act_size_elems(
687        &self,
688        config: &str,
689        params: &AutoDeviceMapParams,
690    ) -> Result<usize> {
691        // NOTE: we ignore max_num_images although it can only be one...
692        let AutoDeviceMapParams::Multimodal {
693            max_seq_len: _,
694            max_batch_size,
695            max_image_shape: _,
696            max_num_images,
697        } = params
698        else {
699            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
700        };
701
702        let cfg: Phi3Config = serde_json::from_str(config)?;
703
704        let vcfg = &PHI3V_CLIP_CONFIG;
705
706        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
707        let img_seq_len = num_patches + 1;
708
709        let max_vision_attn = {
710            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
711        };
712
713        Ok(max_vision_attn)
714    }
715
716    fn non_mapped_size_in_bytes(
717        &self,
718        config: &str,
719        dtype: DType,
720        weight_pack_factor: usize,
721        _matformer_config: Option<&MatformerSliceConfig>,
722    ) -> Result<usize> {
723        let cfg: Phi3Config = serde_json::from_str(config)?;
724        let elems = {
725            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
726            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
727            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
728                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
729            } else {
730                0
731            };
732            let norm = cfg.hidden_size;
733
734            let image_embed = {
735                let projection_cls = cfg
736                    .embd_layer
737                    .projection_cls
738                    .clone()
739                    .unwrap_or("linear".to_string());
740                let with_learnable_separator =
741                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
742                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
743                let image_dim_out = cfg.img_processor.image_dim_out;
744
745                let proj = match (projection_cls.as_str(), use_hd_transform) {
746                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
747                    ("mlp", true) => {
748                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
749                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
750                        a + b
751                    }
752                    ("mlp", false) => {
753                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
754                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
755                        a + b
756                    }
757                    _ => {
758                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
759                    }
760                };
761
762                let (glb_gn, sub_gn) = if with_learnable_separator {
763                    let glb_gn = image_dim_out * 4;
764                    let sub_gn = image_dim_out * 4;
765                    (glb_gn, sub_gn)
766                } else {
767                    (0, 0)
768                };
769
770                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
771
772                proj + glb_gn + sub_gn + clip_vit
773            };
774
775            embed_tokens + lm_head + norm + image_embed
776        };
777
778        Ok(elems * dtype.size_in_bytes())
779    }
780
781    fn layer_sizes_in_bytes(
782        &self,
783        config: &str,
784        dtype: DType,
785        weight_pack_factor: usize,
786        _matformer_config: Option<&MatformerSliceConfig>,
787    ) -> Result<Vec<usize>> {
788        let cfg: Phi3Config = serde_json::from_str(config)?;
789        let per_layer_elems = {
790            let input_layernorm = cfg.hidden_size;
791            let post_attention_layernorm = cfg.hidden_size;
792
793            let size_in = cfg.hidden_size;
794            let head_dim = cfg.head_dim();
795            let op_size =
796                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
797            let qkv_proj = size_in * op_size / weight_pack_factor;
798            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
799
800            let h_size = cfg.hidden_size;
801            let i_size = cfg.intermediate_size;
802            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
803            let down_proj = h_size * i_size / weight_pack_factor;
804
805            input_layernorm
806                + post_attention_layernorm
807                + qkv_proj
808                + o_proj
809                + gate_up_proj
810                + down_proj
811        };
812        Ok(vec![
813            per_layer_elems * dtype.size_in_bytes();
814            cfg.num_hidden_layers
815        ])
816    }
817
818    fn num_layers(&self, config: &str) -> Result<usize> {
819        let cfg: Phi3Config = serde_json::from_str(config)?;
820        Ok(cfg.num_hidden_layers)
821    }
822
823    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
824        let cfg: Phi3Config = serde_json::from_str(config)?;
825
826        let cfg = ModelConfigMetadata {
827            max_seq_len: cfg.max_position_embeddings,
828            num_layers: cfg.num_hidden_layers,
829            hidden_size: cfg.hidden_size,
830            num_kv_heads: cfg.num_key_value_heads,
831            num_attn_heads: cfg.num_attention_heads,
832            sliding_window: cfg.sliding_window,
833            k_head_dim: cfg.head_dim(),
834            v_head_dim: cfg.head_dim(),
835            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
836        };
837
838        Ok(Box::new(cfg))
839    }
840
841    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
842        Some(vec![NonMappedSubModel::Vision])
843    }
844}
845
846// ======================== Idefics 2 loader
847
848/// [`MultimodalLoader`] for an Idefics 2 Vision model.
849///
850/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
851pub struct Idefics2Loader;
852
853pub struct Idefics2Prefixer;
854
855impl MultimodalPromptPrefixer for Idefics2Prefixer {
856    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
857        // Chat template does it
858        prompt.to_string()
859    }
860}
861
862impl MultimodalModelLoader for Idefics2Loader {
863    fn load(
864        &self,
865        config: &str,
866        vb: ShardedVarBuilder,
867        normal_loading_metadata: NormalLoadingMetadata,
868        attention_mechanism: AttentionImplementation,
869    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
870        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
871        Ok(Box::new(Idefics2::new(
872            &cfg,
873            vb,
874            self.is_gptx(config),
875            normal_loading_metadata,
876            attention_mechanism,
877        )?))
878    }
879    fn is_gptx(&self, _config: &str) -> bool {
880        true
881    }
882    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
883        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
884        Ok(Box::new(cfg))
885    }
886    fn get_processor(
887        &self,
888        _model_config: &str,
889        processor_config: Option<ProcessorConfig>,
890        preprocessor_config: PreProcessorConfig,
891        max_edge: Option<u32>,
892    ) -> Arc<dyn Processor + Send + Sync> {
893        Arc::new(Idefics2Processor::new(
894            processor_config.unwrap(),
895            preprocessor_config,
896            max_edge,
897        ))
898    }
899    fn supports_paged_attention(&self, _config: &str) -> bool {
900        true
901    }
902    fn supports_prefix_cacher(&self, _config: &str) -> bool {
903        true
904    }
905    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
906        Arc::new(Idefics2Prefixer)
907    }
908    fn modalities(&self, _config: &str) -> Result<Modalities> {
909        Ok(Modalities {
910            input: vec![SupportedModality::Text, SupportedModality::Vision],
911            output: vec![SupportedModality::Text],
912        })
913    }
914}
915
916impl IsqModelLoader for Idefics2Loader {
917    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
918        Ok(vec![
919            Regex::new(r"lm_head\.(weight|bias)$")?,
920            // Attention
921            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
922            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
923            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
924            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
925            // MLP
926            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
927            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
928            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
929        ])
930    }
931    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
932        Ok(vec![
933            Regex::new(r"lm_head\.(weight|bias)$")?,
934            // Attention
935            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
936            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
937            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
938            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
939            // MLP
940            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
941            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
942            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
943        ])
944    }
945}
946
947impl DeviceMappedModelLoader for Idefics2Loader {
948    fn mapped_max_act_size_elems(
949        &self,
950        config: &str,
951        params: &AutoDeviceMapParams,
952    ) -> Result<usize> {
953        let AutoDeviceMapParams::Multimodal {
954            max_seq_len,
955            max_batch_size,
956            max_image_shape: _,
957            max_num_images,
958        } = params
959        else {
960            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
961        };
962
963        let cfg: Idefics2Config = serde_json::from_str(config)?;
964
965        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
966        let img_seq_len = (num_patches + 1) * max_num_images;
967
968        let max_text_attn = {
969            // This model injects the vision information directly into the input embeddings
970            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
971            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
972        };
973
974        Ok(max_text_attn)
975    }
976
977    fn non_mapped_max_act_size_elems(
978        &self,
979        config: &str,
980        params: &AutoDeviceMapParams,
981    ) -> Result<usize> {
982        let AutoDeviceMapParams::Multimodal {
983            max_seq_len: _,
984            max_batch_size,
985            max_image_shape: _,
986            max_num_images,
987        } = params
988        else {
989            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
990        };
991
992        let cfg: Idefics2Config = serde_json::from_str(config)?;
993
994        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
995        let img_seq_len = num_patches + 1;
996
997        let max_vision_attn = {
998            // do_image_splitting = true
999            let images_factor = 5;
1000
1001            (max_batch_size * images_factor * max_num_images)
1002                * cfg.vision_config.num_attention_heads
1003                * img_seq_len
1004                * img_seq_len
1005        };
1006
1007        Ok(max_vision_attn)
1008    }
1009
1010    fn non_mapped_size_in_bytes(
1011        &self,
1012        config: &str,
1013        dtype: DType,
1014        weight_pack_factor: usize,
1015        _matformer_config: Option<&MatformerSliceConfig>,
1016    ) -> Result<usize> {
1017        let cfg: Idefics2Config = serde_json::from_str(config)?;
1018        let text_elems = {
1019            let tie_word_embeddings = cfg.tie_word_embeddings;
1020            let cfg = &cfg.text_config;
1021
1022            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1023            let lm_head = if !tie_word_embeddings {
1024                cfg.hidden_size * cfg.vocab_size
1025            } else {
1026                0
1027            };
1028            let norm = cfg.hidden_size;
1029            embed_tokens + lm_head + norm
1030        };
1031
1032        let connector_elems = {
1033            let tcfg = &cfg.text_config;
1034            let vcfg = &cfg.vision_config;
1035            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
1036            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
1037            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
1038
1039            let perceiver_elems = {
1040                let tcfg = &cfg.text_config;
1041                let pcfg = &cfg.perceiver_config;
1042
1043                let n_latents = pcfg.resampler_n_latents;
1044                let hidden_size = tcfg.hidden_size;
1045                let depth = pcfg.resampler_depth;
1046
1047                let norm = tcfg.hidden_size;
1048                let latents = n_latents * hidden_size;
1049
1050                let layer_elems = {
1051                    let input_latents_norm = hidden_size;
1052                    let input_context_norm = hidden_size;
1053                    let post_attn_norm = hidden_size;
1054
1055                    let num_heads = pcfg.resampler_n_heads;
1056                    let head_dim = pcfg.resampler_head_dim;
1057                    let num_key_value_heads = pcfg.num_key_value_heads;
1058
1059                    let q_proj = hidden_size * num_heads * head_dim;
1060                    let k_proj = hidden_size * num_key_value_heads * head_dim;
1061                    let v_proj = hidden_size * num_key_value_heads * head_dim;
1062                    let o_proj = num_heads * head_dim * hidden_size;
1063
1064                    let gate_proj = hidden_size * hidden_size * 4;
1065                    let up_proj = hidden_size * hidden_size * 4;
1066                    let down_proj = hidden_size * 4 * hidden_size;
1067
1068                    input_latents_norm
1069                        + input_context_norm
1070                        + post_attn_norm
1071                        + q_proj
1072                        + k_proj
1073                        + v_proj
1074                        + o_proj
1075                        + gate_proj
1076                        + up_proj
1077                        + down_proj
1078                };
1079
1080                norm + latents + layer_elems * depth
1081            };
1082
1083            gate_proj + up_proj + down_proj + perceiver_elems
1084        };
1085
1086        let vision_transformer = {
1087            let cfg = &cfg.vision_config;
1088
1089            let post_layernorm = cfg.hidden_size;
1090
1091            let conv_config = Conv2dConfig {
1092                stride: cfg.patch_size,
1093                ..Default::default()
1094            };
1095            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1096                * cfg.patch_size
1097                * cfg.patch_size;
1098
1099            let num_patches_per_side = cfg.image_size / cfg.patch_size;
1100            let num_patches = num_patches_per_side.pow(2);
1101            let position_embedding = num_patches * cfg.hidden_size;
1102
1103            let layer_elems = {
1104                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1105                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1106
1107                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1108                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1109
1110                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1111                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1112                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1113                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1114
1115                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1116            };
1117
1118            post_layernorm + patch_embedding + position_embedding + layer_elems
1119        };
1120
1121        let elems = text_elems + connector_elems + vision_transformer;
1122
1123        Ok(elems * dtype.size_in_bytes())
1124    }
1125
1126    fn layer_sizes_in_bytes(
1127        &self,
1128        config: &str,
1129        dtype: DType,
1130        weight_pack_factor: usize,
1131        _matformer_config: Option<&MatformerSliceConfig>,
1132    ) -> Result<Vec<usize>> {
1133        let cfg: Idefics2Config = serde_json::from_str(config)?;
1134        let cfg = cfg.text_config;
1135        let per_layer_elems = {
1136            let input_layernorm = cfg.hidden_size;
1137            let post_attention_layernorm = cfg.hidden_size;
1138
1139            let size_in = cfg.hidden_size;
1140            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1141            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1142            let q_proj = size_in * size_q / weight_pack_factor;
1143            let k_proj = size_in * size_kv / weight_pack_factor;
1144            let v_proj = size_in * size_kv / weight_pack_factor;
1145            let o_proj = size_q * size_in / weight_pack_factor;
1146
1147            let h_size = cfg.hidden_size;
1148            let i_size = cfg.intermediate_size;
1149            let gate_proj = h_size * i_size / weight_pack_factor;
1150            let up_proj = h_size * i_size / weight_pack_factor;
1151            let down_proj = i_size * h_size / weight_pack_factor;
1152
1153            input_layernorm
1154                + post_attention_layernorm
1155                + q_proj
1156                + k_proj
1157                + v_proj
1158                + o_proj
1159                + gate_proj
1160                + up_proj
1161                + down_proj
1162        };
1163        Ok(vec![
1164            per_layer_elems * dtype.size_in_bytes();
1165            cfg.num_hidden_layers
1166        ])
1167    }
1168
1169    fn num_layers(&self, config: &str) -> Result<usize> {
1170        let cfg: Idefics2Config = serde_json::from_str(config)?;
1171        Ok(cfg.text_config.num_hidden_layers)
1172    }
1173    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1174        let cfg: Idefics2Config = serde_json::from_str(config)?;
1175        let cfg = &cfg.text_config;
1176
1177        let cfg = ModelConfigMetadata {
1178            max_seq_len: cfg.max_position_embeddings,
1179            num_layers: cfg.num_hidden_layers,
1180            hidden_size: cfg.hidden_size,
1181            num_kv_heads: cfg.num_key_value_heads,
1182            num_attn_heads: cfg.num_attention_heads,
1183            sliding_window: cfg.sliding_window,
1184            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1185            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1186            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1187        };
1188
1189        Ok(Box::new(cfg))
1190    }
1191
1192    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1193        Some(vec![NonMappedSubModel::Vision])
1194    }
1195}
1196
1197// ======================== LLaVANext Loader
1198
1199/// [`MultimodalLoader`] for an LLaVANext Vision model.
1200///
1201/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
1202pub struct LLaVANextLoader;
1203
1204pub struct LLaVANextPrefixer;
1205
1206impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1207    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1208        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1209    }
1210}
1211
1212impl MultimodalModelLoader for LLaVANextLoader {
1213    fn load(
1214        &self,
1215        config: &str,
1216        vb: ShardedVarBuilder,
1217        normal_loading_metadata: NormalLoadingMetadata,
1218        attention_mechanism: AttentionImplementation,
1219    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
1220        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1221        Ok(Box::new(LLaVANext::new(
1222            &cfg,
1223            vb,
1224            self.is_gptx(config),
1225            normal_loading_metadata,
1226            attention_mechanism,
1227        )?))
1228    }
1229    fn is_gptx(&self, _config: &str) -> bool {
1230        false
1231    }
1232    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1233        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1234        Ok(Box::new(cfg))
1235    }
1236    fn get_processor(
1237        &self,
1238        model_config: &str,
1239        _processor_config: Option<ProcessorConfig>,
1240        _preprocessor_config: PreProcessorConfig,
1241        _max_edge: Option<u32>,
1242    ) -> Arc<dyn Processor + Send + Sync> {
1243        Arc::new(LLaVANextProcessor::new(model_config))
1244    }
1245    fn supports_paged_attention(&self, _config: &str) -> bool {
1246        true
1247    }
1248    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1249        true
1250    }
1251    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1252        Arc::new(LLaVANextPrefixer)
1253    }
1254    fn modalities(&self, _config: &str) -> Result<Modalities> {
1255        Ok(Modalities {
1256            input: vec![SupportedModality::Text, SupportedModality::Vision],
1257            output: vec![SupportedModality::Text],
1258        })
1259    }
1260}
1261
1262impl IsqModelLoader for LLaVANextLoader {
1263    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1264        Ok(vec![
1265            Regex::new(r"lm_head\.(weight|bias)$")?,
1266            // Attention
1267            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1268            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1269            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1270            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1271            // MLP
1272            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1273            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1274            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1275        ])
1276    }
1277    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1278        Ok(vec![
1279            Regex::new(r"lm_head\.(weight|bias)$")?,
1280            // Attention
1281            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1282            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1283            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1284            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1285            // MLP
1286            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1287            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1288            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1289        ])
1290    }
1291}
1292
1293impl DeviceMappedModelLoader for LLaVANextLoader {
1294    fn mapped_max_act_size_elems(
1295        &self,
1296        config: &str,
1297        params: &AutoDeviceMapParams,
1298    ) -> Result<usize> {
1299        let AutoDeviceMapParams::Multimodal {
1300            max_seq_len,
1301            max_batch_size,
1302            max_image_shape,
1303            max_num_images,
1304        } = params
1305        else {
1306            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1307        };
1308
1309        let config: LLaVAConfig = serde_json::from_str(config)?;
1310
1311        #[allow(clippy::cast_possible_truncation)]
1312        let img_seq_len =
1313            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1314                &config,
1315                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1316            );
1317        let img_seq_len = img_seq_len * max_num_images;
1318
1319        let max_text_attn = {
1320            let cfg = &config.text_config;
1321            // This model injects the vision information directly into the input embeddings
1322            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1323
1324            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1325        };
1326
1327        Ok(max_text_attn)
1328    }
1329
1330    fn non_mapped_max_act_size_elems(
1331        &self,
1332        config: &str,
1333        params: &AutoDeviceMapParams,
1334    ) -> Result<usize> {
1335        let AutoDeviceMapParams::Multimodal {
1336            max_seq_len: _,
1337            max_batch_size,
1338            max_image_shape,
1339            max_num_images,
1340        } = params
1341        else {
1342            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1343        };
1344
1345        let config: LLaVAConfig = serde_json::from_str(config)?;
1346
1347        #[allow(clippy::cast_possible_truncation)]
1348        let img_seq_len =
1349            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1350                &config,
1351                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1352            );
1353
1354        let max_vision_attn = {
1355            (max_batch_size * max_num_images)
1356                * config.vision_config.num_attention_heads
1357                * img_seq_len
1358                * img_seq_len
1359        };
1360
1361        Ok(max_vision_attn)
1362    }
1363
1364    fn non_mapped_size_in_bytes(
1365        &self,
1366        config: &str,
1367        dtype: DType,
1368        weight_pack_factor: usize,
1369        _matformer_config: Option<&MatformerSliceConfig>,
1370    ) -> Result<usize> {
1371        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1372        let text_elems = {
1373            let cfg = &cfg.text_config;
1374            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1375            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1376            let norm = cfg.hidden_size;
1377            embed_tokens + lm_head + norm
1378        };
1379
1380        let image_newline = cfg.text_config.hidden_size;
1381        let mmproj = {
1382            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1383                + cfg.text_config.hidden_size;
1384            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1385                + cfg.text_config.hidden_size;
1386
1387            linear_1 + linear_2
1388        };
1389        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1390
1391        let elems = text_elems + image_newline + mmproj + vision_tower;
1392        Ok(elems * dtype.size_in_bytes())
1393    }
1394
1395    fn layer_sizes_in_bytes(
1396        &self,
1397        config: &str,
1398        dtype: DType,
1399        weight_pack_factor: usize,
1400        _matformer_config: Option<&MatformerSliceConfig>,
1401    ) -> Result<Vec<usize>> {
1402        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1403        let per_layer_elems = {
1404            let cfg = &cfg.text_config;
1405            let input_layernorm = cfg.hidden_size;
1406            let post_attention_layernorm = cfg.hidden_size;
1407
1408            let size_in = cfg.hidden_size;
1409            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1410            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1411            let q_proj = size_in * size_q / weight_pack_factor;
1412            let k_proj = size_in * size_kv / weight_pack_factor;
1413            let v_proj = size_in * size_kv / weight_pack_factor;
1414            let o_proj = size_q * size_in / weight_pack_factor;
1415
1416            let h_size = cfg.hidden_size;
1417            let i_size = cfg.intermediate_size;
1418            let gate_proj = h_size * i_size / weight_pack_factor;
1419            let up_proj = h_size * i_size / weight_pack_factor;
1420            let down_proj = i_size * h_size / weight_pack_factor;
1421
1422            input_layernorm
1423                + post_attention_layernorm
1424                + q_proj
1425                + k_proj
1426                + v_proj
1427                + o_proj
1428                + gate_proj
1429                + up_proj
1430                + down_proj
1431        };
1432        Ok(vec![
1433            per_layer_elems * dtype.size_in_bytes();
1434            cfg.text_config.num_hidden_layers
1435        ])
1436    }
1437
1438    fn num_layers(&self, config: &str) -> Result<usize> {
1439        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1440        Ok(cfg.text_config.num_hidden_layers)
1441    }
1442
1443    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1444        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1445        let cfg = &cfg.text_config;
1446
1447        let cfg = ModelConfigMetadata {
1448            max_seq_len: cfg.max_position_embeddings,
1449            num_layers: cfg.num_hidden_layers,
1450            hidden_size: cfg.hidden_size,
1451            num_kv_heads: cfg.num_key_value_heads,
1452            num_attn_heads: cfg.num_attention_heads,
1453            sliding_window: cfg.sliding_window,
1454            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1455            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1456            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1457        };
1458
1459        Ok(Box::new(cfg))
1460    }
1461
1462    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1463        Some(vec![NonMappedSubModel::Vision])
1464    }
1465}
1466
1467// ======================== LLaVA Loader
1468
1469/// [`MultimodalLoader`] for an LLaVA Vision model.
1470///
1471/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
1472pub struct LLaVALoader;
1473
1474pub struct LLaVAPrefixer;
1475
1476impl MultimodalPromptPrefixer for LLaVAPrefixer {
1477    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1478        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1479    }
1480}
1481
1482impl MultimodalModelLoader for LLaVALoader {
1483    fn load(
1484        &self,
1485        config: &str,
1486        vb: ShardedVarBuilder,
1487        normal_loading_metadata: NormalLoadingMetadata,
1488        attention_mechanism: AttentionImplementation,
1489    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
1490        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1491        Ok(Box::new(LLaVA::new(
1492            &cfg,
1493            vb,
1494            self.is_gptx(config),
1495            normal_loading_metadata,
1496            attention_mechanism,
1497        )?))
1498    }
1499    fn is_gptx(&self, _config: &str) -> bool {
1500        false
1501    }
1502    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1503        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1504        Ok(Box::new(cfg))
1505    }
1506    fn get_processor(
1507        &self,
1508        model_config: &str,
1509        _processor_config: Option<ProcessorConfig>,
1510        _preprocessor_config: PreProcessorConfig,
1511        _max_edge: Option<u32>,
1512    ) -> Arc<dyn Processor + Send + Sync> {
1513        Arc::new(LLaVAProcessor::new(model_config))
1514    }
1515    fn supports_paged_attention(&self, _config: &str) -> bool {
1516        true
1517    }
1518    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1519        true
1520    }
1521    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1522        Arc::new(LLaVAPrefixer)
1523    }
1524    fn modalities(&self, _config: &str) -> Result<Modalities> {
1525        Ok(Modalities {
1526            input: vec![SupportedModality::Text, SupportedModality::Vision],
1527            output: vec![SupportedModality::Text],
1528        })
1529    }
1530}
1531
1532impl IsqModelLoader for LLaVALoader {
1533    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1534        Ok(vec![
1535            Regex::new(r"lm_head\.(weight|bias)$")?,
1536            // Attention
1537            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1538            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1539            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1540            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1541            // MLP
1542            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1543            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1544            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1545        ])
1546    }
1547    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1548        Ok(vec![
1549            Regex::new(r"lm_head\.(weight|bias)$")?,
1550            // Attention
1551            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1552            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1553            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1554            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1555            // MLP
1556            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1557            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1558            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1559        ])
1560    }
1561}
1562
1563impl DeviceMappedModelLoader for LLaVALoader {
1564    fn mapped_max_act_size_elems(
1565        &self,
1566        config: &str,
1567        params: &AutoDeviceMapParams,
1568    ) -> Result<usize> {
1569        let AutoDeviceMapParams::Multimodal {
1570            max_seq_len,
1571            max_batch_size,
1572            max_image_shape: _,
1573            max_num_images,
1574        } = params
1575        else {
1576            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1577        };
1578
1579        let config: LLaVAConfig = serde_json::from_str(config)?;
1580
1581        let img_seq_len =
1582            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1583        let img_seq_len = img_seq_len * max_num_images;
1584
1585        let max_text_attn = {
1586            let cfg = &config.text_config;
1587            // This model injects the vision information directly into the input embeddings
1588            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1589
1590            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1591        };
1592
1593        Ok(max_text_attn)
1594    }
1595
1596    fn non_mapped_max_act_size_elems(
1597        &self,
1598        config: &str,
1599        params: &AutoDeviceMapParams,
1600    ) -> Result<usize> {
1601        let AutoDeviceMapParams::Multimodal {
1602            max_seq_len: _,
1603            max_batch_size,
1604            max_image_shape: _,
1605            max_num_images,
1606        } = params
1607        else {
1608            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1609        };
1610
1611        let config: LLaVAConfig = serde_json::from_str(config)?;
1612
1613        let img_seq_len =
1614            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1615
1616        let max_vision_attn = {
1617            (max_batch_size * max_num_images)
1618                * config.vision_config.num_attention_heads
1619                * img_seq_len
1620                * img_seq_len
1621        };
1622
1623        Ok(max_vision_attn)
1624    }
1625
1626    fn non_mapped_size_in_bytes(
1627        &self,
1628        config: &str,
1629        dtype: DType,
1630        weight_pack_factor: usize,
1631        _matformer_config: Option<&MatformerSliceConfig>,
1632    ) -> Result<usize> {
1633        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1634        let text_elems = {
1635            let cfg = &cfg.text_config;
1636            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1637            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1638            let norm = cfg.hidden_size;
1639            embed_tokens + lm_head + norm
1640        };
1641
1642        let image_newline = cfg.text_config.hidden_size;
1643        let mmproj = {
1644            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1645                + cfg.text_config.hidden_size;
1646            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1647                + cfg.text_config.hidden_size;
1648
1649            linear_1 + linear_2
1650        };
1651        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1652
1653        let elems = text_elems + image_newline + mmproj + vision_tower;
1654        Ok(elems * dtype.size_in_bytes())
1655    }
1656
1657    fn layer_sizes_in_bytes(
1658        &self,
1659        config: &str,
1660        dtype: DType,
1661        weight_pack_factor: usize,
1662        _matformer_config: Option<&MatformerSliceConfig>,
1663    ) -> Result<Vec<usize>> {
1664        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1665        let per_layer_elems = {
1666            let cfg = &cfg.text_config;
1667            let input_layernorm = cfg.hidden_size;
1668            let post_attention_layernorm = cfg.hidden_size;
1669
1670            let size_in = cfg.hidden_size;
1671            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1672            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1673            let q_proj = size_in * size_q / weight_pack_factor;
1674            let k_proj = size_in * size_kv / weight_pack_factor;
1675            let v_proj = size_in * size_kv / weight_pack_factor;
1676            let o_proj = size_q * size_in / weight_pack_factor;
1677
1678            let h_size = cfg.hidden_size;
1679            let i_size = cfg.intermediate_size;
1680            let gate_proj = h_size * i_size / weight_pack_factor;
1681            let up_proj = h_size * i_size / weight_pack_factor;
1682            let down_proj = i_size * h_size / weight_pack_factor;
1683
1684            input_layernorm
1685                + post_attention_layernorm
1686                + q_proj
1687                + k_proj
1688                + v_proj
1689                + o_proj
1690                + gate_proj
1691                + up_proj
1692                + down_proj
1693        };
1694        Ok(vec![
1695            per_layer_elems * dtype.size_in_bytes();
1696            cfg.text_config.num_hidden_layers
1697        ])
1698    }
1699
1700    fn num_layers(&self, config: &str) -> Result<usize> {
1701        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1702        Ok(cfg.text_config.num_hidden_layers)
1703    }
1704
1705    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1706        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1707        let cfg = &cfg.text_config;
1708
1709        let cfg = ModelConfigMetadata {
1710            max_seq_len: cfg.max_position_embeddings,
1711            num_layers: cfg.num_hidden_layers,
1712            hidden_size: cfg.hidden_size,
1713            num_kv_heads: cfg.num_key_value_heads,
1714            num_attn_heads: cfg.num_attention_heads,
1715            sliding_window: cfg.sliding_window,
1716            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1717            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1718            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1719        };
1720
1721        Ok(Box::new(cfg))
1722    }
1723
1724    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1725        Some(vec![NonMappedSubModel::Vision])
1726    }
1727}
1728
1729// ======================== MLlama Loader
1730
1731/// [`MultimodalLoader`] for an Llama Vision model.
1732///
1733/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
1734pub struct VLlamaLoader;
1735
1736pub struct VLlamaPrefixer;
1737
1738impl MultimodalPromptPrefixer for VLlamaPrefixer {
1739    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1740        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1741    }
1742}
1743
1744impl MultimodalModelLoader for VLlamaLoader {
1745    fn load(
1746        &self,
1747        config: &str,
1748        vb: ShardedVarBuilder,
1749        normal_loading_metadata: NormalLoadingMetadata,
1750        attention_mechanism: AttentionImplementation,
1751    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
1752        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1753        Ok(Box::new(MLlamaModel::new(
1754            &cfg,
1755            vb,
1756            self.is_gptx(config),
1757            normal_loading_metadata,
1758            attention_mechanism,
1759        )?))
1760    }
1761    fn is_gptx(&self, _config: &str) -> bool {
1762        true
1763    }
1764    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1765        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1766        Ok(Box::new(cfg))
1767    }
1768    fn get_processor(
1769        &self,
1770        _model_config: &str,
1771        _processor_config: Option<ProcessorConfig>,
1772        _preprocessor_config: PreProcessorConfig,
1773        _max_edge: Option<u32>,
1774    ) -> Arc<dyn Processor + Send + Sync> {
1775        Arc::new(MLlamaProcessor::new())
1776    }
1777    fn supports_paged_attention(&self, _config: &str) -> bool {
1778        true
1779    }
1780    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1781        true
1782    }
1783    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1784        Arc::new(VLlamaPrefixer)
1785    }
1786    fn modalities(&self, _config: &str) -> Result<Modalities> {
1787        Ok(Modalities {
1788            input: vec![SupportedModality::Text, SupportedModality::Vision],
1789            output: vec![SupportedModality::Text],
1790        })
1791    }
1792}
1793
1794impl IsqModelLoader for VLlamaLoader {
1795    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1796        let config: MLlamaConfig = serde_json::from_str(config)?;
1797        let cross_attn_layers = &config.text_config.cross_attention_layers;
1798        let transformer_layers =
1799            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1800        let mut text_regexes = Vec::new();
1801        for layer in transformer_layers {
1802            text_regexes.extend(vec![
1803                // Attention text
1804                Regex::new(&format!(
1805                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1806                ))?,
1807                Regex::new(&format!(
1808                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1809                ))?,
1810                Regex::new(&format!(
1811                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1812                ))?,
1813                Regex::new(&format!(
1814                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1815                ))?,
1816                // MLP text
1817                Regex::new(&format!(
1818                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1819                ))?,
1820                Regex::new(&format!(
1821                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1822                ))?,
1823                Regex::new(&format!(
1824                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1825                ))?,
1826            ]);
1827        }
1828        let vision_regexes = vec![
1829            // Vision attention (transformer)
1830            Regex::new(
1831                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1832            )?,
1833            Regex::new(
1834                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1835            )?,
1836            Regex::new(
1837                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1838            )?,
1839            Regex::new(
1840                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1841            )?,
1842            // Vision attention (global transforemr)
1843            Regex::new(
1844                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1845            )?,
1846            Regex::new(
1847                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1848            )?,
1849            Regex::new(
1850                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1851            )?,
1852            Regex::new(
1853                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1854            )?,
1855            // MLP vision
1856            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1857            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1858        ];
1859
1860        Ok([text_regexes, vision_regexes].concat())
1861    }
1862    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1863        self.isq_layer_regexes(config)
1864    }
1865}
1866
1867impl DeviceMappedModelLoader for VLlamaLoader {
1868    fn mapped_max_act_size_elems(
1869        &self,
1870        config: &str,
1871        params: &AutoDeviceMapParams,
1872    ) -> Result<usize> {
1873        let AutoDeviceMapParams::Multimodal {
1874            max_seq_len,
1875            max_batch_size,
1876            max_image_shape: _,
1877            max_num_images,
1878        } = params
1879        else {
1880            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1881        };
1882
1883        let config: MLlamaConfig = serde_json::from_str(config)?;
1884
1885        let img_seq_len = {
1886            let cfg = &config.vision_config;
1887            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1888            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1889            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1890        };
1891        let img_seq_len = img_seq_len * max_num_images;
1892
1893        let max_cross_text_attn = {
1894            let cfg = &config.text_config;
1895            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1896        };
1897
1898        let max_self_text_attn = {
1899            let cfg = &config.text_config;
1900            max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1901        };
1902
1903        Ok(max_self_text_attn.max(max_cross_text_attn))
1904    }
1905
1906    fn non_mapped_max_act_size_elems(
1907        &self,
1908        config: &str,
1909        params: &AutoDeviceMapParams,
1910    ) -> Result<usize> {
1911        let AutoDeviceMapParams::Multimodal {
1912            max_seq_len: _,
1913            max_batch_size,
1914            max_image_shape: _,
1915            max_num_images,
1916        } = params
1917        else {
1918            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1919        };
1920
1921        let config: MLlamaConfig = serde_json::from_str(config)?;
1922
1923        let img_seq_len = {
1924            let cfg = &config.vision_config;
1925            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1926            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1927            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1928        };
1929        let max_vision_attn = {
1930            let cfg = &config.vision_config;
1931            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1932        };
1933
1934        Ok(max_vision_attn)
1935    }
1936
1937    fn non_mapped_size_in_bytes(
1938        &self,
1939        config: &str,
1940        dtype: DType,
1941        weight_pack_factor: usize,
1942        _matformer_config: Option<&MatformerSliceConfig>,
1943    ) -> Result<usize> {
1944        let config: MLlamaConfig = serde_json::from_str(config)?;
1945        let text_elems = {
1946            let cfg = &config.text_config;
1947            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1948            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1949            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1950                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1951            } else {
1952                0
1953            };
1954            let norm = cfg.hidden_size;
1955            embed_tokens + lm_head + norm
1956        };
1957
1958        let vision_elems = {
1959            let cfg = &config.vision_config;
1960
1961            let conv_cfg = Conv2dConfig {
1962                stride: cfg.patch_size,
1963                ..Default::default()
1964            };
1965            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1966                * cfg.patch_size
1967                * cfg.patch_size;
1968
1969            let class_embedding = cfg.hidden_size;
1970
1971            let gated_positional_embedding = {
1972                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1973                let embedding = num_patches * cfg.hidden_size;
1974                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1975                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1976
1977                embedding + tile_embedding
1978            };
1979
1980            let pre_tile_positional_embedding =
1981                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1982            let post_tile_positional_embedding =
1983                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1984
1985            let layernorm_pre = cfg.hidden_size;
1986            let layernorm_post = cfg.hidden_size;
1987
1988            let encoder_layer = {
1989                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1990                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1991
1992                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1993                let q_proj =
1994                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1995                let k_proj =
1996                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1997                let v_proj =
1998                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1999                let o_proj =
2000                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
2001
2002                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
2003                    + cfg.intermediate_size;
2004                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
2005                    + cfg.hidden_size;
2006
2007                input_layernorm
2008                    + post_attention_layernorm
2009                    + q_proj
2010                    + k_proj
2011                    + v_proj
2012                    + o_proj
2013                    + fc1
2014                    + fc2
2015            };
2016
2017            patch_embedding
2018                + class_embedding
2019                + gated_positional_embedding
2020                + pre_tile_positional_embedding
2021                + post_tile_positional_embedding
2022                + layernorm_pre
2023                + layernorm_post
2024                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
2025        };
2026
2027        let elems = text_elems + vision_elems;
2028        Ok(elems * dtype.size_in_bytes())
2029    }
2030
2031    fn layer_sizes_in_bytes(
2032        &self,
2033        config: &str,
2034        dtype: DType,
2035        weight_pack_factor: usize,
2036        _matformer_config: Option<&MatformerSliceConfig>,
2037    ) -> Result<Vec<usize>> {
2038        let config: MLlamaConfig = serde_json::from_str(config)?;
2039        let cfg = &config.text_config;
2040
2041        let mut layer_sizes = Vec::new();
2042
2043        for i in 0..cfg.num_hidden_layers {
2044            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
2045                // No isq for cross attention
2046                1
2047            } else {
2048                weight_pack_factor
2049            };
2050
2051            let per_layer_elems = {
2052                let input_layernorm = cfg.hidden_size;
2053                let post_attention_layernorm = cfg.hidden_size;
2054
2055                let size_in = cfg.hidden_size;
2056                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2057                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2058                let q_proj = size_in * size_q / weight_pack_factor;
2059                let k_proj = size_in * size_kv / weight_pack_factor;
2060                let v_proj = size_in * size_kv / weight_pack_factor;
2061                let o_proj = size_q * size_in / weight_pack_factor;
2062
2063                let h_size = cfg.hidden_size;
2064                let i_size = cfg.intermediate_size;
2065                let gate_proj = h_size * i_size / weight_pack_factor;
2066                let up_proj = h_size * i_size / weight_pack_factor;
2067                let down_proj = i_size * h_size / weight_pack_factor;
2068
2069                input_layernorm
2070                    + post_attention_layernorm
2071                    + q_proj
2072                    + k_proj
2073                    + v_proj
2074                    + o_proj
2075                    + gate_proj
2076                    + up_proj
2077                    + down_proj
2078            };
2079
2080            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
2081        }
2082
2083        Ok(layer_sizes)
2084    }
2085
2086    fn num_layers(&self, config: &str) -> Result<usize> {
2087        let config: MLlamaConfig = serde_json::from_str(config)?;
2088        Ok(config.text_config.num_hidden_layers)
2089    }
2090
2091    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2092        let cfg: MLlamaConfig = serde_json::from_str(config)?;
2093        let cfg = &cfg.text_config;
2094
2095        let cfg = ModelConfigMetadata {
2096            max_seq_len: cfg.max_position_embeddings,
2097            num_layers: cfg.num_hidden_layers,
2098            hidden_size: cfg.hidden_size,
2099            num_kv_heads: cfg.num_key_value_heads,
2100            num_attn_heads: cfg.num_attention_heads,
2101            sliding_window: None,
2102            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2103            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2104            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2105        };
2106
2107        Ok(Box::new(cfg))
2108    }
2109
2110    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2111        Some(vec![NonMappedSubModel::Vision])
2112    }
2113}
2114
2115// ======================== Qwen2VL Loader
2116
2117/// [`MultimodalLoader`] for an Qwen2-VL model.
2118///
2119/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
2120pub struct Qwen2VLLoader;
2121
2122pub struct Qwen2VLPrefixer;
2123
2124impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2125    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2126        format!(
2127            "{}{prompt}",
2128            format!(
2129                "{}{}{}",
2130                Qwen2VLProcessor::VISION_START,
2131                Qwen2VLProcessor::IMAGE_PAD,
2132                Qwen2VLProcessor::VISION_END
2133            )
2134            .repeat(image_indexes.len())
2135        )
2136    }
2137}
2138
2139impl MultimodalModelLoader for Qwen2VLLoader {
2140    fn load(
2141        &self,
2142        config: &str,
2143        vb: ShardedVarBuilder,
2144        normal_loading_metadata: NormalLoadingMetadata,
2145        attention_mechanism: AttentionImplementation,
2146    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
2147        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2148        Ok(Box::new(Qwen2VLModel::new(
2149            &cfg,
2150            vb,
2151            self.is_gptx(config),
2152            normal_loading_metadata,
2153            attention_mechanism,
2154        )?))
2155    }
2156    fn is_gptx(&self, _config: &str) -> bool {
2157        true
2158    }
2159    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2160        let config: Qwen2VLConfig = serde_json::from_str(config)?;
2161        Ok(Box::new(config))
2162    }
2163    fn get_processor(
2164        &self,
2165        _model_config: &str,
2166        _processor_config: Option<ProcessorConfig>,
2167        _preprocessor_config: PreProcessorConfig,
2168        max_edge: Option<u32>,
2169    ) -> Arc<dyn Processor + Send + Sync> {
2170        Arc::new(Qwen2VLProcessor::new(max_edge))
2171    }
2172    fn supports_paged_attention(&self, _config: &str) -> bool {
2173        false
2174    }
2175    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2176        Arc::new(Qwen2VLPrefixer)
2177    }
2178    fn modalities(&self, _config: &str) -> Result<Modalities> {
2179        Ok(Modalities {
2180            input: vec![SupportedModality::Text, SupportedModality::Vision],
2181            output: vec![SupportedModality::Text],
2182        })
2183    }
2184}
2185
2186impl IsqModelLoader for Qwen2VLLoader {
2187    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2188        Ok(vec![
2189            Regex::new(r"lm_head\.(weight|bias)$")?,
2190            // Attention
2191            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2192            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2193            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2194            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2195            // MLP
2196            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2197            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2198            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2199        ])
2200    }
2201    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2202        self.isq_layer_regexes(config)
2203    }
2204}
2205
2206impl DeviceMappedModelLoader for Qwen2VLLoader {
2207    fn mapped_max_act_size_elems(
2208        &self,
2209        config: &str,
2210        params: &AutoDeviceMapParams,
2211    ) -> Result<usize> {
2212        let AutoDeviceMapParams::Multimodal {
2213            max_seq_len,
2214            max_batch_size,
2215            max_image_shape,
2216            max_num_images,
2217        } = params
2218        else {
2219            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2220        };
2221
2222        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2223
2224        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
2225        let img_seq_len = {
2226            let cfg = &cfg.vision_config;
2227            // grid_t is 1 for images (temporal dimension is for video only)
2228            let grid_t = 1;
2229            // After patch embedding and spatial merge, the effective grid dimensions are reduced
2230            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
2231            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
2232            grid_t * grid_h * grid_w * max_num_images
2233        };
2234
2235        let max_text_attn = {
2236            // This model injects the vision information directly into the input embeddings
2237            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2238            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2239        };
2240
2241        Ok(max_text_attn)
2242    }
2243
2244    fn non_mapped_max_act_size_elems(
2245        &self,
2246        config: &str,
2247        params: &AutoDeviceMapParams,
2248    ) -> Result<usize> {
2249        let AutoDeviceMapParams::Multimodal {
2250            max_seq_len: _,
2251            max_batch_size,
2252            max_image_shape,
2253            max_num_images,
2254        } = params
2255        else {
2256            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2257        };
2258
2259        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2260
2261        // For the vision encoder, before spatial merging
2262        let img_seq_len = {
2263            let cfg = &cfg.vision_config;
2264            // grid_t is 1 for images
2265            let grid_t = 1;
2266            let grid_h = max_image_shape.0 / cfg.patch_size;
2267            let grid_w = max_image_shape.1 / cfg.patch_size;
2268            grid_t * grid_h * grid_w
2269        };
2270
2271        let max_vision_attn = {
2272            let cfg = &cfg.vision_config;
2273            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2274        };
2275
2276        Ok(max_vision_attn)
2277    }
2278
2279    fn non_mapped_size_in_bytes(
2280        &self,
2281        config: &str,
2282        dtype: DType,
2283        weight_pack_factor: usize,
2284        _matformer_config: Option<&MatformerSliceConfig>,
2285    ) -> Result<usize> {
2286        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2287        let text_elems = {
2288            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2289            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2290            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2291                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2292            } else {
2293                0
2294            };
2295            let norm = cfg.hidden_size;
2296            embed_tokens + lm_head + norm
2297        };
2298
2299        let patch_merger = {
2300            let cfg = &cfg.vision_config;
2301            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2302
2303            let mlp0 = hidden_size * hidden_size + hidden_size;
2304            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2305
2306            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2307
2308            mlp0 + mlp2 + ln_q
2309        };
2310
2311        let patch_embed = {
2312            let cfg = &cfg.vision_config;
2313            let conv_cfg = Conv3dConfig {
2314                stride: cfg.patch_size,
2315                ..Default::default()
2316            };
2317            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2318            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2319                * kernel_sizes[0]
2320                * kernel_sizes[1]
2321                * kernel_sizes[2]
2322        };
2323
2324        let encoder_layer = {
2325            let cfg = &cfg.vision_config;
2326            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2327            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2328
2329            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2330            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2331            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2332            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2333
2334            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2335            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2336
2337            norm1 + norm2 + fc1 + fc2 + qkv + out
2338        };
2339
2340        let elems =
2341            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2342
2343        Ok(elems * dtype.size_in_bytes())
2344    }
2345
2346    fn layer_sizes_in_bytes(
2347        &self,
2348        config: &str,
2349        dtype: DType,
2350        weight_pack_factor: usize,
2351        _matformer_config: Option<&MatformerSliceConfig>,
2352    ) -> Result<Vec<usize>> {
2353        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2354        let per_layer_elems = {
2355            let input_layernorm = cfg.hidden_size;
2356            let post_attention_layernorm = cfg.hidden_size;
2357
2358            let size_in = cfg.hidden_size;
2359            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2360            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2361            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2362            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2363            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2364            let o_proj = size_q * size_in / weight_pack_factor;
2365
2366            let h_size = cfg.hidden_size;
2367            let i_size = cfg.intermediate_size;
2368            let gate_proj = h_size * i_size / weight_pack_factor;
2369            let up_proj = h_size * i_size / weight_pack_factor;
2370            let down_proj = i_size * h_size / weight_pack_factor;
2371
2372            input_layernorm
2373                + post_attention_layernorm
2374                + q_proj
2375                + k_proj
2376                + v_proj
2377                + o_proj
2378                + gate_proj
2379                + up_proj
2380                + down_proj
2381        };
2382        Ok(vec![
2383            per_layer_elems * dtype.size_in_bytes();
2384            cfg.num_hidden_layers
2385        ])
2386    }
2387
2388    fn num_layers(&self, config: &str) -> Result<usize> {
2389        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2390        Ok(cfg.num_hidden_layers)
2391    }
2392
2393    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2394        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2395
2396        let cfg = ModelConfigMetadata {
2397            max_seq_len: cfg.max_position_embeddings,
2398            num_layers: cfg.num_hidden_layers,
2399            hidden_size: cfg.hidden_size,
2400            num_kv_heads: cfg.num_key_value_heads,
2401            num_attn_heads: cfg.num_attention_heads,
2402            sliding_window: cfg.sliding_window,
2403            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2404            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2405            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2406        };
2407
2408        Ok(Box::new(cfg))
2409    }
2410
2411    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2412        Some(vec![NonMappedSubModel::Vision])
2413    }
2414}
2415
2416// ======================== Idefics 3 loader
2417
2418/// [`MultimodalLoader`] for an Idefics 3 Vision model.
2419///
2420/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
2421pub struct Idefics3Loader;
2422
2423pub struct Idefics3Prefixer;
2424
2425impl MultimodalPromptPrefixer for Idefics3Prefixer {
2426    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2427        // Chat template does it
2428        prompt.to_string()
2429    }
2430}
2431
2432impl MultimodalModelLoader for Idefics3Loader {
2433    fn load(
2434        &self,
2435        config: &str,
2436        vb: ShardedVarBuilder,
2437        normal_loading_metadata: NormalLoadingMetadata,
2438        attention_mechanism: AttentionImplementation,
2439    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
2440        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2441        Ok(Box::new(Idefics3Model::new(
2442            &cfg,
2443            vb,
2444            self.is_gptx(config),
2445            normal_loading_metadata,
2446            attention_mechanism,
2447        )?))
2448    }
2449    fn is_gptx(&self, _config: &str) -> bool {
2450        true
2451    }
2452    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2453        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2454        Ok(Box::new(cfg))
2455    }
2456    fn get_processor(
2457        &self,
2458        _model_config: &str,
2459        processor_config: Option<ProcessorConfig>,
2460        preprocessor_config: PreProcessorConfig,
2461        max_edge: Option<u32>,
2462    ) -> Arc<dyn Processor + Send + Sync> {
2463        Arc::new(Idefics3Processor::new(
2464            processor_config.unwrap_or_default(),
2465            preprocessor_config,
2466            max_edge,
2467        ))
2468    }
2469    fn supports_paged_attention(&self, _config: &str) -> bool {
2470        true
2471    }
2472    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2473        true
2474    }
2475    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2476        Arc::new(Idefics3Prefixer)
2477    }
2478    fn modalities(&self, _config: &str) -> Result<Modalities> {
2479        Ok(Modalities {
2480            input: vec![SupportedModality::Text, SupportedModality::Vision],
2481            output: vec![SupportedModality::Text],
2482        })
2483    }
2484}
2485
2486impl IsqModelLoader for Idefics3Loader {
2487    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2488        Ok(vec![
2489            Regex::new(r"lm_head\.(weight|bias)$")?,
2490            // Attention
2491            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2492            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2493            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2494            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2495            // MLP
2496            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2497            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2498            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2499        ])
2500    }
2501    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2502        Ok(vec![
2503            Regex::new(r"lm_head\.(weight|bias)$")?,
2504            // Attention
2505            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2506            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2507            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2508            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2509            // MLP
2510            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2511            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2512            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2513            // // Attention (vision)
2514            // Regex::new(
2515            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2516            // )?,
2517            // Regex::new(
2518            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
2519            // )?,
2520            // Regex::new(
2521            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
2522            // )?,
2523            // Regex::new(
2524            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
2525            // )?,
2526            // MLP (vision)
2527            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2528            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
2529        ])
2530    }
2531}
2532
2533impl DeviceMappedModelLoader for Idefics3Loader {
2534    fn mapped_max_act_size_elems(
2535        &self,
2536        config: &str,
2537        params: &AutoDeviceMapParams,
2538    ) -> Result<usize> {
2539        let AutoDeviceMapParams::Multimodal {
2540            max_seq_len,
2541            max_batch_size,
2542            max_image_shape: _,
2543            max_num_images,
2544        } = params
2545        else {
2546            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2547        };
2548
2549        let cfg: Idefics3Config = serde_json::from_str(config)?;
2550
2551        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2552        let img_seq_len = (num_patches + 1) * max_num_images;
2553
2554        let max_text_attn = {
2555            // This model injects the vision information directly into the input embeddings
2556            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2557            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2558        };
2559
2560        Ok(max_text_attn)
2561    }
2562
2563    fn non_mapped_max_act_size_elems(
2564        &self,
2565        config: &str,
2566        params: &AutoDeviceMapParams,
2567    ) -> Result<usize> {
2568        let AutoDeviceMapParams::Multimodal {
2569            max_seq_len: _,
2570            max_batch_size,
2571            max_image_shape: _,
2572            max_num_images,
2573        } = params
2574        else {
2575            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2576        };
2577
2578        let cfg: Idefics3Config = serde_json::from_str(config)?;
2579
2580        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2581        let img_seq_len = num_patches + 1;
2582
2583        let max_vision_attn = {
2584            // do_image_splitting = true
2585            let images_factor = 5;
2586
2587            (max_batch_size * images_factor * max_num_images)
2588                * cfg.vision_config.num_attention_heads
2589                * img_seq_len
2590                * img_seq_len
2591        };
2592
2593        Ok(max_vision_attn)
2594    }
2595
2596    fn non_mapped_size_in_bytes(
2597        &self,
2598        config: &str,
2599        dtype: DType,
2600        weight_pack_factor: usize,
2601        _matformer_config: Option<&MatformerSliceConfig>,
2602    ) -> Result<usize> {
2603        let cfg: Idefics3Config = serde_json::from_str(config)?;
2604        let text_elems = {
2605            let cfg = &cfg.text_config;
2606
2607            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2608            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2609            let norm = cfg.hidden_size;
2610            embed_tokens + lm_head + norm
2611        };
2612
2613        let connector_elems = {
2614            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2615            let out_dim = cfg.text_config.hidden_size;
2616
2617            in_dim * out_dim
2618        };
2619
2620        let vision_transformer = {
2621            let cfg = &cfg.vision_config;
2622
2623            let post_layernorm = cfg.hidden_size;
2624
2625            let conv_config = Conv2dConfig {
2626                stride: cfg.patch_size,
2627                ..Default::default()
2628            };
2629            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2630                * cfg.patch_size
2631                * cfg.patch_size;
2632
2633            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2634            let num_patches = num_patches_per_side.pow(2);
2635            let position_embedding = num_patches * cfg.hidden_size;
2636
2637            let layer_elems = {
2638                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2639                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2640
2641                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2642                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2643
2644                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2645                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2646                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2647                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2648
2649                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2650            };
2651
2652            post_layernorm
2653                + patch_embedding
2654                + position_embedding
2655                + layer_elems * cfg.num_hidden_layers
2656        };
2657
2658        let elems = text_elems + connector_elems + vision_transformer;
2659
2660        Ok(elems * dtype.size_in_bytes())
2661    }
2662
2663    fn layer_sizes_in_bytes(
2664        &self,
2665        config: &str,
2666        dtype: DType,
2667        weight_pack_factor: usize,
2668        _matformer_config: Option<&MatformerSliceConfig>,
2669    ) -> Result<Vec<usize>> {
2670        let cfg: Idefics3Config = serde_json::from_str(config)?;
2671        let cfg = cfg.text_config;
2672        let per_layer_elems = {
2673            let input_layernorm = cfg.hidden_size;
2674            let post_attention_layernorm = cfg.hidden_size;
2675
2676            let size_in = cfg.hidden_size;
2677            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2678            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2679            let q_proj = size_in * size_q / weight_pack_factor;
2680            let k_proj = size_in * size_kv / weight_pack_factor;
2681            let v_proj = size_in * size_kv / weight_pack_factor;
2682            let o_proj = size_q * size_in / weight_pack_factor;
2683
2684            let h_size = cfg.hidden_size;
2685            let i_size = cfg.intermediate_size;
2686            let gate_proj = h_size * i_size / weight_pack_factor;
2687            let up_proj = h_size * i_size / weight_pack_factor;
2688            let down_proj = i_size * h_size / weight_pack_factor;
2689
2690            input_layernorm
2691                + post_attention_layernorm
2692                + q_proj
2693                + k_proj
2694                + v_proj
2695                + o_proj
2696                + gate_proj
2697                + up_proj
2698                + down_proj
2699        };
2700        Ok(vec![
2701            per_layer_elems * dtype.size_in_bytes();
2702            cfg.num_hidden_layers
2703        ])
2704    }
2705
2706    fn num_layers(&self, config: &str) -> Result<usize> {
2707        let cfg: Idefics3Config = serde_json::from_str(config)?;
2708        Ok(cfg.text_config.num_hidden_layers)
2709    }
2710    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2711        let cfg: Idefics3Config = serde_json::from_str(config)?;
2712        let cfg = &cfg.text_config;
2713
2714        let cfg = ModelConfigMetadata {
2715            max_seq_len: cfg.max_position_embeddings,
2716            num_layers: cfg.num_hidden_layers,
2717            hidden_size: cfg.hidden_size,
2718            num_kv_heads: cfg.num_key_value_heads,
2719            num_attn_heads: cfg.num_attention_heads,
2720            sliding_window: None,
2721            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2722            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2723            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2724        };
2725
2726        Ok(Box::new(cfg))
2727    }
2728
2729    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2730        Some(vec![NonMappedSubModel::Vision])
2731    }
2732}
2733
2734// ======================== MiniCpm-O loader
2735
2736/// [`MultimodalLoader`] for an MiniCpm-O model.
2737///
2738/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
2739pub struct MiniCpmOLoader;
2740
2741pub struct MiniCpmOPrefixer;
2742
2743impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2744    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2745        format!(
2746            "{}{prompt}",
2747            "(<image>./</image>)".repeat(image_indexes.len())
2748        )
2749    }
2750}
2751
2752impl MultimodalModelLoader for MiniCpmOLoader {
2753    fn load(
2754        &self,
2755        config: &str,
2756        vb: ShardedVarBuilder,
2757        normal_loading_metadata: NormalLoadingMetadata,
2758        attention_mechanism: AttentionImplementation,
2759    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
2760        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2761        Ok(Box::new(MiniCpmOModel::new(
2762            &cfg,
2763            vb,
2764            self.is_gptx(config),
2765            normal_loading_metadata,
2766            attention_mechanism,
2767        )?))
2768    }
2769    fn is_gptx(&self, _config: &str) -> bool {
2770        true
2771    }
2772    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2773        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2774        Ok(Box::new(cfg))
2775    }
2776    fn get_processor(
2777        &self,
2778        _model_config: &str,
2779        processor_config: Option<ProcessorConfig>,
2780        preprocessor_config: PreProcessorConfig,
2781        max_edge: Option<u32>,
2782    ) -> Arc<dyn Processor + Send + Sync> {
2783        Arc::new(MiniCpmOProcessor::new(
2784            processor_config.unwrap_or_default(),
2785            preprocessor_config,
2786            max_edge,
2787        ))
2788    }
2789    fn supports_paged_attention(&self, _config: &str) -> bool {
2790        true
2791    }
2792    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2793        Arc::new(MiniCpmOPrefixer)
2794    }
2795    fn modalities(&self, _config: &str) -> Result<Modalities> {
2796        Ok(Modalities {
2797            input: vec![SupportedModality::Text, SupportedModality::Vision],
2798            output: vec![SupportedModality::Text],
2799        })
2800    }
2801}
2802
2803impl IsqModelLoader for MiniCpmOLoader {
2804    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2805        Ok(vec![
2806            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2807            // Attention
2808            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2809            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2810            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2811            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2812            // MLP
2813            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2814            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2815            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2816        ])
2817    }
2818    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2819        self.isq_layer_regexes(config)
2820    }
2821}
2822
2823impl DeviceMappedModelLoader for MiniCpmOLoader {
2824    fn mapped_max_act_size_elems(
2825        &self,
2826        config: &str,
2827        params: &AutoDeviceMapParams,
2828    ) -> Result<usize> {
2829        let AutoDeviceMapParams::Multimodal {
2830            max_seq_len,
2831            max_batch_size,
2832            max_image_shape: _,
2833            max_num_images,
2834        } = params
2835        else {
2836            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2837        };
2838
2839        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2840
2841        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2842        let img_seq_len = (num_patches + 1) * max_num_images;
2843
2844        let max_text_attn = {
2845            // This model injects the vision information directly into the input embeddings
2846            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2847            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2848        };
2849
2850        Ok(max_text_attn)
2851    }
2852
2853    fn non_mapped_max_act_size_elems(
2854        &self,
2855        config: &str,
2856        params: &AutoDeviceMapParams,
2857    ) -> Result<usize> {
2858        let AutoDeviceMapParams::Multimodal {
2859            max_seq_len: _,
2860            max_batch_size,
2861            max_image_shape: _,
2862            max_num_images,
2863        } = params
2864        else {
2865            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2866        };
2867
2868        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2869
2870        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2871        let img_seq_len = num_patches + 1;
2872
2873        let max_vision_attn = {
2874            // do_image_splitting = true
2875            let images_factor = 5;
2876
2877            (max_batch_size * images_factor * max_num_images)
2878                * cfg.vision_config.num_attention_heads
2879                * img_seq_len
2880                * img_seq_len
2881        };
2882
2883        Ok(max_vision_attn)
2884    }
2885
2886    fn non_mapped_size_in_bytes(
2887        &self,
2888        config: &str,
2889        dtype: DType,
2890        weight_pack_factor: usize,
2891        _matformer_config: Option<&MatformerSliceConfig>,
2892    ) -> Result<usize> {
2893        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2894        let text_elems = {
2895            let cfg = &cfg.text_config;
2896
2897            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2898            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2899            let norm = cfg.hidden_size;
2900            embed_tokens + lm_head + norm
2901        };
2902
2903        let vision_transformer = {
2904            let cfg = &cfg.vision_config;
2905
2906            let post_layernorm = cfg.hidden_size;
2907
2908            let conv_config = Conv2dConfig {
2909                stride: cfg.patch_size,
2910                ..Default::default()
2911            };
2912            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2913                * cfg.patch_size
2914                * cfg.patch_size;
2915
2916            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2917            let num_patches = num_patches_per_side.pow(2);
2918            let position_embedding = num_patches * cfg.hidden_size;
2919
2920            let layer_elems = {
2921                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2922                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2923
2924                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2925                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2926
2927                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2928                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2929                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2930                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2931
2932                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2933            };
2934
2935            post_layernorm
2936                + patch_embedding
2937                + position_embedding
2938                + layer_elems * cfg.num_hidden_layers
2939        };
2940
2941        let elems = text_elems + vision_transformer;
2942
2943        Ok(elems * dtype.size_in_bytes())
2944    }
2945
2946    fn layer_sizes_in_bytes(
2947        &self,
2948        config: &str,
2949        dtype: DType,
2950        weight_pack_factor: usize,
2951        _matformer_config: Option<&MatformerSliceConfig>,
2952    ) -> Result<Vec<usize>> {
2953        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2954        let cfg = cfg.text_config;
2955        let per_layer_elems = {
2956            let input_layernorm = cfg.hidden_size;
2957            let post_attention_layernorm = cfg.hidden_size;
2958
2959            let size_in = cfg.hidden_size;
2960            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2961            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2962            let q_proj = size_in * size_q / weight_pack_factor;
2963            let k_proj = size_in * size_kv / weight_pack_factor;
2964            let v_proj = size_in * size_kv / weight_pack_factor;
2965            let o_proj = size_q * size_in / weight_pack_factor;
2966
2967            let h_size = cfg.hidden_size;
2968            let i_size = cfg.intermediate_size;
2969            let gate_proj = h_size * i_size / weight_pack_factor;
2970            let up_proj = h_size * i_size / weight_pack_factor;
2971            let down_proj = i_size * h_size / weight_pack_factor;
2972
2973            input_layernorm
2974                + post_attention_layernorm
2975                + q_proj
2976                + k_proj
2977                + v_proj
2978                + o_proj
2979                + gate_proj
2980                + up_proj
2981                + down_proj
2982        };
2983        Ok(vec![
2984            per_layer_elems * dtype.size_in_bytes();
2985            cfg.num_hidden_layers
2986        ])
2987    }
2988
2989    fn num_layers(&self, config: &str) -> Result<usize> {
2990        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2991        Ok(cfg.text_config.num_hidden_layers)
2992    }
2993    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2994        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2995        let cfg = &cfg.text_config;
2996
2997        let cfg = ModelConfigMetadata {
2998            max_seq_len: cfg.max_position_embeddings,
2999            num_layers: cfg.num_hidden_layers,
3000            hidden_size: cfg.hidden_size,
3001            num_kv_heads: cfg.num_key_value_heads,
3002            num_attn_heads: cfg.num_attention_heads,
3003            sliding_window: None,
3004            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3005            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3006            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3007        };
3008
3009        Ok(Box::new(cfg))
3010    }
3011}
3012
3013// ======================== Phi 4MM loader
3014
3015/// [`MultimodalLoader`] for a Phi 4MM Vision model.
3016///
3017/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
3018pub struct Phi4MMLoader;
3019
3020pub struct Phi4MMPrefixer;
3021
3022impl MultimodalPromptPrefixer for Phi4MMPrefixer {
3023    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3024        // Image indexing starts at 0.
3025
3026        format!(
3027            "{}{prompt}",
3028            image_indexes
3029                .into_iter()
3030                .map(|image_index| format!("<|image_{}|>", image_index + 1))
3031                .join("")
3032        )
3033    }
3034    fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
3035        // Image indexing starts at 0.
3036
3037        format!(
3038            "{}{prompt}",
3039            audio_indexes
3040                .into_iter()
3041                .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
3042                .join("")
3043        )
3044    }
3045}
3046
3047impl MultimodalModelLoader for Phi4MMLoader {
3048    fn load(
3049        &self,
3050        config: &str,
3051        vb: ShardedVarBuilder,
3052        normal_loading_metadata: NormalLoadingMetadata,
3053        attention_mechanism: AttentionImplementation,
3054    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
3055        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
3056        Ok(Box::new(Phi4MMModel::new(
3057            &cfg,
3058            vb,
3059            self.is_gptx(config),
3060            normal_loading_metadata,
3061            attention_mechanism,
3062        )?))
3063    }
3064    fn is_gptx(&self, _config: &str) -> bool {
3065        true
3066    }
3067    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3068        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
3069        Ok(Box::new(cfg))
3070    }
3071    fn get_processor(
3072        &self,
3073        _model_config: &str,
3074        processor_config: Option<ProcessorConfig>,
3075        preprocessor_config: PreProcessorConfig,
3076        _max_edge: Option<u32>,
3077    ) -> Arc<dyn Processor + Send + Sync> {
3078        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
3079    }
3080    fn supports_paged_attention(&self, _config: &str) -> bool {
3081        true
3082    }
3083    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3084        true
3085    }
3086    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3087        Arc::new(Phi4MMPrefixer)
3088    }
3089    fn modalities(&self, _config: &str) -> Result<Modalities> {
3090        Ok(Modalities {
3091            input: vec![
3092                SupportedModality::Text,
3093                SupportedModality::Vision,
3094                SupportedModality::Audio,
3095            ],
3096            output: vec![SupportedModality::Text],
3097        })
3098    }
3099}
3100
3101impl IsqModelLoader for Phi4MMLoader {
3102    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3103        Ok(vec![
3104            Regex::new(r"lm_head\.(weight|bias)$")?,
3105            // Attention
3106            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3107            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3108            // MLP
3109            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3110            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3111        ])
3112    }
3113    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3114        self.isq_layer_regexes(config)
3115    }
3116}
3117
3118impl DeviceMappedModelLoader for Phi4MMLoader {
3119    fn mapped_max_act_size_elems(
3120        &self,
3121        config: &str,
3122        params: &AutoDeviceMapParams,
3123    ) -> Result<usize> {
3124        // NOTE: we ignore max_num_images although it can only be one...
3125        let AutoDeviceMapParams::Multimodal {
3126            max_seq_len,
3127            max_batch_size,
3128            max_image_shape: _,
3129            max_num_images,
3130        } = params
3131        else {
3132            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3133        };
3134
3135        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3136
3137        let vcfg = &PHI4_MM_VISION_CFG;
3138
3139        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3140        let img_seq_len = (num_patches + 1) * max_num_images;
3141
3142        let max_text_attn = {
3143            // This model injects the vision information directly into the input embeddings
3144            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3145            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3146        };
3147
3148        Ok(max_text_attn)
3149    }
3150
3151    fn non_mapped_max_act_size_elems(
3152        &self,
3153        _config: &str,
3154        params: &AutoDeviceMapParams,
3155    ) -> Result<usize> {
3156        let AutoDeviceMapParams::Multimodal {
3157            max_seq_len: _,
3158            max_batch_size,
3159            max_image_shape,
3160            max_num_images,
3161        } = params
3162        else {
3163            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3164        };
3165
3166        let vcfg = &PHI4_MM_VISION_CFG;
3167
3168        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3169        let img_seq_len = num_patches + 1;
3170
3171        let max_batch_size = max_batch_size
3172            * (max_image_shape
3173                .0
3174                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3175                * max_image_shape
3176                    .1
3177                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3178                + 1);
3179
3180        let max_vision_attn = (max_batch_size * max_num_images)
3181            * vcfg.num_attention_heads
3182            * img_seq_len
3183            * img_seq_len;
3184        let max_qkv = 3
3185            * (max_batch_size
3186                * vcfg.num_attention_heads
3187                * img_seq_len
3188                * (vcfg.hidden_size / vcfg.num_attention_heads));
3189
3190        Ok(max_vision_attn + max_qkv)
3191    }
3192
3193    fn non_mapped_size_in_bytes(
3194        &self,
3195        config: &str,
3196        dtype: DType,
3197        weight_pack_factor: usize,
3198        _matformer_config: Option<&MatformerSliceConfig>,
3199    ) -> Result<usize> {
3200        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3201        let elems = {
3202            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3203            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3204            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3205                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3206            } else {
3207                0
3208            };
3209            let norm = cfg.hidden_size;
3210
3211            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3212                let projection_cls = img_embed
3213                    .projection_cls
3214                    .clone()
3215                    .unwrap_or("linear".to_string());
3216                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3217                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3218                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3219
3220                let proj = match (projection_cls.as_str(), use_hd_transform) {
3221                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3222                    ("mlp", true) => {
3223                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3224                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3225                        a + b
3226                    }
3227                    ("mlp", false) => {
3228                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3229                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3230                        a + b
3231                    }
3232                    _ => {
3233                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3234                    }
3235                };
3236
3237                let (glb_gn, sub_gn) = if with_learnable_separator {
3238                    let glb_gn = image_dim_out * 4;
3239                    let sub_gn = image_dim_out * 4;
3240                    (glb_gn, sub_gn)
3241                } else {
3242                    (0, 0)
3243                };
3244
3245                let vision_transformer = {
3246                    let cfg = &PHI4_MM_VISION_CFG;
3247
3248                    let post_layernorm = cfg.hidden_size;
3249
3250                    let conv_config = Conv2dConfig {
3251                        stride: cfg.patch_size,
3252                        ..Default::default()
3253                    };
3254                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3255                        * cfg.patch_size
3256                        * cfg.patch_size;
3257
3258                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3259                    let num_patches = num_patches_per_side.pow(2);
3260                    let position_embedding = num_patches * cfg.hidden_size;
3261
3262                    let layer_elems = {
3263                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3264                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3265
3266                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3267                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3268
3269                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3270                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3271                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3272                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3273
3274                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3275                    };
3276
3277                    post_layernorm
3278                        + patch_embedding
3279                        + position_embedding
3280                        + layer_elems * cfg.num_hidden_layers
3281                };
3282
3283                proj + glb_gn + sub_gn + vision_transformer
3284            } else {
3285                0
3286            };
3287
3288            embed_tokens + lm_head + norm + image_embed
3289        };
3290
3291        Ok(elems * dtype.size_in_bytes())
3292    }
3293
3294    fn layer_sizes_in_bytes(
3295        &self,
3296        config: &str,
3297        dtype: DType,
3298        weight_pack_factor: usize,
3299        _matformer_config: Option<&MatformerSliceConfig>,
3300    ) -> Result<Vec<usize>> {
3301        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3302        let per_layer_elems = {
3303            let input_layernorm = cfg.hidden_size;
3304            let post_attention_layernorm = cfg.hidden_size;
3305
3306            let size_in = cfg.hidden_size;
3307            let head_dim = cfg.head_dim();
3308            let op_size =
3309                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3310            let qkv_proj = size_in * op_size / weight_pack_factor;
3311            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3312
3313            let h_size = cfg.hidden_size;
3314            let i_size = cfg.intermediate_size;
3315            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3316            let down_proj = h_size * i_size / weight_pack_factor;
3317
3318            input_layernorm
3319                + post_attention_layernorm
3320                + qkv_proj
3321                + o_proj
3322                + gate_up_proj
3323                + down_proj
3324        };
3325        Ok(vec![
3326            per_layer_elems * dtype.size_in_bytes();
3327            cfg.num_hidden_layers
3328        ])
3329    }
3330
3331    fn num_layers(&self, config: &str) -> Result<usize> {
3332        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3333        Ok(cfg.num_hidden_layers)
3334    }
3335
3336    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3337        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3338
3339        let cfg = ModelConfigMetadata {
3340            max_seq_len: cfg.max_position_embeddings,
3341            num_layers: cfg.num_hidden_layers,
3342            hidden_size: cfg.hidden_size,
3343            num_kv_heads: cfg.num_key_value_heads(),
3344            num_attn_heads: cfg.num_attention_heads,
3345            sliding_window: cfg.sliding_window,
3346            k_head_dim: cfg.head_dim(),
3347            v_head_dim: cfg.head_dim(),
3348            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3349        };
3350
3351        Ok(Box::new(cfg))
3352    }
3353
3354    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3355        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3356    }
3357}
3358
3359// ======================== Qwen2_5VL Loader
3360
3361/// [`MultimodalLoader`] for an Qwen2_5VL model.
3362///
3363/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
3364pub struct Qwen2_5VLLoader;
3365
3366pub struct Qwen2_5VLPrefixer;
3367
3368impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3369    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3370        format!(
3371            "{}{prompt}",
3372            format!(
3373                "{}{}{}",
3374                Qwen2_5VLProcessor::VISION_START,
3375                Qwen2_5VLProcessor::IMAGE_PAD,
3376                Qwen2_5VLProcessor::VISION_END
3377            )
3378            .repeat(image_indexes.len())
3379        )
3380    }
3381}
3382
3383impl MultimodalModelLoader for Qwen2_5VLLoader {
3384    fn load(
3385        &self,
3386        config: &str,
3387        vb: ShardedVarBuilder,
3388        normal_loading_metadata: NormalLoadingMetadata,
3389        attention_mechanism: AttentionImplementation,
3390    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
3391        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3392        Ok(Box::new(Qwen2_5VLModel::new(
3393            &cfg,
3394            vb,
3395            self.is_gptx(config),
3396            normal_loading_metadata,
3397            attention_mechanism,
3398        )?))
3399    }
3400    fn is_gptx(&self, _config: &str) -> bool {
3401        true
3402    }
3403    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3404        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3405        Ok(Box::new(config))
3406    }
3407    fn get_processor(
3408        &self,
3409        _model_config: &str,
3410        _processor_config: Option<ProcessorConfig>,
3411        _preprocessor_config: PreProcessorConfig,
3412        max_edge: Option<u32>,
3413    ) -> Arc<dyn Processor + Send + Sync> {
3414        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3415    }
3416    fn supports_paged_attention(&self, _config: &str) -> bool {
3417        false
3418    }
3419    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3420        Arc::new(Qwen2_5VLPrefixer)
3421    }
3422    fn modalities(&self, _config: &str) -> Result<Modalities> {
3423        Ok(Modalities {
3424            input: vec![SupportedModality::Text, SupportedModality::Vision],
3425            output: vec![SupportedModality::Text],
3426        })
3427    }
3428}
3429
3430impl IsqModelLoader for Qwen2_5VLLoader {
3431    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3432        Ok(vec![
3433            Regex::new(r"lm_head\.(weight|bias)$")?,
3434            // Attention
3435            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3436            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3437            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3438            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3439            // MLP
3440            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3441            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3442            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3443        ])
3444    }
3445    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3446        self.isq_layer_regexes(config)
3447    }
3448}
3449
3450impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3451    fn mapped_max_act_size_elems(
3452        &self,
3453        config: &str,
3454        params: &AutoDeviceMapParams,
3455    ) -> Result<usize> {
3456        let AutoDeviceMapParams::Multimodal {
3457            max_seq_len,
3458            max_batch_size,
3459            max_image_shape,
3460            max_num_images,
3461        } = params
3462        else {
3463            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3464        };
3465
3466        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3467
3468        let img_seq_len = {
3469            let cfg = &cfg.vision_config;
3470            let grid_t = max_num_images / cfg.temporal_patch_size;
3471            let grid_h = max_image_shape.0 / cfg.patch_size;
3472            let grid_w = max_image_shape.1 / cfg.patch_size;
3473            grid_t * grid_h * grid_w
3474        };
3475        let img_seq_len = img_seq_len * max_num_images;
3476
3477        let max_text_attn = {
3478            // This model injects the vision information directly into the input embeddings
3479            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3480            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3481        };
3482
3483        Ok(max_text_attn)
3484    }
3485
3486    fn non_mapped_max_act_size_elems(
3487        &self,
3488        config: &str,
3489        params: &AutoDeviceMapParams,
3490    ) -> Result<usize> {
3491        let AutoDeviceMapParams::Multimodal {
3492            max_seq_len: _,
3493            max_batch_size,
3494            max_image_shape,
3495            max_num_images,
3496        } = params
3497        else {
3498            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3499        };
3500
3501        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3502
3503        let img_seq_len = {
3504            let cfg = &cfg.vision_config;
3505            let grid_t = max_num_images / cfg.temporal_patch_size;
3506            let grid_h = max_image_shape.0 / cfg.patch_size;
3507            let grid_w = max_image_shape.1 / cfg.patch_size;
3508            grid_t * grid_h * grid_w
3509        };
3510
3511        let max_vision_attn = {
3512            let cfg = &cfg.vision_config;
3513            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3514        };
3515
3516        Ok(max_vision_attn)
3517    }
3518
3519    fn non_mapped_size_in_bytes(
3520        &self,
3521        config: &str,
3522        dtype: DType,
3523        weight_pack_factor: usize,
3524        _matformer_config: Option<&MatformerSliceConfig>,
3525    ) -> Result<usize> {
3526        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3527        let text_elems = {
3528            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3529            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3530            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3531                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3532            } else {
3533                0
3534            };
3535            let norm = cfg.hidden_size;
3536            embed_tokens + lm_head + norm
3537        };
3538
3539        let patch_merger = {
3540            let cfg = &cfg.vision_config;
3541            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3542
3543            let mlp0 = hidden_size * hidden_size + hidden_size;
3544            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3545
3546            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3547
3548            mlp0 + mlp2 + ln_q
3549        };
3550
3551        let patch_embed = {
3552            let cfg = &cfg.vision_config;
3553            let conv_cfg = Conv3dConfig {
3554                stride: cfg.patch_size,
3555                ..Default::default()
3556            };
3557            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3558            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3559                * kernel_sizes[0]
3560                * kernel_sizes[1]
3561                * kernel_sizes[2]
3562        };
3563
3564        let encoder_layer = {
3565            let cfg = &cfg.vision_config;
3566            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3567            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3568
3569            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3570            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3571            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3572
3573            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3574            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3575
3576            norm1 + norm2 + fc1 + fc2 + qkv + out
3577        };
3578
3579        let elems =
3580            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3581
3582        Ok(elems * dtype.size_in_bytes())
3583    }
3584
3585    fn layer_sizes_in_bytes(
3586        &self,
3587        config: &str,
3588        dtype: DType,
3589        weight_pack_factor: usize,
3590        _matformer_config: Option<&MatformerSliceConfig>,
3591    ) -> Result<Vec<usize>> {
3592        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3593        let per_layer_elems = {
3594            let input_layernorm = cfg.hidden_size;
3595            let post_attention_layernorm = cfg.hidden_size;
3596
3597            let size_in = cfg.hidden_size;
3598            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3599            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3600            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3601            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3602            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3603            let o_proj = size_q * size_in / weight_pack_factor;
3604
3605            let h_size = cfg.hidden_size;
3606            let i_size = cfg.intermediate_size;
3607            let gate_proj = h_size * i_size / weight_pack_factor;
3608            let up_proj = h_size * i_size / weight_pack_factor;
3609            let down_proj = i_size * h_size / weight_pack_factor;
3610
3611            input_layernorm
3612                + post_attention_layernorm
3613                + q_proj
3614                + k_proj
3615                + v_proj
3616                + o_proj
3617                + gate_proj
3618                + up_proj
3619                + down_proj
3620        };
3621        Ok(vec![
3622            per_layer_elems * dtype.size_in_bytes();
3623            cfg.num_hidden_layers
3624        ])
3625    }
3626
3627    fn num_layers(&self, config: &str) -> Result<usize> {
3628        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3629        Ok(cfg.num_hidden_layers)
3630    }
3631
3632    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3633        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3634
3635        let cfg = ModelConfigMetadata {
3636            max_seq_len: cfg.max_position_embeddings,
3637            num_layers: cfg.num_hidden_layers,
3638            hidden_size: cfg.hidden_size,
3639            num_kv_heads: cfg.num_key_value_heads,
3640            num_attn_heads: cfg.num_attention_heads,
3641            sliding_window: cfg.sliding_window,
3642            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3643            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3644            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3645        };
3646
3647        Ok(Box::new(cfg))
3648    }
3649
3650    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3651        Some(vec![NonMappedSubModel::Vision])
3652    }
3653}
3654
3655// ======================== Gemma 3 Loader
3656
3657/// [`MultimodalLoader`] for an Gemma 3 model.
3658///
3659/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
3660pub struct Gemma3Loader;
3661
3662pub struct Gemma3Prefixer;
3663
3664impl MultimodalPromptPrefixer for Gemma3Prefixer {
3665    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3666        prompt.to_string()
3667    }
3668}
3669
3670impl MultimodalModelLoader for Gemma3Loader {
3671    fn load(
3672        &self,
3673        config: &str,
3674        vb: ShardedVarBuilder,
3675        normal_loading_metadata: NormalLoadingMetadata,
3676        attention_mechanism: AttentionImplementation,
3677    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
3678        let cfg: Gemma3Config = serde_json::from_str(config)?;
3679        Ok(Box::new(Gemma3Model::new(
3680            &cfg,
3681            vb,
3682            self.is_gptx(config),
3683            normal_loading_metadata,
3684            attention_mechanism,
3685        )?))
3686    }
3687    fn is_gptx(&self, _config: &str) -> bool {
3688        true
3689    }
3690    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3691        let config: Gemma3Config = serde_json::from_str(config)?;
3692        Ok(Box::new(config))
3693    }
3694    fn get_processor(
3695        &self,
3696        config: &str,
3697        processor_config: Option<ProcessorConfig>,
3698        _preprocessor_config: PreProcessorConfig,
3699        _max_edge: Option<u32>,
3700    ) -> Arc<dyn Processor + Send + Sync> {
3701        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3702        // Handle the Gemma 3 1b case here
3703        Arc::new(Gemma3Processor::new(
3704            processor_config.unwrap_or_default(),
3705            matches!(config, Gemma3Config::WithVision { .. }),
3706        ))
3707    }
3708    fn supports_paged_attention(&self, _config: &str) -> bool {
3709        true
3710    }
3711    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3712        true
3713    }
3714    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3715        Arc::new(Gemma3Prefixer)
3716    }
3717    fn modalities(&self, _config: &str) -> Result<Modalities> {
3718        Ok(Modalities {
3719            input: vec![SupportedModality::Text, SupportedModality::Vision],
3720            output: vec![SupportedModality::Text],
3721        })
3722    }
3723}
3724
3725impl IsqModelLoader for Gemma3Loader {
3726    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3727        Ok(vec![
3728            Regex::new(r"lm_head\.(weight|bias)$")?,
3729            // Attention
3730            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3731            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3732            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3733            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3734            // MLP
3735            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3736            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3737            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3738        ])
3739    }
3740    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3741        Ok(vec![
3742            Regex::new(r"lm_head\.(weight|bias)$")?,
3743            // Attention
3744            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3745            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3746            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3747            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3748            // MLP
3749            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3750            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3751            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3752        ])
3753    }
3754}
3755
3756impl DeviceMappedModelLoader for Gemma3Loader {
3757    fn mapped_max_act_size_elems(
3758        &self,
3759        config: &str,
3760        params: &AutoDeviceMapParams,
3761    ) -> Result<usize> {
3762        let AutoDeviceMapParams::Multimodal {
3763            max_seq_len,
3764            max_batch_size,
3765            max_image_shape: _,
3766            max_num_images,
3767        } = params
3768        else {
3769            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3770        };
3771
3772        let cfg: Gemma3Config = serde_json::from_str(config)?;
3773
3774        match cfg {
3775            Gemma3Config::Text(text_config) => Ok(max_batch_size
3776                * text_config.num_attention_heads
3777                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3778            Gemma3Config::WithVision {
3779                text_config,
3780                vision_config,
3781                ..
3782            } => {
3783                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3784                let img_seq_len = (num_patches + 1) * max_num_images;
3785
3786                let max_text_attn = {
3787                    // This model injects the vision information directly into the input embeddings
3788                    let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3789                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3790                };
3791                Ok(max_text_attn)
3792            }
3793        }
3794    }
3795
3796    fn non_mapped_max_act_size_elems(
3797        &self,
3798        config: &str,
3799        params: &AutoDeviceMapParams,
3800    ) -> Result<usize> {
3801        let AutoDeviceMapParams::Multimodal {
3802            max_seq_len: _,
3803            max_batch_size,
3804            max_image_shape: _,
3805            max_num_images,
3806        } = params
3807        else {
3808            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3809        };
3810
3811        let cfg: Gemma3Config = serde_json::from_str(config)?;
3812
3813        match cfg {
3814            Gemma3Config::WithVision { vision_config, .. } => {
3815                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3816                let img_seq_len = num_patches + 1;
3817
3818                let max_vision_attn = {
3819                    (max_batch_size * max_num_images)
3820                        * vision_config.num_attention_heads
3821                        * img_seq_len
3822                        * img_seq_len
3823                };
3824
3825                Ok(max_vision_attn)
3826            }
3827            Gemma3Config::Text(_) => Ok(0),
3828        }
3829    }
3830
3831    fn non_mapped_size_in_bytes(
3832        &self,
3833        config: &str,
3834        dtype: DType,
3835        weight_pack_factor: usize,
3836        _matformer_config: Option<&MatformerSliceConfig>,
3837    ) -> Result<usize> {
3838        let cfg: Gemma3Config = serde_json::from_str(config)?;
3839
3840        let text_elems = {
3841            let cfg = match &cfg {
3842                Gemma3Config::Text(cfg) => cfg,
3843                Gemma3Config::WithVision { text_config, .. } => text_config,
3844            };
3845            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3846            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3847            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3848                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3849            } else {
3850                0
3851            };
3852            let norm = cfg.hidden_size;
3853            embed_tokens + lm_head + norm
3854        };
3855
3856        let vision_transformer = if let Gemma3Config::WithVision {
3857            vision_config: cfg, ..
3858        } = &cfg
3859        {
3860            let post_layernorm = cfg.hidden_size;
3861
3862            let conv_config = Conv2dConfig {
3863                stride: cfg.patch_size,
3864                ..Default::default()
3865            };
3866            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3867                * cfg.patch_size
3868                * cfg.patch_size;
3869
3870            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3871            let num_patches = num_patches_per_side.pow(2);
3872            let position_embedding = num_patches * cfg.hidden_size;
3873
3874            let layer_elems = {
3875                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3876                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3877
3878                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3879                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3880
3881                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3882                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3883                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3884                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3885
3886                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3887            };
3888
3889            post_layernorm
3890                + patch_embedding
3891                + position_embedding
3892                + layer_elems * cfg.num_hidden_layers
3893        } else {
3894            0
3895        };
3896
3897        let elems = text_elems + vision_transformer;
3898
3899        Ok(elems * dtype.size_in_bytes())
3900    }
3901
3902    fn layer_sizes_in_bytes(
3903        &self,
3904        config: &str,
3905        dtype: DType,
3906        weight_pack_factor: usize,
3907        _matformer_config: Option<&MatformerSliceConfig>,
3908    ) -> Result<Vec<usize>> {
3909        let cfg: Gemma3Config = serde_json::from_str(config)?;
3910
3911        let txt_cfg = match &cfg {
3912            Gemma3Config::Text(cfg) => cfg,
3913            Gemma3Config::WithVision { text_config, .. } => text_config,
3914        };
3915        let per_layer_elems = {
3916            let cfg = txt_cfg;
3917
3918            let input_layernorm = cfg.hidden_size;
3919            let post_attention_layernorm = cfg.hidden_size;
3920
3921            let size_in = cfg.hidden_size;
3922            let size_q = cfg.head_dim * cfg.num_attention_heads;
3923            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3924            let q_proj =
3925                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3926            let k_proj =
3927                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3928            let v_proj =
3929                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3930            let o_proj =
3931                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3932
3933            let h_size = cfg.hidden_size;
3934            let i_size = cfg.intermediate_size;
3935            let gate_proj = h_size * i_size / weight_pack_factor;
3936            let up_proj = h_size * i_size / weight_pack_factor;
3937            let down_proj = i_size * h_size / weight_pack_factor;
3938
3939            input_layernorm
3940                + post_attention_layernorm
3941                + q_proj
3942                + k_proj
3943                + v_proj
3944                + o_proj
3945                + gate_proj
3946                + up_proj
3947                + down_proj
3948        };
3949        Ok(vec![
3950            per_layer_elems * dtype.size_in_bytes();
3951            txt_cfg.num_hidden_layers
3952        ])
3953    }
3954
3955    fn num_layers(&self, config: &str) -> Result<usize> {
3956        let cfg: Gemma3Config = serde_json::from_str(config)?;
3957
3958        let txt_cfg = match &cfg {
3959            Gemma3Config::Text(cfg) => cfg,
3960            Gemma3Config::WithVision { text_config, .. } => text_config,
3961        };
3962
3963        Ok(txt_cfg.num_hidden_layers)
3964    }
3965
3966    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3967        let cfg: Gemma3Config = serde_json::from_str(config)?;
3968
3969        let cfg = match &cfg {
3970            Gemma3Config::Text(cfg) => cfg,
3971            Gemma3Config::WithVision { text_config, .. } => text_config,
3972        };
3973
3974        let cfg = ModelConfigMetadata {
3975            max_seq_len: cfg.max_position_embeddings,
3976            num_layers: cfg.num_hidden_layers,
3977            hidden_size: cfg.hidden_size,
3978            num_kv_heads: cfg.num_key_value_heads,
3979            num_attn_heads: cfg.num_attention_heads,
3980            sliding_window: None, // None to be more forgiving, some do not
3981            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3982            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3983            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3984        };
3985
3986        Ok(Box::new(cfg))
3987    }
3988
3989    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3990        Some(vec![NonMappedSubModel::Vision])
3991    }
3992}
3993
3994// ======================== Mistral 3 Loader
3995
3996/// [`MultimodalLoader`] for an Mistral 3 model.
3997///
3998/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
3999pub struct Mistral3Loader;
4000
4001pub struct Mistral3Prefixer;
4002
4003impl MultimodalPromptPrefixer for Mistral3Prefixer {
4004    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4005        prompt.to_string()
4006    }
4007}
4008
4009impl MultimodalModelLoader for Mistral3Loader {
4010    fn load(
4011        &self,
4012        config: &str,
4013        vb: ShardedVarBuilder,
4014        normal_loading_metadata: NormalLoadingMetadata,
4015        attention_mechanism: AttentionImplementation,
4016    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
4017        let mut cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
4018        cfg.propagate_quantization_config();
4019        Ok(Box::new(Mistral3Model::new(
4020            &cfg,
4021            vb,
4022            self.is_gptx(config),
4023            normal_loading_metadata,
4024            attention_mechanism,
4025        )?))
4026    }
4027    fn is_gptx(&self, _config: &str) -> bool {
4028        true
4029    }
4030    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4031        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
4032        Ok(Box::new(cfg))
4033    }
4034    fn get_processor(
4035        &self,
4036        _model_config: &str,
4037        processor_config: Option<ProcessorConfig>,
4038        _preprocessor_config: PreProcessorConfig,
4039        _max_edge: Option<u32>,
4040    ) -> Arc<dyn Processor + Send + Sync> {
4041        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
4042    }
4043    fn supports_paged_attention(&self, _config: &str) -> bool {
4044        true
4045    }
4046    fn supports_prefix_cacher(&self, _config: &str) -> bool {
4047        true
4048    }
4049    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4050        Arc::new(Mistral3Prefixer)
4051    }
4052    fn modalities(&self, _config: &str) -> Result<Modalities> {
4053        Ok(Modalities {
4054            input: vec![SupportedModality::Text, SupportedModality::Vision],
4055            output: vec![SupportedModality::Text],
4056        })
4057    }
4058}
4059
4060impl IsqModelLoader for Mistral3Loader {
4061    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4062        Ok(vec![
4063            Regex::new(r"lm_head\.(weight|bias)$")?,
4064            // Attention
4065            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4066            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4067            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4068            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4069            // MLP
4070            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4071            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4072            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4073        ])
4074    }
4075    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4076        Ok(vec![
4077            Regex::new(r"lm_head\.(weight|bias)$")?,
4078            // Attention
4079            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4080            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4081            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4082            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4083            // MLP
4084            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4085            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4086            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4087        ])
4088    }
4089}
4090
4091#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4092impl DeviceMappedModelLoader for Mistral3Loader {
4093    fn mapped_max_act_size_elems(
4094        &self,
4095        config: &str,
4096        params: &AutoDeviceMapParams,
4097    ) -> Result<usize> {
4098        let cfg: Mistral3Config = serde_json::from_str(config)?;
4099        let vcfg = &cfg.vision_config;
4100        let tcfg = &cfg.text_config;
4101
4102        let AutoDeviceMapParams::Multimodal {
4103            max_seq_len,
4104            max_batch_size,
4105            max_image_shape: (mut height, mut width),
4106            max_num_images,
4107        } = params
4108        else {
4109            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4110        };
4111
4112        let img_seq_len = {
4113            // Reshaping algorithm
4114
4115            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4116            let (max_height, max_width) = (1540, 1540);
4117            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4118            if ratio > 1. {
4119                height = (height as f64 / ratio).floor() as usize;
4120                width = (width as f64 / ratio).floor() as usize;
4121            }
4122
4123            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4124            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4125
4126            height = num_height_tokens * vcfg.patch_size;
4127            width = num_width_tokens * vcfg.patch_size;
4128
4129            let num_height_tokens = height / vcfg.patch_size;
4130            let num_width_tokens = width / vcfg.patch_size;
4131
4132            (num_width_tokens + 1) * num_height_tokens
4133        };
4134
4135        // This model injects the vision information directly into the input embeddings
4136        let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4137        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4138    }
4139
4140    fn non_mapped_max_act_size_elems(
4141        &self,
4142        config: &str,
4143        params: &AutoDeviceMapParams,
4144    ) -> Result<usize> {
4145        let cfg: Mistral3Config = serde_json::from_str(config)?;
4146        let cfg = &cfg.vision_config;
4147
4148        let AutoDeviceMapParams::Multimodal {
4149            max_seq_len: _,
4150            max_batch_size,
4151            max_image_shape: (mut height, mut width),
4152            max_num_images,
4153        } = params
4154        else {
4155            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4156        };
4157
4158        let img_seq_len = {
4159            // Reshaping algorithm
4160
4161            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4162            let (max_height, max_width) = (1540, 1540);
4163            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4164            if ratio > 1. {
4165                height = (height as f64 / ratio).floor() as usize;
4166                width = (width as f64 / ratio).floor() as usize;
4167            }
4168
4169            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4170            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4171
4172            height = num_height_tokens * cfg.patch_size;
4173            width = num_width_tokens * cfg.patch_size;
4174
4175            let num_height_tokens = height / cfg.patch_size;
4176            let num_width_tokens = width / cfg.patch_size;
4177
4178            (num_width_tokens + 1) * num_height_tokens
4179        };
4180
4181        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4182    }
4183
4184    fn non_mapped_size_in_bytes(
4185        &self,
4186        config: &str,
4187        dtype: DType,
4188        weight_pack_factor: usize,
4189        _matformer_config: Option<&MatformerSliceConfig>,
4190    ) -> Result<usize> {
4191        let cfg: Mistral3Config = serde_json::from_str(config)?;
4192
4193        let text_elems = {
4194            let cfg = &cfg.text_config;
4195
4196            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4197            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4198            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4199                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4200            } else {
4201                0
4202            };
4203            let norm = cfg.hidden_size;
4204            embed_tokens + lm_head + norm
4205        };
4206
4207        let vision_elems = {
4208            let cfg = &cfg.vision_config;
4209
4210            let patch_embed = {
4211                let conv_cfg = Conv2dConfig {
4212                    stride: cfg.patch_size,
4213                    ..Default::default()
4214                };
4215                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4216                    * cfg.patch_size
4217                    * cfg.patch_size
4218                    * cfg.patch_size
4219            };
4220            let ln_pre = cfg.hidden_size;
4221            let vision_layer = {
4222                let attn_norm = cfg.hidden_size;
4223                let ffn_norm = cfg.hidden_size;
4224
4225                let gate = cfg.hidden_size * cfg.intermediate_size;
4226                let up = cfg.hidden_size * cfg.intermediate_size;
4227                let down = cfg.hidden_size * cfg.intermediate_size;
4228
4229                let q = cfg.hidden_size * cfg.hidden_size;
4230                let k = cfg.hidden_size * cfg.hidden_size;
4231                let v = cfg.hidden_size * cfg.hidden_size;
4232                let o = cfg.hidden_size * cfg.hidden_size;
4233
4234                attn_norm + ffn_norm + gate + up + down + q + k + v + o
4235            };
4236
4237            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4238        };
4239
4240        let elems = text_elems + vision_elems;
4241
4242        Ok(elems * dtype.size_in_bytes())
4243    }
4244
4245    fn layer_sizes_in_bytes(
4246        &self,
4247        config: &str,
4248        dtype: DType,
4249        weight_pack_factor: usize,
4250        _matformer_config: Option<&MatformerSliceConfig>,
4251    ) -> Result<Vec<usize>> {
4252        let cfg: Mistral3Config = serde_json::from_str(config)?;
4253        let cfg = &cfg.text_config;
4254
4255        let per_layer_elems = {
4256            let input_layernorm = cfg.hidden_size;
4257            let post_attention_layernorm = cfg.hidden_size;
4258
4259            let size_in = cfg.hidden_size;
4260            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4261            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4262            let q_proj = size_in * size_q / weight_pack_factor;
4263            let k_proj = size_in * size_kv / weight_pack_factor;
4264            let v_proj = size_in * size_kv / weight_pack_factor;
4265            let o_proj = size_q * size_in / weight_pack_factor;
4266
4267            let h_size = cfg.hidden_size;
4268            let i_size = cfg.intermediate_size;
4269            let gate_proj = h_size * i_size / weight_pack_factor;
4270            let up_proj = h_size * i_size / weight_pack_factor;
4271            let down_proj = i_size * h_size / weight_pack_factor;
4272
4273            input_layernorm
4274                + post_attention_layernorm
4275                + q_proj
4276                + k_proj
4277                + v_proj
4278                + o_proj
4279                + gate_proj
4280                + up_proj
4281                + down_proj
4282        };
4283        Ok(vec![
4284            per_layer_elems * dtype.size_in_bytes();
4285            cfg.num_hidden_layers
4286        ])
4287    }
4288
4289    fn num_layers(&self, config: &str) -> Result<usize> {
4290        let cfg: Mistral3Config = serde_json::from_str(config)?;
4291        let cfg = &cfg.text_config;
4292        Ok(cfg.num_hidden_layers)
4293    }
4294
4295    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4296        let cfg: Mistral3Config = serde_json::from_str(config)?;
4297        let cfg = &cfg.text_config;
4298
4299        let cfg = ModelConfigMetadata {
4300            max_seq_len: cfg.max_position_embeddings,
4301            num_layers: cfg.num_hidden_layers,
4302            hidden_size: cfg.hidden_size,
4303            num_kv_heads: cfg.num_key_value_heads,
4304            num_attn_heads: cfg.num_attention_heads,
4305            sliding_window: cfg.sliding_window,
4306            k_head_dim: cfg.head_dim(),
4307            v_head_dim: cfg.head_dim(),
4308            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4309        };
4310
4311        Ok(Box::new(cfg))
4312    }
4313
4314    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4315        Some(vec![NonMappedSubModel::Vision])
4316    }
4317}
4318
4319// ======================== Llama 4 Loader
4320
4321/// [`MultimodalLoader`] for an Llama Vision model.
4322///
4323/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
4324pub struct VLlama4Loader;
4325
4326pub struct VLlama4Prefixer;
4327
4328impl MultimodalPromptPrefixer for VLlama4Prefixer {
4329    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4330        format!(
4331            "{}{prompt}",
4332            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4333        )
4334    }
4335}
4336
4337impl MultimodalModelLoader for VLlama4Loader {
4338    fn load(
4339        &self,
4340        config: &str,
4341        vb: ShardedVarBuilder,
4342        normal_loading_metadata: NormalLoadingMetadata,
4343        attention_mechanism: AttentionImplementation,
4344    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
4345        let mut cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4346        cfg.propagate_quantization_config();
4347        Ok(Box::new(Llama4Model::new(
4348            &cfg,
4349            vb,
4350            self.is_gptx(config),
4351            normal_loading_metadata,
4352            attention_mechanism,
4353        )?))
4354    }
4355    fn is_gptx(&self, _config: &str) -> bool {
4356        false
4357    }
4358    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4359        let mut cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4360        cfg.propagate_quantization_config();
4361        Ok(Box::new(cfg))
4362    }
4363    fn get_processor(
4364        &self,
4365        _model_config: &str,
4366        processor_config: Option<ProcessorConfig>,
4367        _preprocessor_config: PreProcessorConfig,
4368        _max_edge: Option<u32>,
4369    ) -> Arc<dyn Processor + Send + Sync> {
4370        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4371    }
4372    fn supports_paged_attention(&self, _config: &str) -> bool {
4373        true
4374    }
4375    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4376        Arc::new(VLlama4Prefixer)
4377    }
4378    fn modalities(&self, _config: &str) -> Result<Modalities> {
4379        Ok(Modalities {
4380            input: vec![SupportedModality::Text, SupportedModality::Vision],
4381            output: vec![SupportedModality::Text],
4382        })
4383    }
4384}
4385
4386impl IsqModelLoader for VLlama4Loader {
4387    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4388        Ok(vec![
4389            Regex::new(r"lm_head\.(weight|bias)$")?,
4390            // Attention
4391            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4392            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4393            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4394            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4395            // FF MoE
4396            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4397            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4398            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4399            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4400            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4401            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4402            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4403            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4404            // FF MLP
4405            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4406            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4407            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4408        ])
4409    }
4410    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4411        Ok(vec![
4412            Regex::new(r"lm_head\.(weight|bias)$")?,
4413            // Attention
4414            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4415            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4416            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4417            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4418            // FF MoE
4419            Regex::new(
4420                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4421            )?,
4422            Regex::new(
4423                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4424            )?,
4425            Regex::new(
4426                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4427            )?,
4428            Regex::new(
4429                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4430            )?,
4431            Regex::new(
4432                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4433            )?,
4434            Regex::new(
4435                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4436            )?,
4437            Regex::new(
4438                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4439            )?,
4440            Regex::new(
4441                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4442            )?,
4443            // FF MLP
4444            Regex::new(
4445                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4446            )?,
4447            Regex::new(
4448                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4449            )?,
4450            Regex::new(
4451                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4452            )?,
4453        ])
4454    }
4455}
4456
4457impl VLlama4Loader {
4458    /// This incorporates the max batch size!
4459    /// Returns (pixels max batch size, num text image tokens)
4460    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4461    fn run_dummy_processing(
4462        &self,
4463        cfg: &Llama4Config,
4464        height: usize,
4465        width: usize,
4466        max_num_images: usize,
4467        max_batch_size: usize,
4468    ) -> Result<(usize, usize)> {
4469        let cfg = &cfg.vision_config;
4470
4471        let img_processor =
4472            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4473        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4474        let res = img_processor.preprocess(
4475            vec![image; max_num_images],
4476            vec![],
4477            &PreProcessorConfig::default(),
4478            &Device::Cpu,
4479            (max_batch_size, max_num_images),
4480        )?;
4481
4482        let pixels_batch_size = res.pixel_values.dim(0)?;
4483        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4484
4485        let (image_h, image_w) = (
4486            res.pixel_values.dim(D::Minus2).unwrap(),
4487            res.pixel_values.dim(D::Minus1).unwrap(),
4488        );
4489        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4490            * (image_w / img_processor.patch_size)
4491            / img_processor.downsample_ratio;
4492
4493        Ok((
4494            pixels_max_batch_size,
4495            num_patches_per_chunk * pixels_max_batch_size,
4496        ))
4497    }
4498}
4499
4500impl DeviceMappedModelLoader for VLlama4Loader {
4501    fn mapped_max_act_size_elems(
4502        &self,
4503        config: &str,
4504        params: &AutoDeviceMapParams,
4505    ) -> Result<usize> {
4506        let AutoDeviceMapParams::Multimodal {
4507            max_seq_len,
4508            max_batch_size,
4509            max_image_shape: (height, width),
4510            max_num_images,
4511        } = params
4512        else {
4513            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4514        };
4515
4516        let cfg: Llama4Config = serde_json::from_str(config)?;
4517
4518        let (_pixels_batch_size, num_text_image_toks) =
4519            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4520
4521        let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4522
4523        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4524    }
4525    fn non_mapped_max_act_size_elems(
4526        &self,
4527        config: &str,
4528        params: &AutoDeviceMapParams,
4529    ) -> Result<usize> {
4530        let AutoDeviceMapParams::Multimodal {
4531            max_seq_len: _,
4532            max_batch_size,
4533            max_image_shape: (height, width),
4534            max_num_images,
4535        } = params
4536        else {
4537            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4538        };
4539
4540        let cfg: Llama4Config = serde_json::from_str(config)?;
4541
4542        let (pixels_batch_size, _num_text_image_toks) =
4543            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4544        let max_seq_len = cfg.vision_config.num_patches();
4545
4546        Ok((max_batch_size * pixels_batch_size)
4547            * cfg.vision_config.num_attention_heads
4548            * max_seq_len
4549            * max_seq_len)
4550    }
4551
4552    fn non_mapped_size_in_bytes(
4553        &self,
4554        config: &str,
4555        dtype: DType,
4556        weight_pack_factor: usize,
4557        _matformer_config: Option<&MatformerSliceConfig>,
4558    ) -> Result<usize> {
4559        let cfg: Llama4Config = serde_json::from_str(config)?;
4560        let tcfg = &cfg.text_config;
4561
4562        let text_elems = {
4563            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4564            let lm_head = if !tcfg.tie_word_embeddings {
4565                tcfg.hidden_size * tcfg.vocab_size
4566            } else {
4567                0
4568            };
4569            let norm = tcfg.hidden_size;
4570            embed_tokens + lm_head + norm
4571        };
4572
4573        let vision_elems = {
4574            let cfg = &cfg.vision_config;
4575
4576            let num_patches = cfg.num_patches();
4577
4578            let unfold_elems =
4579                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4580            let class_embeddng_elems = cfg.hidden_size;
4581            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4582            let layernorm_pre_elems = cfg.hidden_size;
4583            let layernorm_post_elems = cfg.hidden_size;
4584
4585            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4586                / weight_pack_factor
4587                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4588
4589            let encoder_layer = {
4590                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4591                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4592
4593                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4594                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4595                    / weight_pack_factor
4596                    + cfg.num_attention_heads * head_dim;
4597                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4598                    / weight_pack_factor
4599                    + cfg.num_attention_heads * head_dim;
4600                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4601                    / weight_pack_factor
4602                    + cfg.num_attention_heads * head_dim;
4603                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4604                    / weight_pack_factor
4605                    + cfg.num_attention_heads * head_dim;
4606
4607                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4608                    + cfg.intermediate_size;
4609                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4610                    + cfg.hidden_size;
4611
4612                input_layernorm
4613                    + post_attention_layernorm
4614                    + q_proj
4615                    + k_proj
4616                    + v_proj
4617                    + o_proj
4618                    + fc1
4619                    + fc2
4620            };
4621
4622            unfold_elems
4623                + class_embeddng_elems
4624                + positional_embedding_vlm_elems
4625                + layernorm_post_elems
4626                + layernorm_pre_elems
4627                + pixel_shuffle_elems
4628                + encoder_layer * cfg.num_hidden_layers
4629        };
4630
4631        let elems = text_elems + vision_elems;
4632
4633        Ok(elems * dtype.size_in_bytes())
4634    }
4635
4636    fn layer_sizes_in_bytes(
4637        &self,
4638        config: &str,
4639        dtype: DType,
4640        weight_pack_factor: usize,
4641        _matformer_config: Option<&MatformerSliceConfig>,
4642    ) -> Result<Vec<usize>> {
4643        let cfg: Llama4Config = serde_json::from_str(config)?;
4644        let tcfg = &cfg.text_config;
4645
4646        let mut per_layer_elems = Vec::new();
4647
4648        for layer_idx in 0..tcfg.num_hidden_layers {
4649            let input_layernorm = tcfg.hidden_size;
4650            let post_attention_layernorm = tcfg.hidden_size;
4651
4652            let size_in = tcfg.hidden_size;
4653            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4654            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4655            let q_proj = size_in * size_q / weight_pack_factor;
4656            let k_proj = size_in * size_kv / weight_pack_factor;
4657            let v_proj = size_in * size_kv / weight_pack_factor;
4658            let o_proj = size_q * size_in / weight_pack_factor;
4659
4660            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4661            let moe_block = if use_moe {
4662                let h_size = tcfg.hidden_size;
4663                let i_size = tcfg.intermediate_size;
4664                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4665                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4666                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4667
4668                gate_proj + up_proj + down_proj
4669            } else {
4670                let h_size = tcfg.hidden_size;
4671                let i_size = tcfg.intermediate_size_mlp;
4672                let gate_proj = h_size * i_size / weight_pack_factor;
4673                let up_proj = h_size * i_size / weight_pack_factor;
4674                let down_proj = i_size * h_size / weight_pack_factor;
4675
4676                gate_proj + up_proj + down_proj
4677            };
4678
4679            per_layer_elems.push(
4680                input_layernorm
4681                    + post_attention_layernorm
4682                    + q_proj
4683                    + k_proj
4684                    + v_proj
4685                    + o_proj
4686                    + moe_block,
4687            );
4688        }
4689
4690        Ok(per_layer_elems
4691            .into_iter()
4692            .map(|x| x * dtype.size_in_bytes())
4693            .collect())
4694    }
4695
4696    fn num_layers(&self, config: &str) -> Result<usize> {
4697        let cfg: Llama4Config = serde_json::from_str(config)?;
4698        Ok(cfg.text_config.num_hidden_layers)
4699    }
4700
4701    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4702        let cfg: Llama4Config = serde_json::from_str(config)?;
4703        let cfg = &cfg.text_config;
4704
4705        let cfg = ModelConfigMetadata {
4706            max_seq_len: cfg.max_position_embeddings,
4707            num_layers: cfg.num_hidden_layers,
4708            hidden_size: cfg.hidden_size,
4709            num_kv_heads: cfg.num_attention_heads,
4710            num_attn_heads: cfg.num_attention_heads,
4711            sliding_window: None,
4712            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4713            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4714            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4715        };
4716
4717        Ok(Box::new(cfg))
4718    }
4719
4720    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4721        Some(vec![NonMappedSubModel::Vision])
4722    }
4723}
4724
4725// ======================== Gemma 3n Loader
4726
4727/// [`MultimodalLoader`] for an Gemma 3n model.
4728///
4729/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
4730pub struct Gemma3nLoader;
4731
4732#[allow(dead_code)]
4733pub struct Gemma3nPrefixer;
4734
4735impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4736    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4737        prompt.to_string()
4738    }
4739}
4740
4741impl MultimodalModelLoader for Gemma3nLoader {
4742    fn load(
4743        &self,
4744        config: &str,
4745        vb: ShardedVarBuilder,
4746        normal_loading_metadata: NormalLoadingMetadata,
4747        attention_mechanism: AttentionImplementation,
4748    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
4749        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4750        Ok(Box::new(Gemma3nModel::new(
4751            &cfg,
4752            vb,
4753            self.is_gptx(config),
4754            normal_loading_metadata,
4755            attention_mechanism,
4756        )?))
4757    }
4758    fn is_gptx(&self, _config: &str) -> bool {
4759        true
4760    }
4761    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4762        let config: Gemma3nConfig = serde_json::from_str(config)?;
4763        Ok(Box::new(config))
4764    }
4765    fn get_processor(
4766        &self,
4767        _config: &str,
4768        processor_config: Option<ProcessorConfig>,
4769        _preprocessor_config: PreProcessorConfig,
4770        _max_edge: Option<u32>,
4771    ) -> Arc<dyn Processor + Send + Sync> {
4772        // Handle the Gemma 3 1b case here
4773        Arc::new(Gemma3nProcessor::new(
4774            processor_config.unwrap_or_default(),
4775            true,
4776        ))
4777    }
4778    fn supports_paged_attention(&self, _config: &str) -> bool {
4779        false
4780    }
4781    fn supports_prefix_cacher(&self, _config: &str) -> bool {
4782        true
4783    }
4784    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4785        Arc::new(Gemma3Prefixer)
4786    }
4787    fn modalities(&self, _config: &str) -> Result<Modalities> {
4788        Ok(Modalities {
4789            input: vec![
4790                SupportedModality::Text,
4791                SupportedModality::Vision,
4792                SupportedModality::Audio,
4793            ],
4794            output: vec![SupportedModality::Text],
4795        })
4796    }
4797}
4798
4799impl IsqModelLoader for Gemma3nLoader {
4800    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4801        Ok(vec![
4802            Regex::new(r"lm_head\.(weight|bias)$")?,
4803            // Language model attention
4804            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4805            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4806            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4807            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4808            // Language model MLP
4809            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4810            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4811            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4812            // Audio conformer attention layers
4813            Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4814            Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4815            Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4816            Regex::new(
4817                r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4818            )?,
4819            Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4820            // Audio conformer FFW layers
4821            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4822            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4823            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4824            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4825            // Audio conformer conv1d layers
4826            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4827            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4828            // Audio subsample projection
4829            Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4830            // Multimodal embedders
4831            Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4832            Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4833        ])
4834    }
4835    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4836        Ok(vec![
4837            Regex::new(r"lm_head\.(weight|bias)$")?,
4838            // Language model attention
4839            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4840            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4841            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4842            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4843            // Language model MLP
4844            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4845            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4846            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4847            // Projections
4848            Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4849            Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4850            Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4851            // Audio conformer attention layers
4852            Regex::new(
4853                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4854            )?,
4855            Regex::new(
4856                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4857            )?,
4858            Regex::new(
4859                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4860            )?,
4861            Regex::new(
4862                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4863            )?,
4864            Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4865            // Audio conformer FFW layers
4866            Regex::new(
4867                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4868            )?,
4869            Regex::new(
4870                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4871            )?,
4872            Regex::new(
4873                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4874            )?,
4875            Regex::new(
4876                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4877            )?,
4878            // Audio conformer conv1d layers
4879            Regex::new(
4880                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4881            )?,
4882            Regex::new(
4883                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4884            )?,
4885            // Audio subsample projection
4886            Regex::new(
4887                r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4888            )?,
4889            // Multimodal embedders
4890            Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4891            Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4892        ])
4893    }
4894}
4895
4896impl DeviceMappedModelLoader for Gemma3nLoader {
4897    fn mapped_max_act_size_elems(
4898        &self,
4899        config: &str,
4900        params: &AutoDeviceMapParams,
4901    ) -> Result<usize> {
4902        let AutoDeviceMapParams::Multimodal {
4903            max_seq_len,
4904            max_batch_size,
4905            max_image_shape: _,
4906            max_num_images,
4907        } = params
4908        else {
4909            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4910        };
4911
4912        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4913        let text_cfg = &cfg.text_config;
4914
4915        // Gemma3n is an "inject into the prompt" model, similar to Gemma3
4916        // We need to account for vision and audio tokens in the sequence length
4917
4918        let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4919
4920        // Add vision tokens
4921        {
4922            // Vision tokens are injected into the prompt
4923            // MSFA outputs fixed 16x16 features regardless of input size
4924            let msfa_spatial_size = 16; // Fixed from vision.rs line 1115
4925            let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; // 256 tokens
4926            total_seq_len += vision_tokens_per_image * max_num_images;
4927        }
4928
4929        // Add audio tokens
4930        {
4931            // Audio tokens are injected into the prompt
4932            // From config field audio_soft_tokens_per_image (typically 188)
4933            let audio_tokens = cfg.audio_soft_tokens_per_image;
4934            total_seq_len += audio_tokens;
4935        }
4936
4937        // Calculate max attention size for text model with all injected tokens
4938        let max_text_attn =
4939            max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4940
4941        Ok(max_text_attn)
4942    }
4943
4944    fn non_mapped_max_act_size_elems(
4945        &self,
4946        config: &str,
4947        params: &AutoDeviceMapParams,
4948    ) -> Result<usize> {
4949        let AutoDeviceMapParams::Multimodal {
4950            max_seq_len: _,
4951            max_batch_size,
4952            max_image_shape: _,
4953            max_num_images,
4954        } = params
4955        else {
4956            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4957        };
4958
4959        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4960
4961        // Calculate max activation sizes for each modality
4962        let mut max_activation = 0;
4963
4964        // Vision activation size
4965        {
4966            // Vision is Gemma3n's MobileNetV5 architecture with Multi-Query Attention
4967            // The peak activation is in the Multi-Query Attention layers
4968
4969            // From the architecture: stages 3 and 4 have MMQA blocks
4970            // Input images are 768x768 (from inputs_processor.rs)
4971            // Stage 3: 640 channels at 48x48 (768/16 downsampling), MMQA with num_heads=12, kv_dim=64
4972            // Stage 4: 1280 channels at 24x24 (768/32 downsampling), MMQA with num_heads=16, kv_dim=96
4973            // MSFA output: 2048 channels at fixed 16x16
4974
4975            let vision_tower_act = {
4976                // Peak is during MMQA attention computation in stage 4
4977                // Stage 4 has higher memory usage than Stage 3 due to more heads (16 vs 12)
4978                // From vision.rs: Stage 4 has num_heads=16, kv_dim=96, kv_stride=1
4979                let num_heads = 16; // Stage 4 configuration
4980                let spatial_size = 24; // 768 / 32 = 24 (input 768x768, stage 4 has 32x downsampling)
4981                let seq_len = spatial_size * spatial_size;
4982
4983                // Attention scores: [B * num_images, num_heads, seq_len, seq_len]
4984                max_batch_size * max_num_images * num_heads * seq_len * seq_len
4985            };
4986
4987            // Vision embedder activations
4988            let vision_embed_act = {
4989                // MSFA output: 2048 channels at fixed 16x16 spatial (from vision.rs line 1115)
4990                let msfa_channels = 2048; // MSFA_OUT_CHANNELS from vision.rs
4991                let spatial_size = 16; // Fixed output resolution from MSFA
4992                let vision_features =
4993                    max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4994
4995                // After embedding projection to text hidden size
4996                let projected = max_batch_size
4997                    * max_num_images
4998                    * spatial_size
4999                    * spatial_size
5000                    * cfg.text_config.hidden_size;
5001
5002                vision_features.max(projected)
5003            };
5004
5005            max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
5006        }
5007
5008        // Audio activation size
5009        {
5010            let audio_cfg = &cfg.audio_config;
5011
5012            // Calculate max audio sequence length based on config
5013            // Audio uses conformer with subsampling and reduction
5014
5015            // A rough estimate of max_audio_frames
5016            let max_audio_frames = 1280;
5017
5018            let subsample_factor: usize = audio_cfg
5019                .sscp_conv_stride_size
5020                .iter()
5021                .map(|stride| stride[0]) // Time dimension stride
5022                .product();
5023            let audio_seq_after_subsample = max_audio_frames / subsample_factor;
5024
5025            // Audio encoder activations
5026            let audio_encoder_act = {
5027                // Conformer FFW layers have expansion factor from config
5028                let intermediate_size = audio_cfg.hidden_size * 4; // FFW expansion factor
5029
5030                // Peak is in the FFW layers before reduction
5031                max_batch_size * audio_seq_after_subsample * intermediate_size
5032            };
5033
5034            // Audio attention activations
5035            let audio_attn_act = {
5036                // Attention uses chunked processing with specific context sizes
5037                let chunk_size = audio_cfg.conf_attention_chunk_size;
5038                let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5039                    + audio_cfg.conf_attention_context_right;
5040
5041                // Peak is attention scores: [B, num_heads, num_chunks, chunk_size, context_size]
5042                let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
5043
5044                max_batch_size
5045                    * audio_cfg.conf_num_attention_heads
5046                    * num_chunks
5047                    * chunk_size
5048                    * context_size
5049            };
5050
5051            max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
5052        }
5053
5054        Ok(max_activation)
5055    }
5056
5057    fn non_mapped_size_in_bytes(
5058        &self,
5059        config: &str,
5060        dtype: DType,
5061        weight_pack_factor: usize,
5062        matformer_config: Option<&MatformerSliceConfig>,
5063    ) -> Result<usize> {
5064        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5065
5066        // Apply matformer slicing if configured
5067        let text_cfg = if let Some(matformer_cfg) = matformer_config {
5068            use crate::device_map::DummyDeviceMapper;
5069            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5070
5071            let dummy_mapper = DummyDeviceMapper {
5072                nm_device: Device::Cpu,
5073            };
5074            let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
5075                &cfg.text_config,
5076                &Some(matformer_cfg.clone()),
5077                &dummy_mapper,
5078            )?;
5079            adjusted_cfg
5080        } else {
5081            cfg.text_config.clone()
5082        };
5083
5084        let text_cfg = &text_cfg;
5085
5086        // Text components that are not device-mapped
5087        let text_elems = {
5088            // Embeddings
5089            let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
5090            let embed_tokens_per_layer = text_cfg.num_hidden_layers
5091                * text_cfg.hidden_size_per_layer_input
5092                * text_cfg.vocab_size_per_layer_input;
5093
5094            // LM head (if not tied)
5095            let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
5096                text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
5097            } else {
5098                0
5099            };
5100
5101            // Final layer norm
5102            let norm = text_cfg.hidden_size;
5103
5104            // AltUp projections (not device-mapped)
5105            let altup_projections =
5106                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5107                    / weight_pack_factor;
5108            let altup_unembed_projections =
5109                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5110                    / weight_pack_factor;
5111
5112            // Per-layer model projection
5113            let per_layer_model_projection = text_cfg.num_hidden_layers
5114                * text_cfg.hidden_size
5115                * text_cfg.hidden_size_per_layer_input
5116                / weight_pack_factor;
5117            let per_layer_projection_norm = text_cfg.hidden_size;
5118
5119            embed_tokens
5120                + embed_tokens_per_layer
5121                + lm_head
5122                + norm
5123                + altup_projections
5124                + altup_unembed_projections
5125                + per_layer_model_projection
5126                + per_layer_projection_norm
5127        };
5128
5129        // Vision components
5130        let vision_elems = {
5131            let multimodal_cfg = &cfg.vision_config;
5132            // Vision tower - calculated from actual Gemma3n architecture
5133            // NOTE: Vision tower uses only Conv2d layers, NOT Arc<dyn QuantMethod>,
5134            // so NONE of these should be divided by weight_pack_factor
5135            let vision_tower_elems = {
5136                use crate::vision_models::gemma3n::vision::{
5137                    gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5138                    MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5139                    STEM_OUT_CHANNELS,
5140                };
5141
5142                // Stem: ConvNormAct (Conv2d + RMSNorm)
5143                let stem_conv =
5144                    INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5145                let stem_norm = STEM_OUT_CHANNELS; // RMSNorm weight
5146
5147                // Track input channels through the network
5148                let mut in_chs = STEM_OUT_CHANNELS;
5149                let mut total_elems = stem_conv + stem_norm;
5150
5151                // Process all stages from gemma3n_mobilenet_def
5152                let block_defs = gemma3n_mobilenet_def();
5153
5154                for stage_blocks in block_defs.iter() {
5155                    for block_type in stage_blocks.iter() {
5156                        match block_type {
5157                            BlockType::EdgeResidual {
5158                                out_channels,
5159                                kernel_size,
5160                                stride: _,
5161                                expand_ratio,
5162                                ..
5163                            } => {
5164                                #[allow(clippy::cast_precision_loss)]
5165                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5166                                // EdgeResidual: all Conv2d layers, not quantizable
5167                                total_elems += in_chs * mid_chs * kernel_size * kernel_size; // conv_exp (Conv2d)
5168                                total_elems += mid_chs; // bn1 weight
5169                                total_elems += mid_chs * out_channels; // conv_pwl (Conv2d)
5170                                total_elems += out_channels; // bn2 weight
5171                                in_chs = *out_channels;
5172                            }
5173                            BlockType::UniversalInvertedResidual {
5174                                out_channels,
5175                                start_kernel_size,
5176                                mid_kernel_size,
5177                                stride: _,
5178                                expand_ratio,
5179                                ..
5180                            } => {
5181                                #[allow(clippy::cast_precision_loss)]
5182                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5183                                // UniversalInvertedResidual: all Conv2d layers, not quantizable
5184                                if *expand_ratio != 1.0 {
5185                                    total_elems += in_chs * mid_chs; // expand conv (Conv2d)
5186                                    total_elems += mid_chs; // expand norm
5187                                }
5188                                if *start_kernel_size > 0 {
5189                                    total_elems += mid_chs * start_kernel_size * start_kernel_size; // depthwise start (Conv2d)
5190                                    total_elems += mid_chs; // norm
5191                                }
5192                                if *mid_kernel_size > 0 {
5193                                    total_elems += mid_chs * mid_kernel_size * mid_kernel_size; // depthwise mid (Conv2d)
5194                                    total_elems += mid_chs; // norm
5195                                }
5196                                total_elems += mid_chs * out_channels; // project conv (Conv2d)
5197                                total_elems += out_channels; // project norm
5198                                total_elems += out_channels; // layer scale gamma
5199                                in_chs = *out_channels;
5200                            }
5201                            BlockType::MultiQueryAttention {
5202                                num_heads,
5203                                kv_dim,
5204                                kv_stride: _,
5205                                ..
5206                            } => {
5207                                // MMQA: all Conv2d layers, not quantizable
5208                                let dw_kernel_size = 3; // Default dw_kernel_size for MMQA
5209                                total_elems += in_chs; // norm weight
5210                                total_elems += in_chs * num_heads * kv_dim; // query_proj (Conv2d)
5211                                total_elems += in_chs * kv_dim; // key_proj (Conv2d)
5212                                total_elems += in_chs * dw_kernel_size * dw_kernel_size; // key_dw_conv (Conv2d)
5213                                total_elems += *kv_dim; // value_down_conv (Conv2d)
5214                                total_elems += 1; // value_norm weight
5215                                total_elems += *kv_dim; // value_proj (Conv2d)
5216                                total_elems += num_heads * kv_dim * in_chs; // output_proj (Conv2d)
5217                                total_elems += in_chs; // layer scale
5218                            }
5219                        }
5220                    }
5221                }
5222
5223                // Multi-scale fusion adapter (msfa) - also uses Conv2d layers
5224                let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5225                let msfa_out = MSFA_OUT_CHANNELS;
5226                #[allow(clippy::cast_precision_loss)]
5227                let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5228
5229                // MSFA FFN (UIR with expansion_ratio) - Conv2d layers, not quantizable
5230                total_elems += msfa_in * msfa_mid; // expand (Conv2d)
5231                total_elems += msfa_mid; // expand norm
5232                total_elems += msfa_mid * msfa_out; // project (Conv2d)
5233                total_elems += msfa_out; // project norm
5234                total_elems += msfa_out; // final norm
5235
5236                total_elems
5237            };
5238
5239            // Vision multimodal embedder components
5240            let embed_vision_elems = {
5241                // Embedding layer (not quantizable)
5242                let embedding = multimodal_cfg.vocab_size * multimodal_cfg.hidden_size;
5243
5244                // Normalization layers (not quantizable)
5245                let hard_norm = multimodal_cfg.hidden_size;
5246                let soft_norm = multimodal_cfg.hidden_size;
5247
5248                // Projection from vision to text hidden size (IS Arc<dyn QuantMethod>, so quantizable)
5249                let projection =
5250                    multimodal_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5251
5252                // Post-projection norm (not quantizable)
5253                let post_norm = text_cfg.hidden_size;
5254
5255                embedding + hard_norm + soft_norm + projection + post_norm
5256            };
5257
5258            vision_tower_elems + embed_vision_elems
5259        };
5260
5261        // Audio components - based on actual audio.rs structure
5262        let audio_elems = {
5263            let audio_cfg = &cfg.audio_config;
5264
5265            // SubSampleConvProjection components
5266            let subsample_conv_projection_elems = {
5267                // Conv blocks (Conv2d layers - NOT quantizable)
5268                let mut conv_elems = 0;
5269
5270                // conv_0: Conv2d from 1 channel to first channel size
5271                let in_ch_0 = 1;
5272                let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5273                let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5274                conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5275
5276                // conv_1: Conv2d from first to second channel size
5277                let in_ch_1 = out_ch_0;
5278                let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5279                let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5280                conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5281
5282                // CumulativeGroupNorm for each conv block (weight only, no bias by default)
5283                let norm_0 = out_ch_0; // norm weight for conv_0
5284                let norm_1 = out_ch_1; // norm weight for conv_1
5285
5286                // input_proj_linear (Arc<dyn QuantMethod> - IS quantizable)
5287                let mut f_out = audio_cfg.input_feat_size;
5288                for i in 0..2 {
5289                    let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5290                    let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5291                    let pad_left = 1;
5292                    let pad_right = 1;
5293                    f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5294                }
5295                let input_proj_in_features = out_ch_1 * f_out;
5296                let input_proj_linear =
5297                    input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5298
5299                conv_elems + norm_0 + norm_1 + input_proj_linear
5300            };
5301
5302            // Conformer blocks
5303            let conformer_elems = {
5304                let mut total = 0;
5305
5306                for _ in 0..audio_cfg.conf_num_hidden_layers {
5307                    // ConformerAttention
5308                    let attention_elems = {
5309                        // Norms (NOT quantizable)
5310                        let pre_attn_norm = audio_cfg.hidden_size;
5311                        let post_norm = audio_cfg.hidden_size;
5312
5313                        // Attention projections (Arc<dyn QuantMethod> - IS quantizable)
5314                        let q_proj =
5315                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5316                        let k_proj =
5317                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5318                        let v_proj =
5319                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5320                        let post =
5321                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5322
5323                        // RelativePositionEmbedding
5324                        let pos_proj =
5325                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5326                        let per_dim_scale =
5327                            audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; // head_dim
5328                        let inv_timescales = audio_cfg.hidden_size / 2; // num_timescales
5329                        let pos_indices = audio_cfg.conf_attention_context_left
5330                            + audio_cfg.conf_attention_context_right
5331                            + 1;
5332
5333                        // Local causal masks (precomputed tensors)
5334                        let chunk_size = audio_cfg.conf_attention_chunk_size;
5335                        let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5336                            + audio_cfg.conf_attention_context_right;
5337                        let local_causal_valid_mask = chunk_size * context_size; // U8 tensor
5338                        let invalid_logits_tensor = 1; // single f32 value
5339
5340                        pre_attn_norm
5341                            + post_norm
5342                            + q_proj
5343                            + k_proj
5344                            + v_proj
5345                            + post
5346                            + pos_proj
5347                            + per_dim_scale
5348                            + inv_timescales
5349                            + pos_indices
5350                            + local_causal_valid_mask
5351                            + invalid_logits_tensor
5352                    };
5353
5354                    // ConformerFeedForward (start and end)
5355                    let ffw_elems = {
5356                        // Each FFW has:
5357                        // - pre_layer_norm (NOT quantizable)
5358                        // - ffw_layer_1 (Arc<dyn QuantMethod> - IS quantizable)
5359                        // - ffw_layer_2 (Arc<dyn QuantMethod> - IS quantizable)
5360                        // - post_layer_norm (NOT quantizable)
5361                        let intermediate_size = audio_cfg.hidden_size * 4;
5362
5363                        let ffw_start = {
5364                            let pre_norm = audio_cfg.hidden_size;
5365                            let layer_1 =
5366                                audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5367                            let layer_2 =
5368                                intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5369                            let post_norm = audio_cfg.hidden_size;
5370                            pre_norm + layer_1 + layer_2 + post_norm
5371                        };
5372
5373                        let ffw_end = ffw_start; // Same structure
5374
5375                        ffw_start + ffw_end
5376                    };
5377
5378                    // ConformerLightConv1d
5379                    let lconv1d_elems = {
5380                        // Norms (NOT quantizable)
5381                        let pre_layer_norm = audio_cfg.hidden_size;
5382                        let conv_norm = audio_cfg.hidden_size;
5383
5384                        // Linear layers (Arc<dyn QuantMethod> - IS quantizable)
5385                        let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5386                            / weight_pack_factor;
5387                        let linear_end =
5388                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5389
5390                        // depthwise_conv1d (Conv1d - NOT quantizable)
5391                        let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5392
5393                        pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5394                    };
5395
5396                    // Final norm for conformer block (NOT quantizable)
5397                    let block_norm = audio_cfg.hidden_size;
5398
5399                    total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5400                }
5401
5402                total
5403            };
5404
5405            // Audio multimodal embedder (embed_audio)
5406            let embed_audio_elems = {
5407                // Embedding layer (ScaledEmbedding - NOT quantizable)
5408                let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5409
5410                // RMS norms (NOT quantizable)
5411                let hard_embedding_norm = audio_cfg.hidden_size; // with scale
5412                let soft_embedding_norm = audio_cfg.hidden_size; // with scale
5413                let embedding_post_projection_norm = text_cfg.hidden_size; // without scale
5414
5415                // Projection (Arc<dyn QuantMethod> - IS quantizable)
5416                let embedding_projection =
5417                    audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5418
5419                embedding
5420                    + hard_embedding_norm
5421                    + soft_embedding_norm
5422                    + embedding_post_projection_norm
5423                    + embedding_projection
5424            };
5425
5426            subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5427        };
5428
5429        let vision_dtype = if dtype == DType::F16 {
5430            // f16 -> f32 for vision model in particular.
5431            DType::F32
5432        } else {
5433            dtype
5434        };
5435
5436        let total_elems = text_elems * dtype.size_in_bytes()
5437            + vision_elems * vision_dtype.size_in_bytes()
5438            + audio_elems * dtype.size_in_bytes();
5439
5440        Ok(total_elems)
5441    }
5442
5443    fn layer_sizes_in_bytes(
5444        &self,
5445        config: &str,
5446        dtype: DType,
5447        weight_pack_factor: usize,
5448        matformer_config: Option<&MatformerSliceConfig>,
5449    ) -> Result<Vec<usize>> {
5450        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5451
5452        // Apply matformer slicing if configured
5453        let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5454            matformer_config
5455        {
5456            use crate::device_map::DummyDeviceMapper;
5457            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5458
5459            let dummy_mapper = DummyDeviceMapper {
5460                nm_device: Device::Cpu,
5461            };
5462            let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5463                &cfg.text_config,
5464                &Some(matformer_cfg.clone()),
5465                &dummy_mapper,
5466            )?;
5467            (adjusted_cfg, layer_rename_map, layers_skipped)
5468        } else {
5469            (cfg.text_config.clone(), None, None)
5470        };
5471
5472        let text_cfg = &text_cfg;
5473
5474        // When matformer slicing is applied, we only include the layers that are kept
5475        let mut layer_sizes = Vec::new();
5476
5477        // Note: We don't need orig_intermediate_sizes anymore since the adjusted config
5478        // already has the correct intermediate sizes after matformer slicing
5479
5480        for layer_idx in 0..text_cfg.num_hidden_layers {
5481            let per_layer_elems = {
5482                // Layer norms
5483                let input_layernorm = text_cfg.hidden_size;
5484                let post_attention_layernorm = text_cfg.hidden_size;
5485                let pre_feedforward_layernorm = text_cfg.hidden_size;
5486                let post_feedforward_layernorm = text_cfg.hidden_size;
5487                let post_per_layer_input_norm = text_cfg.hidden_size;
5488
5489                // Attention components
5490                let size_in = text_cfg.hidden_size;
5491                let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5492                let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5493
5494                let q_proj = size_in * size_q / weight_pack_factor;
5495                let k_proj = size_in * size_kv / weight_pack_factor;
5496                let v_proj = size_in * size_kv / weight_pack_factor;
5497                let o_proj = size_q * size_in / weight_pack_factor;
5498
5499                // Q, K, V norms
5500                let q_norm = text_cfg.head_dim;
5501                let k_norm = text_cfg.head_dim;
5502                let v_norm = text_cfg.head_dim; // No bias for v_norm
5503
5504                // MLP components - use the adjusted intermediate sizes from matformer
5505                let intermediate_size = match &text_cfg.intermediate_size {
5506                    IntermediateSize::Single(size) => *size,
5507                    IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5508                    IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5509                };
5510                let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5511                let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5512                let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5513
5514                // AltUp components (per layer)
5515                let altup_elems = {
5516                    let correct_output_scale = text_cfg.hidden_size;
5517                    let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5518                    let prediction_coefs =
5519                        text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5520                    let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5521                    let router_norm = text_cfg.hidden_size;
5522
5523                    correct_output_scale
5524                        + correction_coefs
5525                        + prediction_coefs
5526                        + modality_router
5527                        + router_norm
5528                };
5529
5530                // Laurel block components
5531                let laurel_elems = {
5532                    let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5533                    let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5534                    let post_norm = text_cfg.hidden_size;
5535
5536                    left + right + post_norm
5537                };
5538
5539                // Per-layer input components
5540                let per_layer_input_gate =
5541                    text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5542                let per_layer_projection =
5543                    text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5544
5545                input_layernorm
5546                    + post_attention_layernorm
5547                    + pre_feedforward_layernorm
5548                    + post_feedforward_layernorm
5549                    + post_per_layer_input_norm
5550                    + q_proj
5551                    + k_proj
5552                    + v_proj
5553                    + o_proj
5554                    + q_norm
5555                    + k_norm
5556                    + v_norm
5557                    + gate_proj
5558                    + up_proj
5559                    + down_proj
5560                    + altup_elems
5561                    + laurel_elems
5562                    + per_layer_input_gate
5563                    + per_layer_projection
5564            };
5565
5566            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5567        }
5568
5569        Ok(layer_sizes)
5570    }
5571
5572    fn num_layers(&self, config: &str) -> Result<usize> {
5573        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5574        Ok(cfg.text_config.num_hidden_layers)
5575    }
5576
5577    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5578        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5579        let cfg = cfg.text_config;
5580
5581        let cfg = ModelConfigMetadata {
5582            max_seq_len: cfg.max_position_embeddings,
5583            num_layers: cfg.num_hidden_layers,
5584            hidden_size: cfg.hidden_size,
5585            num_kv_heads: cfg.num_key_value_heads,
5586            num_attn_heads: cfg.num_attention_heads,
5587            sliding_window: None, // None to be more forgiving, some do not
5588            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5589            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5590            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5591        };
5592
5593        Ok(Box::new(cfg))
5594    }
5595
5596    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5597        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5598    }
5599}
5600
5601// ======================== Qwen3VL Loader
5602
5603/// [`MultimodalLoader`] for an Qwen3VL model.
5604///
5605/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
5606pub struct Qwen3VLLoader;
5607
5608pub struct Qwen3VLPrefixer;
5609
5610impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5611    // No-op: With MessagesAction::Keep, the chat template handles image tokens
5612    // when it sees {"type": "image"} entries in the content.
5613}
5614
5615impl MultimodalModelLoader for Qwen3VLLoader {
5616    fn load(
5617        &self,
5618        config: &str,
5619        vb: ShardedVarBuilder,
5620        normal_loading_metadata: NormalLoadingMetadata,
5621        attention_mechanism: AttentionImplementation,
5622    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
5623        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5624        Ok(Box::new(Qwen3VLModel::new(
5625            &cfg,
5626            vb,
5627            self.is_gptx(config),
5628            normal_loading_metadata,
5629            attention_mechanism,
5630        )?))
5631    }
5632    fn is_gptx(&self, _config: &str) -> bool {
5633        true
5634    }
5635    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5636        let config: Qwen3VLConfig = serde_json::from_str(config)?;
5637        Ok(Box::new(config))
5638    }
5639    fn get_processor(
5640        &self,
5641        _model_config: &str,
5642        _processor_config: Option<ProcessorConfig>,
5643        _preprocessor_config: PreProcessorConfig,
5644        max_edge: Option<u32>,
5645    ) -> Arc<dyn Processor + Send + Sync> {
5646        Arc::new(Qwen3VLProcessor::new(max_edge))
5647    }
5648    fn supports_paged_attention(&self, _config: &str) -> bool {
5649        true
5650    }
5651    fn supports_prefix_cacher(&self, _config: &str) -> bool {
5652        true
5653    }
5654    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5655        Arc::new(Qwen3VLPrefixer)
5656    }
5657    fn modalities(&self, _config: &str) -> Result<Modalities> {
5658        Ok(Modalities {
5659            input: vec![SupportedModality::Text, SupportedModality::Vision],
5660            output: vec![SupportedModality::Text],
5661        })
5662    }
5663}
5664
5665impl IsqModelLoader for Qwen3VLLoader {
5666    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5667        Ok(vec![
5668            Regex::new(r"lm_head\.(weight|bias)$")?,
5669            // Attention
5670            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5671            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5672            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5673            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5674            // MLP
5675            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5676            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5677            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5678        ])
5679    }
5680    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5681        self.isq_layer_regexes(config)
5682    }
5683}
5684
5685impl DeviceMappedModelLoader for Qwen3VLLoader {
5686    fn mapped_max_act_size_elems(
5687        &self,
5688        config: &str,
5689        params: &AutoDeviceMapParams,
5690    ) -> Result<usize> {
5691        let AutoDeviceMapParams::Multimodal {
5692            max_seq_len,
5693            max_batch_size,
5694            max_image_shape,
5695            max_num_images,
5696        } = params
5697        else {
5698            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
5699        };
5700
5701        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5702
5703        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
5704        let img_seq_len = {
5705            let cfg = &cfg.vision_config;
5706            // grid_t is 1 for images (temporal dimension is for video only)
5707            let grid_t = 1;
5708            // After patch embedding and spatial merge, the effective grid dimensions are reduced
5709            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5710            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5711            grid_t * grid_h * grid_w * max_num_images
5712        };
5713
5714        let max_text_attn = {
5715            let cfg = &cfg.text_config;
5716            // This model injects the vision information directly into the input embeddings
5717            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5718            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5719        };
5720
5721        Ok(max_text_attn)
5722    }
5723
5724    fn non_mapped_max_act_size_elems(
5725        &self,
5726        config: &str,
5727        params: &AutoDeviceMapParams,
5728    ) -> Result<usize> {
5729        let AutoDeviceMapParams::Multimodal {
5730            max_seq_len: _,
5731            max_batch_size,
5732            max_image_shape,
5733            max_num_images,
5734        } = params
5735        else {
5736            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
5737        };
5738
5739        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5740
5741        // For the vision encoder, before spatial merging
5742        let img_seq_len = {
5743            let cfg = &cfg.vision_config;
5744            // grid_t is 1 for images
5745            let grid_t = 1;
5746            let grid_h = max_image_shape.0 / cfg.patch_size;
5747            let grid_w = max_image_shape.1 / cfg.patch_size;
5748            grid_t * grid_h * grid_w
5749        };
5750
5751        let max_vision_attn = {
5752            let cfg = &cfg.vision_config;
5753            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5754        };
5755
5756        Ok(max_vision_attn)
5757    }
5758
5759    fn non_mapped_size_in_bytes(
5760        &self,
5761        config: &str,
5762        dtype: DType,
5763        weight_pack_factor: usize,
5764        _matformer_config: Option<&MatformerSliceConfig>,
5765    ) -> Result<usize> {
5766        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5767        let tie = cfg.tie_word_embeddings;
5768        let text_elems = {
5769            let cfg = &cfg.text_config;
5770            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5771            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
5772            let lm_head = if !tie || weight_pack_factor != 1 {
5773                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5774            } else {
5775                0
5776            };
5777            let norm = cfg.hidden_size;
5778            embed_tokens + lm_head + norm
5779        };
5780
5781        let (patch_merger, deepstack_mergers) = {
5782            let cfg = &cfg.vision_config;
5783            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5784
5785            let mlp0 = hidden_size * hidden_size + hidden_size;
5786            let mlp2 = hidden_size * cfg.out_hidden_size + cfg.out_hidden_size;
5787
5788            // Main merger: norm uses cfg.hidden_size
5789            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5790            let merger = mlp0 + mlp2 + ln_q;
5791
5792            // Deepstack mergers: norm uses merged hidden_size
5793            let ds_ln = hidden_size + bias_if!(true, hidden_size);
5794            let ds_merger = mlp0 + mlp2 + ds_ln;
5795            let deepstack = cfg.deepstack_visual_indexes.len() * ds_merger;
5796
5797            (merger, deepstack)
5798        };
5799
5800        let patch_embed = {
5801            let cfg = &cfg.vision_config;
5802            let conv_cfg = Conv3dConfig {
5803                stride: cfg.patch_size,
5804                ..Default::default()
5805            };
5806            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5807            let weight = cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5808                * kernel_sizes[0]
5809                * kernel_sizes[1]
5810                * kernel_sizes[2];
5811            let bias = cfg.hidden_size;
5812            weight + bias
5813        };
5814
5815        let pos_embed = {
5816            let cfg = &cfg.vision_config;
5817            cfg.num_position_embeddings * cfg.hidden_size
5818        };
5819
5820        let encoder_layer = {
5821            let cfg = &cfg.vision_config;
5822            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5823            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5824
5825            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5826            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5827            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5828
5829            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5830            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5831
5832            norm1 + norm2 + fc1 + fc2 + qkv + out
5833        };
5834
5835        let elems = text_elems
5836            + patch_merger
5837            + deepstack_mergers
5838            + patch_embed
5839            + pos_embed
5840            + encoder_layer * cfg.vision_config.depth;
5841
5842        Ok(elems * dtype.size_in_bytes())
5843    }
5844
5845    fn layer_sizes_in_bytes(
5846        &self,
5847        config: &str,
5848        dtype: DType,
5849        weight_pack_factor: usize,
5850        _matformer_config: Option<&MatformerSliceConfig>,
5851    ) -> Result<Vec<usize>> {
5852        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5853        let per_layer_elems = {
5854            let cfg = &cfg.text_config;
5855            let input_layernorm = cfg.hidden_size;
5856            let post_attention_layernorm = cfg.hidden_size;
5857
5858            let size_in = cfg.hidden_size;
5859            let size_q = cfg.head_dim * cfg.num_attention_heads;
5860            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
5861            let q_proj = size_in * size_q / weight_pack_factor;
5862            let k_proj = size_in * size_kv / weight_pack_factor;
5863            let v_proj = size_in * size_kv / weight_pack_factor;
5864            let o_proj = size_q * size_in / weight_pack_factor;
5865
5866            let q_norm = cfg.head_dim;
5867            let k_norm = cfg.head_dim;
5868
5869            let h_size = cfg.hidden_size;
5870            let i_size = cfg.intermediate_size;
5871            let gate_proj = h_size * i_size / weight_pack_factor;
5872            let up_proj = h_size * i_size / weight_pack_factor;
5873            let down_proj = i_size * h_size / weight_pack_factor;
5874
5875            input_layernorm
5876                + post_attention_layernorm
5877                + q_proj
5878                + k_proj
5879                + v_proj
5880                + o_proj
5881                + q_norm
5882                + k_norm
5883                + gate_proj
5884                + up_proj
5885                + down_proj
5886        };
5887        Ok(vec![
5888            per_layer_elems * dtype.size_in_bytes();
5889            cfg.text_config.num_hidden_layers
5890        ])
5891    }
5892
5893    fn num_layers(&self, config: &str) -> Result<usize> {
5894        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5895        let cfg = &cfg.text_config;
5896        Ok(cfg.num_hidden_layers)
5897    }
5898
5899    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5900        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5901        let cfg = &cfg.text_config;
5902
5903        let cfg = ModelConfigMetadata {
5904            max_seq_len: cfg.max_position_embeddings,
5905            num_layers: cfg.num_hidden_layers,
5906            hidden_size: cfg.hidden_size,
5907            num_kv_heads: cfg.num_key_value_heads,
5908            num_attn_heads: cfg.num_attention_heads,
5909            sliding_window: cfg.sliding_window,
5910            k_head_dim: cfg.head_dim,
5911            v_head_dim: cfg.head_dim,
5912            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5913        };
5914
5915        Ok(Box::new(cfg))
5916    }
5917
5918    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5919        Some(vec![NonMappedSubModel::Vision])
5920    }
5921}
5922
5923// ======================== Qwen3VLMoE Loader
5924
5925/// [`MultimodalLoader`] for a Qwen3VLMoE model.
5926///
5927/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
5928pub struct Qwen3VLMoELoader;
5929
5930pub struct Qwen3VLMoEPrefixer;
5931
5932impl MultimodalPromptPrefixer for Qwen3VLMoEPrefixer {
5933    // No-op: With MessagesAction::Keep, the chat template handles image tokens
5934    // when it sees {"type": "image"} entries in the content.
5935}
5936
5937impl MultimodalModelLoader for Qwen3VLMoELoader {
5938    fn load(
5939        &self,
5940        config: &str,
5941        vb: ShardedVarBuilder,
5942        normal_loading_metadata: NormalLoadingMetadata,
5943        attention_mechanism: AttentionImplementation,
5944    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
5945        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5946        Ok(Box::new(Qwen3VLMoEModel::new(
5947            &cfg,
5948            vb,
5949            self.is_gptx(config),
5950            normal_loading_metadata,
5951            attention_mechanism,
5952        )?))
5953    }
5954    fn is_gptx(&self, _config: &str) -> bool {
5955        true
5956    }
5957    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5958        let config: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5959        Ok(Box::new(config))
5960    }
5961    fn get_processor(
5962        &self,
5963        _model_config: &str,
5964        _processor_config: Option<ProcessorConfig>,
5965        _preprocessor_config: PreProcessorConfig,
5966        max_edge: Option<u32>,
5967    ) -> Arc<dyn Processor + Send + Sync> {
5968        Arc::new(Qwen3VLMoEProcessor::new(max_edge))
5969    }
5970    fn supports_paged_attention(&self, _config: &str) -> bool {
5971        true
5972    }
5973    fn supports_prefix_cacher(&self, _config: &str) -> bool {
5974        true
5975    }
5976    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5977        Arc::new(Qwen3VLMoEPrefixer)
5978    }
5979    fn modalities(&self, _config: &str) -> Result<Modalities> {
5980        Ok(Modalities {
5981            input: vec![SupportedModality::Text, SupportedModality::Vision],
5982            output: vec![SupportedModality::Text],
5983        })
5984    }
5985}
5986
5987impl IsqModelLoader for Qwen3VLMoELoader {
5988    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5989        Ok(vec![
5990            Regex::new(r"lm_head\.(weight|bias)$")?,
5991            // Attention
5992            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5993            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5994            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5995            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5996            // MLP (dense layers)
5997            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5998            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5999            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
6000            // MoE router
6001            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
6002            // MoE experts - now unpacked into individual experts
6003            Regex::new(
6004                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
6005            )?,
6006            Regex::new(
6007                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
6008            )?,
6009            Regex::new(
6010                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
6011            )?,
6012        ])
6013    }
6014    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
6015        self.isq_layer_regexes(config)
6016    }
6017    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
6018        Ok(vec![
6019            Regex::new(r"lm_head\.(weight|bias)$")?,
6020            // MLP (dense layers)
6021            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
6022            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
6023            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
6024            // MoE router
6025            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
6026            // MoE experts
6027            Regex::new(
6028                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
6029            )?,
6030            Regex::new(
6031                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
6032            )?,
6033            Regex::new(
6034                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
6035            )?,
6036        ])
6037    }
6038    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
6039        self.isq_layer_regexes_moqe(config)
6040    }
6041}
6042
6043impl DeviceMappedModelLoader for Qwen3VLMoELoader {
6044    fn mapped_max_act_size_elems(
6045        &self,
6046        config: &str,
6047        params: &AutoDeviceMapParams,
6048    ) -> Result<usize> {
6049        let AutoDeviceMapParams::Multimodal {
6050            max_seq_len,
6051            max_batch_size,
6052            max_image_shape,
6053            max_num_images,
6054        } = params
6055        else {
6056            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6057        };
6058
6059        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6060
6061        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
6062        let img_seq_len = {
6063            let cfg = &cfg.vision_config;
6064            // grid_t is 1 for images (temporal dimension is for video only)
6065            let grid_t = 1;
6066            // After patch embedding and spatial merge, the effective grid dimensions are reduced
6067            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
6068            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
6069            grid_t * grid_h * grid_w * max_num_images
6070        };
6071
6072        let max_text_attn = {
6073            let cfg = &cfg.text_config;
6074            // This model injects the vision information directly into the input embeddings
6075            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
6076            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
6077        };
6078
6079        Ok(max_text_attn)
6080    }
6081
6082    fn non_mapped_max_act_size_elems(
6083        &self,
6084        config: &str,
6085        params: &AutoDeviceMapParams,
6086    ) -> Result<usize> {
6087        let AutoDeviceMapParams::Multimodal {
6088            max_seq_len: _,
6089            max_batch_size,
6090            max_image_shape,
6091            max_num_images,
6092        } = params
6093        else {
6094            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6095        };
6096
6097        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6098
6099        // For the vision encoder, before spatial merging
6100        let img_seq_len = {
6101            let cfg = &cfg.vision_config;
6102            // grid_t is 1 for images
6103            let grid_t = 1;
6104            let grid_h = max_image_shape.0 / cfg.patch_size;
6105            let grid_w = max_image_shape.1 / cfg.patch_size;
6106            grid_t * grid_h * grid_w
6107        };
6108
6109        let max_vision_attn = {
6110            let cfg = &cfg.vision_config;
6111            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
6112        };
6113
6114        Ok(max_vision_attn)
6115    }
6116
6117    fn non_mapped_size_in_bytes(
6118        &self,
6119        config: &str,
6120        dtype: DType,
6121        weight_pack_factor: usize,
6122        _matformer_config: Option<&MatformerSliceConfig>,
6123    ) -> Result<usize> {
6124        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6125        let tie = cfg.tie_word_embeddings;
6126        let text_elems = {
6127            let cfg = &cfg.text_config;
6128            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
6129            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
6130            let lm_head = if !tie || weight_pack_factor != 1 {
6131                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
6132            } else {
6133                0
6134            };
6135            let norm = cfg.hidden_size;
6136            embed_tokens + lm_head + norm
6137        };
6138
6139        let (patch_merger, deepstack_mergers) = {
6140            let cfg = &cfg.vision_config;
6141            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6142
6143            let mlp0 = hidden_size * hidden_size + hidden_size;
6144            let mlp2 = hidden_size * cfg.out_hidden_size + cfg.out_hidden_size;
6145
6146            // Main merger: norm uses cfg.hidden_size
6147            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6148            let merger = mlp0 + mlp2 + ln_q;
6149
6150            // Deepstack mergers: norm uses merged hidden_size
6151            let ds_ln = hidden_size + bias_if!(true, hidden_size);
6152            let ds_merger = mlp0 + mlp2 + ds_ln;
6153            let deepstack = cfg.deepstack_visual_indexes.len() * ds_merger;
6154
6155            (merger, deepstack)
6156        };
6157
6158        let patch_embed = {
6159            let cfg = &cfg.vision_config;
6160            let conv_cfg = Conv3dConfig {
6161                stride: cfg.patch_size,
6162                ..Default::default()
6163            };
6164            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6165            let weight = cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6166                * kernel_sizes[0]
6167                * kernel_sizes[1]
6168                * kernel_sizes[2];
6169            let bias = cfg.hidden_size;
6170            weight + bias
6171        };
6172
6173        let pos_embed = {
6174            let cfg = &cfg.vision_config;
6175            cfg.num_position_embeddings * cfg.hidden_size
6176        };
6177
6178        let encoder_layer = {
6179            let cfg = &cfg.vision_config;
6180            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6181            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6182
6183            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
6184            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6185            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6186
6187            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6188            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6189
6190            norm1 + norm2 + fc1 + fc2 + qkv + out
6191        };
6192
6193        let elems = text_elems
6194            + patch_merger
6195            + deepstack_mergers
6196            + patch_embed
6197            + pos_embed
6198            + encoder_layer * cfg.vision_config.depth;
6199
6200        Ok(elems * dtype.size_in_bytes())
6201    }
6202
6203    fn layer_sizes_in_bytes(
6204        &self,
6205        config: &str,
6206        dtype: DType,
6207        weight_pack_factor: usize,
6208        _matformer_config: Option<&MatformerSliceConfig>,
6209    ) -> Result<Vec<usize>> {
6210        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6211        let text_cfg = &cfg.text_config;
6212
6213        let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6214
6215        for layer_idx in 0..text_cfg.num_hidden_layers {
6216            let input_layernorm = text_cfg.hidden_size;
6217            let post_attention_layernorm = text_cfg.hidden_size;
6218
6219            let size_in = text_cfg.hidden_size;
6220            let size_q = text_cfg.head_dim * text_cfg.num_attention_heads;
6221            let size_kv = text_cfg.head_dim * text_cfg.num_key_value_heads;
6222            let q_proj = size_in * size_q / weight_pack_factor;
6223            let k_proj = size_in * size_kv / weight_pack_factor;
6224            let v_proj = size_in * size_kv / weight_pack_factor;
6225            let o_proj = size_q * size_in / weight_pack_factor;
6226
6227            let q_norm = text_cfg.head_dim;
6228            let k_norm = text_cfg.head_dim;
6229
6230            // Check if this is a MoE layer
6231            let is_moe = !text_cfg.mlp_only_layers.contains(&layer_idx)
6232                && (text_cfg.num_experts > 0
6233                    && (layer_idx + 1) % text_cfg.decoder_sparse_step == 0);
6234
6235            let mlp_elems = if is_moe {
6236                // MoE layer: gate + experts
6237                let gate = text_cfg.hidden_size * text_cfg.num_experts;
6238                let per_expert = {
6239                    let h_size = text_cfg.hidden_size;
6240                    let i_size = text_cfg.moe_intermediate_size;
6241                    let gate_proj = h_size * i_size / weight_pack_factor;
6242                    let up_proj = h_size * i_size / weight_pack_factor;
6243                    let down_proj = i_size * h_size / weight_pack_factor;
6244                    gate_proj + up_proj + down_proj
6245                };
6246                gate + per_expert * text_cfg.num_experts
6247            } else {
6248                // Dense MLP layer
6249                let h_size = text_cfg.hidden_size;
6250                let i_size = text_cfg.intermediate_size;
6251                let gate_proj = h_size * i_size / weight_pack_factor;
6252                let up_proj = h_size * i_size / weight_pack_factor;
6253                let down_proj = i_size * h_size / weight_pack_factor;
6254                gate_proj + up_proj + down_proj
6255            };
6256
6257            let per_layer_elems = input_layernorm
6258                + post_attention_layernorm
6259                + q_proj
6260                + k_proj
6261                + v_proj
6262                + o_proj
6263                + q_norm
6264                + k_norm
6265                + mlp_elems;
6266
6267            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6268        }
6269
6270        Ok(layer_sizes)
6271    }
6272
6273    fn num_layers(&self, config: &str) -> Result<usize> {
6274        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6275        let cfg = &cfg.text_config;
6276        Ok(cfg.num_hidden_layers)
6277    }
6278
6279    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6280        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6281        let cfg = &cfg.text_config;
6282
6283        let cfg = ModelConfigMetadata {
6284            max_seq_len: cfg.max_position_embeddings,
6285            num_layers: cfg.num_hidden_layers,
6286            hidden_size: cfg.hidden_size,
6287            num_kv_heads: cfg.num_key_value_heads,
6288            num_attn_heads: cfg.num_attention_heads,
6289            sliding_window: cfg.sliding_window,
6290            k_head_dim: cfg.head_dim,
6291            v_head_dim: cfg.head_dim,
6292            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
6293        };
6294
6295        Ok(Box::new(cfg))
6296    }
6297
6298    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6299        Some(vec![NonMappedSubModel::Vision])
6300    }
6301}
6302
6303// ======================== Qwen3_5 (Dense) Loader
6304
6305/// [`MultimodalLoader`] for a Qwen3.5 dense (hybrid GDN + full attention) model.
6306///
6307/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
6308pub struct Qwen3_5Loader;
6309
6310pub struct Qwen3_5Prefixer;
6311
6312impl MultimodalPromptPrefixer for Qwen3_5Prefixer {
6313    // No-op: With MessagesAction::Keep, the chat template handles image tokens
6314    // when it sees {"type": "image"} entries in the content.
6315}
6316
6317impl MultimodalModelLoader for Qwen3_5Loader {
6318    fn load(
6319        &self,
6320        config: &str,
6321        vb: ShardedVarBuilder,
6322        normal_loading_metadata: NormalLoadingMetadata,
6323        attention_mechanism: AttentionImplementation,
6324    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
6325        let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6326        Ok(Box::new(Qwen3_5Model::new(
6327            &cfg,
6328            vb,
6329            self.is_gptx(config),
6330            normal_loading_metadata,
6331            attention_mechanism,
6332        )?))
6333    }
6334    fn is_gptx(&self, _config: &str) -> bool {
6335        true
6336    }
6337    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
6338        let config: Qwen3_5Config = serde_json::from_str(config)?;
6339        Ok(Box::new(config))
6340    }
6341    fn get_processor(
6342        &self,
6343        _model_config: &str,
6344        _processor_config: Option<ProcessorConfig>,
6345        _preprocessor_config: PreProcessorConfig,
6346        max_edge: Option<u32>,
6347    ) -> Arc<dyn Processor + Send + Sync> {
6348        Arc::new(Qwen3_5Processor::new(max_edge))
6349    }
6350    fn supports_paged_attention(&self, _config: &str) -> bool {
6351        true
6352    }
6353    fn supports_prefix_cacher(&self, _config: &str) -> bool {
6354        true
6355    }
6356    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
6357        Arc::new(Qwen3_5Prefixer)
6358    }
6359    fn modalities(&self, _config: &str) -> Result<Modalities> {
6360        Ok(Modalities {
6361            input: vec![SupportedModality::Text, SupportedModality::Vision],
6362            output: vec![SupportedModality::Text],
6363        })
6364    }
6365}
6366
6367impl IsqModelLoader for Qwen3_5Loader {
6368    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
6369        Ok(vec![
6370            Regex::new(r"lm_head\.(weight|bias)$")?,
6371            // Full attention projections
6372            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
6373            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
6374            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
6375            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
6376            // GDN linear attention output projection
6377            Regex::new(
6378                r"model\.language_model\.layers\.(\d+)\.linear_attn\.out_proj\.(weight|bias)$",
6379            )?,
6380            // Dense MLP
6381            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
6382            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
6383            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
6384        ])
6385    }
6386    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
6387        self.isq_layer_regexes(config)
6388    }
6389}
6390
6391impl DeviceMappedModelLoader for Qwen3_5Loader {
6392    fn mapped_max_act_size_elems(
6393        &self,
6394        config: &str,
6395        params: &AutoDeviceMapParams,
6396    ) -> Result<usize> {
6397        let AutoDeviceMapParams::Multimodal {
6398            max_seq_len,
6399            max_batch_size,
6400            max_image_shape,
6401            max_num_images,
6402        } = params
6403        else {
6404            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6405        };
6406
6407        let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6408
6409        let img_seq_len = {
6410            let cfg = &cfg.vision_config;
6411            let grid_t = 1;
6412            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
6413            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
6414            grid_t * grid_h * grid_w * max_num_images
6415        };
6416
6417        let max_text_attn = {
6418            let cfg = &cfg.text_config;
6419            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
6420            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
6421        };
6422
6423        Ok(max_text_attn)
6424    }
6425
6426    fn non_mapped_max_act_size_elems(
6427        &self,
6428        config: &str,
6429        params: &AutoDeviceMapParams,
6430    ) -> Result<usize> {
6431        let AutoDeviceMapParams::Multimodal {
6432            max_seq_len: _,
6433            max_batch_size,
6434            max_image_shape,
6435            max_num_images,
6436        } = params
6437        else {
6438            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6439        };
6440
6441        let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6442
6443        let img_seq_len = {
6444            let cfg = &cfg.vision_config;
6445            let grid_t = 1;
6446            let grid_h = max_image_shape.0 / cfg.patch_size;
6447            let grid_w = max_image_shape.1 / cfg.patch_size;
6448            grid_t * grid_h * grid_w
6449        };
6450
6451        let max_vision_attn = {
6452            let cfg = &cfg.vision_config;
6453            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
6454        };
6455
6456        Ok(max_vision_attn)
6457    }
6458
6459    fn non_mapped_size_in_bytes(
6460        &self,
6461        config: &str,
6462        dtype: DType,
6463        weight_pack_factor: usize,
6464        _matformer_config: Option<&MatformerSliceConfig>,
6465    ) -> Result<usize> {
6466        let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6467        let tie = cfg.tie_word_embeddings;
6468        let text_elems = {
6469            let cfg = &cfg.text_config;
6470            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
6471            let lm_head = if !tie || weight_pack_factor != 1 {
6472                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
6473            } else {
6474                0
6475            };
6476            let norm = cfg.hidden_size;
6477            embed_tokens + lm_head + norm
6478        };
6479
6480        let (patch_merger, deepstack_mergers) = {
6481            let cfg = &cfg.vision_config;
6482            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6483
6484            let mlp0 = hidden_size * hidden_size + hidden_size;
6485            let mlp2 = hidden_size * cfg.out_hidden_size + cfg.out_hidden_size;
6486
6487            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6488            let merger = mlp0 + mlp2 + ln_q;
6489
6490            let ds_ln = hidden_size + bias_if!(true, hidden_size);
6491            let ds_merger = mlp0 + mlp2 + ds_ln;
6492            let deepstack = cfg.deepstack_visual_indexes.len() * ds_merger;
6493
6494            (merger, deepstack)
6495        };
6496
6497        let patch_embed = {
6498            let cfg = &cfg.vision_config;
6499            let conv_cfg = Conv3dConfig {
6500                stride: cfg.patch_size,
6501                ..Default::default()
6502            };
6503            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6504            let weight = cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6505                * kernel_sizes[0]
6506                * kernel_sizes[1]
6507                * kernel_sizes[2];
6508            let bias = cfg.hidden_size;
6509            weight + bias
6510        };
6511
6512        let pos_embed = {
6513            let cfg = &cfg.vision_config;
6514            cfg.num_position_embeddings * cfg.hidden_size
6515        };
6516
6517        let encoder_layer = {
6518            let cfg = &cfg.vision_config;
6519            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6520            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6521
6522            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6523            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6524
6525            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6526            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6527
6528            norm1 + norm2 + fc1 + fc2 + qkv + out
6529        };
6530
6531        let elems = text_elems
6532            + patch_merger
6533            + deepstack_mergers
6534            + patch_embed
6535            + pos_embed
6536            + encoder_layer * cfg.vision_config.depth;
6537
6538        Ok(elems * dtype.size_in_bytes())
6539    }
6540
6541    fn layer_sizes_in_bytes(
6542        &self,
6543        config: &str,
6544        dtype: DType,
6545        weight_pack_factor: usize,
6546        _matformer_config: Option<&MatformerSliceConfig>,
6547    ) -> Result<Vec<usize>> {
6548        let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6549        let text_cfg = &cfg.text_config;
6550        let layer_types = text_cfg.layer_types();
6551
6552        let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6553
6554        for layer_type in &layer_types {
6555            let input_layernorm = text_cfg.hidden_size;
6556            let post_attention_layernorm = text_cfg.hidden_size;
6557
6558            let attn_elems = match layer_type {
6559                crate::vision_models::qwen3_5::config::LayerType::FullAttention => {
6560                    let size_in = text_cfg.hidden_size;
6561                    let size_q = text_cfg.head_dim * text_cfg.num_attention_heads;
6562                    let size_kv = text_cfg.head_dim * text_cfg.num_key_value_heads;
6563                    let q_proj = size_in * size_q * 2 / weight_pack_factor;
6564                    let k_proj = size_in * size_kv / weight_pack_factor;
6565                    let v_proj = size_in * size_kv / weight_pack_factor;
6566                    let o_proj = size_q * size_in / weight_pack_factor;
6567                    let q_norm = text_cfg.head_dim;
6568                    let k_norm = text_cfg.head_dim;
6569                    q_proj + k_proj + v_proj + o_proj + q_norm + k_norm
6570                }
6571                crate::vision_models::qwen3_5::config::LayerType::LinearAttention => {
6572                    let hidden = text_cfg.hidden_size;
6573                    let key_dim = text_cfg.linear_key_dim();
6574                    let value_dim = text_cfg.linear_value_dim();
6575                    let conv_dim = text_cfg.linear_conv_dim();
6576                    // in_proj_qkvz: (2 * key_dim + 2 * value_dim, hidden)
6577                    let in_proj_qkvz = hidden * (key_dim * 2 + value_dim * 2);
6578                    // in_proj_ba: (2 * num_v_heads, hidden)
6579                    let in_proj_ba = hidden * (text_cfg.linear_num_value_heads * 2);
6580                    let out_proj = value_dim * hidden / weight_pack_factor;
6581                    let conv1d = conv_dim * text_cfg.linear_conv_kernel_dim;
6582                    let dt_bias = text_cfg.linear_num_value_heads;
6583                    let a_log = text_cfg.linear_num_value_heads;
6584                    // RmsNormGated over per-head value dim
6585                    let norm = text_cfg.linear_value_head_dim;
6586                    in_proj_qkvz + in_proj_ba + out_proj + conv1d + dt_bias + a_log + norm
6587                }
6588            };
6589
6590            // Dense MLP
6591            let mlp_elems = {
6592                let h_size = text_cfg.hidden_size;
6593                let i_size = text_cfg.intermediate_size;
6594                let gate_proj = h_size * i_size / weight_pack_factor;
6595                let up_proj = h_size * i_size / weight_pack_factor;
6596                let down_proj = i_size * h_size / weight_pack_factor;
6597                gate_proj + up_proj + down_proj
6598            };
6599
6600            let per_layer_elems =
6601                input_layernorm + post_attention_layernorm + attn_elems + mlp_elems;
6602
6603            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6604        }
6605
6606        Ok(layer_sizes)
6607    }
6608
6609    fn num_layers(&self, config: &str) -> Result<usize> {
6610        let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6611        Ok(cfg.text_config.num_hidden_layers)
6612    }
6613
6614    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6615        let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6616        let cfg = &cfg.text_config;
6617
6618        let cfg = ModelConfigMetadata {
6619            max_seq_len: cfg.max_position_embeddings,
6620            num_layers: cfg.num_hidden_layers,
6621            hidden_size: cfg.hidden_size,
6622            num_kv_heads: cfg.num_key_value_heads,
6623            num_attn_heads: cfg.num_attention_heads,
6624            sliding_window: None,
6625            k_head_dim: cfg.head_dim,
6626            v_head_dim: cfg.head_dim,
6627            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
6628        };
6629
6630        Ok(Box::new(cfg))
6631    }
6632
6633    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6634        Some(vec![NonMappedSubModel::Vision])
6635    }
6636}
6637
6638// ======================== Qwen3_5Moe Loader
6639
6640/// [`MultimodalLoader`] for a Qwen3.5 MoE (hybrid GDN + full attention) model.
6641///
6642/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
6643pub struct Qwen3_5MoeLoader;
6644
6645pub struct Qwen3_5MoePrefixer;
6646
6647impl MultimodalPromptPrefixer for Qwen3_5MoePrefixer {
6648    // No-op: With MessagesAction::Keep, the chat template handles image tokens
6649    // when it sees {"type": "image"} entries in the content.
6650}
6651
6652impl MultimodalModelLoader for Qwen3_5MoeLoader {
6653    fn load(
6654        &self,
6655        config: &str,
6656        vb: ShardedVarBuilder,
6657        normal_loading_metadata: NormalLoadingMetadata,
6658        attention_mechanism: AttentionImplementation,
6659    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
6660        let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6661        Ok(Box::new(Qwen3_5MoeModel::new(
6662            &cfg,
6663            vb,
6664            self.is_gptx(config),
6665            normal_loading_metadata,
6666            attention_mechanism,
6667        )?))
6668    }
6669    fn is_gptx(&self, _config: &str) -> bool {
6670        true
6671    }
6672    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
6673        let config: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6674        Ok(Box::new(config))
6675    }
6676    fn get_processor(
6677        &self,
6678        _model_config: &str,
6679        _processor_config: Option<ProcessorConfig>,
6680        _preprocessor_config: PreProcessorConfig,
6681        max_edge: Option<u32>,
6682    ) -> Arc<dyn Processor + Send + Sync> {
6683        Arc::new(Qwen3_5MoeProcessor::new(max_edge))
6684    }
6685    fn supports_paged_attention(&self, _config: &str) -> bool {
6686        true
6687    }
6688    fn supports_prefix_cacher(&self, _config: &str) -> bool {
6689        true
6690    }
6691    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
6692        Arc::new(Qwen3_5MoePrefixer)
6693    }
6694    fn modalities(&self, _config: &str) -> Result<Modalities> {
6695        Ok(Modalities {
6696            input: vec![SupportedModality::Text, SupportedModality::Vision],
6697            output: vec![SupportedModality::Text],
6698        })
6699    }
6700}
6701
6702impl IsqModelLoader for Qwen3_5MoeLoader {
6703    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
6704        Ok(vec![
6705            Regex::new(r"lm_head\.(weight|bias)$")?,
6706            // Full attention projections
6707            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
6708            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
6709            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
6710            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
6711            // GDN linear attention output projection
6712            Regex::new(
6713                r"model\.language_model\.layers\.(\d+)\.linear_attn\.out_proj\.(weight|bias)$",
6714            )?,
6715            // MoE experts
6716            Regex::new(
6717                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
6718            )?,
6719            Regex::new(
6720                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
6721            )?,
6722            Regex::new(
6723                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
6724            )?,
6725            // Shared expert
6726            Regex::new(
6727                r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.gate_proj\.(weight|bias)$",
6728            )?,
6729            Regex::new(
6730                r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.up_proj\.(weight|bias)$",
6731            )?,
6732            Regex::new(
6733                r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.down_proj\.(weight|bias)$",
6734            )?,
6735        ])
6736    }
6737    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
6738        self.isq_layer_regexes(config)
6739    }
6740    fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
6741        Ok(vec![
6742            Regex::new(r"lm_head\.(weight|bias)$")?,
6743            // MoE experts
6744            Regex::new(
6745                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
6746            )?,
6747            Regex::new(
6748                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
6749            )?,
6750            Regex::new(
6751                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
6752            )?,
6753            // Shared expert
6754            Regex::new(
6755                r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.gate_proj\.(weight|bias)$",
6756            )?,
6757            Regex::new(
6758                r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.up_proj\.(weight|bias)$",
6759            )?,
6760            Regex::new(
6761                r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.down_proj\.(weight|bias)$",
6762            )?,
6763        ])
6764    }
6765    fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
6766        self.isq_layer_regexes_moqe(config)
6767    }
6768}
6769
6770impl DeviceMappedModelLoader for Qwen3_5MoeLoader {
6771    fn mapped_max_act_size_elems(
6772        &self,
6773        config: &str,
6774        params: &AutoDeviceMapParams,
6775    ) -> Result<usize> {
6776        let AutoDeviceMapParams::Multimodal {
6777            max_seq_len,
6778            max_batch_size,
6779            max_image_shape,
6780            max_num_images,
6781        } = params
6782        else {
6783            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6784        };
6785
6786        let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6787
6788        let img_seq_len = {
6789            let cfg = &cfg.vision_config;
6790            let grid_t = 1;
6791            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
6792            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
6793            grid_t * grid_h * grid_w * max_num_images
6794        };
6795
6796        let max_text_attn = {
6797            let cfg = &cfg.text_config;
6798            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
6799            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
6800        };
6801
6802        Ok(max_text_attn)
6803    }
6804
6805    fn non_mapped_max_act_size_elems(
6806        &self,
6807        config: &str,
6808        params: &AutoDeviceMapParams,
6809    ) -> Result<usize> {
6810        let AutoDeviceMapParams::Multimodal {
6811            max_seq_len: _,
6812            max_batch_size,
6813            max_image_shape,
6814            max_num_images,
6815        } = params
6816        else {
6817            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6818        };
6819
6820        let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6821
6822        let img_seq_len = {
6823            let cfg = &cfg.vision_config;
6824            let grid_t = 1;
6825            let grid_h = max_image_shape.0 / cfg.patch_size;
6826            let grid_w = max_image_shape.1 / cfg.patch_size;
6827            grid_t * grid_h * grid_w
6828        };
6829
6830        let max_vision_attn = {
6831            let cfg = &cfg.vision_config;
6832            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
6833        };
6834
6835        Ok(max_vision_attn)
6836    }
6837
6838    fn non_mapped_size_in_bytes(
6839        &self,
6840        config: &str,
6841        dtype: DType,
6842        weight_pack_factor: usize,
6843        _matformer_config: Option<&MatformerSliceConfig>,
6844    ) -> Result<usize> {
6845        let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6846        let tie = cfg.tie_word_embeddings;
6847        let text_elems = {
6848            let cfg = &cfg.text_config;
6849            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
6850            let lm_head = if !tie || weight_pack_factor != 1 {
6851                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
6852            } else {
6853                0
6854            };
6855            let norm = cfg.hidden_size;
6856            embed_tokens + lm_head + norm
6857        };
6858
6859        let (patch_merger, deepstack_mergers) = {
6860            let cfg = &cfg.vision_config;
6861            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6862
6863            let mlp0 = hidden_size * hidden_size + hidden_size;
6864            let mlp2 = hidden_size * cfg.out_hidden_size + cfg.out_hidden_size;
6865
6866            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6867            let merger = mlp0 + mlp2 + ln_q;
6868
6869            let ds_ln = hidden_size + bias_if!(true, hidden_size);
6870            let ds_merger = mlp0 + mlp2 + ds_ln;
6871            let deepstack = cfg.deepstack_visual_indexes.len() * ds_merger;
6872
6873            (merger, deepstack)
6874        };
6875
6876        let patch_embed = {
6877            let cfg = &cfg.vision_config;
6878            let conv_cfg = Conv3dConfig {
6879                stride: cfg.patch_size,
6880                ..Default::default()
6881            };
6882            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6883            let weight = cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6884                * kernel_sizes[0]
6885                * kernel_sizes[1]
6886                * kernel_sizes[2];
6887            let bias = cfg.hidden_size;
6888            weight + bias
6889        };
6890
6891        let pos_embed = {
6892            let cfg = &cfg.vision_config;
6893            cfg.num_position_embeddings * cfg.hidden_size
6894        };
6895
6896        let encoder_layer = {
6897            let cfg = &cfg.vision_config;
6898            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6899            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6900
6901            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6902            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6903
6904            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6905            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6906
6907            norm1 + norm2 + fc1 + fc2 + qkv + out
6908        };
6909
6910        let elems = text_elems
6911            + patch_merger
6912            + deepstack_mergers
6913            + patch_embed
6914            + pos_embed
6915            + encoder_layer * cfg.vision_config.depth;
6916
6917        Ok(elems * dtype.size_in_bytes())
6918    }
6919
6920    fn layer_sizes_in_bytes(
6921        &self,
6922        config: &str,
6923        dtype: DType,
6924        weight_pack_factor: usize,
6925        _matformer_config: Option<&MatformerSliceConfig>,
6926    ) -> Result<Vec<usize>> {
6927        let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6928        let text_cfg = &cfg.text_config;
6929        let layer_types = text_cfg.layer_types();
6930
6931        let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6932
6933        for layer_type in &layer_types {
6934            let input_layernorm = text_cfg.hidden_size;
6935            let post_attention_layernorm = text_cfg.hidden_size;
6936
6937            let attn_elems = match layer_type {
6938                crate::vision_models::qwen3_5_moe::config::LayerType::FullAttention => {
6939                    let size_in = text_cfg.hidden_size;
6940                    let size_q = text_cfg.head_dim * text_cfg.num_attention_heads;
6941                    let size_kv = text_cfg.head_dim * text_cfg.num_key_value_heads;
6942                    let q_proj = size_in * size_q * 2 / weight_pack_factor;
6943                    let k_proj = size_in * size_kv / weight_pack_factor;
6944                    let v_proj = size_in * size_kv / weight_pack_factor;
6945                    let o_proj = size_q * size_in / weight_pack_factor;
6946                    let q_norm = text_cfg.head_dim;
6947                    let k_norm = text_cfg.head_dim;
6948                    q_proj + k_proj + v_proj + o_proj + q_norm + k_norm
6949                }
6950                crate::vision_models::qwen3_5_moe::config::LayerType::LinearAttention => {
6951                    let hidden = text_cfg.hidden_size;
6952                    let key_dim = text_cfg.linear_key_dim();
6953                    let value_dim = text_cfg.linear_value_dim();
6954                    let conv_dim = text_cfg.linear_conv_dim();
6955                    // in_proj_qkvz: (2 * key_dim + 2 * value_dim, hidden)
6956                    let in_proj_qkvz = hidden * (key_dim * 2 + value_dim * 2);
6957                    // in_proj_ba: (2 * num_v_heads, hidden)
6958                    let in_proj_ba = hidden * (text_cfg.linear_num_value_heads * 2);
6959                    // out_proj: value_dim -> hidden
6960                    let out_proj = value_dim * hidden / weight_pack_factor;
6961                    // conv1d weight
6962                    let conv1d = conv_dim * text_cfg.linear_conv_kernel_dim;
6963                    // dt_bias, A_log, norm weight
6964                    let dt_bias = text_cfg.linear_num_value_heads;
6965                    let a_log = text_cfg.linear_num_value_heads;
6966                    // RmsNormGated over per-head value dim
6967                    let norm = text_cfg.linear_value_head_dim;
6968                    in_proj_qkvz + in_proj_ba + out_proj + conv1d + dt_bias + a_log + norm
6969                }
6970            };
6971
6972            // All layers have MoE
6973            let moe_elems = {
6974                let gate = text_cfg.hidden_size * text_cfg.num_experts;
6975                let per_expert = {
6976                    let h_size = text_cfg.hidden_size;
6977                    let i_size = text_cfg.moe_intermediate_size;
6978                    let gate_proj = h_size * i_size / weight_pack_factor;
6979                    let up_proj = h_size * i_size / weight_pack_factor;
6980                    let down_proj = i_size * h_size / weight_pack_factor;
6981                    gate_proj + up_proj + down_proj
6982                };
6983                let shared_expert = {
6984                    let h_size = text_cfg.hidden_size;
6985                    let i_size = text_cfg.shared_expert_intermediate_size;
6986                    let gate_proj = h_size * i_size / weight_pack_factor;
6987                    let up_proj = h_size * i_size / weight_pack_factor;
6988                    let down_proj = i_size * h_size / weight_pack_factor;
6989                    gate_proj + up_proj + down_proj
6990                };
6991                let shared_expert_gate = text_cfg.hidden_size;
6992                gate + per_expert * text_cfg.num_experts + shared_expert + shared_expert_gate
6993            };
6994
6995            let per_layer_elems =
6996                input_layernorm + post_attention_layernorm + attn_elems + moe_elems;
6997
6998            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6999        }
7000
7001        Ok(layer_sizes)
7002    }
7003
7004    fn num_layers(&self, config: &str) -> Result<usize> {
7005        let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
7006        Ok(cfg.text_config.num_hidden_layers)
7007    }
7008
7009    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
7010        let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
7011        let cfg = &cfg.text_config;
7012
7013        let cfg = ModelConfigMetadata {
7014            max_seq_len: cfg.max_position_embeddings,
7015            num_layers: cfg.num_hidden_layers,
7016            hidden_size: cfg.hidden_size,
7017            num_kv_heads: cfg.num_key_value_heads,
7018            num_attn_heads: cfg.num_attention_heads,
7019            sliding_window: None,
7020            k_head_dim: cfg.head_dim,
7021            v_head_dim: cfg.head_dim,
7022            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
7023        };
7024
7025        Ok(Box::new(cfg))
7026    }
7027
7028    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
7029        Some(vec![NonMappedSubModel::Vision])
7030    }
7031}
7032
7033// ─── Voxtral ────────────────────────────────────────────────────────────────
7034
7035/// [`MultimodalLoader`] for a Voxtral model.
7036///
7037/// [`MultimodalLoader`]: https://docs.rs/hanzo/latest/hanzo/struct.MultimodalLoader.html
7038pub struct VoxtralLoader;
7039
7040pub struct VoxtralPrefixer;
7041
7042impl MultimodalPromptPrefixer for VoxtralPrefixer {
7043    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
7044        prompt.to_string()
7045    }
7046}
7047
7048impl MultimodalModelLoader for VoxtralLoader {
7049    fn load(
7050        &self,
7051        config: &str,
7052        vb: ShardedVarBuilder,
7053        normal_loading_metadata: NormalLoadingMetadata,
7054        attention_mechanism: AttentionImplementation,
7055    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
7056        let cfg: VoxtralConfig = serde_json::from_str(config)?;
7057        Ok(Box::new(VoxtralModel::new(
7058            &cfg,
7059            vb,
7060            self.is_gptx(config),
7061            normal_loading_metadata,
7062            attention_mechanism,
7063        )?))
7064    }
7065    fn is_gptx(&self, _config: &str) -> bool {
7066        true
7067    }
7068    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
7069        let cfg: VoxtralConfig = serde_json::from_str(config)?;
7070        Ok(Box::new(cfg))
7071    }
7072    fn get_processor(
7073        &self,
7074        model_config: &str,
7075        _processor_config: Option<ProcessorConfig>,
7076        _preprocessor_config: PreProcessorConfig,
7077        _max_edge: Option<u32>,
7078    ) -> Arc<dyn Processor + Send + Sync> {
7079        let cfg: VoxtralConfig =
7080            serde_json::from_str(model_config).expect("Failed to parse VoxtralConfig");
7081        Arc::new(VoxtralProcessor::new(&cfg))
7082    }
7083    fn supports_paged_attention(&self, _config: &str) -> bool {
7084        false
7085    }
7086    fn supports_prefix_cacher(&self, _config: &str) -> bool {
7087        false
7088    }
7089    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
7090        Arc::new(VoxtralPrefixer)
7091    }
7092    fn modalities(&self, _config: &str) -> Result<Modalities> {
7093        Ok(Modalities {
7094            input: vec![SupportedModality::Text, SupportedModality::Audio],
7095            output: vec![SupportedModality::Text],
7096        })
7097    }
7098    fn default_chat_template(&self, _config: &str) -> Option<String> {
7099        // Mistral v7 instruct format using [INST]/[/INST] tokens
7100        Some("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string())
7101    }
7102    fn default_bos_eos(&self, _config: &str) -> Option<(String, String)> {
7103        // Mistral tekken tokenizer: <s> = ID 1, </s> = ID 2
7104        Some(("<s>".to_string(), "</s>".to_string()))
7105    }
7106}
7107
7108impl IsqModelLoader for VoxtralLoader {
7109    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
7110        Ok(vec![
7111            // Output / lm_head (tied with tok_embeddings)
7112            Regex::new(r"lm_head\.(weight|bias)$")?,
7113            // Decoder attention (Mistral-native naming)
7114            Regex::new(r"layers\.(\d+)\.attention\.wq\.(weight|bias)$")?,
7115            Regex::new(r"layers\.(\d+)\.attention\.wk\.(weight|bias)$")?,
7116            Regex::new(r"layers\.(\d+)\.attention\.wv\.(weight|bias)$")?,
7117            Regex::new(r"layers\.(\d+)\.attention\.wo\.(weight|bias)$")?,
7118            // Decoder MLP (Mistral-native naming)
7119            Regex::new(r"layers\.(\d+)\.feed_forward\.w1\.(weight|bias)$")?,
7120            Regex::new(r"layers\.(\d+)\.feed_forward\.w3\.(weight|bias)$")?,
7121            Regex::new(r"layers\.(\d+)\.feed_forward\.w2\.(weight|bias)$")?,
7122        ])
7123    }
7124    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
7125        Ok(vec![
7126            Regex::new(r"tok_embeddings\.(weight|bias)$")?,
7127            // Decoder attention
7128            Regex::new(r"layers\.(\d+)\.attention\.wq\.(weight|bias)$")?,
7129            Regex::new(r"layers\.(\d+)\.attention\.wk\.(weight|bias)$")?,
7130            Regex::new(r"layers\.(\d+)\.attention\.wv\.(weight|bias)$")?,
7131            Regex::new(r"layers\.(\d+)\.attention\.wo\.(weight|bias)$")?,
7132            // Decoder MLP
7133            Regex::new(r"layers\.(\d+)\.feed_forward\.w1\.(weight|bias)$")?,
7134            Regex::new(r"layers\.(\d+)\.feed_forward\.w3\.(weight|bias)$")?,
7135            Regex::new(r"layers\.(\d+)\.feed_forward\.w2\.(weight|bias)$")?,
7136        ])
7137    }
7138}
7139
7140#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
7141impl DeviceMappedModelLoader for VoxtralLoader {
7142    fn mapped_max_act_size_elems(
7143        &self,
7144        config: &str,
7145        params: &AutoDeviceMapParams,
7146    ) -> Result<usize> {
7147        let AutoDeviceMapParams::Multimodal {
7148            max_seq_len,
7149            max_batch_size,
7150            ..
7151        } = params
7152        else {
7153            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
7154        };
7155
7156        let cfg: VoxtralConfig = serde_json::from_str(config)?;
7157
7158        // Audio tokens are prepended: max audio len + text seq len
7159        // Audio: ~30s at 16kHz = 480k samples, /160 hop = 3000 frames, /2 conv stride = 1500, /4 adapter = 375 tokens
7160        let max_audio_tokens = 375;
7161        let total_seq = max_audio_tokens + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
7162        Ok(max_batch_size * cfg.n_heads * total_seq * total_seq)
7163    }
7164
7165    fn non_mapped_max_act_size_elems(
7166        &self,
7167        config: &str,
7168        params: &AutoDeviceMapParams,
7169    ) -> Result<usize> {
7170        let AutoDeviceMapParams::Multimodal { max_batch_size, .. } = params else {
7171            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
7172        };
7173
7174        let cfg: VoxtralConfig = serde_json::from_str(config)?;
7175        let enc = &cfg.multimodal.whisper_model_args.encoder_args;
7176        // Encoder max activation: attention matrix
7177        // ~3000 mel frames, encoder has 32 heads, seq_len^2
7178        let max_enc_seq = 3000usize;
7179        Ok(max_batch_size * enc.n_heads * max_enc_seq * max_enc_seq)
7180    }
7181
7182    fn non_mapped_size_in_bytes(
7183        &self,
7184        config: &str,
7185        dtype: DType,
7186        _weight_pack_factor: usize,
7187        _matformer_config: Option<&MatformerSliceConfig>,
7188    ) -> Result<usize> {
7189        let cfg: VoxtralConfig = serde_json::from_str(config)?;
7190        let enc = &cfg.multimodal.whisper_model_args.encoder_args;
7191        let ds = &cfg.multimodal.whisper_model_args.downsample_args;
7192
7193        let elem = dtype.size_in_bytes();
7194
7195        // Encoder conv layers
7196        let conv1 = enc.dim * enc.audio_encoding_args.num_mel_bins * 3 + enc.dim; // weight + bias
7197        let conv2 = enc.dim * enc.dim * 3 + enc.dim;
7198
7199        // Encoder layers
7200        let enc_attn_per_layer = 4 * enc.dim * enc.dim; // wq, wk, wv, wo (full heads)
7201        let enc_mlp_per_layer = 3 * enc.dim * enc.hidden_dim; // w1, w2, w3
7202        let enc_norm_per_layer = 2 * enc.dim; // attention_norm, ffn_norm
7203        let enc_layers =
7204            enc.n_layers * (enc_attn_per_layer + enc_mlp_per_layer + enc_norm_per_layer);
7205        let enc_final_norm = enc.dim;
7206
7207        // Adapter
7208        let adapter_in_features = enc.dim * ds.downsample_factor;
7209        let adapter = adapter_in_features * cfg.dim + cfg.dim + cfg.dim * cfg.dim + cfg.dim;
7210
7211        let total_encoder = conv1 + conv2 + enc_layers + enc_final_norm + adapter;
7212
7213        // Decoder embeddings
7214        let embeddings = cfg.vocab_size * cfg.dim;
7215
7216        Ok((total_encoder + embeddings) * elem)
7217    }
7218
7219    fn layer_sizes_in_bytes(
7220        &self,
7221        config: &str,
7222        dtype: DType,
7223        weight_pack_factor: usize,
7224        _matformer_config: Option<&MatformerSliceConfig>,
7225    ) -> Result<Vec<usize>> {
7226        let cfg: VoxtralConfig = serde_json::from_str(config)?;
7227        let elem = dtype.size_in_bytes();
7228
7229        let attn = (cfg.dim * cfg.n_heads * cfg.head_dim
7230            + cfg.dim * cfg.n_kv_heads * cfg.head_dim
7231            + cfg.dim * cfg.n_kv_heads * cfg.head_dim
7232            + cfg.n_heads * cfg.head_dim * cfg.dim)
7233            / weight_pack_factor;
7234        let mlp = (cfg.dim * cfg.hidden_dim + cfg.hidden_dim * cfg.dim + cfg.dim * cfg.hidden_dim)
7235            / weight_pack_factor;
7236        let norms = 2 * cfg.dim; // attention_norm + ffn_norm
7237
7238        let per_layer = (attn + mlp + norms) * elem;
7239
7240        Ok(vec![per_layer; cfg.n_layers])
7241    }
7242
7243    fn num_layers(&self, config: &str) -> Result<usize> {
7244        let cfg: VoxtralConfig = serde_json::from_str(config)?;
7245        Ok(cfg.n_layers)
7246    }
7247
7248    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
7249        let cfg: VoxtralConfig = serde_json::from_str(config)?;
7250
7251        let cfg = ModelConfigMetadata {
7252            max_seq_len: cfg.model_max_length,
7253            num_layers: cfg.n_layers,
7254            hidden_size: cfg.dim,
7255            num_kv_heads: cfg.n_kv_heads,
7256            num_attn_heads: cfg.n_heads,
7257            sliding_window: cfg.sliding_window,
7258            k_head_dim: cfg.head_dim,
7259            v_head_dim: cfg.head_dim,
7260            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
7261        };
7262
7263        Ok(Box::new(cfg))
7264    }
7265}
7266
7267// ── Gemma4 ─────────────────────────────────────────────────────────────────
7268
7269pub struct Gemma4Loader;
7270
7271#[allow(dead_code)]
7272pub struct Gemma4Prefixer;
7273
7274impl MultimodalPromptPrefixer for Gemma4Prefixer {
7275    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
7276        prompt.to_string()
7277    }
7278    fn prefix_video(&self, _video_indexes: Vec<usize>, prompt: &str) -> String {
7279        prompt.to_string()
7280    }
7281}
7282
7283impl MultimodalModelLoader for Gemma4Loader {
7284    fn load(
7285        &self,
7286        config: &str,
7287        vb: ShardedVarBuilder,
7288        normal_loading_metadata: NormalLoadingMetadata,
7289        attention_mechanism: AttentionImplementation,
7290    ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
7291        let cfg: Gemma4Config = serde_json::from_str(config)?;
7292        Ok(Box::new(Gemma4Model::new(
7293            &cfg,
7294            vb,
7295            self.is_gptx(config),
7296            normal_loading_metadata,
7297            attention_mechanism,
7298        )?))
7299    }
7300    fn is_gptx(&self, _config: &str) -> bool {
7301        true
7302    }
7303    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
7304        let config: Gemma4Config = serde_json::from_str(config)?;
7305        Ok(Box::new(config))
7306    }
7307    fn get_processor(
7308        &self,
7309        config: &str,
7310        processor_config: Option<ProcessorConfig>,
7311        _preprocessor_config: PreProcessorConfig,
7312        _max_edge: Option<u32>,
7313    ) -> Arc<dyn Processor + Send + Sync> {
7314        let cfg: Gemma4Config = serde_json::from_str(config).expect("Failed to parse Gemma4Config");
7315        Arc::new(Gemma4Processor::new(
7316            processor_config.unwrap_or_default(),
7317            cfg.vision_config.patch_size,
7318            cfg.vision_config.pooling_kernel_size,
7319            cfg.vision_config.default_output_length,
7320            true,
7321            cfg.audio_config.is_some(),
7322        ))
7323    }
7324    fn supports_paged_attention(&self, _config: &str) -> bool {
7325        true
7326    }
7327    fn supports_prefix_cacher(&self, _config: &str) -> bool {
7328        true
7329    }
7330    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
7331        Arc::new(Gemma4Prefixer)
7332    }
7333    fn modalities(&self, config: &str) -> Result<Modalities> {
7334        let cfg: Gemma4Config = serde_json::from_str(config)?;
7335        let mut input = vec![
7336            SupportedModality::Text,
7337            SupportedModality::Vision,
7338            SupportedModality::Video,
7339        ];
7340        if cfg.audio_config.is_some() {
7341            input.push(SupportedModality::Audio);
7342        }
7343        Ok(Modalities {
7344            input,
7345            output: vec![SupportedModality::Text],
7346        })
7347    }
7348}
7349
7350impl IsqModelLoader for Gemma4Loader {
7351    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
7352        // `embed_vision.embedding_projection` is intentionally excluded.
7353        Ok(vec![
7354            Regex::new(r"lm_head\.(weight|bias)$")?,
7355            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
7356            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
7357            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
7358            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
7359            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
7360            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
7361            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
7362            Regex::new(r"layers\.(\d+)\.moe\.gate_up_proj\.weight$")?,
7363            Regex::new(r"layers\.(\d+)\.moe\.down_proj\.weight$")?,
7364            Regex::new(r"layers\.(\d+)\.experts\.gate_up_proj\.weight$")?,
7365            Regex::new(r"layers\.(\d+)\.experts\.down_proj\.weight$")?,
7366            Regex::new(r"per_layer_model_projection\.(weight|bias)$")?,
7367            Regex::new(r"layers\.(\d+)\.per_layer_input_gate\.(weight|bias)$")?,
7368            Regex::new(r"layers\.(\d+)\.per_layer_projection\.(weight|bias)$")?,
7369        ])
7370    }
7371    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
7372        Ok(vec![
7373            Regex::new(r"lm_head\.(weight|bias)$")?,
7374            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
7375            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
7376            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
7377            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
7378            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
7379            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
7380            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
7381            Regex::new(r"model\.language_model\.layers\.(\d+)\.moe\.gate_up_proj\.weight$")?,
7382            Regex::new(r"model\.language_model\.layers\.(\d+)\.moe\.down_proj\.weight$")?,
7383            Regex::new(r"model\.language_model\.layers\.(\d+)\.experts\.gate_up_proj\.weight$")?,
7384            Regex::new(r"model\.language_model\.layers\.(\d+)\.experts\.down_proj\.weight$")?,
7385            Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
7386            Regex::new(
7387                r"model\.language_model\.layers\.(\d+)\.per_layer_input_gate\.(weight|bias)$",
7388            )?,
7389            Regex::new(
7390                r"model\.language_model\.layers\.(\d+)\.per_layer_projection\.(weight|bias)$",
7391            )?,
7392        ])
7393    }
7394}
7395
7396impl DeviceMappedModelLoader for Gemma4Loader {
7397    fn mapped_max_act_size_elems(
7398        &self,
7399        config: &str,
7400        params: &AutoDeviceMapParams,
7401    ) -> Result<usize> {
7402        let AutoDeviceMapParams::Multimodal {
7403            max_seq_len,
7404            max_batch_size,
7405            max_image_shape: _,
7406            max_num_images,
7407        } = params
7408        else {
7409            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
7410        };
7411
7412        let cfg: Gemma4Config = serde_json::from_str(config)?;
7413        let tc = &cfg.text_config;
7414
7415        let vision_tokens_per_image = cfg.vision_soft_tokens_per_image.unwrap_or(280);
7416        let audio_tokens = if cfg.audio_config.is_some() { 750 } else { 0 };
7417        let total_seq_len = *max_seq_len + vision_tokens_per_image * max_num_images + audio_tokens;
7418        let max_text_attn = max_batch_size * tc.num_attention_heads * total_seq_len * total_seq_len;
7419
7420        Ok(max_text_attn)
7421    }
7422
7423    fn non_mapped_max_act_size_elems(
7424        &self,
7425        config: &str,
7426        params: &AutoDeviceMapParams,
7427    ) -> Result<usize> {
7428        let AutoDeviceMapParams::Multimodal {
7429            max_seq_len: _,
7430            max_batch_size,
7431            max_image_shape: _,
7432            max_num_images,
7433        } = params
7434        else {
7435            anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
7436        };
7437
7438        let cfg: Gemma4Config = serde_json::from_str(config)?;
7439        let vc = &cfg.vision_config;
7440
7441        let max_patches =
7442            vc.default_output_length * vc.pooling_kernel_size * vc.pooling_kernel_size;
7443        let max_vision_attn =
7444            max_batch_size * max_num_images * vc.num_attention_heads * max_patches * max_patches;
7445        let max_vision_hidden = max_batch_size
7446            * max_num_images
7447            * max_patches
7448            * vc.hidden_size.max(vc.intermediate_size);
7449
7450        let max_audio_activation = cfg.audio_config.as_ref().map_or(0, |audio_cfg| {
7451            let subsample_factor: usize = audio_cfg
7452                .sscp_conv_stride_size
7453                .iter()
7454                .map(|stride| stride[0])
7455                .product();
7456            let max_audio_frames = 750 * subsample_factor.max(1);
7457            let audio_seq_after_subsample = max_audio_frames / subsample_factor.max(1);
7458
7459            let audio_encoder_act = audio_seq_after_subsample * (audio_cfg.hidden_size * 4);
7460            let chunk_size = audio_cfg.conf_attention_chunk_size;
7461            let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
7462                + audio_cfg.conf_attention_context_right;
7463            let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
7464            let audio_attn_act =
7465                audio_cfg.conf_num_attention_heads * num_chunks * chunk_size * context_size;
7466
7467            max_batch_size * audio_encoder_act.max(audio_attn_act)
7468        });
7469
7470        Ok(max_vision_attn
7471            .max(max_vision_hidden)
7472            .max(max_audio_activation))
7473    }
7474
7475    fn non_mapped_size_in_bytes(
7476        &self,
7477        config: &str,
7478        dtype: DType,
7479        weight_pack_factor: usize,
7480        _matformer_config: Option<&MatformerSliceConfig>,
7481    ) -> Result<usize> {
7482        let cfg: Gemma4Config = serde_json::from_str(config)?;
7483        let tc = &cfg.text_config;
7484        let vc = &cfg.vision_config;
7485
7486        let text_elems = {
7487            let embed_tokens = tc.hidden_size * tc.vocab_size;
7488            let lm_head = if !tc.tie_word_embeddings || weight_pack_factor != 1 {
7489                tc.hidden_size * tc.vocab_size / weight_pack_factor
7490            } else {
7491                0
7492            };
7493            let norm = tc.hidden_size;
7494
7495            let ple_dim = tc.hidden_size_per_layer_input.unwrap_or(0);
7496            let ple_vocab = tc.vocab_size_per_layer_input.unwrap_or(tc.vocab_size);
7497            let embed_tokens_per_layer = if ple_dim > 0 {
7498                ple_vocab * tc.num_hidden_layers * ple_dim
7499            } else {
7500                0
7501            };
7502            let per_layer_model_projection = if ple_dim > 0 {
7503                tc.hidden_size * tc.num_hidden_layers * ple_dim / weight_pack_factor
7504            } else {
7505                0
7506            };
7507            let per_layer_projection_norm = ple_dim;
7508
7509            embed_tokens
7510                + lm_head
7511                + norm
7512                + embed_tokens_per_layer
7513                + per_layer_model_projection
7514                + per_layer_projection_norm
7515        };
7516
7517        let vision_layer_elems = {
7518            let quantized = vc.hidden_size * vc.num_attention_heads * vc.head_dim
7519                + 3 * (vc.hidden_size * vc.num_key_value_heads * vc.head_dim)
7520                + 2 * (vc.hidden_size * vc.intermediate_size)
7521                + vc.intermediate_size * vc.hidden_size;
7522            let norms = 2 * vc.head_dim + 4 * vc.hidden_size;
7523            quantized / weight_pack_factor + norms
7524        };
7525        let vision_elems = {
7526            let patch_embed = vc.patch_size * vc.patch_size * 3 * vc.hidden_size;
7527            let position_embedding_table = 2 * vc.position_embedding_size * vc.hidden_size;
7528            let patch_embedder = patch_embed / weight_pack_factor + position_embedding_table;
7529            let encoder = vc.num_hidden_layers * vision_layer_elems;
7530            let embed_vision = vc.hidden_size * tc.hidden_size / weight_pack_factor;
7531
7532            patch_embedder + encoder + embed_vision
7533        };
7534
7535        let audio_elems = cfg.audio_config.as_ref().map_or(0, |audio_cfg| {
7536            let mut f_out = audio_cfg.input_feat_size;
7537            for i in 0..2 {
7538                let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
7539                let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
7540                let pad_left = 1;
7541                let pad_right = 1;
7542                f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
7543            }
7544
7545            let subsample_conv_projection = {
7546                let conv_0 = audio_cfg.sscp_conv_channel_size[0]
7547                    * audio_cfg.sscp_conv_kernel_size[0][0]
7548                    * audio_cfg.sscp_conv_kernel_size[0][1];
7549                let conv_1 = audio_cfg.sscp_conv_channel_size[0]
7550                    * audio_cfg.sscp_conv_channel_size[1]
7551                    * audio_cfg.sscp_conv_kernel_size[1][0]
7552                    * audio_cfg.sscp_conv_kernel_size[1][1];
7553                let norms =
7554                    audio_cfg.sscp_conv_channel_size[0] + audio_cfg.sscp_conv_channel_size[1];
7555                let input_proj =
7556                    audio_cfg.sscp_conv_channel_size[1] * f_out * audio_cfg.hidden_size
7557                        / weight_pack_factor;
7558                conv_0 + conv_1 + norms + input_proj
7559            };
7560
7561            let conformer_block = {
7562                let attention = 5 * (audio_cfg.hidden_size * audio_cfg.hidden_size)
7563                    / weight_pack_factor
7564                    + 2 * audio_cfg.hidden_size
7565                    + audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads
7566                    + audio_cfg.hidden_size / 2
7567                    + (audio_cfg.conf_attention_context_left
7568                        + audio_cfg.conf_attention_context_right
7569                        + 1)
7570                    + (audio_cfg.conf_attention_chunk_size
7571                        * (audio_cfg.conf_attention_chunk_size
7572                            + audio_cfg.conf_attention_context_left
7573                            - 1
7574                            + audio_cfg.conf_attention_context_right))
7575                    + 1;
7576                let ffw = 2
7577                    * (2 * audio_cfg.hidden_size
7578                        + 2 * (audio_cfg.hidden_size * (audio_cfg.hidden_size * 4))
7579                            / weight_pack_factor);
7580                let conv = 2 * audio_cfg.hidden_size
7581                    + audio_cfg.hidden_size * (audio_cfg.hidden_size * 2) / weight_pack_factor
7582                    + audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor
7583                    + audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
7584                attention + ffw + conv + audio_cfg.hidden_size
7585            };
7586
7587            let output_proj = audio_cfg.output_proj_dims.map_or(0, |output_dim| {
7588                audio_cfg.hidden_size * output_dim / weight_pack_factor + output_dim
7589            });
7590            let audio_embed_hidden = audio_cfg.output_proj_dims.unwrap_or(audio_cfg.hidden_size);
7591            let embed_audio = audio_embed_hidden * tc.hidden_size / weight_pack_factor;
7592
7593            subsample_conv_projection
7594                + audio_cfg.conf_num_hidden_layers * conformer_block
7595                + output_proj
7596                + embed_audio
7597        });
7598
7599        let vision_dtype = if dtype == DType::F16 {
7600            DType::F32
7601        } else {
7602            dtype
7603        };
7604
7605        Ok(text_elems * dtype.size_in_bytes()
7606            + vision_elems * vision_dtype.size_in_bytes()
7607            + audio_elems * dtype.size_in_bytes())
7608    }
7609
7610    fn layer_sizes_in_bytes(
7611        &self,
7612        config: &str,
7613        dtype: DType,
7614        weight_pack_factor: usize,
7615        _matformer_config: Option<&MatformerSliceConfig>,
7616    ) -> Result<Vec<usize>> {
7617        let cfg: Gemma4Config = serde_json::from_str(config)?;
7618        let tc = &cfg.text_config;
7619        let sizes: Vec<usize> = (0..tc.num_hidden_layers)
7620            .map(|layer_idx| {
7621                let is_sliding = {
7622                    let is_last = layer_idx == tc.num_hidden_layers - 1;
7623                    !is_last && (layer_idx + 1) % tc.sliding_window_pattern != 0
7624                };
7625                let hd = if is_sliding {
7626                    tc.head_dim
7627                } else {
7628                    tc.global_head_dim
7629                };
7630                let nkv = if is_sliding {
7631                    tc.num_key_value_heads
7632                } else {
7633                    tc.num_global_key_value_heads
7634                        .unwrap_or(tc.num_key_value_heads)
7635                };
7636                let use_k_eq_v = tc.attention_k_eq_v && !is_sliding;
7637
7638                let mut attn = tc.hidden_size * tc.num_attention_heads * hd
7639                    + tc.hidden_size * nkv * hd
7640                    + tc.num_attention_heads * hd * tc.hidden_size;
7641                if !use_k_eq_v {
7642                    attn += tc.hidden_size * nkv * hd;
7643                }
7644                attn += 2 * hd;
7645
7646                let mlp = 3 * tc.hidden_size * tc.intermediate_size;
7647
7648                let moe = if tc.enable_moe_block {
7649                    let ne = tc.num_experts.unwrap_or(0);
7650                    let ei = tc.expert_intermediate_size().unwrap_or(0);
7651                    ne * tc.hidden_size * ei * 2
7652                        + ne * ei * tc.hidden_size
7653                        + ne
7654                        + ne * tc.hidden_size
7655                        + tc.hidden_size
7656                        + 3 * tc.hidden_size
7657                } else {
7658                    0
7659                };
7660
7661                let ple = if tc.hidden_size_per_layer_input.unwrap_or(0) > 0 {
7662                    let pd = tc.hidden_size_per_layer_input.unwrap();
7663                    tc.hidden_size * pd + pd * tc.hidden_size + tc.hidden_size
7664                } else {
7665                    0
7666                };
7667
7668                let norms = 4 * tc.hidden_size + 1;
7669
7670                (attn + mlp + moe + ple + norms) * dtype.size_in_bytes() / weight_pack_factor
7671            })
7672            .collect();
7673        Ok(sizes)
7674    }
7675
7676    fn num_layers(&self, config: &str) -> Result<usize> {
7677        let cfg: Gemma4Config = serde_json::from_str(config)?;
7678        Ok(cfg.text_config.num_hidden_layers)
7679    }
7680
7681    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
7682        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
7683    }
7684
7685    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
7686        let cfg: Gemma4Config = serde_json::from_str(config)?;
7687        let tc = &cfg.text_config;
7688
7689        let cfg = ModelConfigMetadata {
7690            max_seq_len: tc.max_position_embeddings,
7691            num_layers: tc.num_hidden_layers,
7692            hidden_size: tc.hidden_size,
7693            num_kv_heads: tc.num_key_value_heads,
7694            num_attn_heads: tc.num_attention_heads,
7695            sliding_window: Some(tc.sliding_window),
7696            k_head_dim: tc.global_head_dim,
7697            v_head_dim: tc.global_head_dim,
7698            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
7699        };
7700
7701        Ok(Box::new(cfg))
7702    }
7703}