1use std::any::Any;
2use std::sync::atomic::AtomicUsize;
3use std::sync::Arc;
4use std::{fmt::Debug, str::FromStr};
5
6use anyhow::Result;
7use hanzo_ml::{DType, Device, Tensor, D};
8use hanzo_nn::Conv2dConfig;
9use hanzo_quant::log::once_log_debug;
10use hanzo_quant::ShardedVarBuilder;
11use image::{ColorType, DynamicImage};
12use itertools::Itertools;
13
14#[cfg(feature = "pyo3_macros")]
15use pyo3::pyclass;
16
17use regex::Regex;
18use serde::Deserialize;
19
20use self::minicpmo::{MiniCpmOConfig, MiniCpmOModel, MiniCpmOProcessor};
21
22use super::{DeviceMappedModelLoader, NonMappedSubModel, NormalLoadingMetadata};
23use crate::amoe::AnyMoeBaseModelMixin;
24use crate::attention::ATTENTION_CHUNK_SIZE;
25use crate::device_map::DeviceMapper;
26use crate::layers::Conv3dConfig;
27use crate::matformer::MatformerSliceConfig;
28use crate::paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata};
29use crate::pipeline::isq::IsqModelLoader;
30use crate::pipeline::loaders::AutoDeviceMapParams;
31use crate::pipeline::text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata};
32use crate::pipeline::{
33 EitherCache, IsqModel, Modalities, MultimodalPromptPrefixer, Processor, ProcessorCreator,
34 SupportedModality,
35};
36use crate::speculative::SpeculativeTargetMixin;
37use crate::utils::varbuilder_utils::DeviceForLoadTensor;
38use crate::vision_models::clip::ClipConfig;
39use crate::vision_models::gemma3::config::Gemma3Config;
40use crate::vision_models::gemma3::{Gemma3Model, Gemma3Processor};
41use crate::vision_models::gemma3n::config::{Gemma3nConfig, IntermediateSize};
42use crate::vision_models::gemma3n::{Gemma3nModel, Gemma3nProcessor};
43use crate::vision_models::gemma4::config::Gemma4Config;
44use crate::vision_models::gemma4::{Gemma4Model, Gemma4Processor};
45use crate::vision_models::idefics2::{Config as Idefics2Config, Idefics2};
46use crate::vision_models::idefics2_input_processor::Idefics2Processor;
47use crate::vision_models::idefics3::{Idefics3Config, Idefics3Model, Idefics3Processor};
48use crate::vision_models::image_processor::ImagePreProcessor;
49use crate::vision_models::inputs_processor::Phi4MMProcessor;
50use crate::vision_models::llama4::{
51 self, Llama4Config, Llama4ImageProcessor, Llama4Model, Llama4Processor,
52};
53use crate::vision_models::llava::config::Config as LLaVAConfig;
54use crate::vision_models::llava15::Model as LLaVA;
55use crate::vision_models::llava_inputs_processor::{self, LLaVAProcessor};
56use crate::vision_models::llava_next::Model as LLaVANext;
57use crate::vision_models::llava_next_inputs_processor::{self, LLaVANextProcessor};
58use crate::vision_models::mistral3::{Mistral3Config, Mistral3Model, Mistral3Processor};
59use crate::vision_models::mllama::{MLlamaConfig, MLlamaModel, MLlamaProcessor};
60use crate::vision_models::phi3::{Config as Phi3Config, Model as Phi3, PHI3V_CLIP_CONFIG};
61use crate::vision_models::phi3_inputs_processor::Phi3Processor;
62use crate::vision_models::phi4::{Phi4MMConfig, Phi4MMModel, PHI4_MM_VISION_CFG};
63use crate::vision_models::preprocessor_config::PreProcessorConfig;
64use crate::vision_models::processor_config::ProcessorConfig;
65use crate::vision_models::qwen2_5_vl::{
66 Config as Qwen2_5VLConfig, Qwen2_5VLModel, Qwen2_5VLProcessor,
67};
68use crate::vision_models::qwen2vl::{Config as Qwen2VLConfig, Qwen2VLModel, Qwen2VLProcessor};
69use crate::vision_models::qwen3_5::{Config as Qwen3_5Config, Qwen3_5Model, Qwen3_5Processor};
70use crate::vision_models::qwen3_5_moe::{
71 Config as Qwen3_5MoeConfig, Qwen3_5MoeModel, Qwen3_5MoeProcessor,
72};
73use crate::vision_models::qwen3_vl::{Config as Qwen3VLConfig, Qwen3VLModel, Qwen3VLProcessor};
74use crate::vision_models::qwen3_vl_moe::{
75 Config as Qwen3VLMoEConfig, Qwen3VLMoEModel, Qwen3VLMoEProcessor,
76};
77use crate::vision_models::voxtral::config::VoxtralConfig;
78use crate::vision_models::voxtral::{VoxtralModel, VoxtralProcessor};
79use crate::vision_models::{minicpmo, phi4};
80
81pub trait MultimodalModel: IsqModel + AnyMoeBaseModelMixin + SpeculativeTargetMixin {
82 #[allow(clippy::too_many_arguments)]
84 fn forward(
85 &self,
86 input_ids: &Tensor,
87 pixel_values: Option<Tensor>,
88 seqlen_offsets: &[usize],
89 context_lens: Vec<(usize, usize)>,
90 position_ids: Vec<usize>,
91 model_specific_args: Box<dyn Any>, metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
93 flash_params: &FlashParams,
94 ) -> hanzo_ml::Result<Tensor>;
95 fn device(&self) -> &Device;
96 fn cache(&self) -> &EitherCache;
97 fn cache_mut(&mut self) -> &mut EitherCache;
98 fn max_seq_len(&self) -> usize;
99 fn config(&self) -> &ModelConfigMetadata;
100 fn model_config(&self) -> Arc<dyn ModelConfigLike + Send + Sync> {
101 Arc::new(self.config().clone())
102 }
103 fn default_model_specific_args(&self, input_ids: &Tensor) -> Box<dyn Any>;
105 fn encoder_cache_counters(&self) -> Option<(Arc<AtomicUsize>, Arc<AtomicUsize>)> {
107 None
108 }
109 fn reset_model_specific_state(&self) {}
112}
113
114pub trait MultimodalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
115 fn load(
116 &self,
117 config: &str,
118 vb: ShardedVarBuilder,
119 normal_loading_metadata: NormalLoadingMetadata,
120 attention_mechanism: AttentionImplementation,
121 ) -> Result<Box<dyn MultimodalModel + Send + Sync>>;
122 fn is_gptx(&self, config: &str) -> bool;
123 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
124 fn get_processor(
125 &self,
126 model_config: &str,
127 processor_config: Option<ProcessorConfig>,
128 preprocessor_config: PreProcessorConfig,
129 max_edge: Option<u32>,
130 ) -> Arc<dyn Processor + Send + Sync>;
131 fn supports_paged_attention(&self, config: &str) -> bool;
132 fn supports_prefix_cacher(&self, _config: &str) -> bool {
133 false
135 }
136 fn modalities(&self, config: &str) -> Result<Modalities>;
137 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer>;
138 fn default_chat_template(&self, _config: &str) -> Option<String> {
143 None
144 }
145 fn default_bos_eos(&self, _config: &str) -> Option<(String, String)> {
149 None
150 }
151 fn get_device_for_tensor(
152 &self,
153 config: &str,
154 _mapper: &dyn DeviceMapper,
155 loading_isq: bool,
156 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
157 if loading_isq {
158 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
159 } else {
160 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
161 let num_layers = self.model_config(config)?.num_layers();
162 let closure = move |name: String| {
163 if let Some(captures) = re.captures(&name) {
164 captures
165 .get(1)
166 .and_then(|m| m.as_str().parse::<usize>().ok())
167 .map(|l| l.min(num_layers))
168 .map(DeviceForLoadTensor::Idx)
169 .unwrap_or(DeviceForLoadTensor::Base)
170 } else {
171 DeviceForLoadTensor::Base
172 }
173 };
174
175 Ok(Arc::new(closure))
176 }
177 }
178}
179
180#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
181#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
182pub enum MultimodalLoaderType {
184 #[serde(rename = "phi3v")]
185 Phi3V,
186 #[serde(rename = "idefics2")]
187 Idefics2,
188 #[serde(rename = "llava_next")]
189 LLaVANext,
190 #[serde(rename = "llava")]
191 LLaVA,
192 #[serde(rename = "vllama")]
193 VLlama,
194 #[serde(rename = "qwen2vl")]
195 Qwen2VL,
196 #[serde(rename = "idefics3")]
197 Idefics3,
198 #[serde(rename = "minicpmo")]
199 MiniCpmO,
200 #[serde(rename = "phi4mm")]
201 Phi4MM,
202 #[serde(rename = "qwen2_5vl")]
203 Qwen2_5VL,
204 #[serde(rename = "gemma3")]
205 Gemma3,
206 #[serde(rename = "mistral3")]
207 Mistral3,
208 #[serde(rename = "llama4")]
209 Llama4,
210 #[serde(rename = "gemma3n")]
211 Gemma3n,
212 #[serde(rename = "qwen3vl")]
213 Qwen3VL,
214 #[serde(rename = "qwen3vlmoe")]
215 Qwen3VLMoE,
216 #[serde(rename = "qwen3_5")]
217 Qwen3_5,
218 #[serde(rename = "qwen3_5moe")]
219 Qwen3_5Moe,
220 #[serde(rename = "voxtral")]
221 Voxtral,
222 #[serde(rename = "gemma4")]
223 Gemma4,
224}
225
226impl MultimodalLoaderType {
228 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
229 match name {
230 "Phi3VForCausalLM" => Ok(Self::Phi3V),
231 "Idefics2ForConditionalGeneration" => Ok(Self::Idefics2),
232 "LlavaNextForConditionalGeneration" => Ok(Self::LLaVANext),
233 "LlavaForConditionalGeneration" => Ok(Self::LLaVA),
234 "MllamaForConditionalGeneration" => Ok(Self::VLlama),
235 "Qwen2VLForConditionalGeneration" => Ok(Self::Qwen2VL),
236 "Idefics3ForConditionalGeneration" => Ok(Self::Idefics3),
237 "MiniCPMO" => Ok(Self::MiniCpmO),
238 "Phi4MMForCausalLM" => Ok(Self::Phi4MM),
239 "Qwen2_5_VLForConditionalGeneration" => Ok(Self::Qwen2_5VL),
240 "Gemma3ForConditionalGeneration" | "Gemma3ForCausalLM" => Ok(Self::Gemma3),
241 "Mistral3ForConditionalGeneration" => Ok(Self::Mistral3),
242 "Llama4ForConditionalGeneration" => Ok(Self::Llama4),
243 "Gemma3nForConditionalGeneration" => Ok(Self::Gemma3n),
244 "Gemma4ForConditionalGeneration" => Ok(Self::Gemma4),
245 "Qwen3VLForConditionalGeneration" => Ok(Self::Qwen3VL),
246 "Qwen3VLMoeForConditionalGeneration" => Ok(Self::Qwen3VLMoE),
247 "Qwen3_5ForConditionalGeneration" => Ok(Self::Qwen3_5),
248 "Qwen3_5MoeForConditionalGeneration" => Ok(Self::Qwen3_5Moe),
249 "VoxtralForConditionalGeneration"
250 | "VoxtralRealtimeForConditionalGeneration" => Ok(Self::Voxtral),
251 other => anyhow::bail!(
252 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
253 ),
254 }
255 }
256}
257
258impl FromStr for MultimodalLoaderType {
259 type Err = String;
260 fn from_str(s: &str) -> Result<Self, Self::Err> {
261 match s {
262 "phi3v" => Ok(Self::Phi3V),
263 "idefics2" => Ok(Self::Idefics2),
264 "llava_next" => Ok(Self::LLaVANext),
265 "llava" => Ok(Self::LLaVA),
266 "vllama" => Ok(Self::VLlama),
267 "qwen2vl" => Ok(Self::Qwen2VL),
268 "idefics3" => Ok(Self::Idefics3),
269 "minicpmo" => Ok(Self::MiniCpmO),
270 "phi4mm" => Ok(Self::Phi4MM),
271 "qwen2_5vl" => Ok(Self::Qwen2_5VL),
272 "gemma3" => Ok(Self::Gemma3),
273 "mistral3" => Ok(Self::Mistral3),
274 "llama4" => Ok(Self::Llama4),
275 "gemma3n" => Ok(Self::Gemma3n),
276 "gemma4" => Ok(Self::Gemma4),
277 "qwen3vl" => Ok(Self::Qwen3VL),
278 "qwen3vlmoe" => Ok(Self::Qwen3VLMoE),
279 "qwen3_5" => Ok(Self::Qwen3_5),
280 "qwen3_5moe" => Ok(Self::Qwen3_5Moe),
281 "voxtral" => Ok(Self::Voxtral),
282 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `phi3v`, `idefics2`, `llava_next`, `llava`, `vllama`, `qwen2vl`, `idefics3`, `minicpmo`, `phi4mm`, `qwen2_5vl`, `gemma3`, `mistral3`, `llama4`, `gemma3n`, `gemma4`, `qwen3vl`, `qwen3vlmoe`, `qwen3_5`, `qwen3_5moe`, `voxtral`.")),
283 }
284 }
285}
286
287impl std::fmt::Display for MultimodalLoaderType {
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 let name = match self {
290 MultimodalLoaderType::Phi3V => "phi3v",
291 MultimodalLoaderType::Idefics2 => "idefics2",
292 MultimodalLoaderType::LLaVANext => "llava_next",
293 MultimodalLoaderType::LLaVA => "llava",
294 MultimodalLoaderType::VLlama => "vllama",
295 MultimodalLoaderType::Qwen2VL => "qwen2vl",
296 MultimodalLoaderType::Idefics3 => "idefics3",
297 MultimodalLoaderType::MiniCpmO => "minicpmo",
298 MultimodalLoaderType::Phi4MM => "phi4mm",
299 MultimodalLoaderType::Qwen2_5VL => "qwen2_5vl",
300 MultimodalLoaderType::Gemma3 => "gemma3",
301 MultimodalLoaderType::Mistral3 => "mistral3",
302 MultimodalLoaderType::Llama4 => "llama4",
303 MultimodalLoaderType::Gemma3n => "gemma3n",
304 MultimodalLoaderType::Qwen3VL => "qwen3vl",
305 MultimodalLoaderType::Qwen3VLMoE => "qwen3vlmoe",
306 MultimodalLoaderType::Qwen3_5 => "qwen3_5",
307 MultimodalLoaderType::Qwen3_5Moe => "qwen3_5moe",
308 MultimodalLoaderType::Voxtral => "voxtral",
309 MultimodalLoaderType::Gemma4 => "gemma4",
310 };
311 write!(f, "{name}")
312 }
313}
314
315#[derive(Deserialize)]
316struct AutoMultimodalLoaderConfig {
317 #[serde(default)]
318 architectures: Vec<String>,
319 #[serde(default)]
321 multimodal: Option<serde_json::Value>,
322}
323
324pub struct AutoMultimodalLoader;
326
327impl AutoMultimodalLoader {
328 fn get_loader(config: &str) -> Result<Box<dyn MultimodalModelLoader>> {
329 let auto_cfg: AutoMultimodalLoaderConfig = serde_json::from_str(config)?;
330
331 if auto_cfg.multimodal.is_some() && auto_cfg.architectures.is_empty() {
333 once_log_debug("Automatic loader type determined to be `voxtral`");
334 return Ok(Box::new(VoxtralLoader));
335 }
336
337 if auto_cfg.architectures.len() != 1 {
338 anyhow::bail!("Expected exactly one architecture in config");
339 }
340
341 let name = &auto_cfg.architectures[0];
342 let tp = MultimodalLoaderType::from_causal_lm_name(name)?;
343
344 once_log_debug(format!("Automatic loader type determined to be `{tp}`"));
345
346 Ok(match tp {
348 MultimodalLoaderType::Phi3V => Box::new(Phi3VLoader),
349 MultimodalLoaderType::Idefics2 => Box::new(Idefics2Loader),
350 MultimodalLoaderType::LLaVANext => Box::new(LLaVANextLoader),
351 MultimodalLoaderType::LLaVA => Box::new(LLaVALoader),
352 MultimodalLoaderType::VLlama => Box::new(VLlamaLoader),
353 MultimodalLoaderType::Qwen2VL => Box::new(Qwen2VLLoader),
354 MultimodalLoaderType::Idefics3 => Box::new(Idefics3Loader),
355 MultimodalLoaderType::MiniCpmO => Box::new(MiniCpmOLoader),
356 MultimodalLoaderType::Phi4MM => Box::new(Phi4MMLoader),
357 MultimodalLoaderType::Qwen2_5VL => Box::new(Qwen2_5VLLoader),
358 MultimodalLoaderType::Gemma3 => Box::new(Gemma3Loader),
359 MultimodalLoaderType::Mistral3 => Box::new(Mistral3Loader),
360 MultimodalLoaderType::Llama4 => Box::new(VLlama4Loader),
361 MultimodalLoaderType::Gemma3n => Box::new(Gemma3nLoader),
362 MultimodalLoaderType::Qwen3VL => Box::new(Qwen3VLLoader),
363 MultimodalLoaderType::Qwen3VLMoE => Box::new(Qwen3VLMoELoader),
364 MultimodalLoaderType::Qwen3_5 => Box::new(Qwen3_5Loader),
365 MultimodalLoaderType::Qwen3_5Moe => Box::new(Qwen3_5MoeLoader),
366 MultimodalLoaderType::Voxtral => Box::new(VoxtralLoader),
367 MultimodalLoaderType::Gemma4 => Box::new(Gemma4Loader),
368 })
369 }
370}
371
372impl MultimodalModelLoader for AutoMultimodalLoader {
373 fn load(
374 &self,
375 config: &str,
376 vb: ShardedVarBuilder,
377 normal_loading_metadata: NormalLoadingMetadata,
378 attention_mechanism: AttentionImplementation,
379 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
380 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
381 }
382
383 fn is_gptx(&self, config: &str) -> bool {
384 Self::get_loader(config)
385 .expect("AutoMultimodalLoader get_loader")
386 .is_gptx(config)
387 }
388
389 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
390 Self::get_loader(config)?.get_config_repr(config)
391 }
392
393 fn get_processor(
394 &self,
395 model_config: &str,
396 proc_cfg: Option<ProcessorConfig>,
397 preproc_cfg: PreProcessorConfig,
398 max_edge: Option<u32>,
399 ) -> Arc<dyn Processor + Send + Sync> {
400 Self::get_loader(model_config)
401 .expect("AutoMultimodalLoader get_loader")
402 .get_processor(model_config, proc_cfg, preproc_cfg, max_edge)
403 }
404
405 fn supports_paged_attention(&self, config: &str) -> bool {
406 Self::get_loader(config)
407 .expect("AutoMultimodalLoader")
408 .supports_paged_attention(config)
409 }
410
411 fn modalities(&self, config: &str) -> Result<Modalities> {
412 Self::get_loader(config)?.modalities(config)
413 }
414
415 fn supports_prefix_cacher(&self, config: &str) -> bool {
416 Self::get_loader(config)
417 .expect("AutoMultimodalLoader")
418 .supports_prefix_cacher(config)
419 }
420
421 fn prefixer(&self, config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
422 Self::get_loader(config)
423 .expect("AutoMultimodalLoader")
424 .prefixer(config)
425 }
426
427 fn default_chat_template(&self, config: &str) -> Option<String> {
428 Self::get_loader(config).ok()?.default_chat_template(config)
429 }
430
431 fn default_bos_eos(&self, config: &str) -> Option<(String, String)> {
432 Self::get_loader(config).ok()?.default_bos_eos(config)
433 }
434
435 fn get_device_for_tensor(
436 &self,
437 config: &str,
438 mapper: &dyn DeviceMapper,
439 loading_isq: bool,
440 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
441 Self::get_loader(config)?.get_device_for_tensor(config, mapper, loading_isq)
442 }
443}
444
445impl IsqModelLoader for AutoMultimodalLoader {
446 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
447 Self::get_loader(config)?.isq_layer_regexes(config)
448 }
449 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
450 Self::get_loader(config)?.immediate_isq_predicates(config)
451 }
452 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
453 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
454 }
455 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
456 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
457 }
458}
459
460impl DeviceMappedModelLoader for AutoMultimodalLoader {
461 fn mapped_max_act_size_elems(
462 &self,
463 config: &str,
464 params: &AutoDeviceMapParams,
465 ) -> Result<usize> {
466 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
467 }
468 fn non_mapped_max_act_size_elems(
469 &self,
470 config: &str,
471 params: &AutoDeviceMapParams,
472 ) -> Result<usize> {
473 Self::get_loader(config)?.non_mapped_max_act_size_elems(config, params)
474 }
475 fn non_mapped_size_in_bytes(
476 &self,
477 config: &str,
478 dtype: DType,
479 weight_pack_factor: usize,
480 _matformer_config: Option<&MatformerSliceConfig>,
481 ) -> Result<usize> {
482 Self::get_loader(config)?.non_mapped_size_in_bytes(
483 config,
484 dtype,
485 weight_pack_factor,
486 _matformer_config,
487 )
488 }
489 fn layer_sizes_in_bytes(
490 &self,
491 config: &str,
492 dtype: DType,
493 weight_pack_factor: usize,
494 _matformer_config: Option<&MatformerSliceConfig>,
495 ) -> Result<Vec<usize>> {
496 Self::get_loader(config)?.layer_sizes_in_bytes(
497 config,
498 dtype,
499 weight_pack_factor,
500 _matformer_config,
501 )
502 }
503 fn num_layers(&self, config: &str) -> Result<usize> {
504 Self::get_loader(config)?.num_layers(config)
505 }
506 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
507 Self::get_loader(config)?.model_config(config)
508 }
509}
510
511macro_rules! bias_if {
512 ($cond:expr, $size:expr) => {
513 if $cond {
514 $size
515 } else {
516 0
517 }
518 };
519}
520
521fn get_clip_vit_num_elems(cfg: &ClipConfig) -> usize {
522 let pre_layer_norm = cfg.hidden_size;
523 let final_layer_norm = cfg.hidden_size;
524
525 let num_patches = (cfg.image_size / cfg.patch_size).pow(2);
526 let num_positions = num_patches + 1;
527
528 let class_embedding = cfg.hidden_size;
529
530 let position_ids = num_positions;
531 let position_embedding = num_positions * cfg.hidden_size;
532
533 let conv2dconfig = Conv2dConfig {
534 stride: cfg.patch_size,
535 ..Default::default()
536 };
537 let patch_embedding =
538 cfg.num_channels * cfg.hidden_size / conv2dconfig.groups * cfg.patch_size * cfg.patch_size;
539
540 let encoder_layer_elems = {
541 let layer_norm1 = cfg.hidden_size;
542 let layer_norm2 = cfg.hidden_size;
543
544 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
545 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
546 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
547 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
548
549 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
550 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
551
552 layer_norm1 + layer_norm2 + q_proj + k_proj + v_proj + o_proj + fc1 + fc2
553 };
554
555 pre_layer_norm
556 + final_layer_norm
557 + class_embedding
558 + position_ids
559 + position_embedding
560 + patch_embedding
561 + cfg.num_hidden_layers * encoder_layer_elems
562}
563
564pub struct Phi3VLoader;
570
571pub struct Phi3VPrefixer;
572
573impl MultimodalPromptPrefixer for Phi3VPrefixer {
574 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
575 format!(
577 "{}{prompt}",
578 image_indexes
579 .into_iter()
580 .map(|image_index| format!("<|image_{}|>", image_index + 1))
581 .join("")
582 )
583 }
584}
585
586impl MultimodalModelLoader for Phi3VLoader {
587 fn load(
588 &self,
589 config: &str,
590 vb: ShardedVarBuilder,
591 normal_loading_metadata: NormalLoadingMetadata,
592 attention_mechanism: AttentionImplementation,
593 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
594 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
595 Ok(Box::new(Phi3::new(
596 &cfg,
597 vb,
598 self.is_gptx(config),
599 normal_loading_metadata,
600 attention_mechanism,
601 )?))
602 }
603 fn is_gptx(&self, _config: &str) -> bool {
604 true
605 }
606 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
607 let cfg: crate::vision_models::phi3::Config = serde_json::from_str(config)?;
608 Ok(Box::new(cfg))
609 }
610 fn get_processor(
611 &self,
612 _model_config: &str,
613 processor_config: Option<ProcessorConfig>,
614 preprocessor_config: PreProcessorConfig,
615 _max_edge: Option<u32>,
616 ) -> Arc<dyn Processor + Send + Sync> {
617 Phi3Processor::new_processor(processor_config, preprocessor_config)
618 }
619 fn supports_paged_attention(&self, _config: &str) -> bool {
620 true
621 }
622 fn supports_prefix_cacher(&self, _config: &str) -> bool {
623 true
624 }
625 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
626 Arc::new(Phi3VPrefixer)
627 }
628 fn modalities(&self, _config: &str) -> Result<Modalities> {
629 Ok(Modalities {
630 input: vec![SupportedModality::Text, SupportedModality::Vision],
631 output: vec![SupportedModality::Text],
632 })
633 }
634}
635
636impl IsqModelLoader for Phi3VLoader {
637 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
638 Ok(vec![
639 Regex::new(r"lm_head\.(weight|bias)$")?,
640 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
642 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
643 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
645 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
646 ])
647 }
648 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
649 self.isq_layer_regexes(config)
650 }
651}
652
653impl DeviceMappedModelLoader for Phi3VLoader {
654 fn mapped_max_act_size_elems(
655 &self,
656 config: &str,
657 params: &AutoDeviceMapParams,
658 ) -> Result<usize> {
659 let AutoDeviceMapParams::Multimodal {
661 max_seq_len,
662 max_batch_size,
663 max_image_shape: _,
664 max_num_images,
665 } = params
666 else {
667 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
668 };
669
670 let cfg: Phi3Config = serde_json::from_str(config)?;
671
672 let vcfg = &PHI3V_CLIP_CONFIG;
673
674 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
675 let img_seq_len = (num_patches + 1) * max_num_images;
676
677 let max_text_attn = {
678 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
680 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
681 };
682
683 Ok(max_text_attn)
684 }
685
686 fn non_mapped_max_act_size_elems(
687 &self,
688 config: &str,
689 params: &AutoDeviceMapParams,
690 ) -> Result<usize> {
691 let AutoDeviceMapParams::Multimodal {
693 max_seq_len: _,
694 max_batch_size,
695 max_image_shape: _,
696 max_num_images,
697 } = params
698 else {
699 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
700 };
701
702 let cfg: Phi3Config = serde_json::from_str(config)?;
703
704 let vcfg = &PHI3V_CLIP_CONFIG;
705
706 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
707 let img_seq_len = num_patches + 1;
708
709 let max_vision_attn = {
710 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
711 };
712
713 Ok(max_vision_attn)
714 }
715
716 fn non_mapped_size_in_bytes(
717 &self,
718 config: &str,
719 dtype: DType,
720 weight_pack_factor: usize,
721 _matformer_config: Option<&MatformerSliceConfig>,
722 ) -> Result<usize> {
723 let cfg: Phi3Config = serde_json::from_str(config)?;
724 let elems = {
725 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
726 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
728 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
729 } else {
730 0
731 };
732 let norm = cfg.hidden_size;
733
734 let image_embed = {
735 let projection_cls = cfg
736 .embd_layer
737 .projection_cls
738 .clone()
739 .unwrap_or("linear".to_string());
740 let with_learnable_separator =
741 cfg.embd_layer.with_learnable_separator.unwrap_or(false);
742 let use_hd_transform = cfg.embd_layer.use_hd_transform.unwrap_or(false);
743 let image_dim_out = cfg.img_processor.image_dim_out;
744
745 let proj = match (projection_cls.as_str(), use_hd_transform) {
746 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
747 ("mlp", true) => {
748 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
749 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
750 a + b
751 }
752 ("mlp", false) => {
753 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
754 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
755 a + b
756 }
757 _ => {
758 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
759 }
760 };
761
762 let (glb_gn, sub_gn) = if with_learnable_separator {
763 let glb_gn = image_dim_out * 4;
764 let sub_gn = image_dim_out * 4;
765 (glb_gn, sub_gn)
766 } else {
767 (0, 0)
768 };
769
770 let clip_vit = get_clip_vit_num_elems(&PHI3V_CLIP_CONFIG);
771
772 proj + glb_gn + sub_gn + clip_vit
773 };
774
775 embed_tokens + lm_head + norm + image_embed
776 };
777
778 Ok(elems * dtype.size_in_bytes())
779 }
780
781 fn layer_sizes_in_bytes(
782 &self,
783 config: &str,
784 dtype: DType,
785 weight_pack_factor: usize,
786 _matformer_config: Option<&MatformerSliceConfig>,
787 ) -> Result<Vec<usize>> {
788 let cfg: Phi3Config = serde_json::from_str(config)?;
789 let per_layer_elems = {
790 let input_layernorm = cfg.hidden_size;
791 let post_attention_layernorm = cfg.hidden_size;
792
793 let size_in = cfg.hidden_size;
794 let head_dim = cfg.head_dim();
795 let op_size =
796 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
797 let qkv_proj = size_in * op_size / weight_pack_factor;
798 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
799
800 let h_size = cfg.hidden_size;
801 let i_size = cfg.intermediate_size;
802 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
803 let down_proj = h_size * i_size / weight_pack_factor;
804
805 input_layernorm
806 + post_attention_layernorm
807 + qkv_proj
808 + o_proj
809 + gate_up_proj
810 + down_proj
811 };
812 Ok(vec![
813 per_layer_elems * dtype.size_in_bytes();
814 cfg.num_hidden_layers
815 ])
816 }
817
818 fn num_layers(&self, config: &str) -> Result<usize> {
819 let cfg: Phi3Config = serde_json::from_str(config)?;
820 Ok(cfg.num_hidden_layers)
821 }
822
823 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
824 let cfg: Phi3Config = serde_json::from_str(config)?;
825
826 let cfg = ModelConfigMetadata {
827 max_seq_len: cfg.max_position_embeddings,
828 num_layers: cfg.num_hidden_layers,
829 hidden_size: cfg.hidden_size,
830 num_kv_heads: cfg.num_key_value_heads,
831 num_attn_heads: cfg.num_attention_heads,
832 sliding_window: cfg.sliding_window,
833 k_head_dim: cfg.head_dim(),
834 v_head_dim: cfg.head_dim(),
835 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
836 };
837
838 Ok(Box::new(cfg))
839 }
840
841 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
842 Some(vec![NonMappedSubModel::Vision])
843 }
844}
845
846pub struct Idefics2Loader;
852
853pub struct Idefics2Prefixer;
854
855impl MultimodalPromptPrefixer for Idefics2Prefixer {
856 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
857 prompt.to_string()
859 }
860}
861
862impl MultimodalModelLoader for Idefics2Loader {
863 fn load(
864 &self,
865 config: &str,
866 vb: ShardedVarBuilder,
867 normal_loading_metadata: NormalLoadingMetadata,
868 attention_mechanism: AttentionImplementation,
869 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
870 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
871 Ok(Box::new(Idefics2::new(
872 &cfg,
873 vb,
874 self.is_gptx(config),
875 normal_loading_metadata,
876 attention_mechanism,
877 )?))
878 }
879 fn is_gptx(&self, _config: &str) -> bool {
880 true
881 }
882 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
883 let cfg: crate::vision_models::idefics2::Config = serde_json::from_str(config)?;
884 Ok(Box::new(cfg))
885 }
886 fn get_processor(
887 &self,
888 _model_config: &str,
889 processor_config: Option<ProcessorConfig>,
890 preprocessor_config: PreProcessorConfig,
891 max_edge: Option<u32>,
892 ) -> Arc<dyn Processor + Send + Sync> {
893 Arc::new(Idefics2Processor::new(
894 processor_config.unwrap(),
895 preprocessor_config,
896 max_edge,
897 ))
898 }
899 fn supports_paged_attention(&self, _config: &str) -> bool {
900 true
901 }
902 fn supports_prefix_cacher(&self, _config: &str) -> bool {
903 true
904 }
905 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
906 Arc::new(Idefics2Prefixer)
907 }
908 fn modalities(&self, _config: &str) -> Result<Modalities> {
909 Ok(Modalities {
910 input: vec![SupportedModality::Text, SupportedModality::Vision],
911 output: vec![SupportedModality::Text],
912 })
913 }
914}
915
916impl IsqModelLoader for Idefics2Loader {
917 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
918 Ok(vec![
919 Regex::new(r"lm_head\.(weight|bias)$")?,
920 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
922 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
923 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
924 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
925 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
927 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
928 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
929 ])
930 }
931 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
932 Ok(vec![
933 Regex::new(r"lm_head\.(weight|bias)$")?,
934 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
936 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
937 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
938 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
939 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
941 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
942 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
943 ])
944 }
945}
946
947impl DeviceMappedModelLoader for Idefics2Loader {
948 fn mapped_max_act_size_elems(
949 &self,
950 config: &str,
951 params: &AutoDeviceMapParams,
952 ) -> Result<usize> {
953 let AutoDeviceMapParams::Multimodal {
954 max_seq_len,
955 max_batch_size,
956 max_image_shape: _,
957 max_num_images,
958 } = params
959 else {
960 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
961 };
962
963 let cfg: Idefics2Config = serde_json::from_str(config)?;
964
965 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
966 let img_seq_len = (num_patches + 1) * max_num_images;
967
968 let max_text_attn = {
969 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
971 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
972 };
973
974 Ok(max_text_attn)
975 }
976
977 fn non_mapped_max_act_size_elems(
978 &self,
979 config: &str,
980 params: &AutoDeviceMapParams,
981 ) -> Result<usize> {
982 let AutoDeviceMapParams::Multimodal {
983 max_seq_len: _,
984 max_batch_size,
985 max_image_shape: _,
986 max_num_images,
987 } = params
988 else {
989 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
990 };
991
992 let cfg: Idefics2Config = serde_json::from_str(config)?;
993
994 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
995 let img_seq_len = num_patches + 1;
996
997 let max_vision_attn = {
998 let images_factor = 5;
1000
1001 (max_batch_size * images_factor * max_num_images)
1002 * cfg.vision_config.num_attention_heads
1003 * img_seq_len
1004 * img_seq_len
1005 };
1006
1007 Ok(max_vision_attn)
1008 }
1009
1010 fn non_mapped_size_in_bytes(
1011 &self,
1012 config: &str,
1013 dtype: DType,
1014 weight_pack_factor: usize,
1015 _matformer_config: Option<&MatformerSliceConfig>,
1016 ) -> Result<usize> {
1017 let cfg: Idefics2Config = serde_json::from_str(config)?;
1018 let text_elems = {
1019 let tie_word_embeddings = cfg.tie_word_embeddings;
1020 let cfg = &cfg.text_config;
1021
1022 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1023 let lm_head = if !tie_word_embeddings {
1024 cfg.hidden_size * cfg.vocab_size
1025 } else {
1026 0
1027 };
1028 let norm = cfg.hidden_size;
1029 embed_tokens + lm_head + norm
1030 };
1031
1032 let connector_elems = {
1033 let tcfg = &cfg.text_config;
1034 let vcfg = &cfg.vision_config;
1035 let gate_proj = vcfg.hidden_size * tcfg.intermediate_size;
1036 let up_proj = vcfg.hidden_size * tcfg.intermediate_size;
1037 let down_proj = tcfg.intermediate_size * tcfg.hidden_size;
1038
1039 let perceiver_elems = {
1040 let tcfg = &cfg.text_config;
1041 let pcfg = &cfg.perceiver_config;
1042
1043 let n_latents = pcfg.resampler_n_latents;
1044 let hidden_size = tcfg.hidden_size;
1045 let depth = pcfg.resampler_depth;
1046
1047 let norm = tcfg.hidden_size;
1048 let latents = n_latents * hidden_size;
1049
1050 let layer_elems = {
1051 let input_latents_norm = hidden_size;
1052 let input_context_norm = hidden_size;
1053 let post_attn_norm = hidden_size;
1054
1055 let num_heads = pcfg.resampler_n_heads;
1056 let head_dim = pcfg.resampler_head_dim;
1057 let num_key_value_heads = pcfg.num_key_value_heads;
1058
1059 let q_proj = hidden_size * num_heads * head_dim;
1060 let k_proj = hidden_size * num_key_value_heads * head_dim;
1061 let v_proj = hidden_size * num_key_value_heads * head_dim;
1062 let o_proj = num_heads * head_dim * hidden_size;
1063
1064 let gate_proj = hidden_size * hidden_size * 4;
1065 let up_proj = hidden_size * hidden_size * 4;
1066 let down_proj = hidden_size * 4 * hidden_size;
1067
1068 input_latents_norm
1069 + input_context_norm
1070 + post_attn_norm
1071 + q_proj
1072 + k_proj
1073 + v_proj
1074 + o_proj
1075 + gate_proj
1076 + up_proj
1077 + down_proj
1078 };
1079
1080 norm + latents + layer_elems * depth
1081 };
1082
1083 gate_proj + up_proj + down_proj + perceiver_elems
1084 };
1085
1086 let vision_transformer = {
1087 let cfg = &cfg.vision_config;
1088
1089 let post_layernorm = cfg.hidden_size;
1090
1091 let conv_config = Conv2dConfig {
1092 stride: cfg.patch_size,
1093 ..Default::default()
1094 };
1095 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
1096 * cfg.patch_size
1097 * cfg.patch_size;
1098
1099 let num_patches_per_side = cfg.image_size / cfg.patch_size;
1100 let num_patches = num_patches_per_side.pow(2);
1101 let position_embedding = num_patches * cfg.hidden_size;
1102
1103 let layer_elems = {
1104 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1105 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
1106
1107 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
1108 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
1109
1110 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1111 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1112 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1113 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
1114
1115 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
1116 };
1117
1118 post_layernorm + patch_embedding + position_embedding + layer_elems
1119 };
1120
1121 let elems = text_elems + connector_elems + vision_transformer;
1122
1123 Ok(elems * dtype.size_in_bytes())
1124 }
1125
1126 fn layer_sizes_in_bytes(
1127 &self,
1128 config: &str,
1129 dtype: DType,
1130 weight_pack_factor: usize,
1131 _matformer_config: Option<&MatformerSliceConfig>,
1132 ) -> Result<Vec<usize>> {
1133 let cfg: Idefics2Config = serde_json::from_str(config)?;
1134 let cfg = cfg.text_config;
1135 let per_layer_elems = {
1136 let input_layernorm = cfg.hidden_size;
1137 let post_attention_layernorm = cfg.hidden_size;
1138
1139 let size_in = cfg.hidden_size;
1140 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1141 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1142 let q_proj = size_in * size_q / weight_pack_factor;
1143 let k_proj = size_in * size_kv / weight_pack_factor;
1144 let v_proj = size_in * size_kv / weight_pack_factor;
1145 let o_proj = size_q * size_in / weight_pack_factor;
1146
1147 let h_size = cfg.hidden_size;
1148 let i_size = cfg.intermediate_size;
1149 let gate_proj = h_size * i_size / weight_pack_factor;
1150 let up_proj = h_size * i_size / weight_pack_factor;
1151 let down_proj = i_size * h_size / weight_pack_factor;
1152
1153 input_layernorm
1154 + post_attention_layernorm
1155 + q_proj
1156 + k_proj
1157 + v_proj
1158 + o_proj
1159 + gate_proj
1160 + up_proj
1161 + down_proj
1162 };
1163 Ok(vec![
1164 per_layer_elems * dtype.size_in_bytes();
1165 cfg.num_hidden_layers
1166 ])
1167 }
1168
1169 fn num_layers(&self, config: &str) -> Result<usize> {
1170 let cfg: Idefics2Config = serde_json::from_str(config)?;
1171 Ok(cfg.text_config.num_hidden_layers)
1172 }
1173 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1174 let cfg: Idefics2Config = serde_json::from_str(config)?;
1175 let cfg = &cfg.text_config;
1176
1177 let cfg = ModelConfigMetadata {
1178 max_seq_len: cfg.max_position_embeddings,
1179 num_layers: cfg.num_hidden_layers,
1180 hidden_size: cfg.hidden_size,
1181 num_kv_heads: cfg.num_key_value_heads,
1182 num_attn_heads: cfg.num_attention_heads,
1183 sliding_window: cfg.sliding_window,
1184 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1185 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1186 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1187 };
1188
1189 Ok(Box::new(cfg))
1190 }
1191
1192 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1193 Some(vec![NonMappedSubModel::Vision])
1194 }
1195}
1196
1197pub struct LLaVANextLoader;
1203
1204pub struct LLaVANextPrefixer;
1205
1206impl MultimodalPromptPrefixer for LLaVANextPrefixer {
1207 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1208 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1209 }
1210}
1211
1212impl MultimodalModelLoader for LLaVANextLoader {
1213 fn load(
1214 &self,
1215 config: &str,
1216 vb: ShardedVarBuilder,
1217 normal_loading_metadata: NormalLoadingMetadata,
1218 attention_mechanism: AttentionImplementation,
1219 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
1220 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1221 Ok(Box::new(LLaVANext::new(
1222 &cfg,
1223 vb,
1224 self.is_gptx(config),
1225 normal_loading_metadata,
1226 attention_mechanism,
1227 )?))
1228 }
1229 fn is_gptx(&self, _config: &str) -> bool {
1230 false
1231 }
1232 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1233 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1234 Ok(Box::new(cfg))
1235 }
1236 fn get_processor(
1237 &self,
1238 model_config: &str,
1239 _processor_config: Option<ProcessorConfig>,
1240 _preprocessor_config: PreProcessorConfig,
1241 _max_edge: Option<u32>,
1242 ) -> Arc<dyn Processor + Send + Sync> {
1243 Arc::new(LLaVANextProcessor::new(model_config))
1244 }
1245 fn supports_paged_attention(&self, _config: &str) -> bool {
1246 true
1247 }
1248 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1249 true
1250 }
1251 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1252 Arc::new(LLaVANextPrefixer)
1253 }
1254 fn modalities(&self, _config: &str) -> Result<Modalities> {
1255 Ok(Modalities {
1256 input: vec![SupportedModality::Text, SupportedModality::Vision],
1257 output: vec![SupportedModality::Text],
1258 })
1259 }
1260}
1261
1262impl IsqModelLoader for LLaVANextLoader {
1263 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1264 Ok(vec![
1265 Regex::new(r"lm_head\.(weight|bias)$")?,
1266 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1268 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1269 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1270 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1271 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1273 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1274 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1275 ])
1276 }
1277 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1278 Ok(vec![
1279 Regex::new(r"lm_head\.(weight|bias)$")?,
1280 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1282 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1283 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1284 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1285 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1287 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1288 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1289 ])
1290 }
1291}
1292
1293impl DeviceMappedModelLoader for LLaVANextLoader {
1294 fn mapped_max_act_size_elems(
1295 &self,
1296 config: &str,
1297 params: &AutoDeviceMapParams,
1298 ) -> Result<usize> {
1299 let AutoDeviceMapParams::Multimodal {
1300 max_seq_len,
1301 max_batch_size,
1302 max_image_shape,
1303 max_num_images,
1304 } = params
1305 else {
1306 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1307 };
1308
1309 let config: LLaVAConfig = serde_json::from_str(config)?;
1310
1311 #[allow(clippy::cast_possible_truncation)]
1312 let img_seq_len =
1313 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1314 &config,
1315 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1316 );
1317 let img_seq_len = img_seq_len * max_num_images;
1318
1319 let max_text_attn = {
1320 let cfg = &config.text_config;
1321 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1323
1324 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1325 };
1326
1327 Ok(max_text_attn)
1328 }
1329
1330 fn non_mapped_max_act_size_elems(
1331 &self,
1332 config: &str,
1333 params: &AutoDeviceMapParams,
1334 ) -> Result<usize> {
1335 let AutoDeviceMapParams::Multimodal {
1336 max_seq_len: _,
1337 max_batch_size,
1338 max_image_shape,
1339 max_num_images,
1340 } = params
1341 else {
1342 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1343 };
1344
1345 let config: LLaVAConfig = serde_json::from_str(config)?;
1346
1347 #[allow(clippy::cast_possible_truncation)]
1348 let img_seq_len =
1349 llava_next_inputs_processor::LLaVANextInputProcessor::get_num_image_tokens(
1350 &config,
1351 (max_image_shape.0 as u32, max_image_shape.1 as u32),
1352 );
1353
1354 let max_vision_attn = {
1355 (max_batch_size * max_num_images)
1356 * config.vision_config.num_attention_heads
1357 * img_seq_len
1358 * img_seq_len
1359 };
1360
1361 Ok(max_vision_attn)
1362 }
1363
1364 fn non_mapped_size_in_bytes(
1365 &self,
1366 config: &str,
1367 dtype: DType,
1368 weight_pack_factor: usize,
1369 _matformer_config: Option<&MatformerSliceConfig>,
1370 ) -> Result<usize> {
1371 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1372 let text_elems = {
1373 let cfg = &cfg.text_config;
1374 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1375 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1376 let norm = cfg.hidden_size;
1377 embed_tokens + lm_head + norm
1378 };
1379
1380 let image_newline = cfg.text_config.hidden_size;
1381 let mmproj = {
1382 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1383 + cfg.text_config.hidden_size;
1384 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1385 + cfg.text_config.hidden_size;
1386
1387 linear_1 + linear_2
1388 };
1389 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1390
1391 let elems = text_elems + image_newline + mmproj + vision_tower;
1392 Ok(elems * dtype.size_in_bytes())
1393 }
1394
1395 fn layer_sizes_in_bytes(
1396 &self,
1397 config: &str,
1398 dtype: DType,
1399 weight_pack_factor: usize,
1400 _matformer_config: Option<&MatformerSliceConfig>,
1401 ) -> Result<Vec<usize>> {
1402 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1403 let per_layer_elems = {
1404 let cfg = &cfg.text_config;
1405 let input_layernorm = cfg.hidden_size;
1406 let post_attention_layernorm = cfg.hidden_size;
1407
1408 let size_in = cfg.hidden_size;
1409 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1410 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1411 let q_proj = size_in * size_q / weight_pack_factor;
1412 let k_proj = size_in * size_kv / weight_pack_factor;
1413 let v_proj = size_in * size_kv / weight_pack_factor;
1414 let o_proj = size_q * size_in / weight_pack_factor;
1415
1416 let h_size = cfg.hidden_size;
1417 let i_size = cfg.intermediate_size;
1418 let gate_proj = h_size * i_size / weight_pack_factor;
1419 let up_proj = h_size * i_size / weight_pack_factor;
1420 let down_proj = i_size * h_size / weight_pack_factor;
1421
1422 input_layernorm
1423 + post_attention_layernorm
1424 + q_proj
1425 + k_proj
1426 + v_proj
1427 + o_proj
1428 + gate_proj
1429 + up_proj
1430 + down_proj
1431 };
1432 Ok(vec![
1433 per_layer_elems * dtype.size_in_bytes();
1434 cfg.text_config.num_hidden_layers
1435 ])
1436 }
1437
1438 fn num_layers(&self, config: &str) -> Result<usize> {
1439 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1440 Ok(cfg.text_config.num_hidden_layers)
1441 }
1442
1443 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1444 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1445 let cfg = &cfg.text_config;
1446
1447 let cfg = ModelConfigMetadata {
1448 max_seq_len: cfg.max_position_embeddings,
1449 num_layers: cfg.num_hidden_layers,
1450 hidden_size: cfg.hidden_size,
1451 num_kv_heads: cfg.num_key_value_heads,
1452 num_attn_heads: cfg.num_attention_heads,
1453 sliding_window: cfg.sliding_window,
1454 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1455 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1456 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1457 };
1458
1459 Ok(Box::new(cfg))
1460 }
1461
1462 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1463 Some(vec![NonMappedSubModel::Vision])
1464 }
1465}
1466
1467pub struct LLaVALoader;
1473
1474pub struct LLaVAPrefixer;
1475
1476impl MultimodalPromptPrefixer for LLaVAPrefixer {
1477 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1478 format!("{}{prompt}", "<image>".repeat(image_indexes.len()))
1479 }
1480}
1481
1482impl MultimodalModelLoader for LLaVALoader {
1483 fn load(
1484 &self,
1485 config: &str,
1486 vb: ShardedVarBuilder,
1487 normal_loading_metadata: NormalLoadingMetadata,
1488 attention_mechanism: AttentionImplementation,
1489 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
1490 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1491 Ok(Box::new(LLaVA::new(
1492 &cfg,
1493 vb,
1494 self.is_gptx(config),
1495 normal_loading_metadata,
1496 attention_mechanism,
1497 )?))
1498 }
1499 fn is_gptx(&self, _config: &str) -> bool {
1500 false
1501 }
1502 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1503 let cfg: crate::vision_models::llava::config::Config = serde_json::from_str(config)?;
1504 Ok(Box::new(cfg))
1505 }
1506 fn get_processor(
1507 &self,
1508 model_config: &str,
1509 _processor_config: Option<ProcessorConfig>,
1510 _preprocessor_config: PreProcessorConfig,
1511 _max_edge: Option<u32>,
1512 ) -> Arc<dyn Processor + Send + Sync> {
1513 Arc::new(LLaVAProcessor::new(model_config))
1514 }
1515 fn supports_paged_attention(&self, _config: &str) -> bool {
1516 true
1517 }
1518 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1519 true
1520 }
1521 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1522 Arc::new(LLaVAPrefixer)
1523 }
1524 fn modalities(&self, _config: &str) -> Result<Modalities> {
1525 Ok(Modalities {
1526 input: vec![SupportedModality::Text, SupportedModality::Vision],
1527 output: vec![SupportedModality::Text],
1528 })
1529 }
1530}
1531
1532impl IsqModelLoader for LLaVALoader {
1533 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1534 Ok(vec![
1535 Regex::new(r"lm_head\.(weight|bias)$")?,
1536 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1538 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1539 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1540 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1541 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1543 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1544 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1545 ])
1546 }
1547 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
1548 Ok(vec![
1549 Regex::new(r"lm_head\.(weight|bias)$")?,
1550 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1552 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1553 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1554 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1555 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1557 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1558 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1559 ])
1560 }
1561}
1562
1563impl DeviceMappedModelLoader for LLaVALoader {
1564 fn mapped_max_act_size_elems(
1565 &self,
1566 config: &str,
1567 params: &AutoDeviceMapParams,
1568 ) -> Result<usize> {
1569 let AutoDeviceMapParams::Multimodal {
1570 max_seq_len,
1571 max_batch_size,
1572 max_image_shape: _,
1573 max_num_images,
1574 } = params
1575 else {
1576 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1577 };
1578
1579 let config: LLaVAConfig = serde_json::from_str(config)?;
1580
1581 let img_seq_len =
1582 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1583 let img_seq_len = img_seq_len * max_num_images;
1584
1585 let max_text_attn = {
1586 let cfg = &config.text_config;
1587 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
1589
1590 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
1591 };
1592
1593 Ok(max_text_attn)
1594 }
1595
1596 fn non_mapped_max_act_size_elems(
1597 &self,
1598 config: &str,
1599 params: &AutoDeviceMapParams,
1600 ) -> Result<usize> {
1601 let AutoDeviceMapParams::Multimodal {
1602 max_seq_len: _,
1603 max_batch_size,
1604 max_image_shape: _,
1605 max_num_images,
1606 } = params
1607 else {
1608 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1609 };
1610
1611 let config: LLaVAConfig = serde_json::from_str(config)?;
1612
1613 let img_seq_len =
1614 llava_inputs_processor::LLaVAInputProcessor::get_num_image_tokens(&config);
1615
1616 let max_vision_attn = {
1617 (max_batch_size * max_num_images)
1618 * config.vision_config.num_attention_heads
1619 * img_seq_len
1620 * img_seq_len
1621 };
1622
1623 Ok(max_vision_attn)
1624 }
1625
1626 fn non_mapped_size_in_bytes(
1627 &self,
1628 config: &str,
1629 dtype: DType,
1630 weight_pack_factor: usize,
1631 _matformer_config: Option<&MatformerSliceConfig>,
1632 ) -> Result<usize> {
1633 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1634 let text_elems = {
1635 let cfg = &cfg.text_config;
1636 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1637 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1638 let norm = cfg.hidden_size;
1639 embed_tokens + lm_head + norm
1640 };
1641
1642 let image_newline = cfg.text_config.hidden_size;
1643 let mmproj = {
1644 let linear_1 = cfg.vision_config.hidden_size * cfg.text_config.hidden_size
1645 + cfg.text_config.hidden_size;
1646 let linear_2 = cfg.text_config.hidden_size * cfg.text_config.hidden_size
1647 + cfg.text_config.hidden_size;
1648
1649 linear_1 + linear_2
1650 };
1651 let vision_tower = get_clip_vit_num_elems(&cfg.to_clip_config());
1652
1653 let elems = text_elems + image_newline + mmproj + vision_tower;
1654 Ok(elems * dtype.size_in_bytes())
1655 }
1656
1657 fn layer_sizes_in_bytes(
1658 &self,
1659 config: &str,
1660 dtype: DType,
1661 weight_pack_factor: usize,
1662 _matformer_config: Option<&MatformerSliceConfig>,
1663 ) -> Result<Vec<usize>> {
1664 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1665 let per_layer_elems = {
1666 let cfg = &cfg.text_config;
1667 let input_layernorm = cfg.hidden_size;
1668 let post_attention_layernorm = cfg.hidden_size;
1669
1670 let size_in = cfg.hidden_size;
1671 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1672 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1673 let q_proj = size_in * size_q / weight_pack_factor;
1674 let k_proj = size_in * size_kv / weight_pack_factor;
1675 let v_proj = size_in * size_kv / weight_pack_factor;
1676 let o_proj = size_q * size_in / weight_pack_factor;
1677
1678 let h_size = cfg.hidden_size;
1679 let i_size = cfg.intermediate_size;
1680 let gate_proj = h_size * i_size / weight_pack_factor;
1681 let up_proj = h_size * i_size / weight_pack_factor;
1682 let down_proj = i_size * h_size / weight_pack_factor;
1683
1684 input_layernorm
1685 + post_attention_layernorm
1686 + q_proj
1687 + k_proj
1688 + v_proj
1689 + o_proj
1690 + gate_proj
1691 + up_proj
1692 + down_proj
1693 };
1694 Ok(vec![
1695 per_layer_elems * dtype.size_in_bytes();
1696 cfg.text_config.num_hidden_layers
1697 ])
1698 }
1699
1700 fn num_layers(&self, config: &str) -> Result<usize> {
1701 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1702 Ok(cfg.text_config.num_hidden_layers)
1703 }
1704
1705 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1706 let cfg: LLaVAConfig = serde_json::from_str(config)?;
1707 let cfg = &cfg.text_config;
1708
1709 let cfg = ModelConfigMetadata {
1710 max_seq_len: cfg.max_position_embeddings,
1711 num_layers: cfg.num_hidden_layers,
1712 hidden_size: cfg.hidden_size,
1713 num_kv_heads: cfg.num_key_value_heads,
1714 num_attn_heads: cfg.num_attention_heads,
1715 sliding_window: cfg.sliding_window,
1716 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1717 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1718 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1719 };
1720
1721 Ok(Box::new(cfg))
1722 }
1723
1724 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
1725 Some(vec![NonMappedSubModel::Vision])
1726 }
1727}
1728
1729pub struct VLlamaLoader;
1735
1736pub struct VLlamaPrefixer;
1737
1738impl MultimodalPromptPrefixer for VLlamaPrefixer {
1739 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
1740 format!("{}{prompt}", "<|image|>".repeat(image_indexes.len()))
1741 }
1742}
1743
1744impl MultimodalModelLoader for VLlamaLoader {
1745 fn load(
1746 &self,
1747 config: &str,
1748 vb: ShardedVarBuilder,
1749 normal_loading_metadata: NormalLoadingMetadata,
1750 attention_mechanism: AttentionImplementation,
1751 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
1752 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1753 Ok(Box::new(MLlamaModel::new(
1754 &cfg,
1755 vb,
1756 self.is_gptx(config),
1757 normal_loading_metadata,
1758 attention_mechanism,
1759 )?))
1760 }
1761 fn is_gptx(&self, _config: &str) -> bool {
1762 true
1763 }
1764 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1765 let cfg: crate::vision_models::mllama::MLlamaConfig = serde_json::from_str(config)?;
1766 Ok(Box::new(cfg))
1767 }
1768 fn get_processor(
1769 &self,
1770 _model_config: &str,
1771 _processor_config: Option<ProcessorConfig>,
1772 _preprocessor_config: PreProcessorConfig,
1773 _max_edge: Option<u32>,
1774 ) -> Arc<dyn Processor + Send + Sync> {
1775 Arc::new(MLlamaProcessor::new())
1776 }
1777 fn supports_paged_attention(&self, _config: &str) -> bool {
1778 true
1779 }
1780 fn supports_prefix_cacher(&self, _config: &str) -> bool {
1781 true
1782 }
1783 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
1784 Arc::new(VLlamaPrefixer)
1785 }
1786 fn modalities(&self, _config: &str) -> Result<Modalities> {
1787 Ok(Modalities {
1788 input: vec![SupportedModality::Text, SupportedModality::Vision],
1789 output: vec![SupportedModality::Text],
1790 })
1791 }
1792}
1793
1794impl IsqModelLoader for VLlamaLoader {
1795 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
1796 let config: MLlamaConfig = serde_json::from_str(config)?;
1797 let cross_attn_layers = &config.text_config.cross_attention_layers;
1798 let transformer_layers =
1799 (0..config.text_config.num_hidden_layers).filter(|i| !cross_attn_layers.contains(i));
1800 let mut text_regexes = Vec::new();
1801 for layer in transformer_layers {
1802 text_regexes.extend(vec![
1803 Regex::new(&format!(
1805 r"language_model.model.layers\.{layer}\.self_attn\.q_proj\.(weight|bias)$"
1806 ))?,
1807 Regex::new(&format!(
1808 r"language_model.model.layers\.{layer}\.self_attn\.k_proj\.(weight|bias)$"
1809 ))?,
1810 Regex::new(&format!(
1811 r"language_model.model.layers\.{layer}\.self_attn\.v_proj\.(weight|bias)$"
1812 ))?,
1813 Regex::new(&format!(
1814 r"language_model.model.layers\.{layer}\.self_attn\.o_proj\.(weight|bias)$"
1815 ))?,
1816 Regex::new(&format!(
1818 r"language_model.model.layers\.{layer}\.mlp\.gate_proj\.(weight|bias)$"
1819 ))?,
1820 Regex::new(&format!(
1821 r"language_model.model.layers\.{layer}\.mlp\.up_proj\.(weight|bias)$"
1822 ))?,
1823 Regex::new(&format!(
1824 r"language_model.model.layers\.{layer}\.mlp\.down_proj\.(weight|bias)$"
1825 ))?,
1826 ]);
1827 }
1828 let vision_regexes = vec![
1829 Regex::new(
1831 r"vision_model.transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1832 )?,
1833 Regex::new(
1834 r"vision_model.transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1835 )?,
1836 Regex::new(
1837 r"vision_model.transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1838 )?,
1839 Regex::new(
1840 r"vision_model.transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1841 )?,
1842 Regex::new(
1844 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
1845 )?,
1846 Regex::new(
1847 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$",
1848 )?,
1849 Regex::new(
1850 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$",
1851 )?,
1852 Regex::new(
1853 r"vision_model.global_transformer.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$",
1854 )?,
1855 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1857 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1858 ];
1859
1860 Ok([text_regexes, vision_regexes].concat())
1861 }
1862 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1863 self.isq_layer_regexes(config)
1864 }
1865}
1866
1867impl DeviceMappedModelLoader for VLlamaLoader {
1868 fn mapped_max_act_size_elems(
1869 &self,
1870 config: &str,
1871 params: &AutoDeviceMapParams,
1872 ) -> Result<usize> {
1873 let AutoDeviceMapParams::Multimodal {
1874 max_seq_len,
1875 max_batch_size,
1876 max_image_shape: _,
1877 max_num_images,
1878 } = params
1879 else {
1880 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1881 };
1882
1883 let config: MLlamaConfig = serde_json::from_str(config)?;
1884
1885 let img_seq_len = {
1886 let cfg = &config.vision_config;
1887 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1888 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1889 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1890 };
1891 let img_seq_len = img_seq_len * max_num_images;
1892
1893 let max_cross_text_attn = {
1894 let cfg = &config.text_config;
1895 max_batch_size * cfg.num_attention_heads * img_seq_len * img_seq_len
1896 };
1897
1898 let max_self_text_attn = {
1899 let cfg = &config.text_config;
1900 max_batch_size * cfg.num_attention_heads * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)
1901 };
1902
1903 Ok(max_self_text_attn.max(max_cross_text_attn))
1904 }
1905
1906 fn non_mapped_max_act_size_elems(
1907 &self,
1908 config: &str,
1909 params: &AutoDeviceMapParams,
1910 ) -> Result<usize> {
1911 let AutoDeviceMapParams::Multimodal {
1912 max_seq_len: _,
1913 max_batch_size,
1914 max_image_shape: _,
1915 max_num_images,
1916 } = params
1917 else {
1918 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
1919 };
1920
1921 let config: MLlamaConfig = serde_json::from_str(config)?;
1922
1923 let img_seq_len = {
1924 let cfg = &config.vision_config;
1925 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1926 let num_padding_patches = (8 - (num_patches as isize % 8)) % 8;
1927 cfg.max_num_tiles * (num_patches as isize + num_padding_patches) as usize
1928 };
1929 let max_vision_attn = {
1930 let cfg = &config.vision_config;
1931 (max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len
1932 };
1933
1934 Ok(max_vision_attn)
1935 }
1936
1937 fn non_mapped_size_in_bytes(
1938 &self,
1939 config: &str,
1940 dtype: DType,
1941 weight_pack_factor: usize,
1942 _matformer_config: Option<&MatformerSliceConfig>,
1943 ) -> Result<usize> {
1944 let config: MLlamaConfig = serde_json::from_str(config)?;
1945 let text_elems = {
1946 let cfg = &config.text_config;
1947 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1948 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1950 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1951 } else {
1952 0
1953 };
1954 let norm = cfg.hidden_size;
1955 embed_tokens + lm_head + norm
1956 };
1957
1958 let vision_elems = {
1959 let cfg = &config.vision_config;
1960
1961 let conv_cfg = Conv2dConfig {
1962 stride: cfg.patch_size,
1963 ..Default::default()
1964 };
1965 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_cfg.groups
1966 * cfg.patch_size
1967 * cfg.patch_size;
1968
1969 let class_embedding = cfg.hidden_size;
1970
1971 let gated_positional_embedding = {
1972 let num_patches = (cfg.image_size / cfg.patch_size).pow(2) + 1;
1973 let embedding = num_patches * cfg.hidden_size;
1974 let tile_embedding = (cfg.max_aspect_ratio_id() + 1)
1975 * (cfg.max_num_tiles * num_patches * cfg.hidden_size);
1976
1977 embedding + tile_embedding
1978 };
1979
1980 let pre_tile_positional_embedding =
1981 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1982 let post_tile_positional_embedding =
1983 (cfg.max_aspect_ratio_id() + 1) * (cfg.max_num_tiles * cfg.hidden_size);
1984
1985 let layernorm_pre = cfg.hidden_size;
1986 let layernorm_post = cfg.hidden_size;
1987
1988 let encoder_layer = {
1989 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1990 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
1991
1992 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
1993 let q_proj =
1994 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1995 let k_proj =
1996 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1997 let v_proj =
1998 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
1999 let o_proj =
2000 cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor;
2001
2002 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
2003 + cfg.intermediate_size;
2004 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
2005 + cfg.hidden_size;
2006
2007 input_layernorm
2008 + post_attention_layernorm
2009 + q_proj
2010 + k_proj
2011 + v_proj
2012 + o_proj
2013 + fc1
2014 + fc2
2015 };
2016
2017 patch_embedding
2018 + class_embedding
2019 + gated_positional_embedding
2020 + pre_tile_positional_embedding
2021 + post_tile_positional_embedding
2022 + layernorm_pre
2023 + layernorm_post
2024 + encoder_layer * (cfg.num_hidden_layers + cfg.num_global_layers)
2025 };
2026
2027 let elems = text_elems + vision_elems;
2028 Ok(elems * dtype.size_in_bytes())
2029 }
2030
2031 fn layer_sizes_in_bytes(
2032 &self,
2033 config: &str,
2034 dtype: DType,
2035 weight_pack_factor: usize,
2036 _matformer_config: Option<&MatformerSliceConfig>,
2037 ) -> Result<Vec<usize>> {
2038 let config: MLlamaConfig = serde_json::from_str(config)?;
2039 let cfg = &config.text_config;
2040
2041 let mut layer_sizes = Vec::new();
2042
2043 for i in 0..cfg.num_hidden_layers {
2044 let weight_pack_factor = if cfg.cross_attention_layers.contains(&i) {
2045 1
2047 } else {
2048 weight_pack_factor
2049 };
2050
2051 let per_layer_elems = {
2052 let input_layernorm = cfg.hidden_size;
2053 let post_attention_layernorm = cfg.hidden_size;
2054
2055 let size_in = cfg.hidden_size;
2056 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2057 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2058 let q_proj = size_in * size_q / weight_pack_factor;
2059 let k_proj = size_in * size_kv / weight_pack_factor;
2060 let v_proj = size_in * size_kv / weight_pack_factor;
2061 let o_proj = size_q * size_in / weight_pack_factor;
2062
2063 let h_size = cfg.hidden_size;
2064 let i_size = cfg.intermediate_size;
2065 let gate_proj = h_size * i_size / weight_pack_factor;
2066 let up_proj = h_size * i_size / weight_pack_factor;
2067 let down_proj = i_size * h_size / weight_pack_factor;
2068
2069 input_layernorm
2070 + post_attention_layernorm
2071 + q_proj
2072 + k_proj
2073 + v_proj
2074 + o_proj
2075 + gate_proj
2076 + up_proj
2077 + down_proj
2078 };
2079
2080 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
2081 }
2082
2083 Ok(layer_sizes)
2084 }
2085
2086 fn num_layers(&self, config: &str) -> Result<usize> {
2087 let config: MLlamaConfig = serde_json::from_str(config)?;
2088 Ok(config.text_config.num_hidden_layers)
2089 }
2090
2091 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2092 let cfg: MLlamaConfig = serde_json::from_str(config)?;
2093 let cfg = &cfg.text_config;
2094
2095 let cfg = ModelConfigMetadata {
2096 max_seq_len: cfg.max_position_embeddings,
2097 num_layers: cfg.num_hidden_layers,
2098 hidden_size: cfg.hidden_size,
2099 num_kv_heads: cfg.num_key_value_heads,
2100 num_attn_heads: cfg.num_attention_heads,
2101 sliding_window: None,
2102 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2103 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2104 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2105 };
2106
2107 Ok(Box::new(cfg))
2108 }
2109
2110 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2111 Some(vec![NonMappedSubModel::Vision])
2112 }
2113}
2114
2115pub struct Qwen2VLLoader;
2121
2122pub struct Qwen2VLPrefixer;
2123
2124impl MultimodalPromptPrefixer for Qwen2VLPrefixer {
2125 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2126 format!(
2127 "{}{prompt}",
2128 format!(
2129 "{}{}{}",
2130 Qwen2VLProcessor::VISION_START,
2131 Qwen2VLProcessor::IMAGE_PAD,
2132 Qwen2VLProcessor::VISION_END
2133 )
2134 .repeat(image_indexes.len())
2135 )
2136 }
2137}
2138
2139impl MultimodalModelLoader for Qwen2VLLoader {
2140 fn load(
2141 &self,
2142 config: &str,
2143 vb: ShardedVarBuilder,
2144 normal_loading_metadata: NormalLoadingMetadata,
2145 attention_mechanism: AttentionImplementation,
2146 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
2147 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2148 Ok(Box::new(Qwen2VLModel::new(
2149 &cfg,
2150 vb,
2151 self.is_gptx(config),
2152 normal_loading_metadata,
2153 attention_mechanism,
2154 )?))
2155 }
2156 fn is_gptx(&self, _config: &str) -> bool {
2157 true
2158 }
2159 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2160 let config: Qwen2VLConfig = serde_json::from_str(config)?;
2161 Ok(Box::new(config))
2162 }
2163 fn get_processor(
2164 &self,
2165 _model_config: &str,
2166 _processor_config: Option<ProcessorConfig>,
2167 _preprocessor_config: PreProcessorConfig,
2168 max_edge: Option<u32>,
2169 ) -> Arc<dyn Processor + Send + Sync> {
2170 Arc::new(Qwen2VLProcessor::new(max_edge))
2171 }
2172 fn supports_paged_attention(&self, _config: &str) -> bool {
2173 false
2174 }
2175 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2176 Arc::new(Qwen2VLPrefixer)
2177 }
2178 fn modalities(&self, _config: &str) -> Result<Modalities> {
2179 Ok(Modalities {
2180 input: vec![SupportedModality::Text, SupportedModality::Vision],
2181 output: vec![SupportedModality::Text],
2182 })
2183 }
2184}
2185
2186impl IsqModelLoader for Qwen2VLLoader {
2187 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2188 Ok(vec![
2189 Regex::new(r"lm_head\.(weight|bias)$")?,
2190 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2192 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2193 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2194 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2195 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2197 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2198 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2199 ])
2200 }
2201 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2202 self.isq_layer_regexes(config)
2203 }
2204}
2205
2206impl DeviceMappedModelLoader for Qwen2VLLoader {
2207 fn mapped_max_act_size_elems(
2208 &self,
2209 config: &str,
2210 params: &AutoDeviceMapParams,
2211 ) -> Result<usize> {
2212 let AutoDeviceMapParams::Multimodal {
2213 max_seq_len,
2214 max_batch_size,
2215 max_image_shape,
2216 max_num_images,
2217 } = params
2218 else {
2219 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2220 };
2221
2222 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2223
2224 let img_seq_len = {
2226 let cfg = &cfg.vision_config;
2227 let grid_t = 1;
2229 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
2231 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
2232 grid_t * grid_h * grid_w * max_num_images
2233 };
2234
2235 let max_text_attn = {
2236 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2238 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
2239 };
2240
2241 Ok(max_text_attn)
2242 }
2243
2244 fn non_mapped_max_act_size_elems(
2245 &self,
2246 config: &str,
2247 params: &AutoDeviceMapParams,
2248 ) -> Result<usize> {
2249 let AutoDeviceMapParams::Multimodal {
2250 max_seq_len: _,
2251 max_batch_size,
2252 max_image_shape,
2253 max_num_images,
2254 } = params
2255 else {
2256 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2257 };
2258
2259 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2260
2261 let img_seq_len = {
2263 let cfg = &cfg.vision_config;
2264 let grid_t = 1;
2266 let grid_h = max_image_shape.0 / cfg.patch_size;
2267 let grid_w = max_image_shape.1 / cfg.patch_size;
2268 grid_t * grid_h * grid_w
2269 };
2270
2271 let max_vision_attn = {
2272 let cfg = &cfg.vision_config;
2273 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
2274 };
2275
2276 Ok(max_vision_attn)
2277 }
2278
2279 fn non_mapped_size_in_bytes(
2280 &self,
2281 config: &str,
2282 dtype: DType,
2283 weight_pack_factor: usize,
2284 _matformer_config: Option<&MatformerSliceConfig>,
2285 ) -> Result<usize> {
2286 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2287 let text_elems = {
2288 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2289 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2291 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2292 } else {
2293 0
2294 };
2295 let norm = cfg.hidden_size;
2296 embed_tokens + lm_head + norm
2297 };
2298
2299 let patch_merger = {
2300 let cfg = &cfg.vision_config;
2301 let hidden_size = cfg.embed_dim * cfg.spatial_merge_size.pow(2);
2302
2303 let mlp0 = hidden_size * hidden_size + hidden_size;
2304 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
2305
2306 let ln_q = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2307
2308 mlp0 + mlp2 + ln_q
2309 };
2310
2311 let patch_embed = {
2312 let cfg = &cfg.vision_config;
2313 let conv_cfg = Conv3dConfig {
2314 stride: cfg.patch_size,
2315 ..Default::default()
2316 };
2317 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
2318 cfg.in_channels * cfg.embed_dim / conv_cfg.groups
2319 * kernel_sizes[0]
2320 * kernel_sizes[1]
2321 * kernel_sizes[2]
2322 };
2323
2324 let encoder_layer = {
2325 let cfg = &cfg.vision_config;
2326 let norm1 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2327 let norm2 = cfg.embed_dim + bias_if!(true, cfg.embed_dim);
2328
2329 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
2330 let mlp_hidden_dim = (cfg.embed_dim as f64 * cfg.mlp_ratio) as usize;
2331 let fc1 = cfg.embed_dim * mlp_hidden_dim + mlp_hidden_dim;
2332 let fc2 = cfg.embed_dim * mlp_hidden_dim + cfg.embed_dim;
2333
2334 let qkv = cfg.embed_dim * cfg.embed_dim * 3 + cfg.embed_dim * 3;
2335 let out = cfg.embed_dim * cfg.embed_dim + cfg.embed_dim;
2336
2337 norm1 + norm2 + fc1 + fc2 + qkv + out
2338 };
2339
2340 let elems =
2341 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
2342
2343 Ok(elems * dtype.size_in_bytes())
2344 }
2345
2346 fn layer_sizes_in_bytes(
2347 &self,
2348 config: &str,
2349 dtype: DType,
2350 weight_pack_factor: usize,
2351 _matformer_config: Option<&MatformerSliceConfig>,
2352 ) -> Result<Vec<usize>> {
2353 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2354 let per_layer_elems = {
2355 let input_layernorm = cfg.hidden_size;
2356 let post_attention_layernorm = cfg.hidden_size;
2357
2358 let size_in = cfg.hidden_size;
2359 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2360 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2361 let q_proj = size_in * size_q / weight_pack_factor + size_q;
2362 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
2363 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
2364 let o_proj = size_q * size_in / weight_pack_factor;
2365
2366 let h_size = cfg.hidden_size;
2367 let i_size = cfg.intermediate_size;
2368 let gate_proj = h_size * i_size / weight_pack_factor;
2369 let up_proj = h_size * i_size / weight_pack_factor;
2370 let down_proj = i_size * h_size / weight_pack_factor;
2371
2372 input_layernorm
2373 + post_attention_layernorm
2374 + q_proj
2375 + k_proj
2376 + v_proj
2377 + o_proj
2378 + gate_proj
2379 + up_proj
2380 + down_proj
2381 };
2382 Ok(vec![
2383 per_layer_elems * dtype.size_in_bytes();
2384 cfg.num_hidden_layers
2385 ])
2386 }
2387
2388 fn num_layers(&self, config: &str) -> Result<usize> {
2389 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2390 Ok(cfg.num_hidden_layers)
2391 }
2392
2393 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2394 let cfg: Qwen2VLConfig = serde_json::from_str(config)?;
2395
2396 let cfg = ModelConfigMetadata {
2397 max_seq_len: cfg.max_position_embeddings,
2398 num_layers: cfg.num_hidden_layers,
2399 hidden_size: cfg.hidden_size,
2400 num_kv_heads: cfg.num_key_value_heads,
2401 num_attn_heads: cfg.num_attention_heads,
2402 sliding_window: cfg.sliding_window,
2403 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2404 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2405 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2406 };
2407
2408 Ok(Box::new(cfg))
2409 }
2410
2411 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2412 Some(vec![NonMappedSubModel::Vision])
2413 }
2414}
2415
2416pub struct Idefics3Loader;
2422
2423pub struct Idefics3Prefixer;
2424
2425impl MultimodalPromptPrefixer for Idefics3Prefixer {
2426 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
2427 prompt.to_string()
2429 }
2430}
2431
2432impl MultimodalModelLoader for Idefics3Loader {
2433 fn load(
2434 &self,
2435 config: &str,
2436 vb: ShardedVarBuilder,
2437 normal_loading_metadata: NormalLoadingMetadata,
2438 attention_mechanism: AttentionImplementation,
2439 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
2440 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2441 Ok(Box::new(Idefics3Model::new(
2442 &cfg,
2443 vb,
2444 self.is_gptx(config),
2445 normal_loading_metadata,
2446 attention_mechanism,
2447 )?))
2448 }
2449 fn is_gptx(&self, _config: &str) -> bool {
2450 true
2451 }
2452 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2453 let cfg: crate::vision_models::idefics3::Idefics3Config = serde_json::from_str(config)?;
2454 Ok(Box::new(cfg))
2455 }
2456 fn get_processor(
2457 &self,
2458 _model_config: &str,
2459 processor_config: Option<ProcessorConfig>,
2460 preprocessor_config: PreProcessorConfig,
2461 max_edge: Option<u32>,
2462 ) -> Arc<dyn Processor + Send + Sync> {
2463 Arc::new(Idefics3Processor::new(
2464 processor_config.unwrap_or_default(),
2465 preprocessor_config,
2466 max_edge,
2467 ))
2468 }
2469 fn supports_paged_attention(&self, _config: &str) -> bool {
2470 true
2471 }
2472 fn supports_prefix_cacher(&self, _config: &str) -> bool {
2473 true
2474 }
2475 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2476 Arc::new(Idefics3Prefixer)
2477 }
2478 fn modalities(&self, _config: &str) -> Result<Modalities> {
2479 Ok(Modalities {
2480 input: vec![SupportedModality::Text, SupportedModality::Vision],
2481 output: vec![SupportedModality::Text],
2482 })
2483 }
2484}
2485
2486impl IsqModelLoader for Idefics3Loader {
2487 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2488 Ok(vec![
2489 Regex::new(r"lm_head\.(weight|bias)$")?,
2490 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2492 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2493 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2494 Regex::new(r"model.text_model.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2495 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2497 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2498 Regex::new(r"model.text_model.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2499 ])
2500 }
2501 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
2502 Ok(vec![
2503 Regex::new(r"lm_head\.(weight|bias)$")?,
2504 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2506 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2507 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2508 Regex::new(r"model\.text_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2509 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2511 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2512 Regex::new(r"model\.text_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2513 ])
2530 }
2531}
2532
2533impl DeviceMappedModelLoader for Idefics3Loader {
2534 fn mapped_max_act_size_elems(
2535 &self,
2536 config: &str,
2537 params: &AutoDeviceMapParams,
2538 ) -> Result<usize> {
2539 let AutoDeviceMapParams::Multimodal {
2540 max_seq_len,
2541 max_batch_size,
2542 max_image_shape: _,
2543 max_num_images,
2544 } = params
2545 else {
2546 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2547 };
2548
2549 let cfg: Idefics3Config = serde_json::from_str(config)?;
2550
2551 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2552 let img_seq_len = (num_patches + 1) * max_num_images;
2553
2554 let max_text_attn = {
2555 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2557 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2558 };
2559
2560 Ok(max_text_attn)
2561 }
2562
2563 fn non_mapped_max_act_size_elems(
2564 &self,
2565 config: &str,
2566 params: &AutoDeviceMapParams,
2567 ) -> Result<usize> {
2568 let AutoDeviceMapParams::Multimodal {
2569 max_seq_len: _,
2570 max_batch_size,
2571 max_image_shape: _,
2572 max_num_images,
2573 } = params
2574 else {
2575 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2576 };
2577
2578 let cfg: Idefics3Config = serde_json::from_str(config)?;
2579
2580 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2581 let img_seq_len = num_patches + 1;
2582
2583 let max_vision_attn = {
2584 let images_factor = 5;
2586
2587 (max_batch_size * images_factor * max_num_images)
2588 * cfg.vision_config.num_attention_heads
2589 * img_seq_len
2590 * img_seq_len
2591 };
2592
2593 Ok(max_vision_attn)
2594 }
2595
2596 fn non_mapped_size_in_bytes(
2597 &self,
2598 config: &str,
2599 dtype: DType,
2600 weight_pack_factor: usize,
2601 _matformer_config: Option<&MatformerSliceConfig>,
2602 ) -> Result<usize> {
2603 let cfg: Idefics3Config = serde_json::from_str(config)?;
2604 let text_elems = {
2605 let cfg = &cfg.text_config;
2606
2607 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2608 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2609 let norm = cfg.hidden_size;
2610 embed_tokens + lm_head + norm
2611 };
2612
2613 let connector_elems = {
2614 let in_dim = cfg.vision_config.hidden_size * cfg.scale_factor.pow(2);
2615 let out_dim = cfg.text_config.hidden_size;
2616
2617 in_dim * out_dim
2618 };
2619
2620 let vision_transformer = {
2621 let cfg = &cfg.vision_config;
2622
2623 let post_layernorm = cfg.hidden_size;
2624
2625 let conv_config = Conv2dConfig {
2626 stride: cfg.patch_size,
2627 ..Default::default()
2628 };
2629 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2630 * cfg.patch_size
2631 * cfg.patch_size;
2632
2633 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2634 let num_patches = num_patches_per_side.pow(2);
2635 let position_embedding = num_patches * cfg.hidden_size;
2636
2637 let layer_elems = {
2638 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2639 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2640
2641 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2642 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2643
2644 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2645 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2646 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2647 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2648
2649 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2650 };
2651
2652 post_layernorm
2653 + patch_embedding
2654 + position_embedding
2655 + layer_elems * cfg.num_hidden_layers
2656 };
2657
2658 let elems = text_elems + connector_elems + vision_transformer;
2659
2660 Ok(elems * dtype.size_in_bytes())
2661 }
2662
2663 fn layer_sizes_in_bytes(
2664 &self,
2665 config: &str,
2666 dtype: DType,
2667 weight_pack_factor: usize,
2668 _matformer_config: Option<&MatformerSliceConfig>,
2669 ) -> Result<Vec<usize>> {
2670 let cfg: Idefics3Config = serde_json::from_str(config)?;
2671 let cfg = cfg.text_config;
2672 let per_layer_elems = {
2673 let input_layernorm = cfg.hidden_size;
2674 let post_attention_layernorm = cfg.hidden_size;
2675
2676 let size_in = cfg.hidden_size;
2677 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2678 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2679 let q_proj = size_in * size_q / weight_pack_factor;
2680 let k_proj = size_in * size_kv / weight_pack_factor;
2681 let v_proj = size_in * size_kv / weight_pack_factor;
2682 let o_proj = size_q * size_in / weight_pack_factor;
2683
2684 let h_size = cfg.hidden_size;
2685 let i_size = cfg.intermediate_size;
2686 let gate_proj = h_size * i_size / weight_pack_factor;
2687 let up_proj = h_size * i_size / weight_pack_factor;
2688 let down_proj = i_size * h_size / weight_pack_factor;
2689
2690 input_layernorm
2691 + post_attention_layernorm
2692 + q_proj
2693 + k_proj
2694 + v_proj
2695 + o_proj
2696 + gate_proj
2697 + up_proj
2698 + down_proj
2699 };
2700 Ok(vec![
2701 per_layer_elems * dtype.size_in_bytes();
2702 cfg.num_hidden_layers
2703 ])
2704 }
2705
2706 fn num_layers(&self, config: &str) -> Result<usize> {
2707 let cfg: Idefics3Config = serde_json::from_str(config)?;
2708 Ok(cfg.text_config.num_hidden_layers)
2709 }
2710 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2711 let cfg: Idefics3Config = serde_json::from_str(config)?;
2712 let cfg = &cfg.text_config;
2713
2714 let cfg = ModelConfigMetadata {
2715 max_seq_len: cfg.max_position_embeddings,
2716 num_layers: cfg.num_hidden_layers,
2717 hidden_size: cfg.hidden_size,
2718 num_kv_heads: cfg.num_key_value_heads,
2719 num_attn_heads: cfg.num_attention_heads,
2720 sliding_window: None,
2721 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2722 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2723 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2724 };
2725
2726 Ok(Box::new(cfg))
2727 }
2728
2729 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
2730 Some(vec![NonMappedSubModel::Vision])
2731 }
2732}
2733
2734pub struct MiniCpmOLoader;
2740
2741pub struct MiniCpmOPrefixer;
2742
2743impl MultimodalPromptPrefixer for MiniCpmOPrefixer {
2744 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
2745 format!(
2746 "{}{prompt}",
2747 "(<image>./</image>)".repeat(image_indexes.len())
2748 )
2749 }
2750}
2751
2752impl MultimodalModelLoader for MiniCpmOLoader {
2753 fn load(
2754 &self,
2755 config: &str,
2756 vb: ShardedVarBuilder,
2757 normal_loading_metadata: NormalLoadingMetadata,
2758 attention_mechanism: AttentionImplementation,
2759 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
2760 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2761 Ok(Box::new(MiniCpmOModel::new(
2762 &cfg,
2763 vb,
2764 self.is_gptx(config),
2765 normal_loading_metadata,
2766 attention_mechanism,
2767 )?))
2768 }
2769 fn is_gptx(&self, _config: &str) -> bool {
2770 true
2771 }
2772 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2773 let cfg: crate::vision_models::minicpmo::MiniCpmOConfig = serde_json::from_str(config)?;
2774 Ok(Box::new(cfg))
2775 }
2776 fn get_processor(
2777 &self,
2778 _model_config: &str,
2779 processor_config: Option<ProcessorConfig>,
2780 preprocessor_config: PreProcessorConfig,
2781 max_edge: Option<u32>,
2782 ) -> Arc<dyn Processor + Send + Sync> {
2783 Arc::new(MiniCpmOProcessor::new(
2784 processor_config.unwrap_or_default(),
2785 preprocessor_config,
2786 max_edge,
2787 ))
2788 }
2789 fn supports_paged_attention(&self, _config: &str) -> bool {
2790 true
2791 }
2792 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
2793 Arc::new(MiniCpmOPrefixer)
2794 }
2795 fn modalities(&self, _config: &str) -> Result<Modalities> {
2796 Ok(Modalities {
2797 input: vec![SupportedModality::Text, SupportedModality::Vision],
2798 output: vec![SupportedModality::Text],
2799 })
2800 }
2801}
2802
2803impl IsqModelLoader for MiniCpmOLoader {
2804 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2805 Ok(vec![
2806 Regex::new(r"llm.lm_head\.(weight|bias)$")?,
2807 Regex::new(r"llm.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2809 Regex::new(r"llm.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2810 Regex::new(r"llm.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2811 Regex::new(r"llm.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2812 Regex::new(r"llm.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
2814 Regex::new(r"llm.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
2815 Regex::new(r"llm.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
2816 ])
2817 }
2818 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2819 self.isq_layer_regexes(config)
2820 }
2821}
2822
2823impl DeviceMappedModelLoader for MiniCpmOLoader {
2824 fn mapped_max_act_size_elems(
2825 &self,
2826 config: &str,
2827 params: &AutoDeviceMapParams,
2828 ) -> Result<usize> {
2829 let AutoDeviceMapParams::Multimodal {
2830 max_seq_len,
2831 max_batch_size,
2832 max_image_shape: _,
2833 max_num_images,
2834 } = params
2835 else {
2836 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2837 };
2838
2839 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2840
2841 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2842 let img_seq_len = (num_patches + 1) * max_num_images;
2843
2844 let max_text_attn = {
2845 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
2847 max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len
2848 };
2849
2850 Ok(max_text_attn)
2851 }
2852
2853 fn non_mapped_max_act_size_elems(
2854 &self,
2855 config: &str,
2856 params: &AutoDeviceMapParams,
2857 ) -> Result<usize> {
2858 let AutoDeviceMapParams::Multimodal {
2859 max_seq_len: _,
2860 max_batch_size,
2861 max_image_shape: _,
2862 max_num_images,
2863 } = params
2864 else {
2865 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
2866 };
2867
2868 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2869
2870 let num_patches = (cfg.vision_config.image_size / cfg.vision_config.patch_size).pow(2);
2871 let img_seq_len = num_patches + 1;
2872
2873 let max_vision_attn = {
2874 let images_factor = 5;
2876
2877 (max_batch_size * images_factor * max_num_images)
2878 * cfg.vision_config.num_attention_heads
2879 * img_seq_len
2880 * img_seq_len
2881 };
2882
2883 Ok(max_vision_attn)
2884 }
2885
2886 fn non_mapped_size_in_bytes(
2887 &self,
2888 config: &str,
2889 dtype: DType,
2890 weight_pack_factor: usize,
2891 _matformer_config: Option<&MatformerSliceConfig>,
2892 ) -> Result<usize> {
2893 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2894 let text_elems = {
2895 let cfg = &cfg.text_config;
2896
2897 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2898 let lm_head = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2899 let norm = cfg.hidden_size;
2900 embed_tokens + lm_head + norm
2901 };
2902
2903 let vision_transformer = {
2904 let cfg = &cfg.vision_config;
2905
2906 let post_layernorm = cfg.hidden_size;
2907
2908 let conv_config = Conv2dConfig {
2909 stride: cfg.patch_size,
2910 ..Default::default()
2911 };
2912 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
2913 * cfg.patch_size
2914 * cfg.patch_size;
2915
2916 let num_patches_per_side = cfg.image_size / cfg.patch_size;
2917 let num_patches = num_patches_per_side.pow(2);
2918 let position_embedding = num_patches * cfg.hidden_size;
2919
2920 let layer_elems = {
2921 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2922 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
2923
2924 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
2925 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
2926
2927 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2928 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2929 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2930 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
2931
2932 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
2933 };
2934
2935 post_layernorm
2936 + patch_embedding
2937 + position_embedding
2938 + layer_elems * cfg.num_hidden_layers
2939 };
2940
2941 let elems = text_elems + vision_transformer;
2942
2943 Ok(elems * dtype.size_in_bytes())
2944 }
2945
2946 fn layer_sizes_in_bytes(
2947 &self,
2948 config: &str,
2949 dtype: DType,
2950 weight_pack_factor: usize,
2951 _matformer_config: Option<&MatformerSliceConfig>,
2952 ) -> Result<Vec<usize>> {
2953 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2954 let cfg = cfg.text_config;
2955 let per_layer_elems = {
2956 let input_layernorm = cfg.hidden_size;
2957 let post_attention_layernorm = cfg.hidden_size;
2958
2959 let size_in = cfg.hidden_size;
2960 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2961 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2962 let q_proj = size_in * size_q / weight_pack_factor;
2963 let k_proj = size_in * size_kv / weight_pack_factor;
2964 let v_proj = size_in * size_kv / weight_pack_factor;
2965 let o_proj = size_q * size_in / weight_pack_factor;
2966
2967 let h_size = cfg.hidden_size;
2968 let i_size = cfg.intermediate_size;
2969 let gate_proj = h_size * i_size / weight_pack_factor;
2970 let up_proj = h_size * i_size / weight_pack_factor;
2971 let down_proj = i_size * h_size / weight_pack_factor;
2972
2973 input_layernorm
2974 + post_attention_layernorm
2975 + q_proj
2976 + k_proj
2977 + v_proj
2978 + o_proj
2979 + gate_proj
2980 + up_proj
2981 + down_proj
2982 };
2983 Ok(vec![
2984 per_layer_elems * dtype.size_in_bytes();
2985 cfg.num_hidden_layers
2986 ])
2987 }
2988
2989 fn num_layers(&self, config: &str) -> Result<usize> {
2990 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2991 Ok(cfg.text_config.num_hidden_layers)
2992 }
2993 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2994 let cfg: MiniCpmOConfig = serde_json::from_str(config)?;
2995 let cfg = &cfg.text_config;
2996
2997 let cfg = ModelConfigMetadata {
2998 max_seq_len: cfg.max_position_embeddings,
2999 num_layers: cfg.num_hidden_layers,
3000 hidden_size: cfg.hidden_size,
3001 num_kv_heads: cfg.num_key_value_heads,
3002 num_attn_heads: cfg.num_attention_heads,
3003 sliding_window: None,
3004 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3005 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3006 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3007 };
3008
3009 Ok(Box::new(cfg))
3010 }
3011}
3012
3013pub struct Phi4MMLoader;
3019
3020pub struct Phi4MMPrefixer;
3021
3022impl MultimodalPromptPrefixer for Phi4MMPrefixer {
3023 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3024 format!(
3027 "{}{prompt}",
3028 image_indexes
3029 .into_iter()
3030 .map(|image_index| format!("<|image_{}|>", image_index + 1))
3031 .join("")
3032 )
3033 }
3034 fn prefix_audio(&self, audio_indexes: Vec<usize>, prompt: &str) -> String {
3035 format!(
3038 "{}{prompt}",
3039 audio_indexes
3040 .into_iter()
3041 .map(|audio_index| format!("<|audio_{}|>", audio_index + 1))
3042 .join("")
3043 )
3044 }
3045}
3046
3047impl MultimodalModelLoader for Phi4MMLoader {
3048 fn load(
3049 &self,
3050 config: &str,
3051 vb: ShardedVarBuilder,
3052 normal_loading_metadata: NormalLoadingMetadata,
3053 attention_mechanism: AttentionImplementation,
3054 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
3055 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
3056 Ok(Box::new(Phi4MMModel::new(
3057 &cfg,
3058 vb,
3059 self.is_gptx(config),
3060 normal_loading_metadata,
3061 attention_mechanism,
3062 )?))
3063 }
3064 fn is_gptx(&self, _config: &str) -> bool {
3065 true
3066 }
3067 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3068 let cfg: crate::vision_models::phi4::Phi4MMConfig = serde_json::from_str(config)?;
3069 Ok(Box::new(cfg))
3070 }
3071 fn get_processor(
3072 &self,
3073 _model_config: &str,
3074 processor_config: Option<ProcessorConfig>,
3075 preprocessor_config: PreProcessorConfig,
3076 _max_edge: Option<u32>,
3077 ) -> Arc<dyn Processor + Send + Sync> {
3078 Phi4MMProcessor::new_processor(processor_config, preprocessor_config)
3079 }
3080 fn supports_paged_attention(&self, _config: &str) -> bool {
3081 true
3082 }
3083 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3084 true
3085 }
3086 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3087 Arc::new(Phi4MMPrefixer)
3088 }
3089 fn modalities(&self, _config: &str) -> Result<Modalities> {
3090 Ok(Modalities {
3091 input: vec![
3092 SupportedModality::Text,
3093 SupportedModality::Vision,
3094 SupportedModality::Audio,
3095 ],
3096 output: vec![SupportedModality::Text],
3097 })
3098 }
3099}
3100
3101impl IsqModelLoader for Phi4MMLoader {
3102 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3103 Ok(vec![
3104 Regex::new(r"lm_head\.(weight|bias)$")?,
3105 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
3107 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3108 Regex::new(r"layers\.(\d+)\.mlp\.gate_up_proj\.(weight|bias)$")?,
3110 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3111 ])
3112 }
3113 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3114 self.isq_layer_regexes(config)
3115 }
3116}
3117
3118impl DeviceMappedModelLoader for Phi4MMLoader {
3119 fn mapped_max_act_size_elems(
3120 &self,
3121 config: &str,
3122 params: &AutoDeviceMapParams,
3123 ) -> Result<usize> {
3124 let AutoDeviceMapParams::Multimodal {
3126 max_seq_len,
3127 max_batch_size,
3128 max_image_shape: _,
3129 max_num_images,
3130 } = params
3131 else {
3132 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3133 };
3134
3135 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3136
3137 let vcfg = &PHI4_MM_VISION_CFG;
3138
3139 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3140 let img_seq_len = (num_patches + 1) * max_num_images;
3141
3142 let max_text_attn = {
3143 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3145 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3146 };
3147
3148 Ok(max_text_attn)
3149 }
3150
3151 fn non_mapped_max_act_size_elems(
3152 &self,
3153 _config: &str,
3154 params: &AutoDeviceMapParams,
3155 ) -> Result<usize> {
3156 let AutoDeviceMapParams::Multimodal {
3157 max_seq_len: _,
3158 max_batch_size,
3159 max_image_shape,
3160 max_num_images,
3161 } = params
3162 else {
3163 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3164 };
3165
3166 let vcfg = &PHI4_MM_VISION_CFG;
3167
3168 let num_patches = (vcfg.image_size / vcfg.patch_size).pow(2);
3169 let img_seq_len = num_patches + 1;
3170
3171 let max_batch_size = max_batch_size
3172 * (max_image_shape
3173 .0
3174 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3175 * max_image_shape
3176 .1
3177 .div_ceil(phi4::inputs_processor::DYHD_BASE_RESOLUTION)
3178 + 1);
3179
3180 let max_vision_attn = (max_batch_size * max_num_images)
3181 * vcfg.num_attention_heads
3182 * img_seq_len
3183 * img_seq_len;
3184 let max_qkv = 3
3185 * (max_batch_size
3186 * vcfg.num_attention_heads
3187 * img_seq_len
3188 * (vcfg.hidden_size / vcfg.num_attention_heads));
3189
3190 Ok(max_vision_attn + max_qkv)
3191 }
3192
3193 fn non_mapped_size_in_bytes(
3194 &self,
3195 config: &str,
3196 dtype: DType,
3197 weight_pack_factor: usize,
3198 _matformer_config: Option<&MatformerSliceConfig>,
3199 ) -> Result<usize> {
3200 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3201 let elems = {
3202 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3203 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3205 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3206 } else {
3207 0
3208 };
3209 let norm = cfg.hidden_size;
3210
3211 let image_embed = if let Some(img_embed) = &cfg.embd_layer.image_embd_layer {
3212 let projection_cls = img_embed
3213 .projection_cls
3214 .clone()
3215 .unwrap_or("linear".to_string());
3216 let with_learnable_separator = img_embed.with_learnable_separator.unwrap_or(false);
3217 let use_hd_transform = img_embed.use_hd_transform.unwrap_or(false);
3218 let image_dim_out = PHI4_MM_VISION_CFG.hidden_size;
3219
3220 let proj = match (projection_cls.as_str(), use_hd_transform) {
3221 ("linear", _) => image_dim_out * cfg.hidden_size + cfg.hidden_size,
3222 ("mlp", true) => {
3223 let a = (image_dim_out * 4) * cfg.hidden_size + cfg.hidden_size;
3224 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3225 a + b
3226 }
3227 ("mlp", false) => {
3228 let a = image_dim_out * cfg.hidden_size + cfg.hidden_size;
3229 let b = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3230 a + b
3231 }
3232 _ => {
3233 anyhow::bail!("projection_cls=`{projection_cls}` not implemented.");
3234 }
3235 };
3236
3237 let (glb_gn, sub_gn) = if with_learnable_separator {
3238 let glb_gn = image_dim_out * 4;
3239 let sub_gn = image_dim_out * 4;
3240 (glb_gn, sub_gn)
3241 } else {
3242 (0, 0)
3243 };
3244
3245 let vision_transformer = {
3246 let cfg = &PHI4_MM_VISION_CFG;
3247
3248 let post_layernorm = cfg.hidden_size;
3249
3250 let conv_config = Conv2dConfig {
3251 stride: cfg.patch_size,
3252 ..Default::default()
3253 };
3254 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3255 * cfg.patch_size
3256 * cfg.patch_size;
3257
3258 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3259 let num_patches = num_patches_per_side.pow(2);
3260 let position_embedding = num_patches * cfg.hidden_size;
3261
3262 let layer_elems = {
3263 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3264 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3265
3266 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3267 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3268
3269 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3270 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3271 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3272 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3273
3274 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3275 };
3276
3277 post_layernorm
3278 + patch_embedding
3279 + position_embedding
3280 + layer_elems * cfg.num_hidden_layers
3281 };
3282
3283 proj + glb_gn + sub_gn + vision_transformer
3284 } else {
3285 0
3286 };
3287
3288 embed_tokens + lm_head + norm + image_embed
3289 };
3290
3291 Ok(elems * dtype.size_in_bytes())
3292 }
3293
3294 fn layer_sizes_in_bytes(
3295 &self,
3296 config: &str,
3297 dtype: DType,
3298 weight_pack_factor: usize,
3299 _matformer_config: Option<&MatformerSliceConfig>,
3300 ) -> Result<Vec<usize>> {
3301 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3302 let per_layer_elems = {
3303 let input_layernorm = cfg.hidden_size;
3304 let post_attention_layernorm = cfg.hidden_size;
3305
3306 let size_in = cfg.hidden_size;
3307 let head_dim = cfg.head_dim();
3308 let op_size =
3309 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads() * head_dim;
3310 let qkv_proj = size_in * op_size / weight_pack_factor;
3311 let o_proj = (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor;
3312
3313 let h_size = cfg.hidden_size;
3314 let i_size = cfg.intermediate_size;
3315 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
3316 let down_proj = h_size * i_size / weight_pack_factor;
3317
3318 input_layernorm
3319 + post_attention_layernorm
3320 + qkv_proj
3321 + o_proj
3322 + gate_up_proj
3323 + down_proj
3324 };
3325 Ok(vec![
3326 per_layer_elems * dtype.size_in_bytes();
3327 cfg.num_hidden_layers
3328 ])
3329 }
3330
3331 fn num_layers(&self, config: &str) -> Result<usize> {
3332 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3333 Ok(cfg.num_hidden_layers)
3334 }
3335
3336 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3337 let cfg: Phi4MMConfig = serde_json::from_str(config)?;
3338
3339 let cfg = ModelConfigMetadata {
3340 max_seq_len: cfg.max_position_embeddings,
3341 num_layers: cfg.num_hidden_layers,
3342 hidden_size: cfg.hidden_size,
3343 num_kv_heads: cfg.num_key_value_heads(),
3344 num_attn_heads: cfg.num_attention_heads,
3345 sliding_window: cfg.sliding_window,
3346 k_head_dim: cfg.head_dim(),
3347 v_head_dim: cfg.head_dim(),
3348 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3349 };
3350
3351 Ok(Box::new(cfg))
3352 }
3353
3354 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3355 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
3356 }
3357}
3358
3359pub struct Qwen2_5VLLoader;
3365
3366pub struct Qwen2_5VLPrefixer;
3367
3368impl MultimodalPromptPrefixer for Qwen2_5VLPrefixer {
3369 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
3370 format!(
3371 "{}{prompt}",
3372 format!(
3373 "{}{}{}",
3374 Qwen2_5VLProcessor::VISION_START,
3375 Qwen2_5VLProcessor::IMAGE_PAD,
3376 Qwen2_5VLProcessor::VISION_END
3377 )
3378 .repeat(image_indexes.len())
3379 )
3380 }
3381}
3382
3383impl MultimodalModelLoader for Qwen2_5VLLoader {
3384 fn load(
3385 &self,
3386 config: &str,
3387 vb: ShardedVarBuilder,
3388 normal_loading_metadata: NormalLoadingMetadata,
3389 attention_mechanism: AttentionImplementation,
3390 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
3391 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3392 Ok(Box::new(Qwen2_5VLModel::new(
3393 &cfg,
3394 vb,
3395 self.is_gptx(config),
3396 normal_loading_metadata,
3397 attention_mechanism,
3398 )?))
3399 }
3400 fn is_gptx(&self, _config: &str) -> bool {
3401 true
3402 }
3403 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3404 let config: Qwen2_5VLConfig = serde_json::from_str(config)?;
3405 Ok(Box::new(config))
3406 }
3407 fn get_processor(
3408 &self,
3409 _model_config: &str,
3410 _processor_config: Option<ProcessorConfig>,
3411 _preprocessor_config: PreProcessorConfig,
3412 max_edge: Option<u32>,
3413 ) -> Arc<dyn Processor + Send + Sync> {
3414 Arc::new(Qwen2_5VLProcessor::new(max_edge))
3415 }
3416 fn supports_paged_attention(&self, _config: &str) -> bool {
3417 false
3418 }
3419 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3420 Arc::new(Qwen2_5VLPrefixer)
3421 }
3422 fn modalities(&self, _config: &str) -> Result<Modalities> {
3423 Ok(Modalities {
3424 input: vec![SupportedModality::Text, SupportedModality::Vision],
3425 output: vec![SupportedModality::Text],
3426 })
3427 }
3428}
3429
3430impl IsqModelLoader for Qwen2_5VLLoader {
3431 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3432 Ok(vec![
3433 Regex::new(r"lm_head\.(weight|bias)$")?,
3434 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3436 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3437 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3438 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3439 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3441 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3442 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3443 ])
3444 }
3445 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3446 self.isq_layer_regexes(config)
3447 }
3448}
3449
3450impl DeviceMappedModelLoader for Qwen2_5VLLoader {
3451 fn mapped_max_act_size_elems(
3452 &self,
3453 config: &str,
3454 params: &AutoDeviceMapParams,
3455 ) -> Result<usize> {
3456 let AutoDeviceMapParams::Multimodal {
3457 max_seq_len,
3458 max_batch_size,
3459 max_image_shape,
3460 max_num_images,
3461 } = params
3462 else {
3463 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3464 };
3465
3466 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3467
3468 let img_seq_len = {
3469 let cfg = &cfg.vision_config;
3470 let grid_t = max_num_images / cfg.temporal_patch_size;
3471 let grid_h = max_image_shape.0 / cfg.patch_size;
3472 let grid_w = max_image_shape.1 / cfg.patch_size;
3473 grid_t * grid_h * grid_w
3474 };
3475 let img_seq_len = img_seq_len * max_num_images;
3476
3477 let max_text_attn = {
3478 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3480 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
3481 };
3482
3483 Ok(max_text_attn)
3484 }
3485
3486 fn non_mapped_max_act_size_elems(
3487 &self,
3488 config: &str,
3489 params: &AutoDeviceMapParams,
3490 ) -> Result<usize> {
3491 let AutoDeviceMapParams::Multimodal {
3492 max_seq_len: _,
3493 max_batch_size,
3494 max_image_shape,
3495 max_num_images,
3496 } = params
3497 else {
3498 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3499 };
3500
3501 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3502
3503 let img_seq_len = {
3504 let cfg = &cfg.vision_config;
3505 let grid_t = max_num_images / cfg.temporal_patch_size;
3506 let grid_h = max_image_shape.0 / cfg.patch_size;
3507 let grid_w = max_image_shape.1 / cfg.patch_size;
3508 grid_t * grid_h * grid_w
3509 };
3510
3511 let max_vision_attn = {
3512 let cfg = &cfg.vision_config;
3513 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
3514 };
3515
3516 Ok(max_vision_attn)
3517 }
3518
3519 fn non_mapped_size_in_bytes(
3520 &self,
3521 config: &str,
3522 dtype: DType,
3523 weight_pack_factor: usize,
3524 _matformer_config: Option<&MatformerSliceConfig>,
3525 ) -> Result<usize> {
3526 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3527 let text_elems = {
3528 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3529 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3531 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3532 } else {
3533 0
3534 };
3535 let norm = cfg.hidden_size;
3536 embed_tokens + lm_head + norm
3537 };
3538
3539 let patch_merger = {
3540 let cfg = &cfg.vision_config;
3541 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
3542
3543 let mlp0 = hidden_size * hidden_size + hidden_size;
3544 let mlp2 = hidden_size * cfg.hidden_size + cfg.hidden_size;
3545
3546 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3547
3548 mlp0 + mlp2 + ln_q
3549 };
3550
3551 let patch_embed = {
3552 let cfg = &cfg.vision_config;
3553 let conv_cfg = Conv3dConfig {
3554 stride: cfg.patch_size,
3555 ..Default::default()
3556 };
3557 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
3558 cfg.in_chans * cfg.hidden_size / conv_cfg.groups
3559 * kernel_sizes[0]
3560 * kernel_sizes[1]
3561 * kernel_sizes[2]
3562 };
3563
3564 let encoder_layer = {
3565 let cfg = &cfg.vision_config;
3566 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3567 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3568
3569 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
3570 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3571 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
3572
3573 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
3574 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3575
3576 norm1 + norm2 + fc1 + fc2 + qkv + out
3577 };
3578
3579 let elems =
3580 text_elems + patch_merger + patch_embed + encoder_layer * cfg.vision_config.depth;
3581
3582 Ok(elems * dtype.size_in_bytes())
3583 }
3584
3585 fn layer_sizes_in_bytes(
3586 &self,
3587 config: &str,
3588 dtype: DType,
3589 weight_pack_factor: usize,
3590 _matformer_config: Option<&MatformerSliceConfig>,
3591 ) -> Result<Vec<usize>> {
3592 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3593 let per_layer_elems = {
3594 let input_layernorm = cfg.hidden_size;
3595 let post_attention_layernorm = cfg.hidden_size;
3596
3597 let size_in = cfg.hidden_size;
3598 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
3599 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
3600 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3601 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3602 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3603 let o_proj = size_q * size_in / weight_pack_factor;
3604
3605 let h_size = cfg.hidden_size;
3606 let i_size = cfg.intermediate_size;
3607 let gate_proj = h_size * i_size / weight_pack_factor;
3608 let up_proj = h_size * i_size / weight_pack_factor;
3609 let down_proj = i_size * h_size / weight_pack_factor;
3610
3611 input_layernorm
3612 + post_attention_layernorm
3613 + q_proj
3614 + k_proj
3615 + v_proj
3616 + o_proj
3617 + gate_proj
3618 + up_proj
3619 + down_proj
3620 };
3621 Ok(vec![
3622 per_layer_elems * dtype.size_in_bytes();
3623 cfg.num_hidden_layers
3624 ])
3625 }
3626
3627 fn num_layers(&self, config: &str) -> Result<usize> {
3628 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3629 Ok(cfg.num_hidden_layers)
3630 }
3631
3632 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3633 let cfg: Qwen2_5VLConfig = serde_json::from_str(config)?;
3634
3635 let cfg = ModelConfigMetadata {
3636 max_seq_len: cfg.max_position_embeddings,
3637 num_layers: cfg.num_hidden_layers,
3638 hidden_size: cfg.hidden_size,
3639 num_kv_heads: cfg.num_key_value_heads,
3640 num_attn_heads: cfg.num_attention_heads,
3641 sliding_window: cfg.sliding_window,
3642 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3643 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3644 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3645 };
3646
3647 Ok(Box::new(cfg))
3648 }
3649
3650 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3651 Some(vec![NonMappedSubModel::Vision])
3652 }
3653}
3654
3655pub struct Gemma3Loader;
3661
3662pub struct Gemma3Prefixer;
3663
3664impl MultimodalPromptPrefixer for Gemma3Prefixer {
3665 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
3666 prompt.to_string()
3667 }
3668}
3669
3670impl MultimodalModelLoader for Gemma3Loader {
3671 fn load(
3672 &self,
3673 config: &str,
3674 vb: ShardedVarBuilder,
3675 normal_loading_metadata: NormalLoadingMetadata,
3676 attention_mechanism: AttentionImplementation,
3677 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
3678 let cfg: Gemma3Config = serde_json::from_str(config)?;
3679 Ok(Box::new(Gemma3Model::new(
3680 &cfg,
3681 vb,
3682 self.is_gptx(config),
3683 normal_loading_metadata,
3684 attention_mechanism,
3685 )?))
3686 }
3687 fn is_gptx(&self, _config: &str) -> bool {
3688 true
3689 }
3690 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3691 let config: Gemma3Config = serde_json::from_str(config)?;
3692 Ok(Box::new(config))
3693 }
3694 fn get_processor(
3695 &self,
3696 config: &str,
3697 processor_config: Option<ProcessorConfig>,
3698 _preprocessor_config: PreProcessorConfig,
3699 _max_edge: Option<u32>,
3700 ) -> Arc<dyn Processor + Send + Sync> {
3701 let config: Gemma3Config = serde_json::from_str(config).unwrap();
3702 Arc::new(Gemma3Processor::new(
3704 processor_config.unwrap_or_default(),
3705 matches!(config, Gemma3Config::WithVision { .. }),
3706 ))
3707 }
3708 fn supports_paged_attention(&self, _config: &str) -> bool {
3709 true
3710 }
3711 fn supports_prefix_cacher(&self, _config: &str) -> bool {
3712 true
3713 }
3714 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
3715 Arc::new(Gemma3Prefixer)
3716 }
3717 fn modalities(&self, _config: &str) -> Result<Modalities> {
3718 Ok(Modalities {
3719 input: vec![SupportedModality::Text, SupportedModality::Vision],
3720 output: vec![SupportedModality::Text],
3721 })
3722 }
3723}
3724
3725impl IsqModelLoader for Gemma3Loader {
3726 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3727 Ok(vec![
3728 Regex::new(r"lm_head\.(weight|bias)$")?,
3729 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3731 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3732 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3733 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3734 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3736 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3737 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3738 ])
3739 }
3740 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
3741 Ok(vec![
3742 Regex::new(r"lm_head\.(weight|bias)$")?,
3743 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3745 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3746 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3747 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3748 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3750 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3751 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3752 ])
3753 }
3754}
3755
3756impl DeviceMappedModelLoader for Gemma3Loader {
3757 fn mapped_max_act_size_elems(
3758 &self,
3759 config: &str,
3760 params: &AutoDeviceMapParams,
3761 ) -> Result<usize> {
3762 let AutoDeviceMapParams::Multimodal {
3763 max_seq_len,
3764 max_batch_size,
3765 max_image_shape: _,
3766 max_num_images,
3767 } = params
3768 else {
3769 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3770 };
3771
3772 let cfg: Gemma3Config = serde_json::from_str(config)?;
3773
3774 match cfg {
3775 Gemma3Config::Text(text_config) => Ok(max_batch_size
3776 * text_config.num_attention_heads
3777 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2)),
3778 Gemma3Config::WithVision {
3779 text_config,
3780 vision_config,
3781 ..
3782 } => {
3783 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3784 let img_seq_len = (num_patches + 1) * max_num_images;
3785
3786 let max_text_attn = {
3787 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
3789 max_batch_size * text_config.num_attention_heads * max_seq_len * max_seq_len
3790 };
3791 Ok(max_text_attn)
3792 }
3793 }
3794 }
3795
3796 fn non_mapped_max_act_size_elems(
3797 &self,
3798 config: &str,
3799 params: &AutoDeviceMapParams,
3800 ) -> Result<usize> {
3801 let AutoDeviceMapParams::Multimodal {
3802 max_seq_len: _,
3803 max_batch_size,
3804 max_image_shape: _,
3805 max_num_images,
3806 } = params
3807 else {
3808 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
3809 };
3810
3811 let cfg: Gemma3Config = serde_json::from_str(config)?;
3812
3813 match cfg {
3814 Gemma3Config::WithVision { vision_config, .. } => {
3815 let num_patches = (vision_config.image_size / vision_config.patch_size).pow(2);
3816 let img_seq_len = num_patches + 1;
3817
3818 let max_vision_attn = {
3819 (max_batch_size * max_num_images)
3820 * vision_config.num_attention_heads
3821 * img_seq_len
3822 * img_seq_len
3823 };
3824
3825 Ok(max_vision_attn)
3826 }
3827 Gemma3Config::Text(_) => Ok(0),
3828 }
3829 }
3830
3831 fn non_mapped_size_in_bytes(
3832 &self,
3833 config: &str,
3834 dtype: DType,
3835 weight_pack_factor: usize,
3836 _matformer_config: Option<&MatformerSliceConfig>,
3837 ) -> Result<usize> {
3838 let cfg: Gemma3Config = serde_json::from_str(config)?;
3839
3840 let text_elems = {
3841 let cfg = match &cfg {
3842 Gemma3Config::Text(cfg) => cfg,
3843 Gemma3Config::WithVision { text_config, .. } => text_config,
3844 };
3845 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3846 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3848 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3849 } else {
3850 0
3851 };
3852 let norm = cfg.hidden_size;
3853 embed_tokens + lm_head + norm
3854 };
3855
3856 let vision_transformer = if let Gemma3Config::WithVision {
3857 vision_config: cfg, ..
3858 } = &cfg
3859 {
3860 let post_layernorm = cfg.hidden_size;
3861
3862 let conv_config = Conv2dConfig {
3863 stride: cfg.patch_size,
3864 ..Default::default()
3865 };
3866 let patch_embedding = cfg.num_channels * cfg.hidden_size / conv_config.groups
3867 * cfg.patch_size
3868 * cfg.patch_size;
3869
3870 let num_patches_per_side = cfg.image_size / cfg.patch_size;
3871 let num_patches = num_patches_per_side.pow(2);
3872 let position_embedding = num_patches * cfg.hidden_size;
3873
3874 let layer_elems = {
3875 let layer_norm_1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3876 let layer_norm_2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
3877
3878 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
3879 let fc2 = cfg.intermediate_size * cfg.hidden_size + cfg.hidden_size;
3880
3881 let q_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3882 let k_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3883 let v_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3884 let o_proj = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
3885
3886 layer_norm_1 + layer_norm_2 + fc1 + fc2 + q_proj + k_proj + v_proj + o_proj
3887 };
3888
3889 post_layernorm
3890 + patch_embedding
3891 + position_embedding
3892 + layer_elems * cfg.num_hidden_layers
3893 } else {
3894 0
3895 };
3896
3897 let elems = text_elems + vision_transformer;
3898
3899 Ok(elems * dtype.size_in_bytes())
3900 }
3901
3902 fn layer_sizes_in_bytes(
3903 &self,
3904 config: &str,
3905 dtype: DType,
3906 weight_pack_factor: usize,
3907 _matformer_config: Option<&MatformerSliceConfig>,
3908 ) -> Result<Vec<usize>> {
3909 let cfg: Gemma3Config = serde_json::from_str(config)?;
3910
3911 let txt_cfg = match &cfg {
3912 Gemma3Config::Text(cfg) => cfg,
3913 Gemma3Config::WithVision { text_config, .. } => text_config,
3914 };
3915 let per_layer_elems = {
3916 let cfg = txt_cfg;
3917
3918 let input_layernorm = cfg.hidden_size;
3919 let post_attention_layernorm = cfg.hidden_size;
3920
3921 let size_in = cfg.hidden_size;
3922 let size_q = cfg.head_dim * cfg.num_attention_heads;
3923 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
3924 let q_proj =
3925 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
3926 let k_proj =
3927 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3928 let v_proj =
3929 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
3930 let o_proj =
3931 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
3932
3933 let h_size = cfg.hidden_size;
3934 let i_size = cfg.intermediate_size;
3935 let gate_proj = h_size * i_size / weight_pack_factor;
3936 let up_proj = h_size * i_size / weight_pack_factor;
3937 let down_proj = i_size * h_size / weight_pack_factor;
3938
3939 input_layernorm
3940 + post_attention_layernorm
3941 + q_proj
3942 + k_proj
3943 + v_proj
3944 + o_proj
3945 + gate_proj
3946 + up_proj
3947 + down_proj
3948 };
3949 Ok(vec![
3950 per_layer_elems * dtype.size_in_bytes();
3951 txt_cfg.num_hidden_layers
3952 ])
3953 }
3954
3955 fn num_layers(&self, config: &str) -> Result<usize> {
3956 let cfg: Gemma3Config = serde_json::from_str(config)?;
3957
3958 let txt_cfg = match &cfg {
3959 Gemma3Config::Text(cfg) => cfg,
3960 Gemma3Config::WithVision { text_config, .. } => text_config,
3961 };
3962
3963 Ok(txt_cfg.num_hidden_layers)
3964 }
3965
3966 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3967 let cfg: Gemma3Config = serde_json::from_str(config)?;
3968
3969 let cfg = match &cfg {
3970 Gemma3Config::Text(cfg) => cfg,
3971 Gemma3Config::WithVision { text_config, .. } => text_config,
3972 };
3973
3974 let cfg = ModelConfigMetadata {
3975 max_seq_len: cfg.max_position_embeddings,
3976 num_layers: cfg.num_hidden_layers,
3977 hidden_size: cfg.hidden_size,
3978 num_kv_heads: cfg.num_key_value_heads,
3979 num_attn_heads: cfg.num_attention_heads,
3980 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3982 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3983 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3984 };
3985
3986 Ok(Box::new(cfg))
3987 }
3988
3989 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
3990 Some(vec![NonMappedSubModel::Vision])
3991 }
3992}
3993
3994pub struct Mistral3Loader;
4000
4001pub struct Mistral3Prefixer;
4002
4003impl MultimodalPromptPrefixer for Mistral3Prefixer {
4004 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4005 prompt.to_string()
4006 }
4007}
4008
4009impl MultimodalModelLoader for Mistral3Loader {
4010 fn load(
4011 &self,
4012 config: &str,
4013 vb: ShardedVarBuilder,
4014 normal_loading_metadata: NormalLoadingMetadata,
4015 attention_mechanism: AttentionImplementation,
4016 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
4017 let mut cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
4018 cfg.propagate_quantization_config();
4019 Ok(Box::new(Mistral3Model::new(
4020 &cfg,
4021 vb,
4022 self.is_gptx(config),
4023 normal_loading_metadata,
4024 attention_mechanism,
4025 )?))
4026 }
4027 fn is_gptx(&self, _config: &str) -> bool {
4028 true
4029 }
4030 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4031 let cfg: crate::vision_models::mistral3::Mistral3Config = serde_json::from_str(config)?;
4032 Ok(Box::new(cfg))
4033 }
4034 fn get_processor(
4035 &self,
4036 _model_config: &str,
4037 processor_config: Option<ProcessorConfig>,
4038 _preprocessor_config: PreProcessorConfig,
4039 _max_edge: Option<u32>,
4040 ) -> Arc<dyn Processor + Send + Sync> {
4041 Arc::new(Mistral3Processor::new(processor_config.unwrap_or_default()))
4042 }
4043 fn supports_paged_attention(&self, _config: &str) -> bool {
4044 true
4045 }
4046 fn supports_prefix_cacher(&self, _config: &str) -> bool {
4047 true
4048 }
4049 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4050 Arc::new(Mistral3Prefixer)
4051 }
4052 fn modalities(&self, _config: &str) -> Result<Modalities> {
4053 Ok(Modalities {
4054 input: vec![SupportedModality::Text, SupportedModality::Vision],
4055 output: vec![SupportedModality::Text],
4056 })
4057 }
4058}
4059
4060impl IsqModelLoader for Mistral3Loader {
4061 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4062 Ok(vec![
4063 Regex::new(r"lm_head\.(weight|bias)$")?,
4064 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4066 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4067 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4068 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4069 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4071 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4072 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4073 ])
4074 }
4075 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4076 Ok(vec![
4077 Regex::new(r"lm_head\.(weight|bias)$")?,
4078 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4080 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4081 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4082 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4083 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4085 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4086 Regex::new(r"language_model\.model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4087 ])
4088 }
4089}
4090
4091#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4092impl DeviceMappedModelLoader for Mistral3Loader {
4093 fn mapped_max_act_size_elems(
4094 &self,
4095 config: &str,
4096 params: &AutoDeviceMapParams,
4097 ) -> Result<usize> {
4098 let cfg: Mistral3Config = serde_json::from_str(config)?;
4099 let vcfg = &cfg.vision_config;
4100 let tcfg = &cfg.text_config;
4101
4102 let AutoDeviceMapParams::Multimodal {
4103 max_seq_len,
4104 max_batch_size,
4105 max_image_shape: (mut height, mut width),
4106 max_num_images,
4107 } = params
4108 else {
4109 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4110 };
4111
4112 let img_seq_len = {
4113 let (max_height, max_width) = (1540, 1540);
4117 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4118 if ratio > 1. {
4119 height = (height as f64 / ratio).floor() as usize;
4120 width = (width as f64 / ratio).floor() as usize;
4121 }
4122
4123 let num_height_tokens = (height - 1) / vcfg.patch_size + 1;
4124 let num_width_tokens = (width - 1) / vcfg.patch_size + 1;
4125
4126 height = num_height_tokens * vcfg.patch_size;
4127 width = num_width_tokens * vcfg.patch_size;
4128
4129 let num_height_tokens = height / vcfg.patch_size;
4130 let num_width_tokens = width / vcfg.patch_size;
4131
4132 (num_width_tokens + 1) * num_height_tokens
4133 };
4134
4135 let max_seq_len = img_seq_len * max_num_images + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4137 Ok(max_batch_size * tcfg.num_attention_heads * max_seq_len * max_seq_len)
4138 }
4139
4140 fn non_mapped_max_act_size_elems(
4141 &self,
4142 config: &str,
4143 params: &AutoDeviceMapParams,
4144 ) -> Result<usize> {
4145 let cfg: Mistral3Config = serde_json::from_str(config)?;
4146 let cfg = &cfg.vision_config;
4147
4148 let AutoDeviceMapParams::Multimodal {
4149 max_seq_len: _,
4150 max_batch_size,
4151 max_image_shape: (mut height, mut width),
4152 max_num_images,
4153 } = params
4154 else {
4155 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4156 };
4157
4158 let img_seq_len = {
4159 let (max_height, max_width) = (1540, 1540);
4163 let ratio = (height as f64 / max_height as f64).max(width as f64 / max_width as f64);
4164 if ratio > 1. {
4165 height = (height as f64 / ratio).floor() as usize;
4166 width = (width as f64 / ratio).floor() as usize;
4167 }
4168
4169 let num_height_tokens = (height - 1) / cfg.patch_size + 1;
4170 let num_width_tokens = (width - 1) / cfg.patch_size + 1;
4171
4172 height = num_height_tokens * cfg.patch_size;
4173 width = num_width_tokens * cfg.patch_size;
4174
4175 let num_height_tokens = height / cfg.patch_size;
4176 let num_width_tokens = width / cfg.patch_size;
4177
4178 (num_width_tokens + 1) * num_height_tokens
4179 };
4180
4181 Ok((max_batch_size * max_num_images) * cfg.num_attention_heads * img_seq_len * img_seq_len)
4182 }
4183
4184 fn non_mapped_size_in_bytes(
4185 &self,
4186 config: &str,
4187 dtype: DType,
4188 weight_pack_factor: usize,
4189 _matformer_config: Option<&MatformerSliceConfig>,
4190 ) -> Result<usize> {
4191 let cfg: Mistral3Config = serde_json::from_str(config)?;
4192
4193 let text_elems = {
4194 let cfg = &cfg.text_config;
4195
4196 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4197 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4199 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4200 } else {
4201 0
4202 };
4203 let norm = cfg.hidden_size;
4204 embed_tokens + lm_head + norm
4205 };
4206
4207 let vision_elems = {
4208 let cfg = &cfg.vision_config;
4209
4210 let patch_embed = {
4211 let conv_cfg = Conv2dConfig {
4212 stride: cfg.patch_size,
4213 ..Default::default()
4214 };
4215 cfg.num_channels * cfg.hidden_size / conv_cfg.groups
4216 * cfg.patch_size
4217 * cfg.patch_size
4218 * cfg.patch_size
4219 };
4220 let ln_pre = cfg.hidden_size;
4221 let vision_layer = {
4222 let attn_norm = cfg.hidden_size;
4223 let ffn_norm = cfg.hidden_size;
4224
4225 let gate = cfg.hidden_size * cfg.intermediate_size;
4226 let up = cfg.hidden_size * cfg.intermediate_size;
4227 let down = cfg.hidden_size * cfg.intermediate_size;
4228
4229 let q = cfg.hidden_size * cfg.hidden_size;
4230 let k = cfg.hidden_size * cfg.hidden_size;
4231 let v = cfg.hidden_size * cfg.hidden_size;
4232 let o = cfg.hidden_size * cfg.hidden_size;
4233
4234 attn_norm + ffn_norm + gate + up + down + q + k + v + o
4235 };
4236
4237 patch_embed + ln_pre + vision_layer * cfg.num_hidden_layers
4238 };
4239
4240 let elems = text_elems + vision_elems;
4241
4242 Ok(elems * dtype.size_in_bytes())
4243 }
4244
4245 fn layer_sizes_in_bytes(
4246 &self,
4247 config: &str,
4248 dtype: DType,
4249 weight_pack_factor: usize,
4250 _matformer_config: Option<&MatformerSliceConfig>,
4251 ) -> Result<Vec<usize>> {
4252 let cfg: Mistral3Config = serde_json::from_str(config)?;
4253 let cfg = &cfg.text_config;
4254
4255 let per_layer_elems = {
4256 let input_layernorm = cfg.hidden_size;
4257 let post_attention_layernorm = cfg.hidden_size;
4258
4259 let size_in = cfg.hidden_size;
4260 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4261 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4262 let q_proj = size_in * size_q / weight_pack_factor;
4263 let k_proj = size_in * size_kv / weight_pack_factor;
4264 let v_proj = size_in * size_kv / weight_pack_factor;
4265 let o_proj = size_q * size_in / weight_pack_factor;
4266
4267 let h_size = cfg.hidden_size;
4268 let i_size = cfg.intermediate_size;
4269 let gate_proj = h_size * i_size / weight_pack_factor;
4270 let up_proj = h_size * i_size / weight_pack_factor;
4271 let down_proj = i_size * h_size / weight_pack_factor;
4272
4273 input_layernorm
4274 + post_attention_layernorm
4275 + q_proj
4276 + k_proj
4277 + v_proj
4278 + o_proj
4279 + gate_proj
4280 + up_proj
4281 + down_proj
4282 };
4283 Ok(vec![
4284 per_layer_elems * dtype.size_in_bytes();
4285 cfg.num_hidden_layers
4286 ])
4287 }
4288
4289 fn num_layers(&self, config: &str) -> Result<usize> {
4290 let cfg: Mistral3Config = serde_json::from_str(config)?;
4291 let cfg = &cfg.text_config;
4292 Ok(cfg.num_hidden_layers)
4293 }
4294
4295 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4296 let cfg: Mistral3Config = serde_json::from_str(config)?;
4297 let cfg = &cfg.text_config;
4298
4299 let cfg = ModelConfigMetadata {
4300 max_seq_len: cfg.max_position_embeddings,
4301 num_layers: cfg.num_hidden_layers,
4302 hidden_size: cfg.hidden_size,
4303 num_kv_heads: cfg.num_key_value_heads,
4304 num_attn_heads: cfg.num_attention_heads,
4305 sliding_window: cfg.sliding_window,
4306 k_head_dim: cfg.head_dim(),
4307 v_head_dim: cfg.head_dim(),
4308 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4309 };
4310
4311 Ok(Box::new(cfg))
4312 }
4313
4314 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4315 Some(vec![NonMappedSubModel::Vision])
4316 }
4317}
4318
4319pub struct VLlama4Loader;
4325
4326pub struct VLlama4Prefixer;
4327
4328impl MultimodalPromptPrefixer for VLlama4Prefixer {
4329 fn prefix_image(&self, image_indexes: Vec<usize>, prompt: &str) -> String {
4330 format!(
4331 "{}{prompt}",
4332 llama4::IMAGE_TOKEN.repeat(image_indexes.len())
4333 )
4334 }
4335}
4336
4337impl MultimodalModelLoader for VLlama4Loader {
4338 fn load(
4339 &self,
4340 config: &str,
4341 vb: ShardedVarBuilder,
4342 normal_loading_metadata: NormalLoadingMetadata,
4343 attention_mechanism: AttentionImplementation,
4344 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
4345 let mut cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4346 cfg.propagate_quantization_config();
4347 Ok(Box::new(Llama4Model::new(
4348 &cfg,
4349 vb,
4350 self.is_gptx(config),
4351 normal_loading_metadata,
4352 attention_mechanism,
4353 )?))
4354 }
4355 fn is_gptx(&self, _config: &str) -> bool {
4356 false
4357 }
4358 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4359 let mut cfg: crate::vision_models::llama4::Llama4Config = serde_json::from_str(config)?;
4360 cfg.propagate_quantization_config();
4361 Ok(Box::new(cfg))
4362 }
4363 fn get_processor(
4364 &self,
4365 _model_config: &str,
4366 processor_config: Option<ProcessorConfig>,
4367 _preprocessor_config: PreProcessorConfig,
4368 _max_edge: Option<u32>,
4369 ) -> Arc<dyn Processor + Send + Sync> {
4370 Arc::new(Llama4Processor::new(&processor_config.unwrap()))
4371 }
4372 fn supports_paged_attention(&self, _config: &str) -> bool {
4373 true
4374 }
4375 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4376 Arc::new(VLlama4Prefixer)
4377 }
4378 fn modalities(&self, _config: &str) -> Result<Modalities> {
4379 Ok(Modalities {
4380 input: vec![SupportedModality::Text, SupportedModality::Vision],
4381 output: vec![SupportedModality::Text],
4382 })
4383 }
4384}
4385
4386impl IsqModelLoader for VLlama4Loader {
4387 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4388 Ok(vec![
4389 Regex::new(r"lm_head\.(weight|bias)$")?,
4390 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4392 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4393 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4394 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4395 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_up_proj\.(weight|bias)$")?,
4397 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.gate_proj\.(weight|bias)$")?,
4398 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.up_proj\.(weight|bias)$")?,
4399 Regex::new(r"layers\.(\d+)\.feed_forward\.experts\.down_proj\.(weight|bias)$")?,
4400 Regex::new(r"layers\.(\d+)\.feed_forward\.router\.(weight|bias)$")?,
4401 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4402 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4403 Regex::new(r"layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$")?,
4404 Regex::new(r"layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$")?,
4406 Regex::new(r"layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$")?,
4407 Regex::new(r"layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$")?,
4408 ])
4409 }
4410 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4411 Ok(vec![
4412 Regex::new(r"lm_head\.(weight|bias)$")?,
4413 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4415 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4416 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4417 Regex::new(r"language_model\.model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4418 Regex::new(
4420 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_up_proj\.(weight|bias)$",
4421 )?,
4422 Regex::new(
4423 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
4424 )?,
4425 Regex::new(
4426 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.up_proj\.(weight|bias)$",
4427 )?,
4428 Regex::new(
4429 r"language_model\.model\.layers\.(\d+)\.feed_forward\.experts\.(\d+)\.down_proj\.(weight|bias)$",
4430 )?,
4431 Regex::new(
4432 r"language_model\.model\.layers\.(\d+)\.feed_forward\.router\.(weight|bias)$",
4433 )?,
4434 Regex::new(
4435 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4436 )?,
4437 Regex::new(
4438 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4439 )?,
4440 Regex::new(
4441 r"language_model\.model\.layers\.(\d+)\.feed_forward\.shared_expert\.(weight|bias)$",
4442 )?,
4443 Regex::new(
4445 r"language_model\.model\.layers\.(\d+)\.feed_forward\.gate_proj\.(weight|bias)$",
4446 )?,
4447 Regex::new(
4448 r"language_model\.model\.layers\.(\d+)\.feed_forward\.up_proj\.(weight|bias)$",
4449 )?,
4450 Regex::new(
4451 r"language_model\.model\.layers\.(\d+)\.feed_forward\.down_proj\.(weight|bias)$",
4452 )?,
4453 ])
4454 }
4455}
4456
4457impl VLlama4Loader {
4458 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
4461 fn run_dummy_processing(
4462 &self,
4463 cfg: &Llama4Config,
4464 height: usize,
4465 width: usize,
4466 max_num_images: usize,
4467 max_batch_size: usize,
4468 ) -> Result<(usize, usize)> {
4469 let cfg = &cfg.vision_config;
4470
4471 let img_processor =
4472 Llama4ImageProcessor::new(Some(cfg.patch_size), Some(cfg.pixel_shuffle_ratio));
4473 let image = DynamicImage::new(width as u32, height as u32, ColorType::Rgb8);
4474 let res = img_processor.preprocess(
4475 vec![image; max_num_images],
4476 vec![],
4477 &PreProcessorConfig::default(),
4478 &Device::Cpu,
4479 (max_batch_size, max_num_images),
4480 )?;
4481
4482 let pixels_batch_size = res.pixel_values.dim(0)?;
4483 let pixels_max_batch_size = pixels_batch_size * max_batch_size;
4484
4485 let (image_h, image_w) = (
4486 res.pixel_values.dim(D::Minus2).unwrap(),
4487 res.pixel_values.dim(D::Minus1).unwrap(),
4488 );
4489 let num_patches_per_chunk = (image_h / img_processor.patch_size)
4490 * (image_w / img_processor.patch_size)
4491 / img_processor.downsample_ratio;
4492
4493 Ok((
4494 pixels_max_batch_size,
4495 num_patches_per_chunk * pixels_max_batch_size,
4496 ))
4497 }
4498}
4499
4500impl DeviceMappedModelLoader for VLlama4Loader {
4501 fn mapped_max_act_size_elems(
4502 &self,
4503 config: &str,
4504 params: &AutoDeviceMapParams,
4505 ) -> Result<usize> {
4506 let AutoDeviceMapParams::Multimodal {
4507 max_seq_len,
4508 max_batch_size,
4509 max_image_shape: (height, width),
4510 max_num_images,
4511 } = params
4512 else {
4513 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4514 };
4515
4516 let cfg: Llama4Config = serde_json::from_str(config)?;
4517
4518 let (_pixels_batch_size, num_text_image_toks) =
4519 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4520
4521 let max_seq_len = max_seq_len.min(&ATTENTION_CHUNK_SIZE) + num_text_image_toks;
4522
4523 Ok(max_batch_size * cfg.text_config.num_attention_heads * max_seq_len * max_seq_len)
4524 }
4525 fn non_mapped_max_act_size_elems(
4526 &self,
4527 config: &str,
4528 params: &AutoDeviceMapParams,
4529 ) -> Result<usize> {
4530 let AutoDeviceMapParams::Multimodal {
4531 max_seq_len: _,
4532 max_batch_size,
4533 max_image_shape: (height, width),
4534 max_num_images,
4535 } = params
4536 else {
4537 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4538 };
4539
4540 let cfg: Llama4Config = serde_json::from_str(config)?;
4541
4542 let (pixels_batch_size, _num_text_image_toks) =
4543 self.run_dummy_processing(&cfg, *height, *width, *max_num_images, *max_batch_size)?;
4544 let max_seq_len = cfg.vision_config.num_patches();
4545
4546 Ok((max_batch_size * pixels_batch_size)
4547 * cfg.vision_config.num_attention_heads
4548 * max_seq_len
4549 * max_seq_len)
4550 }
4551
4552 fn non_mapped_size_in_bytes(
4553 &self,
4554 config: &str,
4555 dtype: DType,
4556 weight_pack_factor: usize,
4557 _matformer_config: Option<&MatformerSliceConfig>,
4558 ) -> Result<usize> {
4559 let cfg: Llama4Config = serde_json::from_str(config)?;
4560 let tcfg = &cfg.text_config;
4561
4562 let text_elems = {
4563 let embed_tokens = tcfg.hidden_size * tcfg.vocab_size / weight_pack_factor;
4564 let lm_head = if !tcfg.tie_word_embeddings {
4565 tcfg.hidden_size * tcfg.vocab_size
4566 } else {
4567 0
4568 };
4569 let norm = tcfg.hidden_size;
4570 embed_tokens + lm_head + norm
4571 };
4572
4573 let vision_elems = {
4574 let cfg = &cfg.vision_config;
4575
4576 let num_patches = cfg.num_patches();
4577
4578 let unfold_elems =
4579 (cfg.num_channels * cfg.patch_size * cfg.patch_size) * cfg.hidden_size;
4580 let class_embeddng_elems = cfg.hidden_size;
4581 let positional_embedding_vlm_elems = num_patches * cfg.hidden_size;
4582 let layernorm_pre_elems = cfg.hidden_size;
4583 let layernorm_post_elems = cfg.hidden_size;
4584
4585 let pixel_shuffle_elems = cfg.intermediate_size * cfg.projector_input_dim
4586 / weight_pack_factor
4587 + cfg.projector_input_dim * cfg.projector_output_dim / weight_pack_factor;
4588
4589 let encoder_layer = {
4590 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
4591 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
4592
4593 let head_dim = cfg.hidden_size / cfg.num_attention_heads;
4594 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4595 / weight_pack_factor
4596 + cfg.num_attention_heads * head_dim;
4597 let k_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4598 / weight_pack_factor
4599 + cfg.num_attention_heads * head_dim;
4600 let v_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4601 / weight_pack_factor
4602 + cfg.num_attention_heads * head_dim;
4603 let o_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim
4604 / weight_pack_factor
4605 + cfg.num_attention_heads * head_dim;
4606
4607 let fc1 = (cfg.hidden_size * cfg.intermediate_size) / weight_pack_factor
4608 + cfg.intermediate_size;
4609 let fc2 = (cfg.intermediate_size * cfg.hidden_size) / weight_pack_factor
4610 + cfg.hidden_size;
4611
4612 input_layernorm
4613 + post_attention_layernorm
4614 + q_proj
4615 + k_proj
4616 + v_proj
4617 + o_proj
4618 + fc1
4619 + fc2
4620 };
4621
4622 unfold_elems
4623 + class_embeddng_elems
4624 + positional_embedding_vlm_elems
4625 + layernorm_post_elems
4626 + layernorm_pre_elems
4627 + pixel_shuffle_elems
4628 + encoder_layer * cfg.num_hidden_layers
4629 };
4630
4631 let elems = text_elems + vision_elems;
4632
4633 Ok(elems * dtype.size_in_bytes())
4634 }
4635
4636 fn layer_sizes_in_bytes(
4637 &self,
4638 config: &str,
4639 dtype: DType,
4640 weight_pack_factor: usize,
4641 _matformer_config: Option<&MatformerSliceConfig>,
4642 ) -> Result<Vec<usize>> {
4643 let cfg: Llama4Config = serde_json::from_str(config)?;
4644 let tcfg = &cfg.text_config;
4645
4646 let mut per_layer_elems = Vec::new();
4647
4648 for layer_idx in 0..tcfg.num_hidden_layers {
4649 let input_layernorm = tcfg.hidden_size;
4650 let post_attention_layernorm = tcfg.hidden_size;
4651
4652 let size_in = tcfg.hidden_size;
4653 let size_q = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_attention_heads;
4654 let size_kv = (tcfg.hidden_size / tcfg.num_attention_heads) * tcfg.num_key_value_heads;
4655 let q_proj = size_in * size_q / weight_pack_factor;
4656 let k_proj = size_in * size_kv / weight_pack_factor;
4657 let v_proj = size_in * size_kv / weight_pack_factor;
4658 let o_proj = size_q * size_in / weight_pack_factor;
4659
4660 let use_moe = tcfg.moe_layers().contains(&layer_idx);
4661 let moe_block = if use_moe {
4662 let h_size = tcfg.hidden_size;
4663 let i_size = tcfg.intermediate_size;
4664 let gate_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4665 let up_proj = tcfg.num_local_experts * h_size * i_size / weight_pack_factor;
4666 let down_proj = tcfg.num_local_experts * i_size * h_size / weight_pack_factor;
4667
4668 gate_proj + up_proj + down_proj
4669 } else {
4670 let h_size = tcfg.hidden_size;
4671 let i_size = tcfg.intermediate_size_mlp;
4672 let gate_proj = h_size * i_size / weight_pack_factor;
4673 let up_proj = h_size * i_size / weight_pack_factor;
4674 let down_proj = i_size * h_size / weight_pack_factor;
4675
4676 gate_proj + up_proj + down_proj
4677 };
4678
4679 per_layer_elems.push(
4680 input_layernorm
4681 + post_attention_layernorm
4682 + q_proj
4683 + k_proj
4684 + v_proj
4685 + o_proj
4686 + moe_block,
4687 );
4688 }
4689
4690 Ok(per_layer_elems
4691 .into_iter()
4692 .map(|x| x * dtype.size_in_bytes())
4693 .collect())
4694 }
4695
4696 fn num_layers(&self, config: &str) -> Result<usize> {
4697 let cfg: Llama4Config = serde_json::from_str(config)?;
4698 Ok(cfg.text_config.num_hidden_layers)
4699 }
4700
4701 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4702 let cfg: Llama4Config = serde_json::from_str(config)?;
4703 let cfg = &cfg.text_config;
4704
4705 let cfg = ModelConfigMetadata {
4706 max_seq_len: cfg.max_position_embeddings,
4707 num_layers: cfg.num_hidden_layers,
4708 hidden_size: cfg.hidden_size,
4709 num_kv_heads: cfg.num_attention_heads,
4710 num_attn_heads: cfg.num_attention_heads,
4711 sliding_window: None,
4712 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4713 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4714 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4715 };
4716
4717 Ok(Box::new(cfg))
4718 }
4719
4720 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
4721 Some(vec![NonMappedSubModel::Vision])
4722 }
4723}
4724
4725pub struct Gemma3nLoader;
4731
4732#[allow(dead_code)]
4733pub struct Gemma3nPrefixer;
4734
4735impl MultimodalPromptPrefixer for Gemma3nPrefixer {
4736 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
4737 prompt.to_string()
4738 }
4739}
4740
4741impl MultimodalModelLoader for Gemma3nLoader {
4742 fn load(
4743 &self,
4744 config: &str,
4745 vb: ShardedVarBuilder,
4746 normal_loading_metadata: NormalLoadingMetadata,
4747 attention_mechanism: AttentionImplementation,
4748 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
4749 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4750 Ok(Box::new(Gemma3nModel::new(
4751 &cfg,
4752 vb,
4753 self.is_gptx(config),
4754 normal_loading_metadata,
4755 attention_mechanism,
4756 )?))
4757 }
4758 fn is_gptx(&self, _config: &str) -> bool {
4759 true
4760 }
4761 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4762 let config: Gemma3nConfig = serde_json::from_str(config)?;
4763 Ok(Box::new(config))
4764 }
4765 fn get_processor(
4766 &self,
4767 _config: &str,
4768 processor_config: Option<ProcessorConfig>,
4769 _preprocessor_config: PreProcessorConfig,
4770 _max_edge: Option<u32>,
4771 ) -> Arc<dyn Processor + Send + Sync> {
4772 Arc::new(Gemma3nProcessor::new(
4774 processor_config.unwrap_or_default(),
4775 true,
4776 ))
4777 }
4778 fn supports_paged_attention(&self, _config: &str) -> bool {
4779 false
4780 }
4781 fn supports_prefix_cacher(&self, _config: &str) -> bool {
4782 true
4783 }
4784 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
4785 Arc::new(Gemma3Prefixer)
4786 }
4787 fn modalities(&self, _config: &str) -> Result<Modalities> {
4788 Ok(Modalities {
4789 input: vec![
4790 SupportedModality::Text,
4791 SupportedModality::Vision,
4792 SupportedModality::Audio,
4793 ],
4794 output: vec![SupportedModality::Text],
4795 })
4796 }
4797}
4798
4799impl IsqModelLoader for Gemma3nLoader {
4800 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4801 Ok(vec![
4802 Regex::new(r"lm_head\.(weight|bias)$")?,
4803 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4805 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4806 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4807 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4808 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4810 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4811 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4812 Regex::new(r"conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$")?,
4814 Regex::new(r"conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$")?,
4815 Regex::new(r"conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$")?,
4816 Regex::new(
4817 r"conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4818 )?,
4819 Regex::new(r"conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4820 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$")?,
4822 Regex::new(r"conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$")?,
4823 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$")?,
4824 Regex::new(r"conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$")?,
4825 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$")?,
4827 Regex::new(r"conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$")?,
4828 Regex::new(r"subsample_conv_projection\.input_proj_linear\.(weight|bias)$")?,
4830 Regex::new(r"embed_vision\.embedding_projection\.(weight|bias)$")?,
4832 Regex::new(r"embed_audio\.embedding_projection\.(weight|bias)$")?,
4833 ])
4834 }
4835 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
4836 Ok(vec![
4837 Regex::new(r"lm_head\.(weight|bias)$")?,
4838 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4840 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4841 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4842 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4843 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4845 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4846 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4847 Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
4849 Regex::new(r"model\.language_model\.altup_projections\.(\d+)\.(weight|bias)$")?,
4850 Regex::new(r"model\.language_model\.altup_unembed_projections\.(\d+)\.(weight|bias)$")?,
4851 Regex::new(
4853 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.q_proj\.(weight|bias)$",
4854 )?,
4855 Regex::new(
4856 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.k_proj\.(weight|bias)$",
4857 )?,
4858 Regex::new(
4859 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.v_proj\.(weight|bias)$",
4860 )?,
4861 Regex::new(
4862 r"model\.audio_tower\.conformer\.(\d+)\.attention\.attn\.relative_position_embedding\.pos_proj\.(weight|bias)$",
4863 )?,
4864 Regex::new(r"model\.audio_tower\.conformer\.(\d+)\.attention\.post\.(weight|bias)$")?,
4865 Regex::new(
4867 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_1\.(weight|bias)$",
4868 )?,
4869 Regex::new(
4870 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_start\.ffw_layer_2\.(weight|bias)$",
4871 )?,
4872 Regex::new(
4873 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_1\.(weight|bias)$",
4874 )?,
4875 Regex::new(
4876 r"model\.audio_tower\.conformer\.(\d+)\.ffw_layer_end\.ffw_layer_2\.(weight|bias)$",
4877 )?,
4878 Regex::new(
4880 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_start\.(weight|bias)$",
4881 )?,
4882 Regex::new(
4883 r"model\.audio_tower\.conformer\.(\d+)\.lconv1d\.linear_end\.(weight|bias)$",
4884 )?,
4885 Regex::new(
4887 r"model\.audio_tower\.subsample_conv_projection\.input_proj_linear\.(weight|bias)$",
4888 )?,
4889 Regex::new(r"model\.embed_vision\.embedding_projection\.(weight|bias)$")?,
4891 Regex::new(r"model\.embed_audio\.embedding_projection\.(weight|bias)$")?,
4892 ])
4893 }
4894}
4895
4896impl DeviceMappedModelLoader for Gemma3nLoader {
4897 fn mapped_max_act_size_elems(
4898 &self,
4899 config: &str,
4900 params: &AutoDeviceMapParams,
4901 ) -> Result<usize> {
4902 let AutoDeviceMapParams::Multimodal {
4903 max_seq_len,
4904 max_batch_size,
4905 max_image_shape: _,
4906 max_num_images,
4907 } = params
4908 else {
4909 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4910 };
4911
4912 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4913 let text_cfg = &cfg.text_config;
4914
4915 let mut total_seq_len = *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
4919
4920 {
4922 let msfa_spatial_size = 16; let vision_tokens_per_image = msfa_spatial_size * msfa_spatial_size; total_seq_len += vision_tokens_per_image * max_num_images;
4927 }
4928
4929 {
4931 let audio_tokens = cfg.audio_soft_tokens_per_image;
4934 total_seq_len += audio_tokens;
4935 }
4936
4937 let max_text_attn =
4939 max_batch_size * text_cfg.num_attention_heads * total_seq_len * total_seq_len;
4940
4941 Ok(max_text_attn)
4942 }
4943
4944 fn non_mapped_max_act_size_elems(
4945 &self,
4946 config: &str,
4947 params: &AutoDeviceMapParams,
4948 ) -> Result<usize> {
4949 let AutoDeviceMapParams::Multimodal {
4950 max_seq_len: _,
4951 max_batch_size,
4952 max_image_shape: _,
4953 max_num_images,
4954 } = params
4955 else {
4956 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
4957 };
4958
4959 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
4960
4961 let mut max_activation = 0;
4963
4964 {
4966 let vision_tower_act = {
4976 let num_heads = 16; let spatial_size = 24; let seq_len = spatial_size * spatial_size;
4982
4983 max_batch_size * max_num_images * num_heads * seq_len * seq_len
4985 };
4986
4987 let vision_embed_act = {
4989 let msfa_channels = 2048; let spatial_size = 16; let vision_features =
4993 max_batch_size * max_num_images * msfa_channels * spatial_size * spatial_size;
4994
4995 let projected = max_batch_size
4997 * max_num_images
4998 * spatial_size
4999 * spatial_size
5000 * cfg.text_config.hidden_size;
5001
5002 vision_features.max(projected)
5003 };
5004
5005 max_activation = max_activation.max(vision_tower_act).max(vision_embed_act);
5006 }
5007
5008 {
5010 let audio_cfg = &cfg.audio_config;
5011
5012 let max_audio_frames = 1280;
5017
5018 let subsample_factor: usize = audio_cfg
5019 .sscp_conv_stride_size
5020 .iter()
5021 .map(|stride| stride[0]) .product();
5023 let audio_seq_after_subsample = max_audio_frames / subsample_factor;
5024
5025 let audio_encoder_act = {
5027 let intermediate_size = audio_cfg.hidden_size * 4; max_batch_size * audio_seq_after_subsample * intermediate_size
5032 };
5033
5034 let audio_attn_act = {
5036 let chunk_size = audio_cfg.conf_attention_chunk_size;
5038 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5039 + audio_cfg.conf_attention_context_right;
5040
5041 let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
5043
5044 max_batch_size
5045 * audio_cfg.conf_num_attention_heads
5046 * num_chunks
5047 * chunk_size
5048 * context_size
5049 };
5050
5051 max_activation = max_activation.max(audio_encoder_act).max(audio_attn_act);
5052 }
5053
5054 Ok(max_activation)
5055 }
5056
5057 fn non_mapped_size_in_bytes(
5058 &self,
5059 config: &str,
5060 dtype: DType,
5061 weight_pack_factor: usize,
5062 matformer_config: Option<&MatformerSliceConfig>,
5063 ) -> Result<usize> {
5064 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5065
5066 let text_cfg = if let Some(matformer_cfg) = matformer_config {
5068 use crate::device_map::DummyDeviceMapper;
5069 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5070
5071 let dummy_mapper = DummyDeviceMapper {
5072 nm_device: Device::Cpu,
5073 };
5074 let (adjusted_cfg, _, _, _, _) = handle_matformer_slicing(
5075 &cfg.text_config,
5076 &Some(matformer_cfg.clone()),
5077 &dummy_mapper,
5078 )?;
5079 adjusted_cfg
5080 } else {
5081 cfg.text_config.clone()
5082 };
5083
5084 let text_cfg = &text_cfg;
5085
5086 let text_elems = {
5088 let embed_tokens = text_cfg.hidden_size * text_cfg.vocab_size;
5090 let embed_tokens_per_layer = text_cfg.num_hidden_layers
5091 * text_cfg.hidden_size_per_layer_input
5092 * text_cfg.vocab_size_per_layer_input;
5093
5094 let lm_head = if !text_cfg.tie_word_embeddings || weight_pack_factor != 1 {
5096 text_cfg.hidden_size * text_cfg.vocab_size / weight_pack_factor
5097 } else {
5098 0
5099 };
5100
5101 let norm = text_cfg.hidden_size;
5103
5104 let altup_projections =
5106 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5107 / weight_pack_factor;
5108 let altup_unembed_projections =
5109 (text_cfg.altup_num_inputs - 1) * text_cfg.hidden_size * text_cfg.hidden_size
5110 / weight_pack_factor;
5111
5112 let per_layer_model_projection = text_cfg.num_hidden_layers
5114 * text_cfg.hidden_size
5115 * text_cfg.hidden_size_per_layer_input
5116 / weight_pack_factor;
5117 let per_layer_projection_norm = text_cfg.hidden_size;
5118
5119 embed_tokens
5120 + embed_tokens_per_layer
5121 + lm_head
5122 + norm
5123 + altup_projections
5124 + altup_unembed_projections
5125 + per_layer_model_projection
5126 + per_layer_projection_norm
5127 };
5128
5129 let vision_elems = {
5131 let multimodal_cfg = &cfg.vision_config;
5132 let vision_tower_elems = {
5136 use crate::vision_models::gemma3n::vision::{
5137 gemma3n_mobilenet_def, make_divisible, BlockType, INPUT_CHANNELS,
5138 MSFA_EXPANSION_RATIO, MSFA_IN_CHANNELS, MSFA_OUT_CHANNELS, STEM_KERNEL_SIZE,
5139 STEM_OUT_CHANNELS,
5140 };
5141
5142 let stem_conv =
5144 INPUT_CHANNELS * STEM_OUT_CHANNELS * STEM_KERNEL_SIZE * STEM_KERNEL_SIZE;
5145 let stem_norm = STEM_OUT_CHANNELS; let mut in_chs = STEM_OUT_CHANNELS;
5149 let mut total_elems = stem_conv + stem_norm;
5150
5151 let block_defs = gemma3n_mobilenet_def();
5153
5154 for stage_blocks in block_defs.iter() {
5155 for block_type in stage_blocks.iter() {
5156 match block_type {
5157 BlockType::EdgeResidual {
5158 out_channels,
5159 kernel_size,
5160 stride: _,
5161 expand_ratio,
5162 ..
5163 } => {
5164 #[allow(clippy::cast_precision_loss)]
5165 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5166 total_elems += in_chs * mid_chs * kernel_size * kernel_size; total_elems += mid_chs; total_elems += mid_chs * out_channels; total_elems += out_channels; in_chs = *out_channels;
5172 }
5173 BlockType::UniversalInvertedResidual {
5174 out_channels,
5175 start_kernel_size,
5176 mid_kernel_size,
5177 stride: _,
5178 expand_ratio,
5179 ..
5180 } => {
5181 #[allow(clippy::cast_precision_loss)]
5182 let mid_chs = make_divisible(in_chs as f64 * expand_ratio, 8);
5183 if *expand_ratio != 1.0 {
5185 total_elems += in_chs * mid_chs; total_elems += mid_chs; }
5188 if *start_kernel_size > 0 {
5189 total_elems += mid_chs * start_kernel_size * start_kernel_size; total_elems += mid_chs; }
5192 if *mid_kernel_size > 0 {
5193 total_elems += mid_chs * mid_kernel_size * mid_kernel_size; total_elems += mid_chs; }
5196 total_elems += mid_chs * out_channels; total_elems += out_channels; total_elems += out_channels; in_chs = *out_channels;
5200 }
5201 BlockType::MultiQueryAttention {
5202 num_heads,
5203 kv_dim,
5204 kv_stride: _,
5205 ..
5206 } => {
5207 let dw_kernel_size = 3; total_elems += in_chs; total_elems += in_chs * num_heads * kv_dim; total_elems += in_chs * kv_dim; total_elems += in_chs * dw_kernel_size * dw_kernel_size; total_elems += *kv_dim; total_elems += 1; total_elems += *kv_dim; total_elems += num_heads * kv_dim * in_chs; total_elems += in_chs; }
5219 }
5220 }
5221 }
5222
5223 let msfa_in = MSFA_IN_CHANNELS.iter().sum::<usize>();
5225 let msfa_out = MSFA_OUT_CHANNELS;
5226 #[allow(clippy::cast_precision_loss)]
5227 let msfa_mid = make_divisible(msfa_in as f64 * MSFA_EXPANSION_RATIO, 8);
5228
5229 total_elems += msfa_in * msfa_mid; total_elems += msfa_mid; total_elems += msfa_mid * msfa_out; total_elems += msfa_out; total_elems += msfa_out; total_elems
5237 };
5238
5239 let embed_vision_elems = {
5241 let embedding = multimodal_cfg.vocab_size * multimodal_cfg.hidden_size;
5243
5244 let hard_norm = multimodal_cfg.hidden_size;
5246 let soft_norm = multimodal_cfg.hidden_size;
5247
5248 let projection =
5250 multimodal_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5251
5252 let post_norm = text_cfg.hidden_size;
5254
5255 embedding + hard_norm + soft_norm + projection + post_norm
5256 };
5257
5258 vision_tower_elems + embed_vision_elems
5259 };
5260
5261 let audio_elems = {
5263 let audio_cfg = &cfg.audio_config;
5264
5265 let subsample_conv_projection_elems = {
5267 let mut conv_elems = 0;
5269
5270 let in_ch_0 = 1;
5272 let out_ch_0 = audio_cfg.sscp_conv_channel_size[0];
5273 let kernel_0 = &audio_cfg.sscp_conv_kernel_size[0];
5274 conv_elems += in_ch_0 * out_ch_0 * kernel_0[0] * kernel_0[1];
5275
5276 let in_ch_1 = out_ch_0;
5278 let out_ch_1 = audio_cfg.sscp_conv_channel_size[1];
5279 let kernel_1 = &audio_cfg.sscp_conv_kernel_size[1];
5280 conv_elems += in_ch_1 * out_ch_1 * kernel_1[0] * kernel_1[1];
5281
5282 let norm_0 = out_ch_0; let norm_1 = out_ch_1; let mut f_out = audio_cfg.input_feat_size;
5288 for i in 0..2 {
5289 let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
5290 let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
5291 let pad_left = 1;
5292 let pad_right = 1;
5293 f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
5294 }
5295 let input_proj_in_features = out_ch_1 * f_out;
5296 let input_proj_linear =
5297 input_proj_in_features * audio_cfg.hidden_size / weight_pack_factor;
5298
5299 conv_elems + norm_0 + norm_1 + input_proj_linear
5300 };
5301
5302 let conformer_elems = {
5304 let mut total = 0;
5305
5306 for _ in 0..audio_cfg.conf_num_hidden_layers {
5307 let attention_elems = {
5309 let pre_attn_norm = audio_cfg.hidden_size;
5311 let post_norm = audio_cfg.hidden_size;
5312
5313 let q_proj =
5315 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5316 let k_proj =
5317 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5318 let v_proj =
5319 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5320 let post =
5321 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5322
5323 let pos_proj =
5325 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5326 let per_dim_scale =
5327 audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads; let inv_timescales = audio_cfg.hidden_size / 2; let pos_indices = audio_cfg.conf_attention_context_left
5330 + audio_cfg.conf_attention_context_right
5331 + 1;
5332
5333 let chunk_size = audio_cfg.conf_attention_chunk_size;
5335 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
5336 + audio_cfg.conf_attention_context_right;
5337 let local_causal_valid_mask = chunk_size * context_size; let invalid_logits_tensor = 1; pre_attn_norm
5341 + post_norm
5342 + q_proj
5343 + k_proj
5344 + v_proj
5345 + post
5346 + pos_proj
5347 + per_dim_scale
5348 + inv_timescales
5349 + pos_indices
5350 + local_causal_valid_mask
5351 + invalid_logits_tensor
5352 };
5353
5354 let ffw_elems = {
5356 let intermediate_size = audio_cfg.hidden_size * 4;
5362
5363 let ffw_start = {
5364 let pre_norm = audio_cfg.hidden_size;
5365 let layer_1 =
5366 audio_cfg.hidden_size * intermediate_size / weight_pack_factor;
5367 let layer_2 =
5368 intermediate_size * audio_cfg.hidden_size / weight_pack_factor;
5369 let post_norm = audio_cfg.hidden_size;
5370 pre_norm + layer_1 + layer_2 + post_norm
5371 };
5372
5373 let ffw_end = ffw_start; ffw_start + ffw_end
5376 };
5377
5378 let lconv1d_elems = {
5380 let pre_layer_norm = audio_cfg.hidden_size;
5382 let conv_norm = audio_cfg.hidden_size;
5383
5384 let linear_start = audio_cfg.hidden_size * (audio_cfg.hidden_size * 2)
5386 / weight_pack_factor;
5387 let linear_end =
5388 audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor;
5389
5390 let depthwise = audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
5392
5393 pre_layer_norm + conv_norm + linear_start + linear_end + depthwise
5394 };
5395
5396 let block_norm = audio_cfg.hidden_size;
5398
5399 total += attention_elems + ffw_elems + lconv1d_elems + block_norm;
5400 }
5401
5402 total
5403 };
5404
5405 let embed_audio_elems = {
5407 let embedding = audio_cfg.vocab_size * audio_cfg.hidden_size;
5409
5410 let hard_embedding_norm = audio_cfg.hidden_size; let soft_embedding_norm = audio_cfg.hidden_size; let embedding_post_projection_norm = text_cfg.hidden_size; let embedding_projection =
5417 audio_cfg.hidden_size * text_cfg.hidden_size / weight_pack_factor;
5418
5419 embedding
5420 + hard_embedding_norm
5421 + soft_embedding_norm
5422 + embedding_post_projection_norm
5423 + embedding_projection
5424 };
5425
5426 subsample_conv_projection_elems + conformer_elems + embed_audio_elems
5427 };
5428
5429 let vision_dtype = if dtype == DType::F16 {
5430 DType::F32
5432 } else {
5433 dtype
5434 };
5435
5436 let total_elems = text_elems * dtype.size_in_bytes()
5437 + vision_elems * vision_dtype.size_in_bytes()
5438 + audio_elems * dtype.size_in_bytes();
5439
5440 Ok(total_elems)
5441 }
5442
5443 fn layer_sizes_in_bytes(
5444 &self,
5445 config: &str,
5446 dtype: DType,
5447 weight_pack_factor: usize,
5448 matformer_config: Option<&MatformerSliceConfig>,
5449 ) -> Result<Vec<usize>> {
5450 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5451
5452 let (text_cfg, _layer_rename_map, _layers_skipped) = if let Some(matformer_cfg) =
5454 matformer_config
5455 {
5456 use crate::device_map::DummyDeviceMapper;
5457 use crate::vision_models::gemma3n::text::handle_matformer_slicing;
5458
5459 let dummy_mapper = DummyDeviceMapper {
5460 nm_device: Device::Cpu,
5461 };
5462 let (adjusted_cfg, _, _, layer_rename_map, layers_skipped) = handle_matformer_slicing(
5463 &cfg.text_config,
5464 &Some(matformer_cfg.clone()),
5465 &dummy_mapper,
5466 )?;
5467 (adjusted_cfg, layer_rename_map, layers_skipped)
5468 } else {
5469 (cfg.text_config.clone(), None, None)
5470 };
5471
5472 let text_cfg = &text_cfg;
5473
5474 let mut layer_sizes = Vec::new();
5476
5477 for layer_idx in 0..text_cfg.num_hidden_layers {
5481 let per_layer_elems = {
5482 let input_layernorm = text_cfg.hidden_size;
5484 let post_attention_layernorm = text_cfg.hidden_size;
5485 let pre_feedforward_layernorm = text_cfg.hidden_size;
5486 let post_feedforward_layernorm = text_cfg.hidden_size;
5487 let post_per_layer_input_norm = text_cfg.hidden_size;
5488
5489 let size_in = text_cfg.hidden_size;
5491 let size_q = text_cfg.num_attention_heads * text_cfg.head_dim;
5492 let size_kv = text_cfg.num_key_value_heads * text_cfg.head_dim;
5493
5494 let q_proj = size_in * size_q / weight_pack_factor;
5495 let k_proj = size_in * size_kv / weight_pack_factor;
5496 let v_proj = size_in * size_kv / weight_pack_factor;
5497 let o_proj = size_q * size_in / weight_pack_factor;
5498
5499 let q_norm = text_cfg.head_dim;
5501 let k_norm = text_cfg.head_dim;
5502 let v_norm = text_cfg.head_dim; let intermediate_size = match &text_cfg.intermediate_size {
5506 IntermediateSize::Single(size) => *size,
5507 IntermediateSize::PerLayer(sizes) => sizes[layer_idx],
5508 IntermediateSize::Matformer(sizes, _) => sizes[layer_idx],
5509 };
5510 let gate_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5511 let up_proj = text_cfg.hidden_size * intermediate_size / weight_pack_factor;
5512 let down_proj = intermediate_size * text_cfg.hidden_size / weight_pack_factor;
5513
5514 let altup_elems = {
5516 let correct_output_scale = text_cfg.hidden_size;
5517 let correction_coefs = text_cfg.altup_num_inputs * text_cfg.altup_num_inputs;
5518 let prediction_coefs =
5519 text_cfg.altup_num_inputs * text_cfg.altup_num_inputs.pow(2);
5520 let modality_router = text_cfg.hidden_size * text_cfg.altup_num_inputs;
5521 let router_norm = text_cfg.hidden_size;
5522
5523 correct_output_scale
5524 + correction_coefs
5525 + prediction_coefs
5526 + modality_router
5527 + router_norm
5528 };
5529
5530 let laurel_elems = {
5532 let left = text_cfg.hidden_size * text_cfg.laurel_rank;
5533 let right = text_cfg.laurel_rank * text_cfg.hidden_size;
5534 let post_norm = text_cfg.hidden_size;
5535
5536 left + right + post_norm
5537 };
5538
5539 let per_layer_input_gate =
5541 text_cfg.hidden_size * text_cfg.hidden_size_per_layer_input;
5542 let per_layer_projection =
5543 text_cfg.hidden_size_per_layer_input * text_cfg.hidden_size;
5544
5545 input_layernorm
5546 + post_attention_layernorm
5547 + pre_feedforward_layernorm
5548 + post_feedforward_layernorm
5549 + post_per_layer_input_norm
5550 + q_proj
5551 + k_proj
5552 + v_proj
5553 + o_proj
5554 + q_norm
5555 + k_norm
5556 + v_norm
5557 + gate_proj
5558 + up_proj
5559 + down_proj
5560 + altup_elems
5561 + laurel_elems
5562 + per_layer_input_gate
5563 + per_layer_projection
5564 };
5565
5566 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5567 }
5568
5569 Ok(layer_sizes)
5570 }
5571
5572 fn num_layers(&self, config: &str) -> Result<usize> {
5573 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5574 Ok(cfg.text_config.num_hidden_layers)
5575 }
5576
5577 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5578 let cfg: Gemma3nConfig = serde_json::from_str(config)?;
5579 let cfg = cfg.text_config;
5580
5581 let cfg = ModelConfigMetadata {
5582 max_seq_len: cfg.max_position_embeddings,
5583 num_layers: cfg.num_hidden_layers,
5584 hidden_size: cfg.hidden_size,
5585 num_kv_heads: cfg.num_key_value_heads,
5586 num_attn_heads: cfg.num_attention_heads,
5587 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5589 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
5590 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5591 };
5592
5593 Ok(Box::new(cfg))
5594 }
5595
5596 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5597 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
5598 }
5599}
5600
5601pub struct Qwen3VLLoader;
5607
5608pub struct Qwen3VLPrefixer;
5609
5610impl MultimodalPromptPrefixer for Qwen3VLPrefixer {
5611 }
5614
5615impl MultimodalModelLoader for Qwen3VLLoader {
5616 fn load(
5617 &self,
5618 config: &str,
5619 vb: ShardedVarBuilder,
5620 normal_loading_metadata: NormalLoadingMetadata,
5621 attention_mechanism: AttentionImplementation,
5622 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
5623 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5624 Ok(Box::new(Qwen3VLModel::new(
5625 &cfg,
5626 vb,
5627 self.is_gptx(config),
5628 normal_loading_metadata,
5629 attention_mechanism,
5630 )?))
5631 }
5632 fn is_gptx(&self, _config: &str) -> bool {
5633 true
5634 }
5635 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5636 let config: Qwen3VLConfig = serde_json::from_str(config)?;
5637 Ok(Box::new(config))
5638 }
5639 fn get_processor(
5640 &self,
5641 _model_config: &str,
5642 _processor_config: Option<ProcessorConfig>,
5643 _preprocessor_config: PreProcessorConfig,
5644 max_edge: Option<u32>,
5645 ) -> Arc<dyn Processor + Send + Sync> {
5646 Arc::new(Qwen3VLProcessor::new(max_edge))
5647 }
5648 fn supports_paged_attention(&self, _config: &str) -> bool {
5649 true
5650 }
5651 fn supports_prefix_cacher(&self, _config: &str) -> bool {
5652 true
5653 }
5654 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5655 Arc::new(Qwen3VLPrefixer)
5656 }
5657 fn modalities(&self, _config: &str) -> Result<Modalities> {
5658 Ok(Modalities {
5659 input: vec![SupportedModality::Text, SupportedModality::Vision],
5660 output: vec![SupportedModality::Text],
5661 })
5662 }
5663}
5664
5665impl IsqModelLoader for Qwen3VLLoader {
5666 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5667 Ok(vec![
5668 Regex::new(r"lm_head\.(weight|bias)$")?,
5669 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5671 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5672 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5673 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5674 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5676 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5677 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
5678 ])
5679 }
5680 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
5681 self.isq_layer_regexes(config)
5682 }
5683}
5684
5685impl DeviceMappedModelLoader for Qwen3VLLoader {
5686 fn mapped_max_act_size_elems(
5687 &self,
5688 config: &str,
5689 params: &AutoDeviceMapParams,
5690 ) -> Result<usize> {
5691 let AutoDeviceMapParams::Multimodal {
5692 max_seq_len,
5693 max_batch_size,
5694 max_image_shape,
5695 max_num_images,
5696 } = params
5697 else {
5698 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
5699 };
5700
5701 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5702
5703 let img_seq_len = {
5705 let cfg = &cfg.vision_config;
5706 let grid_t = 1;
5708 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
5710 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
5711 grid_t * grid_h * grid_w * max_num_images
5712 };
5713
5714 let max_text_attn = {
5715 let cfg = &cfg.text_config;
5716 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
5718 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
5719 };
5720
5721 Ok(max_text_attn)
5722 }
5723
5724 fn non_mapped_max_act_size_elems(
5725 &self,
5726 config: &str,
5727 params: &AutoDeviceMapParams,
5728 ) -> Result<usize> {
5729 let AutoDeviceMapParams::Multimodal {
5730 max_seq_len: _,
5731 max_batch_size,
5732 max_image_shape,
5733 max_num_images,
5734 } = params
5735 else {
5736 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
5737 };
5738
5739 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5740
5741 let img_seq_len = {
5743 let cfg = &cfg.vision_config;
5744 let grid_t = 1;
5746 let grid_h = max_image_shape.0 / cfg.patch_size;
5747 let grid_w = max_image_shape.1 / cfg.patch_size;
5748 grid_t * grid_h * grid_w
5749 };
5750
5751 let max_vision_attn = {
5752 let cfg = &cfg.vision_config;
5753 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
5754 };
5755
5756 Ok(max_vision_attn)
5757 }
5758
5759 fn non_mapped_size_in_bytes(
5760 &self,
5761 config: &str,
5762 dtype: DType,
5763 weight_pack_factor: usize,
5764 _matformer_config: Option<&MatformerSliceConfig>,
5765 ) -> Result<usize> {
5766 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5767 let tie = cfg.tie_word_embeddings;
5768 let text_elems = {
5769 let cfg = &cfg.text_config;
5770 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5771 let lm_head = if !tie || weight_pack_factor != 1 {
5773 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5774 } else {
5775 0
5776 };
5777 let norm = cfg.hidden_size;
5778 embed_tokens + lm_head + norm
5779 };
5780
5781 let (patch_merger, deepstack_mergers) = {
5782 let cfg = &cfg.vision_config;
5783 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
5784
5785 let mlp0 = hidden_size * hidden_size + hidden_size;
5786 let mlp2 = hidden_size * cfg.out_hidden_size + cfg.out_hidden_size;
5787
5788 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5790 let merger = mlp0 + mlp2 + ln_q;
5791
5792 let ds_ln = hidden_size + bias_if!(true, hidden_size);
5794 let ds_merger = mlp0 + mlp2 + ds_ln;
5795 let deepstack = cfg.deepstack_visual_indexes.len() * ds_merger;
5796
5797 (merger, deepstack)
5798 };
5799
5800 let patch_embed = {
5801 let cfg = &cfg.vision_config;
5802 let conv_cfg = Conv3dConfig {
5803 stride: cfg.patch_size,
5804 ..Default::default()
5805 };
5806 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
5807 let weight = cfg.in_chans * cfg.hidden_size / conv_cfg.groups
5808 * kernel_sizes[0]
5809 * kernel_sizes[1]
5810 * kernel_sizes[2];
5811 let bias = cfg.hidden_size;
5812 weight + bias
5813 };
5814
5815 let pos_embed = {
5816 let cfg = &cfg.vision_config;
5817 cfg.num_position_embeddings * cfg.hidden_size
5818 };
5819
5820 let encoder_layer = {
5821 let cfg = &cfg.vision_config;
5822 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5823 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
5824
5825 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
5826 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
5827 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
5828
5829 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
5830 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
5831
5832 norm1 + norm2 + fc1 + fc2 + qkv + out
5833 };
5834
5835 let elems = text_elems
5836 + patch_merger
5837 + deepstack_mergers
5838 + patch_embed
5839 + pos_embed
5840 + encoder_layer * cfg.vision_config.depth;
5841
5842 Ok(elems * dtype.size_in_bytes())
5843 }
5844
5845 fn layer_sizes_in_bytes(
5846 &self,
5847 config: &str,
5848 dtype: DType,
5849 weight_pack_factor: usize,
5850 _matformer_config: Option<&MatformerSliceConfig>,
5851 ) -> Result<Vec<usize>> {
5852 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5853 let per_layer_elems = {
5854 let cfg = &cfg.text_config;
5855 let input_layernorm = cfg.hidden_size;
5856 let post_attention_layernorm = cfg.hidden_size;
5857
5858 let size_in = cfg.hidden_size;
5859 let size_q = cfg.head_dim * cfg.num_attention_heads;
5860 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
5861 let q_proj = size_in * size_q / weight_pack_factor;
5862 let k_proj = size_in * size_kv / weight_pack_factor;
5863 let v_proj = size_in * size_kv / weight_pack_factor;
5864 let o_proj = size_q * size_in / weight_pack_factor;
5865
5866 let q_norm = cfg.head_dim;
5867 let k_norm = cfg.head_dim;
5868
5869 let h_size = cfg.hidden_size;
5870 let i_size = cfg.intermediate_size;
5871 let gate_proj = h_size * i_size / weight_pack_factor;
5872 let up_proj = h_size * i_size / weight_pack_factor;
5873 let down_proj = i_size * h_size / weight_pack_factor;
5874
5875 input_layernorm
5876 + post_attention_layernorm
5877 + q_proj
5878 + k_proj
5879 + v_proj
5880 + o_proj
5881 + q_norm
5882 + k_norm
5883 + gate_proj
5884 + up_proj
5885 + down_proj
5886 };
5887 Ok(vec![
5888 per_layer_elems * dtype.size_in_bytes();
5889 cfg.text_config.num_hidden_layers
5890 ])
5891 }
5892
5893 fn num_layers(&self, config: &str) -> Result<usize> {
5894 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5895 let cfg = &cfg.text_config;
5896 Ok(cfg.num_hidden_layers)
5897 }
5898
5899 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5900 let cfg: Qwen3VLConfig = serde_json::from_str(config)?;
5901 let cfg = &cfg.text_config;
5902
5903 let cfg = ModelConfigMetadata {
5904 max_seq_len: cfg.max_position_embeddings,
5905 num_layers: cfg.num_hidden_layers,
5906 hidden_size: cfg.hidden_size,
5907 num_kv_heads: cfg.num_key_value_heads,
5908 num_attn_heads: cfg.num_attention_heads,
5909 sliding_window: cfg.sliding_window,
5910 k_head_dim: cfg.head_dim,
5911 v_head_dim: cfg.head_dim,
5912 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5913 };
5914
5915 Ok(Box::new(cfg))
5916 }
5917
5918 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
5919 Some(vec![NonMappedSubModel::Vision])
5920 }
5921}
5922
5923pub struct Qwen3VLMoELoader;
5929
5930pub struct Qwen3VLMoEPrefixer;
5931
5932impl MultimodalPromptPrefixer for Qwen3VLMoEPrefixer {
5933 }
5936
5937impl MultimodalModelLoader for Qwen3VLMoELoader {
5938 fn load(
5939 &self,
5940 config: &str,
5941 vb: ShardedVarBuilder,
5942 normal_loading_metadata: NormalLoadingMetadata,
5943 attention_mechanism: AttentionImplementation,
5944 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
5945 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5946 Ok(Box::new(Qwen3VLMoEModel::new(
5947 &cfg,
5948 vb,
5949 self.is_gptx(config),
5950 normal_loading_metadata,
5951 attention_mechanism,
5952 )?))
5953 }
5954 fn is_gptx(&self, _config: &str) -> bool {
5955 true
5956 }
5957 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
5958 let config: Qwen3VLMoEConfig = serde_json::from_str(config)?;
5959 Ok(Box::new(config))
5960 }
5961 fn get_processor(
5962 &self,
5963 _model_config: &str,
5964 _processor_config: Option<ProcessorConfig>,
5965 _preprocessor_config: PreProcessorConfig,
5966 max_edge: Option<u32>,
5967 ) -> Arc<dyn Processor + Send + Sync> {
5968 Arc::new(Qwen3VLMoEProcessor::new(max_edge))
5969 }
5970 fn supports_paged_attention(&self, _config: &str) -> bool {
5971 true
5972 }
5973 fn supports_prefix_cacher(&self, _config: &str) -> bool {
5974 true
5975 }
5976 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
5977 Arc::new(Qwen3VLMoEPrefixer)
5978 }
5979 fn modalities(&self, _config: &str) -> Result<Modalities> {
5980 Ok(Modalities {
5981 input: vec![SupportedModality::Text, SupportedModality::Vision],
5982 output: vec![SupportedModality::Text],
5983 })
5984 }
5985}
5986
5987impl IsqModelLoader for Qwen3VLMoELoader {
5988 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
5989 Ok(vec![
5990 Regex::new(r"lm_head\.(weight|bias)$")?,
5991 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
5993 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
5994 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
5995 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
5996 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
5998 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
5999 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
6000 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
6002 Regex::new(
6004 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
6005 )?,
6006 Regex::new(
6007 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
6008 )?,
6009 Regex::new(
6010 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
6011 )?,
6012 ])
6013 }
6014 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
6015 self.isq_layer_regexes(config)
6016 }
6017 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
6018 Ok(vec![
6019 Regex::new(r"lm_head\.(weight|bias)$")?,
6020 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
6022 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
6023 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
6024 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate\.(weight|bias)$")?,
6026 Regex::new(
6028 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
6029 )?,
6030 Regex::new(
6031 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
6032 )?,
6033 Regex::new(
6034 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
6035 )?,
6036 ])
6037 }
6038 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
6039 self.isq_layer_regexes_moqe(config)
6040 }
6041}
6042
6043impl DeviceMappedModelLoader for Qwen3VLMoELoader {
6044 fn mapped_max_act_size_elems(
6045 &self,
6046 config: &str,
6047 params: &AutoDeviceMapParams,
6048 ) -> Result<usize> {
6049 let AutoDeviceMapParams::Multimodal {
6050 max_seq_len,
6051 max_batch_size,
6052 max_image_shape,
6053 max_num_images,
6054 } = params
6055 else {
6056 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6057 };
6058
6059 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6060
6061 let img_seq_len = {
6063 let cfg = &cfg.vision_config;
6064 let grid_t = 1;
6066 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
6068 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
6069 grid_t * grid_h * grid_w * max_num_images
6070 };
6071
6072 let max_text_attn = {
6073 let cfg = &cfg.text_config;
6074 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
6076 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
6077 };
6078
6079 Ok(max_text_attn)
6080 }
6081
6082 fn non_mapped_max_act_size_elems(
6083 &self,
6084 config: &str,
6085 params: &AutoDeviceMapParams,
6086 ) -> Result<usize> {
6087 let AutoDeviceMapParams::Multimodal {
6088 max_seq_len: _,
6089 max_batch_size,
6090 max_image_shape,
6091 max_num_images,
6092 } = params
6093 else {
6094 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6095 };
6096
6097 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6098
6099 let img_seq_len = {
6101 let cfg = &cfg.vision_config;
6102 let grid_t = 1;
6104 let grid_h = max_image_shape.0 / cfg.patch_size;
6105 let grid_w = max_image_shape.1 / cfg.patch_size;
6106 grid_t * grid_h * grid_w
6107 };
6108
6109 let max_vision_attn = {
6110 let cfg = &cfg.vision_config;
6111 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
6112 };
6113
6114 Ok(max_vision_attn)
6115 }
6116
6117 fn non_mapped_size_in_bytes(
6118 &self,
6119 config: &str,
6120 dtype: DType,
6121 weight_pack_factor: usize,
6122 _matformer_config: Option<&MatformerSliceConfig>,
6123 ) -> Result<usize> {
6124 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6125 let tie = cfg.tie_word_embeddings;
6126 let text_elems = {
6127 let cfg = &cfg.text_config;
6128 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
6129 let lm_head = if !tie || weight_pack_factor != 1 {
6131 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
6132 } else {
6133 0
6134 };
6135 let norm = cfg.hidden_size;
6136 embed_tokens + lm_head + norm
6137 };
6138
6139 let (patch_merger, deepstack_mergers) = {
6140 let cfg = &cfg.vision_config;
6141 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6142
6143 let mlp0 = hidden_size * hidden_size + hidden_size;
6144 let mlp2 = hidden_size * cfg.out_hidden_size + cfg.out_hidden_size;
6145
6146 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6148 let merger = mlp0 + mlp2 + ln_q;
6149
6150 let ds_ln = hidden_size + bias_if!(true, hidden_size);
6152 let ds_merger = mlp0 + mlp2 + ds_ln;
6153 let deepstack = cfg.deepstack_visual_indexes.len() * ds_merger;
6154
6155 (merger, deepstack)
6156 };
6157
6158 let patch_embed = {
6159 let cfg = &cfg.vision_config;
6160 let conv_cfg = Conv3dConfig {
6161 stride: cfg.patch_size,
6162 ..Default::default()
6163 };
6164 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6165 let weight = cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6166 * kernel_sizes[0]
6167 * kernel_sizes[1]
6168 * kernel_sizes[2];
6169 let bias = cfg.hidden_size;
6170 weight + bias
6171 };
6172
6173 let pos_embed = {
6174 let cfg = &cfg.vision_config;
6175 cfg.num_position_embeddings * cfg.hidden_size
6176 };
6177
6178 let encoder_layer = {
6179 let cfg = &cfg.vision_config;
6180 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6181 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6182
6183 #[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
6184 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6185 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6186
6187 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6188 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6189
6190 norm1 + norm2 + fc1 + fc2 + qkv + out
6191 };
6192
6193 let elems = text_elems
6194 + patch_merger
6195 + deepstack_mergers
6196 + patch_embed
6197 + pos_embed
6198 + encoder_layer * cfg.vision_config.depth;
6199
6200 Ok(elems * dtype.size_in_bytes())
6201 }
6202
6203 fn layer_sizes_in_bytes(
6204 &self,
6205 config: &str,
6206 dtype: DType,
6207 weight_pack_factor: usize,
6208 _matformer_config: Option<&MatformerSliceConfig>,
6209 ) -> Result<Vec<usize>> {
6210 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6211 let text_cfg = &cfg.text_config;
6212
6213 let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6214
6215 for layer_idx in 0..text_cfg.num_hidden_layers {
6216 let input_layernorm = text_cfg.hidden_size;
6217 let post_attention_layernorm = text_cfg.hidden_size;
6218
6219 let size_in = text_cfg.hidden_size;
6220 let size_q = text_cfg.head_dim * text_cfg.num_attention_heads;
6221 let size_kv = text_cfg.head_dim * text_cfg.num_key_value_heads;
6222 let q_proj = size_in * size_q / weight_pack_factor;
6223 let k_proj = size_in * size_kv / weight_pack_factor;
6224 let v_proj = size_in * size_kv / weight_pack_factor;
6225 let o_proj = size_q * size_in / weight_pack_factor;
6226
6227 let q_norm = text_cfg.head_dim;
6228 let k_norm = text_cfg.head_dim;
6229
6230 let is_moe = !text_cfg.mlp_only_layers.contains(&layer_idx)
6232 && (text_cfg.num_experts > 0
6233 && (layer_idx + 1) % text_cfg.decoder_sparse_step == 0);
6234
6235 let mlp_elems = if is_moe {
6236 let gate = text_cfg.hidden_size * text_cfg.num_experts;
6238 let per_expert = {
6239 let h_size = text_cfg.hidden_size;
6240 let i_size = text_cfg.moe_intermediate_size;
6241 let gate_proj = h_size * i_size / weight_pack_factor;
6242 let up_proj = h_size * i_size / weight_pack_factor;
6243 let down_proj = i_size * h_size / weight_pack_factor;
6244 gate_proj + up_proj + down_proj
6245 };
6246 gate + per_expert * text_cfg.num_experts
6247 } else {
6248 let h_size = text_cfg.hidden_size;
6250 let i_size = text_cfg.intermediate_size;
6251 let gate_proj = h_size * i_size / weight_pack_factor;
6252 let up_proj = h_size * i_size / weight_pack_factor;
6253 let down_proj = i_size * h_size / weight_pack_factor;
6254 gate_proj + up_proj + down_proj
6255 };
6256
6257 let per_layer_elems = input_layernorm
6258 + post_attention_layernorm
6259 + q_proj
6260 + k_proj
6261 + v_proj
6262 + o_proj
6263 + q_norm
6264 + k_norm
6265 + mlp_elems;
6266
6267 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6268 }
6269
6270 Ok(layer_sizes)
6271 }
6272
6273 fn num_layers(&self, config: &str) -> Result<usize> {
6274 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6275 let cfg = &cfg.text_config;
6276 Ok(cfg.num_hidden_layers)
6277 }
6278
6279 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6280 let cfg: Qwen3VLMoEConfig = serde_json::from_str(config)?;
6281 let cfg = &cfg.text_config;
6282
6283 let cfg = ModelConfigMetadata {
6284 max_seq_len: cfg.max_position_embeddings,
6285 num_layers: cfg.num_hidden_layers,
6286 hidden_size: cfg.hidden_size,
6287 num_kv_heads: cfg.num_key_value_heads,
6288 num_attn_heads: cfg.num_attention_heads,
6289 sliding_window: cfg.sliding_window,
6290 k_head_dim: cfg.head_dim,
6291 v_head_dim: cfg.head_dim,
6292 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
6293 };
6294
6295 Ok(Box::new(cfg))
6296 }
6297
6298 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6299 Some(vec![NonMappedSubModel::Vision])
6300 }
6301}
6302
6303pub struct Qwen3_5Loader;
6309
6310pub struct Qwen3_5Prefixer;
6311
6312impl MultimodalPromptPrefixer for Qwen3_5Prefixer {
6313 }
6316
6317impl MultimodalModelLoader for Qwen3_5Loader {
6318 fn load(
6319 &self,
6320 config: &str,
6321 vb: ShardedVarBuilder,
6322 normal_loading_metadata: NormalLoadingMetadata,
6323 attention_mechanism: AttentionImplementation,
6324 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
6325 let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6326 Ok(Box::new(Qwen3_5Model::new(
6327 &cfg,
6328 vb,
6329 self.is_gptx(config),
6330 normal_loading_metadata,
6331 attention_mechanism,
6332 )?))
6333 }
6334 fn is_gptx(&self, _config: &str) -> bool {
6335 true
6336 }
6337 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
6338 let config: Qwen3_5Config = serde_json::from_str(config)?;
6339 Ok(Box::new(config))
6340 }
6341 fn get_processor(
6342 &self,
6343 _model_config: &str,
6344 _processor_config: Option<ProcessorConfig>,
6345 _preprocessor_config: PreProcessorConfig,
6346 max_edge: Option<u32>,
6347 ) -> Arc<dyn Processor + Send + Sync> {
6348 Arc::new(Qwen3_5Processor::new(max_edge))
6349 }
6350 fn supports_paged_attention(&self, _config: &str) -> bool {
6351 true
6352 }
6353 fn supports_prefix_cacher(&self, _config: &str) -> bool {
6354 true
6355 }
6356 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
6357 Arc::new(Qwen3_5Prefixer)
6358 }
6359 fn modalities(&self, _config: &str) -> Result<Modalities> {
6360 Ok(Modalities {
6361 input: vec![SupportedModality::Text, SupportedModality::Vision],
6362 output: vec![SupportedModality::Text],
6363 })
6364 }
6365}
6366
6367impl IsqModelLoader for Qwen3_5Loader {
6368 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
6369 Ok(vec![
6370 Regex::new(r"lm_head\.(weight|bias)$")?,
6371 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
6373 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
6374 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
6375 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
6376 Regex::new(
6378 r"model\.language_model\.layers\.(\d+)\.linear_attn\.out_proj\.(weight|bias)$",
6379 )?,
6380 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
6382 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
6383 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
6384 ])
6385 }
6386 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
6387 self.isq_layer_regexes(config)
6388 }
6389}
6390
6391impl DeviceMappedModelLoader for Qwen3_5Loader {
6392 fn mapped_max_act_size_elems(
6393 &self,
6394 config: &str,
6395 params: &AutoDeviceMapParams,
6396 ) -> Result<usize> {
6397 let AutoDeviceMapParams::Multimodal {
6398 max_seq_len,
6399 max_batch_size,
6400 max_image_shape,
6401 max_num_images,
6402 } = params
6403 else {
6404 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6405 };
6406
6407 let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6408
6409 let img_seq_len = {
6410 let cfg = &cfg.vision_config;
6411 let grid_t = 1;
6412 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
6413 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
6414 grid_t * grid_h * grid_w * max_num_images
6415 };
6416
6417 let max_text_attn = {
6418 let cfg = &cfg.text_config;
6419 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
6420 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
6421 };
6422
6423 Ok(max_text_attn)
6424 }
6425
6426 fn non_mapped_max_act_size_elems(
6427 &self,
6428 config: &str,
6429 params: &AutoDeviceMapParams,
6430 ) -> Result<usize> {
6431 let AutoDeviceMapParams::Multimodal {
6432 max_seq_len: _,
6433 max_batch_size,
6434 max_image_shape,
6435 max_num_images,
6436 } = params
6437 else {
6438 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6439 };
6440
6441 let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6442
6443 let img_seq_len = {
6444 let cfg = &cfg.vision_config;
6445 let grid_t = 1;
6446 let grid_h = max_image_shape.0 / cfg.patch_size;
6447 let grid_w = max_image_shape.1 / cfg.patch_size;
6448 grid_t * grid_h * grid_w
6449 };
6450
6451 let max_vision_attn = {
6452 let cfg = &cfg.vision_config;
6453 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
6454 };
6455
6456 Ok(max_vision_attn)
6457 }
6458
6459 fn non_mapped_size_in_bytes(
6460 &self,
6461 config: &str,
6462 dtype: DType,
6463 weight_pack_factor: usize,
6464 _matformer_config: Option<&MatformerSliceConfig>,
6465 ) -> Result<usize> {
6466 let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6467 let tie = cfg.tie_word_embeddings;
6468 let text_elems = {
6469 let cfg = &cfg.text_config;
6470 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
6471 let lm_head = if !tie || weight_pack_factor != 1 {
6472 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
6473 } else {
6474 0
6475 };
6476 let norm = cfg.hidden_size;
6477 embed_tokens + lm_head + norm
6478 };
6479
6480 let (patch_merger, deepstack_mergers) = {
6481 let cfg = &cfg.vision_config;
6482 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6483
6484 let mlp0 = hidden_size * hidden_size + hidden_size;
6485 let mlp2 = hidden_size * cfg.out_hidden_size + cfg.out_hidden_size;
6486
6487 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6488 let merger = mlp0 + mlp2 + ln_q;
6489
6490 let ds_ln = hidden_size + bias_if!(true, hidden_size);
6491 let ds_merger = mlp0 + mlp2 + ds_ln;
6492 let deepstack = cfg.deepstack_visual_indexes.len() * ds_merger;
6493
6494 (merger, deepstack)
6495 };
6496
6497 let patch_embed = {
6498 let cfg = &cfg.vision_config;
6499 let conv_cfg = Conv3dConfig {
6500 stride: cfg.patch_size,
6501 ..Default::default()
6502 };
6503 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6504 let weight = cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6505 * kernel_sizes[0]
6506 * kernel_sizes[1]
6507 * kernel_sizes[2];
6508 let bias = cfg.hidden_size;
6509 weight + bias
6510 };
6511
6512 let pos_embed = {
6513 let cfg = &cfg.vision_config;
6514 cfg.num_position_embeddings * cfg.hidden_size
6515 };
6516
6517 let encoder_layer = {
6518 let cfg = &cfg.vision_config;
6519 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6520 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6521
6522 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6523 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6524
6525 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6526 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6527
6528 norm1 + norm2 + fc1 + fc2 + qkv + out
6529 };
6530
6531 let elems = text_elems
6532 + patch_merger
6533 + deepstack_mergers
6534 + patch_embed
6535 + pos_embed
6536 + encoder_layer * cfg.vision_config.depth;
6537
6538 Ok(elems * dtype.size_in_bytes())
6539 }
6540
6541 fn layer_sizes_in_bytes(
6542 &self,
6543 config: &str,
6544 dtype: DType,
6545 weight_pack_factor: usize,
6546 _matformer_config: Option<&MatformerSliceConfig>,
6547 ) -> Result<Vec<usize>> {
6548 let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6549 let text_cfg = &cfg.text_config;
6550 let layer_types = text_cfg.layer_types();
6551
6552 let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6553
6554 for layer_type in &layer_types {
6555 let input_layernorm = text_cfg.hidden_size;
6556 let post_attention_layernorm = text_cfg.hidden_size;
6557
6558 let attn_elems = match layer_type {
6559 crate::vision_models::qwen3_5::config::LayerType::FullAttention => {
6560 let size_in = text_cfg.hidden_size;
6561 let size_q = text_cfg.head_dim * text_cfg.num_attention_heads;
6562 let size_kv = text_cfg.head_dim * text_cfg.num_key_value_heads;
6563 let q_proj = size_in * size_q * 2 / weight_pack_factor;
6564 let k_proj = size_in * size_kv / weight_pack_factor;
6565 let v_proj = size_in * size_kv / weight_pack_factor;
6566 let o_proj = size_q * size_in / weight_pack_factor;
6567 let q_norm = text_cfg.head_dim;
6568 let k_norm = text_cfg.head_dim;
6569 q_proj + k_proj + v_proj + o_proj + q_norm + k_norm
6570 }
6571 crate::vision_models::qwen3_5::config::LayerType::LinearAttention => {
6572 let hidden = text_cfg.hidden_size;
6573 let key_dim = text_cfg.linear_key_dim();
6574 let value_dim = text_cfg.linear_value_dim();
6575 let conv_dim = text_cfg.linear_conv_dim();
6576 let in_proj_qkvz = hidden * (key_dim * 2 + value_dim * 2);
6578 let in_proj_ba = hidden * (text_cfg.linear_num_value_heads * 2);
6580 let out_proj = value_dim * hidden / weight_pack_factor;
6581 let conv1d = conv_dim * text_cfg.linear_conv_kernel_dim;
6582 let dt_bias = text_cfg.linear_num_value_heads;
6583 let a_log = text_cfg.linear_num_value_heads;
6584 let norm = text_cfg.linear_value_head_dim;
6586 in_proj_qkvz + in_proj_ba + out_proj + conv1d + dt_bias + a_log + norm
6587 }
6588 };
6589
6590 let mlp_elems = {
6592 let h_size = text_cfg.hidden_size;
6593 let i_size = text_cfg.intermediate_size;
6594 let gate_proj = h_size * i_size / weight_pack_factor;
6595 let up_proj = h_size * i_size / weight_pack_factor;
6596 let down_proj = i_size * h_size / weight_pack_factor;
6597 gate_proj + up_proj + down_proj
6598 };
6599
6600 let per_layer_elems =
6601 input_layernorm + post_attention_layernorm + attn_elems + mlp_elems;
6602
6603 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6604 }
6605
6606 Ok(layer_sizes)
6607 }
6608
6609 fn num_layers(&self, config: &str) -> Result<usize> {
6610 let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6611 Ok(cfg.text_config.num_hidden_layers)
6612 }
6613
6614 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
6615 let cfg: Qwen3_5Config = serde_json::from_str(config)?;
6616 let cfg = &cfg.text_config;
6617
6618 let cfg = ModelConfigMetadata {
6619 max_seq_len: cfg.max_position_embeddings,
6620 num_layers: cfg.num_hidden_layers,
6621 hidden_size: cfg.hidden_size,
6622 num_kv_heads: cfg.num_key_value_heads,
6623 num_attn_heads: cfg.num_attention_heads,
6624 sliding_window: None,
6625 k_head_dim: cfg.head_dim,
6626 v_head_dim: cfg.head_dim,
6627 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
6628 };
6629
6630 Ok(Box::new(cfg))
6631 }
6632
6633 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
6634 Some(vec![NonMappedSubModel::Vision])
6635 }
6636}
6637
6638pub struct Qwen3_5MoeLoader;
6644
6645pub struct Qwen3_5MoePrefixer;
6646
6647impl MultimodalPromptPrefixer for Qwen3_5MoePrefixer {
6648 }
6651
6652impl MultimodalModelLoader for Qwen3_5MoeLoader {
6653 fn load(
6654 &self,
6655 config: &str,
6656 vb: ShardedVarBuilder,
6657 normal_loading_metadata: NormalLoadingMetadata,
6658 attention_mechanism: AttentionImplementation,
6659 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
6660 let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6661 Ok(Box::new(Qwen3_5MoeModel::new(
6662 &cfg,
6663 vb,
6664 self.is_gptx(config),
6665 normal_loading_metadata,
6666 attention_mechanism,
6667 )?))
6668 }
6669 fn is_gptx(&self, _config: &str) -> bool {
6670 true
6671 }
6672 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
6673 let config: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6674 Ok(Box::new(config))
6675 }
6676 fn get_processor(
6677 &self,
6678 _model_config: &str,
6679 _processor_config: Option<ProcessorConfig>,
6680 _preprocessor_config: PreProcessorConfig,
6681 max_edge: Option<u32>,
6682 ) -> Arc<dyn Processor + Send + Sync> {
6683 Arc::new(Qwen3_5MoeProcessor::new(max_edge))
6684 }
6685 fn supports_paged_attention(&self, _config: &str) -> bool {
6686 true
6687 }
6688 fn supports_prefix_cacher(&self, _config: &str) -> bool {
6689 true
6690 }
6691 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
6692 Arc::new(Qwen3_5MoePrefixer)
6693 }
6694 fn modalities(&self, _config: &str) -> Result<Modalities> {
6695 Ok(Modalities {
6696 input: vec![SupportedModality::Text, SupportedModality::Vision],
6697 output: vec![SupportedModality::Text],
6698 })
6699 }
6700}
6701
6702impl IsqModelLoader for Qwen3_5MoeLoader {
6703 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
6704 Ok(vec![
6705 Regex::new(r"lm_head\.(weight|bias)$")?,
6706 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
6708 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
6709 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
6710 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
6711 Regex::new(
6713 r"model\.language_model\.layers\.(\d+)\.linear_attn\.out_proj\.(weight|bias)$",
6714 )?,
6715 Regex::new(
6717 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
6718 )?,
6719 Regex::new(
6720 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
6721 )?,
6722 Regex::new(
6723 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
6724 )?,
6725 Regex::new(
6727 r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.gate_proj\.(weight|bias)$",
6728 )?,
6729 Regex::new(
6730 r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.up_proj\.(weight|bias)$",
6731 )?,
6732 Regex::new(
6733 r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.down_proj\.(weight|bias)$",
6734 )?,
6735 ])
6736 }
6737 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
6738 self.isq_layer_regexes(config)
6739 }
6740 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
6741 Ok(vec![
6742 Regex::new(r"lm_head\.(weight|bias)$")?,
6743 Regex::new(
6745 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$",
6746 )?,
6747 Regex::new(
6748 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$",
6749 )?,
6750 Regex::new(
6751 r"model\.language_model\.layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$",
6752 )?,
6753 Regex::new(
6755 r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.gate_proj\.(weight|bias)$",
6756 )?,
6757 Regex::new(
6758 r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.up_proj\.(weight|bias)$",
6759 )?,
6760 Regex::new(
6761 r"model\.language_model\.layers\.(\d+)\.mlp\.shared_expert\.down_proj\.(weight|bias)$",
6762 )?,
6763 ])
6764 }
6765 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
6766 self.isq_layer_regexes_moqe(config)
6767 }
6768}
6769
6770impl DeviceMappedModelLoader for Qwen3_5MoeLoader {
6771 fn mapped_max_act_size_elems(
6772 &self,
6773 config: &str,
6774 params: &AutoDeviceMapParams,
6775 ) -> Result<usize> {
6776 let AutoDeviceMapParams::Multimodal {
6777 max_seq_len,
6778 max_batch_size,
6779 max_image_shape,
6780 max_num_images,
6781 } = params
6782 else {
6783 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6784 };
6785
6786 let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6787
6788 let img_seq_len = {
6789 let cfg = &cfg.vision_config;
6790 let grid_t = 1;
6791 let grid_h = (max_image_shape.0 / cfg.patch_size) / cfg.spatial_merge_size;
6792 let grid_w = (max_image_shape.1 / cfg.patch_size) / cfg.spatial_merge_size;
6793 grid_t * grid_h * grid_w * max_num_images
6794 };
6795
6796 let max_text_attn = {
6797 let cfg = &cfg.text_config;
6798 let max_seq_len = img_seq_len + max_seq_len.min(&ATTENTION_CHUNK_SIZE);
6799 max_batch_size * cfg.num_attention_heads * max_seq_len * max_seq_len
6800 };
6801
6802 Ok(max_text_attn)
6803 }
6804
6805 fn non_mapped_max_act_size_elems(
6806 &self,
6807 config: &str,
6808 params: &AutoDeviceMapParams,
6809 ) -> Result<usize> {
6810 let AutoDeviceMapParams::Multimodal {
6811 max_seq_len: _,
6812 max_batch_size,
6813 max_image_shape,
6814 max_num_images,
6815 } = params
6816 else {
6817 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
6818 };
6819
6820 let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6821
6822 let img_seq_len = {
6823 let cfg = &cfg.vision_config;
6824 let grid_t = 1;
6825 let grid_h = max_image_shape.0 / cfg.patch_size;
6826 let grid_w = max_image_shape.1 / cfg.patch_size;
6827 grid_t * grid_h * grid_w
6828 };
6829
6830 let max_vision_attn = {
6831 let cfg = &cfg.vision_config;
6832 (max_batch_size * max_num_images) * cfg.num_heads * img_seq_len * img_seq_len
6833 };
6834
6835 Ok(max_vision_attn)
6836 }
6837
6838 fn non_mapped_size_in_bytes(
6839 &self,
6840 config: &str,
6841 dtype: DType,
6842 weight_pack_factor: usize,
6843 _matformer_config: Option<&MatformerSliceConfig>,
6844 ) -> Result<usize> {
6845 let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6846 let tie = cfg.tie_word_embeddings;
6847 let text_elems = {
6848 let cfg = &cfg.text_config;
6849 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
6850 let lm_head = if !tie || weight_pack_factor != 1 {
6851 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
6852 } else {
6853 0
6854 };
6855 let norm = cfg.hidden_size;
6856 embed_tokens + lm_head + norm
6857 };
6858
6859 let (patch_merger, deepstack_mergers) = {
6860 let cfg = &cfg.vision_config;
6861 let hidden_size = cfg.hidden_size * cfg.spatial_merge_size.pow(2);
6862
6863 let mlp0 = hidden_size * hidden_size + hidden_size;
6864 let mlp2 = hidden_size * cfg.out_hidden_size + cfg.out_hidden_size;
6865
6866 let ln_q = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6867 let merger = mlp0 + mlp2 + ln_q;
6868
6869 let ds_ln = hidden_size + bias_if!(true, hidden_size);
6870 let ds_merger = mlp0 + mlp2 + ds_ln;
6871 let deepstack = cfg.deepstack_visual_indexes.len() * ds_merger;
6872
6873 (merger, deepstack)
6874 };
6875
6876 let patch_embed = {
6877 let cfg = &cfg.vision_config;
6878 let conv_cfg = Conv3dConfig {
6879 stride: cfg.patch_size,
6880 ..Default::default()
6881 };
6882 let kernel_sizes = [cfg.temporal_patch_size, cfg.patch_size, cfg.patch_size];
6883 let weight = cfg.in_chans * cfg.hidden_size / conv_cfg.groups
6884 * kernel_sizes[0]
6885 * kernel_sizes[1]
6886 * kernel_sizes[2];
6887 let bias = cfg.hidden_size;
6888 weight + bias
6889 };
6890
6891 let pos_embed = {
6892 let cfg = &cfg.vision_config;
6893 cfg.num_position_embeddings * cfg.hidden_size
6894 };
6895
6896 let encoder_layer = {
6897 let cfg = &cfg.vision_config;
6898 let norm1 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6899 let norm2 = cfg.hidden_size + bias_if!(true, cfg.hidden_size);
6900
6901 let fc1 = cfg.hidden_size * cfg.intermediate_size + cfg.intermediate_size;
6902 let fc2 = cfg.hidden_size * cfg.intermediate_size + cfg.hidden_size;
6903
6904 let qkv = cfg.hidden_size * cfg.hidden_size * 3 + cfg.hidden_size * 3;
6905 let out = cfg.hidden_size * cfg.hidden_size + cfg.hidden_size;
6906
6907 norm1 + norm2 + fc1 + fc2 + qkv + out
6908 };
6909
6910 let elems = text_elems
6911 + patch_merger
6912 + deepstack_mergers
6913 + patch_embed
6914 + pos_embed
6915 + encoder_layer * cfg.vision_config.depth;
6916
6917 Ok(elems * dtype.size_in_bytes())
6918 }
6919
6920 fn layer_sizes_in_bytes(
6921 &self,
6922 config: &str,
6923 dtype: DType,
6924 weight_pack_factor: usize,
6925 _matformer_config: Option<&MatformerSliceConfig>,
6926 ) -> Result<Vec<usize>> {
6927 let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
6928 let text_cfg = &cfg.text_config;
6929 let layer_types = text_cfg.layer_types();
6930
6931 let mut layer_sizes = Vec::with_capacity(text_cfg.num_hidden_layers);
6932
6933 for layer_type in &layer_types {
6934 let input_layernorm = text_cfg.hidden_size;
6935 let post_attention_layernorm = text_cfg.hidden_size;
6936
6937 let attn_elems = match layer_type {
6938 crate::vision_models::qwen3_5_moe::config::LayerType::FullAttention => {
6939 let size_in = text_cfg.hidden_size;
6940 let size_q = text_cfg.head_dim * text_cfg.num_attention_heads;
6941 let size_kv = text_cfg.head_dim * text_cfg.num_key_value_heads;
6942 let q_proj = size_in * size_q * 2 / weight_pack_factor;
6943 let k_proj = size_in * size_kv / weight_pack_factor;
6944 let v_proj = size_in * size_kv / weight_pack_factor;
6945 let o_proj = size_q * size_in / weight_pack_factor;
6946 let q_norm = text_cfg.head_dim;
6947 let k_norm = text_cfg.head_dim;
6948 q_proj + k_proj + v_proj + o_proj + q_norm + k_norm
6949 }
6950 crate::vision_models::qwen3_5_moe::config::LayerType::LinearAttention => {
6951 let hidden = text_cfg.hidden_size;
6952 let key_dim = text_cfg.linear_key_dim();
6953 let value_dim = text_cfg.linear_value_dim();
6954 let conv_dim = text_cfg.linear_conv_dim();
6955 let in_proj_qkvz = hidden * (key_dim * 2 + value_dim * 2);
6957 let in_proj_ba = hidden * (text_cfg.linear_num_value_heads * 2);
6959 let out_proj = value_dim * hidden / weight_pack_factor;
6961 let conv1d = conv_dim * text_cfg.linear_conv_kernel_dim;
6963 let dt_bias = text_cfg.linear_num_value_heads;
6965 let a_log = text_cfg.linear_num_value_heads;
6966 let norm = text_cfg.linear_value_head_dim;
6968 in_proj_qkvz + in_proj_ba + out_proj + conv1d + dt_bias + a_log + norm
6969 }
6970 };
6971
6972 let moe_elems = {
6974 let gate = text_cfg.hidden_size * text_cfg.num_experts;
6975 let per_expert = {
6976 let h_size = text_cfg.hidden_size;
6977 let i_size = text_cfg.moe_intermediate_size;
6978 let gate_proj = h_size * i_size / weight_pack_factor;
6979 let up_proj = h_size * i_size / weight_pack_factor;
6980 let down_proj = i_size * h_size / weight_pack_factor;
6981 gate_proj + up_proj + down_proj
6982 };
6983 let shared_expert = {
6984 let h_size = text_cfg.hidden_size;
6985 let i_size = text_cfg.shared_expert_intermediate_size;
6986 let gate_proj = h_size * i_size / weight_pack_factor;
6987 let up_proj = h_size * i_size / weight_pack_factor;
6988 let down_proj = i_size * h_size / weight_pack_factor;
6989 gate_proj + up_proj + down_proj
6990 };
6991 let shared_expert_gate = text_cfg.hidden_size;
6992 gate + per_expert * text_cfg.num_experts + shared_expert + shared_expert_gate
6993 };
6994
6995 let per_layer_elems =
6996 input_layernorm + post_attention_layernorm + attn_elems + moe_elems;
6997
6998 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
6999 }
7000
7001 Ok(layer_sizes)
7002 }
7003
7004 fn num_layers(&self, config: &str) -> Result<usize> {
7005 let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
7006 Ok(cfg.text_config.num_hidden_layers)
7007 }
7008
7009 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
7010 let cfg: Qwen3_5MoeConfig = serde_json::from_str(config)?;
7011 let cfg = &cfg.text_config;
7012
7013 let cfg = ModelConfigMetadata {
7014 max_seq_len: cfg.max_position_embeddings,
7015 num_layers: cfg.num_hidden_layers,
7016 hidden_size: cfg.hidden_size,
7017 num_kv_heads: cfg.num_key_value_heads,
7018 num_attn_heads: cfg.num_attention_heads,
7019 sliding_window: None,
7020 k_head_dim: cfg.head_dim,
7021 v_head_dim: cfg.head_dim,
7022 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
7023 };
7024
7025 Ok(Box::new(cfg))
7026 }
7027
7028 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
7029 Some(vec![NonMappedSubModel::Vision])
7030 }
7031}
7032
7033pub struct VoxtralLoader;
7039
7040pub struct VoxtralPrefixer;
7041
7042impl MultimodalPromptPrefixer for VoxtralPrefixer {
7043 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
7044 prompt.to_string()
7045 }
7046}
7047
7048impl MultimodalModelLoader for VoxtralLoader {
7049 fn load(
7050 &self,
7051 config: &str,
7052 vb: ShardedVarBuilder,
7053 normal_loading_metadata: NormalLoadingMetadata,
7054 attention_mechanism: AttentionImplementation,
7055 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
7056 let cfg: VoxtralConfig = serde_json::from_str(config)?;
7057 Ok(Box::new(VoxtralModel::new(
7058 &cfg,
7059 vb,
7060 self.is_gptx(config),
7061 normal_loading_metadata,
7062 attention_mechanism,
7063 )?))
7064 }
7065 fn is_gptx(&self, _config: &str) -> bool {
7066 true
7067 }
7068 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
7069 let cfg: VoxtralConfig = serde_json::from_str(config)?;
7070 Ok(Box::new(cfg))
7071 }
7072 fn get_processor(
7073 &self,
7074 model_config: &str,
7075 _processor_config: Option<ProcessorConfig>,
7076 _preprocessor_config: PreProcessorConfig,
7077 _max_edge: Option<u32>,
7078 ) -> Arc<dyn Processor + Send + Sync> {
7079 let cfg: VoxtralConfig =
7080 serde_json::from_str(model_config).expect("Failed to parse VoxtralConfig");
7081 Arc::new(VoxtralProcessor::new(&cfg))
7082 }
7083 fn supports_paged_attention(&self, _config: &str) -> bool {
7084 false
7085 }
7086 fn supports_prefix_cacher(&self, _config: &str) -> bool {
7087 false
7088 }
7089 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
7090 Arc::new(VoxtralPrefixer)
7091 }
7092 fn modalities(&self, _config: &str) -> Result<Modalities> {
7093 Ok(Modalities {
7094 input: vec![SupportedModality::Text, SupportedModality::Audio],
7095 output: vec![SupportedModality::Text],
7096 })
7097 }
7098 fn default_chat_template(&self, _config: &str) -> Option<String> {
7099 Some("{{ bos_token }}{% for message in messages %}{% if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}{{ raise_exception('Conversation roles must alternate user/assistant/user/assistant/...') }}{% endif %}{% if message['role'] == 'user' %}{{ '[INST] ' + message['content'] + ' [/INST]' }}{% elif message['role'] == 'assistant' %}{{ message['content'] + eos_token + ' ' }}{% else %}{{ raise_exception('Only user and assistant roles are supported!') }}{% endif %}{% endfor %}".to_string())
7101 }
7102 fn default_bos_eos(&self, _config: &str) -> Option<(String, String)> {
7103 Some(("<s>".to_string(), "</s>".to_string()))
7105 }
7106}
7107
7108impl IsqModelLoader for VoxtralLoader {
7109 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
7110 Ok(vec![
7111 Regex::new(r"lm_head\.(weight|bias)$")?,
7113 Regex::new(r"layers\.(\d+)\.attention\.wq\.(weight|bias)$")?,
7115 Regex::new(r"layers\.(\d+)\.attention\.wk\.(weight|bias)$")?,
7116 Regex::new(r"layers\.(\d+)\.attention\.wv\.(weight|bias)$")?,
7117 Regex::new(r"layers\.(\d+)\.attention\.wo\.(weight|bias)$")?,
7118 Regex::new(r"layers\.(\d+)\.feed_forward\.w1\.(weight|bias)$")?,
7120 Regex::new(r"layers\.(\d+)\.feed_forward\.w3\.(weight|bias)$")?,
7121 Regex::new(r"layers\.(\d+)\.feed_forward\.w2\.(weight|bias)$")?,
7122 ])
7123 }
7124 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
7125 Ok(vec![
7126 Regex::new(r"tok_embeddings\.(weight|bias)$")?,
7127 Regex::new(r"layers\.(\d+)\.attention\.wq\.(weight|bias)$")?,
7129 Regex::new(r"layers\.(\d+)\.attention\.wk\.(weight|bias)$")?,
7130 Regex::new(r"layers\.(\d+)\.attention\.wv\.(weight|bias)$")?,
7131 Regex::new(r"layers\.(\d+)\.attention\.wo\.(weight|bias)$")?,
7132 Regex::new(r"layers\.(\d+)\.feed_forward\.w1\.(weight|bias)$")?,
7134 Regex::new(r"layers\.(\d+)\.feed_forward\.w3\.(weight|bias)$")?,
7135 Regex::new(r"layers\.(\d+)\.feed_forward\.w2\.(weight|bias)$")?,
7136 ])
7137 }
7138}
7139
7140#[allow(clippy::cast_possible_truncation, clippy::cast_precision_loss)]
7141impl DeviceMappedModelLoader for VoxtralLoader {
7142 fn mapped_max_act_size_elems(
7143 &self,
7144 config: &str,
7145 params: &AutoDeviceMapParams,
7146 ) -> Result<usize> {
7147 let AutoDeviceMapParams::Multimodal {
7148 max_seq_len,
7149 max_batch_size,
7150 ..
7151 } = params
7152 else {
7153 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
7154 };
7155
7156 let cfg: VoxtralConfig = serde_json::from_str(config)?;
7157
7158 let max_audio_tokens = 375;
7161 let total_seq = max_audio_tokens + *max_seq_len.min(&ATTENTION_CHUNK_SIZE);
7162 Ok(max_batch_size * cfg.n_heads * total_seq * total_seq)
7163 }
7164
7165 fn non_mapped_max_act_size_elems(
7166 &self,
7167 config: &str,
7168 params: &AutoDeviceMapParams,
7169 ) -> Result<usize> {
7170 let AutoDeviceMapParams::Multimodal { max_batch_size, .. } = params else {
7171 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
7172 };
7173
7174 let cfg: VoxtralConfig = serde_json::from_str(config)?;
7175 let enc = &cfg.multimodal.whisper_model_args.encoder_args;
7176 let max_enc_seq = 3000usize;
7179 Ok(max_batch_size * enc.n_heads * max_enc_seq * max_enc_seq)
7180 }
7181
7182 fn non_mapped_size_in_bytes(
7183 &self,
7184 config: &str,
7185 dtype: DType,
7186 _weight_pack_factor: usize,
7187 _matformer_config: Option<&MatformerSliceConfig>,
7188 ) -> Result<usize> {
7189 let cfg: VoxtralConfig = serde_json::from_str(config)?;
7190 let enc = &cfg.multimodal.whisper_model_args.encoder_args;
7191 let ds = &cfg.multimodal.whisper_model_args.downsample_args;
7192
7193 let elem = dtype.size_in_bytes();
7194
7195 let conv1 = enc.dim * enc.audio_encoding_args.num_mel_bins * 3 + enc.dim; let conv2 = enc.dim * enc.dim * 3 + enc.dim;
7198
7199 let enc_attn_per_layer = 4 * enc.dim * enc.dim; let enc_mlp_per_layer = 3 * enc.dim * enc.hidden_dim; let enc_norm_per_layer = 2 * enc.dim; let enc_layers =
7204 enc.n_layers * (enc_attn_per_layer + enc_mlp_per_layer + enc_norm_per_layer);
7205 let enc_final_norm = enc.dim;
7206
7207 let adapter_in_features = enc.dim * ds.downsample_factor;
7209 let adapter = adapter_in_features * cfg.dim + cfg.dim + cfg.dim * cfg.dim + cfg.dim;
7210
7211 let total_encoder = conv1 + conv2 + enc_layers + enc_final_norm + adapter;
7212
7213 let embeddings = cfg.vocab_size * cfg.dim;
7215
7216 Ok((total_encoder + embeddings) * elem)
7217 }
7218
7219 fn layer_sizes_in_bytes(
7220 &self,
7221 config: &str,
7222 dtype: DType,
7223 weight_pack_factor: usize,
7224 _matformer_config: Option<&MatformerSliceConfig>,
7225 ) -> Result<Vec<usize>> {
7226 let cfg: VoxtralConfig = serde_json::from_str(config)?;
7227 let elem = dtype.size_in_bytes();
7228
7229 let attn = (cfg.dim * cfg.n_heads * cfg.head_dim
7230 + cfg.dim * cfg.n_kv_heads * cfg.head_dim
7231 + cfg.dim * cfg.n_kv_heads * cfg.head_dim
7232 + cfg.n_heads * cfg.head_dim * cfg.dim)
7233 / weight_pack_factor;
7234 let mlp = (cfg.dim * cfg.hidden_dim + cfg.hidden_dim * cfg.dim + cfg.dim * cfg.hidden_dim)
7235 / weight_pack_factor;
7236 let norms = 2 * cfg.dim; let per_layer = (attn + mlp + norms) * elem;
7239
7240 Ok(vec![per_layer; cfg.n_layers])
7241 }
7242
7243 fn num_layers(&self, config: &str) -> Result<usize> {
7244 let cfg: VoxtralConfig = serde_json::from_str(config)?;
7245 Ok(cfg.n_layers)
7246 }
7247
7248 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
7249 let cfg: VoxtralConfig = serde_json::from_str(config)?;
7250
7251 let cfg = ModelConfigMetadata {
7252 max_seq_len: cfg.model_max_length,
7253 num_layers: cfg.n_layers,
7254 hidden_size: cfg.dim,
7255 num_kv_heads: cfg.n_kv_heads,
7256 num_attn_heads: cfg.n_heads,
7257 sliding_window: cfg.sliding_window,
7258 k_head_dim: cfg.head_dim,
7259 v_head_dim: cfg.head_dim,
7260 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
7261 };
7262
7263 Ok(Box::new(cfg))
7264 }
7265}
7266
7267pub struct Gemma4Loader;
7270
7271#[allow(dead_code)]
7272pub struct Gemma4Prefixer;
7273
7274impl MultimodalPromptPrefixer for Gemma4Prefixer {
7275 fn prefix_image(&self, _image_indexes: Vec<usize>, prompt: &str) -> String {
7276 prompt.to_string()
7277 }
7278 fn prefix_video(&self, _video_indexes: Vec<usize>, prompt: &str) -> String {
7279 prompt.to_string()
7280 }
7281}
7282
7283impl MultimodalModelLoader for Gemma4Loader {
7284 fn load(
7285 &self,
7286 config: &str,
7287 vb: ShardedVarBuilder,
7288 normal_loading_metadata: NormalLoadingMetadata,
7289 attention_mechanism: AttentionImplementation,
7290 ) -> Result<Box<dyn MultimodalModel + Send + Sync>> {
7291 let cfg: Gemma4Config = serde_json::from_str(config)?;
7292 Ok(Box::new(Gemma4Model::new(
7293 &cfg,
7294 vb,
7295 self.is_gptx(config),
7296 normal_loading_metadata,
7297 attention_mechanism,
7298 )?))
7299 }
7300 fn is_gptx(&self, _config: &str) -> bool {
7301 true
7302 }
7303 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
7304 let config: Gemma4Config = serde_json::from_str(config)?;
7305 Ok(Box::new(config))
7306 }
7307 fn get_processor(
7308 &self,
7309 config: &str,
7310 processor_config: Option<ProcessorConfig>,
7311 _preprocessor_config: PreProcessorConfig,
7312 _max_edge: Option<u32>,
7313 ) -> Arc<dyn Processor + Send + Sync> {
7314 let cfg: Gemma4Config = serde_json::from_str(config).expect("Failed to parse Gemma4Config");
7315 Arc::new(Gemma4Processor::new(
7316 processor_config.unwrap_or_default(),
7317 cfg.vision_config.patch_size,
7318 cfg.vision_config.pooling_kernel_size,
7319 cfg.vision_config.default_output_length,
7320 true,
7321 cfg.audio_config.is_some(),
7322 ))
7323 }
7324 fn supports_paged_attention(&self, _config: &str) -> bool {
7325 true
7326 }
7327 fn supports_prefix_cacher(&self, _config: &str) -> bool {
7328 true
7329 }
7330 fn prefixer(&self, _config: &str) -> Arc<dyn MultimodalPromptPrefixer> {
7331 Arc::new(Gemma4Prefixer)
7332 }
7333 fn modalities(&self, config: &str) -> Result<Modalities> {
7334 let cfg: Gemma4Config = serde_json::from_str(config)?;
7335 let mut input = vec![
7336 SupportedModality::Text,
7337 SupportedModality::Vision,
7338 SupportedModality::Video,
7339 ];
7340 if cfg.audio_config.is_some() {
7341 input.push(SupportedModality::Audio);
7342 }
7343 Ok(Modalities {
7344 input,
7345 output: vec![SupportedModality::Text],
7346 })
7347 }
7348}
7349
7350impl IsqModelLoader for Gemma4Loader {
7351 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
7352 Ok(vec![
7354 Regex::new(r"lm_head\.(weight|bias)$")?,
7355 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
7356 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
7357 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
7358 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
7359 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
7360 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
7361 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
7362 Regex::new(r"layers\.(\d+)\.moe\.gate_up_proj\.weight$")?,
7363 Regex::new(r"layers\.(\d+)\.moe\.down_proj\.weight$")?,
7364 Regex::new(r"layers\.(\d+)\.experts\.gate_up_proj\.weight$")?,
7365 Regex::new(r"layers\.(\d+)\.experts\.down_proj\.weight$")?,
7366 Regex::new(r"per_layer_model_projection\.(weight|bias)$")?,
7367 Regex::new(r"layers\.(\d+)\.per_layer_input_gate\.(weight|bias)$")?,
7368 Regex::new(r"layers\.(\d+)\.per_layer_projection\.(weight|bias)$")?,
7369 ])
7370 }
7371 fn immediate_isq_predicates(&self, _config: &str) -> Result<Vec<Regex>> {
7372 Ok(vec![
7373 Regex::new(r"lm_head\.(weight|bias)$")?,
7374 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
7375 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
7376 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
7377 Regex::new(r"model\.language_model\.layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
7378 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
7379 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
7380 Regex::new(r"model\.language_model\.layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
7381 Regex::new(r"model\.language_model\.layers\.(\d+)\.moe\.gate_up_proj\.weight$")?,
7382 Regex::new(r"model\.language_model\.layers\.(\d+)\.moe\.down_proj\.weight$")?,
7383 Regex::new(r"model\.language_model\.layers\.(\d+)\.experts\.gate_up_proj\.weight$")?,
7384 Regex::new(r"model\.language_model\.layers\.(\d+)\.experts\.down_proj\.weight$")?,
7385 Regex::new(r"model\.language_model\.per_layer_model_projection\.(weight|bias)$")?,
7386 Regex::new(
7387 r"model\.language_model\.layers\.(\d+)\.per_layer_input_gate\.(weight|bias)$",
7388 )?,
7389 Regex::new(
7390 r"model\.language_model\.layers\.(\d+)\.per_layer_projection\.(weight|bias)$",
7391 )?,
7392 ])
7393 }
7394}
7395
7396impl DeviceMappedModelLoader for Gemma4Loader {
7397 fn mapped_max_act_size_elems(
7398 &self,
7399 config: &str,
7400 params: &AutoDeviceMapParams,
7401 ) -> Result<usize> {
7402 let AutoDeviceMapParams::Multimodal {
7403 max_seq_len,
7404 max_batch_size,
7405 max_image_shape: _,
7406 max_num_images,
7407 } = params
7408 else {
7409 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
7410 };
7411
7412 let cfg: Gemma4Config = serde_json::from_str(config)?;
7413 let tc = &cfg.text_config;
7414
7415 let vision_tokens_per_image = cfg.vision_soft_tokens_per_image.unwrap_or(280);
7416 let audio_tokens = if cfg.audio_config.is_some() { 750 } else { 0 };
7417 let total_seq_len = *max_seq_len + vision_tokens_per_image * max_num_images + audio_tokens;
7418 let max_text_attn = max_batch_size * tc.num_attention_heads * total_seq_len * total_seq_len;
7419
7420 Ok(max_text_attn)
7421 }
7422
7423 fn non_mapped_max_act_size_elems(
7424 &self,
7425 config: &str,
7426 params: &AutoDeviceMapParams,
7427 ) -> Result<usize> {
7428 let AutoDeviceMapParams::Multimodal {
7429 max_seq_len: _,
7430 max_batch_size,
7431 max_image_shape: _,
7432 max_num_images,
7433 } = params
7434 else {
7435 anyhow::bail!("Expected multimodal AutoDeviceMapParams for this model!")
7436 };
7437
7438 let cfg: Gemma4Config = serde_json::from_str(config)?;
7439 let vc = &cfg.vision_config;
7440
7441 let max_patches =
7442 vc.default_output_length * vc.pooling_kernel_size * vc.pooling_kernel_size;
7443 let max_vision_attn =
7444 max_batch_size * max_num_images * vc.num_attention_heads * max_patches * max_patches;
7445 let max_vision_hidden = max_batch_size
7446 * max_num_images
7447 * max_patches
7448 * vc.hidden_size.max(vc.intermediate_size);
7449
7450 let max_audio_activation = cfg.audio_config.as_ref().map_or(0, |audio_cfg| {
7451 let subsample_factor: usize = audio_cfg
7452 .sscp_conv_stride_size
7453 .iter()
7454 .map(|stride| stride[0])
7455 .product();
7456 let max_audio_frames = 750 * subsample_factor.max(1);
7457 let audio_seq_after_subsample = max_audio_frames / subsample_factor.max(1);
7458
7459 let audio_encoder_act = audio_seq_after_subsample * (audio_cfg.hidden_size * 4);
7460 let chunk_size = audio_cfg.conf_attention_chunk_size;
7461 let context_size = chunk_size + audio_cfg.conf_attention_context_left - 1
7462 + audio_cfg.conf_attention_context_right;
7463 let num_chunks = audio_seq_after_subsample.div_ceil(chunk_size);
7464 let audio_attn_act =
7465 audio_cfg.conf_num_attention_heads * num_chunks * chunk_size * context_size;
7466
7467 max_batch_size * audio_encoder_act.max(audio_attn_act)
7468 });
7469
7470 Ok(max_vision_attn
7471 .max(max_vision_hidden)
7472 .max(max_audio_activation))
7473 }
7474
7475 fn non_mapped_size_in_bytes(
7476 &self,
7477 config: &str,
7478 dtype: DType,
7479 weight_pack_factor: usize,
7480 _matformer_config: Option<&MatformerSliceConfig>,
7481 ) -> Result<usize> {
7482 let cfg: Gemma4Config = serde_json::from_str(config)?;
7483 let tc = &cfg.text_config;
7484 let vc = &cfg.vision_config;
7485
7486 let text_elems = {
7487 let embed_tokens = tc.hidden_size * tc.vocab_size;
7488 let lm_head = if !tc.tie_word_embeddings || weight_pack_factor != 1 {
7489 tc.hidden_size * tc.vocab_size / weight_pack_factor
7490 } else {
7491 0
7492 };
7493 let norm = tc.hidden_size;
7494
7495 let ple_dim = tc.hidden_size_per_layer_input.unwrap_or(0);
7496 let ple_vocab = tc.vocab_size_per_layer_input.unwrap_or(tc.vocab_size);
7497 let embed_tokens_per_layer = if ple_dim > 0 {
7498 ple_vocab * tc.num_hidden_layers * ple_dim
7499 } else {
7500 0
7501 };
7502 let per_layer_model_projection = if ple_dim > 0 {
7503 tc.hidden_size * tc.num_hidden_layers * ple_dim / weight_pack_factor
7504 } else {
7505 0
7506 };
7507 let per_layer_projection_norm = ple_dim;
7508
7509 embed_tokens
7510 + lm_head
7511 + norm
7512 + embed_tokens_per_layer
7513 + per_layer_model_projection
7514 + per_layer_projection_norm
7515 };
7516
7517 let vision_layer_elems = {
7518 let quantized = vc.hidden_size * vc.num_attention_heads * vc.head_dim
7519 + 3 * (vc.hidden_size * vc.num_key_value_heads * vc.head_dim)
7520 + 2 * (vc.hidden_size * vc.intermediate_size)
7521 + vc.intermediate_size * vc.hidden_size;
7522 let norms = 2 * vc.head_dim + 4 * vc.hidden_size;
7523 quantized / weight_pack_factor + norms
7524 };
7525 let vision_elems = {
7526 let patch_embed = vc.patch_size * vc.patch_size * 3 * vc.hidden_size;
7527 let position_embedding_table = 2 * vc.position_embedding_size * vc.hidden_size;
7528 let patch_embedder = patch_embed / weight_pack_factor + position_embedding_table;
7529 let encoder = vc.num_hidden_layers * vision_layer_elems;
7530 let embed_vision = vc.hidden_size * tc.hidden_size / weight_pack_factor;
7531
7532 patch_embedder + encoder + embed_vision
7533 };
7534
7535 let audio_elems = cfg.audio_config.as_ref().map_or(0, |audio_cfg| {
7536 let mut f_out = audio_cfg.input_feat_size;
7537 for i in 0..2 {
7538 let kernel_w = audio_cfg.sscp_conv_kernel_size[i][1];
7539 let stride_w = audio_cfg.sscp_conv_stride_size[i][1];
7540 let pad_left = 1;
7541 let pad_right = 1;
7542 f_out = (f_out + pad_left + pad_right + stride_w - kernel_w) / stride_w;
7543 }
7544
7545 let subsample_conv_projection = {
7546 let conv_0 = audio_cfg.sscp_conv_channel_size[0]
7547 * audio_cfg.sscp_conv_kernel_size[0][0]
7548 * audio_cfg.sscp_conv_kernel_size[0][1];
7549 let conv_1 = audio_cfg.sscp_conv_channel_size[0]
7550 * audio_cfg.sscp_conv_channel_size[1]
7551 * audio_cfg.sscp_conv_kernel_size[1][0]
7552 * audio_cfg.sscp_conv_kernel_size[1][1];
7553 let norms =
7554 audio_cfg.sscp_conv_channel_size[0] + audio_cfg.sscp_conv_channel_size[1];
7555 let input_proj =
7556 audio_cfg.sscp_conv_channel_size[1] * f_out * audio_cfg.hidden_size
7557 / weight_pack_factor;
7558 conv_0 + conv_1 + norms + input_proj
7559 };
7560
7561 let conformer_block = {
7562 let attention = 5 * (audio_cfg.hidden_size * audio_cfg.hidden_size)
7563 / weight_pack_factor
7564 + 2 * audio_cfg.hidden_size
7565 + audio_cfg.hidden_size / audio_cfg.conf_num_attention_heads
7566 + audio_cfg.hidden_size / 2
7567 + (audio_cfg.conf_attention_context_left
7568 + audio_cfg.conf_attention_context_right
7569 + 1)
7570 + (audio_cfg.conf_attention_chunk_size
7571 * (audio_cfg.conf_attention_chunk_size
7572 + audio_cfg.conf_attention_context_left
7573 - 1
7574 + audio_cfg.conf_attention_context_right))
7575 + 1;
7576 let ffw = 2
7577 * (2 * audio_cfg.hidden_size
7578 + 2 * (audio_cfg.hidden_size * (audio_cfg.hidden_size * 4))
7579 / weight_pack_factor);
7580 let conv = 2 * audio_cfg.hidden_size
7581 + audio_cfg.hidden_size * (audio_cfg.hidden_size * 2) / weight_pack_factor
7582 + audio_cfg.hidden_size * audio_cfg.hidden_size / weight_pack_factor
7583 + audio_cfg.hidden_size * audio_cfg.conf_conv_kernel_size;
7584 attention + ffw + conv + audio_cfg.hidden_size
7585 };
7586
7587 let output_proj = audio_cfg.output_proj_dims.map_or(0, |output_dim| {
7588 audio_cfg.hidden_size * output_dim / weight_pack_factor + output_dim
7589 });
7590 let audio_embed_hidden = audio_cfg.output_proj_dims.unwrap_or(audio_cfg.hidden_size);
7591 let embed_audio = audio_embed_hidden * tc.hidden_size / weight_pack_factor;
7592
7593 subsample_conv_projection
7594 + audio_cfg.conf_num_hidden_layers * conformer_block
7595 + output_proj
7596 + embed_audio
7597 });
7598
7599 let vision_dtype = if dtype == DType::F16 {
7600 DType::F32
7601 } else {
7602 dtype
7603 };
7604
7605 Ok(text_elems * dtype.size_in_bytes()
7606 + vision_elems * vision_dtype.size_in_bytes()
7607 + audio_elems * dtype.size_in_bytes())
7608 }
7609
7610 fn layer_sizes_in_bytes(
7611 &self,
7612 config: &str,
7613 dtype: DType,
7614 weight_pack_factor: usize,
7615 _matformer_config: Option<&MatformerSliceConfig>,
7616 ) -> Result<Vec<usize>> {
7617 let cfg: Gemma4Config = serde_json::from_str(config)?;
7618 let tc = &cfg.text_config;
7619 let sizes: Vec<usize> = (0..tc.num_hidden_layers)
7620 .map(|layer_idx| {
7621 let is_sliding = {
7622 let is_last = layer_idx == tc.num_hidden_layers - 1;
7623 !is_last && (layer_idx + 1) % tc.sliding_window_pattern != 0
7624 };
7625 let hd = if is_sliding {
7626 tc.head_dim
7627 } else {
7628 tc.global_head_dim
7629 };
7630 let nkv = if is_sliding {
7631 tc.num_key_value_heads
7632 } else {
7633 tc.num_global_key_value_heads
7634 .unwrap_or(tc.num_key_value_heads)
7635 };
7636 let use_k_eq_v = tc.attention_k_eq_v && !is_sliding;
7637
7638 let mut attn = tc.hidden_size * tc.num_attention_heads * hd
7639 + tc.hidden_size * nkv * hd
7640 + tc.num_attention_heads * hd * tc.hidden_size;
7641 if !use_k_eq_v {
7642 attn += tc.hidden_size * nkv * hd;
7643 }
7644 attn += 2 * hd;
7645
7646 let mlp = 3 * tc.hidden_size * tc.intermediate_size;
7647
7648 let moe = if tc.enable_moe_block {
7649 let ne = tc.num_experts.unwrap_or(0);
7650 let ei = tc.expert_intermediate_size().unwrap_or(0);
7651 ne * tc.hidden_size * ei * 2
7652 + ne * ei * tc.hidden_size
7653 + ne
7654 + ne * tc.hidden_size
7655 + tc.hidden_size
7656 + 3 * tc.hidden_size
7657 } else {
7658 0
7659 };
7660
7661 let ple = if tc.hidden_size_per_layer_input.unwrap_or(0) > 0 {
7662 let pd = tc.hidden_size_per_layer_input.unwrap();
7663 tc.hidden_size * pd + pd * tc.hidden_size + tc.hidden_size
7664 } else {
7665 0
7666 };
7667
7668 let norms = 4 * tc.hidden_size + 1;
7669
7670 (attn + mlp + moe + ple + norms) * dtype.size_in_bytes() / weight_pack_factor
7671 })
7672 .collect();
7673 Ok(sizes)
7674 }
7675
7676 fn num_layers(&self, config: &str) -> Result<usize> {
7677 let cfg: Gemma4Config = serde_json::from_str(config)?;
7678 Ok(cfg.text_config.num_hidden_layers)
7679 }
7680
7681 fn non_mapped_sub_models(&self) -> Option<Vec<NonMappedSubModel>> {
7682 Some(vec![NonMappedSubModel::Vision, NonMappedSubModel::Audio])
7683 }
7684
7685 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
7686 let cfg: Gemma4Config = serde_json::from_str(config)?;
7687 let tc = &cfg.text_config;
7688
7689 let cfg = ModelConfigMetadata {
7690 max_seq_len: tc.max_position_embeddings,
7691 num_layers: tc.num_hidden_layers,
7692 hidden_size: tc.hidden_size,
7693 num_kv_heads: tc.num_key_value_heads,
7694 num_attn_heads: tc.num_attention_heads,
7695 sliding_window: Some(tc.sliding_window),
7696 k_head_dim: tc.global_head_dim,
7697 v_head_dim: tc.global_head_dim,
7698 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
7699 };
7700
7701 Ok(Box::new(cfg))
7702 }
7703}