Skip to main content

mistralrs_core/pipeline/loaders/
vision_loaders.rs

1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32    EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33    SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47    self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62    Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
66use crate::vision_models::qwen3_vl_moe::{
67    Config as Qwen3VLMoEConfig, Qwen3VLMoEModel, Qwen3VLMoEProcessor,
68};
69use crate::vision_models::{minicpmo, phi4};
70
71pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
72    // pixel_values and pixel_attention_mask only specified for prompt seqs
73    #[allow(clippy::too_many_arguments)]
74    fn forward(
75        &self,
76        input_ids: &Tensor,
77        pixel_values: Option<Tensor>,
78        seqlen_offsets: &[usize],
79        context_lens: Vec<(usize, usize)>,
80        position_ids: Vec<usize>,
81        model_specific_args: Box<dyn Any>, // pixel attention mask, or image sizes, or anything else
82        metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
83        flash_params: &FlashParams,
84    ) -> candle_core::Result<Tensor>;
85    fn device(&self) -> &Device;
86    fn cache(&self) -> &EitherCache;
87    fn cache_mut(&mut self) -> &mut EitherCache;
88    fn max_seq_len(&self) -> usize;
89    fn config(&self) -> &ModelConfigMetadata;
90    /// For a prompt without images. Requires batch size of 1!
91    fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
92}
93
94pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
95    fn load(
96        &self,
97        config: &str,
98        vb: ShardedVarBuilder,
99        normal_loading_metadata: NormalLoadingMetadata,
100        attention_mechanism: AttentionImplementation,
101    ) -> Result<Box<dyn VisionModel + Send + Sync>>;
102    fn is_gptx(&self, config: &str) -> bool;
103    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
104    fn get_processor(
105        &self,
106        model_config: &str,
107        processor_config: Option<ProcessorConfig>,
108        preprocessor_config: PreProcessorConfig,
109        max_edge: Option<u32>,
110    ) -> Arc<dyn Processor + Send + Sync>;
111    fn supports_paged_attention(&self, config: &str) -> bool;
112    fn supports_prefix_cacher(&self, _config: &str) -> bool {
113        // Default is false, specific model must override.
114        false
115    }
116    fn modalities(&self, config: &str) -> Result<Modalities>;
117    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
118    fn get_device_for_tensor(
119        &self,
120        config: &str,
121        _mapper: &dyn DeviceMapper,
122        loading_isq: bool,
123    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
124        if loading_isq {
125            Ok(Arc::new(|_| DeviceForLoadTensor::Base))
126        } else {
127            let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
128            let num_layers = self.model_config(config)?.num_layers();
129            let closure = move |name: String| {
130                if let Some(captures) = re.captures(&name) {
131                    captures
132                        .get(1)
133                        .and_then(|m| m.as_str().parse::<usize>().ok())
134                        .map(|l| l.min(num_layers))
135                        .map(DeviceForLoadTensor::Idx)
136                        .unwrap_or(DeviceForLoadTensor::Base)
137                } else {
138                    DeviceForLoadTensor::Base
139                }
140            };
141
142            Ok(Arc::new(closure))
143        }
144    }
145}
146
147#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
148#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
149/// The architecture to load the vision model as.
150pub enum VisionLoaderType {
151    #[serde(rename = "phi3v")]
152    Phi3V,
153    #[serde(rename = "idefics2")]
154    Idefics2,
155    #[serde(rename = "llava_next")]
156    LLaVANext,
157    #[serde(rename = "llava")]
158    LLaVA,
159    #[serde(rename = "vllama")]
160    VLlama,
161    #[serde(rename = "qwen2vl")]
162    Qwen2VL,
163    #[serde(rename = "idefics3")]
164    Idefics3,
165    #[serde(rename = "minicpmo")]
166    MiniCpmO,
167    #[serde(rename = "phi4mm")]
168    Phi4MM,
169    #[serde(rename = "qwen2_5vl")]
170    Qwen2_5VL,
171    #[serde(rename = "gemma3")]
172    Gemma3,
173    #[serde(rename = "mistral3")]
174    Mistral3,
175    #[serde(rename = "llama4")]
176    Llama4,
177    #[serde(rename = "gemma3n")]
178    Gemma3n,
179    #[serde(rename = "qwen3vl")]
180    Qwen3VL,
181    #[serde(rename = "qwen3vlmoe")]
182    Qwen3VLMoE,
183}
184
185// https://github.com/huggingface/transformers/blob/cff06aac6fad28019930be03f5d467055bf62177/src/transformers/models/auto/modeling_auto.py#L448
186impl VisionLoaderType {
187    pub fn from_causal_lm_name(name: &str) -> Result<Self> {
188        match name {
189            "Phi3VForCausalLM" => Ok(Self::Phi3V),
190            "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
191            "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
192            "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
193            "MllamaForConditionalGeneration" => Ok(Self::VLlama),
194            "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
195            "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
196            "MiniCPMO" => Ok(Self::MiniCpmO),
197            "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
198            "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
199            "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
200            "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
201            "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
202            "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
203            "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
204            "Qwen3VLMoeForConditionalGeneration" => Ok(Self::Qwen3VLMoE),
205            other => anyhow::bail!(
206                "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
207            ),
208        }
209    }
210}
211
212impl FromStr for VisionLoaderType {
213    type Err = String;
214    fn from_str(s: &str) -> Result<Self, Self::Err> {
215        match s {
216            "phi3v" => Ok(Self::Phi3V),
217            "idefics2" => Ok(Self::Idefics2),
218            "llava_next" => Ok(Self::LLaVANext),
219            "llava" => Ok(Self::LLaVA),
220            "vllama" => Ok(Self::VLlama),
221            "qwen2vl" => Ok(Self::Qwen2VL),
222            "idefics3" => Ok(Self::Idefics3),
223            "minicpmo" => Ok(Self::MiniCpmO),
224            "phi4mm" => Ok(Self::Phi4MM),
225            "qwen2_5vl" => Ok(Self::Qwen2_5VL),
226            "gemma3" => Ok(Self::Gemma3),
227            "mistral3" => Ok(Self::Mistral3),
228            "llama4" => Ok(Self::Llama4),
229            "gemma3n" => Ok(Self::Gemma3n),
230            "qwen3vl" => Ok(Self::Qwen3VL),
231            "qwen3vlmoe" => Ok(Self::Qwen3VLMoE),
232            a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`, `qwen3vl`, `qwen3vlmoe`.")),
233        }
234    }
235}
236
237impl std::fmt::Display for VisionLoaderType {
238    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239        let name = match self {
240            VisionLoaderType::Phi3V => "phi3v",
241            VisionLoaderType::Idefics2 => "idefics2",
242            VisionLoaderType::LLaVANext => "llava_next",
243            VisionLoaderType::LLaVA => "llava",
244            VisionLoaderType::VLlama => "vllama",
245            VisionLoaderType::Qwen2VL => "qwen2vl",
246            VisionLoaderType::Idefics3 => "idefics3",
247            VisionLoaderType::MiniCpmO => "minicpmo",
248            VisionLoaderType::Phi4MM => "phi4mm",
249            VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
250            VisionLoaderType::Gemma3 => "gemma3",
251            VisionLoaderType::Mistral3 => "mistral3",
252            VisionLoaderType::Llama4 => "llama4",
253            VisionLoaderType::Gemma3n => "gemma3n",
254            VisionLoaderType::Qwen3VL => "qwen3vl",
255            VisionLoaderType::Qwen3VLMoE => "qwen3vlmoe",
256        };
257        write!(f, "{name}")
258    }
259}
260
261#[derive(Deserialize)]
262struct AutoVisionLoaderConfig {
263    architectures: Vec<String>,
264}
265
266/// Automatically selects a VisionModelLoader implementation based on the JSON `architectures` field.
267pub struct AutoVisionLoader;
268
269impl AutoVisionLoader {
270    fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
271        let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
272        if auto_cfg.architectures.len() != 1 {
273            anyhow::bail!("Expected exactly one architecture in config");
274        }
275
276        let name = &auto_cfg.architectures[0];
277        let tp = VisionLoaderType::from_causal_lm_name(name)?;
278
279        once_log_info(format!("Automatic loader type determined to be `{tp}`"));
280
281        // Delegate to the concrete loader
282        Ok(match tp {
283            VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
284            VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
285            VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
286            VisionLoaderType::LLaVA => Box::new(LLaVALoader),
287            VisionLoaderType::VLlama => Box::new(VLlamaLoader),
288            VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
289            VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
290            VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
291            VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
292            VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
293            VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
294            VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
295            VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
296            VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
297            VisionLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
298            VisionLoaderType::Qwen3VLMoE => Box::new(Qwen3VLMoELoader),
299        })
300    }
301}
302
303impl VisionModelLoader for AutoVisionLoader {
304    fn load(
305        &self,
306        config: &str,
307        vb: ShardedVarBuilder,
308        normal_loading_metadata: NormalLoadingMetadata,
309        attention_mechanism: AttentionImplementation,
310    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
311        Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
312    }
313
314    fn is_gptx(&self, config: &str) -> bool {
315        Self::get_loader(config)
316            .expect("AutoVisionLoader get_loader")
317            .is_gptx(config)
318    }
319
320    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
321        Self::get_loader(config)?.get_config_repr(config)
322    }
323
324    fn get_processor(
325        &self,
326        model_config: &str,
327        proc_cfg: Option<ProcessorConfig>,
328        preproc_cfg: PreProcessorConfig,
329        max_edge: Option<u32>,
330    ) -> Arc<dyn Processor + Send + Sync> {
331        Self::get_loader(model_config)
332            .expect("AutoVisionLoader get_loader")
333            .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
334    }
335
336    fn supports_paged_attention(&self, config: &str) -> bool {
337        Self::get_loader(config)
338            .expect("AutoVisionLoader")
339            .supports_paged_attention(config)
340    }
341
342    fn modalities(&self, config: &str) -> Result<Modalities> {
343        Self::get_loader(config)?.modalities(config)
344    }
345
346    fn supports_prefix_cacher(&self, config: &str) -> bool {
347        Self::get_loader(config)
348            .expect("AutoVisionLoader")
349            .supports_prefix_cacher(config)
350    }
351
352    fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
353        Self::get_loader(config)
354            .expect("AutoVisionLoader")
355            .prefixer(config)
356    }
357
358    fn get_device_for_tensor(
359        &self,
360        config: &str,
361        mapper: &dyn DeviceMapper,
362        loading_isq: bool,
363    ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
364        Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
365    }
366}
367
368impl IsqModelLoader for AutoVisionLoader {
369    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
370        Self::get_loader(config)?.isq_layer_regexes(config)
371    }
372    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
373        Self::get_loader(config)?.immediate_isq_predicates(config)
374    }
375}
376
377impl DeviceMappedModelLoader for AutoVisionLoader {
378    fn mapped_max_act_size_elems(
379        &self,
380        config: &str,
381        params: &AutoDeviceMapParams,
382    ) -> Result<usize> {
383        Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
384    }
385    fn non_mapped_max_act_size_elems(
386        &self,
387        config: &str,
388        params: &AutoDeviceMapParams,
389    ) -> Result<usize> {
390        Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
391    }
392    fn non_mapped_size_in_bytes(
393        &self,
394        config: &str,
395        dtype: DType,
396        weight_pack_factor: usize,
397        _matformer_config: Option<&MatformerSliceConfig>,
398    ) -> Result<usize> {
399        Self::get_loader(config)?.non_mapped_size_in_bytes(
400            config,
401            dtype,
402            weight_pack_factor,
403            _matformer_config,
404        )
405    }
406    fn layer_sizes_in_bytes(
407        &self,
408        config: &str,
409        dtype: DType,
410        weight_pack_factor: usize,
411        _matformer_config: Option<&MatformerSliceConfig>,
412    ) -> Result<Vec<usize>> {
413        Self::get_loader(config)?.layer_sizes_in_bytes(
414            config,
415            dtype,
416            weight_pack_factor,
417            _matformer_config,
418        )
419    }
420    fn num_layers(&self, config: &str) -> Result<usize> {
421        Self::get_loader(config)?.num_layers(config)
422    }
423    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
424        Self::get_loader(config)?.model_config(config)
425    }
426}
427
428macro_rules! bias_if {
429    ($cond:expr, $size:expr) => {
430        if $cond {
431            $size
432        } else {
433            0
434        }
435    };
436}
437
438fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
439    let pre_layer_norm = cfg.hidden_size;
440    let final_layer_norm = cfg.hidden_size;
441
442    let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
443    let num_positions = num_patches + 1;
444
445    let class_embedding = cfg.hidden_size;
446
447    let position_ids = num_positions;
448    let position_embedding = num_positions * cfg.hidden_size;
449
450    let conv2dconfig = Conv2dConfig {
451        stride: cfg.patch_size,
452        ..Default::default()
453    };
454    let patch_embedding =
455        cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
456
457    let encoder_layer_elems = {
458        let layer_norm1 = cfg.hidden_size;
459        let layer_norm2 = cfg.hidden_size;
460
461        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
462        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
463        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
464        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
465
466        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
467        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
468
469        layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
470    };
471
472    pre_layer_norm
473        + final_layer_norm
474        + class_embedding
475        + position_ids
476        + position_embedding
477        + patch_embedding
478        + cfg.num_hidden_layers * encoder_layer_elems
479}
480
481// ======================== Phi 3 loader
482
483/// [`VisionLoader`] for a Phi 3 Vision model.
484///
485/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
486pub struct Phi3VLoader;
487
488pub struct Phi3VPrefixer;
489
490impl MultimodalPromptPrefixer for Phi3VPrefixer {
491    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
492        // Image indexing starts at 0.
493        format!(
494            "{}{prompt}",
495            image_indexes
496                .into_iter()
497                .map(|image_index| format!("<|image_{}|>", image_index + 1))
498                .join("")
499        )
500    }
501}
502
503impl VisionModelLoader for Phi3VLoader {
504    fn load(
505        &self,
506        config: &str,
507        vb: ShardedVarBuilder,
508        normal_loading_metadata: NormalLoadingMetadata,
509        attention_mechanism: AttentionImplementation,
510    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
511        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
512        Ok(Box::new(Phi3::new(
513            &cfg,
514            vb,
515            self.is_gptx(config),
516            normal_loading_metadata,
517            attention_mechanism,
518        )?))
519    }
520    fn is_gptx(&self, _config: &str) -> bool {
521        true
522    }
523    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
524        let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
525        Ok(Box::new(cfg))
526    }
527    fn get_processor(
528        &self,
529        _model_config: &str,
530        processor_config: Option<ProcessorConfig>,
531        preprocessor_config: PreProcessorConfig,
532        _max_edge: Option<u32>,
533    ) -> Arc<dyn Processor + Send + Sync> {
534        Phi3Processor::new_processor(processor_config, preprocessor_config)
535    }
536    fn supports_paged_attention(&self, _config: &str) -> bool {
537        true
538    }
539    fn supports_prefix_cacher(&self, _config: &str) -> bool {
540        true
541    }
542    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
543        Arc::new(Phi3VPrefixer)
544    }
545    fn modalities(&self, _config: &str) -> Result<Modalities> {
546        Ok(Modalities {
547            input: vec![SupportedModality::Text, SupportedModality::Vision],
548            output: vec![SupportedModality::Text],
549        })
550    }
551}
552
553impl IsqModelLoader for Phi3VLoader {
554    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
555        Ok(vec![
556            Regex::new(r"lm_head\.(weight|bias)$")?,
557            // Attention
558            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
559            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
560            // MLP
561            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
562            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
563        ])
564    }
565    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
566        self.isq_layer_regexes(config)
567    }
568}
569
570impl DeviceMappedModelLoader for Phi3VLoader {
571    fn mapped_max_act_size_elems(
572        &self,
573        config: &str,
574        params: &AutoDeviceMapParams,
575    ) -> Result<usize> {
576        // NOTE: we ignore max_num_images although it can only be one...
577        let AutoDeviceMapParams::Vision {
578            max_seq_len,
579            max_batch_size,
580            max_image_shape: _,
581            max_num_images,
582        } = params
583        else {
584            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
585        };
586
587        let cfg: Phi3Config = serde_json::from_str(config)?;
588
589        let vcfg = &PHI3V_CLIP_CONFIG;
590
591        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
592        let img_seq_len = (num_patches + 1) * max_num_images;
593
594        let max_text_attn = {
595            // This model injects the vision information directly into the input embeddings
596            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
597            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
598        };
599
600        Ok(max_text_attn)
601    }
602
603    fn non_mapped_max_act_size_elems(
604        &self,
605        config: &str,
606        params: &AutoDeviceMapParams,
607    ) -> Result<usize> {
608        // NOTE: we ignore max_num_images although it can only be one...
609        let AutoDeviceMapParams::Vision {
610            max_seq_len: _,
611            max_batch_size,
612            max_image_shape: _,
613            max_num_images,
614        } = params
615        else {
616            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
617        };
618
619        let cfg: Phi3Config = serde_json::from_str(config)?;
620
621        let vcfg = &PHI3V_CLIP_CONFIG;
622
623        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
624        let img_seq_len = num_patches + 1;
625
626        let max_vision_attn = {
627            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
628        };
629
630        Ok(max_vision_attn)
631    }
632
633    fn non_mapped_size_in_bytes(
634        &self,
635        config: &str,
636        dtype: DType,
637        weight_pack_factor: usize,
638        _matformer_config: Option<&MatformerSliceConfig>,
639    ) -> Result<usize> {
640        let cfg: Phi3Config = serde_json::from_str(config)?;
641        let elems = {
642            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
643            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
644            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
645                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
646            } else {
647                0
648            };
649            let norm = cfg.hidden_size;
650
651            let image_embed = {
652                let projection_cls = cfg
653                    .embd_layer
654                    .projection_cls
655                    .clone()
656                    .unwrap_or("linear".to_string());
657                let with_learnable_separator =
658                    cfg.embd_layer.with_learnable_separator.unwrap_or(false);
659                let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
660                let image_dim_out = cfg.img_processor.image_dim_out;
661
662                let proj = match (projection_cls.as_str(), use_hd_transform) {
663                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
664                    ("mlp", true) => {
665                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
666                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
667                        a + b
668                    }
669                    ("mlp", false) => {
670                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
671                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
672                        a + b
673                    }
674                    _ => {
675                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
676                    }
677                };
678
679                let (glb_gn, sub_gn) = if with_learnable_separator {
680                    let glb_gn = image_dim_out * 4;
681                    let sub_gn = image_dim_out * 4;
682                    (glb_gn, sub_gn)
683                } else {
684                    (0, 0)
685                };
686
687                let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
688
689                proj + glb_gn + sub_gn + clip_vit
690            };
691
692            embed_tokens + lm_head + norm + image_embed
693        };
694
695        Ok(elems * dtype.size_in_bytes())
696    }
697
698    fn layer_sizes_in_bytes(
699        &self,
700        config: &str,
701        dtype: DType,
702        weight_pack_factor: usize,
703        _matformer_config: Option<&MatformerSliceConfig>,
704    ) -> Result<Vec<usize>> {
705        let cfg: Phi3Config = serde_json::from_str(config)?;
706        let per_layer_elems = {
707            let input_layernorm = cfg.hidden_size;
708            let post_attention_layernorm = cfg.hidden_size;
709
710            let size_in = cfg.hidden_size;
711            let head_dim = cfg.head_dim();
712            let op_size =
713                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
714            let qkv_proj = size_in * op_size / weight_pack_factor;
715            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
716
717            let h_size = cfg.hidden_size;
718            let i_size = cfg.intermediate_size;
719            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
720            let down_proj = h_size * i_size / weight_pack_factor;
721
722            input_layernorm
723                + post_attention_layernorm
724                + qkv_proj
725                + o_proj
726                + gate_up_proj
727                + down_proj
728        };
729        Ok(vec![
730            per_layer_elems * dtype.size_in_bytes();
731            cfg.num_hidden_layers
732        ])
733    }
734
735    fn num_layers(&self, config: &str) -> Result<usize> {
736        let cfg: Phi3Config = serde_json::from_str(config)?;
737        Ok(cfg.num_hidden_layers)
738    }
739
740    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
741        let cfg: Phi3Config = serde_json::from_str(config)?;
742
743        let cfg = ModelConfigMetadata {
744            max_seq_len: cfg.max_position_embeddings,
745            num_layers: cfg.num_hidden_layers,
746            hidden_size: cfg.hidden_size,
747            num_kv_heads: cfg.num_key_value_heads,
748            num_attn_heads: cfg.num_attention_heads,
749            sliding_window: cfg.sliding_window,
750            k_head_dim: cfg.head_dim(),
751            v_head_dim: cfg.head_dim(),
752            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
753        };
754
755        Ok(Box::new(cfg))
756    }
757
758    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
759        Some(vec![NonMappedSubModel::Vision])
760    }
761}
762
763// ======================== Idefics 2 loader
764
765/// [`VisionLoader`] for an Idefics 2 Vision model.
766///
767/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
768pub struct Idefics2Loader;
769
770pub struct Idefics2Prefixer;
771
772impl MultimodalPromptPrefixer for Idefics2Prefixer {
773    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
774        // Chat template does it
775        prompt.to_string()
776    }
777}
778
779impl VisionModelLoader for Idefics2Loader {
780    fn load(
781        &self,
782        config: &str,
783        vb: ShardedVarBuilder,
784        normal_loading_metadata: NormalLoadingMetadata,
785        attention_mechanism: AttentionImplementation,
786    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
787        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
788        Ok(Box::new(Idefics2::new(
789            &cfg,
790            vb,
791            self.is_gptx(config),
792            normal_loading_metadata,
793            attention_mechanism,
794        )?))
795    }
796    fn is_gptx(&self, _config: &str) -> bool {
797        true
798    }
799    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
800        let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
801        Ok(Box::new(cfg))
802    }
803    fn get_processor(
804        &self,
805        _model_config: &str,
806        processor_config: Option<ProcessorConfig>,
807        preprocessor_config: PreProcessorConfig,
808        max_edge: Option<u32>,
809    ) -> Arc<dyn Processor + Send + Sync> {
810        Arc::new(Idefics2Processor::new(
811            processor_config.unwrap(),
812            preprocessor_config,
813            max_edge,
814        ))
815    }
816    fn supports_paged_attention(&self, _config: &str) -> bool {
817        true
818    }
819    fn supports_prefix_cacher(&self, _config: &str) -> bool {
820        true
821    }
822    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
823        Arc::new(Idefics2Prefixer)
824    }
825    fn modalities(&self, _config: &str) -> Result<Modalities> {
826        Ok(Modalities {
827            input: vec![SupportedModality::Text, SupportedModality::Vision],
828            output: vec![SupportedModality::Text],
829        })
830    }
831}
832
833impl IsqModelLoader for Idefics2Loader {
834    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
835        Ok(vec![
836            Regex::new(r"lm_head\.(weight|bias)$")?,
837            // Attention
838            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
839            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
840            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
841            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
842            // MLP
843            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
844            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
845            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
846        ])
847    }
848    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
849        Ok(vec![
850            Regex::new(r"lm_head\.(weight|bias)$")?,
851            // Attention
852            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
853            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
854            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
855            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
856            // MLP
857            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
858            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
859            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
860        ])
861    }
862}
863
864impl DeviceMappedModelLoader for Idefics2Loader {
865    fn mapped_max_act_size_elems(
866        &self,
867        config: &str,
868        params: &AutoDeviceMapParams,
869    ) -> Result<usize> {
870        let AutoDeviceMapParams::Vision {
871            max_seq_len,
872            max_batch_size,
873            max_image_shape: _,
874            max_num_images,
875        } = params
876        else {
877            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
878        };
879
880        let cfg: Idefics2Config = serde_json::from_str(config)?;
881
882        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
883        let img_seq_len = (num_patches + 1) * max_num_images;
884
885        let max_text_attn = {
886            // This model injects the vision information directly into the input embeddings
887            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
888            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
889        };
890
891        Ok(max_text_attn)
892    }
893
894    fn non_mapped_max_act_size_elems(
895        &self,
896        config: &str,
897        params: &AutoDeviceMapParams,
898    ) -> Result<usize> {
899        let AutoDeviceMapParams::Vision {
900            max_seq_len: _,
901            max_batch_size,
902            max_image_shape: _,
903            max_num_images,
904        } = params
905        else {
906            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
907        };
908
909        let cfg: Idefics2Config = serde_json::from_str(config)?;
910
911        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
912        let img_seq_len = num_patches + 1;
913
914        let max_vision_attn = {
915            // do_image_splitting = true
916            let images_factor = 5;
917
918            (max_batch_size * images_factor * max_num_images)
919                * cfg.vision_config.num_attention_heads
920                * img_seq_len
921                * img_seq_len
922        };
923
924        Ok(max_vision_attn)
925    }
926
927    fn non_mapped_size_in_bytes(
928        &self,
929        config: &str,
930        dtype: DType,
931        weight_pack_factor: usize,
932        _matformer_config: Option<&MatformerSliceConfig>,
933    ) -> Result<usize> {
934        let cfg: Idefics2Config = serde_json::from_str(config)?;
935        let text_elems = {
936            let tie_word_embeddings = cfg.tie_word_embeddings;
937            let cfg = &cfg.text_config;
938
939            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
940            let lm_head = if !tie_word_embeddings {
941                cfg.hidden_size * cfg.vocab_size
942            } else {
943                0
944            };
945            let norm = cfg.hidden_size;
946            embed_tokens + lm_head + norm
947        };
948
949        let connector_elems = {
950            let tcfg = &cfg.text_config;
951            let vcfg = &cfg.vision_config;
952            let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
953            let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
954            let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
955
956            let perceiver_elems = {
957                let tcfg = &cfg.text_config;
958                let pcfg = &cfg.perceiver_config;
959
960                let n_latents = pcfg.resampler_n_latents;
961                let hidden_size = tcfg.hidden_size;
962                let depth = pcfg.resampler_depth;
963
964                let norm = tcfg.hidden_size;
965                let latents = n_latents * hidden_size;
966
967                let layer_elems = {
968                    let input_latents_norm = hidden_size;
969                    let input_context_norm = hidden_size;
970                    let post_attn_norm = hidden_size;
971
972                    let num_heads = pcfg.resampler_n_heads;
973                    let head_dim = pcfg.resampler_head_dim;
974                    let num_key_value_heads = pcfg.num_key_value_heads;
975
976                    let q_proj = hidden_size * num_heads * head_dim;
977                    let k_proj = hidden_size * num_key_value_heads * head_dim;
978                    let v_proj = hidden_size * num_key_value_heads * head_dim;
979                    let o_proj = num_heads * head_dim * hidden_size;
980
981                    let gate_proj = hidden_size * hidden_size * 4;
982                    let up_proj = hidden_size * hidden_size * 4;
983                    let down_proj = hidden_size * 4 * hidden_size;
984
985                    input_latents_norm
986                        + input_context_norm
987                        + post_attn_norm
988                        + q_proj
989                        + k_proj
990                        + v_proj
991                        + o_proj
992                        + gate_proj
993                        + up_proj
994                        + down_proj
995                };
996
997                norm + latents + layer_elems * depth
998            };
999
1000            gate_proj + up_proj + down_proj + perceiver_elems
1001        };
1002
1003        let vision_transformer = {
1004            let cfg = &cfg.vision_config;
1005
1006            let post_layernorm = cfg.hidden_size;
1007
1008            let conv_config = Conv2dConfig {
1009                stride: cfg.patch_size,
1010                ..Default::default()
1011            };
1012            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1013                * cfg.patch_size
1014                * cfg.patch_size;
1015
1016            let num_patches_per_side = cfg.image_size / cfg.patch_size;
1017            let num_patches = num_patches_per_side.pow(2);
1018            let position_embedding = num_patches * cfg.hidden_size;
1019
1020            let layer_elems = {
1021                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1022                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1023
1024                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1025                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1026
1027                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1028                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1029                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1030                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1031
1032                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1033            };
1034
1035            post_layernorm + patch_embedding + position_embedding + layer_elems
1036        };
1037
1038        let elems = text_elems + connector_elems + vision_transformer;
1039
1040        Ok(elems * dtype.size_in_bytes())
1041    }
1042
1043    fn layer_sizes_in_bytes(
1044        &self,
1045        config: &str,
1046        dtype: DType,
1047        weight_pack_factor: usize,
1048        _matformer_config: Option<&MatformerSliceConfig>,
1049    ) -> Result<Vec<usize>> {
1050        let cfg: Idefics2Config = serde_json::from_str(config)?;
1051        let cfg = cfg.text_config;
1052        let per_layer_elems = {
1053            let input_layernorm = cfg.hidden_size;
1054            let post_attention_layernorm = cfg.hidden_size;
1055
1056            let size_in = cfg.hidden_size;
1057            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1058            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1059            let q_proj = size_in * size_q / weight_pack_factor;
1060            let k_proj = size_in * size_kv / weight_pack_factor;
1061            let v_proj = size_in * size_kv / weight_pack_factor;
1062            let o_proj = size_q * size_in / weight_pack_factor;
1063
1064            let h_size = cfg.hidden_size;
1065            let i_size = cfg.intermediate_size;
1066            let gate_proj = h_size * i_size / weight_pack_factor;
1067            let up_proj = h_size * i_size / weight_pack_factor;
1068            let down_proj = i_size * h_size / weight_pack_factor;
1069
1070            input_layernorm
1071                + post_attention_layernorm
1072                + q_proj
1073                + k_proj
1074                + v_proj
1075                + o_proj
1076                + gate_proj
1077                + up_proj
1078                + down_proj
1079        };
1080        Ok(vec![
1081            per_layer_elems * dtype.size_in_bytes();
1082            cfg.num_hidden_layers
1083        ])
1084    }
1085
1086    fn num_layers(&self, config: &str) -> Result<usize> {
1087        let cfg: Idefics2Config = serde_json::from_str(config)?;
1088        Ok(cfg.text_config.num_hidden_layers)
1089    }
1090    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1091        let cfg: Idefics2Config = serde_json::from_str(config)?;
1092        let cfg = &cfg.text_config;
1093
1094        let cfg = ModelConfigMetadata {
1095            max_seq_len: cfg.max_position_embeddings,
1096            num_layers: cfg.num_hidden_layers,
1097            hidden_size: cfg.hidden_size,
1098            num_kv_heads: cfg.num_key_value_heads,
1099            num_attn_heads: cfg.num_attention_heads,
1100            sliding_window: cfg.sliding_window,
1101            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1102            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1103            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1104        };
1105
1106        Ok(Box::new(cfg))
1107    }
1108
1109    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1110        Some(vec![NonMappedSubModel::Vision])
1111    }
1112}
1113
1114// ======================== LLaVANext Loader
1115
1116/// [`VisionLoader`] for an LLaVANext Vision model.
1117///
1118/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
1119pub struct LLaVANextLoader;
1120
1121pub struct LLaVANextPrefixer;
1122
1123impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1124    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1125        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1126    }
1127}
1128
1129impl VisionModelLoader for LLaVANextLoader {
1130    fn load(
1131        &self,
1132        config: &str,
1133        vb: ShardedVarBuilder,
1134        normal_loading_metadata: NormalLoadingMetadata,
1135        attention_mechanism: AttentionImplementation,
1136    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1137        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1138        Ok(Box::new(LLaVANext::new(
1139            &cfg,
1140            vb,
1141            self.is_gptx(config),
1142            normal_loading_metadata,
1143            attention_mechanism,
1144        )?))
1145    }
1146    fn is_gptx(&self, _config: &str) -> bool {
1147        false
1148    }
1149    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1150        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1151        Ok(Box::new(cfg))
1152    }
1153    fn get_processor(
1154        &self,
1155        model_config: &str,
1156        _processor_config: Option<ProcessorConfig>,
1157        _preprocessor_config: PreProcessorConfig,
1158        _max_edge: Option<u32>,
1159    ) -> Arc<dyn Processor + Send + Sync> {
1160        Arc::new(LLaVANextProcessor::new(model_config))
1161    }
1162    fn supports_paged_attention(&self, _config: &str) -> bool {
1163        true
1164    }
1165    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1166        true
1167    }
1168    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1169        Arc::new(LLaVANextPrefixer)
1170    }
1171    fn modalities(&self, _config: &str) -> Result<Modalities> {
1172        Ok(Modalities {
1173            input: vec![SupportedModality::Text, SupportedModality::Vision],
1174            output: vec![SupportedModality::Text],
1175        })
1176    }
1177}
1178
1179impl IsqModelLoader for LLaVANextLoader {
1180    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1181        Ok(vec![
1182            Regex::new(r"lm_head\.(weight|bias)$")?,
1183            // Attention
1184            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1185            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1186            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1187            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1188            // MLP
1189            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1190            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1191            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1192        ])
1193    }
1194    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1195        Ok(vec![
1196            Regex::new(r"lm_head\.(weight|bias)$")?,
1197            // Attention
1198            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1199            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1200            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1201            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1202            // MLP
1203            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1204            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1205            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1206        ])
1207    }
1208}
1209
1210impl DeviceMappedModelLoader for LLaVANextLoader {
1211    fn mapped_max_act_size_elems(
1212        &self,
1213        config: &str,
1214        params: &AutoDeviceMapParams,
1215    ) -> Result<usize> {
1216        let AutoDeviceMapParams::Vision {
1217            max_seq_len,
1218            max_batch_size,
1219            max_image_shape,
1220            max_num_images,
1221        } = params
1222        else {
1223            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1224        };
1225
1226        let config: LLaVAConfig = serde_json::from_str(config)?;
1227
1228        #[allow(clippy::cast_possible_truncation)]
1229        let img_seq_len =
1230            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1231                &config,
1232                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1233            );
1234        let img_seq_len = img_seq_len * max_num_images;
1235
1236        let max_text_attn = {
1237            let cfg = &config.text_config;
1238            // This model injects the vision information directly into the input embeddings
1239            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1240
1241            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1242        };
1243
1244        Ok(max_text_attn)
1245    }
1246
1247    fn non_mapped_max_act_size_elems(
1248        &self,
1249        config: &str,
1250        params: &AutoDeviceMapParams,
1251    ) -> Result<usize> {
1252        let AutoDeviceMapParams::Vision {
1253            max_seq_len: _,
1254            max_batch_size,
1255            max_image_shape,
1256            max_num_images,
1257        } = params
1258        else {
1259            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1260        };
1261
1262        let config: LLaVAConfig = serde_json::from_str(config)?;
1263
1264        #[allow(clippy::cast_possible_truncation)]
1265        let img_seq_len =
1266            llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1267                &config,
1268                (max_image_shape.0 as u32, max_image_shape.1 as u32),
1269            );
1270
1271        let max_vision_attn = {
1272            (max_batch_size * max_num_images)
1273                * config.vision_config.num_attention_heads
1274                * img_seq_len
1275                * img_seq_len
1276        };
1277
1278        Ok(max_vision_attn)
1279    }
1280
1281    fn non_mapped_size_in_bytes(
1282        &self,
1283        config: &str,
1284        dtype: DType,
1285        weight_pack_factor: usize,
1286        _matformer_config: Option<&MatformerSliceConfig>,
1287    ) -> Result<usize> {
1288        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1289        let text_elems = {
1290            let cfg = &cfg.text_config;
1291            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1292            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1293            let norm = cfg.hidden_size;
1294            embed_tokens + lm_head + norm
1295        };
1296
1297        let image_newline = cfg.text_config.hidden_size;
1298        let mmproj = {
1299            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1300                + cfg.text_config.hidden_size;
1301            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1302                + cfg.text_config.hidden_size;
1303
1304            linear_1 + linear_2
1305        };
1306        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1307
1308        let elems = text_elems + image_newline + mmproj + vision_tower;
1309        Ok(elems * dtype.size_in_bytes())
1310    }
1311
1312    fn layer_sizes_in_bytes(
1313        &self,
1314        config: &str,
1315        dtype: DType,
1316        weight_pack_factor: usize,
1317        _matformer_config: Option<&MatformerSliceConfig>,
1318    ) -> Result<Vec<usize>> {
1319        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1320        let per_layer_elems = {
1321            let cfg = &cfg.text_config;
1322            let input_layernorm = cfg.hidden_size;
1323            let post_attention_layernorm = cfg.hidden_size;
1324
1325            let size_in = cfg.hidden_size;
1326            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1327            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1328            let q_proj = size_in * size_q / weight_pack_factor;
1329            let k_proj = size_in * size_kv / weight_pack_factor;
1330            let v_proj = size_in * size_kv / weight_pack_factor;
1331            let o_proj = size_q * size_in / weight_pack_factor;
1332
1333            let h_size = cfg.hidden_size;
1334            let i_size = cfg.intermediate_size;
1335            let gate_proj = h_size * i_size / weight_pack_factor;
1336            let up_proj = h_size * i_size / weight_pack_factor;
1337            let down_proj = i_size * h_size / weight_pack_factor;
1338
1339            input_layernorm
1340                + post_attention_layernorm
1341                + q_proj
1342                + k_proj
1343                + v_proj
1344                + o_proj
1345                + gate_proj
1346                + up_proj
1347                + down_proj
1348        };
1349        Ok(vec![
1350            per_layer_elems * dtype.size_in_bytes();
1351            cfg.text_config.num_hidden_layers
1352        ])
1353    }
1354
1355    fn num_layers(&self, config: &str) -> Result<usize> {
1356        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1357        Ok(cfg.text_config.num_hidden_layers)
1358    }
1359
1360    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1361        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1362        let cfg = &cfg.text_config;
1363
1364        let cfg = ModelConfigMetadata {
1365            max_seq_len: cfg.max_position_embeddings,
1366            num_layers: cfg.num_hidden_layers,
1367            hidden_size: cfg.hidden_size,
1368            num_kv_heads: cfg.num_key_value_heads,
1369            num_attn_heads: cfg.num_attention_heads,
1370            sliding_window: cfg.sliding_window,
1371            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1372            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1373            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1374        };
1375
1376        Ok(Box::new(cfg))
1377    }
1378
1379    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1380        Some(vec![NonMappedSubModel::Vision])
1381    }
1382}
1383
1384// ======================== LLaVA Loader
1385
1386/// [`VisionLoader`] for an LLaVA Vision model.
1387///
1388/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
1389pub struct LLaVALoader;
1390
1391pub struct LLaVAPrefixer;
1392
1393impl MultimodalPromptPrefixer for LLaVAPrefixer {
1394    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1395        format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1396    }
1397}
1398
1399impl VisionModelLoader for LLaVALoader {
1400    fn load(
1401        &self,
1402        config: &str,
1403        vb: ShardedVarBuilder,
1404        normal_loading_metadata: NormalLoadingMetadata,
1405        attention_mechanism: AttentionImplementation,
1406    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1407        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1408        Ok(Box::new(LLaVA::new(
1409            &cfg,
1410            vb,
1411            self.is_gptx(config),
1412            normal_loading_metadata,
1413            attention_mechanism,
1414        )?))
1415    }
1416    fn is_gptx(&self, _config: &str) -> bool {
1417        false
1418    }
1419    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1420        let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1421        Ok(Box::new(cfg))
1422    }
1423    fn get_processor(
1424        &self,
1425        model_config: &str,
1426        _processor_config: Option<ProcessorConfig>,
1427        _preprocessor_config: PreProcessorConfig,
1428        _max_edge: Option<u32>,
1429    ) -> Arc<dyn Processor + Send + Sync> {
1430        Arc::new(LLaVAProcessor::new(model_config))
1431    }
1432    fn supports_paged_attention(&self, _config: &str) -> bool {
1433        true
1434    }
1435    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1436        true
1437    }
1438    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1439        Arc::new(LLaVAPrefixer)
1440    }
1441    fn modalities(&self, _config: &str) -> Result<Modalities> {
1442        Ok(Modalities {
1443            input: vec![SupportedModality::Text, SupportedModality::Vision],
1444            output: vec![SupportedModality::Text],
1445        })
1446    }
1447}
1448
1449impl IsqModelLoader for LLaVALoader {
1450    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1451        Ok(vec![
1452            Regex::new(r"lm_head\.(weight|bias)$")?,
1453            // Attention
1454            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1455            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1456            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1457            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1458            // MLP
1459            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1460            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1461            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1462        ])
1463    }
1464    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1465        Ok(vec![
1466            Regex::new(r"lm_head\.(weight|bias)$")?,
1467            // Attention
1468            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1469            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1470            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1471            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1472            // MLP
1473            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1474            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1475            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1476        ])
1477    }
1478}
1479
1480impl DeviceMappedModelLoader for LLaVALoader {
1481    fn mapped_max_act_size_elems(
1482        &self,
1483        config: &str,
1484        params: &AutoDeviceMapParams,
1485    ) -> Result<usize> {
1486        let AutoDeviceMapParams::Vision {
1487            max_seq_len,
1488            max_batch_size,
1489            max_image_shape: _,
1490            max_num_images,
1491        } = params
1492        else {
1493            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1494        };
1495
1496        let config: LLaVAConfig = serde_json::from_str(config)?;
1497
1498        let img_seq_len =
1499            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1500        let img_seq_len = img_seq_len * max_num_images;
1501
1502        let max_text_attn = {
1503            let cfg = &config.text_config;
1504            // This model injects the vision information directly into the input embeddings
1505            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1506
1507            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1508        };
1509
1510        Ok(max_text_attn)
1511    }
1512
1513    fn non_mapped_max_act_size_elems(
1514        &self,
1515        config: &str,
1516        params: &AutoDeviceMapParams,
1517    ) -> Result<usize> {
1518        let AutoDeviceMapParams::Vision {
1519            max_seq_len: _,
1520            max_batch_size,
1521            max_image_shape: _,
1522            max_num_images,
1523        } = params
1524        else {
1525            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1526        };
1527
1528        let config: LLaVAConfig = serde_json::from_str(config)?;
1529
1530        let img_seq_len =
1531            llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1532
1533        let max_vision_attn = {
1534            (max_batch_size * max_num_images)
1535                * config.vision_config.num_attention_heads
1536                * img_seq_len
1537                * img_seq_len
1538        };
1539
1540        Ok(max_vision_attn)
1541    }
1542
1543    fn non_mapped_size_in_bytes(
1544        &self,
1545        config: &str,
1546        dtype: DType,
1547        weight_pack_factor: usize,
1548        _matformer_config: Option<&MatformerSliceConfig>,
1549    ) -> Result<usize> {
1550        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1551        let text_elems = {
1552            let cfg = &cfg.text_config;
1553            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1554            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1555            let norm = cfg.hidden_size;
1556            embed_tokens + lm_head + norm
1557        };
1558
1559        let image_newline = cfg.text_config.hidden_size;
1560        let mmproj = {
1561            let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1562                + cfg.text_config.hidden_size;
1563            let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1564                + cfg.text_config.hidden_size;
1565
1566            linear_1 + linear_2
1567        };
1568        let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1569
1570        let elems = text_elems + image_newline + mmproj + vision_tower;
1571        Ok(elems * dtype.size_in_bytes())
1572    }
1573
1574    fn layer_sizes_in_bytes(
1575        &self,
1576        config: &str,
1577        dtype: DType,
1578        weight_pack_factor: usize,
1579        _matformer_config: Option<&MatformerSliceConfig>,
1580    ) -> Result<Vec<usize>> {
1581        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1582        let per_layer_elems = {
1583            let cfg = &cfg.text_config;
1584            let input_layernorm = cfg.hidden_size;
1585            let post_attention_layernorm = cfg.hidden_size;
1586
1587            let size_in = cfg.hidden_size;
1588            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1589            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1590            let q_proj = size_in * size_q / weight_pack_factor;
1591            let k_proj = size_in * size_kv / weight_pack_factor;
1592            let v_proj = size_in * size_kv / weight_pack_factor;
1593            let o_proj = size_q * size_in / weight_pack_factor;
1594
1595            let h_size = cfg.hidden_size;
1596            let i_size = cfg.intermediate_size;
1597            let gate_proj = h_size * i_size / weight_pack_factor;
1598            let up_proj = h_size * i_size / weight_pack_factor;
1599            let down_proj = i_size * h_size / weight_pack_factor;
1600
1601            input_layernorm
1602                + post_attention_layernorm
1603                + q_proj
1604                + k_proj
1605                + v_proj
1606                + o_proj
1607                + gate_proj
1608                + up_proj
1609                + down_proj
1610        };
1611        Ok(vec![
1612            per_layer_elems * dtype.size_in_bytes();
1613            cfg.text_config.num_hidden_layers
1614        ])
1615    }
1616
1617    fn num_layers(&self, config: &str) -> Result<usize> {
1618        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1619        Ok(cfg.text_config.num_hidden_layers)
1620    }
1621
1622    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1623        let cfg: LLaVAConfig = serde_json::from_str(config)?;
1624        let cfg = &cfg.text_config;
1625
1626        let cfg = ModelConfigMetadata {
1627            max_seq_len: cfg.max_position_embeddings,
1628            num_layers: cfg.num_hidden_layers,
1629            hidden_size: cfg.hidden_size,
1630            num_kv_heads: cfg.num_key_value_heads,
1631            num_attn_heads: cfg.num_attention_heads,
1632            sliding_window: cfg.sliding_window,
1633            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1634            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1635            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1636        };
1637
1638        Ok(Box::new(cfg))
1639    }
1640
1641    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1642        Some(vec![NonMappedSubModel::Vision])
1643    }
1644}
1645
1646// ======================== MLlama Loader
1647
1648/// [`VisionLoader`] for an Llama Vision model.
1649///
1650/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
1651pub struct VLlamaLoader;
1652
1653pub struct VLlamaPrefixer;
1654
1655impl MultimodalPromptPrefixer for VLlamaPrefixer {
1656    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1657        format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1658    }
1659}
1660
1661impl VisionModelLoader for VLlamaLoader {
1662    fn load(
1663        &self,
1664        config: &str,
1665        vb: ShardedVarBuilder,
1666        normal_loading_metadata: NormalLoadingMetadata,
1667        attention_mechanism: AttentionImplementation,
1668    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1669        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1670        Ok(Box::new(MLlamaModel::new(
1671            &cfg,
1672            vb,
1673            self.is_gptx(config),
1674            normal_loading_metadata,
1675            attention_mechanism,
1676        )?))
1677    }
1678    fn is_gptx(&self, _config: &str) -> bool {
1679        true
1680    }
1681    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1682        let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1683        Ok(Box::new(cfg))
1684    }
1685    fn get_processor(
1686        &self,
1687        _model_config: &str,
1688        _processor_config: Option<ProcessorConfig>,
1689        _preprocessor_config: PreProcessorConfig,
1690        _max_edge: Option<u32>,
1691    ) -> Arc<dyn Processor + Send + Sync> {
1692        Arc::new(MLlamaProcessor::new())
1693    }
1694    fn supports_paged_attention(&self, _config: &str) -> bool {
1695        false
1696    }
1697    fn supports_prefix_cacher(&self, _config: &str) -> bool {
1698        true
1699    }
1700    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1701        Arc::new(VLlamaPrefixer)
1702    }
1703    fn modalities(&self, _config: &str) -> Result<Modalities> {
1704        Ok(Modalities {
1705            input: vec![SupportedModality::Text, SupportedModality::Vision],
1706            output: vec![SupportedModality::Text],
1707        })
1708    }
1709}
1710
1711impl IsqModelLoader for VLlamaLoader {
1712    fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1713        let config: MLlamaConfig = serde_json::from_str(config)?;
1714        let cross_attn_layers = &config.text_config.cross_attention_layers;
1715        let transformer_layers =
1716            (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1717        let mut text_regexes = Vec::new();
1718        for layer in transformer_layers {
1719            text_regexes.extend(vec![
1720                // Attention text
1721                Regex::new(&format!(
1722                    r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1723                ))?,
1724                Regex::new(&format!(
1725                    r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1726                ))?,
1727                Regex::new(&format!(
1728                    r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1729                ))?,
1730                Regex::new(&format!(
1731                    r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1732                ))?,
1733                // MLP text
1734                Regex::new(&format!(
1735                    r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1736                ))?,
1737                Regex::new(&format!(
1738                    r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1739                ))?,
1740                Regex::new(&format!(
1741                    r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1742                ))?,
1743            ]);
1744        }
1745        let vision_regexes = vec![
1746            // Vision attention (transformer)
1747            Regex::new(
1748                r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1749            )?,
1750            Regex::new(
1751                r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1752            )?,
1753            Regex::new(
1754                r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1755            )?,
1756            Regex::new(
1757                r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1758            )?,
1759            // Vision attention (global transforemr)
1760            Regex::new(
1761                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1762            )?,
1763            Regex::new(
1764                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1765            )?,
1766            Regex::new(
1767                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1768            )?,
1769            Regex::new(
1770                r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1771            )?,
1772            // MLP vision
1773            Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1774            Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1775        ];
1776
1777        Ok([text_regexes, vision_regexes].concat())
1778    }
1779    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1780        self.isq_layer_regexes(config)
1781    }
1782}
1783
1784impl DeviceMappedModelLoader for VLlamaLoader {
1785    fn mapped_max_act_size_elems(
1786        &self,
1787        config: &str,
1788        params: &AutoDeviceMapParams,
1789    ) -> Result<usize> {
1790        let AutoDeviceMapParams::Vision {
1791            max_seq_len,
1792            max_batch_size,
1793            max_image_shape: _,
1794            max_num_images,
1795        } = params
1796        else {
1797            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1798        };
1799
1800        let config: MLlamaConfig = serde_json::from_str(config)?;
1801
1802        let img_seq_len = {
1803            let cfg = &config.vision_config;
1804            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1805            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1806            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1807        };
1808        let img_seq_len = img_seq_len * max_num_images;
1809
1810        let max_cross_text_attn = {
1811            let cfg = &config.text_config;
1812            max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1813        };
1814
1815        let max_self_text_attn = {
1816            let cfg = &config.text_config;
1817            max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1818        };
1819
1820        Ok(max_self_text_attn.max(max_cross_text_attn))
1821    }
1822
1823    fn non_mapped_max_act_size_elems(
1824        &self,
1825        config: &str,
1826        params: &AutoDeviceMapParams,
1827    ) -> Result<usize> {
1828        let AutoDeviceMapParams::Vision {
1829            max_seq_len: _,
1830            max_batch_size,
1831            max_image_shape: _,
1832            max_num_images,
1833        } = params
1834        else {
1835            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1836        };
1837
1838        let config: MLlamaConfig = serde_json::from_str(config)?;
1839
1840        let img_seq_len = {
1841            let cfg = &config.vision_config;
1842            let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1843            let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1844            cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1845        };
1846        let max_vision_attn = {
1847            let cfg = &config.vision_config;
1848            (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1849        };
1850
1851        Ok(max_vision_attn)
1852    }
1853
1854    fn non_mapped_size_in_bytes(
1855        &self,
1856        config: &str,
1857        dtype: DType,
1858        weight_pack_factor: usize,
1859        _matformer_config: Option<&MatformerSliceConfig>,
1860    ) -> Result<usize> {
1861        let config: MLlamaConfig = serde_json::from_str(config)?;
1862        let text_elems = {
1863            let cfg = &config.text_config;
1864            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1865            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
1866            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1867                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1868            } else {
1869                0
1870            };
1871            let norm = cfg.hidden_size;
1872            embed_tokens + lm_head + norm
1873        };
1874
1875        let vision_elems = {
1876            let cfg = &config.vision_config;
1877
1878            let conv_cfg = Conv2dConfig {
1879                stride: cfg.patch_size,
1880                ..Default::default()
1881            };
1882            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1883                * cfg.patch_size
1884                * cfg.patch_size;
1885
1886            let class_embedding = cfg.hidden_size;
1887
1888            let gated_positional_embedding = {
1889                let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1890                let embedding = num_patches * cfg.hidden_size;
1891                let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1892                    * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1893
1894                embedding + tile_embedding
1895            };
1896
1897            let pre_tile_positional_embedding =
1898                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1899            let post_tile_positional_embedding =
1900                (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1901
1902            let layernorm_pre = cfg.hidden_size;
1903            let layernorm_post = cfg.hidden_size;
1904
1905            let encoder_layer = {
1906                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1907                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1908
1909                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1910                let q_proj =
1911                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1912                let k_proj =
1913                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1914                let v_proj =
1915                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1916                let o_proj =
1917                    cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1918
1919                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1920                    + cfg.intermediate_size;
1921                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1922                    + cfg.hidden_size;
1923
1924                input_layernorm
1925                    + post_attention_layernorm
1926                    + q_proj
1927                    + k_proj
1928                    + v_proj
1929                    + o_proj
1930                    + fc1
1931                    + fc2
1932            };
1933
1934            patch_embedding
1935                + class_embedding
1936                + gated_positional_embedding
1937                + pre_tile_positional_embedding
1938                + post_tile_positional_embedding
1939                + layernorm_pre
1940                + layernorm_post
1941                + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1942        };
1943
1944        let elems = text_elems + vision_elems;
1945        Ok(elems * dtype.size_in_bytes())
1946    }
1947
1948    fn layer_sizes_in_bytes(
1949        &self,
1950        config: &str,
1951        dtype: DType,
1952        weight_pack_factor: usize,
1953        _matformer_config: Option<&MatformerSliceConfig>,
1954    ) -> Result<Vec<usize>> {
1955        let config: MLlamaConfig = serde_json::from_str(config)?;
1956        let cfg = &config.text_config;
1957
1958        let mut layer_sizes = Vec::new();
1959
1960        for i in 0..cfg.num_hidden_layers {
1961            let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1962                // No isq for cross attention
1963                1
1964            } else {
1965                weight_pack_factor
1966            };
1967
1968            let per_layer_elems = {
1969                let input_layernorm = cfg.hidden_size;
1970                let post_attention_layernorm = cfg.hidden_size;
1971
1972                let size_in = cfg.hidden_size;
1973                let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1974                let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1975                let q_proj = size_in * size_q / weight_pack_factor;
1976                let k_proj = size_in * size_kv / weight_pack_factor;
1977                let v_proj = size_in * size_kv / weight_pack_factor;
1978                let o_proj = size_q * size_in / weight_pack_factor;
1979
1980                let h_size = cfg.hidden_size;
1981                let i_size = cfg.intermediate_size;
1982                let gate_proj = h_size * i_size / weight_pack_factor;
1983                let up_proj = h_size * i_size / weight_pack_factor;
1984                let down_proj = i_size * h_size / weight_pack_factor;
1985
1986                input_layernorm
1987                    + post_attention_layernorm
1988                    + q_proj
1989                    + k_proj
1990                    + v_proj
1991                    + o_proj
1992                    + gate_proj
1993                    + up_proj
1994                    + down_proj
1995            };
1996
1997            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1998        }
1999
2000        Ok(layer_sizes)
2001    }
2002
2003    fn num_layers(&self, config: &str) -> Result<usize> {
2004        let config: MLlamaConfig = serde_json::from_str(config)?;
2005        Ok(config.text_config.num_hidden_layers)
2006    }
2007
2008    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2009        let cfg: MLlamaConfig = serde_json::from_str(config)?;
2010        let cfg = &cfg.text_config;
2011
2012        let cfg = ModelConfigMetadata {
2013            max_seq_len: cfg.max_position_embeddings,
2014            num_layers: cfg.num_hidden_layers,
2015            hidden_size: cfg.hidden_size,
2016            num_kv_heads: cfg.num_key_value_heads,
2017            num_attn_heads: cfg.num_attention_heads,
2018            sliding_window: None,
2019            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2020            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2021            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2022        };
2023
2024        Ok(Box::new(cfg))
2025    }
2026
2027    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2028        Some(vec![NonMappedSubModel::Vision])
2029    }
2030}
2031
2032// ======================== Qwen2VL Loader
2033
2034/// [`VisionLoader`] for an Qwen2-VL model.
2035///
2036/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
2037pub struct Qwen2VLLoader;
2038
2039pub struct Qwen2VLPrefixer;
2040
2041impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2042    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2043        format!(
2044            "{}{prompt}",
2045            format!(
2046                "{}{}{}",
2047                Qwen2VLProcessor::VISION_START,
2048                Qwen2VLProcessor::IMAGE_PAD,
2049                Qwen2VLProcessor::VISION_END
2050            )
2051            .repeat(image_indexes.len())
2052        )
2053    }
2054}
2055
2056impl VisionModelLoader for Qwen2VLLoader {
2057    fn load(
2058        &self,
2059        config: &str,
2060        vb: ShardedVarBuilder,
2061        normal_loading_metadata: NormalLoadingMetadata,
2062        attention_mechanism: AttentionImplementation,
2063    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2064        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2065        Ok(Box::new(Qwen2VLModel::new(
2066            &cfg,
2067            vb,
2068            self.is_gptx(config),
2069            normal_loading_metadata,
2070            attention_mechanism,
2071        )?))
2072    }
2073    fn is_gptx(&self, _config: &str) -> bool {
2074        true
2075    }
2076    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2077        let config: Qwen2VLConfig = serde_json::from_str(config)?;
2078        Ok(Box::new(config))
2079    }
2080    fn get_processor(
2081        &self,
2082        _model_config: &str,
2083        _processor_config: Option<ProcessorConfig>,
2084        _preprocessor_config: PreProcessorConfig,
2085        max_edge: Option<u32>,
2086    ) -> Arc<dyn Processor + Send + Sync> {
2087        Arc::new(Qwen2VLProcessor::new(max_edge))
2088    }
2089    fn supports_paged_attention(&self, _config: &str) -> bool {
2090        false
2091    }
2092    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2093        Arc::new(Qwen2VLPrefixer)
2094    }
2095    fn modalities(&self, _config: &str) -> Result<Modalities> {
2096        Ok(Modalities {
2097            input: vec![SupportedModality::Text, SupportedModality::Vision],
2098            output: vec![SupportedModality::Text],
2099        })
2100    }
2101}
2102
2103impl IsqModelLoader for Qwen2VLLoader {
2104    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2105        Ok(vec![
2106            Regex::new(r"lm_head\.(weight|bias)$")?,
2107            // Attention
2108            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2109            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2110            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2111            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2112            // MLP
2113            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2114            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2115            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2116        ])
2117    }
2118    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2119        self.isq_layer_regexes(config)
2120    }
2121}
2122
2123impl DeviceMappedModelLoader for Qwen2VLLoader {
2124    fn mapped_max_act_size_elems(
2125        &self,
2126        config: &str,
2127        params: &AutoDeviceMapParams,
2128    ) -> Result<usize> {
2129        let AutoDeviceMapParams::Vision {
2130            max_seq_len,
2131            max_batch_size,
2132            max_image_shape,
2133            max_num_images,
2134        } = params
2135        else {
2136            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2137        };
2138
2139        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2140
2141        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
2142        let img_seq_len = {
2143            let cfg = &cfg.vision_config;
2144            // grid_t is 1 for images (temporal dimension is for video only)
2145            let grid_t = 1;
2146            // After patch embedding and spatial merge, the effective grid dimensions are reduced
2147            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
2148            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
2149            grid_t * grid_h * grid_w * max_num_images
2150        };
2151
2152        let max_text_attn = {
2153            // This model injects the vision information directly into the input embeddings
2154            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2155            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2156        };
2157
2158        Ok(max_text_attn)
2159    }
2160
2161    fn non_mapped_max_act_size_elems(
2162        &self,
2163        config: &str,
2164        params: &AutoDeviceMapParams,
2165    ) -> Result<usize> {
2166        let AutoDeviceMapParams::Vision {
2167            max_seq_len: _,
2168            max_batch_size,
2169            max_image_shape,
2170            max_num_images,
2171        } = params
2172        else {
2173            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2174        };
2175
2176        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2177
2178        // For the vision encoder, before spatial merging
2179        let img_seq_len = {
2180            let cfg = &cfg.vision_config;
2181            // grid_t is 1 for images
2182            let grid_t = 1;
2183            let grid_h = max_image_shape.0 / cfg.patch_size;
2184            let grid_w = max_image_shape.1 / cfg.patch_size;
2185            grid_t * grid_h * grid_w
2186        };
2187
2188        let max_vision_attn = {
2189            let cfg = &cfg.vision_config;
2190            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2191        };
2192
2193        Ok(max_vision_attn)
2194    }
2195
2196    fn non_mapped_size_in_bytes(
2197        &self,
2198        config: &str,
2199        dtype: DType,
2200        weight_pack_factor: usize,
2201        _matformer_config: Option<&MatformerSliceConfig>,
2202    ) -> Result<usize> {
2203        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2204        let text_elems = {
2205            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2206            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
2207            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2208                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2209            } else {
2210                0
2211            };
2212            let norm = cfg.hidden_size;
2213            embed_tokens + lm_head + norm
2214        };
2215
2216        let patch_merger = {
2217            let cfg = &cfg.vision_config;
2218            let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2219
2220            let mlp0 = hidden_size * hidden_size + hidden_size;
2221            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2222
2223            let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2224
2225            mlp0 + mlp2 + ln_q
2226        };
2227
2228        let patch_embed = {
2229            let cfg = &cfg.vision_config;
2230            let conv_cfg = Conv3dConfig {
2231                stride: cfg.patch_size,
2232                ..Default::default()
2233            };
2234            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2235            cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2236                * kernel_sizes[0]
2237                * kernel_sizes[1]
2238                * kernel_sizes[2]
2239        };
2240
2241        let encoder_layer = {
2242            let cfg = &cfg.vision_config;
2243            let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2244            let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2245
2246            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2247            let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2248            let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2249            let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2250
2251            let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2252            let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2253
2254            norm1 + norm2 + fc1 + fc2 + qkv + out
2255        };
2256
2257        let elems =
2258            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2259
2260        Ok(elems * dtype.size_in_bytes())
2261    }
2262
2263    fn layer_sizes_in_bytes(
2264        &self,
2265        config: &str,
2266        dtype: DType,
2267        weight_pack_factor: usize,
2268        _matformer_config: Option<&MatformerSliceConfig>,
2269    ) -> Result<Vec<usize>> {
2270        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2271        let per_layer_elems = {
2272            let input_layernorm = cfg.hidden_size;
2273            let post_attention_layernorm = cfg.hidden_size;
2274
2275            let size_in = cfg.hidden_size;
2276            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2277            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2278            let q_proj = size_in * size_q / weight_pack_factor + size_q;
2279            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2280            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2281            let o_proj = size_q * size_in / weight_pack_factor;
2282
2283            let h_size = cfg.hidden_size;
2284            let i_size = cfg.intermediate_size;
2285            let gate_proj = h_size * i_size / weight_pack_factor;
2286            let up_proj = h_size * i_size / weight_pack_factor;
2287            let down_proj = i_size * h_size / weight_pack_factor;
2288
2289            input_layernorm
2290                + post_attention_layernorm
2291                + q_proj
2292                + k_proj
2293                + v_proj
2294                + o_proj
2295                + gate_proj
2296                + up_proj
2297                + down_proj
2298        };
2299        Ok(vec![
2300            per_layer_elems * dtype.size_in_bytes();
2301            cfg.num_hidden_layers
2302        ])
2303    }
2304
2305    fn num_layers(&self, config: &str) -> Result<usize> {
2306        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2307        Ok(cfg.num_hidden_layers)
2308    }
2309
2310    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2311        let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2312
2313        let cfg = ModelConfigMetadata {
2314            max_seq_len: cfg.max_position_embeddings,
2315            num_layers: cfg.num_hidden_layers,
2316            hidden_size: cfg.hidden_size,
2317            num_kv_heads: cfg.num_key_value_heads,
2318            num_attn_heads: cfg.num_attention_heads,
2319            sliding_window: cfg.sliding_window,
2320            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2321            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2322            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2323        };
2324
2325        Ok(Box::new(cfg))
2326    }
2327
2328    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2329        Some(vec![NonMappedSubModel::Vision])
2330    }
2331}
2332
2333// ======================== Idefics 3 loader
2334
2335/// [`VisionLoader`] for an Idefics 3 Vision model.
2336///
2337/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
2338pub struct Idefics3Loader;
2339
2340pub struct Idefics3Prefixer;
2341
2342impl MultimodalPromptPrefixer for Idefics3Prefixer {
2343    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2344        // Chat template does it
2345        prompt.to_string()
2346    }
2347}
2348
2349impl VisionModelLoader for Idefics3Loader {
2350    fn load(
2351        &self,
2352        config: &str,
2353        vb: ShardedVarBuilder,
2354        normal_loading_metadata: NormalLoadingMetadata,
2355        attention_mechanism: AttentionImplementation,
2356    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2357        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2358        Ok(Box::new(Idefics3Model::new(
2359            &cfg,
2360            vb,
2361            self.is_gptx(config),
2362            normal_loading_metadata,
2363            attention_mechanism,
2364        )?))
2365    }
2366    fn is_gptx(&self, _config: &str) -> bool {
2367        true
2368    }
2369    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2370        let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2371        Ok(Box::new(cfg))
2372    }
2373    fn get_processor(
2374        &self,
2375        _model_config: &str,
2376        processor_config: Option<ProcessorConfig>,
2377        preprocessor_config: PreProcessorConfig,
2378        max_edge: Option<u32>,
2379    ) -> Arc<dyn Processor + Send + Sync> {
2380        Arc::new(Idefics3Processor::new(
2381            processor_config.unwrap_or_default(),
2382            preprocessor_config,
2383            max_edge,
2384        ))
2385    }
2386    fn supports_paged_attention(&self, _config: &str) -> bool {
2387        true
2388    }
2389    fn supports_prefix_cacher(&self, _config: &str) -> bool {
2390        true
2391    }
2392    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2393        Arc::new(Idefics3Prefixer)
2394    }
2395    fn modalities(&self, _config: &str) -> Result<Modalities> {
2396        Ok(Modalities {
2397            input: vec![SupportedModality::Text, SupportedModality::Vision],
2398            output: vec![SupportedModality::Text],
2399        })
2400    }
2401}
2402
2403impl IsqModelLoader for Idefics3Loader {
2404    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2405        Ok(vec![
2406            Regex::new(r"lm_head\.(weight|bias)$")?,
2407            // Attention
2408            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2409            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2410            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2411            Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2412            // MLP
2413            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2414            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2415            Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2416        ])
2417    }
2418    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2419        Ok(vec![
2420            Regex::new(r"lm_head\.(weight|bias)$")?,
2421            // Attention
2422            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2423            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2424            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2425            Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2426            // MLP
2427            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2428            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2429            Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2430            // // Attention (vision)
2431            // Regex::new(
2432            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2433            // )?,
2434            // Regex::new(
2435            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
2436            // )?,
2437            // Regex::new(
2438            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
2439            // )?,
2440            // Regex::new(
2441            //     r"model\.vision_model\.encoder\.layers\.(\d+)\.self_attn\.out_proj\.(weight|bias)$",
2442            // )?,
2443            // MLP (vision)
2444            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2445            // Regex::new(r"model\.vision_model\.encoder\.layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
2446        ])
2447    }
2448}
2449
2450impl DeviceMappedModelLoader for Idefics3Loader {
2451    fn mapped_max_act_size_elems(
2452        &self,
2453        config: &str,
2454        params: &AutoDeviceMapParams,
2455    ) -> Result<usize> {
2456        let AutoDeviceMapParams::Vision {
2457            max_seq_len,
2458            max_batch_size,
2459            max_image_shape: _,
2460            max_num_images,
2461        } = params
2462        else {
2463            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2464        };
2465
2466        let cfg: Idefics3Config = serde_json::from_str(config)?;
2467
2468        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2469        let img_seq_len = (num_patches + 1) * max_num_images;
2470
2471        let max_text_attn = {
2472            // This model injects the vision information directly into the input embeddings
2473            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2474            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2475        };
2476
2477        Ok(max_text_attn)
2478    }
2479
2480    fn non_mapped_max_act_size_elems(
2481        &self,
2482        config: &str,
2483        params: &AutoDeviceMapParams,
2484    ) -> Result<usize> {
2485        let AutoDeviceMapParams::Vision {
2486            max_seq_len: _,
2487            max_batch_size,
2488            max_image_shape: _,
2489            max_num_images,
2490        } = params
2491        else {
2492            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2493        };
2494
2495        let cfg: Idefics3Config = serde_json::from_str(config)?;
2496
2497        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2498        let img_seq_len = num_patches + 1;
2499
2500        let max_vision_attn = {
2501            // do_image_splitting = true
2502            let images_factor = 5;
2503
2504            (max_batch_size * images_factor * max_num_images)
2505                * cfg.vision_config.num_attention_heads
2506                * img_seq_len
2507                * img_seq_len
2508        };
2509
2510        Ok(max_vision_attn)
2511    }
2512
2513    fn non_mapped_size_in_bytes(
2514        &self,
2515        config: &str,
2516        dtype: DType,
2517        weight_pack_factor: usize,
2518        _matformer_config: Option<&MatformerSliceConfig>,
2519    ) -> Result<usize> {
2520        let cfg: Idefics3Config = serde_json::from_str(config)?;
2521        let text_elems = {
2522            let cfg = &cfg.text_config;
2523
2524            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2525            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2526            let norm = cfg.hidden_size;
2527            embed_tokens + lm_head + norm
2528        };
2529
2530        let connector_elems = {
2531            let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2532            let out_dim = cfg.text_config.hidden_size;
2533
2534            in_dim * out_dim
2535        };
2536
2537        let vision_transformer = {
2538            let cfg = &cfg.vision_config;
2539
2540            let post_layernorm = cfg.hidden_size;
2541
2542            let conv_config = Conv2dConfig {
2543                stride: cfg.patch_size,
2544                ..Default::default()
2545            };
2546            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2547                * cfg.patch_size
2548                * cfg.patch_size;
2549
2550            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2551            let num_patches = num_patches_per_side.pow(2);
2552            let position_embedding = num_patches * cfg.hidden_size;
2553
2554            let layer_elems = {
2555                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2556                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2557
2558                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2559                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2560
2561                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2562                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2563                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2564                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2565
2566                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2567            };
2568
2569            post_layernorm
2570                + patch_embedding
2571                + position_embedding
2572                + layer_elems * cfg.num_hidden_layers
2573        };
2574
2575        let elems = text_elems + connector_elems + vision_transformer;
2576
2577        Ok(elems * dtype.size_in_bytes())
2578    }
2579
2580    fn layer_sizes_in_bytes(
2581        &self,
2582        config: &str,
2583        dtype: DType,
2584        weight_pack_factor: usize,
2585        _matformer_config: Option<&MatformerSliceConfig>,
2586    ) -> Result<Vec<usize>> {
2587        let cfg: Idefics3Config = serde_json::from_str(config)?;
2588        let cfg = cfg.text_config;
2589        let per_layer_elems = {
2590            let input_layernorm = cfg.hidden_size;
2591            let post_attention_layernorm = cfg.hidden_size;
2592
2593            let size_in = cfg.hidden_size;
2594            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2595            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2596            let q_proj = size_in * size_q / weight_pack_factor;
2597            let k_proj = size_in * size_kv / weight_pack_factor;
2598            let v_proj = size_in * size_kv / weight_pack_factor;
2599            let o_proj = size_q * size_in / weight_pack_factor;
2600
2601            let h_size = cfg.hidden_size;
2602            let i_size = cfg.intermediate_size;
2603            let gate_proj = h_size * i_size / weight_pack_factor;
2604            let up_proj = h_size * i_size / weight_pack_factor;
2605            let down_proj = i_size * h_size / weight_pack_factor;
2606
2607            input_layernorm
2608                + post_attention_layernorm
2609                + q_proj
2610                + k_proj
2611                + v_proj
2612                + o_proj
2613                + gate_proj
2614                + up_proj
2615                + down_proj
2616        };
2617        Ok(vec![
2618            per_layer_elems * dtype.size_in_bytes();
2619            cfg.num_hidden_layers
2620        ])
2621    }
2622
2623    fn num_layers(&self, config: &str) -> Result<usize> {
2624        let cfg: Idefics3Config = serde_json::from_str(config)?;
2625        Ok(cfg.text_config.num_hidden_layers)
2626    }
2627    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2628        let cfg: Idefics3Config = serde_json::from_str(config)?;
2629        let cfg = &cfg.text_config;
2630
2631        let cfg = ModelConfigMetadata {
2632            max_seq_len: cfg.max_position_embeddings,
2633            num_layers: cfg.num_hidden_layers,
2634            hidden_size: cfg.hidden_size,
2635            num_kv_heads: cfg.num_key_value_heads,
2636            num_attn_heads: cfg.num_attention_heads,
2637            sliding_window: None,
2638            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2639            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2640            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2641        };
2642
2643        Ok(Box::new(cfg))
2644    }
2645
2646    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2647        Some(vec![NonMappedSubModel::Vision])
2648    }
2649}
2650
2651// ======================== MiniCpm-O loader
2652
2653/// [`VisionLoader`] for an MiniCpm-O model.
2654///
2655/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
2656pub struct MiniCpmOLoader;
2657
2658pub struct MiniCpmOPrefixer;
2659
2660impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2661    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2662        format!(
2663            "{}{prompt}",
2664            "(<image>./</image>)".repeat(image_indexes.len())
2665        )
2666    }
2667}
2668
2669impl VisionModelLoader for MiniCpmOLoader {
2670    fn load(
2671        &self,
2672        config: &str,
2673        vb: ShardedVarBuilder,
2674        normal_loading_metadata: NormalLoadingMetadata,
2675        attention_mechanism: AttentionImplementation,
2676    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2677        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2678        Ok(Box::new(MiniCpmOModel::new(
2679            &cfg,
2680            vb,
2681            self.is_gptx(config),
2682            normal_loading_metadata,
2683            attention_mechanism,
2684        )?))
2685    }
2686    fn is_gptx(&self, _config: &str) -> bool {
2687        true
2688    }
2689    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2690        let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2691        Ok(Box::new(cfg))
2692    }
2693    fn get_processor(
2694        &self,
2695        _model_config: &str,
2696        processor_config: Option<ProcessorConfig>,
2697        preprocessor_config: PreProcessorConfig,
2698        max_edge: Option<u32>,
2699    ) -> Arc<dyn Processor + Send + Sync> {
2700        Arc::new(MiniCpmOProcessor::new(
2701            processor_config.unwrap_or_default(),
2702            preprocessor_config,
2703            max_edge,
2704        ))
2705    }
2706    fn supports_paged_attention(&self, _config: &str) -> bool {
2707        true
2708    }
2709    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2710        Arc::new(MiniCpmOPrefixer)
2711    }
2712    fn modalities(&self, _config: &str) -> Result<Modalities> {
2713        Ok(Modalities {
2714            input: vec![SupportedModality::Text, SupportedModality::Vision],
2715            output: vec![SupportedModality::Text],
2716        })
2717    }
2718}
2719
2720impl IsqModelLoader for MiniCpmOLoader {
2721    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2722        Ok(vec![
2723            Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2724            // Attention
2725            Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2726            Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2727            Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2728            Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2729            // MLP
2730            Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2731            Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2732            Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2733        ])
2734    }
2735    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2736        self.isq_layer_regexes(config)
2737    }
2738}
2739
2740impl DeviceMappedModelLoader for MiniCpmOLoader {
2741    fn mapped_max_act_size_elems(
2742        &self,
2743        config: &str,
2744        params: &AutoDeviceMapParams,
2745    ) -> Result<usize> {
2746        let AutoDeviceMapParams::Vision {
2747            max_seq_len,
2748            max_batch_size,
2749            max_image_shape: _,
2750            max_num_images,
2751        } = params
2752        else {
2753            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2754        };
2755
2756        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2757
2758        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2759        let img_seq_len = (num_patches + 1) * max_num_images;
2760
2761        let max_text_attn = {
2762            // This model injects the vision information directly into the input embeddings
2763            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2764            max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2765        };
2766
2767        Ok(max_text_attn)
2768    }
2769
2770    fn non_mapped_max_act_size_elems(
2771        &self,
2772        config: &str,
2773        params: &AutoDeviceMapParams,
2774    ) -> Result<usize> {
2775        let AutoDeviceMapParams::Vision {
2776            max_seq_len: _,
2777            max_batch_size,
2778            max_image_shape: _,
2779            max_num_images,
2780        } = params
2781        else {
2782            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2783        };
2784
2785        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2786
2787        let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2788        let img_seq_len = num_patches + 1;
2789
2790        let max_vision_attn = {
2791            // do_image_splitting = true
2792            let images_factor = 5;
2793
2794            (max_batch_size * images_factor * max_num_images)
2795                * cfg.vision_config.num_attention_heads
2796                * img_seq_len
2797                * img_seq_len
2798        };
2799
2800        Ok(max_vision_attn)
2801    }
2802
2803    fn non_mapped_size_in_bytes(
2804        &self,
2805        config: &str,
2806        dtype: DType,
2807        weight_pack_factor: usize,
2808        _matformer_config: Option<&MatformerSliceConfig>,
2809    ) -> Result<usize> {
2810        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2811        let text_elems = {
2812            let cfg = &cfg.text_config;
2813
2814            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2815            let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2816            let norm = cfg.hidden_size;
2817            embed_tokens + lm_head + norm
2818        };
2819
2820        let vision_transformer = {
2821            let cfg = &cfg.vision_config;
2822
2823            let post_layernorm = cfg.hidden_size;
2824
2825            let conv_config = Conv2dConfig {
2826                stride: cfg.patch_size,
2827                ..Default::default()
2828            };
2829            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2830                * cfg.patch_size
2831                * cfg.patch_size;
2832
2833            let num_patches_per_side = cfg.image_size / cfg.patch_size;
2834            let num_patches = num_patches_per_side.pow(2);
2835            let position_embedding = num_patches * cfg.hidden_size;
2836
2837            let layer_elems = {
2838                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2839                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2840
2841                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2842                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2843
2844                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2845                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2846                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2847                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2848
2849                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2850            };
2851
2852            post_layernorm
2853                + patch_embedding
2854                + position_embedding
2855                + layer_elems * cfg.num_hidden_layers
2856        };
2857
2858        let elems = text_elems + vision_transformer;
2859
2860        Ok(elems * dtype.size_in_bytes())
2861    }
2862
2863    fn layer_sizes_in_bytes(
2864        &self,
2865        config: &str,
2866        dtype: DType,
2867        weight_pack_factor: usize,
2868        _matformer_config: Option<&MatformerSliceConfig>,
2869    ) -> Result<Vec<usize>> {
2870        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2871        let cfg = cfg.text_config;
2872        let per_layer_elems = {
2873            let input_layernorm = cfg.hidden_size;
2874            let post_attention_layernorm = cfg.hidden_size;
2875
2876            let size_in = cfg.hidden_size;
2877            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2878            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2879            let q_proj = size_in * size_q / weight_pack_factor;
2880            let k_proj = size_in * size_kv / weight_pack_factor;
2881            let v_proj = size_in * size_kv / weight_pack_factor;
2882            let o_proj = size_q * size_in / weight_pack_factor;
2883
2884            let h_size = cfg.hidden_size;
2885            let i_size = cfg.intermediate_size;
2886            let gate_proj = h_size * i_size / weight_pack_factor;
2887            let up_proj = h_size * i_size / weight_pack_factor;
2888            let down_proj = i_size * h_size / weight_pack_factor;
2889
2890            input_layernorm
2891                + post_attention_layernorm
2892                + q_proj
2893                + k_proj
2894                + v_proj
2895                + o_proj
2896                + gate_proj
2897                + up_proj
2898                + down_proj
2899        };
2900        Ok(vec![
2901            per_layer_elems * dtype.size_in_bytes();
2902            cfg.num_hidden_layers
2903        ])
2904    }
2905
2906    fn num_layers(&self, config: &str) -> Result<usize> {
2907        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2908        Ok(cfg.text_config.num_hidden_layers)
2909    }
2910    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2911        let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2912        let cfg = &cfg.text_config;
2913
2914        let cfg = ModelConfigMetadata {
2915            max_seq_len: cfg.max_position_embeddings,
2916            num_layers: cfg.num_hidden_layers,
2917            hidden_size: cfg.hidden_size,
2918            num_kv_heads: cfg.num_key_value_heads,
2919            num_attn_heads: cfg.num_attention_heads,
2920            sliding_window: None,
2921            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2922            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2923            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2924        };
2925
2926        Ok(Box::new(cfg))
2927    }
2928}
2929
2930// ======================== Phi 4MM loader
2931
2932/// [`VisionLoader`] for a Phi 4MM Vision model.
2933///
2934/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
2935pub struct Phi4MMLoader;
2936
2937pub struct Phi4MMPrefixer;
2938
2939impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2940    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2941        // Image indexing starts at 0.
2942
2943        format!(
2944            "{}{prompt}",
2945            image_indexes
2946                .into_iter()
2947                .map(|image_index| format!("<|image_{}|>", image_index + 1))
2948                .join("")
2949        )
2950    }
2951    fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2952        // Image indexing starts at 0.
2953
2954        format!(
2955            "{}{prompt}",
2956            audio_indexes
2957                .into_iter()
2958                .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2959                .join("")
2960        )
2961    }
2962}
2963
2964impl VisionModelLoader for Phi4MMLoader {
2965    fn load(
2966        &self,
2967        config: &str,
2968        vb: ShardedVarBuilder,
2969        normal_loading_metadata: NormalLoadingMetadata,
2970        attention_mechanism: AttentionImplementation,
2971    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2972        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2973        Ok(Box::new(Phi4MMModel::new(
2974            &cfg,
2975            vb,
2976            self.is_gptx(config),
2977            normal_loading_metadata,
2978            attention_mechanism,
2979        )?))
2980    }
2981    fn is_gptx(&self, _config: &str) -> bool {
2982        true
2983    }
2984    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2985        let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2986        Ok(Box::new(cfg))
2987    }
2988    fn get_processor(
2989        &self,
2990        _model_config: &str,
2991        processor_config: Option<ProcessorConfig>,
2992        preprocessor_config: PreProcessorConfig,
2993        _max_edge: Option<u32>,
2994    ) -> Arc<dyn Processor + Send + Sync> {
2995        Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2996    }
2997    fn supports_paged_attention(&self, _config: &str) -> bool {
2998        true
2999    }
3000    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3001        true
3002    }
3003    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3004        Arc::new(Phi4MMPrefixer)
3005    }
3006    fn modalities(&self, _config: &str) -> Result<Modalities> {
3007        Ok(Modalities {
3008            input: vec![
3009                SupportedModality::Text,
3010                SupportedModality::Vision,
3011                SupportedModality::Audio,
3012            ],
3013            output: vec![SupportedModality::Text],
3014        })
3015    }
3016}
3017
3018impl IsqModelLoader for Phi4MMLoader {
3019    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3020        Ok(vec![
3021            Regex::new(r"lm_head\.(weight|bias)$")?,
3022            // Attention
3023            Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3024            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3025            // MLP
3026            Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3027            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3028        ])
3029    }
3030    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3031        self.isq_layer_regexes(config)
3032    }
3033}
3034
3035impl DeviceMappedModelLoader for Phi4MMLoader {
3036    fn mapped_max_act_size_elems(
3037        &self,
3038        config: &str,
3039        params: &AutoDeviceMapParams,
3040    ) -> Result<usize> {
3041        // NOTE: we ignore max_num_images although it can only be one...
3042        let AutoDeviceMapParams::Vision {
3043            max_seq_len,
3044            max_batch_size,
3045            max_image_shape: _,
3046            max_num_images,
3047        } = params
3048        else {
3049            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3050        };
3051
3052        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3053
3054        let vcfg = &PHI4_MM_VISION_CFG;
3055
3056        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3057        let img_seq_len = (num_patches + 1) * max_num_images;
3058
3059        let max_text_attn = {
3060            // This model injects the vision information directly into the input embeddings
3061            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3062            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3063        };
3064
3065        Ok(max_text_attn)
3066    }
3067
3068    fn non_mapped_max_act_size_elems(
3069        &self,
3070        _config: &str,
3071        params: &AutoDeviceMapParams,
3072    ) -> Result<usize> {
3073        let AutoDeviceMapParams::Vision {
3074            max_seq_len: _,
3075            max_batch_size,
3076            max_image_shape,
3077            max_num_images,
3078        } = params
3079        else {
3080            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3081        };
3082
3083        let vcfg = &PHI4_MM_VISION_CFG;
3084
3085        let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3086        let img_seq_len = num_patches + 1;
3087
3088        let max_batch_size = max_batch_size
3089            * (max_image_shape
3090                .0
3091                .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3092                * max_image_shape
3093                    .1
3094                    .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3095                + 1);
3096
3097        let max_vision_attn = (max_batch_size * max_num_images)
3098            * vcfg.num_attention_heads
3099            * img_seq_len
3100            * img_seq_len;
3101        let max_qkv = 3
3102            * (max_batch_size
3103                * vcfg.num_attention_heads
3104                * img_seq_len
3105                * (vcfg.hidden_size / vcfg.num_attention_heads));
3106
3107        Ok(max_vision_attn + max_qkv)
3108    }
3109
3110    fn non_mapped_size_in_bytes(
3111        &self,
3112        config: &str,
3113        dtype: DType,
3114        weight_pack_factor: usize,
3115        _matformer_config: Option<&MatformerSliceConfig>,
3116    ) -> Result<usize> {
3117        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3118        let elems = {
3119            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3120            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3121            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3122                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3123            } else {
3124                0
3125            };
3126            let norm = cfg.hidden_size;
3127
3128            let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3129                let projection_cls = img_embed
3130                    .projection_cls
3131                    .clone()
3132                    .unwrap_or("linear".to_string());
3133                let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3134                let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3135                let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3136
3137                let proj = match (projection_cls.as_str(), use_hd_transform) {
3138                    ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3139                    ("mlp", true) => {
3140                        let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3141                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3142                        a + b
3143                    }
3144                    ("mlp", false) => {
3145                        let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3146                        let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3147                        a + b
3148                    }
3149                    _ => {
3150                        anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3151                    }
3152                };
3153
3154                let (glb_gn, sub_gn) = if with_learnable_separator {
3155                    let glb_gn = image_dim_out * 4;
3156                    let sub_gn = image_dim_out * 4;
3157                    (glb_gn, sub_gn)
3158                } else {
3159                    (0, 0)
3160                };
3161
3162                let vision_transformer = {
3163                    let cfg = &PHI4_MM_VISION_CFG;
3164
3165                    let post_layernorm = cfg.hidden_size;
3166
3167                    let conv_config = Conv2dConfig {
3168                        stride: cfg.patch_size,
3169                        ..Default::default()
3170                    };
3171                    let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3172                        * cfg.patch_size
3173                        * cfg.patch_size;
3174
3175                    let num_patches_per_side = cfg.image_size / cfg.patch_size;
3176                    let num_patches = num_patches_per_side.pow(2);
3177                    let position_embedding = num_patches * cfg.hidden_size;
3178
3179                    let layer_elems = {
3180                        let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3181                        let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3182
3183                        let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3184                        let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3185
3186                        let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3187                        let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3188                        let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3189                        let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3190
3191                        layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3192                    };
3193
3194                    post_layernorm
3195                        + patch_embedding
3196                        + position_embedding
3197                        + layer_elems * cfg.num_hidden_layers
3198                };
3199
3200                proj + glb_gn + sub_gn + vision_transformer
3201            } else {
3202                0
3203            };
3204
3205            embed_tokens + lm_head + norm + image_embed
3206        };
3207
3208        Ok(elems * dtype.size_in_bytes())
3209    }
3210
3211    fn layer_sizes_in_bytes(
3212        &self,
3213        config: &str,
3214        dtype: DType,
3215        weight_pack_factor: usize,
3216        _matformer_config: Option<&MatformerSliceConfig>,
3217    ) -> Result<Vec<usize>> {
3218        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3219        let per_layer_elems = {
3220            let input_layernorm = cfg.hidden_size;
3221            let post_attention_layernorm = cfg.hidden_size;
3222
3223            let size_in = cfg.hidden_size;
3224            let head_dim = cfg.head_dim();
3225            let op_size =
3226                cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3227            let qkv_proj = size_in * op_size / weight_pack_factor;
3228            let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3229
3230            let h_size = cfg.hidden_size;
3231            let i_size = cfg.intermediate_size;
3232            let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3233            let down_proj = h_size * i_size / weight_pack_factor;
3234
3235            input_layernorm
3236                + post_attention_layernorm
3237                + qkv_proj
3238                + o_proj
3239                + gate_up_proj
3240                + down_proj
3241        };
3242        Ok(vec![
3243            per_layer_elems * dtype.size_in_bytes();
3244            cfg.num_hidden_layers
3245        ])
3246    }
3247
3248    fn num_layers(&self, config: &str) -> Result<usize> {
3249        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3250        Ok(cfg.num_hidden_layers)
3251    }
3252
3253    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3254        let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3255
3256        let cfg = ModelConfigMetadata {
3257            max_seq_len: cfg.max_position_embeddings,
3258            num_layers: cfg.num_hidden_layers,
3259            hidden_size: cfg.hidden_size,
3260            num_kv_heads: cfg.num_key_value_heads(),
3261            num_attn_heads: cfg.num_attention_heads,
3262            sliding_window: cfg.sliding_window,
3263            k_head_dim: cfg.head_dim(),
3264            v_head_dim: cfg.head_dim(),
3265            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3266        };
3267
3268        Ok(Box::new(cfg))
3269    }
3270
3271    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3272        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3273    }
3274}
3275
3276// ======================== Qwen2_5VL Loader
3277
3278/// [`VisionLoader`] for an Qwen2_5VL model.
3279///
3280/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
3281pub struct Qwen2_5VLLoader;
3282
3283pub struct Qwen2_5VLPrefixer;
3284
3285impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3286    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3287        format!(
3288            "{}{prompt}",
3289            format!(
3290                "{}{}{}",
3291                Qwen2_5VLProcessor::VISION_START,
3292                Qwen2_5VLProcessor::IMAGE_PAD,
3293                Qwen2_5VLProcessor::VISION_END
3294            )
3295            .repeat(image_indexes.len())
3296        )
3297    }
3298}
3299
3300impl VisionModelLoader for Qwen2_5VLLoader {
3301    fn load(
3302        &self,
3303        config: &str,
3304        vb: ShardedVarBuilder,
3305        normal_loading_metadata: NormalLoadingMetadata,
3306        attention_mechanism: AttentionImplementation,
3307    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3308        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3309        Ok(Box::new(Qwen2_5VLModel::new(
3310            &cfg,
3311            vb,
3312            self.is_gptx(config),
3313            normal_loading_metadata,
3314            attention_mechanism,
3315        )?))
3316    }
3317    fn is_gptx(&self, _config: &str) -> bool {
3318        true
3319    }
3320    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3321        let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3322        Ok(Box::new(config))
3323    }
3324    fn get_processor(
3325        &self,
3326        _model_config: &str,
3327        _processor_config: Option<ProcessorConfig>,
3328        _preprocessor_config: PreProcessorConfig,
3329        max_edge: Option<u32>,
3330    ) -> Arc<dyn Processor + Send + Sync> {
3331        Arc::new(Qwen2_5VLProcessor::new(max_edge))
3332    }
3333    fn supports_paged_attention(&self, _config: &str) -> bool {
3334        false
3335    }
3336    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3337        Arc::new(Qwen2_5VLPrefixer)
3338    }
3339    fn modalities(&self, _config: &str) -> Result<Modalities> {
3340        Ok(Modalities {
3341            input: vec![SupportedModality::Text, SupportedModality::Vision],
3342            output: vec![SupportedModality::Text],
3343        })
3344    }
3345}
3346
3347impl IsqModelLoader for Qwen2_5VLLoader {
3348    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3349        Ok(vec![
3350            Regex::new(r"lm_head\.(weight|bias)$")?,
3351            // Attention
3352            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3353            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3354            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3355            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3356            // MLP
3357            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3358            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3359            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3360        ])
3361    }
3362    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3363        self.isq_layer_regexes(config)
3364    }
3365}
3366
3367impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3368    fn mapped_max_act_size_elems(
3369        &self,
3370        config: &str,
3371        params: &AutoDeviceMapParams,
3372    ) -> Result<usize> {
3373        let AutoDeviceMapParams::Vision {
3374            max_seq_len,
3375            max_batch_size,
3376            max_image_shape,
3377            max_num_images,
3378        } = params
3379        else {
3380            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3381        };
3382
3383        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3384
3385        let img_seq_len = {
3386            let cfg = &cfg.vision_config;
3387            let grid_t = max_num_images / cfg.temporal_patch_size;
3388            let grid_h = max_image_shape.0 / cfg.patch_size;
3389            let grid_w = max_image_shape.1 / cfg.patch_size;
3390            grid_t * grid_h * grid_w
3391        };
3392        let img_seq_len = img_seq_len * max_num_images;
3393
3394        let max_text_attn = {
3395            // This model injects the vision information directly into the input embeddings
3396            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3397            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3398        };
3399
3400        Ok(max_text_attn)
3401    }
3402
3403    fn non_mapped_max_act_size_elems(
3404        &self,
3405        config: &str,
3406        params: &AutoDeviceMapParams,
3407    ) -> Result<usize> {
3408        let AutoDeviceMapParams::Vision {
3409            max_seq_len: _,
3410            max_batch_size,
3411            max_image_shape,
3412            max_num_images,
3413        } = params
3414        else {
3415            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3416        };
3417
3418        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3419
3420        let img_seq_len = {
3421            let cfg = &cfg.vision_config;
3422            let grid_t = max_num_images / cfg.temporal_patch_size;
3423            let grid_h = max_image_shape.0 / cfg.patch_size;
3424            let grid_w = max_image_shape.1 / cfg.patch_size;
3425            grid_t * grid_h * grid_w
3426        };
3427
3428        let max_vision_attn = {
3429            let cfg = &cfg.vision_config;
3430            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3431        };
3432
3433        Ok(max_vision_attn)
3434    }
3435
3436    fn non_mapped_size_in_bytes(
3437        &self,
3438        config: &str,
3439        dtype: DType,
3440        weight_pack_factor: usize,
3441        _matformer_config: Option<&MatformerSliceConfig>,
3442    ) -> Result<usize> {
3443        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3444        let text_elems = {
3445            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3446            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3447            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3448                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3449            } else {
3450                0
3451            };
3452            let norm = cfg.hidden_size;
3453            embed_tokens + lm_head + norm
3454        };
3455
3456        let patch_merger = {
3457            let cfg = &cfg.vision_config;
3458            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3459
3460            let mlp0 = hidden_size * hidden_size + hidden_size;
3461            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3462
3463            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3464
3465            mlp0 + mlp2 + ln_q
3466        };
3467
3468        let patch_embed = {
3469            let cfg = &cfg.vision_config;
3470            let conv_cfg = Conv3dConfig {
3471                stride: cfg.patch_size,
3472                ..Default::default()
3473            };
3474            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3475            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3476                * kernel_sizes[0]
3477                * kernel_sizes[1]
3478                * kernel_sizes[2]
3479        };
3480
3481        let encoder_layer = {
3482            let cfg = &cfg.vision_config;
3483            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3484            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3485
3486            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3487            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3488            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3489
3490            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3491            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3492
3493            norm1 + norm2 + fc1 + fc2 + qkv + out
3494        };
3495
3496        let elems =
3497            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3498
3499        Ok(elems * dtype.size_in_bytes())
3500    }
3501
3502    fn layer_sizes_in_bytes(
3503        &self,
3504        config: &str,
3505        dtype: DType,
3506        weight_pack_factor: usize,
3507        _matformer_config: Option<&MatformerSliceConfig>,
3508    ) -> Result<Vec<usize>> {
3509        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3510        let per_layer_elems = {
3511            let input_layernorm = cfg.hidden_size;
3512            let post_attention_layernorm = cfg.hidden_size;
3513
3514            let size_in = cfg.hidden_size;
3515            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3516            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3517            let q_proj = size_in * size_q / weight_pack_factor + size_q;
3518            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3519            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3520            let o_proj = size_q * size_in / weight_pack_factor;
3521
3522            let h_size = cfg.hidden_size;
3523            let i_size = cfg.intermediate_size;
3524            let gate_proj = h_size * i_size / weight_pack_factor;
3525            let up_proj = h_size * i_size / weight_pack_factor;
3526            let down_proj = i_size * h_size / weight_pack_factor;
3527
3528            input_layernorm
3529                + post_attention_layernorm
3530                + q_proj
3531                + k_proj
3532                + v_proj
3533                + o_proj
3534                + gate_proj
3535                + up_proj
3536                + down_proj
3537        };
3538        Ok(vec![
3539            per_layer_elems * dtype.size_in_bytes();
3540            cfg.num_hidden_layers
3541        ])
3542    }
3543
3544    fn num_layers(&self, config: &str) -> Result<usize> {
3545        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3546        Ok(cfg.num_hidden_layers)
3547    }
3548
3549    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3550        let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3551
3552        let cfg = ModelConfigMetadata {
3553            max_seq_len: cfg.max_position_embeddings,
3554            num_layers: cfg.num_hidden_layers,
3555            hidden_size: cfg.hidden_size,
3556            num_kv_heads: cfg.num_key_value_heads,
3557            num_attn_heads: cfg.num_attention_heads,
3558            sliding_window: cfg.sliding_window,
3559            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3560            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3561            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3562        };
3563
3564        Ok(Box::new(cfg))
3565    }
3566
3567    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3568        Some(vec![NonMappedSubModel::Vision])
3569    }
3570}
3571
3572// ======================== Gemma 3 Loader
3573
3574/// [`VisionLoader`] for an Gemma 3 model.
3575///
3576/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
3577pub struct Gemma3Loader;
3578
3579pub struct Gemma3Prefixer;
3580
3581impl MultimodalPromptPrefixer for Gemma3Prefixer {
3582    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3583        prompt.to_string()
3584    }
3585}
3586
3587impl VisionModelLoader for Gemma3Loader {
3588    fn load(
3589        &self,
3590        config: &str,
3591        vb: ShardedVarBuilder,
3592        normal_loading_metadata: NormalLoadingMetadata,
3593        attention_mechanism: AttentionImplementation,
3594    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3595        let cfg: Gemma3Config = serde_json::from_str(config)?;
3596        Ok(Box::new(Gemma3Model::new(
3597            &cfg,
3598            vb,
3599            self.is_gptx(config),
3600            normal_loading_metadata,
3601            attention_mechanism,
3602        )?))
3603    }
3604    fn is_gptx(&self, _config: &str) -> bool {
3605        true
3606    }
3607    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3608        let config: Gemma3Config = serde_json::from_str(config)?;
3609        Ok(Box::new(config))
3610    }
3611    fn get_processor(
3612        &self,
3613        config: &str,
3614        processor_config: Option<ProcessorConfig>,
3615        _preprocessor_config: PreProcessorConfig,
3616        _max_edge: Option<u32>,
3617    ) -> Arc<dyn Processor + Send + Sync> {
3618        let config: Gemma3Config = serde_json::from_str(config).unwrap();
3619        // Handle the Gemma 3 1b case here
3620        Arc::new(Gemma3Processor::new(
3621            processor_config.unwrap_or_default(),
3622            matches!(config, Gemma3Config::WithVision { .. }),
3623        ))
3624    }
3625    fn supports_paged_attention(&self, _config: &str) -> bool {
3626        true
3627    }
3628    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3629        true
3630    }
3631    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3632        Arc::new(Gemma3Prefixer)
3633    }
3634    fn modalities(&self, _config: &str) -> Result<Modalities> {
3635        Ok(Modalities {
3636            input: vec![SupportedModality::Text, SupportedModality::Vision],
3637            output: vec![SupportedModality::Text],
3638        })
3639    }
3640}
3641
3642impl IsqModelLoader for Gemma3Loader {
3643    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3644        Ok(vec![
3645            Regex::new(r"lm_head\.(weight|bias)$")?,
3646            // Attention
3647            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3648            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3649            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3650            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3651            // MLP
3652            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3653            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3654            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3655        ])
3656    }
3657    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3658        Ok(vec![
3659            Regex::new(r"lm_head\.(weight|bias)$")?,
3660            // Attention
3661            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3662            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3663            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3664            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3665            // MLP
3666            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3667            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3668            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3669        ])
3670    }
3671}
3672
3673impl DeviceMappedModelLoader for Gemma3Loader {
3674    fn mapped_max_act_size_elems(
3675        &self,
3676        config: &str,
3677        params: &AutoDeviceMapParams,
3678    ) -> Result<usize> {
3679        let AutoDeviceMapParams::Vision {
3680            max_seq_len,
3681            max_batch_size,
3682            max_image_shape: _,
3683            max_num_images,
3684        } = params
3685        else {
3686            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3687        };
3688
3689        let cfg: Gemma3Config = serde_json::from_str(config)?;
3690
3691        match cfg {
3692            Gemma3Config::Text(text_config) => Ok(max_batch_size
3693                * text_config.num_attention_heads
3694                * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3695            Gemma3Config::WithVision {
3696                text_config,
3697                vision_config,
3698                ..
3699            } => {
3700                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3701                let img_seq_len = (num_patches + 1) * max_num_images;
3702
3703                let max_text_attn = {
3704                    // This model injects the vision information directly into the input embeddings
3705                    let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3706                    max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3707                };
3708                Ok(max_text_attn)
3709            }
3710        }
3711    }
3712
3713    fn non_mapped_max_act_size_elems(
3714        &self,
3715        config: &str,
3716        params: &AutoDeviceMapParams,
3717    ) -> Result<usize> {
3718        let AutoDeviceMapParams::Vision {
3719            max_seq_len: _,
3720            max_batch_size,
3721            max_image_shape: _,
3722            max_num_images,
3723        } = params
3724        else {
3725            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3726        };
3727
3728        let cfg: Gemma3Config = serde_json::from_str(config)?;
3729
3730        match cfg {
3731            Gemma3Config::WithVision { vision_config, .. } => {
3732                let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3733                let img_seq_len = num_patches + 1;
3734
3735                let max_vision_attn = {
3736                    (max_batch_size * max_num_images)
3737                        * vision_config.num_attention_heads
3738                        * img_seq_len
3739                        * img_seq_len
3740                };
3741
3742                Ok(max_vision_attn)
3743            }
3744            Gemma3Config::Text(_) => Ok(0),
3745        }
3746    }
3747
3748    fn non_mapped_size_in_bytes(
3749        &self,
3750        config: &str,
3751        dtype: DType,
3752        weight_pack_factor: usize,
3753        _matformer_config: Option<&MatformerSliceConfig>,
3754    ) -> Result<usize> {
3755        let cfg: Gemma3Config = serde_json::from_str(config)?;
3756
3757        let text_elems = {
3758            let cfg = match &cfg {
3759                Gemma3Config::Text(cfg) => cfg,
3760                Gemma3Config::WithVision { text_config, .. } => text_config,
3761            };
3762            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3763            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
3764            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3765                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3766            } else {
3767                0
3768            };
3769            let norm = cfg.hidden_size;
3770            embed_tokens + lm_head + norm
3771        };
3772
3773        let vision_transformer = if let Gemma3Config::WithVision {
3774            vision_config: cfg, ..
3775        } = &cfg
3776        {
3777            let post_layernorm = cfg.hidden_size;
3778
3779            let conv_config = Conv2dConfig {
3780                stride: cfg.patch_size,
3781                ..Default::default()
3782            };
3783            let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3784                * cfg.patch_size
3785                * cfg.patch_size;
3786
3787            let num_patches_per_side = cfg.image_size / cfg.patch_size;
3788            let num_patches = num_patches_per_side.pow(2);
3789            let position_embedding = num_patches * cfg.hidden_size;
3790
3791            let layer_elems = {
3792                let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3793                let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3794
3795                let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3796                let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3797
3798                let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3799                let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3800                let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3801                let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3802
3803                layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3804            };
3805
3806            post_layernorm
3807                + patch_embedding
3808                + position_embedding
3809                + layer_elems * cfg.num_hidden_layers
3810        } else {
3811            0
3812        };
3813
3814        let elems = text_elems + vision_transformer;
3815
3816        Ok(elems * dtype.size_in_bytes())
3817    }
3818
3819    fn layer_sizes_in_bytes(
3820        &self,
3821        config: &str,
3822        dtype: DType,
3823        weight_pack_factor: usize,
3824        _matformer_config: Option<&MatformerSliceConfig>,
3825    ) -> Result<Vec<usize>> {
3826        let cfg: Gemma3Config = serde_json::from_str(config)?;
3827
3828        let txt_cfg = match &cfg {
3829            Gemma3Config::Text(cfg) => cfg,
3830            Gemma3Config::WithVision { text_config, .. } => text_config,
3831        };
3832        let per_layer_elems = {
3833            let cfg = txt_cfg;
3834
3835            let input_layernorm = cfg.hidden_size;
3836            let post_attention_layernorm = cfg.hidden_size;
3837
3838            let size_in = cfg.hidden_size;
3839            let size_q = cfg.head_dim * cfg.num_attention_heads;
3840            let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3841            let q_proj =
3842                size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3843            let k_proj =
3844                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3845            let v_proj =
3846                size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3847            let o_proj =
3848                size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3849
3850            let h_size = cfg.hidden_size;
3851            let i_size = cfg.intermediate_size;
3852            let gate_proj = h_size * i_size / weight_pack_factor;
3853            let up_proj = h_size * i_size / weight_pack_factor;
3854            let down_proj = i_size * h_size / weight_pack_factor;
3855
3856            input_layernorm
3857                + post_attention_layernorm
3858                + q_proj
3859                + k_proj
3860                + v_proj
3861                + o_proj
3862                + gate_proj
3863                + up_proj
3864                + down_proj
3865        };
3866        Ok(vec![
3867            per_layer_elems * dtype.size_in_bytes();
3868            txt_cfg.num_hidden_layers
3869        ])
3870    }
3871
3872    fn num_layers(&self, config: &str) -> Result<usize> {
3873        let cfg: Gemma3Config = serde_json::from_str(config)?;
3874
3875        let txt_cfg = match &cfg {
3876            Gemma3Config::Text(cfg) => cfg,
3877            Gemma3Config::WithVision { text_config, .. } => text_config,
3878        };
3879
3880        Ok(txt_cfg.num_hidden_layers)
3881    }
3882
3883    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3884        let cfg: Gemma3Config = serde_json::from_str(config)?;
3885
3886        let cfg = match &cfg {
3887            Gemma3Config::Text(cfg) => cfg,
3888            Gemma3Config::WithVision { text_config, .. } => text_config,
3889        };
3890
3891        let cfg = ModelConfigMetadata {
3892            max_seq_len: cfg.max_position_embeddings,
3893            num_layers: cfg.num_hidden_layers,
3894            hidden_size: cfg.hidden_size,
3895            num_kv_heads: cfg.num_key_value_heads,
3896            num_attn_heads: cfg.num_attention_heads,
3897            sliding_window: None, // None to be more forgiving, some do not
3898            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3899            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3900            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3901        };
3902
3903        Ok(Box::new(cfg))
3904    }
3905
3906    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3907        Some(vec![NonMappedSubModel::Vision])
3908    }
3909}
3910
3911// ======================== Mistral 3 Loader
3912
3913/// [`VisionLoader`] for an Mistral 3 model.
3914///
3915/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
3916pub struct Mistral3Loader;
3917
3918pub struct Mistral3Prefixer;
3919
3920impl MultimodalPromptPrefixer for Mistral3Prefixer {
3921    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3922        prompt.to_string()
3923    }
3924}
3925
3926impl VisionModelLoader for Mistral3Loader {
3927    fn load(
3928        &self,
3929        config: &str,
3930        vb: ShardedVarBuilder,
3931        normal_loading_metadata: NormalLoadingMetadata,
3932        attention_mechanism: AttentionImplementation,
3933    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3934        let mut cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3935        cfg.propagate_quantization_config();
3936        Ok(Box::new(Mistral3Model::new(
3937            &cfg,
3938            vb,
3939            self.is_gptx(config),
3940            normal_loading_metadata,
3941            attention_mechanism,
3942        )?))
3943    }
3944    fn is_gptx(&self, _config: &str) -> bool {
3945        true
3946    }
3947    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3948        let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3949        Ok(Box::new(cfg))
3950    }
3951    fn get_processor(
3952        &self,
3953        _model_config: &str,
3954        processor_config: Option<ProcessorConfig>,
3955        _preprocessor_config: PreProcessorConfig,
3956        _max_edge: Option<u32>,
3957    ) -> Arc<dyn Processor + Send + Sync> {
3958        Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3959    }
3960    fn supports_paged_attention(&self, _config: &str) -> bool {
3961        true
3962    }
3963    fn supports_prefix_cacher(&self, _config: &str) -> bool {
3964        true
3965    }
3966    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3967        Arc::new(Mistral3Prefixer)
3968    }
3969    fn modalities(&self, _config: &str) -> Result<Modalities> {
3970        Ok(Modalities {
3971            input: vec![SupportedModality::Text, SupportedModality::Vision],
3972            output: vec![SupportedModality::Text],
3973        })
3974    }
3975}
3976
3977impl IsqModelLoader for Mistral3Loader {
3978    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3979        Ok(vec![
3980            Regex::new(r"lm_head\.(weight|bias)$")?,
3981            // Attention
3982            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3983            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3984            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3985            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3986            // MLP
3987            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3988            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3989            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3990        ])
3991    }
3992    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3993        Ok(vec![
3994            Regex::new(r"lm_head\.(weight|bias)$")?,
3995            // Attention
3996            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3997            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3998            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3999            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4000            // MLP
4001            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4002            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4003            Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4004        ])
4005    }
4006}
4007
4008#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4009impl DeviceMappedModelLoader for Mistral3Loader {
4010    fn mapped_max_act_size_elems(
4011        &self,
4012        config: &str,
4013        params: &AutoDeviceMapParams,
4014    ) -> Result<usize> {
4015        let cfg: Mistral3Config = serde_json::from_str(config)?;
4016        let vcfg = &cfg.vision_config;
4017        let tcfg = &cfg.text_config;
4018
4019        let AutoDeviceMapParams::Vision {
4020            max_seq_len,
4021            max_batch_size,
4022            max_image_shape: (mut height, mut width),
4023            max_num_images,
4024        } = params
4025        else {
4026            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4027        };
4028
4029        let img_seq_len = {
4030            // Reshaping algorithm
4031
4032            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4033            let (max_height, max_width) = (1540, 1540);
4034            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4035            if ratio > 1. {
4036                height = (height as f64 / ratio).floor() as usize;
4037                width = (width as f64 / ratio).floor() as usize;
4038            }
4039
4040            let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4041            let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4042
4043            height = num_height_tokens * vcfg.patch_size;
4044            width = num_width_tokens * vcfg.patch_size;
4045
4046            let num_height_tokens = height / vcfg.patch_size;
4047            let num_width_tokens = width / vcfg.patch_size;
4048
4049            (num_width_tokens + 1) * num_height_tokens
4050        };
4051
4052        // This model injects the vision information directly into the input embeddings
4053        let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4054        Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4055    }
4056
4057    fn non_mapped_max_act_size_elems(
4058        &self,
4059        config: &str,
4060        params: &AutoDeviceMapParams,
4061    ) -> Result<usize> {
4062        let cfg: Mistral3Config = serde_json::from_str(config)?;
4063        let cfg = &cfg.vision_config;
4064
4065        let AutoDeviceMapParams::Vision {
4066            max_seq_len: _,
4067            max_batch_size,
4068            max_image_shape: (mut height, mut width),
4069            max_num_images,
4070        } = params
4071        else {
4072            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4073        };
4074
4075        let img_seq_len = {
4076            // Reshaping algorithm
4077
4078            // https://huggingface.co/mistralai/Mistral-Small-3.1-24B-Instruct-2503/blob/main/preprocessor_config.json#L29
4079            let (max_height, max_width) = (1540, 1540);
4080            let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4081            if ratio > 1. {
4082                height = (height as f64 / ratio).floor() as usize;
4083                width = (width as f64 / ratio).floor() as usize;
4084            }
4085
4086            let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4087            let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4088
4089            height = num_height_tokens * cfg.patch_size;
4090            width = num_width_tokens * cfg.patch_size;
4091
4092            let num_height_tokens = height / cfg.patch_size;
4093            let num_width_tokens = width / cfg.patch_size;
4094
4095            (num_width_tokens + 1) * num_height_tokens
4096        };
4097
4098        Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4099    }
4100
4101    fn non_mapped_size_in_bytes(
4102        &self,
4103        config: &str,
4104        dtype: DType,
4105        weight_pack_factor: usize,
4106        _matformer_config: Option<&MatformerSliceConfig>,
4107    ) -> Result<usize> {
4108        let cfg: Mistral3Config = serde_json::from_str(config)?;
4109
4110        let text_elems = {
4111            let cfg = &cfg.text_config;
4112
4113            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4114            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
4115            let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4116                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4117            } else {
4118                0
4119            };
4120            let norm = cfg.hidden_size;
4121            embed_tokens + lm_head + norm
4122        };
4123
4124        let vision_elems = {
4125            let cfg = &cfg.vision_config;
4126
4127            let patch_embed = {
4128                let conv_cfg = Conv2dConfig {
4129                    stride: cfg.patch_size,
4130                    ..Default::default()
4131                };
4132                cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4133                    * cfg.patch_size
4134                    * cfg.patch_size
4135                    * cfg.patch_size
4136            };
4137            let ln_pre = cfg.hidden_size;
4138            let vision_layer = {
4139                let attn_norm = cfg.hidden_size;
4140                let ffn_norm = cfg.hidden_size;
4141
4142                let gate = cfg.hidden_size * cfg.intermediate_size;
4143                let up = cfg.hidden_size * cfg.intermediate_size;
4144                let down = cfg.hidden_size * cfg.intermediate_size;
4145
4146                let q = cfg.hidden_size * cfg.hidden_size;
4147                let k = cfg.hidden_size * cfg.hidden_size;
4148                let v = cfg.hidden_size * cfg.hidden_size;
4149                let o = cfg.hidden_size * cfg.hidden_size;
4150
4151                attn_norm + ffn_norm + gate + up + down + q + k + v + o
4152            };
4153
4154            patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4155        };
4156
4157        let elems = text_elems + vision_elems;
4158
4159        Ok(elems * dtype.size_in_bytes())
4160    }
4161
4162    fn layer_sizes_in_bytes(
4163        &self,
4164        config: &str,
4165        dtype: DType,
4166        weight_pack_factor: usize,
4167        _matformer_config: Option<&MatformerSliceConfig>,
4168    ) -> Result<Vec<usize>> {
4169        let cfg: Mistral3Config = serde_json::from_str(config)?;
4170        let cfg = &cfg.text_config;
4171
4172        let per_layer_elems = {
4173            let input_layernorm = cfg.hidden_size;
4174            let post_attention_layernorm = cfg.hidden_size;
4175
4176            let size_in = cfg.hidden_size;
4177            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4178            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4179            let q_proj = size_in * size_q / weight_pack_factor;
4180            let k_proj = size_in * size_kv / weight_pack_factor;
4181            let v_proj = size_in * size_kv / weight_pack_factor;
4182            let o_proj = size_q * size_in / weight_pack_factor;
4183
4184            let h_size = cfg.hidden_size;
4185            let i_size = cfg.intermediate_size;
4186            let gate_proj = h_size * i_size / weight_pack_factor;
4187            let up_proj = h_size * i_size / weight_pack_factor;
4188            let down_proj = i_size * h_size / weight_pack_factor;
4189
4190            input_layernorm
4191                + post_attention_layernorm
4192                + q_proj
4193                + k_proj
4194                + v_proj
4195                + o_proj
4196                + gate_proj
4197                + up_proj
4198                + down_proj
4199        };
4200        Ok(vec![
4201            per_layer_elems * dtype.size_in_bytes();
4202            cfg.num_hidden_layers
4203        ])
4204    }
4205
4206    fn num_layers(&self, config: &str) -> Result<usize> {
4207        let cfg: Mistral3Config = serde_json::from_str(config)?;
4208        let cfg = &cfg.text_config;
4209        Ok(cfg.num_hidden_layers)
4210    }
4211
4212    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4213        let cfg: Mistral3Config = serde_json::from_str(config)?;
4214        let cfg = &cfg.text_config;
4215
4216        let cfg = ModelConfigMetadata {
4217            max_seq_len: cfg.max_position_embeddings,
4218            num_layers: cfg.num_hidden_layers,
4219            hidden_size: cfg.hidden_size,
4220            num_kv_heads: cfg.num_key_value_heads,
4221            num_attn_heads: cfg.num_attention_heads,
4222            sliding_window: cfg.sliding_window,
4223            k_head_dim: cfg.head_dim(),
4224            v_head_dim: cfg.head_dim(),
4225            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4226        };
4227
4228        Ok(Box::new(cfg))
4229    }
4230
4231    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4232        Some(vec![NonMappedSubModel::Vision])
4233    }
4234}
4235
4236// ======================== Llama 4 Loader
4237
4238/// [`VisionLoader`] for an Llama Vision model.
4239///
4240/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
4241pub struct VLlama4Loader;
4242
4243pub struct VLlama4Prefixer;
4244
4245impl MultimodalPromptPrefixer for VLlama4Prefixer {
4246    fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4247        format!(
4248            "{}{prompt}",
4249            llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4250        )
4251    }
4252}
4253
4254impl VisionModelLoader for VLlama4Loader {
4255    fn load(
4256        &self,
4257        config: &str,
4258        vb: ShardedVarBuilder,
4259        normal_loading_metadata: NormalLoadingMetadata,
4260        attention_mechanism: AttentionImplementation,
4261    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4262        let mut cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4263        cfg.propagate_quantization_config();
4264        Ok(Box::new(Llama4Model::new(
4265            &cfg,
4266            vb,
4267            self.is_gptx(config),
4268            normal_loading_metadata,
4269            attention_mechanism,
4270        )?))
4271    }
4272    fn is_gptx(&self, _config: &str) -> bool {
4273        false
4274    }
4275    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4276        let mut cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4277        cfg.propagate_quantization_config();
4278        Ok(Box::new(cfg))
4279    }
4280    fn get_processor(
4281        &self,
4282        _model_config: &str,
4283        processor_config: Option<ProcessorConfig>,
4284        _preprocessor_config: PreProcessorConfig,
4285        _max_edge: Option<u32>,
4286    ) -> Arc<dyn Processor + Send + Sync> {
4287        Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4288    }
4289    fn supports_paged_attention(&self, _config: &str) -> bool {
4290        true
4291    }
4292    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4293        Arc::new(VLlama4Prefixer)
4294    }
4295    fn modalities(&self, _config: &str) -> Result<Modalities> {
4296        Ok(Modalities {
4297            input: vec![SupportedModality::Text, SupportedModality::Vision],
4298            output: vec![SupportedModality::Text],
4299        })
4300    }
4301}
4302
4303impl IsqModelLoader for VLlama4Loader {
4304    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4305        Ok(vec![
4306            Regex::new(r"lm_head\.(weight|bias)$")?,
4307            // Attention
4308            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4309            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4310            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4311            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4312            // FF MoE
4313            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4314            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4315            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4316            Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4317            Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4318            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4319            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4320            Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4321            // FF MLP
4322            Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4323            Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4324            Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4325        ])
4326    }
4327    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4328        Ok(vec![
4329            Regex::new(r"lm_head\.(weight|bias)$")?,
4330            // Attention
4331            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4332            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4333            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4334            Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4335            // FF MoE
4336            Regex::new(
4337                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4338            )?,
4339            Regex::new(
4340                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4341            )?,
4342            Regex::new(
4343                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4344            )?,
4345            Regex::new(
4346                r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4347            )?,
4348            Regex::new(
4349                r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4350            )?,
4351            Regex::new(
4352                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4353            )?,
4354            Regex::new(
4355                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4356            )?,
4357            Regex::new(
4358                r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4359            )?,
4360            // FF MLP
4361            Regex::new(
4362                r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4363            )?,
4364            Regex::new(
4365                r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4366            )?,
4367            Regex::new(
4368                r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4369            )?,
4370        ])
4371    }
4372}
4373
4374impl VLlama4Loader {
4375    /// This incorporates the max batch size!
4376    /// Returns (pixels max batch size, num text image tokens)
4377    #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4378    fn run_dummy_processing(
4379        &self,
4380        cfg: &Llama4Config,
4381        height: usize,
4382        width: usize,
4383        max_num_images: usize,
4384        max_batch_size: usize,
4385    ) -> Result<(usize, usize)> {
4386        let cfg = &cfg.vision_config;
4387
4388        let img_processor =
4389            Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4390        let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4391        let res = img_processor.preprocess(
4392            vec![image; max_num_images],
4393            vec![],
4394            &PreProcessorConfig::default(),
4395            &Device::Cpu,
4396            (max_batch_size, max_num_images),
4397        )?;
4398
4399        let pixels_batch_size = res.pixel_values.dim(0)?;
4400        let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4401
4402        let (image_h, image_w) = (
4403            res.pixel_values.dim(D::Minus2).unwrap(),
4404            res.pixel_values.dim(D::Minus1).unwrap(),
4405        );
4406        let num_patches_per_chunk = (image_h / img_processor.patch_size)
4407            * (image_w / img_processor.patch_size)
4408            / img_processor.downsample_ratio;
4409
4410        Ok((
4411            pixels_max_batch_size,
4412            num_patches_per_chunk * pixels_max_batch_size,
4413        ))
4414    }
4415}
4416
4417impl DeviceMappedModelLoader for VLlama4Loader {
4418    fn mapped_max_act_size_elems(
4419        &self,
4420        config: &str,
4421        params: &AutoDeviceMapParams,
4422    ) -> Result<usize> {
4423        let AutoDeviceMapParams::Vision {
4424            max_seq_len,
4425            max_batch_size,
4426            max_image_shape: (height, width),
4427            max_num_images,
4428        } = params
4429        else {
4430            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4431        };
4432
4433        let cfg: Llama4Config = serde_json::from_str(config)?;
4434
4435        let (_pixels_batch_size, num_text_image_toks) =
4436            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4437
4438        let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4439
4440        Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4441    }
4442    fn non_mapped_max_act_size_elems(
4443        &self,
4444        config: &str,
4445        params: &AutoDeviceMapParams,
4446    ) -> Result<usize> {
4447        let AutoDeviceMapParams::Vision {
4448            max_seq_len: _,
4449            max_batch_size,
4450            max_image_shape: (height, width),
4451            max_num_images,
4452        } = params
4453        else {
4454            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4455        };
4456
4457        let cfg: Llama4Config = serde_json::from_str(config)?;
4458
4459        let (pixels_batch_size, _num_text_image_toks) =
4460            self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4461        let max_seq_len = cfg.vision_config.num_patches();
4462
4463        Ok((max_batch_size * pixels_batch_size)
4464            * cfg.vision_config.num_attention_heads
4465            * max_seq_len
4466            * max_seq_len)
4467    }
4468
4469    fn non_mapped_size_in_bytes(
4470        &self,
4471        config: &str,
4472        dtype: DType,
4473        weight_pack_factor: usize,
4474        _matformer_config: Option<&MatformerSliceConfig>,
4475    ) -> Result<usize> {
4476        let cfg: Llama4Config = serde_json::from_str(config)?;
4477        let tcfg = &cfg.text_config;
4478
4479        let text_elems = {
4480            let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4481            let lm_head = if !tcfg.tie_word_embeddings {
4482                tcfg.hidden_size * tcfg.vocab_size
4483            } else {
4484                0
4485            };
4486            let norm = tcfg.hidden_size;
4487            embed_tokens + lm_head + norm
4488        };
4489
4490        let vision_elems = {
4491            let cfg = &cfg.vision_config;
4492
4493            let num_patches = cfg.num_patches();
4494
4495            let unfold_elems =
4496                (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4497            let class_embeddng_elems = cfg.hidden_size;
4498            let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4499            let layernorm_pre_elems = cfg.hidden_size;
4500            let layernorm_post_elems = cfg.hidden_size;
4501
4502            let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4503                / weight_pack_factor
4504                + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4505
4506            let encoder_layer = {
4507                let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4508                let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4509
4510                let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4511                let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4512                    / weight_pack_factor
4513                    + cfg.num_attention_heads * head_dim;
4514                let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4515                    / weight_pack_factor
4516                    + cfg.num_attention_heads * head_dim;
4517                let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4518                    / weight_pack_factor
4519                    + cfg.num_attention_heads * head_dim;
4520                let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4521                    / weight_pack_factor
4522                    + cfg.num_attention_heads * head_dim;
4523
4524                let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4525                    + cfg.intermediate_size;
4526                let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4527                    + cfg.hidden_size;
4528
4529                input_layernorm
4530                    + post_attention_layernorm
4531                    + q_proj
4532                    + k_proj
4533                    + v_proj
4534                    + o_proj
4535                    + fc1
4536                    + fc2
4537            };
4538
4539            unfold_elems
4540                + class_embeddng_elems
4541                + positional_embedding_vlm_elems
4542                + layernorm_post_elems
4543                + layernorm_pre_elems
4544                + pixel_shuffle_elems
4545                + encoder_layer * cfg.num_hidden_layers
4546        };
4547
4548        let elems = text_elems + vision_elems;
4549
4550        Ok(elems * dtype.size_in_bytes())
4551    }
4552
4553    fn layer_sizes_in_bytes(
4554        &self,
4555        config: &str,
4556        dtype: DType,
4557        weight_pack_factor: usize,
4558        _matformer_config: Option<&MatformerSliceConfig>,
4559    ) -> Result<Vec<usize>> {
4560        let cfg: Llama4Config = serde_json::from_str(config)?;
4561        let tcfg = &cfg.text_config;
4562
4563        let mut per_layer_elems = Vec::new();
4564
4565        for layer_idx in 0..tcfg.num_hidden_layers {
4566            let input_layernorm = tcfg.hidden_size;
4567            let post_attention_layernorm = tcfg.hidden_size;
4568
4569            let size_in = tcfg.hidden_size;
4570            let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4571            let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4572            let q_proj = size_in * size_q / weight_pack_factor;
4573            let k_proj = size_in * size_kv / weight_pack_factor;
4574            let v_proj = size_in * size_kv / weight_pack_factor;
4575            let o_proj = size_q * size_in / weight_pack_factor;
4576
4577            let use_moe = tcfg.moe_layers().contains(&layer_idx);
4578            let moe_block = if use_moe {
4579                let h_size = tcfg.hidden_size;
4580                let i_size = tcfg.intermediate_size;
4581                let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4582                let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4583                let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4584
4585                gate_proj + up_proj + down_proj
4586            } else {
4587                let h_size = tcfg.hidden_size;
4588                let i_size = tcfg.intermediate_size_mlp;
4589                let gate_proj = h_size * i_size / weight_pack_factor;
4590                let up_proj = h_size * i_size / weight_pack_factor;
4591                let down_proj = i_size * h_size / weight_pack_factor;
4592
4593                gate_proj + up_proj + down_proj
4594            };
4595
4596            per_layer_elems.push(
4597                input_layernorm
4598                    + post_attention_layernorm
4599                    + q_proj
4600                    + k_proj
4601                    + v_proj
4602                    + o_proj
4603                    + moe_block,
4604            );
4605        }
4606
4607        Ok(per_layer_elems
4608            .into_iter()
4609            .map(|x| x * dtype.size_in_bytes())
4610            .collect())
4611    }
4612
4613    fn num_layers(&self, config: &str) -> Result<usize> {
4614        let cfg: Llama4Config = serde_json::from_str(config)?;
4615        Ok(cfg.text_config.num_hidden_layers)
4616    }
4617
4618    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4619        let cfg: Llama4Config = serde_json::from_str(config)?;
4620        let cfg = &cfg.text_config;
4621
4622        let cfg = ModelConfigMetadata {
4623            max_seq_len: cfg.max_position_embeddings,
4624            num_layers: cfg.num_hidden_layers,
4625            hidden_size: cfg.hidden_size,
4626            num_kv_heads: cfg.num_attention_heads,
4627            num_attn_heads: cfg.num_attention_heads,
4628            sliding_window: None,
4629            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4630            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4631            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4632        };
4633
4634        Ok(Box::new(cfg))
4635    }
4636
4637    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4638        Some(vec![NonMappedSubModel::Vision])
4639    }
4640}
4641
4642// ======================== Gemma 3n Loader
4643
4644/// [`VisionLoader`] for an Gemma 3n model.
4645///
4646/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
4647pub struct Gemma3nLoader;
4648
4649#[allow(dead_code)]
4650pub struct Gemma3nPrefixer;
4651
4652impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4653    fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4654        prompt.to_string()
4655    }
4656}
4657
4658impl VisionModelLoader for Gemma3nLoader {
4659    fn load(
4660        &self,
4661        config: &str,
4662        vb: ShardedVarBuilder,
4663        normal_loading_metadata: NormalLoadingMetadata,
4664        attention_mechanism: AttentionImplementation,
4665    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4666        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4667        Ok(Box::new(Gemma3nModel::new(
4668            &cfg,
4669            vb,
4670            self.is_gptx(config),
4671            normal_loading_metadata,
4672            attention_mechanism,
4673        )?))
4674    }
4675    fn is_gptx(&self, _config: &str) -> bool {
4676        true
4677    }
4678    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4679        let config: Gemma3nConfig = serde_json::from_str(config)?;
4680        Ok(Box::new(config))
4681    }
4682    fn get_processor(
4683        &self,
4684        _config: &str,
4685        processor_config: Option<ProcessorConfig>,
4686        _preprocessor_config: PreProcessorConfig,
4687        _max_edge: Option<u32>,
4688    ) -> Arc<dyn Processor + Send + Sync> {
4689        // Handle the Gemma 3 1b case here
4690        Arc::new(Gemma3nProcessor::new(
4691            processor_config.unwrap_or_default(),
4692            true,
4693        ))
4694    }
4695    fn supports_paged_attention(&self, _config: &str) -> bool {
4696        false
4697    }
4698    fn supports_prefix_cacher(&self, _config: &str) -> bool {
4699        true
4700    }
4701    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4702        Arc::new(Gemma3Prefixer)
4703    }
4704    fn modalities(&self, _config: &str) -> Result<Modalities> {
4705        Ok(Modalities {
4706            input: vec![
4707                SupportedModality::Text,
4708                SupportedModality::Vision,
4709                SupportedModality::Audio,
4710            ],
4711            output: vec![SupportedModality::Text],
4712        })
4713    }
4714}
4715
4716impl IsqModelLoader for Gemma3nLoader {
4717    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4718        Ok(vec![
4719            Regex::new(r"lm_head\.(weight|bias)$")?,
4720            // Language model attention
4721            Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4722            Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4723            Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4724            Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4725            // Language model MLP
4726            Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4727            Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4728            Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4729            // Audio conformer attention layers
4730            Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4731            Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4732            Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4733            Regex::new(
4734                r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4735            )?,
4736            Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4737            // Audio conformer FFW layers
4738            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4739            Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4740            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4741            Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4742            // Audio conformer conv1d layers
4743            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4744            Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4745            // Audio subsample projection
4746            Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4747            // Multimodal embedders
4748            Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4749            Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4750        ])
4751    }
4752    fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4753        Ok(vec![
4754            Regex::new(r"lm_head\.(weight|bias)$")?,
4755            // Language model attention
4756            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4757            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4758            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4759            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4760            // Language model MLP
4761            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4762            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4763            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4764            // Projections
4765            Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4766            Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4767            Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4768            // Audio conformer attention layers
4769            Regex::new(
4770                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4771            )?,
4772            Regex::new(
4773                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4774            )?,
4775            Regex::new(
4776                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4777            )?,
4778            Regex::new(
4779                r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4780            )?,
4781            Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4782            // Audio conformer FFW layers
4783            Regex::new(
4784                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4785            )?,
4786            Regex::new(
4787                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4788            )?,
4789            Regex::new(
4790                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4791            )?,
4792            Regex::new(
4793                r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4794            )?,
4795            // Audio conformer conv1d layers
4796            Regex::new(
4797                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4798            )?,
4799            Regex::new(
4800                r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4801            )?,
4802            // Audio subsample projection
4803            Regex::new(
4804                r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4805            )?,
4806            // Multimodal embedders
4807            Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4808            Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4809        ])
4810    }
4811}
4812
4813impl DeviceMappedModelLoader for Gemma3nLoader {
4814    fn mapped_max_act_size_elems(
4815        &self,
4816        config: &str,
4817        params: &AutoDeviceMapParams,
4818    ) -> Result<usize> {
4819        let AutoDeviceMapParams::Vision {
4820            max_seq_len,
4821            max_batch_size,
4822            max_image_shape: _,
4823            max_num_images,
4824        } = params
4825        else {
4826            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4827        };
4828
4829        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4830        let text_cfg = &cfg.text_config;
4831
4832        // Gemma3n is an "inject into the prompt" model, similar to Gemma3
4833        // We need to account for vision and audio tokens in the sequence length
4834
4835        let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4836
4837        // Add vision tokens
4838        {
4839            // Vision tokens are injected into the prompt
4840            // MSFA outputs fixed 16x16 features regardless of input size
4841            let msfa_spatial_size = 16; // Fixed from vision.rs line 1115
4842            let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; // 256 tokens
4843            total_seq_len += vision_tokens_per_image * max_num_images;
4844        }
4845
4846        // Add audio tokens
4847        {
4848            // Audio tokens are injected into the prompt
4849            // From config field audio_soft_tokens_per_image (typically 188)
4850            let audio_tokens = cfg.audio_soft_tokens_per_image;
4851            total_seq_len += audio_tokens;
4852        }
4853
4854        // Calculate max attention size for text model with all injected tokens
4855        let max_text_attn =
4856            max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4857
4858        Ok(max_text_attn)
4859    }
4860
4861    fn non_mapped_max_act_size_elems(
4862        &self,
4863        config: &str,
4864        params: &AutoDeviceMapParams,
4865    ) -> Result<usize> {
4866        let AutoDeviceMapParams::Vision {
4867            max_seq_len: _,
4868            max_batch_size,
4869            max_image_shape: _,
4870            max_num_images,
4871        } = params
4872        else {
4873            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4874        };
4875
4876        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4877
4878        // Calculate max activation sizes for each modality
4879        let mut max_activation = 0;
4880
4881        // Vision activation size
4882        {
4883            // Vision is Gemma3n's MobileNetV5 architecture with Multi-Query Attention
4884            // The peak activation is in the Multi-Query Attention layers
4885
4886            // From the architecture: stages 3 and 4 have MMQA blocks
4887            // Input images are 768x768 (from inputs_processor.rs)
4888            // Stage 3: 640 channels at 48x48 (768/16 downsampling), MMQA with num_heads=12, kv_dim=64
4889            // Stage 4: 1280 channels at 24x24 (768/32 downsampling), MMQA with num_heads=16, kv_dim=96
4890            // MSFA output: 2048 channels at fixed 16x16
4891
4892            let vision_tower_act = {
4893                // Peak is during MMQA attention computation in stage 4
4894                // Stage 4 has higher memory usage than Stage 3 due to more heads (16 vs 12)
4895                // From vision.rs: Stage 4 has num_heads=16, kv_dim=96, kv_stride=1
4896                let num_heads = 16; // Stage 4 configuration
4897                let spatial_size = 24; // 768 / 32 = 24 (input 768x768, stage 4 has 32x downsampling)
4898                let seq_len = spatial_size * spatial_size;
4899
4900                // Attention scores: [B * num_images, num_heads, seq_len, seq_len]
4901                max_batch_size * max_num_images * num_heads * seq_len * seq_len
4902            };
4903
4904            // Vision embedder activations
4905            let vision_embed_act = {
4906                // MSFA output: 2048 channels at fixed 16x16 spatial (from vision.rs line 1115)
4907                let msfa_channels = 2048; // MSFA_OUT_CHANNELS from vision.rs
4908                let spatial_size = 16; // Fixed output resolution from MSFA
4909                let vision_features =
4910                    max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4911
4912                // After embedding projection to text hidden size
4913                let projected = max_batch_size
4914                    * max_num_images
4915                    * spatial_size
4916                    * spatial_size
4917                    * cfg.text_config.hidden_size;
4918
4919                vision_features.max(projected)
4920            };
4921
4922            max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4923        }
4924
4925        // Audio activation size
4926        {
4927            let audio_cfg = &cfg.audio_config;
4928
4929            // Calculate max audio sequence length based on config
4930            // Audio uses conformer with subsampling and reduction
4931
4932            // A rough estimate of max_audio_frames
4933            let max_audio_frames = 1280;
4934
4935            let subsample_factor: usize = audio_cfg
4936                .sscp_conv_stride_size
4937                .iter()
4938                .map(|stride| stride[0]) // Time dimension stride
4939                .product();
4940            let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4941
4942            // Audio encoder activations
4943            let audio_encoder_act = {
4944                // Conformer FFW layers have expansion factor from config
4945                let intermediate_size = audio_cfg.hidden_size * 4; // FFW expansion factor
4946
4947                // Peak is in the FFW layers before reduction
4948                max_batch_size * audio_seq_after_subsample * intermediate_size
4949            };
4950
4951            // Audio attention activations
4952            let audio_attn_act = {
4953                // Attention uses chunked processing with specific context sizes
4954                let chunk_size = audio_cfg.conf_attention_chunk_size;
4955                let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4956                    + audio_cfg.conf_attention_context_right;
4957
4958                // Peak is attention scores: [B, num_heads, num_chunks, chunk_size, context_size]
4959                let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4960
4961                max_batch_size
4962                    * audio_cfg.conf_num_attention_heads
4963                    * num_chunks
4964                    * chunk_size
4965                    * context_size
4966            };
4967
4968            max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4969        }
4970
4971        Ok(max_activation)
4972    }
4973
4974    fn non_mapped_size_in_bytes(
4975        &self,
4976        config: &str,
4977        dtype: DType,
4978        weight_pack_factor: usize,
4979        matformer_config: Option<&MatformerSliceConfig>,
4980    ) -> Result<usize> {
4981        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4982
4983        // Apply matformer slicing if configured
4984        let text_cfg = if let Some(matformer_cfg) = matformer_config {
4985            use crate::device_map::DummyDeviceMapper;
4986            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4987
4988            let dummy_mapper = DummyDeviceMapper {
4989                nm_device: Device::Cpu,
4990            };
4991            let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4992                &cfg.text_config,
4993                &Some(matformer_cfg.clone()),
4994                &dummy_mapper,
4995            )?;
4996            adjusted_cfg
4997        } else {
4998            cfg.text_config.clone()
4999        };
5000
5001        let text_cfg = &text_cfg;
5002
5003        // Text components that are not device-mapped
5004        let text_elems = {
5005            // Embeddings
5006            let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
5007            let embed_tokens_per_layer = text_cfg.num_hidden_layers
5008                * text_cfg.hidden_size_per_layer_input
5009                * text_cfg.vocab_size_per_layer_input;
5010
5011            // LM head (if not tied)
5012            let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
5013                text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
5014            } else {
5015                0
5016            };
5017
5018            // Final layer norm
5019            let norm = text_cfg.hidden_size;
5020
5021            // AltUp projections (not device-mapped)
5022            let altup_projections =
5023                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5024                    / weight_pack_factor;
5025            let altup_unembed_projections =
5026                (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5027                    / weight_pack_factor;
5028
5029            // Per-layer model projection
5030            let per_layer_model_projection = text_cfg.num_hidden_layers
5031                * text_cfg.hidden_size
5032                * text_cfg.hidden_size_per_layer_input
5033                / weight_pack_factor;
5034            let per_layer_projection_norm = text_cfg.hidden_size;
5035
5036            embed_tokens
5037                + embed_tokens_per_layer
5038                + lm_head
5039                + norm
5040                + altup_projections
5041                + altup_unembed_projections
5042                + per_layer_model_projection
5043                + per_layer_projection_norm
5044        };
5045
5046        // Vision components
5047        let vision_elems = {
5048            let vision_cfg = &cfg.vision_config;
5049            // Vision tower - calculated from actual Gemma3n architecture
5050            // NOTE: Vision tower uses only Conv2d layers, NOT Arc<dyn QuantMethod>,
5051            // so NONE of these should be divided by weight_pack_factor
5052            let vision_tower_elems = {
5053                use crate::vision_models::gemma3n::vision::{
5054                    gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5055                    MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5056                    STEM_OUT_CHANNELS,
5057                };
5058
5059                // Stem: ConvNormAct (Conv2d + RMSNorm)
5060                let stem_conv =
5061                    INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5062                let stem_norm = STEM_OUT_CHANNELS; // RMSNorm weight
5063
5064                // Track input channels through the network
5065                let mut in_chs = STEM_OUT_CHANNELS;
5066                let mut total_elems = stem_conv + stem_norm;
5067
5068                // Process all stages from gemma3n_mobilenet_def
5069                let block_defs = gemma3n_mobilenet_def();
5070
5071                for stage_blocks in block_defs.iter() {
5072                    for block_type in stage_blocks.iter() {
5073                        match block_type {
5074                            BlockType::EdgeResidual {
5075                                out_channels,
5076                                kernel_size,
5077                                stride: _,
5078                                expand_ratio,
5079                                ..
5080                            } => {
5081                                #[allow(clippy::cast_precision_loss)]
5082                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5083                                // EdgeResidual: all Conv2d layers, not quantizable
5084                                total_elems += in_chs * mid_chs * kernel_size * kernel_size; // conv_exp (Conv2d)
5085                                total_elems += mid_chs; // bn1 weight
5086                                total_elems += mid_chs * out_channels; // conv_pwl (Conv2d)
5087                                total_elems += out_channels; // bn2 weight
5088                                in_chs = *out_channels;
5089                            }
5090                            BlockType::UniversalInvertedResidual {
5091                                out_channels,
5092                                start_kernel_size,
5093                                mid_kernel_size,
5094                                stride: _,
5095                                expand_ratio,
5096                                ..
5097                            } => {
5098                                #[allow(clippy::cast_precision_loss)]
5099                                let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5100                                // UniversalInvertedResidual: all Conv2d layers, not quantizable
5101                                if *expand_ratio != 1.0 {
5102                                    total_elems += in_chs * mid_chs; // expand conv (Conv2d)
5103                                    total_elems += mid_chs; // expand norm
5104                                }
5105                                if *start_kernel_size > 0 {
5106                                    total_elems += mid_chs * start_kernel_size * start_kernel_size; // depthwise start (Conv2d)
5107                                    total_elems += mid_chs; // norm
5108                                }
5109                                if *mid_kernel_size > 0 {
5110                                    total_elems += mid_chs * mid_kernel_size * mid_kernel_size; // depthwise mid (Conv2d)
5111                                    total_elems += mid_chs; // norm
5112                                }
5113                                total_elems += mid_chs * out_channels; // project conv (Conv2d)
5114                                total_elems += out_channels; // project norm
5115                                total_elems += out_channels; // layer scale gamma
5116                                in_chs = *out_channels;
5117                            }
5118                            BlockType::MultiQueryAttention {
5119                                num_heads,
5120                                kv_dim,
5121                                kv_stride: _,
5122                                ..
5123                            } => {
5124                                // MMQA: all Conv2d layers, not quantizable
5125                                let dw_kernel_size = 3; // Default dw_kernel_size for MMQA
5126                                total_elems += in_chs; // norm weight
5127                                total_elems += in_chs * num_heads * kv_dim; // query_proj (Conv2d)
5128                                total_elems += in_chs * kv_dim; // key_proj (Conv2d)
5129                                total_elems += in_chs * dw_kernel_size * dw_kernel_size; // key_dw_conv (Conv2d)
5130                                total_elems += *kv_dim; // value_down_conv (Conv2d)
5131                                total_elems += 1; // value_norm weight
5132                                total_elems += *kv_dim; // value_proj (Conv2d)
5133                                total_elems += num_heads * kv_dim * in_chs; // output_proj (Conv2d)
5134                                total_elems += in_chs; // layer scale
5135                            }
5136                        }
5137                    }
5138                }
5139
5140                // Multi-scale fusion adapter (msfa) - also uses Conv2d layers
5141                let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5142                let msfa_out = MSFA_OUT_CHANNELS;
5143                #[allow(clippy::cast_precision_loss)]
5144                let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5145
5146                // MSFA FFN (UIR with expansion_ratio) - Conv2d layers, not quantizable
5147                total_elems += msfa_in * msfa_mid; // expand (Conv2d)
5148                total_elems += msfa_mid; // expand norm
5149                total_elems += msfa_mid * msfa_out; // project (Conv2d)
5150                total_elems += msfa_out; // project norm
5151                total_elems += msfa_out; // final norm
5152
5153                total_elems
5154            };
5155
5156            // Vision multimodal embedder components
5157            let embed_vision_elems = {
5158                // Embedding layer (not quantizable)
5159                let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5160
5161                // Normalization layers (not quantizable)
5162                let hard_norm = vision_cfg.hidden_size;
5163                let soft_norm = vision_cfg.hidden_size;
5164
5165                // Projection from vision to text hidden size (IS Arc<dyn QuantMethod>, so quantizable)
5166                let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5167
5168                // Post-projection norm (not quantizable)
5169                let post_norm = text_cfg.hidden_size;
5170
5171                embedding + hard_norm + soft_norm + projection + post_norm
5172            };
5173
5174            vision_tower_elems + embed_vision_elems
5175        };
5176
5177        // Audio components - based on actual audio.rs structure
5178        let audio_elems = {
5179            let audio_cfg = &cfg.audio_config;
5180
5181            // SubSampleConvProjection components
5182            let subsample_conv_projection_elems = {
5183                // Conv blocks (Conv2d layers - NOT quantizable)
5184                let mut conv_elems = 0;
5185
5186                // conv_0: Conv2d from 1 channel to first channel size
5187                let in_ch_0 = 1;
5188                let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5189                let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5190                conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5191
5192                // conv_1: Conv2d from first to second channel size
5193                let in_ch_1 = out_ch_0;
5194                let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5195                let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5196                conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5197
5198                // CumulativeGroupNorm for each conv block (weight only, no bias by default)
5199                let norm_0 = out_ch_0; // norm weight for conv_0
5200                let norm_1 = out_ch_1; // norm weight for conv_1
5201
5202                // input_proj_linear (Arc<dyn QuantMethod> - IS quantizable)
5203                let mut f_out = audio_cfg.input_feat_size;
5204                for i in 0..2 {
5205                    let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5206                    let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5207                    let pad_left = 1;
5208                    let pad_right = 1;
5209                    f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5210                }
5211                let input_proj_in_features = out_ch_1 * f_out;
5212                let input_proj_linear =
5213                    input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5214
5215                conv_elems + norm_0 + norm_1 + input_proj_linear
5216            };
5217
5218            // Conformer blocks
5219            let conformer_elems = {
5220                let mut total = 0;
5221
5222                for _ in 0..audio_cfg.conf_num_hidden_layers {
5223                    // ConformerAttention
5224                    let attention_elems = {
5225                        // Norms (NOT quantizable)
5226                        let pre_attn_norm = audio_cfg.hidden_size;
5227                        let post_norm = audio_cfg.hidden_size;
5228
5229                        // Attention projections (Arc<dyn QuantMethod> - IS quantizable)
5230                        let q_proj =
5231                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5232                        let k_proj =
5233                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5234                        let v_proj =
5235                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5236                        let post =
5237                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5238
5239                        // RelativePositionEmbedding
5240                        let pos_proj =
5241                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5242                        let per_dim_scale =
5243                            audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; // head_dim
5244                        let inv_timescales = audio_cfg.hidden_size / 2; // num_timescales
5245                        let pos_indices = audio_cfg.conf_attention_context_left
5246                            + audio_cfg.conf_attention_context_right
5247                            + 1;
5248
5249                        // Local causal masks (precomputed tensors)
5250                        let chunk_size = audio_cfg.conf_attention_chunk_size;
5251                        let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5252                            + audio_cfg.conf_attention_context_right;
5253                        let local_causal_valid_mask = chunk_size * context_size; // U8 tensor
5254                        let invalid_logits_tensor = 1; // single f32 value
5255
5256                        pre_attn_norm
5257                            + post_norm
5258                            + q_proj
5259                            + k_proj
5260                            + v_proj
5261                            + post
5262                            + pos_proj
5263                            + per_dim_scale
5264                            + inv_timescales
5265                            + pos_indices
5266                            + local_causal_valid_mask
5267                            + invalid_logits_tensor
5268                    };
5269
5270                    // ConformerFeedForward (start and end)
5271                    let ffw_elems = {
5272                        // Each FFW has:
5273                        // - pre_layer_norm (NOT quantizable)
5274                        // - ffw_layer_1 (Arc<dyn QuantMethod> - IS quantizable)
5275                        // - ffw_layer_2 (Arc<dyn QuantMethod> - IS quantizable)
5276                        // - post_layer_norm (NOT quantizable)
5277                        let intermediate_size = audio_cfg.hidden_size * 4;
5278
5279                        let ffw_start = {
5280                            let pre_norm = audio_cfg.hidden_size;
5281                            let layer_1 =
5282                                audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5283                            let layer_2 =
5284                                intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5285                            let post_norm = audio_cfg.hidden_size;
5286                            pre_norm + layer_1 + layer_2 + post_norm
5287                        };
5288
5289                        let ffw_end = ffw_start; // Same structure
5290
5291                        ffw_start + ffw_end
5292                    };
5293
5294                    // ConformerLightConv1d
5295                    let lconv1d_elems = {
5296                        // Norms (NOT quantizable)
5297                        let pre_layer_norm = audio_cfg.hidden_size;
5298                        let conv_norm = audio_cfg.hidden_size;
5299
5300                        // Linear layers (Arc<dyn QuantMethod> - IS quantizable)
5301                        let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5302                            / weight_pack_factor;
5303                        let linear_end =
5304                            audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5305
5306                        // depthwise_conv1d (Conv1d - NOT quantizable)
5307                        let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5308
5309                        pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5310                    };
5311
5312                    // Final norm for conformer block (NOT quantizable)
5313                    let block_norm = audio_cfg.hidden_size;
5314
5315                    total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5316                }
5317
5318                total
5319            };
5320
5321            // Audio multimodal embedder (embed_audio)
5322            let embed_audio_elems = {
5323                // Embedding layer (ScaledEmbedding - NOT quantizable)
5324                let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5325
5326                // RMS norms (NOT quantizable)
5327                let hard_embedding_norm = audio_cfg.hidden_size; // with scale
5328                let soft_embedding_norm = audio_cfg.hidden_size; // with scale
5329                let embedding_post_projection_norm = text_cfg.hidden_size; // without scale
5330
5331                // Projection (Arc<dyn QuantMethod> - IS quantizable)
5332                let embedding_projection =
5333                    audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5334
5335                embedding
5336                    + hard_embedding_norm
5337                    + soft_embedding_norm
5338                    + embedding_post_projection_norm
5339                    + embedding_projection
5340            };
5341
5342            subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5343        };
5344
5345        let vision_dtype = if dtype == DType::F16 {
5346            // f16 -> f32 for vision model in particular.
5347            DType::F32
5348        } else {
5349            dtype
5350        };
5351
5352        let total_elems = text_elems * dtype.size_in_bytes()
5353            + vision_elems * vision_dtype.size_in_bytes()
5354            + audio_elems * dtype.size_in_bytes();
5355
5356        Ok(total_elems)
5357    }
5358
5359    fn layer_sizes_in_bytes(
5360        &self,
5361        config: &str,
5362        dtype: DType,
5363        weight_pack_factor: usize,
5364        matformer_config: Option<&MatformerSliceConfig>,
5365    ) -> Result<Vec<usize>> {
5366        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5367
5368        // Apply matformer slicing if configured
5369        let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5370            matformer_config
5371        {
5372            use crate::device_map::DummyDeviceMapper;
5373            use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5374
5375            let dummy_mapper = DummyDeviceMapper {
5376                nm_device: Device::Cpu,
5377            };
5378            let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5379                &cfg.text_config,
5380                &Some(matformer_cfg.clone()),
5381                &dummy_mapper,
5382            )?;
5383            (adjusted_cfg, layer_rename_map, layers_skipped)
5384        } else {
5385            (cfg.text_config.clone(), None, None)
5386        };
5387
5388        let text_cfg = &text_cfg;
5389
5390        // When matformer slicing is applied, we only include the layers that are kept
5391        let mut layer_sizes = Vec::new();
5392
5393        // Note: We don't need orig_intermediate_sizes anymore since the adjusted config
5394        // already has the correct intermediate sizes after matformer slicing
5395
5396        for layer_idx in 0..text_cfg.num_hidden_layers {
5397            let per_layer_elems = {
5398                // Layer norms
5399                let input_layernorm = text_cfg.hidden_size;
5400                let post_attention_layernorm = text_cfg.hidden_size;
5401                let pre_feedforward_layernorm = text_cfg.hidden_size;
5402                let post_feedforward_layernorm = text_cfg.hidden_size;
5403                let post_per_layer_input_norm = text_cfg.hidden_size;
5404
5405                // Attention components
5406                let size_in = text_cfg.hidden_size;
5407                let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5408                let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5409
5410                let q_proj = size_in * size_q / weight_pack_factor;
5411                let k_proj = size_in * size_kv / weight_pack_factor;
5412                let v_proj = size_in * size_kv / weight_pack_factor;
5413                let o_proj = size_q * size_in / weight_pack_factor;
5414
5415                // Q, K, V norms
5416                let q_norm = text_cfg.head_dim;
5417                let k_norm = text_cfg.head_dim;
5418                let v_norm = text_cfg.head_dim; // No bias for v_norm
5419
5420                // MLP components - use the adjusted intermediate sizes from matformer
5421                let intermediate_size = match &text_cfg.intermediate_size {
5422                    IntermediateSize::Single(size) => *size,
5423                    IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5424                    IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5425                };
5426                let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5427                let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5428                let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5429
5430                // AltUp components (per layer)
5431                let altup_elems = {
5432                    let correct_output_scale = text_cfg.hidden_size;
5433                    let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5434                    let prediction_coefs =
5435                        text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5436                    let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5437                    let router_norm = text_cfg.hidden_size;
5438
5439                    correct_output_scale
5440                        + correction_coefs
5441                        + prediction_coefs
5442                        + modality_router
5443                        + router_norm
5444                };
5445
5446                // Laurel block components
5447                let laurel_elems = {
5448                    let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5449                    let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5450                    let post_norm = text_cfg.hidden_size;
5451
5452                    left + right + post_norm
5453                };
5454
5455                // Per-layer input components
5456                let per_layer_input_gate =
5457                    text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5458                let per_layer_projection =
5459                    text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5460
5461                input_layernorm
5462                    + post_attention_layernorm
5463                    + pre_feedforward_layernorm
5464                    + post_feedforward_layernorm
5465                    + post_per_layer_input_norm
5466                    + q_proj
5467                    + k_proj
5468                    + v_proj
5469                    + o_proj
5470                    + q_norm
5471                    + k_norm
5472                    + v_norm
5473                    + gate_proj
5474                    + up_proj
5475                    + down_proj
5476                    + altup_elems
5477                    + laurel_elems
5478                    + per_layer_input_gate
5479                    + per_layer_projection
5480            };
5481
5482            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5483        }
5484
5485        Ok(layer_sizes)
5486    }
5487
5488    fn num_layers(&self, config: &str) -> Result<usize> {
5489        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5490        Ok(cfg.text_config.num_hidden_layers)
5491    }
5492
5493    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5494        let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5495        let cfg = cfg.text_config;
5496
5497        let cfg = ModelConfigMetadata {
5498            max_seq_len: cfg.max_position_embeddings,
5499            num_layers: cfg.num_hidden_layers,
5500            hidden_size: cfg.hidden_size,
5501            num_kv_heads: cfg.num_key_value_heads,
5502            num_attn_heads: cfg.num_attention_heads,
5503            sliding_window: None, // None to be more forgiving, some do not
5504            k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5505            v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5506            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5507        };
5508
5509        Ok(Box::new(cfg))
5510    }
5511
5512    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5513        Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5514    }
5515}
5516
5517// ======================== Qwen3VL Loader
5518
5519/// [`VisionLoader`] for an Qwen3VL model.
5520///
5521/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
5522pub struct Qwen3VLLoader;
5523
5524pub struct Qwen3VLPrefixer;
5525
5526impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5527    // No-op: With MessagesAction::Keep, the chat template handles image tokens
5528    // when it sees {"type": "image"} entries in the content.
5529}
5530
5531impl VisionModelLoader for Qwen3VLLoader {
5532    fn load(
5533        &self,
5534        config: &str,
5535        vb: ShardedVarBuilder,
5536        normal_loading_metadata: NormalLoadingMetadata,
5537        attention_mechanism: AttentionImplementation,
5538    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5539        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5540        Ok(Box::new(Qwen3VLModel::new(
5541            &cfg,
5542            vb,
5543            self.is_gptx(config),
5544            normal_loading_metadata,
5545            attention_mechanism,
5546        )?))
5547    }
5548    fn is_gptx(&self, _config: &str) -> bool {
5549        true
5550    }
5551    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5552        let config: Qwen3VLConfig = serde_json::from_str(config)?;
5553        Ok(Box::new(config))
5554    }
5555    fn get_processor(
5556        &self,
5557        _model_config: &str,
5558        _processor_config: Option<ProcessorConfig>,
5559        _preprocessor_config: PreProcessorConfig,
5560        max_edge: Option<u32>,
5561    ) -> Arc<dyn Processor + Send + Sync> {
5562        Arc::new(Qwen3VLProcessor::new(max_edge))
5563    }
5564    fn supports_paged_attention(&self, _config: &str) -> bool {
5565        true
5566    }
5567    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5568        Arc::new(Qwen3VLPrefixer)
5569    }
5570    fn modalities(&self, _config: &str) -> Result<Modalities> {
5571        Ok(Modalities {
5572            input: vec![SupportedModality::Text, SupportedModality::Vision],
5573            output: vec![SupportedModality::Text],
5574        })
5575    }
5576}
5577
5578impl IsqModelLoader for Qwen3VLLoader {
5579    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5580        Ok(vec![
5581            Regex::new(r"lm_head\.(weight|bias)$")?,
5582            // Attention
5583            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5584            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5585            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5586            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5587            // MLP
5588            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5589            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5590            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5591        ])
5592    }
5593    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5594        self.isq_layer_regexes(config)
5595    }
5596}
5597
5598impl DeviceMappedModelLoader for Qwen3VLLoader {
5599    fn mapped_max_act_size_elems(
5600        &self,
5601        config: &str,
5602        params: &AutoDeviceMapParams,
5603    ) -> Result<usize> {
5604        let AutoDeviceMapParams::Vision {
5605            max_seq_len,
5606            max_batch_size,
5607            max_image_shape,
5608            max_num_images,
5609        } = params
5610        else {
5611            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5612        };
5613
5614        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5615
5616        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
5617        let img_seq_len = {
5618            let cfg = &cfg.vision_config;
5619            // grid_t is 1 for images (temporal dimension is for video only)
5620            let grid_t = 1;
5621            // After patch embedding and spatial merge, the effective grid dimensions are reduced
5622            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5623            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5624            grid_t * grid_h * grid_w * max_num_images
5625        };
5626
5627        let max_text_attn = {
5628            let cfg = &cfg.text_config;
5629            // This model injects the vision information directly into the input embeddings
5630            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5631            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5632        };
5633
5634        Ok(max_text_attn)
5635    }
5636
5637    fn non_mapped_max_act_size_elems(
5638        &self,
5639        config: &str,
5640        params: &AutoDeviceMapParams,
5641    ) -> Result<usize> {
5642        let AutoDeviceMapParams::Vision {
5643            max_seq_len: _,
5644            max_batch_size,
5645            max_image_shape,
5646            max_num_images,
5647        } = params
5648        else {
5649            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5650        };
5651
5652        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5653
5654        // For the vision encoder, before spatial merging
5655        let img_seq_len = {
5656            let cfg = &cfg.vision_config;
5657            // grid_t is 1 for images
5658            let grid_t = 1;
5659            let grid_h = max_image_shape.0 / cfg.patch_size;
5660            let grid_w = max_image_shape.1 / cfg.patch_size;
5661            grid_t * grid_h * grid_w
5662        };
5663
5664        let max_vision_attn = {
5665            let cfg = &cfg.vision_config;
5666            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5667        };
5668
5669        Ok(max_vision_attn)
5670    }
5671
5672    fn non_mapped_size_in_bytes(
5673        &self,
5674        config: &str,
5675        dtype: DType,
5676        weight_pack_factor: usize,
5677        _matformer_config: Option<&MatformerSliceConfig>,
5678    ) -> Result<usize> {
5679        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5680        let tie = cfg.tie_word_embeddings;
5681        let text_elems = {
5682            let cfg = &cfg.text_config;
5683            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5684            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
5685            let lm_head = if !tie || weight_pack_factor != 1 {
5686                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5687            } else {
5688                0
5689            };
5690            let norm = cfg.hidden_size;
5691            embed_tokens + lm_head + norm
5692        };
5693
5694        let patch_merger = {
5695            let cfg = &cfg.vision_config;
5696            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5697
5698            let mlp0 = hidden_size * hidden_size + hidden_size;
5699            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
5700
5701            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5702
5703            mlp0 + mlp2 + ln_q
5704        };
5705
5706        let patch_embed = {
5707            let cfg = &cfg.vision_config;
5708            let conv_cfg = Conv3dConfig {
5709                stride: cfg.patch_size,
5710                ..Default::default()
5711            };
5712            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5713            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5714                * kernel_sizes[0]
5715                * kernel_sizes[1]
5716                * kernel_sizes[2]
5717        };
5718
5719        let encoder_layer = {
5720            let cfg = &cfg.vision_config;
5721            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5722            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5723
5724            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5725            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5726            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5727
5728            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5729            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5730
5731            norm1 + norm2 + fc1 + fc2 + qkv + out
5732        };
5733
5734        let elems =
5735            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
5736
5737        Ok(elems * dtype.size_in_bytes())
5738    }
5739
5740    fn layer_sizes_in_bytes(
5741        &self,
5742        config: &str,
5743        dtype: DType,
5744        weight_pack_factor: usize,
5745        _matformer_config: Option<&MatformerSliceConfig>,
5746    ) -> Result<Vec<usize>> {
5747        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5748        let per_layer_elems = {
5749            let cfg = &cfg.text_config;
5750            let input_layernorm = cfg.hidden_size;
5751            let post_attention_layernorm = cfg.hidden_size;
5752
5753            let size_in = cfg.hidden_size;
5754            let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
5755            let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
5756            let q_proj = size_in * size_q / weight_pack_factor + size_q;
5757            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
5758            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
5759            let o_proj = size_q * size_in / weight_pack_factor;
5760
5761            let h_size = cfg.hidden_size;
5762            let i_size = cfg.intermediate_size;
5763            let gate_proj = h_size * i_size / weight_pack_factor;
5764            let up_proj = h_size * i_size / weight_pack_factor;
5765            let down_proj = i_size * h_size / weight_pack_factor;
5766
5767            input_layernorm
5768                + post_attention_layernorm
5769                + q_proj
5770                + k_proj
5771                + v_proj
5772                + o_proj
5773                + gate_proj
5774                + up_proj
5775                + down_proj
5776        };
5777        Ok(vec![
5778            per_layer_elems * dtype.size_in_bytes();
5779            cfg.text_config.num_hidden_layers
5780        ])
5781    }
5782
5783    fn num_layers(&self, config: &str) -> Result<usize> {
5784        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5785        let cfg = &cfg.text_config;
5786        Ok(cfg.num_hidden_layers)
5787    }
5788
5789    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5790        let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5791        let cfg = &cfg.text_config;
5792
5793        let cfg = ModelConfigMetadata {
5794            max_seq_len: cfg.max_position_embeddings,
5795            num_layers: cfg.num_hidden_layers,
5796            hidden_size: cfg.hidden_size,
5797            num_kv_heads: cfg.num_key_value_heads,
5798            num_attn_heads: cfg.num_attention_heads,
5799            sliding_window: cfg.sliding_window,
5800            k_head_dim: cfg.head_dim,
5801            v_head_dim: cfg.head_dim,
5802            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5803        };
5804
5805        Ok(Box::new(cfg))
5806    }
5807
5808    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5809        Some(vec![NonMappedSubModel::Vision])
5810    }
5811}
5812
5813// ======================== Qwen3VLMoE Loader
5814
5815/// [`VisionLoader`] for a Qwen3VLMoE model.
5816///
5817/// [`VisionLoader`]: https://docs.rs/mistralrs/latest/mistralrs/struct.VisionLoader.html
5818pub struct Qwen3VLMoELoader;
5819
5820pub struct Qwen3VLMoEPrefixer;
5821
5822impl MultimodalPromptPrefixer for Qwen3VLMoEPrefixer {
5823    // No-op: With MessagesAction::Keep, the chat template handles image tokens
5824    // when it sees {"type": "image"} entries in the content.
5825}
5826
5827impl VisionModelLoader for Qwen3VLMoELoader {
5828    fn load(
5829        &self,
5830        config: &str,
5831        vb: ShardedVarBuilder,
5832        normal_loading_metadata: NormalLoadingMetadata,
5833        attention_mechanism: AttentionImplementation,
5834    ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5835        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5836        Ok(Box::new(Qwen3VLMoEModel::new(
5837            &cfg,
5838            vb,
5839            self.is_gptx(config),
5840            normal_loading_metadata,
5841            attention_mechanism,
5842        )?))
5843    }
5844    fn is_gptx(&self, _config: &str) -> bool {
5845        true
5846    }
5847    fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5848        let config: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5849        Ok(Box::new(config))
5850    }
5851    fn get_processor(
5852        &self,
5853        _model_config: &str,
5854        _processor_config: Option<ProcessorConfig>,
5855        _preprocessor_config: PreProcessorConfig,
5856        max_edge: Option<u32>,
5857    ) -> Arc<dyn Processor + Send + Sync> {
5858        Arc::new(Qwen3VLMoEProcessor::new(max_edge))
5859    }
5860    fn supports_paged_attention(&self, _config: &str) -> bool {
5861        true
5862    }
5863    fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5864        Arc::new(Qwen3VLMoEPrefixer)
5865    }
5866    fn modalities(&self, _config: &str) -> Result<Modalities> {
5867        Ok(Modalities {
5868            input: vec![SupportedModality::Text, SupportedModality::Vision],
5869            output: vec![SupportedModality::Text],
5870        })
5871    }
5872}
5873
5874impl IsqModelLoader for Qwen3VLMoELoader {
5875    fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5876        Ok(vec![
5877            Regex::new(r"lm_head\.(weight|bias)$")?,
5878            // Attention
5879            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5880            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5881            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5882            Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5883            // MLP (dense layers)
5884            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5885            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5886            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5887            // MoE router
5888            Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
5889            // MoE experts - now unpacked into individual experts
5890            Regex::new(
5891                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
5892            )?,
5893            Regex::new(
5894                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
5895            )?,
5896            Regex::new(
5897                r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
5898            )?,
5899        ])
5900    }
5901    fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5902        self.isq_layer_regexes(config)
5903    }
5904}
5905
5906impl DeviceMappedModelLoader for Qwen3VLMoELoader {
5907    fn mapped_max_act_size_elems(
5908        &self,
5909        config: &str,
5910        params: &AutoDeviceMapParams,
5911    ) -> Result<usize> {
5912        let AutoDeviceMapParams::Vision {
5913            max_seq_len,
5914            max_batch_size,
5915            max_image_shape,
5916            max_num_images,
5917        } = params
5918        else {
5919            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5920        };
5921
5922        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5923
5924        // For images, grid_t=1. After spatial merging, grid_h and grid_w are reduced.
5925        let img_seq_len = {
5926            let cfg = &cfg.vision_config;
5927            // grid_t is 1 for images (temporal dimension is for video only)
5928            let grid_t = 1;
5929            // After patch embedding and spatial merge, the effective grid dimensions are reduced
5930            let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5931            let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5932            grid_t * grid_h * grid_w * max_num_images
5933        };
5934
5935        let max_text_attn = {
5936            let cfg = &cfg.text_config;
5937            // This model injects the vision information directly into the input embeddings
5938            let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5939            max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5940        };
5941
5942        Ok(max_text_attn)
5943    }
5944
5945    fn non_mapped_max_act_size_elems(
5946        &self,
5947        config: &str,
5948        params: &AutoDeviceMapParams,
5949    ) -> Result<usize> {
5950        let AutoDeviceMapParams::Vision {
5951            max_seq_len: _,
5952            max_batch_size,
5953            max_image_shape,
5954            max_num_images,
5955        } = params
5956        else {
5957            anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5958        };
5959
5960        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5961
5962        // For the vision encoder, before spatial merging
5963        let img_seq_len = {
5964            let cfg = &cfg.vision_config;
5965            // grid_t is 1 for images
5966            let grid_t = 1;
5967            let grid_h = max_image_shape.0 / cfg.patch_size;
5968            let grid_w = max_image_shape.1 / cfg.patch_size;
5969            grid_t * grid_h * grid_w
5970        };
5971
5972        let max_vision_attn = {
5973            let cfg = &cfg.vision_config;
5974            (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5975        };
5976
5977        Ok(max_vision_attn)
5978    }
5979
5980    fn non_mapped_size_in_bytes(
5981        &self,
5982        config: &str,
5983        dtype: DType,
5984        weight_pack_factor: usize,
5985        _matformer_config: Option<&MatformerSliceConfig>,
5986    ) -> Result<usize> {
5987        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5988        let tie = cfg.tie_word_embeddings;
5989        let text_elems = {
5990            let cfg = &cfg.text_config;
5991            let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5992            // If embeddings are tied and no packing, reuse weights -> no separate lm_head needed
5993            let lm_head = if !tie || weight_pack_factor != 1 {
5994                cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5995            } else {
5996                0
5997            };
5998            let norm = cfg.hidden_size;
5999            embed_tokens + lm_head + norm
6000        };
6001
6002        let patch_merger = {
6003            let cfg = &cfg.vision_config;
6004            let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6005
6006            let mlp0 = hidden_size * hidden_size + hidden_size;
6007            let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
6008
6009            let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6010
6011            mlp0 + mlp2 + ln_q
6012        };
6013
6014        let patch_embed = {
6015            let cfg = &cfg.vision_config;
6016            let conv_cfg = Conv3dConfig {
6017                stride: cfg.patch_size,
6018                ..Default::default()
6019            };
6020            let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6021            cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6022                * kernel_sizes[0]
6023                * kernel_sizes[1]
6024                * kernel_sizes[2]
6025        };
6026
6027        let encoder_layer = {
6028            let cfg = &cfg.vision_config;
6029            let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6030            let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6031
6032            #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
6033            let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6034            let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6035
6036            let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6037            let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6038
6039            norm1 + norm2 + fc1 + fc2 + qkv + out
6040        };
6041
6042        let elems =
6043            text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
6044
6045        Ok(elems * dtype.size_in_bytes())
6046    }
6047
6048    fn layer_sizes_in_bytes(
6049        &self,
6050        config: &str,
6051        dtype: DType,
6052        weight_pack_factor: usize,
6053        _matformer_config: Option<&MatformerSliceConfig>,
6054    ) -> Result<Vec<usize>> {
6055        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6056        let text_cfg = &cfg.text_config;
6057
6058        let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6059
6060        for layer_idx in 0..text_cfg.num_hidden_layers {
6061            let input_layernorm = text_cfg.hidden_size;
6062            let post_attention_layernorm = text_cfg.hidden_size;
6063
6064            let size_in = text_cfg.hidden_size;
6065            let size_q = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6066                * text_cfg.num_attention_heads;
6067            let size_kv = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6068                * text_cfg.num_key_value_heads;
6069            let q_proj = size_in * size_q / weight_pack_factor + size_q;
6070            let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
6071            let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
6072            let o_proj = size_q * size_in / weight_pack_factor;
6073
6074            // Check if this is a MoE layer
6075            let is_moe = !text_cfg.mlp_only_layers.contains(&layer_idx)
6076                && (text_cfg.num_experts > 0
6077                    && (layer_idx + 1) % text_cfg.decoder_sparse_step == 0);
6078
6079            let mlp_elems = if is_moe {
6080                // MoE layer: gate + experts
6081                let gate = text_cfg.hidden_size * text_cfg.num_experts;
6082                let per_expert = {
6083                    let h_size = text_cfg.hidden_size;
6084                    let i_size = text_cfg.moe_intermediate_size;
6085                    let gate_proj = h_size * i_size / weight_pack_factor;
6086                    let up_proj = h_size * i_size / weight_pack_factor;
6087                    let down_proj = i_size * h_size / weight_pack_factor;
6088                    gate_proj + up_proj + down_proj
6089                };
6090                gate + per_expert * text_cfg.num_experts
6091            } else {
6092                // Dense MLP layer
6093                let h_size = text_cfg.hidden_size;
6094                let i_size = text_cfg.intermediate_size;
6095                let gate_proj = h_size * i_size / weight_pack_factor;
6096                let up_proj = h_size * i_size / weight_pack_factor;
6097                let down_proj = i_size * h_size / weight_pack_factor;
6098                gate_proj + up_proj + down_proj
6099            };
6100
6101            let per_layer_elems = input_layernorm
6102                + post_attention_layernorm
6103                + q_proj
6104                + k_proj
6105                + v_proj
6106                + o_proj
6107                + mlp_elems;
6108
6109            layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6110        }
6111
6112        Ok(layer_sizes)
6113    }
6114
6115    fn num_layers(&self, config: &str) -> Result<usize> {
6116        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6117        let cfg = &cfg.text_config;
6118        Ok(cfg.num_hidden_layers)
6119    }
6120
6121    fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6122        let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6123        let cfg = &cfg.text_config;
6124
6125        let cfg = ModelConfigMetadata {
6126            max_seq_len: cfg.max_position_embeddings,
6127            num_layers: cfg.num_hidden_layers,
6128            hidden_size: cfg.hidden_size,
6129            num_kv_heads: cfg.num_key_value_heads,
6130            num_attn_heads: cfg.num_attention_heads,
6131            sliding_window: cfg.sliding_window,
6132            k_head_dim: cfg.head_dim,
6133            v_head_dim: cfg.head_dim,
6134            kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
6135        };
6136
6137        Ok(Box::new(cfg))
6138    }
6139
6140    fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6141        Some(vec![NonMappedSubModel::Vision])
6142    }
6143}