1use std::any::Any;
2use std::sync::Arc;
3use std::{fmt::Debug, str::FromStr};
4
5use anyhow::Result;
6use candle_core::{DType, Device, Tensor, D};
7use candle_nn::Conv2dConfig;
8use image::{ColorType, DynamicImage};
9use itertools::Itertools;
10use mistralrs_quant::log::once_log_info;
11use mistralrs_quant::ShardedVarBuilder;
12
13#[cfg(feature = "pyo3_macros")]
14use pyo3::pyclass;
15
16use regex::Regex;
17use serde::Deserialize;
18
19use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
20
21use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
22use crate::amoe::AnyMoeBaseModelMixin;
23use crate::attention::ATTENTION_CHUNK_SIZE;
24use crate::device_map::DeviceMapper;
25use crate::layers::Conv3dConfig;
26use crate::matformer::MatformerSliceConfig;
27use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
28use crate::pipeline::isq::IsqModelLoader;
29use crate::pipeline::loaders::AutoDeviceMapParams;
30use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
31use crate::pipeline::{
32 EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
33 SupportedModality,
34};
35use crate::utils::varbuilder_utils::DeviceForLoadTensor;
36use crate::vision_models::clip::ClipConfig;
37use crate::vision_models::gemma3::config::Gemma3Config;
38use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
39use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
40use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
41use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
42use crate::vision_models::idefics2_input_processor::Idefics2Processor;
43use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
44use crate::vision_models::image_processor::ImagePreProcessor;
45use crate::vision_models::inputs_processor::Phi4MMProcessor;
46use crate::vision_models::llama4::{
47 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
48};
49use crate::vision_models::llava::config::Config as LLaVAConfig;
50use crate::vision_models::llava15::Model as LLaVA;
51use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
52use crate::vision_models::llava_next::Model as LLaVANext;
53use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
54use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
55use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
56use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
57use crate::vision_models::phi3_inputs_processor::Phi3Processor;
58use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
59use crate::vision_models::preprocessor_config::PreProcessorConfig;
60use crate::vision_models::processor_config::ProcessorConfig;
61use crate::vision_models::qwen2_5_vl::{
62 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
63};
64use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
65use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
66use crate::vision_models::qwen3_vl_moe::{
67 Config as Qwen3VLMoEConfig, Qwen3VLMoEModel, Qwen3VLMoEProcessor,
68};
69use crate::vision_models::{minicpmo, phi4};
70
71pub trait VisionModel: IsqModel + AnyMoeBaseModelMixin {
72 #[allow(clippy::too_many_arguments)]
74 fn forward(
75 &self,
76 input_ids: &Tensor,
77 pixel_values: Option<Tensor>,
78 seqlen_offsets: &[usize],
79 context_lens: Vec<(usize, usize)>,
80 position_ids: Vec<usize>,
81 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
83 flash_params: &FlashParams,
84 ) -> candle_core::Result<Tensor>;
85 fn device(&self) -> &Device;
86 fn cache(&self) -> &EitherCache;
87 fn cache_mut(&mut self) -> &mut EitherCache;
88 fn max_seq_len(&self) -> usize;
89 fn config(&self) -> &ModelConfigMetadata;
90 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
92}
93
94pub trait VisionModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
95 fn load(
96 &self,
97 config: &str,
98 vb: ShardedVarBuilder,
99 normal_loading_metadata: NormalLoadingMetadata,
100 attention_mechanism: AttentionImplementation,
101 ) -> Result<Box<dyn VisionModel + Send + Sync>>;
102 fn is_gptx(&self, config: &str) -> bool;
103 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
104 fn get_processor(
105 &self,
106 model_config: &str,
107 processor_config: Option<ProcessorConfig>,
108 preprocessor_config: PreProcessorConfig,
109 max_edge: Option<u32>,
110 ) -> Arc<dyn Processor + Send + Sync>;
111 fn supports_paged_attention(&self, config: &str) -> bool;
112 fn supports_prefix_cacher(&self, _config: &str) -> bool {
113 false
115 }
116 fn modalities(&self, config: &str) -> Result<Modalities>;
117 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
118 fn get_device_for_tensor(
119 &self,
120 config: &str,
121 _mapper: &dyn DeviceMapper,
122 loading_isq: bool,
123 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
124 if loading_isq {
125 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
126 } else {
127 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
128 let num_layers = self.model_config(config)?.num_layers();
129 let closure = move |name: String| {
130 if let Some(captures) = re.captures(&name) {
131 captures
132 .get(1)
133 .and_then(|m| m.as_str().parse::<usize>().ok())
134 .map(|l| l.min(num_layers))
135 .map(DeviceForLoadTensor::Idx)
136 .unwrap_or(DeviceForLoadTensor::Base)
137 } else {
138 DeviceForLoadTensor::Base
139 }
140 };
141
142 Ok(Arc::new(closure))
143 }
144 }
145}
146
147#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
148#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
149pub enum VisionLoaderType {
151 #[serde(rename = "phi3v")]
152 Phi3V,
153 #[serde(rename = "idefics2")]
154 Idefics2,
155 #[serde(rename = "llava_next")]
156 LLaVANext,
157 #[serde(rename = "llava")]
158 LLaVA,
159 #[serde(rename = "vllama")]
160 VLlama,
161 #[serde(rename = "qwen2vl")]
162 Qwen2VL,
163 #[serde(rename = "idefics3")]
164 Idefics3,
165 #[serde(rename = "minicpmo")]
166 MiniCpmO,
167 #[serde(rename = "phi4mm")]
168 Phi4MM,
169 #[serde(rename = "qwen2_5vl")]
170 Qwen2_5VL,
171 #[serde(rename = "gemma3")]
172 Gemma3,
173 #[serde(rename = "mistral3")]
174 Mistral3,
175 #[serde(rename = "llama4")]
176 Llama4,
177 #[serde(rename = "gemma3n")]
178 Gemma3n,
179 #[serde(rename = "qwen3vl")]
180 Qwen3VL,
181 #[serde(rename = "qwen3vlmoe")]
182 Qwen3VLMoE,
183}
184
185impl VisionLoaderType {
187 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
188 match name {
189 "Phi3VForCausalLM" => Ok(Self::Phi3V),
190 "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
191 "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
192 "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
193 "MllamaForConditionalGeneration" => Ok(Self::VLlama),
194 "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
195 "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
196 "MiniCPMO" => Ok(Self::MiniCpmO),
197 "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
198 "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
199 "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
200 "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
201 "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
202 "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
203 "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
204 "Qwen3VLMoeForConditionalGeneration" => Ok(Self::Qwen3VLMoE),
205 other => anyhow::bail!(
206 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
207 ),
208 }
209 }
210}
211
212impl FromStr for VisionLoaderType {
213 type Err = String;
214 fn from_str(s: &str) -> Result<Self, Self::Err> {
215 match s {
216 "phi3v" => Ok(Self::Phi3V),
217 "idefics2" => Ok(Self::Idefics2),
218 "llava_next" => Ok(Self::LLaVANext),
219 "llava" => Ok(Self::LLaVA),
220 "vllama" => Ok(Self::VLlama),
221 "qwen2vl" => Ok(Self::Qwen2VL),
222 "idefics3" => Ok(Self::Idefics3),
223 "minicpmo" => Ok(Self::MiniCpmO),
224 "phi4mm" => Ok(Self::Phi4MM),
225 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
226 "gemma3" => Ok(Self::Gemma3),
227 "mistral3" => Ok(Self::Mistral3),
228 "llama4" => Ok(Self::Llama4),
229 "gemma3n" => Ok(Self::Gemma3n),
230 "qwen3vl" => Ok(Self::Qwen3VL),
231 "qwen3vlmoe" => Ok(Self::Qwen3VLMoE),
232 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`, `qwen3vl`, `qwen3vlmoe`.")),
233 }
234 }
235}
236
237impl std::fmt::Display for VisionLoaderType {
238 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
239 let name = match self {
240 VisionLoaderType::Phi3V => "phi3v",
241 VisionLoaderType::Idefics2 => "idefics2",
242 VisionLoaderType::LLaVANext => "llava_next",
243 VisionLoaderType::LLaVA => "llava",
244 VisionLoaderType::VLlama => "vllama",
245 VisionLoaderType::Qwen2VL => "qwen2vl",
246 VisionLoaderType::Idefics3 => "idefics3",
247 VisionLoaderType::MiniCpmO => "minicpmo",
248 VisionLoaderType::Phi4MM => "phi4mm",
249 VisionLoaderType::Qwen2_5VL => "qwen2_5vl",
250 VisionLoaderType::Gemma3 => "gemma3",
251 VisionLoaderType::Mistral3 => "mistral3",
252 VisionLoaderType::Llama4 => "llama4",
253 VisionLoaderType::Gemma3n => "gemma3n",
254 VisionLoaderType::Qwen3VL => "qwen3vl",
255 VisionLoaderType::Qwen3VLMoE => "qwen3vlmoe",
256 };
257 write!(f, "{name}")
258 }
259}
260
261#[derive(Deserialize)]
262struct AutoVisionLoaderConfig {
263 architectures: Vec<String>,
264}
265
266pub struct AutoVisionLoader;
268
269impl AutoVisionLoader {
270 fn get_loader(config: &str) -> Result<Box<dyn VisionModelLoader>> {
271 let auto_cfg: AutoVisionLoaderConfig = serde_json::from_str(config)?;
272 if auto_cfg.architectures.len() != 1 {
273 anyhow::bail!("Expected exactly one architecture in config");
274 }
275
276 let name = &auto_cfg.architectures[0];
277 let tp = VisionLoaderType::from_causal_lm_name(name)?;
278
279 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
280
281 Ok(match tp {
283 VisionLoaderType::Phi3V => Box::new(Phi3VLoader),
284 VisionLoaderType::Idefics2 => Box::new(Idefics2Loader),
285 VisionLoaderType::LLaVANext => Box::new(LLaVANextLoader),
286 VisionLoaderType::LLaVA => Box::new(LLaVALoader),
287 VisionLoaderType::VLlama => Box::new(VLlamaLoader),
288 VisionLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
289 VisionLoaderType::Idefics3 => Box::new(Idefics3Loader),
290 VisionLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
291 VisionLoaderType::Phi4MM => Box::new(Phi4MMLoader),
292 VisionLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
293 VisionLoaderType::Gemma3 => Box::new(Gemma3Loader),
294 VisionLoaderType::Mistral3 => Box::new(Mistral3Loader),
295 VisionLoaderType::Llama4 => Box::new(VLlama4Loader),
296 VisionLoaderType::Gemma3n => Box::new(Gemma3nLoader),
297 VisionLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
298 VisionLoaderType::Qwen3VLMoE => Box::new(Qwen3VLMoELoader),
299 })
300 }
301}
302
303impl VisionModelLoader for AutoVisionLoader {
304 fn load(
305 &self,
306 config: &str,
307 vb: ShardedVarBuilder,
308 normal_loading_metadata: NormalLoadingMetadata,
309 attention_mechanism: AttentionImplementation,
310 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
311 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
312 }
313
314 fn is_gptx(&self, config: &str) -> bool {
315 Self::get_loader(config)
316 .expect("AutoVisionLoader get_loader")
317 .is_gptx(config)
318 }
319
320 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
321 Self::get_loader(config)?.get_config_repr(config)
322 }
323
324 fn get_processor(
325 &self,
326 model_config: &str,
327 proc_cfg: Option<ProcessorConfig>,
328 preproc_cfg: PreProcessorConfig,
329 max_edge: Option<u32>,
330 ) -> Arc<dyn Processor + Send + Sync> {
331 Self::get_loader(model_config)
332 .expect("AutoVisionLoader get_loader")
333 .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
334 }
335
336 fn supports_paged_attention(&self, config: &str) -> bool {
337 Self::get_loader(config)
338 .expect("AutoVisionLoader")
339 .supports_paged_attention(config)
340 }
341
342 fn modalities(&self, config: &str) -> Result<Modalities> {
343 Self::get_loader(config)?.modalities(config)
344 }
345
346 fn supports_prefix_cacher(&self, config: &str) -> bool {
347 Self::get_loader(config)
348 .expect("AutoVisionLoader")
349 .supports_prefix_cacher(config)
350 }
351
352 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
353 Self::get_loader(config)
354 .expect("AutoVisionLoader")
355 .prefixer(config)
356 }
357
358 fn get_device_for_tensor(
359 &self,
360 config: &str,
361 mapper: &dyn DeviceMapper,
362 loading_isq: bool,
363 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
364 Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
365 }
366}
367
368impl IsqModelLoader for AutoVisionLoader {
369 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
370 Self::get_loader(config)?.isq_layer_regexes(config)
371 }
372 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
373 Self::get_loader(config)?.immediate_isq_predicates(config)
374 }
375}
376
377impl DeviceMappedModelLoader for AutoVisionLoader {
378 fn mapped_max_act_size_elems(
379 &self,
380 config: &str,
381 params: &AutoDeviceMapParams,
382 ) -> Result<usize> {
383 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
384 }
385 fn non_mapped_max_act_size_elems(
386 &self,
387 config: &str,
388 params: &AutoDeviceMapParams,
389 ) -> Result<usize> {
390 Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
391 }
392 fn non_mapped_size_in_bytes(
393 &self,
394 config: &str,
395 dtype: DType,
396 weight_pack_factor: usize,
397 _matformer_config: Option<&MatformerSliceConfig>,
398 ) -> Result<usize> {
399 Self::get_loader(config)?.non_mapped_size_in_bytes(
400 config,
401 dtype,
402 weight_pack_factor,
403 _matformer_config,
404 )
405 }
406 fn layer_sizes_in_bytes(
407 &self,
408 config: &str,
409 dtype: DType,
410 weight_pack_factor: usize,
411 _matformer_config: Option<&MatformerSliceConfig>,
412 ) -> Result<Vec<usize>> {
413 Self::get_loader(config)?.layer_sizes_in_bytes(
414 config,
415 dtype,
416 weight_pack_factor,
417 _matformer_config,
418 )
419 }
420 fn num_layers(&self, config: &str) -> Result<usize> {
421 Self::get_loader(config)?.num_layers(config)
422 }
423 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
424 Self::get_loader(config)?.model_config(config)
425 }
426}
427
428macro_rules! bias_if {
429 ($cond:expr, $size:expr) => {
430 if $cond {
431 $size
432 } else {
433 0
434 }
435 };
436}
437
438fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
439 let pre_layer_norm = cfg.hidden_size;
440 let final_layer_norm = cfg.hidden_size;
441
442 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
443 let num_positions = num_patches + 1;
444
445 let class_embedding = cfg.hidden_size;
446
447 let position_ids = num_positions;
448 let position_embedding = num_positions * cfg.hidden_size;
449
450 let conv2dconfig = Conv2dConfig {
451 stride: cfg.patch_size,
452 ..Default::default()
453 };
454 let patch_embedding =
455 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
456
457 let encoder_layer_elems = {
458 let layer_norm1 = cfg.hidden_size;
459 let layer_norm2 = cfg.hidden_size;
460
461 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
462 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
463 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
464 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
465
466 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
467 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
468
469 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
470 };
471
472 pre_layer_norm
473 + final_layer_norm
474 + class_embedding
475 + position_ids
476 + position_embedding
477 + patch_embedding
478 + cfg.num_hidden_layers * encoder_layer_elems
479}
480
481pub struct Phi3VLoader;
487
488pub struct Phi3VPrefixer;
489
490impl MultimodalPromptPrefixer for Phi3VPrefixer {
491 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
492 format!(
494 "{}{prompt}",
495 image_indexes
496 .into_iter()
497 .map(|image_index| format!("<|image_{}|>", image_index + 1))
498 .join("")
499 )
500 }
501}
502
503impl VisionModelLoader for Phi3VLoader {
504 fn load(
505 &self,
506 config: &str,
507 vb: ShardedVarBuilder,
508 normal_loading_metadata: NormalLoadingMetadata,
509 attention_mechanism: AttentionImplementation,
510 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
511 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
512 Ok(Box::new(Phi3::new(
513 &cfg,
514 vb,
515 self.is_gptx(config),
516 normal_loading_metadata,
517 attention_mechanism,
518 )?))
519 }
520 fn is_gptx(&self, _config: &str) -> bool {
521 true
522 }
523 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
524 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
525 Ok(Box::new(cfg))
526 }
527 fn get_processor(
528 &self,
529 _model_config: &str,
530 processor_config: Option<ProcessorConfig>,
531 preprocessor_config: PreProcessorConfig,
532 _max_edge: Option<u32>,
533 ) -> Arc<dyn Processor + Send + Sync> {
534 Phi3Processor::new_processor(processor_config, preprocessor_config)
535 }
536 fn supports_paged_attention(&self, _config: &str) -> bool {
537 true
538 }
539 fn supports_prefix_cacher(&self, _config: &str) -> bool {
540 true
541 }
542 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
543 Arc::new(Phi3VPrefixer)
544 }
545 fn modalities(&self, _config: &str) -> Result<Modalities> {
546 Ok(Modalities {
547 input: vec![SupportedModality::Text, SupportedModality::Vision],
548 output: vec![SupportedModality::Text],
549 })
550 }
551}
552
553impl IsqModelLoader for Phi3VLoader {
554 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
555 Ok(vec![
556 Regex::new(r"lm_head\.(weight|bias)$")?,
557 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
559 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
560 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
562 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
563 ])
564 }
565 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
566 self.isq_layer_regexes(config)
567 }
568}
569
570impl DeviceMappedModelLoader for Phi3VLoader {
571 fn mapped_max_act_size_elems(
572 &self,
573 config: &str,
574 params: &AutoDeviceMapParams,
575 ) -> Result<usize> {
576 let AutoDeviceMapParams::Vision {
578 max_seq_len,
579 max_batch_size,
580 max_image_shape: _,
581 max_num_images,
582 } = params
583 else {
584 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
585 };
586
587 let cfg: Phi3Config = serde_json::from_str(config)?;
588
589 let vcfg = &PHI3V_CLIP_CONFIG;
590
591 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
592 let img_seq_len = (num_patches + 1) * max_num_images;
593
594 let max_text_attn = {
595 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
597 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
598 };
599
600 Ok(max_text_attn)
601 }
602
603 fn non_mapped_max_act_size_elems(
604 &self,
605 config: &str,
606 params: &AutoDeviceMapParams,
607 ) -> Result<usize> {
608 let AutoDeviceMapParams::Vision {
610 max_seq_len: _,
611 max_batch_size,
612 max_image_shape: _,
613 max_num_images,
614 } = params
615 else {
616 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
617 };
618
619 let cfg: Phi3Config = serde_json::from_str(config)?;
620
621 let vcfg = &PHI3V_CLIP_CONFIG;
622
623 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
624 let img_seq_len = num_patches + 1;
625
626 let max_vision_attn = {
627 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
628 };
629
630 Ok(max_vision_attn)
631 }
632
633 fn non_mapped_size_in_bytes(
634 &self,
635 config: &str,
636 dtype: DType,
637 weight_pack_factor: usize,
638 _matformer_config: Option<&MatformerSliceConfig>,
639 ) -> Result<usize> {
640 let cfg: Phi3Config = serde_json::from_str(config)?;
641 let elems = {
642 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
643 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
645 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
646 } else {
647 0
648 };
649 let norm = cfg.hidden_size;
650
651 let image_embed = {
652 let projection_cls = cfg
653 .embd_layer
654 .projection_cls
655 .clone()
656 .unwrap_or("linear".to_string());
657 let with_learnable_separator =
658 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
659 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
660 let image_dim_out = cfg.img_processor.image_dim_out;
661
662 let proj = match (projection_cls.as_str(), use_hd_transform) {
663 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
664 ("mlp", true) => {
665 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
666 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
667 a + b
668 }
669 ("mlp", false) => {
670 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
671 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
672 a + b
673 }
674 _ => {
675 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
676 }
677 };
678
679 let (glb_gn, sub_gn) = if with_learnable_separator {
680 let glb_gn = image_dim_out * 4;
681 let sub_gn = image_dim_out * 4;
682 (glb_gn, sub_gn)
683 } else {
684 (0, 0)
685 };
686
687 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
688
689 proj + glb_gn + sub_gn + clip_vit
690 };
691
692 embed_tokens + lm_head + norm + image_embed
693 };
694
695 Ok(elems * dtype.size_in_bytes())
696 }
697
698 fn layer_sizes_in_bytes(
699 &self,
700 config: &str,
701 dtype: DType,
702 weight_pack_factor: usize,
703 _matformer_config: Option<&MatformerSliceConfig>,
704 ) -> Result<Vec<usize>> {
705 let cfg: Phi3Config = serde_json::from_str(config)?;
706 let per_layer_elems = {
707 let input_layernorm = cfg.hidden_size;
708 let post_attention_layernorm = cfg.hidden_size;
709
710 let size_in = cfg.hidden_size;
711 let head_dim = cfg.head_dim();
712 let op_size =
713 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
714 let qkv_proj = size_in * op_size / weight_pack_factor;
715 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
716
717 let h_size = cfg.hidden_size;
718 let i_size = cfg.intermediate_size;
719 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
720 let down_proj = h_size * i_size / weight_pack_factor;
721
722 input_layernorm
723 + post_attention_layernorm
724 + qkv_proj
725 + o_proj
726 + gate_up_proj
727 + down_proj
728 };
729 Ok(vec![
730 per_layer_elems * dtype.size_in_bytes();
731 cfg.num_hidden_layers
732 ])
733 }
734
735 fn num_layers(&self, config: &str) -> Result<usize> {
736 let cfg: Phi3Config = serde_json::from_str(config)?;
737 Ok(cfg.num_hidden_layers)
738 }
739
740 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
741 let cfg: Phi3Config = serde_json::from_str(config)?;
742
743 let cfg = ModelConfigMetadata {
744 max_seq_len: cfg.max_position_embeddings,
745 num_layers: cfg.num_hidden_layers,
746 hidden_size: cfg.hidden_size,
747 num_kv_heads: cfg.num_key_value_heads,
748 num_attn_heads: cfg.num_attention_heads,
749 sliding_window: cfg.sliding_window,
750 k_head_dim: cfg.head_dim(),
751 v_head_dim: cfg.head_dim(),
752 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
753 };
754
755 Ok(Box::new(cfg))
756 }
757
758 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
759 Some(vec![NonMappedSubModel::Vision])
760 }
761}
762
763pub struct Idefics2Loader;
769
770pub struct Idefics2Prefixer;
771
772impl MultimodalPromptPrefixer for Idefics2Prefixer {
773 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
774 prompt.to_string()
776 }
777}
778
779impl VisionModelLoader for Idefics2Loader {
780 fn load(
781 &self,
782 config: &str,
783 vb: ShardedVarBuilder,
784 normal_loading_metadata: NormalLoadingMetadata,
785 attention_mechanism: AttentionImplementation,
786 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
787 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
788 Ok(Box::new(Idefics2::new(
789 &cfg,
790 vb,
791 self.is_gptx(config),
792 normal_loading_metadata,
793 attention_mechanism,
794 )?))
795 }
796 fn is_gptx(&self, _config: &str) -> bool {
797 true
798 }
799 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
800 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
801 Ok(Box::new(cfg))
802 }
803 fn get_processor(
804 &self,
805 _model_config: &str,
806 processor_config: Option<ProcessorConfig>,
807 preprocessor_config: PreProcessorConfig,
808 max_edge: Option<u32>,
809 ) -> Arc<dyn Processor + Send + Sync> {
810 Arc::new(Idefics2Processor::new(
811 processor_config.unwrap(),
812 preprocessor_config,
813 max_edge,
814 ))
815 }
816 fn supports_paged_attention(&self, _config: &str) -> bool {
817 true
818 }
819 fn supports_prefix_cacher(&self, _config: &str) -> bool {
820 true
821 }
822 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
823 Arc::new(Idefics2Prefixer)
824 }
825 fn modalities(&self, _config: &str) -> Result<Modalities> {
826 Ok(Modalities {
827 input: vec![SupportedModality::Text, SupportedModality::Vision],
828 output: vec![SupportedModality::Text],
829 })
830 }
831}
832
833impl IsqModelLoader for Idefics2Loader {
834 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
835 Ok(vec![
836 Regex::new(r"lm_head\.(weight|bias)$")?,
837 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
839 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
840 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
841 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
842 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
844 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
845 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
846 ])
847 }
848 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
849 Ok(vec![
850 Regex::new(r"lm_head\.(weight|bias)$")?,
851 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
853 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
854 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
855 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
856 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
858 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
859 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
860 ])
861 }
862}
863
864impl DeviceMappedModelLoader for Idefics2Loader {
865 fn mapped_max_act_size_elems(
866 &self,
867 config: &str,
868 params: &AutoDeviceMapParams,
869 ) -> Result<usize> {
870 let AutoDeviceMapParams::Vision {
871 max_seq_len,
872 max_batch_size,
873 max_image_shape: _,
874 max_num_images,
875 } = params
876 else {
877 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
878 };
879
880 let cfg: Idefics2Config = serde_json::from_str(config)?;
881
882 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
883 let img_seq_len = (num_patches + 1) * max_num_images;
884
885 let max_text_attn = {
886 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
888 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
889 };
890
891 Ok(max_text_attn)
892 }
893
894 fn non_mapped_max_act_size_elems(
895 &self,
896 config: &str,
897 params: &AutoDeviceMapParams,
898 ) -> Result<usize> {
899 let AutoDeviceMapParams::Vision {
900 max_seq_len: _,
901 max_batch_size,
902 max_image_shape: _,
903 max_num_images,
904 } = params
905 else {
906 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
907 };
908
909 let cfg: Idefics2Config = serde_json::from_str(config)?;
910
911 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
912 let img_seq_len = num_patches + 1;
913
914 let max_vision_attn = {
915 let images_factor = 5;
917
918 (max_batch_size * images_factor * max_num_images)
919 * cfg.vision_config.num_attention_heads
920 * img_seq_len
921 * img_seq_len
922 };
923
924 Ok(max_vision_attn)
925 }
926
927 fn non_mapped_size_in_bytes(
928 &self,
929 config: &str,
930 dtype: DType,
931 weight_pack_factor: usize,
932 _matformer_config: Option<&MatformerSliceConfig>,
933 ) -> Result<usize> {
934 let cfg: Idefics2Config = serde_json::from_str(config)?;
935 let text_elems = {
936 let tie_word_embeddings = cfg.tie_word_embeddings;
937 let cfg = &cfg.text_config;
938
939 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
940 let lm_head = if !tie_word_embeddings {
941 cfg.hidden_size * cfg.vocab_size
942 } else {
943 0
944 };
945 let norm = cfg.hidden_size;
946 embed_tokens + lm_head + norm
947 };
948
949 let connector_elems = {
950 let tcfg = &cfg.text_config;
951 let vcfg = &cfg.vision_config;
952 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
953 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
954 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
955
956 let perceiver_elems = {
957 let tcfg = &cfg.text_config;
958 let pcfg = &cfg.perceiver_config;
959
960 let n_latents = pcfg.resampler_n_latents;
961 let hidden_size = tcfg.hidden_size;
962 let depth = pcfg.resampler_depth;
963
964 let norm = tcfg.hidden_size;
965 let latents = n_latents * hidden_size;
966
967 let layer_elems = {
968 let input_latents_norm = hidden_size;
969 let input_context_norm = hidden_size;
970 let post_attn_norm = hidden_size;
971
972 let num_heads = pcfg.resampler_n_heads;
973 let head_dim = pcfg.resampler_head_dim;
974 let num_key_value_heads = pcfg.num_key_value_heads;
975
976 let q_proj = hidden_size * num_heads * head_dim;
977 let k_proj = hidden_size * num_key_value_heads * head_dim;
978 let v_proj = hidden_size * num_key_value_heads * head_dim;
979 let o_proj = num_heads * head_dim * hidden_size;
980
981 let gate_proj = hidden_size * hidden_size * 4;
982 let up_proj = hidden_size * hidden_size * 4;
983 let down_proj = hidden_size * 4 * hidden_size;
984
985 input_latents_norm
986 + input_context_norm
987 + post_attn_norm
988 + q_proj
989 + k_proj
990 + v_proj
991 + o_proj
992 + gate_proj
993 + up_proj
994 + down_proj
995 };
996
997 norm + latents + layer_elems * depth
998 };
999
1000 gate_proj + up_proj + down_proj + perceiver_elems
1001 };
1002
1003 let vision_transformer = {
1004 let cfg = &cfg.vision_config;
1005
1006 let post_layernorm = cfg.hidden_size;
1007
1008 let conv_config = Conv2dConfig {
1009 stride: cfg.patch_size,
1010 ..Default::default()
1011 };
1012 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1013 * cfg.patch_size
1014 * cfg.patch_size;
1015
1016 let num_patches_per_side = cfg.image_size / cfg.patch_size;
1017 let num_patches = num_patches_per_side.pow(2);
1018 let position_embedding = num_patches * cfg.hidden_size;
1019
1020 let layer_elems = {
1021 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1022 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1023
1024 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1025 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1026
1027 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1028 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1029 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1030 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1031
1032 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1033 };
1034
1035 post_layernorm + patch_embedding + position_embedding + layer_elems
1036 };
1037
1038 let elems = text_elems + connector_elems + vision_transformer;
1039
1040 Ok(elems * dtype.size_in_bytes())
1041 }
1042
1043 fn layer_sizes_in_bytes(
1044 &self,
1045 config: &str,
1046 dtype: DType,
1047 weight_pack_factor: usize,
1048 _matformer_config: Option<&MatformerSliceConfig>,
1049 ) -> Result<Vec<usize>> {
1050 let cfg: Idefics2Config = serde_json::from_str(config)?;
1051 let cfg = cfg.text_config;
1052 let per_layer_elems = {
1053 let input_layernorm = cfg.hidden_size;
1054 let post_attention_layernorm = cfg.hidden_size;
1055
1056 let size_in = cfg.hidden_size;
1057 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1058 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1059 let q_proj = size_in * size_q / weight_pack_factor;
1060 let k_proj = size_in * size_kv / weight_pack_factor;
1061 let v_proj = size_in * size_kv / weight_pack_factor;
1062 let o_proj = size_q * size_in / weight_pack_factor;
1063
1064 let h_size = cfg.hidden_size;
1065 let i_size = cfg.intermediate_size;
1066 let gate_proj = h_size * i_size / weight_pack_factor;
1067 let up_proj = h_size * i_size / weight_pack_factor;
1068 let down_proj = i_size * h_size / weight_pack_factor;
1069
1070 input_layernorm
1071 + post_attention_layernorm
1072 + q_proj
1073 + k_proj
1074 + v_proj
1075 + o_proj
1076 + gate_proj
1077 + up_proj
1078 + down_proj
1079 };
1080 Ok(vec![
1081 per_layer_elems * dtype.size_in_bytes();
1082 cfg.num_hidden_layers
1083 ])
1084 }
1085
1086 fn num_layers(&self, config: &str) -> Result<usize> {
1087 let cfg: Idefics2Config = serde_json::from_str(config)?;
1088 Ok(cfg.text_config.num_hidden_layers)
1089 }
1090 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1091 let cfg: Idefics2Config = serde_json::from_str(config)?;
1092 let cfg = &cfg.text_config;
1093
1094 let cfg = ModelConfigMetadata {
1095 max_seq_len: cfg.max_position_embeddings,
1096 num_layers: cfg.num_hidden_layers,
1097 hidden_size: cfg.hidden_size,
1098 num_kv_heads: cfg.num_key_value_heads,
1099 num_attn_heads: cfg.num_attention_heads,
1100 sliding_window: cfg.sliding_window,
1101 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1102 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1103 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1104 };
1105
1106 Ok(Box::new(cfg))
1107 }
1108
1109 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1110 Some(vec![NonMappedSubModel::Vision])
1111 }
1112}
1113
1114pub struct LLaVANextLoader;
1120
1121pub struct LLaVANextPrefixer;
1122
1123impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1124 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1125 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1126 }
1127}
1128
1129impl VisionModelLoader for LLaVANextLoader {
1130 fn load(
1131 &self,
1132 config: &str,
1133 vb: ShardedVarBuilder,
1134 normal_loading_metadata: NormalLoadingMetadata,
1135 attention_mechanism: AttentionImplementation,
1136 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1137 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1138 Ok(Box::new(LLaVANext::new(
1139 &cfg,
1140 vb,
1141 self.is_gptx(config),
1142 normal_loading_metadata,
1143 attention_mechanism,
1144 )?))
1145 }
1146 fn is_gptx(&self, _config: &str) -> bool {
1147 false
1148 }
1149 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1150 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1151 Ok(Box::new(cfg))
1152 }
1153 fn get_processor(
1154 &self,
1155 model_config: &str,
1156 _processor_config: Option<ProcessorConfig>,
1157 _preprocessor_config: PreProcessorConfig,
1158 _max_edge: Option<u32>,
1159 ) -> Arc<dyn Processor + Send + Sync> {
1160 Arc::new(LLaVANextProcessor::new(model_config))
1161 }
1162 fn supports_paged_attention(&self, _config: &str) -> bool {
1163 true
1164 }
1165 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1166 true
1167 }
1168 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1169 Arc::new(LLaVANextPrefixer)
1170 }
1171 fn modalities(&self, _config: &str) -> Result<Modalities> {
1172 Ok(Modalities {
1173 input: vec![SupportedModality::Text, SupportedModality::Vision],
1174 output: vec![SupportedModality::Text],
1175 })
1176 }
1177}
1178
1179impl IsqModelLoader for LLaVANextLoader {
1180 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1181 Ok(vec![
1182 Regex::new(r"lm_head\.(weight|bias)$")?,
1183 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1185 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1186 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1187 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1188 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1190 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1191 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1192 ])
1193 }
1194 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1195 Ok(vec![
1196 Regex::new(r"lm_head\.(weight|bias)$")?,
1197 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1199 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1200 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1201 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1202 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1204 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1205 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1206 ])
1207 }
1208}
1209
1210impl DeviceMappedModelLoader for LLaVANextLoader {
1211 fn mapped_max_act_size_elems(
1212 &self,
1213 config: &str,
1214 params: &AutoDeviceMapParams,
1215 ) -> Result<usize> {
1216 let AutoDeviceMapParams::Vision {
1217 max_seq_len,
1218 max_batch_size,
1219 max_image_shape,
1220 max_num_images,
1221 } = params
1222 else {
1223 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1224 };
1225
1226 let config: LLaVAConfig = serde_json::from_str(config)?;
1227
1228 #[allow(clippy::cast_possible_truncation)]
1229 let img_seq_len =
1230 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1231 &config,
1232 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1233 );
1234 let img_seq_len = img_seq_len * max_num_images;
1235
1236 let max_text_attn = {
1237 let cfg = &config.text_config;
1238 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1240
1241 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1242 };
1243
1244 Ok(max_text_attn)
1245 }
1246
1247 fn non_mapped_max_act_size_elems(
1248 &self,
1249 config: &str,
1250 params: &AutoDeviceMapParams,
1251 ) -> Result<usize> {
1252 let AutoDeviceMapParams::Vision {
1253 max_seq_len: _,
1254 max_batch_size,
1255 max_image_shape,
1256 max_num_images,
1257 } = params
1258 else {
1259 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1260 };
1261
1262 let config: LLaVAConfig = serde_json::from_str(config)?;
1263
1264 #[allow(clippy::cast_possible_truncation)]
1265 let img_seq_len =
1266 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1267 &config,
1268 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1269 );
1270
1271 let max_vision_attn = {
1272 (max_batch_size * max_num_images)
1273 * config.vision_config.num_attention_heads
1274 * img_seq_len
1275 * img_seq_len
1276 };
1277
1278 Ok(max_vision_attn)
1279 }
1280
1281 fn non_mapped_size_in_bytes(
1282 &self,
1283 config: &str,
1284 dtype: DType,
1285 weight_pack_factor: usize,
1286 _matformer_config: Option<&MatformerSliceConfig>,
1287 ) -> Result<usize> {
1288 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1289 let text_elems = {
1290 let cfg = &cfg.text_config;
1291 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1292 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1293 let norm = cfg.hidden_size;
1294 embed_tokens + lm_head + norm
1295 };
1296
1297 let image_newline = cfg.text_config.hidden_size;
1298 let mmproj = {
1299 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1300 + cfg.text_config.hidden_size;
1301 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1302 + cfg.text_config.hidden_size;
1303
1304 linear_1 + linear_2
1305 };
1306 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1307
1308 let elems = text_elems + image_newline + mmproj + vision_tower;
1309 Ok(elems * dtype.size_in_bytes())
1310 }
1311
1312 fn layer_sizes_in_bytes(
1313 &self,
1314 config: &str,
1315 dtype: DType,
1316 weight_pack_factor: usize,
1317 _matformer_config: Option<&MatformerSliceConfig>,
1318 ) -> Result<Vec<usize>> {
1319 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1320 let per_layer_elems = {
1321 let cfg = &cfg.text_config;
1322 let input_layernorm = cfg.hidden_size;
1323 let post_attention_layernorm = cfg.hidden_size;
1324
1325 let size_in = cfg.hidden_size;
1326 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1327 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1328 let q_proj = size_in * size_q / weight_pack_factor;
1329 let k_proj = size_in * size_kv / weight_pack_factor;
1330 let v_proj = size_in * size_kv / weight_pack_factor;
1331 let o_proj = size_q * size_in / weight_pack_factor;
1332
1333 let h_size = cfg.hidden_size;
1334 let i_size = cfg.intermediate_size;
1335 let gate_proj = h_size * i_size / weight_pack_factor;
1336 let up_proj = h_size * i_size / weight_pack_factor;
1337 let down_proj = i_size * h_size / weight_pack_factor;
1338
1339 input_layernorm
1340 + post_attention_layernorm
1341 + q_proj
1342 + k_proj
1343 + v_proj
1344 + o_proj
1345 + gate_proj
1346 + up_proj
1347 + down_proj
1348 };
1349 Ok(vec![
1350 per_layer_elems * dtype.size_in_bytes();
1351 cfg.text_config.num_hidden_layers
1352 ])
1353 }
1354
1355 fn num_layers(&self, config: &str) -> Result<usize> {
1356 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1357 Ok(cfg.text_config.num_hidden_layers)
1358 }
1359
1360 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1361 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1362 let cfg = &cfg.text_config;
1363
1364 let cfg = ModelConfigMetadata {
1365 max_seq_len: cfg.max_position_embeddings,
1366 num_layers: cfg.num_hidden_layers,
1367 hidden_size: cfg.hidden_size,
1368 num_kv_heads: cfg.num_key_value_heads,
1369 num_attn_heads: cfg.num_attention_heads,
1370 sliding_window: cfg.sliding_window,
1371 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1372 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1373 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1374 };
1375
1376 Ok(Box::new(cfg))
1377 }
1378
1379 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1380 Some(vec![NonMappedSubModel::Vision])
1381 }
1382}
1383
1384pub struct LLaVALoader;
1390
1391pub struct LLaVAPrefixer;
1392
1393impl MultimodalPromptPrefixer for LLaVAPrefixer {
1394 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1395 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1396 }
1397}
1398
1399impl VisionModelLoader for LLaVALoader {
1400 fn load(
1401 &self,
1402 config: &str,
1403 vb: ShardedVarBuilder,
1404 normal_loading_metadata: NormalLoadingMetadata,
1405 attention_mechanism: AttentionImplementation,
1406 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1407 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1408 Ok(Box::new(LLaVA::new(
1409 &cfg,
1410 vb,
1411 self.is_gptx(config),
1412 normal_loading_metadata,
1413 attention_mechanism,
1414 )?))
1415 }
1416 fn is_gptx(&self, _config: &str) -> bool {
1417 false
1418 }
1419 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1420 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1421 Ok(Box::new(cfg))
1422 }
1423 fn get_processor(
1424 &self,
1425 model_config: &str,
1426 _processor_config: Option<ProcessorConfig>,
1427 _preprocessor_config: PreProcessorConfig,
1428 _max_edge: Option<u32>,
1429 ) -> Arc<dyn Processor + Send + Sync> {
1430 Arc::new(LLaVAProcessor::new(model_config))
1431 }
1432 fn supports_paged_attention(&self, _config: &str) -> bool {
1433 true
1434 }
1435 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1436 true
1437 }
1438 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1439 Arc::new(LLaVAPrefixer)
1440 }
1441 fn modalities(&self, _config: &str) -> Result<Modalities> {
1442 Ok(Modalities {
1443 input: vec![SupportedModality::Text, SupportedModality::Vision],
1444 output: vec![SupportedModality::Text],
1445 })
1446 }
1447}
1448
1449impl IsqModelLoader for LLaVALoader {
1450 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1451 Ok(vec![
1452 Regex::new(r"lm_head\.(weight|bias)$")?,
1453 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1455 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1456 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1457 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1458 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1460 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1461 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1462 ])
1463 }
1464 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1465 Ok(vec![
1466 Regex::new(r"lm_head\.(weight|bias)$")?,
1467 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1469 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1470 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1471 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1472 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1474 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1475 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1476 ])
1477 }
1478}
1479
1480impl DeviceMappedModelLoader for LLaVALoader {
1481 fn mapped_max_act_size_elems(
1482 &self,
1483 config: &str,
1484 params: &AutoDeviceMapParams,
1485 ) -> Result<usize> {
1486 let AutoDeviceMapParams::Vision {
1487 max_seq_len,
1488 max_batch_size,
1489 max_image_shape: _,
1490 max_num_images,
1491 } = params
1492 else {
1493 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1494 };
1495
1496 let config: LLaVAConfig = serde_json::from_str(config)?;
1497
1498 let img_seq_len =
1499 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1500 let img_seq_len = img_seq_len * max_num_images;
1501
1502 let max_text_attn = {
1503 let cfg = &config.text_config;
1504 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1506
1507 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1508 };
1509
1510 Ok(max_text_attn)
1511 }
1512
1513 fn non_mapped_max_act_size_elems(
1514 &self,
1515 config: &str,
1516 params: &AutoDeviceMapParams,
1517 ) -> Result<usize> {
1518 let AutoDeviceMapParams::Vision {
1519 max_seq_len: _,
1520 max_batch_size,
1521 max_image_shape: _,
1522 max_num_images,
1523 } = params
1524 else {
1525 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1526 };
1527
1528 let config: LLaVAConfig = serde_json::from_str(config)?;
1529
1530 let img_seq_len =
1531 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1532
1533 let max_vision_attn = {
1534 (max_batch_size * max_num_images)
1535 * config.vision_config.num_attention_heads
1536 * img_seq_len
1537 * img_seq_len
1538 };
1539
1540 Ok(max_vision_attn)
1541 }
1542
1543 fn non_mapped_size_in_bytes(
1544 &self,
1545 config: &str,
1546 dtype: DType,
1547 weight_pack_factor: usize,
1548 _matformer_config: Option<&MatformerSliceConfig>,
1549 ) -> Result<usize> {
1550 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1551 let text_elems = {
1552 let cfg = &cfg.text_config;
1553 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1554 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1555 let norm = cfg.hidden_size;
1556 embed_tokens + lm_head + norm
1557 };
1558
1559 let image_newline = cfg.text_config.hidden_size;
1560 let mmproj = {
1561 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1562 + cfg.text_config.hidden_size;
1563 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1564 + cfg.text_config.hidden_size;
1565
1566 linear_1 + linear_2
1567 };
1568 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1569
1570 let elems = text_elems + image_newline + mmproj + vision_tower;
1571 Ok(elems * dtype.size_in_bytes())
1572 }
1573
1574 fn layer_sizes_in_bytes(
1575 &self,
1576 config: &str,
1577 dtype: DType,
1578 weight_pack_factor: usize,
1579 _matformer_config: Option<&MatformerSliceConfig>,
1580 ) -> Result<Vec<usize>> {
1581 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1582 let per_layer_elems = {
1583 let cfg = &cfg.text_config;
1584 let input_layernorm = cfg.hidden_size;
1585 let post_attention_layernorm = cfg.hidden_size;
1586
1587 let size_in = cfg.hidden_size;
1588 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1589 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1590 let q_proj = size_in * size_q / weight_pack_factor;
1591 let k_proj = size_in * size_kv / weight_pack_factor;
1592 let v_proj = size_in * size_kv / weight_pack_factor;
1593 let o_proj = size_q * size_in / weight_pack_factor;
1594
1595 let h_size = cfg.hidden_size;
1596 let i_size = cfg.intermediate_size;
1597 let gate_proj = h_size * i_size / weight_pack_factor;
1598 let up_proj = h_size * i_size / weight_pack_factor;
1599 let down_proj = i_size * h_size / weight_pack_factor;
1600
1601 input_layernorm
1602 + post_attention_layernorm
1603 + q_proj
1604 + k_proj
1605 + v_proj
1606 + o_proj
1607 + gate_proj
1608 + up_proj
1609 + down_proj
1610 };
1611 Ok(vec![
1612 per_layer_elems * dtype.size_in_bytes();
1613 cfg.text_config.num_hidden_layers
1614 ])
1615 }
1616
1617 fn num_layers(&self, config: &str) -> Result<usize> {
1618 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1619 Ok(cfg.text_config.num_hidden_layers)
1620 }
1621
1622 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1623 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1624 let cfg = &cfg.text_config;
1625
1626 let cfg = ModelConfigMetadata {
1627 max_seq_len: cfg.max_position_embeddings,
1628 num_layers: cfg.num_hidden_layers,
1629 hidden_size: cfg.hidden_size,
1630 num_kv_heads: cfg.num_key_value_heads,
1631 num_attn_heads: cfg.num_attention_heads,
1632 sliding_window: cfg.sliding_window,
1633 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1634 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1635 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1636 };
1637
1638 Ok(Box::new(cfg))
1639 }
1640
1641 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1642 Some(vec![NonMappedSubModel::Vision])
1643 }
1644}
1645
1646pub struct VLlamaLoader;
1652
1653pub struct VLlamaPrefixer;
1654
1655impl MultimodalPromptPrefixer for VLlamaPrefixer {
1656 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1657 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1658 }
1659}
1660
1661impl VisionModelLoader for VLlamaLoader {
1662 fn load(
1663 &self,
1664 config: &str,
1665 vb: ShardedVarBuilder,
1666 normal_loading_metadata: NormalLoadingMetadata,
1667 attention_mechanism: AttentionImplementation,
1668 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
1669 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1670 Ok(Box::new(MLlamaModel::new(
1671 &cfg,
1672 vb,
1673 self.is_gptx(config),
1674 normal_loading_metadata,
1675 attention_mechanism,
1676 )?))
1677 }
1678 fn is_gptx(&self, _config: &str) -> bool {
1679 true
1680 }
1681 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1682 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1683 Ok(Box::new(cfg))
1684 }
1685 fn get_processor(
1686 &self,
1687 _model_config: &str,
1688 _processor_config: Option<ProcessorConfig>,
1689 _preprocessor_config: PreProcessorConfig,
1690 _max_edge: Option<u32>,
1691 ) -> Arc<dyn Processor + Send + Sync> {
1692 Arc::new(MLlamaProcessor::new())
1693 }
1694 fn supports_paged_attention(&self, _config: &str) -> bool {
1695 false
1696 }
1697 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1698 true
1699 }
1700 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1701 Arc::new(VLlamaPrefixer)
1702 }
1703 fn modalities(&self, _config: &str) -> Result<Modalities> {
1704 Ok(Modalities {
1705 input: vec![SupportedModality::Text, SupportedModality::Vision],
1706 output: vec![SupportedModality::Text],
1707 })
1708 }
1709}
1710
1711impl IsqModelLoader for VLlamaLoader {
1712 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1713 let config: MLlamaConfig = serde_json::from_str(config)?;
1714 let cross_attn_layers = &config.text_config.cross_attention_layers;
1715 let transformer_layers =
1716 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1717 let mut text_regexes = Vec::new();
1718 for layer in transformer_layers {
1719 text_regexes.extend(vec![
1720 Regex::new(&format!(
1722 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1723 ))?,
1724 Regex::new(&format!(
1725 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1726 ))?,
1727 Regex::new(&format!(
1728 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1729 ))?,
1730 Regex::new(&format!(
1731 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1732 ))?,
1733 Regex::new(&format!(
1735 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1736 ))?,
1737 Regex::new(&format!(
1738 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1739 ))?,
1740 Regex::new(&format!(
1741 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1742 ))?,
1743 ]);
1744 }
1745 let vision_regexes = vec![
1746 Regex::new(
1748 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1749 )?,
1750 Regex::new(
1751 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1752 )?,
1753 Regex::new(
1754 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1755 )?,
1756 Regex::new(
1757 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1758 )?,
1759 Regex::new(
1761 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1762 )?,
1763 Regex::new(
1764 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1765 )?,
1766 Regex::new(
1767 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1768 )?,
1769 Regex::new(
1770 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1771 )?,
1772 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1774 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1775 ];
1776
1777 Ok([text_regexes, vision_regexes].concat())
1778 }
1779 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1780 self.isq_layer_regexes(config)
1781 }
1782}
1783
1784impl DeviceMappedModelLoader for VLlamaLoader {
1785 fn mapped_max_act_size_elems(
1786 &self,
1787 config: &str,
1788 params: &AutoDeviceMapParams,
1789 ) -> Result<usize> {
1790 let AutoDeviceMapParams::Vision {
1791 max_seq_len,
1792 max_batch_size,
1793 max_image_shape: _,
1794 max_num_images,
1795 } = params
1796 else {
1797 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1798 };
1799
1800 let config: MLlamaConfig = serde_json::from_str(config)?;
1801
1802 let img_seq_len = {
1803 let cfg = &config.vision_config;
1804 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1805 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1806 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1807 };
1808 let img_seq_len = img_seq_len * max_num_images;
1809
1810 let max_cross_text_attn = {
1811 let cfg = &config.text_config;
1812 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1813 };
1814
1815 let max_self_text_attn = {
1816 let cfg = &config.text_config;
1817 max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1818 };
1819
1820 Ok(max_self_text_attn.max(max_cross_text_attn))
1821 }
1822
1823 fn non_mapped_max_act_size_elems(
1824 &self,
1825 config: &str,
1826 params: &AutoDeviceMapParams,
1827 ) -> Result<usize> {
1828 let AutoDeviceMapParams::Vision {
1829 max_seq_len: _,
1830 max_batch_size,
1831 max_image_shape: _,
1832 max_num_images,
1833 } = params
1834 else {
1835 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
1836 };
1837
1838 let config: MLlamaConfig = serde_json::from_str(config)?;
1839
1840 let img_seq_len = {
1841 let cfg = &config.vision_config;
1842 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1843 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1844 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1845 };
1846 let max_vision_attn = {
1847 let cfg = &config.vision_config;
1848 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1849 };
1850
1851 Ok(max_vision_attn)
1852 }
1853
1854 fn non_mapped_size_in_bytes(
1855 &self,
1856 config: &str,
1857 dtype: DType,
1858 weight_pack_factor: usize,
1859 _matformer_config: Option<&MatformerSliceConfig>,
1860 ) -> Result<usize> {
1861 let config: MLlamaConfig = serde_json::from_str(config)?;
1862 let text_elems = {
1863 let cfg = &config.text_config;
1864 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1865 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1867 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1868 } else {
1869 0
1870 };
1871 let norm = cfg.hidden_size;
1872 embed_tokens + lm_head + norm
1873 };
1874
1875 let vision_elems = {
1876 let cfg = &config.vision_config;
1877
1878 let conv_cfg = Conv2dConfig {
1879 stride: cfg.patch_size,
1880 ..Default::default()
1881 };
1882 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1883 * cfg.patch_size
1884 * cfg.patch_size;
1885
1886 let class_embedding = cfg.hidden_size;
1887
1888 let gated_positional_embedding = {
1889 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1890 let embedding = num_patches * cfg.hidden_size;
1891 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1892 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1893
1894 embedding + tile_embedding
1895 };
1896
1897 let pre_tile_positional_embedding =
1898 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1899 let post_tile_positional_embedding =
1900 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1901
1902 let layernorm_pre = cfg.hidden_size;
1903 let layernorm_post = cfg.hidden_size;
1904
1905 let encoder_layer = {
1906 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1907 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1908
1909 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1910 let q_proj =
1911 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1912 let k_proj =
1913 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1914 let v_proj =
1915 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1916 let o_proj =
1917 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1918
1919 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
1920 + cfg.intermediate_size;
1921 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
1922 + cfg.hidden_size;
1923
1924 input_layernorm
1925 + post_attention_layernorm
1926 + q_proj
1927 + k_proj
1928 + v_proj
1929 + o_proj
1930 + fc1
1931 + fc2
1932 };
1933
1934 patch_embedding
1935 + class_embedding
1936 + gated_positional_embedding
1937 + pre_tile_positional_embedding
1938 + post_tile_positional_embedding
1939 + layernorm_pre
1940 + layernorm_post
1941 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
1942 };
1943
1944 let elems = text_elems + vision_elems;
1945 Ok(elems * dtype.size_in_bytes())
1946 }
1947
1948 fn layer_sizes_in_bytes(
1949 &self,
1950 config: &str,
1951 dtype: DType,
1952 weight_pack_factor: usize,
1953 _matformer_config: Option<&MatformerSliceConfig>,
1954 ) -> Result<Vec<usize>> {
1955 let config: MLlamaConfig = serde_json::from_str(config)?;
1956 let cfg = &config.text_config;
1957
1958 let mut layer_sizes = Vec::new();
1959
1960 for i in 0..cfg.num_hidden_layers {
1961 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
1962 1
1964 } else {
1965 weight_pack_factor
1966 };
1967
1968 let per_layer_elems = {
1969 let input_layernorm = cfg.hidden_size;
1970 let post_attention_layernorm = cfg.hidden_size;
1971
1972 let size_in = cfg.hidden_size;
1973 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1974 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1975 let q_proj = size_in * size_q / weight_pack_factor;
1976 let k_proj = size_in * size_kv / weight_pack_factor;
1977 let v_proj = size_in * size_kv / weight_pack_factor;
1978 let o_proj = size_q * size_in / weight_pack_factor;
1979
1980 let h_size = cfg.hidden_size;
1981 let i_size = cfg.intermediate_size;
1982 let gate_proj = h_size * i_size / weight_pack_factor;
1983 let up_proj = h_size * i_size / weight_pack_factor;
1984 let down_proj = i_size * h_size / weight_pack_factor;
1985
1986 input_layernorm
1987 + post_attention_layernorm
1988 + q_proj
1989 + k_proj
1990 + v_proj
1991 + o_proj
1992 + gate_proj
1993 + up_proj
1994 + down_proj
1995 };
1996
1997 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
1998 }
1999
2000 Ok(layer_sizes)
2001 }
2002
2003 fn num_layers(&self, config: &str) -> Result<usize> {
2004 let config: MLlamaConfig = serde_json::from_str(config)?;
2005 Ok(config.text_config.num_hidden_layers)
2006 }
2007
2008 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2009 let cfg: MLlamaConfig = serde_json::from_str(config)?;
2010 let cfg = &cfg.text_config;
2011
2012 let cfg = ModelConfigMetadata {
2013 max_seq_len: cfg.max_position_embeddings,
2014 num_layers: cfg.num_hidden_layers,
2015 hidden_size: cfg.hidden_size,
2016 num_kv_heads: cfg.num_key_value_heads,
2017 num_attn_heads: cfg.num_attention_heads,
2018 sliding_window: None,
2019 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2020 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2021 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2022 };
2023
2024 Ok(Box::new(cfg))
2025 }
2026
2027 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2028 Some(vec![NonMappedSubModel::Vision])
2029 }
2030}
2031
2032pub struct Qwen2VLLoader;
2038
2039pub struct Qwen2VLPrefixer;
2040
2041impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2042 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2043 format!(
2044 "{}{prompt}",
2045 format!(
2046 "{}{}{}",
2047 Qwen2VLProcessor::VISION_START,
2048 Qwen2VLProcessor::IMAGE_PAD,
2049 Qwen2VLProcessor::VISION_END
2050 )
2051 .repeat(image_indexes.len())
2052 )
2053 }
2054}
2055
2056impl VisionModelLoader for Qwen2VLLoader {
2057 fn load(
2058 &self,
2059 config: &str,
2060 vb: ShardedVarBuilder,
2061 normal_loading_metadata: NormalLoadingMetadata,
2062 attention_mechanism: AttentionImplementation,
2063 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2064 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2065 Ok(Box::new(Qwen2VLModel::new(
2066 &cfg,
2067 vb,
2068 self.is_gptx(config),
2069 normal_loading_metadata,
2070 attention_mechanism,
2071 )?))
2072 }
2073 fn is_gptx(&self, _config: &str) -> bool {
2074 true
2075 }
2076 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2077 let config: Qwen2VLConfig = serde_json::from_str(config)?;
2078 Ok(Box::new(config))
2079 }
2080 fn get_processor(
2081 &self,
2082 _model_config: &str,
2083 _processor_config: Option<ProcessorConfig>,
2084 _preprocessor_config: PreProcessorConfig,
2085 max_edge: Option<u32>,
2086 ) -> Arc<dyn Processor + Send + Sync> {
2087 Arc::new(Qwen2VLProcessor::new(max_edge))
2088 }
2089 fn supports_paged_attention(&self, _config: &str) -> bool {
2090 false
2091 }
2092 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2093 Arc::new(Qwen2VLPrefixer)
2094 }
2095 fn modalities(&self, _config: &str) -> Result<Modalities> {
2096 Ok(Modalities {
2097 input: vec![SupportedModality::Text, SupportedModality::Vision],
2098 output: vec![SupportedModality::Text],
2099 })
2100 }
2101}
2102
2103impl IsqModelLoader for Qwen2VLLoader {
2104 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2105 Ok(vec![
2106 Regex::new(r"lm_head\.(weight|bias)$")?,
2107 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2109 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2110 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2111 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2112 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2114 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2115 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2116 ])
2117 }
2118 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2119 self.isq_layer_regexes(config)
2120 }
2121}
2122
2123impl DeviceMappedModelLoader for Qwen2VLLoader {
2124 fn mapped_max_act_size_elems(
2125 &self,
2126 config: &str,
2127 params: &AutoDeviceMapParams,
2128 ) -> Result<usize> {
2129 let AutoDeviceMapParams::Vision {
2130 max_seq_len,
2131 max_batch_size,
2132 max_image_shape,
2133 max_num_images,
2134 } = params
2135 else {
2136 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2137 };
2138
2139 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2140
2141 let img_seq_len = {
2143 let cfg = &cfg.vision_config;
2144 let grid_t = 1;
2146 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
2148 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
2149 grid_t * grid_h * grid_w * max_num_images
2150 };
2151
2152 let max_text_attn = {
2153 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2155 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2156 };
2157
2158 Ok(max_text_attn)
2159 }
2160
2161 fn non_mapped_max_act_size_elems(
2162 &self,
2163 config: &str,
2164 params: &AutoDeviceMapParams,
2165 ) -> Result<usize> {
2166 let AutoDeviceMapParams::Vision {
2167 max_seq_len: _,
2168 max_batch_size,
2169 max_image_shape,
2170 max_num_images,
2171 } = params
2172 else {
2173 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2174 };
2175
2176 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2177
2178 let img_seq_len = {
2180 let cfg = &cfg.vision_config;
2181 let grid_t = 1;
2183 let grid_h = max_image_shape.0 / cfg.patch_size;
2184 let grid_w = max_image_shape.1 / cfg.patch_size;
2185 grid_t * grid_h * grid_w
2186 };
2187
2188 let max_vision_attn = {
2189 let cfg = &cfg.vision_config;
2190 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2191 };
2192
2193 Ok(max_vision_attn)
2194 }
2195
2196 fn non_mapped_size_in_bytes(
2197 &self,
2198 config: &str,
2199 dtype: DType,
2200 weight_pack_factor: usize,
2201 _matformer_config: Option<&MatformerSliceConfig>,
2202 ) -> Result<usize> {
2203 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2204 let text_elems = {
2205 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2206 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2208 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2209 } else {
2210 0
2211 };
2212 let norm = cfg.hidden_size;
2213 embed_tokens + lm_head + norm
2214 };
2215
2216 let patch_merger = {
2217 let cfg = &cfg.vision_config;
2218 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2219
2220 let mlp0 = hidden_size * hidden_size + hidden_size;
2221 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2222
2223 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2224
2225 mlp0 + mlp2 + ln_q
2226 };
2227
2228 let patch_embed = {
2229 let cfg = &cfg.vision_config;
2230 let conv_cfg = Conv3dConfig {
2231 stride: cfg.patch_size,
2232 ..Default::default()
2233 };
2234 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2235 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2236 * kernel_sizes[0]
2237 * kernel_sizes[1]
2238 * kernel_sizes[2]
2239 };
2240
2241 let encoder_layer = {
2242 let cfg = &cfg.vision_config;
2243 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2244 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2245
2246 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2247 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2248 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2249 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2250
2251 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2252 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2253
2254 norm1 + norm2 + fc1 + fc2 + qkv + out
2255 };
2256
2257 let elems =
2258 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2259
2260 Ok(elems * dtype.size_in_bytes())
2261 }
2262
2263 fn layer_sizes_in_bytes(
2264 &self,
2265 config: &str,
2266 dtype: DType,
2267 weight_pack_factor: usize,
2268 _matformer_config: Option<&MatformerSliceConfig>,
2269 ) -> Result<Vec<usize>> {
2270 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2271 let per_layer_elems = {
2272 let input_layernorm = cfg.hidden_size;
2273 let post_attention_layernorm = cfg.hidden_size;
2274
2275 let size_in = cfg.hidden_size;
2276 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2277 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2278 let q_proj = size_in * size_q / weight_pack_factor + size_q;
2279 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2280 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2281 let o_proj = size_q * size_in / weight_pack_factor;
2282
2283 let h_size = cfg.hidden_size;
2284 let i_size = cfg.intermediate_size;
2285 let gate_proj = h_size * i_size / weight_pack_factor;
2286 let up_proj = h_size * i_size / weight_pack_factor;
2287 let down_proj = i_size * h_size / weight_pack_factor;
2288
2289 input_layernorm
2290 + post_attention_layernorm
2291 + q_proj
2292 + k_proj
2293 + v_proj
2294 + o_proj
2295 + gate_proj
2296 + up_proj
2297 + down_proj
2298 };
2299 Ok(vec![
2300 per_layer_elems * dtype.size_in_bytes();
2301 cfg.num_hidden_layers
2302 ])
2303 }
2304
2305 fn num_layers(&self, config: &str) -> Result<usize> {
2306 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2307 Ok(cfg.num_hidden_layers)
2308 }
2309
2310 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2311 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2312
2313 let cfg = ModelConfigMetadata {
2314 max_seq_len: cfg.max_position_embeddings,
2315 num_layers: cfg.num_hidden_layers,
2316 hidden_size: cfg.hidden_size,
2317 num_kv_heads: cfg.num_key_value_heads,
2318 num_attn_heads: cfg.num_attention_heads,
2319 sliding_window: cfg.sliding_window,
2320 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2321 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2322 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2323 };
2324
2325 Ok(Box::new(cfg))
2326 }
2327
2328 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2329 Some(vec![NonMappedSubModel::Vision])
2330 }
2331}
2332
2333pub struct Idefics3Loader;
2339
2340pub struct Idefics3Prefixer;
2341
2342impl MultimodalPromptPrefixer for Idefics3Prefixer {
2343 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2344 prompt.to_string()
2346 }
2347}
2348
2349impl VisionModelLoader for Idefics3Loader {
2350 fn load(
2351 &self,
2352 config: &str,
2353 vb: ShardedVarBuilder,
2354 normal_loading_metadata: NormalLoadingMetadata,
2355 attention_mechanism: AttentionImplementation,
2356 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2357 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2358 Ok(Box::new(Idefics3Model::new(
2359 &cfg,
2360 vb,
2361 self.is_gptx(config),
2362 normal_loading_metadata,
2363 attention_mechanism,
2364 )?))
2365 }
2366 fn is_gptx(&self, _config: &str) -> bool {
2367 true
2368 }
2369 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2370 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2371 Ok(Box::new(cfg))
2372 }
2373 fn get_processor(
2374 &self,
2375 _model_config: &str,
2376 processor_config: Option<ProcessorConfig>,
2377 preprocessor_config: PreProcessorConfig,
2378 max_edge: Option<u32>,
2379 ) -> Arc<dyn Processor + Send + Sync> {
2380 Arc::new(Idefics3Processor::new(
2381 processor_config.unwrap_or_default(),
2382 preprocessor_config,
2383 max_edge,
2384 ))
2385 }
2386 fn supports_paged_attention(&self, _config: &str) -> bool {
2387 true
2388 }
2389 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2390 true
2391 }
2392 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2393 Arc::new(Idefics3Prefixer)
2394 }
2395 fn modalities(&self, _config: &str) -> Result<Modalities> {
2396 Ok(Modalities {
2397 input: vec![SupportedModality::Text, SupportedModality::Vision],
2398 output: vec![SupportedModality::Text],
2399 })
2400 }
2401}
2402
2403impl IsqModelLoader for Idefics3Loader {
2404 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2405 Ok(vec![
2406 Regex::new(r"lm_head\.(weight|bias)$")?,
2407 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2409 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2410 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2411 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2412 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2414 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2415 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2416 ])
2417 }
2418 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2419 Ok(vec![
2420 Regex::new(r"lm_head\.(weight|bias)$")?,
2421 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2423 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2424 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2425 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2426 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2428 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2429 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2430 ])
2447 }
2448}
2449
2450impl DeviceMappedModelLoader for Idefics3Loader {
2451 fn mapped_max_act_size_elems(
2452 &self,
2453 config: &str,
2454 params: &AutoDeviceMapParams,
2455 ) -> Result<usize> {
2456 let AutoDeviceMapParams::Vision {
2457 max_seq_len,
2458 max_batch_size,
2459 max_image_shape: _,
2460 max_num_images,
2461 } = params
2462 else {
2463 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2464 };
2465
2466 let cfg: Idefics3Config = serde_json::from_str(config)?;
2467
2468 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2469 let img_seq_len = (num_patches + 1) * max_num_images;
2470
2471 let max_text_attn = {
2472 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2474 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2475 };
2476
2477 Ok(max_text_attn)
2478 }
2479
2480 fn non_mapped_max_act_size_elems(
2481 &self,
2482 config: &str,
2483 params: &AutoDeviceMapParams,
2484 ) -> Result<usize> {
2485 let AutoDeviceMapParams::Vision {
2486 max_seq_len: _,
2487 max_batch_size,
2488 max_image_shape: _,
2489 max_num_images,
2490 } = params
2491 else {
2492 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2493 };
2494
2495 let cfg: Idefics3Config = serde_json::from_str(config)?;
2496
2497 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2498 let img_seq_len = num_patches + 1;
2499
2500 let max_vision_attn = {
2501 let images_factor = 5;
2503
2504 (max_batch_size * images_factor * max_num_images)
2505 * cfg.vision_config.num_attention_heads
2506 * img_seq_len
2507 * img_seq_len
2508 };
2509
2510 Ok(max_vision_attn)
2511 }
2512
2513 fn non_mapped_size_in_bytes(
2514 &self,
2515 config: &str,
2516 dtype: DType,
2517 weight_pack_factor: usize,
2518 _matformer_config: Option<&MatformerSliceConfig>,
2519 ) -> Result<usize> {
2520 let cfg: Idefics3Config = serde_json::from_str(config)?;
2521 let text_elems = {
2522 let cfg = &cfg.text_config;
2523
2524 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2525 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2526 let norm = cfg.hidden_size;
2527 embed_tokens + lm_head + norm
2528 };
2529
2530 let connector_elems = {
2531 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2532 let out_dim = cfg.text_config.hidden_size;
2533
2534 in_dim * out_dim
2535 };
2536
2537 let vision_transformer = {
2538 let cfg = &cfg.vision_config;
2539
2540 let post_layernorm = cfg.hidden_size;
2541
2542 let conv_config = Conv2dConfig {
2543 stride: cfg.patch_size,
2544 ..Default::default()
2545 };
2546 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2547 * cfg.patch_size
2548 * cfg.patch_size;
2549
2550 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2551 let num_patches = num_patches_per_side.pow(2);
2552 let position_embedding = num_patches * cfg.hidden_size;
2553
2554 let layer_elems = {
2555 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2556 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2557
2558 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2559 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2560
2561 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2562 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2563 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2564 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2565
2566 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2567 };
2568
2569 post_layernorm
2570 + patch_embedding
2571 + position_embedding
2572 + layer_elems * cfg.num_hidden_layers
2573 };
2574
2575 let elems = text_elems + connector_elems + vision_transformer;
2576
2577 Ok(elems * dtype.size_in_bytes())
2578 }
2579
2580 fn layer_sizes_in_bytes(
2581 &self,
2582 config: &str,
2583 dtype: DType,
2584 weight_pack_factor: usize,
2585 _matformer_config: Option<&MatformerSliceConfig>,
2586 ) -> Result<Vec<usize>> {
2587 let cfg: Idefics3Config = serde_json::from_str(config)?;
2588 let cfg = cfg.text_config;
2589 let per_layer_elems = {
2590 let input_layernorm = cfg.hidden_size;
2591 let post_attention_layernorm = cfg.hidden_size;
2592
2593 let size_in = cfg.hidden_size;
2594 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2595 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2596 let q_proj = size_in * size_q / weight_pack_factor;
2597 let k_proj = size_in * size_kv / weight_pack_factor;
2598 let v_proj = size_in * size_kv / weight_pack_factor;
2599 let o_proj = size_q * size_in / weight_pack_factor;
2600
2601 let h_size = cfg.hidden_size;
2602 let i_size = cfg.intermediate_size;
2603 let gate_proj = h_size * i_size / weight_pack_factor;
2604 let up_proj = h_size * i_size / weight_pack_factor;
2605 let down_proj = i_size * h_size / weight_pack_factor;
2606
2607 input_layernorm
2608 + post_attention_layernorm
2609 + q_proj
2610 + k_proj
2611 + v_proj
2612 + o_proj
2613 + gate_proj
2614 + up_proj
2615 + down_proj
2616 };
2617 Ok(vec![
2618 per_layer_elems * dtype.size_in_bytes();
2619 cfg.num_hidden_layers
2620 ])
2621 }
2622
2623 fn num_layers(&self, config: &str) -> Result<usize> {
2624 let cfg: Idefics3Config = serde_json::from_str(config)?;
2625 Ok(cfg.text_config.num_hidden_layers)
2626 }
2627 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2628 let cfg: Idefics3Config = serde_json::from_str(config)?;
2629 let cfg = &cfg.text_config;
2630
2631 let cfg = ModelConfigMetadata {
2632 max_seq_len: cfg.max_position_embeddings,
2633 num_layers: cfg.num_hidden_layers,
2634 hidden_size: cfg.hidden_size,
2635 num_kv_heads: cfg.num_key_value_heads,
2636 num_attn_heads: cfg.num_attention_heads,
2637 sliding_window: None,
2638 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2639 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2640 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2641 };
2642
2643 Ok(Box::new(cfg))
2644 }
2645
2646 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2647 Some(vec![NonMappedSubModel::Vision])
2648 }
2649}
2650
2651pub struct MiniCpmOLoader;
2657
2658pub struct MiniCpmOPrefixer;
2659
2660impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2661 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2662 format!(
2663 "{}{prompt}",
2664 "(<image>./</image>)".repeat(image_indexes.len())
2665 )
2666 }
2667}
2668
2669impl VisionModelLoader for MiniCpmOLoader {
2670 fn load(
2671 &self,
2672 config: &str,
2673 vb: ShardedVarBuilder,
2674 normal_loading_metadata: NormalLoadingMetadata,
2675 attention_mechanism: AttentionImplementation,
2676 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2677 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2678 Ok(Box::new(MiniCpmOModel::new(
2679 &cfg,
2680 vb,
2681 self.is_gptx(config),
2682 normal_loading_metadata,
2683 attention_mechanism,
2684 )?))
2685 }
2686 fn is_gptx(&self, _config: &str) -> bool {
2687 true
2688 }
2689 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2690 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2691 Ok(Box::new(cfg))
2692 }
2693 fn get_processor(
2694 &self,
2695 _model_config: &str,
2696 processor_config: Option<ProcessorConfig>,
2697 preprocessor_config: PreProcessorConfig,
2698 max_edge: Option<u32>,
2699 ) -> Arc<dyn Processor + Send + Sync> {
2700 Arc::new(MiniCpmOProcessor::new(
2701 processor_config.unwrap_or_default(),
2702 preprocessor_config,
2703 max_edge,
2704 ))
2705 }
2706 fn supports_paged_attention(&self, _config: &str) -> bool {
2707 true
2708 }
2709 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2710 Arc::new(MiniCpmOPrefixer)
2711 }
2712 fn modalities(&self, _config: &str) -> Result<Modalities> {
2713 Ok(Modalities {
2714 input: vec![SupportedModality::Text, SupportedModality::Vision],
2715 output: vec![SupportedModality::Text],
2716 })
2717 }
2718}
2719
2720impl IsqModelLoader for MiniCpmOLoader {
2721 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2722 Ok(vec![
2723 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2724 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2726 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2727 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2728 Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2729 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2731 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2732 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2733 ])
2734 }
2735 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2736 self.isq_layer_regexes(config)
2737 }
2738}
2739
2740impl DeviceMappedModelLoader for MiniCpmOLoader {
2741 fn mapped_max_act_size_elems(
2742 &self,
2743 config: &str,
2744 params: &AutoDeviceMapParams,
2745 ) -> Result<usize> {
2746 let AutoDeviceMapParams::Vision {
2747 max_seq_len,
2748 max_batch_size,
2749 max_image_shape: _,
2750 max_num_images,
2751 } = params
2752 else {
2753 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2754 };
2755
2756 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2757
2758 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2759 let img_seq_len = (num_patches + 1) * max_num_images;
2760
2761 let max_text_attn = {
2762 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2764 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2765 };
2766
2767 Ok(max_text_attn)
2768 }
2769
2770 fn non_mapped_max_act_size_elems(
2771 &self,
2772 config: &str,
2773 params: &AutoDeviceMapParams,
2774 ) -> Result<usize> {
2775 let AutoDeviceMapParams::Vision {
2776 max_seq_len: _,
2777 max_batch_size,
2778 max_image_shape: _,
2779 max_num_images,
2780 } = params
2781 else {
2782 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
2783 };
2784
2785 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2786
2787 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2788 let img_seq_len = num_patches + 1;
2789
2790 let max_vision_attn = {
2791 let images_factor = 5;
2793
2794 (max_batch_size * images_factor * max_num_images)
2795 * cfg.vision_config.num_attention_heads
2796 * img_seq_len
2797 * img_seq_len
2798 };
2799
2800 Ok(max_vision_attn)
2801 }
2802
2803 fn non_mapped_size_in_bytes(
2804 &self,
2805 config: &str,
2806 dtype: DType,
2807 weight_pack_factor: usize,
2808 _matformer_config: Option<&MatformerSliceConfig>,
2809 ) -> Result<usize> {
2810 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2811 let text_elems = {
2812 let cfg = &cfg.text_config;
2813
2814 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2815 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2816 let norm = cfg.hidden_size;
2817 embed_tokens + lm_head + norm
2818 };
2819
2820 let vision_transformer = {
2821 let cfg = &cfg.vision_config;
2822
2823 let post_layernorm = cfg.hidden_size;
2824
2825 let conv_config = Conv2dConfig {
2826 stride: cfg.patch_size,
2827 ..Default::default()
2828 };
2829 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2830 * cfg.patch_size
2831 * cfg.patch_size;
2832
2833 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2834 let num_patches = num_patches_per_side.pow(2);
2835 let position_embedding = num_patches * cfg.hidden_size;
2836
2837 let layer_elems = {
2838 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2839 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2840
2841 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2842 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2843
2844 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2845 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2846 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2847 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2848
2849 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2850 };
2851
2852 post_layernorm
2853 + patch_embedding
2854 + position_embedding
2855 + layer_elems * cfg.num_hidden_layers
2856 };
2857
2858 let elems = text_elems + vision_transformer;
2859
2860 Ok(elems * dtype.size_in_bytes())
2861 }
2862
2863 fn layer_sizes_in_bytes(
2864 &self,
2865 config: &str,
2866 dtype: DType,
2867 weight_pack_factor: usize,
2868 _matformer_config: Option<&MatformerSliceConfig>,
2869 ) -> Result<Vec<usize>> {
2870 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2871 let cfg = cfg.text_config;
2872 let per_layer_elems = {
2873 let input_layernorm = cfg.hidden_size;
2874 let post_attention_layernorm = cfg.hidden_size;
2875
2876 let size_in = cfg.hidden_size;
2877 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2878 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2879 let q_proj = size_in * size_q / weight_pack_factor;
2880 let k_proj = size_in * size_kv / weight_pack_factor;
2881 let v_proj = size_in * size_kv / weight_pack_factor;
2882 let o_proj = size_q * size_in / weight_pack_factor;
2883
2884 let h_size = cfg.hidden_size;
2885 let i_size = cfg.intermediate_size;
2886 let gate_proj = h_size * i_size / weight_pack_factor;
2887 let up_proj = h_size * i_size / weight_pack_factor;
2888 let down_proj = i_size * h_size / weight_pack_factor;
2889
2890 input_layernorm
2891 + post_attention_layernorm
2892 + q_proj
2893 + k_proj
2894 + v_proj
2895 + o_proj
2896 + gate_proj
2897 + up_proj
2898 + down_proj
2899 };
2900 Ok(vec![
2901 per_layer_elems * dtype.size_in_bytes();
2902 cfg.num_hidden_layers
2903 ])
2904 }
2905
2906 fn num_layers(&self, config: &str) -> Result<usize> {
2907 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2908 Ok(cfg.text_config.num_hidden_layers)
2909 }
2910 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2911 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2912 let cfg = &cfg.text_config;
2913
2914 let cfg = ModelConfigMetadata {
2915 max_seq_len: cfg.max_position_embeddings,
2916 num_layers: cfg.num_hidden_layers,
2917 hidden_size: cfg.hidden_size,
2918 num_kv_heads: cfg.num_key_value_heads,
2919 num_attn_heads: cfg.num_attention_heads,
2920 sliding_window: None,
2921 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2922 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2923 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2924 };
2925
2926 Ok(Box::new(cfg))
2927 }
2928}
2929
2930pub struct Phi4MMLoader;
2936
2937pub struct Phi4MMPrefixer;
2938
2939impl MultimodalPromptPrefixer for Phi4MMPrefixer {
2940 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2941 format!(
2944 "{}{prompt}",
2945 image_indexes
2946 .into_iter()
2947 .map(|image_index| format!("<|image_{}|>", image_index + 1))
2948 .join("")
2949 )
2950 }
2951 fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
2952 format!(
2955 "{}{prompt}",
2956 audio_indexes
2957 .into_iter()
2958 .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
2959 .join("")
2960 )
2961 }
2962}
2963
2964impl VisionModelLoader for Phi4MMLoader {
2965 fn load(
2966 &self,
2967 config: &str,
2968 vb: ShardedVarBuilder,
2969 normal_loading_metadata: NormalLoadingMetadata,
2970 attention_mechanism: AttentionImplementation,
2971 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
2972 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2973 Ok(Box::new(Phi4MMModel::new(
2974 &cfg,
2975 vb,
2976 self.is_gptx(config),
2977 normal_loading_metadata,
2978 attention_mechanism,
2979 )?))
2980 }
2981 fn is_gptx(&self, _config: &str) -> bool {
2982 true
2983 }
2984 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2985 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
2986 Ok(Box::new(cfg))
2987 }
2988 fn get_processor(
2989 &self,
2990 _model_config: &str,
2991 processor_config: Option<ProcessorConfig>,
2992 preprocessor_config: PreProcessorConfig,
2993 _max_edge: Option<u32>,
2994 ) -> Arc<dyn Processor + Send + Sync> {
2995 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
2996 }
2997 fn supports_paged_attention(&self, _config: &str) -> bool {
2998 true
2999 }
3000 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3001 true
3002 }
3003 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3004 Arc::new(Phi4MMPrefixer)
3005 }
3006 fn modalities(&self, _config: &str) -> Result<Modalities> {
3007 Ok(Modalities {
3008 input: vec![
3009 SupportedModality::Text,
3010 SupportedModality::Vision,
3011 SupportedModality::Audio,
3012 ],
3013 output: vec![SupportedModality::Text],
3014 })
3015 }
3016}
3017
3018impl IsqModelLoader for Phi4MMLoader {
3019 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3020 Ok(vec![
3021 Regex::new(r"lm_head\.(weight|bias)$")?,
3022 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3024 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3025 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3027 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3028 ])
3029 }
3030 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3031 self.isq_layer_regexes(config)
3032 }
3033}
3034
3035impl DeviceMappedModelLoader for Phi4MMLoader {
3036 fn mapped_max_act_size_elems(
3037 &self,
3038 config: &str,
3039 params: &AutoDeviceMapParams,
3040 ) -> Result<usize> {
3041 let AutoDeviceMapParams::Vision {
3043 max_seq_len,
3044 max_batch_size,
3045 max_image_shape: _,
3046 max_num_images,
3047 } = params
3048 else {
3049 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3050 };
3051
3052 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3053
3054 let vcfg = &PHI4_MM_VISION_CFG;
3055
3056 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3057 let img_seq_len = (num_patches + 1) * max_num_images;
3058
3059 let max_text_attn = {
3060 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3062 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3063 };
3064
3065 Ok(max_text_attn)
3066 }
3067
3068 fn non_mapped_max_act_size_elems(
3069 &self,
3070 _config: &str,
3071 params: &AutoDeviceMapParams,
3072 ) -> Result<usize> {
3073 let AutoDeviceMapParams::Vision {
3074 max_seq_len: _,
3075 max_batch_size,
3076 max_image_shape,
3077 max_num_images,
3078 } = params
3079 else {
3080 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3081 };
3082
3083 let vcfg = &PHI4_MM_VISION_CFG;
3084
3085 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3086 let img_seq_len = num_patches + 1;
3087
3088 let max_batch_size = max_batch_size
3089 * (max_image_shape
3090 .0
3091 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3092 * max_image_shape
3093 .1
3094 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3095 + 1);
3096
3097 let max_vision_attn = (max_batch_size * max_num_images)
3098 * vcfg.num_attention_heads
3099 * img_seq_len
3100 * img_seq_len;
3101 let max_qkv = 3
3102 * (max_batch_size
3103 * vcfg.num_attention_heads
3104 * img_seq_len
3105 * (vcfg.hidden_size / vcfg.num_attention_heads));
3106
3107 Ok(max_vision_attn + max_qkv)
3108 }
3109
3110 fn non_mapped_size_in_bytes(
3111 &self,
3112 config: &str,
3113 dtype: DType,
3114 weight_pack_factor: usize,
3115 _matformer_config: Option<&MatformerSliceConfig>,
3116 ) -> Result<usize> {
3117 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3118 let elems = {
3119 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3120 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3122 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3123 } else {
3124 0
3125 };
3126 let norm = cfg.hidden_size;
3127
3128 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3129 let projection_cls = img_embed
3130 .projection_cls
3131 .clone()
3132 .unwrap_or("linear".to_string());
3133 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3134 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3135 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3136
3137 let proj = match (projection_cls.as_str(), use_hd_transform) {
3138 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3139 ("mlp", true) => {
3140 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3141 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3142 a + b
3143 }
3144 ("mlp", false) => {
3145 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3146 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3147 a + b
3148 }
3149 _ => {
3150 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3151 }
3152 };
3153
3154 let (glb_gn, sub_gn) = if with_learnable_separator {
3155 let glb_gn = image_dim_out * 4;
3156 let sub_gn = image_dim_out * 4;
3157 (glb_gn, sub_gn)
3158 } else {
3159 (0, 0)
3160 };
3161
3162 let vision_transformer = {
3163 let cfg = &PHI4_MM_VISION_CFG;
3164
3165 let post_layernorm = cfg.hidden_size;
3166
3167 let conv_config = Conv2dConfig {
3168 stride: cfg.patch_size,
3169 ..Default::default()
3170 };
3171 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3172 * cfg.patch_size
3173 * cfg.patch_size;
3174
3175 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3176 let num_patches = num_patches_per_side.pow(2);
3177 let position_embedding = num_patches * cfg.hidden_size;
3178
3179 let layer_elems = {
3180 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3181 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3182
3183 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3184 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3185
3186 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3187 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3188 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3189 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3190
3191 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3192 };
3193
3194 post_layernorm
3195 + patch_embedding
3196 + position_embedding
3197 + layer_elems * cfg.num_hidden_layers
3198 };
3199
3200 proj + glb_gn + sub_gn + vision_transformer
3201 } else {
3202 0
3203 };
3204
3205 embed_tokens + lm_head + norm + image_embed
3206 };
3207
3208 Ok(elems * dtype.size_in_bytes())
3209 }
3210
3211 fn layer_sizes_in_bytes(
3212 &self,
3213 config: &str,
3214 dtype: DType,
3215 weight_pack_factor: usize,
3216 _matformer_config: Option<&MatformerSliceConfig>,
3217 ) -> Result<Vec<usize>> {
3218 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3219 let per_layer_elems = {
3220 let input_layernorm = cfg.hidden_size;
3221 let post_attention_layernorm = cfg.hidden_size;
3222
3223 let size_in = cfg.hidden_size;
3224 let head_dim = cfg.head_dim();
3225 let op_size =
3226 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3227 let qkv_proj = size_in * op_size / weight_pack_factor;
3228 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3229
3230 let h_size = cfg.hidden_size;
3231 let i_size = cfg.intermediate_size;
3232 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3233 let down_proj = h_size * i_size / weight_pack_factor;
3234
3235 input_layernorm
3236 + post_attention_layernorm
3237 + qkv_proj
3238 + o_proj
3239 + gate_up_proj
3240 + down_proj
3241 };
3242 Ok(vec![
3243 per_layer_elems * dtype.size_in_bytes();
3244 cfg.num_hidden_layers
3245 ])
3246 }
3247
3248 fn num_layers(&self, config: &str) -> Result<usize> {
3249 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3250 Ok(cfg.num_hidden_layers)
3251 }
3252
3253 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3254 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3255
3256 let cfg = ModelConfigMetadata {
3257 max_seq_len: cfg.max_position_embeddings,
3258 num_layers: cfg.num_hidden_layers,
3259 hidden_size: cfg.hidden_size,
3260 num_kv_heads: cfg.num_key_value_heads(),
3261 num_attn_heads: cfg.num_attention_heads,
3262 sliding_window: cfg.sliding_window,
3263 k_head_dim: cfg.head_dim(),
3264 v_head_dim: cfg.head_dim(),
3265 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3266 };
3267
3268 Ok(Box::new(cfg))
3269 }
3270
3271 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3272 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3273 }
3274}
3275
3276pub struct Qwen2_5VLLoader;
3282
3283pub struct Qwen2_5VLPrefixer;
3284
3285impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3286 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3287 format!(
3288 "{}{prompt}",
3289 format!(
3290 "{}{}{}",
3291 Qwen2_5VLProcessor::VISION_START,
3292 Qwen2_5VLProcessor::IMAGE_PAD,
3293 Qwen2_5VLProcessor::VISION_END
3294 )
3295 .repeat(image_indexes.len())
3296 )
3297 }
3298}
3299
3300impl VisionModelLoader for Qwen2_5VLLoader {
3301 fn load(
3302 &self,
3303 config: &str,
3304 vb: ShardedVarBuilder,
3305 normal_loading_metadata: NormalLoadingMetadata,
3306 attention_mechanism: AttentionImplementation,
3307 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3308 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3309 Ok(Box::new(Qwen2_5VLModel::new(
3310 &cfg,
3311 vb,
3312 self.is_gptx(config),
3313 normal_loading_metadata,
3314 attention_mechanism,
3315 )?))
3316 }
3317 fn is_gptx(&self, _config: &str) -> bool {
3318 true
3319 }
3320 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3321 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3322 Ok(Box::new(config))
3323 }
3324 fn get_processor(
3325 &self,
3326 _model_config: &str,
3327 _processor_config: Option<ProcessorConfig>,
3328 _preprocessor_config: PreProcessorConfig,
3329 max_edge: Option<u32>,
3330 ) -> Arc<dyn Processor + Send + Sync> {
3331 Arc::new(Qwen2_5VLProcessor::new(max_edge))
3332 }
3333 fn supports_paged_attention(&self, _config: &str) -> bool {
3334 false
3335 }
3336 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3337 Arc::new(Qwen2_5VLPrefixer)
3338 }
3339 fn modalities(&self, _config: &str) -> Result<Modalities> {
3340 Ok(Modalities {
3341 input: vec![SupportedModality::Text, SupportedModality::Vision],
3342 output: vec![SupportedModality::Text],
3343 })
3344 }
3345}
3346
3347impl IsqModelLoader for Qwen2_5VLLoader {
3348 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3349 Ok(vec![
3350 Regex::new(r"lm_head\.(weight|bias)$")?,
3351 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3353 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3354 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3355 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3356 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3358 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3359 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3360 ])
3361 }
3362 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3363 self.isq_layer_regexes(config)
3364 }
3365}
3366
3367impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3368 fn mapped_max_act_size_elems(
3369 &self,
3370 config: &str,
3371 params: &AutoDeviceMapParams,
3372 ) -> Result<usize> {
3373 let AutoDeviceMapParams::Vision {
3374 max_seq_len,
3375 max_batch_size,
3376 max_image_shape,
3377 max_num_images,
3378 } = params
3379 else {
3380 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3381 };
3382
3383 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3384
3385 let img_seq_len = {
3386 let cfg = &cfg.vision_config;
3387 let grid_t = max_num_images / cfg.temporal_patch_size;
3388 let grid_h = max_image_shape.0 / cfg.patch_size;
3389 let grid_w = max_image_shape.1 / cfg.patch_size;
3390 grid_t * grid_h * grid_w
3391 };
3392 let img_seq_len = img_seq_len * max_num_images;
3393
3394 let max_text_attn = {
3395 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3397 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3398 };
3399
3400 Ok(max_text_attn)
3401 }
3402
3403 fn non_mapped_max_act_size_elems(
3404 &self,
3405 config: &str,
3406 params: &AutoDeviceMapParams,
3407 ) -> Result<usize> {
3408 let AutoDeviceMapParams::Vision {
3409 max_seq_len: _,
3410 max_batch_size,
3411 max_image_shape,
3412 max_num_images,
3413 } = params
3414 else {
3415 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3416 };
3417
3418 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3419
3420 let img_seq_len = {
3421 let cfg = &cfg.vision_config;
3422 let grid_t = max_num_images / cfg.temporal_patch_size;
3423 let grid_h = max_image_shape.0 / cfg.patch_size;
3424 let grid_w = max_image_shape.1 / cfg.patch_size;
3425 grid_t * grid_h * grid_w
3426 };
3427
3428 let max_vision_attn = {
3429 let cfg = &cfg.vision_config;
3430 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3431 };
3432
3433 Ok(max_vision_attn)
3434 }
3435
3436 fn non_mapped_size_in_bytes(
3437 &self,
3438 config: &str,
3439 dtype: DType,
3440 weight_pack_factor: usize,
3441 _matformer_config: Option<&MatformerSliceConfig>,
3442 ) -> Result<usize> {
3443 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3444 let text_elems = {
3445 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3446 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3448 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3449 } else {
3450 0
3451 };
3452 let norm = cfg.hidden_size;
3453 embed_tokens + lm_head + norm
3454 };
3455
3456 let patch_merger = {
3457 let cfg = &cfg.vision_config;
3458 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3459
3460 let mlp0 = hidden_size * hidden_size + hidden_size;
3461 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3462
3463 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3464
3465 mlp0 + mlp2 + ln_q
3466 };
3467
3468 let patch_embed = {
3469 let cfg = &cfg.vision_config;
3470 let conv_cfg = Conv3dConfig {
3471 stride: cfg.patch_size,
3472 ..Default::default()
3473 };
3474 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3475 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3476 * kernel_sizes[0]
3477 * kernel_sizes[1]
3478 * kernel_sizes[2]
3479 };
3480
3481 let encoder_layer = {
3482 let cfg = &cfg.vision_config;
3483 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3484 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3485
3486 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3487 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3488 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3489
3490 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3491 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3492
3493 norm1 + norm2 + fc1 + fc2 + qkv + out
3494 };
3495
3496 let elems =
3497 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3498
3499 Ok(elems * dtype.size_in_bytes())
3500 }
3501
3502 fn layer_sizes_in_bytes(
3503 &self,
3504 config: &str,
3505 dtype: DType,
3506 weight_pack_factor: usize,
3507 _matformer_config: Option<&MatformerSliceConfig>,
3508 ) -> Result<Vec<usize>> {
3509 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3510 let per_layer_elems = {
3511 let input_layernorm = cfg.hidden_size;
3512 let post_attention_layernorm = cfg.hidden_size;
3513
3514 let size_in = cfg.hidden_size;
3515 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3516 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3517 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3518 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3519 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3520 let o_proj = size_q * size_in / weight_pack_factor;
3521
3522 let h_size = cfg.hidden_size;
3523 let i_size = cfg.intermediate_size;
3524 let gate_proj = h_size * i_size / weight_pack_factor;
3525 let up_proj = h_size * i_size / weight_pack_factor;
3526 let down_proj = i_size * h_size / weight_pack_factor;
3527
3528 input_layernorm
3529 + post_attention_layernorm
3530 + q_proj
3531 + k_proj
3532 + v_proj
3533 + o_proj
3534 + gate_proj
3535 + up_proj
3536 + down_proj
3537 };
3538 Ok(vec![
3539 per_layer_elems * dtype.size_in_bytes();
3540 cfg.num_hidden_layers
3541 ])
3542 }
3543
3544 fn num_layers(&self, config: &str) -> Result<usize> {
3545 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3546 Ok(cfg.num_hidden_layers)
3547 }
3548
3549 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3550 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3551
3552 let cfg = ModelConfigMetadata {
3553 max_seq_len: cfg.max_position_embeddings,
3554 num_layers: cfg.num_hidden_layers,
3555 hidden_size: cfg.hidden_size,
3556 num_kv_heads: cfg.num_key_value_heads,
3557 num_attn_heads: cfg.num_attention_heads,
3558 sliding_window: cfg.sliding_window,
3559 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3560 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3561 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3562 };
3563
3564 Ok(Box::new(cfg))
3565 }
3566
3567 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3568 Some(vec![NonMappedSubModel::Vision])
3569 }
3570}
3571
3572pub struct Gemma3Loader;
3578
3579pub struct Gemma3Prefixer;
3580
3581impl MultimodalPromptPrefixer for Gemma3Prefixer {
3582 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3583 prompt.to_string()
3584 }
3585}
3586
3587impl VisionModelLoader for Gemma3Loader {
3588 fn load(
3589 &self,
3590 config: &str,
3591 vb: ShardedVarBuilder,
3592 normal_loading_metadata: NormalLoadingMetadata,
3593 attention_mechanism: AttentionImplementation,
3594 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3595 let cfg: Gemma3Config = serde_json::from_str(config)?;
3596 Ok(Box::new(Gemma3Model::new(
3597 &cfg,
3598 vb,
3599 self.is_gptx(config),
3600 normal_loading_metadata,
3601 attention_mechanism,
3602 )?))
3603 }
3604 fn is_gptx(&self, _config: &str) -> bool {
3605 true
3606 }
3607 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3608 let config: Gemma3Config = serde_json::from_str(config)?;
3609 Ok(Box::new(config))
3610 }
3611 fn get_processor(
3612 &self,
3613 config: &str,
3614 processor_config: Option<ProcessorConfig>,
3615 _preprocessor_config: PreProcessorConfig,
3616 _max_edge: Option<u32>,
3617 ) -> Arc<dyn Processor + Send + Sync> {
3618 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3619 Arc::new(Gemma3Processor::new(
3621 processor_config.unwrap_or_default(),
3622 matches!(config, Gemma3Config::WithVision { .. }),
3623 ))
3624 }
3625 fn supports_paged_attention(&self, _config: &str) -> bool {
3626 true
3627 }
3628 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3629 true
3630 }
3631 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3632 Arc::new(Gemma3Prefixer)
3633 }
3634 fn modalities(&self, _config: &str) -> Result<Modalities> {
3635 Ok(Modalities {
3636 input: vec![SupportedModality::Text, SupportedModality::Vision],
3637 output: vec![SupportedModality::Text],
3638 })
3639 }
3640}
3641
3642impl IsqModelLoader for Gemma3Loader {
3643 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3644 Ok(vec![
3645 Regex::new(r"lm_head\.(weight|bias)$")?,
3646 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3648 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3649 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3650 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3651 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3653 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3654 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3655 ])
3656 }
3657 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3658 Ok(vec![
3659 Regex::new(r"lm_head\.(weight|bias)$")?,
3660 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3662 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3663 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3664 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3665 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3667 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3668 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3669 ])
3670 }
3671}
3672
3673impl DeviceMappedModelLoader for Gemma3Loader {
3674 fn mapped_max_act_size_elems(
3675 &self,
3676 config: &str,
3677 params: &AutoDeviceMapParams,
3678 ) -> Result<usize> {
3679 let AutoDeviceMapParams::Vision {
3680 max_seq_len,
3681 max_batch_size,
3682 max_image_shape: _,
3683 max_num_images,
3684 } = params
3685 else {
3686 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3687 };
3688
3689 let cfg: Gemma3Config = serde_json::from_str(config)?;
3690
3691 match cfg {
3692 Gemma3Config::Text(text_config) => Ok(max_batch_size
3693 * text_config.num_attention_heads
3694 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3695 Gemma3Config::WithVision {
3696 text_config,
3697 vision_config,
3698 ..
3699 } => {
3700 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3701 let img_seq_len = (num_patches + 1) * max_num_images;
3702
3703 let max_text_attn = {
3704 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3706 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3707 };
3708 Ok(max_text_attn)
3709 }
3710 }
3711 }
3712
3713 fn non_mapped_max_act_size_elems(
3714 &self,
3715 config: &str,
3716 params: &AutoDeviceMapParams,
3717 ) -> Result<usize> {
3718 let AutoDeviceMapParams::Vision {
3719 max_seq_len: _,
3720 max_batch_size,
3721 max_image_shape: _,
3722 max_num_images,
3723 } = params
3724 else {
3725 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
3726 };
3727
3728 let cfg: Gemma3Config = serde_json::from_str(config)?;
3729
3730 match cfg {
3731 Gemma3Config::WithVision { vision_config, .. } => {
3732 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3733 let img_seq_len = num_patches + 1;
3734
3735 let max_vision_attn = {
3736 (max_batch_size * max_num_images)
3737 * vision_config.num_attention_heads
3738 * img_seq_len
3739 * img_seq_len
3740 };
3741
3742 Ok(max_vision_attn)
3743 }
3744 Gemma3Config::Text(_) => Ok(0),
3745 }
3746 }
3747
3748 fn non_mapped_size_in_bytes(
3749 &self,
3750 config: &str,
3751 dtype: DType,
3752 weight_pack_factor: usize,
3753 _matformer_config: Option<&MatformerSliceConfig>,
3754 ) -> Result<usize> {
3755 let cfg: Gemma3Config = serde_json::from_str(config)?;
3756
3757 let text_elems = {
3758 let cfg = match &cfg {
3759 Gemma3Config::Text(cfg) => cfg,
3760 Gemma3Config::WithVision { text_config, .. } => text_config,
3761 };
3762 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3763 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3765 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3766 } else {
3767 0
3768 };
3769 let norm = cfg.hidden_size;
3770 embed_tokens + lm_head + norm
3771 };
3772
3773 let vision_transformer = if let Gemma3Config::WithVision {
3774 vision_config: cfg, ..
3775 } = &cfg
3776 {
3777 let post_layernorm = cfg.hidden_size;
3778
3779 let conv_config = Conv2dConfig {
3780 stride: cfg.patch_size,
3781 ..Default::default()
3782 };
3783 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3784 * cfg.patch_size
3785 * cfg.patch_size;
3786
3787 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3788 let num_patches = num_patches_per_side.pow(2);
3789 let position_embedding = num_patches * cfg.hidden_size;
3790
3791 let layer_elems = {
3792 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3793 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3794
3795 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3796 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3797
3798 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3799 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3800 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3801 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3802
3803 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3804 };
3805
3806 post_layernorm
3807 + patch_embedding
3808 + position_embedding
3809 + layer_elems * cfg.num_hidden_layers
3810 } else {
3811 0
3812 };
3813
3814 let elems = text_elems + vision_transformer;
3815
3816 Ok(elems * dtype.size_in_bytes())
3817 }
3818
3819 fn layer_sizes_in_bytes(
3820 &self,
3821 config: &str,
3822 dtype: DType,
3823 weight_pack_factor: usize,
3824 _matformer_config: Option<&MatformerSliceConfig>,
3825 ) -> Result<Vec<usize>> {
3826 let cfg: Gemma3Config = serde_json::from_str(config)?;
3827
3828 let txt_cfg = match &cfg {
3829 Gemma3Config::Text(cfg) => cfg,
3830 Gemma3Config::WithVision { text_config, .. } => text_config,
3831 };
3832 let per_layer_elems = {
3833 let cfg = txt_cfg;
3834
3835 let input_layernorm = cfg.hidden_size;
3836 let post_attention_layernorm = cfg.hidden_size;
3837
3838 let size_in = cfg.hidden_size;
3839 let size_q = cfg.head_dim * cfg.num_attention_heads;
3840 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3841 let q_proj =
3842 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3843 let k_proj =
3844 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3845 let v_proj =
3846 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3847 let o_proj =
3848 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3849
3850 let h_size = cfg.hidden_size;
3851 let i_size = cfg.intermediate_size;
3852 let gate_proj = h_size * i_size / weight_pack_factor;
3853 let up_proj = h_size * i_size / weight_pack_factor;
3854 let down_proj = i_size * h_size / weight_pack_factor;
3855
3856 input_layernorm
3857 + post_attention_layernorm
3858 + q_proj
3859 + k_proj
3860 + v_proj
3861 + o_proj
3862 + gate_proj
3863 + up_proj
3864 + down_proj
3865 };
3866 Ok(vec![
3867 per_layer_elems * dtype.size_in_bytes();
3868 txt_cfg.num_hidden_layers
3869 ])
3870 }
3871
3872 fn num_layers(&self, config: &str) -> Result<usize> {
3873 let cfg: Gemma3Config = serde_json::from_str(config)?;
3874
3875 let txt_cfg = match &cfg {
3876 Gemma3Config::Text(cfg) => cfg,
3877 Gemma3Config::WithVision { text_config, .. } => text_config,
3878 };
3879
3880 Ok(txt_cfg.num_hidden_layers)
3881 }
3882
3883 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3884 let cfg: Gemma3Config = serde_json::from_str(config)?;
3885
3886 let cfg = match &cfg {
3887 Gemma3Config::Text(cfg) => cfg,
3888 Gemma3Config::WithVision { text_config, .. } => text_config,
3889 };
3890
3891 let cfg = ModelConfigMetadata {
3892 max_seq_len: cfg.max_position_embeddings,
3893 num_layers: cfg.num_hidden_layers,
3894 hidden_size: cfg.hidden_size,
3895 num_kv_heads: cfg.num_key_value_heads,
3896 num_attn_heads: cfg.num_attention_heads,
3897 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3899 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3900 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3901 };
3902
3903 Ok(Box::new(cfg))
3904 }
3905
3906 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3907 Some(vec![NonMappedSubModel::Vision])
3908 }
3909}
3910
3911pub struct Mistral3Loader;
3917
3918pub struct Mistral3Prefixer;
3919
3920impl MultimodalPromptPrefixer for Mistral3Prefixer {
3921 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3922 prompt.to_string()
3923 }
3924}
3925
3926impl VisionModelLoader for Mistral3Loader {
3927 fn load(
3928 &self,
3929 config: &str,
3930 vb: ShardedVarBuilder,
3931 normal_loading_metadata: NormalLoadingMetadata,
3932 attention_mechanism: AttentionImplementation,
3933 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
3934 let mut cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3935 cfg.propagate_quantization_config();
3936 Ok(Box::new(Mistral3Model::new(
3937 &cfg,
3938 vb,
3939 self.is_gptx(config),
3940 normal_loading_metadata,
3941 attention_mechanism,
3942 )?))
3943 }
3944 fn is_gptx(&self, _config: &str) -> bool {
3945 true
3946 }
3947 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3948 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
3949 Ok(Box::new(cfg))
3950 }
3951 fn get_processor(
3952 &self,
3953 _model_config: &str,
3954 processor_config: Option<ProcessorConfig>,
3955 _preprocessor_config: PreProcessorConfig,
3956 _max_edge: Option<u32>,
3957 ) -> Arc<dyn Processor + Send + Sync> {
3958 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
3959 }
3960 fn supports_paged_attention(&self, _config: &str) -> bool {
3961 true
3962 }
3963 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3964 true
3965 }
3966 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3967 Arc::new(Mistral3Prefixer)
3968 }
3969 fn modalities(&self, _config: &str) -> Result<Modalities> {
3970 Ok(Modalities {
3971 input: vec![SupportedModality::Text, SupportedModality::Vision],
3972 output: vec![SupportedModality::Text],
3973 })
3974 }
3975}
3976
3977impl IsqModelLoader for Mistral3Loader {
3978 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3979 Ok(vec![
3980 Regex::new(r"lm_head\.(weight|bias)$")?,
3981 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3983 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3984 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3985 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3986 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3988 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3989 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3990 ])
3991 }
3992 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3993 Ok(vec![
3994 Regex::new(r"lm_head\.(weight|bias)$")?,
3995 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3997 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3998 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3999 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4000 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4002 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4003 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4004 ])
4005 }
4006}
4007
4008#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4009impl DeviceMappedModelLoader for Mistral3Loader {
4010 fn mapped_max_act_size_elems(
4011 &self,
4012 config: &str,
4013 params: &AutoDeviceMapParams,
4014 ) -> Result<usize> {
4015 let cfg: Mistral3Config = serde_json::from_str(config)?;
4016 let vcfg = &cfg.vision_config;
4017 let tcfg = &cfg.text_config;
4018
4019 let AutoDeviceMapParams::Vision {
4020 max_seq_len,
4021 max_batch_size,
4022 max_image_shape: (mut height, mut width),
4023 max_num_images,
4024 } = params
4025 else {
4026 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4027 };
4028
4029 let img_seq_len = {
4030 let (max_height, max_width) = (1540, 1540);
4034 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4035 if ratio > 1. {
4036 height = (height as f64 / ratio).floor() as usize;
4037 width = (width as f64 / ratio).floor() as usize;
4038 }
4039
4040 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4041 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4042
4043 height = num_height_tokens * vcfg.patch_size;
4044 width = num_width_tokens * vcfg.patch_size;
4045
4046 let num_height_tokens = height / vcfg.patch_size;
4047 let num_width_tokens = width / vcfg.patch_size;
4048
4049 (num_width_tokens + 1) * num_height_tokens
4050 };
4051
4052 let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4054 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4055 }
4056
4057 fn non_mapped_max_act_size_elems(
4058 &self,
4059 config: &str,
4060 params: &AutoDeviceMapParams,
4061 ) -> Result<usize> {
4062 let cfg: Mistral3Config = serde_json::from_str(config)?;
4063 let cfg = &cfg.vision_config;
4064
4065 let AutoDeviceMapParams::Vision {
4066 max_seq_len: _,
4067 max_batch_size,
4068 max_image_shape: (mut height, mut width),
4069 max_num_images,
4070 } = params
4071 else {
4072 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4073 };
4074
4075 let img_seq_len = {
4076 let (max_height, max_width) = (1540, 1540);
4080 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4081 if ratio > 1. {
4082 height = (height as f64 / ratio).floor() as usize;
4083 width = (width as f64 / ratio).floor() as usize;
4084 }
4085
4086 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4087 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4088
4089 height = num_height_tokens * cfg.patch_size;
4090 width = num_width_tokens * cfg.patch_size;
4091
4092 let num_height_tokens = height / cfg.patch_size;
4093 let num_width_tokens = width / cfg.patch_size;
4094
4095 (num_width_tokens + 1) * num_height_tokens
4096 };
4097
4098 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4099 }
4100
4101 fn non_mapped_size_in_bytes(
4102 &self,
4103 config: &str,
4104 dtype: DType,
4105 weight_pack_factor: usize,
4106 _matformer_config: Option<&MatformerSliceConfig>,
4107 ) -> Result<usize> {
4108 let cfg: Mistral3Config = serde_json::from_str(config)?;
4109
4110 let text_elems = {
4111 let cfg = &cfg.text_config;
4112
4113 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4114 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4116 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4117 } else {
4118 0
4119 };
4120 let norm = cfg.hidden_size;
4121 embed_tokens + lm_head + norm
4122 };
4123
4124 let vision_elems = {
4125 let cfg = &cfg.vision_config;
4126
4127 let patch_embed = {
4128 let conv_cfg = Conv2dConfig {
4129 stride: cfg.patch_size,
4130 ..Default::default()
4131 };
4132 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4133 * cfg.patch_size
4134 * cfg.patch_size
4135 * cfg.patch_size
4136 };
4137 let ln_pre = cfg.hidden_size;
4138 let vision_layer = {
4139 let attn_norm = cfg.hidden_size;
4140 let ffn_norm = cfg.hidden_size;
4141
4142 let gate = cfg.hidden_size * cfg.intermediate_size;
4143 let up = cfg.hidden_size * cfg.intermediate_size;
4144 let down = cfg.hidden_size * cfg.intermediate_size;
4145
4146 let q = cfg.hidden_size * cfg.hidden_size;
4147 let k = cfg.hidden_size * cfg.hidden_size;
4148 let v = cfg.hidden_size * cfg.hidden_size;
4149 let o = cfg.hidden_size * cfg.hidden_size;
4150
4151 attn_norm + ffn_norm + gate + up + down + q + k + v + o
4152 };
4153
4154 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4155 };
4156
4157 let elems = text_elems + vision_elems;
4158
4159 Ok(elems * dtype.size_in_bytes())
4160 }
4161
4162 fn layer_sizes_in_bytes(
4163 &self,
4164 config: &str,
4165 dtype: DType,
4166 weight_pack_factor: usize,
4167 _matformer_config: Option<&MatformerSliceConfig>,
4168 ) -> Result<Vec<usize>> {
4169 let cfg: Mistral3Config = serde_json::from_str(config)?;
4170 let cfg = &cfg.text_config;
4171
4172 let per_layer_elems = {
4173 let input_layernorm = cfg.hidden_size;
4174 let post_attention_layernorm = cfg.hidden_size;
4175
4176 let size_in = cfg.hidden_size;
4177 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4178 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4179 let q_proj = size_in * size_q / weight_pack_factor;
4180 let k_proj = size_in * size_kv / weight_pack_factor;
4181 let v_proj = size_in * size_kv / weight_pack_factor;
4182 let o_proj = size_q * size_in / weight_pack_factor;
4183
4184 let h_size = cfg.hidden_size;
4185 let i_size = cfg.intermediate_size;
4186 let gate_proj = h_size * i_size / weight_pack_factor;
4187 let up_proj = h_size * i_size / weight_pack_factor;
4188 let down_proj = i_size * h_size / weight_pack_factor;
4189
4190 input_layernorm
4191 + post_attention_layernorm
4192 + q_proj
4193 + k_proj
4194 + v_proj
4195 + o_proj
4196 + gate_proj
4197 + up_proj
4198 + down_proj
4199 };
4200 Ok(vec![
4201 per_layer_elems * dtype.size_in_bytes();
4202 cfg.num_hidden_layers
4203 ])
4204 }
4205
4206 fn num_layers(&self, config: &str) -> Result<usize> {
4207 let cfg: Mistral3Config = serde_json::from_str(config)?;
4208 let cfg = &cfg.text_config;
4209 Ok(cfg.num_hidden_layers)
4210 }
4211
4212 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4213 let cfg: Mistral3Config = serde_json::from_str(config)?;
4214 let cfg = &cfg.text_config;
4215
4216 let cfg = ModelConfigMetadata {
4217 max_seq_len: cfg.max_position_embeddings,
4218 num_layers: cfg.num_hidden_layers,
4219 hidden_size: cfg.hidden_size,
4220 num_kv_heads: cfg.num_key_value_heads,
4221 num_attn_heads: cfg.num_attention_heads,
4222 sliding_window: cfg.sliding_window,
4223 k_head_dim: cfg.head_dim(),
4224 v_head_dim: cfg.head_dim(),
4225 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4226 };
4227
4228 Ok(Box::new(cfg))
4229 }
4230
4231 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4232 Some(vec![NonMappedSubModel::Vision])
4233 }
4234}
4235
4236pub struct VLlama4Loader;
4242
4243pub struct VLlama4Prefixer;
4244
4245impl MultimodalPromptPrefixer for VLlama4Prefixer {
4246 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4247 format!(
4248 "{}{prompt}",
4249 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4250 )
4251 }
4252}
4253
4254impl VisionModelLoader for VLlama4Loader {
4255 fn load(
4256 &self,
4257 config: &str,
4258 vb: ShardedVarBuilder,
4259 normal_loading_metadata: NormalLoadingMetadata,
4260 attention_mechanism: AttentionImplementation,
4261 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4262 let mut cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4263 cfg.propagate_quantization_config();
4264 Ok(Box::new(Llama4Model::new(
4265 &cfg,
4266 vb,
4267 self.is_gptx(config),
4268 normal_loading_metadata,
4269 attention_mechanism,
4270 )?))
4271 }
4272 fn is_gptx(&self, _config: &str) -> bool {
4273 false
4274 }
4275 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4276 let mut cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4277 cfg.propagate_quantization_config();
4278 Ok(Box::new(cfg))
4279 }
4280 fn get_processor(
4281 &self,
4282 _model_config: &str,
4283 processor_config: Option<ProcessorConfig>,
4284 _preprocessor_config: PreProcessorConfig,
4285 _max_edge: Option<u32>,
4286 ) -> Arc<dyn Processor + Send + Sync> {
4287 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4288 }
4289 fn supports_paged_attention(&self, _config: &str) -> bool {
4290 true
4291 }
4292 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4293 Arc::new(VLlama4Prefixer)
4294 }
4295 fn modalities(&self, _config: &str) -> Result<Modalities> {
4296 Ok(Modalities {
4297 input: vec![SupportedModality::Text, SupportedModality::Vision],
4298 output: vec![SupportedModality::Text],
4299 })
4300 }
4301}
4302
4303impl IsqModelLoader for VLlama4Loader {
4304 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4305 Ok(vec![
4306 Regex::new(r"lm_head\.(weight|bias)$")?,
4307 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4309 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4310 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4311 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4312 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4314 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4315 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4316 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4317 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4318 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4319 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4320 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4321 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4323 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4324 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4325 ])
4326 }
4327 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4328 Ok(vec![
4329 Regex::new(r"lm_head\.(weight|bias)$")?,
4330 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4332 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4333 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4334 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4335 Regex::new(
4337 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4338 )?,
4339 Regex::new(
4340 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4341 )?,
4342 Regex::new(
4343 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4344 )?,
4345 Regex::new(
4346 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4347 )?,
4348 Regex::new(
4349 r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4350 )?,
4351 Regex::new(
4352 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4353 )?,
4354 Regex::new(
4355 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4356 )?,
4357 Regex::new(
4358 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4359 )?,
4360 Regex::new(
4362 r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4363 )?,
4364 Regex::new(
4365 r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4366 )?,
4367 Regex::new(
4368 r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4369 )?,
4370 ])
4371 }
4372}
4373
4374impl VLlama4Loader {
4375 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4378 fn run_dummy_processing(
4379 &self,
4380 cfg: &Llama4Config,
4381 height: usize,
4382 width: usize,
4383 max_num_images: usize,
4384 max_batch_size: usize,
4385 ) -> Result<(usize, usize)> {
4386 let cfg = &cfg.vision_config;
4387
4388 let img_processor =
4389 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4390 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4391 let res = img_processor.preprocess(
4392 vec![image; max_num_images],
4393 vec![],
4394 &PreProcessorConfig::default(),
4395 &Device::Cpu,
4396 (max_batch_size, max_num_images),
4397 )?;
4398
4399 let pixels_batch_size = res.pixel_values.dim(0)?;
4400 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4401
4402 let (image_h, image_w) = (
4403 res.pixel_values.dim(D::Minus2).unwrap(),
4404 res.pixel_values.dim(D::Minus1).unwrap(),
4405 );
4406 let num_patches_per_chunk = (image_h / img_processor.patch_size)
4407 * (image_w / img_processor.patch_size)
4408 / img_processor.downsample_ratio;
4409
4410 Ok((
4411 pixels_max_batch_size,
4412 num_patches_per_chunk * pixels_max_batch_size,
4413 ))
4414 }
4415}
4416
4417impl DeviceMappedModelLoader for VLlama4Loader {
4418 fn mapped_max_act_size_elems(
4419 &self,
4420 config: &str,
4421 params: &AutoDeviceMapParams,
4422 ) -> Result<usize> {
4423 let AutoDeviceMapParams::Vision {
4424 max_seq_len,
4425 max_batch_size,
4426 max_image_shape: (height, width),
4427 max_num_images,
4428 } = params
4429 else {
4430 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4431 };
4432
4433 let cfg: Llama4Config = serde_json::from_str(config)?;
4434
4435 let (_pixels_batch_size, num_text_image_toks) =
4436 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4437
4438 let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4439
4440 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4441 }
4442 fn non_mapped_max_act_size_elems(
4443 &self,
4444 config: &str,
4445 params: &AutoDeviceMapParams,
4446 ) -> Result<usize> {
4447 let AutoDeviceMapParams::Vision {
4448 max_seq_len: _,
4449 max_batch_size,
4450 max_image_shape: (height, width),
4451 max_num_images,
4452 } = params
4453 else {
4454 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4455 };
4456
4457 let cfg: Llama4Config = serde_json::from_str(config)?;
4458
4459 let (pixels_batch_size, _num_text_image_toks) =
4460 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4461 let max_seq_len = cfg.vision_config.num_patches();
4462
4463 Ok((max_batch_size * pixels_batch_size)
4464 * cfg.vision_config.num_attention_heads
4465 * max_seq_len
4466 * max_seq_len)
4467 }
4468
4469 fn non_mapped_size_in_bytes(
4470 &self,
4471 config: &str,
4472 dtype: DType,
4473 weight_pack_factor: usize,
4474 _matformer_config: Option<&MatformerSliceConfig>,
4475 ) -> Result<usize> {
4476 let cfg: Llama4Config = serde_json::from_str(config)?;
4477 let tcfg = &cfg.text_config;
4478
4479 let text_elems = {
4480 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4481 let lm_head = if !tcfg.tie_word_embeddings {
4482 tcfg.hidden_size * tcfg.vocab_size
4483 } else {
4484 0
4485 };
4486 let norm = tcfg.hidden_size;
4487 embed_tokens + lm_head + norm
4488 };
4489
4490 let vision_elems = {
4491 let cfg = &cfg.vision_config;
4492
4493 let num_patches = cfg.num_patches();
4494
4495 let unfold_elems =
4496 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4497 let class_embeddng_elems = cfg.hidden_size;
4498 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4499 let layernorm_pre_elems = cfg.hidden_size;
4500 let layernorm_post_elems = cfg.hidden_size;
4501
4502 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4503 / weight_pack_factor
4504 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4505
4506 let encoder_layer = {
4507 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4508 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4509
4510 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4511 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4512 / weight_pack_factor
4513 + cfg.num_attention_heads * head_dim;
4514 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4515 / weight_pack_factor
4516 + cfg.num_attention_heads * head_dim;
4517 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4518 / weight_pack_factor
4519 + cfg.num_attention_heads * head_dim;
4520 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4521 / weight_pack_factor
4522 + cfg.num_attention_heads * head_dim;
4523
4524 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4525 + cfg.intermediate_size;
4526 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4527 + cfg.hidden_size;
4528
4529 input_layernorm
4530 + post_attention_layernorm
4531 + q_proj
4532 + k_proj
4533 + v_proj
4534 + o_proj
4535 + fc1
4536 + fc2
4537 };
4538
4539 unfold_elems
4540 + class_embeddng_elems
4541 + positional_embedding_vlm_elems
4542 + layernorm_post_elems
4543 + layernorm_pre_elems
4544 + pixel_shuffle_elems
4545 + encoder_layer * cfg.num_hidden_layers
4546 };
4547
4548 let elems = text_elems + vision_elems;
4549
4550 Ok(elems * dtype.size_in_bytes())
4551 }
4552
4553 fn layer_sizes_in_bytes(
4554 &self,
4555 config: &str,
4556 dtype: DType,
4557 weight_pack_factor: usize,
4558 _matformer_config: Option<&MatformerSliceConfig>,
4559 ) -> Result<Vec<usize>> {
4560 let cfg: Llama4Config = serde_json::from_str(config)?;
4561 let tcfg = &cfg.text_config;
4562
4563 let mut per_layer_elems = Vec::new();
4564
4565 for layer_idx in 0..tcfg.num_hidden_layers {
4566 let input_layernorm = tcfg.hidden_size;
4567 let post_attention_layernorm = tcfg.hidden_size;
4568
4569 let size_in = tcfg.hidden_size;
4570 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4571 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4572 let q_proj = size_in * size_q / weight_pack_factor;
4573 let k_proj = size_in * size_kv / weight_pack_factor;
4574 let v_proj = size_in * size_kv / weight_pack_factor;
4575 let o_proj = size_q * size_in / weight_pack_factor;
4576
4577 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4578 let moe_block = if use_moe {
4579 let h_size = tcfg.hidden_size;
4580 let i_size = tcfg.intermediate_size;
4581 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4582 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4583 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4584
4585 gate_proj + up_proj + down_proj
4586 } else {
4587 let h_size = tcfg.hidden_size;
4588 let i_size = tcfg.intermediate_size_mlp;
4589 let gate_proj = h_size * i_size / weight_pack_factor;
4590 let up_proj = h_size * i_size / weight_pack_factor;
4591 let down_proj = i_size * h_size / weight_pack_factor;
4592
4593 gate_proj + up_proj + down_proj
4594 };
4595
4596 per_layer_elems.push(
4597 input_layernorm
4598 + post_attention_layernorm
4599 + q_proj
4600 + k_proj
4601 + v_proj
4602 + o_proj
4603 + moe_block,
4604 );
4605 }
4606
4607 Ok(per_layer_elems
4608 .into_iter()
4609 .map(|x| x * dtype.size_in_bytes())
4610 .collect())
4611 }
4612
4613 fn num_layers(&self, config: &str) -> Result<usize> {
4614 let cfg: Llama4Config = serde_json::from_str(config)?;
4615 Ok(cfg.text_config.num_hidden_layers)
4616 }
4617
4618 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4619 let cfg: Llama4Config = serde_json::from_str(config)?;
4620 let cfg = &cfg.text_config;
4621
4622 let cfg = ModelConfigMetadata {
4623 max_seq_len: cfg.max_position_embeddings,
4624 num_layers: cfg.num_hidden_layers,
4625 hidden_size: cfg.hidden_size,
4626 num_kv_heads: cfg.num_attention_heads,
4627 num_attn_heads: cfg.num_attention_heads,
4628 sliding_window: None,
4629 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4630 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4631 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4632 };
4633
4634 Ok(Box::new(cfg))
4635 }
4636
4637 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4638 Some(vec![NonMappedSubModel::Vision])
4639 }
4640}
4641
4642pub struct Gemma3nLoader;
4648
4649#[allow(dead_code)]
4650pub struct Gemma3nPrefixer;
4651
4652impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4653 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4654 prompt.to_string()
4655 }
4656}
4657
4658impl VisionModelLoader for Gemma3nLoader {
4659 fn load(
4660 &self,
4661 config: &str,
4662 vb: ShardedVarBuilder,
4663 normal_loading_metadata: NormalLoadingMetadata,
4664 attention_mechanism: AttentionImplementation,
4665 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
4666 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4667 Ok(Box::new(Gemma3nModel::new(
4668 &cfg,
4669 vb,
4670 self.is_gptx(config),
4671 normal_loading_metadata,
4672 attention_mechanism,
4673 )?))
4674 }
4675 fn is_gptx(&self, _config: &str) -> bool {
4676 true
4677 }
4678 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4679 let config: Gemma3nConfig = serde_json::from_str(config)?;
4680 Ok(Box::new(config))
4681 }
4682 fn get_processor(
4683 &self,
4684 _config: &str,
4685 processor_config: Option<ProcessorConfig>,
4686 _preprocessor_config: PreProcessorConfig,
4687 _max_edge: Option<u32>,
4688 ) -> Arc<dyn Processor + Send + Sync> {
4689 Arc::new(Gemma3nProcessor::new(
4691 processor_config.unwrap_or_default(),
4692 true,
4693 ))
4694 }
4695 fn supports_paged_attention(&self, _config: &str) -> bool {
4696 false
4697 }
4698 fn supports_prefix_cacher(&self, _config: &str) -> bool {
4699 true
4700 }
4701 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4702 Arc::new(Gemma3Prefixer)
4703 }
4704 fn modalities(&self, _config: &str) -> Result<Modalities> {
4705 Ok(Modalities {
4706 input: vec![
4707 SupportedModality::Text,
4708 SupportedModality::Vision,
4709 SupportedModality::Audio,
4710 ],
4711 output: vec![SupportedModality::Text],
4712 })
4713 }
4714}
4715
4716impl IsqModelLoader for Gemma3nLoader {
4717 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4718 Ok(vec![
4719 Regex::new(r"lm_head\.(weight|bias)$")?,
4720 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4722 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4723 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4724 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4725 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4727 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4728 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4729 Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4731 Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4732 Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4733 Regex::new(
4734 r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4735 )?,
4736 Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4737 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4739 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4740 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4741 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4742 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4744 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4745 Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4747 Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4749 Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4750 ])
4751 }
4752 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4753 Ok(vec![
4754 Regex::new(r"lm_head\.(weight|bias)$")?,
4755 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4757 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4758 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4759 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4760 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4762 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4763 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4764 Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4766 Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4767 Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4768 Regex::new(
4770 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4771 )?,
4772 Regex::new(
4773 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4774 )?,
4775 Regex::new(
4776 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4777 )?,
4778 Regex::new(
4779 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4780 )?,
4781 Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4782 Regex::new(
4784 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4785 )?,
4786 Regex::new(
4787 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4788 )?,
4789 Regex::new(
4790 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4791 )?,
4792 Regex::new(
4793 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4794 )?,
4795 Regex::new(
4797 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4798 )?,
4799 Regex::new(
4800 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4801 )?,
4802 Regex::new(
4804 r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4805 )?,
4806 Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4808 Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4809 ])
4810 }
4811}
4812
4813impl DeviceMappedModelLoader for Gemma3nLoader {
4814 fn mapped_max_act_size_elems(
4815 &self,
4816 config: &str,
4817 params: &AutoDeviceMapParams,
4818 ) -> Result<usize> {
4819 let AutoDeviceMapParams::Vision {
4820 max_seq_len,
4821 max_batch_size,
4822 max_image_shape: _,
4823 max_num_images,
4824 } = params
4825 else {
4826 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4827 };
4828
4829 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4830 let text_cfg = &cfg.text_config;
4831
4832 let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4836
4837 {
4839 let msfa_spatial_size = 16; let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; total_seq_len += vision_tokens_per_image * max_num_images;
4844 }
4845
4846 {
4848 let audio_tokens = cfg.audio_soft_tokens_per_image;
4851 total_seq_len += audio_tokens;
4852 }
4853
4854 let max_text_attn =
4856 max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4857
4858 Ok(max_text_attn)
4859 }
4860
4861 fn non_mapped_max_act_size_elems(
4862 &self,
4863 config: &str,
4864 params: &AutoDeviceMapParams,
4865 ) -> Result<usize> {
4866 let AutoDeviceMapParams::Vision {
4867 max_seq_len: _,
4868 max_batch_size,
4869 max_image_shape: _,
4870 max_num_images,
4871 } = params
4872 else {
4873 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
4874 };
4875
4876 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4877
4878 let mut max_activation = 0;
4880
4881 {
4883 let vision_tower_act = {
4893 let num_heads = 16; let spatial_size = 24; let seq_len = spatial_size * spatial_size;
4899
4900 max_batch_size * max_num_images * num_heads * seq_len * seq_len
4902 };
4903
4904 let vision_embed_act = {
4906 let msfa_channels = 2048; let spatial_size = 16; let vision_features =
4910 max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4911
4912 let projected = max_batch_size
4914 * max_num_images
4915 * spatial_size
4916 * spatial_size
4917 * cfg.text_config.hidden_size;
4918
4919 vision_features.max(projected)
4920 };
4921
4922 max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
4923 }
4924
4925 {
4927 let audio_cfg = &cfg.audio_config;
4928
4929 let max_audio_frames = 1280;
4934
4935 let subsample_factor: usize = audio_cfg
4936 .sscp_conv_stride_size
4937 .iter()
4938 .map(|stride| stride[0]) .product();
4940 let audio_seq_after_subsample = max_audio_frames / subsample_factor;
4941
4942 let audio_encoder_act = {
4944 let intermediate_size = audio_cfg.hidden_size * 4; max_batch_size * audio_seq_after_subsample * intermediate_size
4949 };
4950
4951 let audio_attn_act = {
4953 let chunk_size = audio_cfg.conf_attention_chunk_size;
4955 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
4956 + audio_cfg.conf_attention_context_right;
4957
4958 let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
4960
4961 max_batch_size
4962 * audio_cfg.conf_num_attention_heads
4963 * num_chunks
4964 * chunk_size
4965 * context_size
4966 };
4967
4968 max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
4969 }
4970
4971 Ok(max_activation)
4972 }
4973
4974 fn non_mapped_size_in_bytes(
4975 &self,
4976 config: &str,
4977 dtype: DType,
4978 weight_pack_factor: usize,
4979 matformer_config: Option<&MatformerSliceConfig>,
4980 ) -> Result<usize> {
4981 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4982
4983 let text_cfg = if let Some(matformer_cfg) = matformer_config {
4985 use crate::device_map::DummyDeviceMapper;
4986 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
4987
4988 let dummy_mapper = DummyDeviceMapper {
4989 nm_device: Device::Cpu,
4990 };
4991 let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
4992 &cfg.text_config,
4993 &Some(matformer_cfg.clone()),
4994 &dummy_mapper,
4995 )?;
4996 adjusted_cfg
4997 } else {
4998 cfg.text_config.clone()
4999 };
5000
5001 let text_cfg = &text_cfg;
5002
5003 let text_elems = {
5005 let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
5007 let embed_tokens_per_layer = text_cfg.num_hidden_layers
5008 * text_cfg.hidden_size_per_layer_input
5009 * text_cfg.vocab_size_per_layer_input;
5010
5011 let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
5013 text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
5014 } else {
5015 0
5016 };
5017
5018 let norm = text_cfg.hidden_size;
5020
5021 let altup_projections =
5023 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5024 / weight_pack_factor;
5025 let altup_unembed_projections =
5026 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5027 / weight_pack_factor;
5028
5029 let per_layer_model_projection = text_cfg.num_hidden_layers
5031 * text_cfg.hidden_size
5032 * text_cfg.hidden_size_per_layer_input
5033 / weight_pack_factor;
5034 let per_layer_projection_norm = text_cfg.hidden_size;
5035
5036 embed_tokens
5037 + embed_tokens_per_layer
5038 + lm_head
5039 + norm
5040 + altup_projections
5041 + altup_unembed_projections
5042 + per_layer_model_projection
5043 + per_layer_projection_norm
5044 };
5045
5046 let vision_elems = {
5048 let vision_cfg = &cfg.vision_config;
5049 let vision_tower_elems = {
5053 use crate::vision_models::gemma3n::vision::{
5054 gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5055 MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5056 STEM_OUT_CHANNELS,
5057 };
5058
5059 let stem_conv =
5061 INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5062 let stem_norm = STEM_OUT_CHANNELS; let mut in_chs = STEM_OUT_CHANNELS;
5066 let mut total_elems = stem_conv + stem_norm;
5067
5068 let block_defs = gemma3n_mobilenet_def();
5070
5071 for stage_blocks in block_defs.iter() {
5072 for block_type in stage_blocks.iter() {
5073 match block_type {
5074 BlockType::EdgeResidual {
5075 out_channels,
5076 kernel_size,
5077 stride: _,
5078 expand_ratio,
5079 ..
5080 } => {
5081 #[allow(clippy::cast_precision_loss)]
5082 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5083 total_elems += in_chs * mid_chs * kernel_size * kernel_size; total_elems += mid_chs; total_elems += mid_chs * out_channels; total_elems += out_channels; in_chs = *out_channels;
5089 }
5090 BlockType::UniversalInvertedResidual {
5091 out_channels,
5092 start_kernel_size,
5093 mid_kernel_size,
5094 stride: _,
5095 expand_ratio,
5096 ..
5097 } => {
5098 #[allow(clippy::cast_precision_loss)]
5099 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5100 if *expand_ratio != 1.0 {
5102 total_elems += in_chs * mid_chs; total_elems += mid_chs; }
5105 if *start_kernel_size > 0 {
5106 total_elems += mid_chs * start_kernel_size * start_kernel_size; total_elems += mid_chs; }
5109 if *mid_kernel_size > 0 {
5110 total_elems += mid_chs * mid_kernel_size * mid_kernel_size; total_elems += mid_chs; }
5113 total_elems += mid_chs * out_channels; total_elems += out_channels; total_elems += out_channels; in_chs = *out_channels;
5117 }
5118 BlockType::MultiQueryAttention {
5119 num_heads,
5120 kv_dim,
5121 kv_stride: _,
5122 ..
5123 } => {
5124 let dw_kernel_size = 3; total_elems += in_chs; total_elems += in_chs * num_heads * kv_dim; total_elems += in_chs * kv_dim; total_elems += in_chs * dw_kernel_size * dw_kernel_size; total_elems += *kv_dim; total_elems += 1; total_elems += *kv_dim; total_elems += num_heads * kv_dim * in_chs; total_elems += in_chs; }
5136 }
5137 }
5138 }
5139
5140 let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5142 let msfa_out = MSFA_OUT_CHANNELS;
5143 #[allow(clippy::cast_precision_loss)]
5144 let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5145
5146 total_elems += msfa_in * msfa_mid; total_elems += msfa_mid; total_elems += msfa_mid * msfa_out; total_elems += msfa_out; total_elems += msfa_out; total_elems
5154 };
5155
5156 let embed_vision_elems = {
5158 let embedding = vision_cfg.vocab_size * vision_cfg.hidden_size;
5160
5161 let hard_norm = vision_cfg.hidden_size;
5163 let soft_norm = vision_cfg.hidden_size;
5164
5165 let projection = vision_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5167
5168 let post_norm = text_cfg.hidden_size;
5170
5171 embedding + hard_norm + soft_norm + projection + post_norm
5172 };
5173
5174 vision_tower_elems + embed_vision_elems
5175 };
5176
5177 let audio_elems = {
5179 let audio_cfg = &cfg.audio_config;
5180
5181 let subsample_conv_projection_elems = {
5183 let mut conv_elems = 0;
5185
5186 let in_ch_0 = 1;
5188 let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5189 let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5190 conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5191
5192 let in_ch_1 = out_ch_0;
5194 let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5195 let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5196 conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5197
5198 let norm_0 = out_ch_0; let norm_1 = out_ch_1; let mut f_out = audio_cfg.input_feat_size;
5204 for i in 0..2 {
5205 let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5206 let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5207 let pad_left = 1;
5208 let pad_right = 1;
5209 f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5210 }
5211 let input_proj_in_features = out_ch_1 * f_out;
5212 let input_proj_linear =
5213 input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5214
5215 conv_elems + norm_0 + norm_1 + input_proj_linear
5216 };
5217
5218 let conformer_elems = {
5220 let mut total = 0;
5221
5222 for _ in 0..audio_cfg.conf_num_hidden_layers {
5223 let attention_elems = {
5225 let pre_attn_norm = audio_cfg.hidden_size;
5227 let post_norm = audio_cfg.hidden_size;
5228
5229 let q_proj =
5231 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5232 let k_proj =
5233 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5234 let v_proj =
5235 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5236 let post =
5237 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5238
5239 let pos_proj =
5241 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5242 let per_dim_scale =
5243 audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; let inv_timescales = audio_cfg.hidden_size / 2; let pos_indices = audio_cfg.conf_attention_context_left
5246 + audio_cfg.conf_attention_context_right
5247 + 1;
5248
5249 let chunk_size = audio_cfg.conf_attention_chunk_size;
5251 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5252 + audio_cfg.conf_attention_context_right;
5253 let local_causal_valid_mask = chunk_size * context_size; let invalid_logits_tensor = 1; pre_attn_norm
5257 + post_norm
5258 + q_proj
5259 + k_proj
5260 + v_proj
5261 + post
5262 + pos_proj
5263 + per_dim_scale
5264 + inv_timescales
5265 + pos_indices
5266 + local_causal_valid_mask
5267 + invalid_logits_tensor
5268 };
5269
5270 let ffw_elems = {
5272 let intermediate_size = audio_cfg.hidden_size * 4;
5278
5279 let ffw_start = {
5280 let pre_norm = audio_cfg.hidden_size;
5281 let layer_1 =
5282 audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5283 let layer_2 =
5284 intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5285 let post_norm = audio_cfg.hidden_size;
5286 pre_norm + layer_1 + layer_2 + post_norm
5287 };
5288
5289 let ffw_end = ffw_start; ffw_start + ffw_end
5292 };
5293
5294 let lconv1d_elems = {
5296 let pre_layer_norm = audio_cfg.hidden_size;
5298 let conv_norm = audio_cfg.hidden_size;
5299
5300 let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5302 / weight_pack_factor;
5303 let linear_end =
5304 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5305
5306 let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5308
5309 pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5310 };
5311
5312 let block_norm = audio_cfg.hidden_size;
5314
5315 total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5316 }
5317
5318 total
5319 };
5320
5321 let embed_audio_elems = {
5323 let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5325
5326 let hard_embedding_norm = audio_cfg.hidden_size; let soft_embedding_norm = audio_cfg.hidden_size; let embedding_post_projection_norm = text_cfg.hidden_size; let embedding_projection =
5333 audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5334
5335 embedding
5336 + hard_embedding_norm
5337 + soft_embedding_norm
5338 + embedding_post_projection_norm
5339 + embedding_projection
5340 };
5341
5342 subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5343 };
5344
5345 let vision_dtype = if dtype == DType::F16 {
5346 DType::F32
5348 } else {
5349 dtype
5350 };
5351
5352 let total_elems = text_elems * dtype.size_in_bytes()
5353 + vision_elems * vision_dtype.size_in_bytes()
5354 + audio_elems * dtype.size_in_bytes();
5355
5356 Ok(total_elems)
5357 }
5358
5359 fn layer_sizes_in_bytes(
5360 &self,
5361 config: &str,
5362 dtype: DType,
5363 weight_pack_factor: usize,
5364 matformer_config: Option<&MatformerSliceConfig>,
5365 ) -> Result<Vec<usize>> {
5366 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5367
5368 let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5370 matformer_config
5371 {
5372 use crate::device_map::DummyDeviceMapper;
5373 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5374
5375 let dummy_mapper = DummyDeviceMapper {
5376 nm_device: Device::Cpu,
5377 };
5378 let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5379 &cfg.text_config,
5380 &Some(matformer_cfg.clone()),
5381 &dummy_mapper,
5382 )?;
5383 (adjusted_cfg, layer_rename_map, layers_skipped)
5384 } else {
5385 (cfg.text_config.clone(), None, None)
5386 };
5387
5388 let text_cfg = &text_cfg;
5389
5390 let mut layer_sizes = Vec::new();
5392
5393 for layer_idx in 0..text_cfg.num_hidden_layers {
5397 let per_layer_elems = {
5398 let input_layernorm = text_cfg.hidden_size;
5400 let post_attention_layernorm = text_cfg.hidden_size;
5401 let pre_feedforward_layernorm = text_cfg.hidden_size;
5402 let post_feedforward_layernorm = text_cfg.hidden_size;
5403 let post_per_layer_input_norm = text_cfg.hidden_size;
5404
5405 let size_in = text_cfg.hidden_size;
5407 let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5408 let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5409
5410 let q_proj = size_in * size_q / weight_pack_factor;
5411 let k_proj = size_in * size_kv / weight_pack_factor;
5412 let v_proj = size_in * size_kv / weight_pack_factor;
5413 let o_proj = size_q * size_in / weight_pack_factor;
5414
5415 let q_norm = text_cfg.head_dim;
5417 let k_norm = text_cfg.head_dim;
5418 let v_norm = text_cfg.head_dim; let intermediate_size = match &text_cfg.intermediate_size {
5422 IntermediateSize::Single(size) => *size,
5423 IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5424 IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5425 };
5426 let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5427 let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5428 let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5429
5430 let altup_elems = {
5432 let correct_output_scale = text_cfg.hidden_size;
5433 let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5434 let prediction_coefs =
5435 text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5436 let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5437 let router_norm = text_cfg.hidden_size;
5438
5439 correct_output_scale
5440 + correction_coefs
5441 + prediction_coefs
5442 + modality_router
5443 + router_norm
5444 };
5445
5446 let laurel_elems = {
5448 let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5449 let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5450 let post_norm = text_cfg.hidden_size;
5451
5452 left + right + post_norm
5453 };
5454
5455 let per_layer_input_gate =
5457 text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5458 let per_layer_projection =
5459 text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5460
5461 input_layernorm
5462 + post_attention_layernorm
5463 + pre_feedforward_layernorm
5464 + post_feedforward_layernorm
5465 + post_per_layer_input_norm
5466 + q_proj
5467 + k_proj
5468 + v_proj
5469 + o_proj
5470 + q_norm
5471 + k_norm
5472 + v_norm
5473 + gate_proj
5474 + up_proj
5475 + down_proj
5476 + altup_elems
5477 + laurel_elems
5478 + per_layer_input_gate
5479 + per_layer_projection
5480 };
5481
5482 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5483 }
5484
5485 Ok(layer_sizes)
5486 }
5487
5488 fn num_layers(&self, config: &str) -> Result<usize> {
5489 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5490 Ok(cfg.text_config.num_hidden_layers)
5491 }
5492
5493 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5494 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5495 let cfg = cfg.text_config;
5496
5497 let cfg = ModelConfigMetadata {
5498 max_seq_len: cfg.max_position_embeddings,
5499 num_layers: cfg.num_hidden_layers,
5500 hidden_size: cfg.hidden_size,
5501 num_kv_heads: cfg.num_key_value_heads,
5502 num_attn_heads: cfg.num_attention_heads,
5503 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5505 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5506 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5507 };
5508
5509 Ok(Box::new(cfg))
5510 }
5511
5512 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5513 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5514 }
5515}
5516
5517pub struct Qwen3VLLoader;
5523
5524pub struct Qwen3VLPrefixer;
5525
5526impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5527 }
5530
5531impl VisionModelLoader for Qwen3VLLoader {
5532 fn load(
5533 &self,
5534 config: &str,
5535 vb: ShardedVarBuilder,
5536 normal_loading_metadata: NormalLoadingMetadata,
5537 attention_mechanism: AttentionImplementation,
5538 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5539 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5540 Ok(Box::new(Qwen3VLModel::new(
5541 &cfg,
5542 vb,
5543 self.is_gptx(config),
5544 normal_loading_metadata,
5545 attention_mechanism,
5546 )?))
5547 }
5548 fn is_gptx(&self, _config: &str) -> bool {
5549 true
5550 }
5551 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5552 let config: Qwen3VLConfig = serde_json::from_str(config)?;
5553 Ok(Box::new(config))
5554 }
5555 fn get_processor(
5556 &self,
5557 _model_config: &str,
5558 _processor_config: Option<ProcessorConfig>,
5559 _preprocessor_config: PreProcessorConfig,
5560 max_edge: Option<u32>,
5561 ) -> Arc<dyn Processor + Send + Sync> {
5562 Arc::new(Qwen3VLProcessor::new(max_edge))
5563 }
5564 fn supports_paged_attention(&self, _config: &str) -> bool {
5565 true
5566 }
5567 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5568 Arc::new(Qwen3VLPrefixer)
5569 }
5570 fn modalities(&self, _config: &str) -> Result<Modalities> {
5571 Ok(Modalities {
5572 input: vec![SupportedModality::Text, SupportedModality::Vision],
5573 output: vec![SupportedModality::Text],
5574 })
5575 }
5576}
5577
5578impl IsqModelLoader for Qwen3VLLoader {
5579 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5580 Ok(vec![
5581 Regex::new(r"lm_head\.(weight|bias)$")?,
5582 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5584 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5585 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5586 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5587 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5589 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5590 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5591 ])
5592 }
5593 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5594 self.isq_layer_regexes(config)
5595 }
5596}
5597
5598impl DeviceMappedModelLoader for Qwen3VLLoader {
5599 fn mapped_max_act_size_elems(
5600 &self,
5601 config: &str,
5602 params: &AutoDeviceMapParams,
5603 ) -> Result<usize> {
5604 let AutoDeviceMapParams::Vision {
5605 max_seq_len,
5606 max_batch_size,
5607 max_image_shape,
5608 max_num_images,
5609 } = params
5610 else {
5611 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5612 };
5613
5614 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5615
5616 let img_seq_len = {
5618 let cfg = &cfg.vision_config;
5619 let grid_t = 1;
5621 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5623 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5624 grid_t * grid_h * grid_w * max_num_images
5625 };
5626
5627 let max_text_attn = {
5628 let cfg = &cfg.text_config;
5629 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5631 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5632 };
5633
5634 Ok(max_text_attn)
5635 }
5636
5637 fn non_mapped_max_act_size_elems(
5638 &self,
5639 config: &str,
5640 params: &AutoDeviceMapParams,
5641 ) -> Result<usize> {
5642 let AutoDeviceMapParams::Vision {
5643 max_seq_len: _,
5644 max_batch_size,
5645 max_image_shape,
5646 max_num_images,
5647 } = params
5648 else {
5649 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5650 };
5651
5652 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5653
5654 let img_seq_len = {
5656 let cfg = &cfg.vision_config;
5657 let grid_t = 1;
5659 let grid_h = max_image_shape.0 / cfg.patch_size;
5660 let grid_w = max_image_shape.1 / cfg.patch_size;
5661 grid_t * grid_h * grid_w
5662 };
5663
5664 let max_vision_attn = {
5665 let cfg = &cfg.vision_config;
5666 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5667 };
5668
5669 Ok(max_vision_attn)
5670 }
5671
5672 fn non_mapped_size_in_bytes(
5673 &self,
5674 config: &str,
5675 dtype: DType,
5676 weight_pack_factor: usize,
5677 _matformer_config: Option<&MatformerSliceConfig>,
5678 ) -> Result<usize> {
5679 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5680 let tie = cfg.tie_word_embeddings;
5681 let text_elems = {
5682 let cfg = &cfg.text_config;
5683 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5684 let lm_head = if !tie || weight_pack_factor != 1 {
5686 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5687 } else {
5688 0
5689 };
5690 let norm = cfg.hidden_size;
5691 embed_tokens + lm_head + norm
5692 };
5693
5694 let patch_merger = {
5695 let cfg = &cfg.vision_config;
5696 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5697
5698 let mlp0 = hidden_size * hidden_size + hidden_size;
5699 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
5700
5701 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5702
5703 mlp0 + mlp2 + ln_q
5704 };
5705
5706 let patch_embed = {
5707 let cfg = &cfg.vision_config;
5708 let conv_cfg = Conv3dConfig {
5709 stride: cfg.patch_size,
5710 ..Default::default()
5711 };
5712 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5713 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5714 * kernel_sizes[0]
5715 * kernel_sizes[1]
5716 * kernel_sizes[2]
5717 };
5718
5719 let encoder_layer = {
5720 let cfg = &cfg.vision_config;
5721 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5722 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5723
5724 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5725 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5726 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5727
5728 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5729 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5730
5731 norm1 + norm2 + fc1 + fc2 + qkv + out
5732 };
5733
5734 let elems =
5735 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
5736
5737 Ok(elems * dtype.size_in_bytes())
5738 }
5739
5740 fn layer_sizes_in_bytes(
5741 &self,
5742 config: &str,
5743 dtype: DType,
5744 weight_pack_factor: usize,
5745 _matformer_config: Option<&MatformerSliceConfig>,
5746 ) -> Result<Vec<usize>> {
5747 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5748 let per_layer_elems = {
5749 let cfg = &cfg.text_config;
5750 let input_layernorm = cfg.hidden_size;
5751 let post_attention_layernorm = cfg.hidden_size;
5752
5753 let size_in = cfg.hidden_size;
5754 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
5755 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
5756 let q_proj = size_in * size_q / weight_pack_factor + size_q;
5757 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
5758 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
5759 let o_proj = size_q * size_in / weight_pack_factor;
5760
5761 let h_size = cfg.hidden_size;
5762 let i_size = cfg.intermediate_size;
5763 let gate_proj = h_size * i_size / weight_pack_factor;
5764 let up_proj = h_size * i_size / weight_pack_factor;
5765 let down_proj = i_size * h_size / weight_pack_factor;
5766
5767 input_layernorm
5768 + post_attention_layernorm
5769 + q_proj
5770 + k_proj
5771 + v_proj
5772 + o_proj
5773 + gate_proj
5774 + up_proj
5775 + down_proj
5776 };
5777 Ok(vec![
5778 per_layer_elems * dtype.size_in_bytes();
5779 cfg.text_config.num_hidden_layers
5780 ])
5781 }
5782
5783 fn num_layers(&self, config: &str) -> Result<usize> {
5784 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5785 let cfg = &cfg.text_config;
5786 Ok(cfg.num_hidden_layers)
5787 }
5788
5789 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5790 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5791 let cfg = &cfg.text_config;
5792
5793 let cfg = ModelConfigMetadata {
5794 max_seq_len: cfg.max_position_embeddings,
5795 num_layers: cfg.num_hidden_layers,
5796 hidden_size: cfg.hidden_size,
5797 num_kv_heads: cfg.num_key_value_heads,
5798 num_attn_heads: cfg.num_attention_heads,
5799 sliding_window: cfg.sliding_window,
5800 k_head_dim: cfg.head_dim,
5801 v_head_dim: cfg.head_dim,
5802 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5803 };
5804
5805 Ok(Box::new(cfg))
5806 }
5807
5808 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5809 Some(vec![NonMappedSubModel::Vision])
5810 }
5811}
5812
5813pub struct Qwen3VLMoELoader;
5819
5820pub struct Qwen3VLMoEPrefixer;
5821
5822impl MultimodalPromptPrefixer for Qwen3VLMoEPrefixer {
5823 }
5826
5827impl VisionModelLoader for Qwen3VLMoELoader {
5828 fn load(
5829 &self,
5830 config: &str,
5831 vb: ShardedVarBuilder,
5832 normal_loading_metadata: NormalLoadingMetadata,
5833 attention_mechanism: AttentionImplementation,
5834 ) -> Result<Box<dyn VisionModel + Send + Sync>> {
5835 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5836 Ok(Box::new(Qwen3VLMoEModel::new(
5837 &cfg,
5838 vb,
5839 self.is_gptx(config),
5840 normal_loading_metadata,
5841 attention_mechanism,
5842 )?))
5843 }
5844 fn is_gptx(&self, _config: &str) -> bool {
5845 true
5846 }
5847 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5848 let config: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5849 Ok(Box::new(config))
5850 }
5851 fn get_processor(
5852 &self,
5853 _model_config: &str,
5854 _processor_config: Option<ProcessorConfig>,
5855 _preprocessor_config: PreProcessorConfig,
5856 max_edge: Option<u32>,
5857 ) -> Arc<dyn Processor + Send + Sync> {
5858 Arc::new(Qwen3VLMoEProcessor::new(max_edge))
5859 }
5860 fn supports_paged_attention(&self, _config: &str) -> bool {
5861 true
5862 }
5863 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5864 Arc::new(Qwen3VLMoEPrefixer)
5865 }
5866 fn modalities(&self, _config: &str) -> Result<Modalities> {
5867 Ok(Modalities {
5868 input: vec![SupportedModality::Text, SupportedModality::Vision],
5869 output: vec![SupportedModality::Text],
5870 })
5871 }
5872}
5873
5874impl IsqModelLoader for Qwen3VLMoELoader {
5875 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5876 Ok(vec![
5877 Regex::new(r"lm_head\.(weight|bias)$")?,
5878 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5880 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5881 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5882 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5883 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5885 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5886 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5887 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
5889 Regex::new(
5891 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
5892 )?,
5893 Regex::new(
5894 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
5895 )?,
5896 Regex::new(
5897 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
5898 )?,
5899 ])
5900 }
5901 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5902 self.isq_layer_regexes(config)
5903 }
5904}
5905
5906impl DeviceMappedModelLoader for Qwen3VLMoELoader {
5907 fn mapped_max_act_size_elems(
5908 &self,
5909 config: &str,
5910 params: &AutoDeviceMapParams,
5911 ) -> Result<usize> {
5912 let AutoDeviceMapParams::Vision {
5913 max_seq_len,
5914 max_batch_size,
5915 max_image_shape,
5916 max_num_images,
5917 } = params
5918 else {
5919 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5920 };
5921
5922 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5923
5924 let img_seq_len = {
5926 let cfg = &cfg.vision_config;
5927 let grid_t = 1;
5929 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5931 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5932 grid_t * grid_h * grid_w * max_num_images
5933 };
5934
5935 let max_text_attn = {
5936 let cfg = &cfg.text_config;
5937 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5939 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5940 };
5941
5942 Ok(max_text_attn)
5943 }
5944
5945 fn non_mapped_max_act_size_elems(
5946 &self,
5947 config: &str,
5948 params: &AutoDeviceMapParams,
5949 ) -> Result<usize> {
5950 let AutoDeviceMapParams::Vision {
5951 max_seq_len: _,
5952 max_batch_size,
5953 max_image_shape,
5954 max_num_images,
5955 } = params
5956 else {
5957 anyhow::bail!("Expected vision AutoDeviceMapParams for this model!")
5958 };
5959
5960 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5961
5962 let img_seq_len = {
5964 let cfg = &cfg.vision_config;
5965 let grid_t = 1;
5967 let grid_h = max_image_shape.0 / cfg.patch_size;
5968 let grid_w = max_image_shape.1 / cfg.patch_size;
5969 grid_t * grid_h * grid_w
5970 };
5971
5972 let max_vision_attn = {
5973 let cfg = &cfg.vision_config;
5974 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5975 };
5976
5977 Ok(max_vision_attn)
5978 }
5979
5980 fn non_mapped_size_in_bytes(
5981 &self,
5982 config: &str,
5983 dtype: DType,
5984 weight_pack_factor: usize,
5985 _matformer_config: Option<&MatformerSliceConfig>,
5986 ) -> Result<usize> {
5987 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5988 let tie = cfg.tie_word_embeddings;
5989 let text_elems = {
5990 let cfg = &cfg.text_config;
5991 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5992 let lm_head = if !tie || weight_pack_factor != 1 {
5994 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5995 } else {
5996 0
5997 };
5998 let norm = cfg.hidden_size;
5999 embed_tokens + lm_head + norm
6000 };
6001
6002 let patch_merger = {
6003 let cfg = &cfg.vision_config;
6004 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6005
6006 let mlp0 = hidden_size * hidden_size + hidden_size;
6007 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
6008
6009 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6010
6011 mlp0 + mlp2 + ln_q
6012 };
6013
6014 let patch_embed = {
6015 let cfg = &cfg.vision_config;
6016 let conv_cfg = Conv3dConfig {
6017 stride: cfg.patch_size,
6018 ..Default::default()
6019 };
6020 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6021 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6022 * kernel_sizes[0]
6023 * kernel_sizes[1]
6024 * kernel_sizes[2]
6025 };
6026
6027 let encoder_layer = {
6028 let cfg = &cfg.vision_config;
6029 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6030 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6031
6032 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
6033 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6034 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6035
6036 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6037 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6038
6039 norm1 + norm2 + fc1 + fc2 + qkv + out
6040 };
6041
6042 let elems =
6043 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
6044
6045 Ok(elems * dtype.size_in_bytes())
6046 }
6047
6048 fn layer_sizes_in_bytes(
6049 &self,
6050 config: &str,
6051 dtype: DType,
6052 weight_pack_factor: usize,
6053 _matformer_config: Option<&MatformerSliceConfig>,
6054 ) -> Result<Vec<usize>> {
6055 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6056 let text_cfg = &cfg.text_config;
6057
6058 let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6059
6060 for layer_idx in 0..text_cfg.num_hidden_layers {
6061 let input_layernorm = text_cfg.hidden_size;
6062 let post_attention_layernorm = text_cfg.hidden_size;
6063
6064 let size_in = text_cfg.hidden_size;
6065 let size_q = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6066 * text_cfg.num_attention_heads;
6067 let size_kv = (text_cfg.hidden_size / text_cfg.num_attention_heads)
6068 * text_cfg.num_key_value_heads;
6069 let q_proj = size_in * size_q / weight_pack_factor + size_q;
6070 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
6071 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
6072 let o_proj = size_q * size_in / weight_pack_factor;
6073
6074 let is_moe = !text_cfg.mlp_only_layers.contains(&layer_idx)
6076 && (text_cfg.num_experts > 0
6077 && (layer_idx + 1) % text_cfg.decoder_sparse_step == 0);
6078
6079 let mlp_elems = if is_moe {
6080 let gate = text_cfg.hidden_size * text_cfg.num_experts;
6082 let per_expert = {
6083 let h_size = text_cfg.hidden_size;
6084 let i_size = text_cfg.moe_intermediate_size;
6085 let gate_proj = h_size * i_size / weight_pack_factor;
6086 let up_proj = h_size * i_size / weight_pack_factor;
6087 let down_proj = i_size * h_size / weight_pack_factor;
6088 gate_proj + up_proj + down_proj
6089 };
6090 gate + per_expert * text_cfg.num_experts
6091 } else {
6092 let h_size = text_cfg.hidden_size;
6094 let i_size = text_cfg.intermediate_size;
6095 let gate_proj = h_size * i_size / weight_pack_factor;
6096 let up_proj = h_size * i_size / weight_pack_factor;
6097 let down_proj = i_size * h_size / weight_pack_factor;
6098 gate_proj + up_proj + down_proj
6099 };
6100
6101 let per_layer_elems = input_layernorm
6102 + post_attention_layernorm
6103 + q_proj
6104 + k_proj
6105 + v_proj
6106 + o_proj
6107 + mlp_elems;
6108
6109 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6110 }
6111
6112 Ok(layer_sizes)
6113 }
6114
6115 fn num_layers(&self, config: &str) -> Result<usize> {
6116 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6117 let cfg = &cfg.text_config;
6118 Ok(cfg.num_hidden_layers)
6119 }
6120
6121 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6122 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6123 let cfg = &cfg.text_config;
6124
6125 let cfg = ModelConfigMetadata {
6126 max_seq_len: cfg.max_position_embeddings,
6127 num_layers: cfg.num_hidden_layers,
6128 hidden_size: cfg.hidden_size,
6129 num_kv_heads: cfg.num_key_value_heads,
6130 num_attn_heads: cfg.num_attention_heads,
6131 sliding_window: cfg.sliding_window,
6132 k_head_dim: cfg.head_dim,
6133 v_head_dim: cfg.head_dim,
6134 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
6135 };
6136
6137 Ok(Box::new(cfg))
6138 }
6139
6140 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6141 Some(vec![NonMappedSubModel::Vision])
6142 }
6143}