1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 lora::{LoraConfig, Ordering},
14 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15 pipeline::{
16 isq::IsqModelLoader,
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel,
19 },
20 utils::varbuilder_utils::DeviceForLoadTensor,
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36 models,
37 xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43 #[allow(clippy::too_many_arguments)]
44 fn forward(
45 &self,
46 input_ids: &Tensor,
47 seqlen_offsets: &[usize],
48 context_lens: Vec<(usize, usize)>,
49 position_ids: Vec<usize>,
50 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51 flash_params: &FlashParams,
52 ) -> candle_core::Result<Tensor>;
53 #[allow(clippy::too_many_arguments)]
54 fn xlora_forward(
55 &self,
56 input_ids: &Tensor,
57 input_ids_full: &Tensor,
58 seqlen_offsets: &[usize],
59 seqlen_offsets_full: &[usize],
60 no_kv_cache: bool,
61 non_granular_state: &Option<NonGranularState>,
62 context_lens: Vec<(usize, usize)>,
63 position_ids: Vec<usize>,
64 flash_params: &FlashParams,
65 flash_params_full: &FlashParams,
66 ) -> candle_core::Result<Tensor>;
67 fn is_xlora(&self) -> bool;
68 fn device(&self) -> &Device;
69 fn cache(&self) -> &EitherCache;
70 fn cache_mut(&mut self) -> &mut EitherCache;
71 fn max_seq_len(&self) -> usize;
72 fn config(&self) -> &ModelConfigMetadata;
73}
74
75pub struct NormalLoadingMetadata {
77 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
79 pub loading_isq: bool,
81 pub real_device: Device,
83 pub multi_progress: Arc<MultiProgress>,
85 pub matformer_slicing_config: Option<MatformerSliceConfig>,
87}
88
89pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
90 fn load(
91 &self,
92 config: &str,
93 vb: ShardedVarBuilder,
94 normal_loading_metadata: NormalLoadingMetadata,
95 attention_mechanism: AttentionImplementation,
96 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
97 #[allow(clippy::too_many_arguments)]
98 fn load_xlora(
99 &self,
100 config: &str,
101 vb: ShardedVarBuilder,
102 lora_config: &[((String, String), LoraConfig)],
103 xlora_config: Option<XLoraConfig>,
104 xlora_ordering: Ordering,
105 normal_loading_metadata: NormalLoadingMetadata,
106 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
107 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
108 fn is_gptx(&self, config: &str) -> Result<bool>;
109 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
110 Ok(true)
111 }
112 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
113 fn get_device_for_tensor(
114 &self,
115 config: &str,
116 _mapper: &dyn DeviceMapper,
117 loading_isq: bool,
118 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
119 if loading_isq {
120 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
121 } else {
122 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
123 let num_layers = self.model_config(config)?.num_layers();
124 let closure = move |name: String| {
125 if let Some(captures) = re.captures(&name) {
126 captures
127 .get(1)
128 .and_then(|m| m.as_str().parse::<usize>().ok())
129 .map(|l| l.min(num_layers))
130 .map(DeviceForLoadTensor::Idx)
131 .unwrap_or(DeviceForLoadTensor::Base)
132 } else {
133 DeviceForLoadTensor::Base
134 }
135 };
136
137 Ok(Arc::new(closure))
138 }
139 }
140}
141
142#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
143#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
144pub enum NormalLoaderType {
146 #[serde(rename = "mistral")]
147 Mistral,
148 #[serde(rename = "gemma")]
149 Gemma,
150 #[serde(rename = "mixtral")]
151 Mixtral,
152 #[serde(rename = "llama")]
153 Llama,
154 #[serde(rename = "phi2")]
155 Phi2,
156 #[serde(rename = "phi3")]
157 Phi3,
158 #[serde(rename = "qwen2")]
159 Qwen2,
160 #[serde(rename = "gemma2")]
161 Gemma2,
162 #[serde(rename = "starcoder2")]
163 Starcoder2,
164 #[serde(rename = "phi3.5moe")]
165 Phi3_5MoE,
166 #[serde(rename = "deepseekv2")]
167 DeepSeekV2,
168 #[serde(rename = "deepseekv3")]
169 DeepSeekV3,
170 #[serde(rename = "qwen3")]
171 Qwen3,
172 #[serde(rename = "glm4")]
173 GLM4,
174 #[serde(rename = "glm4moelite")]
175 GLM4MoeLite,
176 #[serde(rename = "glm4moe")]
177 GLM4Moe,
178 #[serde(rename = "qwen3moe")]
179 Qwen3Moe,
180 #[serde(rename = "smollm3")]
181 SmolLm3,
182 #[serde(rename = "granitemoehybrid")]
183 GraniteMoeHybrid,
184 #[serde(rename = "gpt_oss")]
185 GptOss,
186}
187
188impl NormalLoaderType {
190 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
191 match name {
192 "MistralForCausalLM" => Ok(Self::Mistral),
193 "MixtralForCausalLM" => Ok(Self::Mixtral),
194 "GemmaForCausalLM" => Ok(Self::Gemma),
195 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
196 "PhiForCausalLM" => Ok(Self::Phi2),
197 "Phi3ForCausalLM" => Ok(Self::Phi3),
198 "LlamaForCausalLM" => Ok(Self::Llama),
199 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
200 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
201 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
202 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
203 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
204 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
205 "Glm4ForCausalLM" => Ok(Self::GLM4),
206 "Glm4MoeLiteForCausalLM" => Ok(Self::GLM4MoeLite),
207 "Glm4MoeForCausalLM" => Ok(Self::GLM4Moe),
208 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
209 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
210 "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
211 "GptOssForCausalLM" => Ok(Self::GptOss),
212 other => anyhow::bail!(
213 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
214 ),
215 }
216 }
217}
218
219impl FromStr for NormalLoaderType {
220 type Err = String;
221 fn from_str(s: &str) -> Result<Self, Self::Err> {
222 match s {
223 "mistral" => Ok(Self::Mistral),
224 "gemma" => Ok(Self::Gemma),
225 "mixtral" => Ok(Self::Mixtral),
226 "llama" => Ok(Self::Llama),
227 "phi2" => Ok(Self::Phi2),
228 "phi3" => Ok(Self::Phi3),
229 "qwen2" => Ok(Self::Qwen2),
230 "gemma2" => Ok(Self::Gemma2),
231 "starcoder2" => Ok(Self::Starcoder2),
232 "phi3.5moe" => Ok(Self::Phi3_5MoE),
233 "deepseekv2" => Ok(Self::DeepSeekV2),
234 "deepseekv3" => Ok(Self::DeepSeekV3),
235 "qwen3" => Ok(Self::Qwen3),
236 "glm4" => Ok(Self::GLM4),
237 "glm4moelite" => Ok(Self::GLM4MoeLite),
238 "glm4moe" => Ok(Self::GLM4Moe),
239 "qwen3moe" => Ok(Self::Qwen3Moe),
240 "smollm3" => Ok(Self::SmolLm3),
241 "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
242 "gpt_oss" => Ok(Self::GptOss),
243 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `glm4moelite`, `glm4moe`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`.")),
244 }
245 }
246}
247
248impl Display for NormalLoaderType {
249 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
250 match self {
251 Self::Gemma => write!(f, "gemma"),
252 Self::Gemma2 => write!(f, "gemma2"),
253 Self::Llama => write!(f, "llama"),
254 Self::Mistral => write!(f, "mistral"),
255 Self::Mixtral => write!(f, "mixtral"),
256 Self::Phi2 => write!(f, "phi2"),
257 Self::Phi3 => write!(f, "phi3"),
258 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
259 Self::Qwen2 => write!(f, "qwen2"),
260 Self::Starcoder2 => write!(f, "starcoder2"),
261 Self::DeepSeekV2 => write!(f, "deepseekv2"),
262 Self::DeepSeekV3 => write!(f, "deepseekv3"),
263 Self::Qwen3 => write!(f, "qwen3"),
264 Self::GLM4 => write!(f, "glm4"),
265 Self::GLM4MoeLite => write!(f, "glm4moelite"),
266 Self::GLM4Moe => write!(f, "glm4moe"),
267 Self::Qwen3Moe => write!(f, "qwen3moe"),
268 Self::SmolLm3 => write!(f, "smollm3"),
269 Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
270 Self::GptOss => write!(f, "gpt_oss"),
271 }
272 }
273}
274
275macro_rules! bias_if {
276 ($cond:expr, $size:expr) => {
277 if $cond {
278 $size
279 } else {
280 0
281 }
282 };
283}
284
285pub struct AutoNormalLoader;
287
288#[derive(Deserialize)]
289struct AutoNormalLoaderConfig {
290 architectures: Vec<String>,
291}
292
293impl AutoNormalLoader {
294 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
295 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
296 if auto_cfg.architectures.len() != 1 {
297 anyhow::bail!("Expected to have one name for `architectures` config field.")
298 }
299
300 let name = &auto_cfg.architectures[0];
301
302 let tp = NormalLoaderType::from_causal_lm_name(name)?;
303
304 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
305
306 match tp {
307 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
308 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
309 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
310 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
311 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
312 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
313 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
314 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
315 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
316 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
317 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
318 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
319 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
320 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
321 NormalLoaderType::GLM4MoeLite => Ok(Box::new(GLM4MoeLiteLoader)),
322 NormalLoaderType::GLM4Moe => Ok(Box::new(GLM4MoeLoader)),
323 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
324 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
325 NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
326 NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
327 }
328 }
329}
330
331impl NormalModelLoader for AutoNormalLoader {
332 fn load(
333 &self,
334 config: &str,
335 vb: ShardedVarBuilder,
336 normal_loading_metadata: NormalLoadingMetadata,
337 attention_mechanism: AttentionImplementation,
338 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
339 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
340 }
341 fn load_xlora(
342 &self,
343 config: &str,
344 vb: ShardedVarBuilder,
345 lora_config: &[((String, String), LoraConfig)],
346 xlora_config: Option<XLoraConfig>,
347 xlora_ordering: Ordering,
348 normal_loading_metadata: NormalLoadingMetadata,
349 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
350 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
351 Self::get_loader(config)?.load_xlora(
352 config,
353 vb,
354 lora_config,
355 xlora_config,
356 xlora_ordering,
357 normal_loading_metadata,
358 preload_adapters,
359 )
360 }
361 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
362 Self::get_loader(config)?.get_config_repr(config)
363 }
364 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
365 Self::get_loader(config)?.supports_paged_attention(config)
366 }
367 fn is_gptx(&self, config: &str) -> Result<bool> {
368 Self::get_loader(config)?.is_gptx(config)
369 }
370}
371
372impl IsqModelLoader for AutoNormalLoader {
373 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
374 Self::get_loader(config)?.immediate_isq_predicates(config)
375 }
376 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
377 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
378 }
379 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
380 Self::get_loader(config)?.isq_layer_regexes(config)
381 }
382 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
383 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
384 }
385}
386
387impl DeviceMappedModelLoader for AutoNormalLoader {
388 fn non_mapped_size_in_bytes(
389 &self,
390 config: &str,
391 dtype: DType,
392 weight_pack_factor: usize,
393 _matformer_config: Option<&MatformerSliceConfig>,
394 ) -> Result<usize> {
395 Self::get_loader(config)?.non_mapped_size_in_bytes(
396 config,
397 dtype,
398 weight_pack_factor,
399 _matformer_config,
400 )
401 }
402 fn num_layers(&self, config: &str) -> Result<usize> {
403 Self::get_loader(config)?.num_layers(config)
404 }
405 fn layer_sizes_in_bytes(
406 &self,
407 config: &str,
408 dtype: DType,
409 weight_pack_factor: usize,
410 _matformer_config: Option<&MatformerSliceConfig>,
411 ) -> Result<Vec<usize>> {
412 Self::get_loader(config)?.layer_sizes_in_bytes(
413 config,
414 dtype,
415 weight_pack_factor,
416 _matformer_config,
417 )
418 }
419 fn mapped_max_act_size_elems(
420 &self,
421 config: &str,
422 params: &super::AutoDeviceMapParams,
423 ) -> Result<usize> {
424 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
425 }
426 fn non_mapped_max_act_size_elems(
427 &self,
428 _config: &str,
429 _params: &AutoDeviceMapParams,
430 ) -> Result<usize> {
431 Ok(0)
432 }
433 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
434 Self::get_loader(config)?.model_config(config)
435 }
436}
437
438pub struct MistralLoader;
441
442impl NormalModelLoader for MistralLoader {
443 fn load(
444 &self,
445 config: &str,
446 vb: ShardedVarBuilder,
447 normal_loading_metadata: NormalLoadingMetadata,
448 attention_mechanism: AttentionImplementation,
449 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
450 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
451 Ok(Box::new(models::mistral::Model::new(
452 &cfg,
453 vb,
454 self.is_gptx(config)?,
455 normal_loading_metadata,
456 attention_mechanism,
457 )?))
458 }
459 fn load_xlora(
460 &self,
461 config: &str,
462 vb: ShardedVarBuilder,
463 lora_config: &[((String, String), LoraConfig)],
464 xlora_config: Option<XLoraConfig>,
465 xlora_ordering: Ordering,
466 normal_loading_metadata: NormalLoadingMetadata,
467 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
468 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
469 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
470 Ok(Box::new(xlora_models::XLoraMistral::new(
471 &cfg,
472 vb,
473 lora_config,
474 xlora_config,
475 xlora_ordering,
476 self.is_gptx(config)?,
477 normal_loading_metadata,
478 preload_adapters,
479 )?))
480 }
481 fn is_gptx(&self, _: &str) -> Result<bool> {
482 Ok(true)
483 }
484 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
485 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
486 Ok(Box::new(cfg))
487 }
488}
489
490impl IsqModelLoader for MistralLoader {
491 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
492 Ok(vec![
493 Regex::new(r"lm_head\.(weight|bias)$")?,
494 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
496 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
497 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
498 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
499 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
501 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
502 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
503 ])
504 }
505 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
506 self.isq_layer_regexes(config)
507 }
508}
509
510impl DeviceMappedModelLoader for MistralLoader {
511 fn mapped_max_act_size_elems(
512 &self,
513 config: &str,
514 params: &AutoDeviceMapParams,
515 ) -> Result<usize> {
516 let AutoDeviceMapParams::Text {
517 max_seq_len,
518 max_batch_size,
519 } = params
520 else {
521 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
522 };
523
524 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
525
526 Ok(
527 max_batch_size
528 * cfg.num_attention_heads
529 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
530 )
531 }
532 fn non_mapped_max_act_size_elems(
533 &self,
534 _config: &str,
535 _params: &AutoDeviceMapParams,
536 ) -> Result<usize> {
537 Ok(0)
538 }
539
540 fn non_mapped_size_in_bytes(
541 &self,
542 config: &str,
543 dtype: DType,
544 weight_pack_factor: usize,
545 _matformer_config: Option<&MatformerSliceConfig>,
546 ) -> Result<usize> {
547 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
548
549 let elems = {
550 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
551 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
553 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
554 } else {
555 0
556 };
557 let norm = cfg.hidden_size;
558 embed_tokens + lm_head + norm
559 };
560 Ok(elems * dtype.size_in_bytes())
561 }
562
563 fn layer_sizes_in_bytes(
564 &self,
565 config: &str,
566 dtype: DType,
567 weight_pack_factor: usize,
568 _matformer_config: Option<&MatformerSliceConfig>,
569 ) -> Result<Vec<usize>> {
570 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
571
572 let per_layer_elems = {
573 let input_layernorm = cfg.hidden_size;
574 let post_attention_layernorm = cfg.hidden_size;
575
576 let size_in = cfg.hidden_size;
577 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
578 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
579 let q_proj = size_in * size_q / weight_pack_factor;
580 let k_proj = size_in * size_kv / weight_pack_factor;
581 let v_proj = size_in * size_kv / weight_pack_factor;
582 let o_proj = size_q * size_in / weight_pack_factor;
583
584 let h_size = cfg.hidden_size;
585 let i_size = cfg.intermediate_size;
586 let gate_proj = h_size * i_size / weight_pack_factor;
587 let up_proj = h_size * i_size / weight_pack_factor;
588 let down_proj = i_size * h_size / weight_pack_factor;
589
590 input_layernorm
591 + post_attention_layernorm
592 + q_proj
593 + k_proj
594 + v_proj
595 + o_proj
596 + gate_proj
597 + up_proj
598 + down_proj
599 };
600 Ok(vec![
601 per_layer_elems * dtype.size_in_bytes();
602 cfg.num_hidden_layers
603 ])
604 }
605
606 fn num_layers(&self, config: &str) -> Result<usize> {
607 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
608 Ok(cfg.num_hidden_layers)
609 }
610
611 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
612 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
613
614 let cfg = ModelConfigMetadata {
615 max_seq_len: cfg.max_position_embeddings,
616 num_layers: cfg.num_hidden_layers,
617 hidden_size: cfg.hidden_size,
618 num_kv_heads: cfg.num_key_value_heads,
619 num_attn_heads: cfg.num_attention_heads,
620 sliding_window: cfg.sliding_window,
621 k_head_dim: cfg.head_dim(),
622 v_head_dim: cfg.head_dim(),
623 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
624 };
625
626 Ok(Box::new(cfg))
627 }
628}
629
630pub struct GemmaLoader;
636
637impl NormalModelLoader for GemmaLoader {
638 fn load(
639 &self,
640 config: &str,
641 vb: ShardedVarBuilder,
642 normal_loading_metadata: NormalLoadingMetadata,
643 attention_mechanism: AttentionImplementation,
644 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
645 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
646
647 Ok(Box::new(models::gemma::Model::new(
648 &cfg,
649 vb,
650 self.is_gptx(config)?,
651 normal_loading_metadata,
652 attention_mechanism,
653 )?))
654 }
655 fn load_xlora(
656 &self,
657 config: &str,
658 vb: ShardedVarBuilder,
659 lora_config: &[((String, String), LoraConfig)],
660 xlora_config: Option<XLoraConfig>,
661 xlora_ordering: Ordering,
662 normal_loading_metadata: NormalLoadingMetadata,
663 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
664 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
665 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
666
667 Ok(Box::new(xlora_models::XLoraGemma::new(
668 &cfg,
669 vb,
670 lora_config,
671 xlora_config,
672 xlora_ordering,
673 self.is_gptx(config)?,
674 normal_loading_metadata,
675 preload_adapters,
676 )?))
677 }
678 fn is_gptx(&self, _: &str) -> Result<bool> {
679 Ok(true)
680 }
681 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
682 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
683 Ok(Box::new(cfg))
684 }
685}
686
687impl IsqModelLoader for GemmaLoader {
688 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
689 Ok(vec![
690 Regex::new(r"lm_head\.(weight|bias)$")?,
691 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
693 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
694 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
695 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
696 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
698 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
699 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
700 ])
701 }
702 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
703 self.isq_layer_regexes(config)
704 }
705}
706
707impl DeviceMappedModelLoader for GemmaLoader {
708 fn mapped_max_act_size_elems(
709 &self,
710 config: &str,
711 params: &AutoDeviceMapParams,
712 ) -> Result<usize> {
713 let AutoDeviceMapParams::Text {
714 max_seq_len,
715 max_batch_size,
716 } = params
717 else {
718 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
719 };
720
721 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
722
723 Ok(
724 max_batch_size
725 * cfg.num_attention_heads
726 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
727 )
728 }
729 fn non_mapped_max_act_size_elems(
730 &self,
731 _config: &str,
732 _params: &AutoDeviceMapParams,
733 ) -> Result<usize> {
734 Ok(0)
735 }
736
737 fn non_mapped_size_in_bytes(
738 &self,
739 config: &str,
740 dtype: DType,
741 weight_pack_factor: usize,
742 _matformer_config: Option<&MatformerSliceConfig>,
743 ) -> Result<usize> {
744 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
745
746 let elems = {
747 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
748 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
750 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
751 } else {
752 0
753 };
754 let norm = cfg.hidden_size;
755 embed_tokens + lm_head + norm
756 };
757 Ok(elems * dtype.size_in_bytes())
758 }
759
760 fn layer_sizes_in_bytes(
761 &self,
762 config: &str,
763 dtype: DType,
764 weight_pack_factor: usize,
765 _matformer_config: Option<&MatformerSliceConfig>,
766 ) -> Result<Vec<usize>> {
767 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
768
769 let per_layer_elems = {
770 let input_layernorm = cfg.hidden_size;
771 let post_attention_layernorm = cfg.hidden_size;
772
773 let size_in = cfg.hidden_size;
774 let size_q = cfg.head_dim * cfg.num_attention_heads;
775 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
776 let q_proj =
777 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
778 let k_proj =
779 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
780 let v_proj =
781 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
782 let o_proj =
783 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
784
785 let h_size = cfg.hidden_size;
786 let i_size = cfg.intermediate_size;
787 let gate_proj = h_size * i_size / weight_pack_factor;
788 let up_proj = h_size * i_size / weight_pack_factor;
789 let down_proj = i_size * h_size / weight_pack_factor;
790
791 input_layernorm
792 + post_attention_layernorm
793 + q_proj
794 + k_proj
795 + v_proj
796 + o_proj
797 + gate_proj
798 + up_proj
799 + down_proj
800 };
801 Ok(vec![
802 per_layer_elems * dtype.size_in_bytes();
803 cfg.num_hidden_layers
804 ])
805 }
806
807 fn num_layers(&self, config: &str) -> Result<usize> {
808 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
809 Ok(cfg.num_hidden_layers)
810 }
811
812 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
813 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
814
815 let cfg = ModelConfigMetadata {
816 max_seq_len: cfg.max_position_embeddings,
817 num_layers: cfg.num_hidden_layers,
818 hidden_size: cfg.hidden_size,
819 num_kv_heads: cfg.num_key_value_heads,
820 num_attn_heads: cfg.num_attention_heads,
821 sliding_window: None,
822 k_head_dim: cfg.head_dim,
823 v_head_dim: cfg.head_dim,
824 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
825 };
826
827 Ok(Box::new(cfg))
828 }
829}
830
831pub struct LlamaLoader;
837
838impl NormalModelLoader for LlamaLoader {
839 fn load(
840 &self,
841 config: &str,
842 vb: ShardedVarBuilder,
843 normal_loading_metadata: NormalLoadingMetadata,
844 attention_mechanism: AttentionImplementation,
845 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
846 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
847
848 Ok(Box::new(models::llama::Llama::new(
849 &cfg,
850 vb,
851 self.is_gptx(config)?,
852 normal_loading_metadata,
853 attention_mechanism,
854 )?))
855 }
856 fn load_xlora(
857 &self,
858 config: &str,
859 vb: ShardedVarBuilder,
860 lora_config: &[((String, String), LoraConfig)],
861 xlora_config: Option<XLoraConfig>,
862 xlora_ordering: Ordering,
863 normal_loading_metadata: NormalLoadingMetadata,
864 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
865 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
866 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
867
868 Ok(Box::new(xlora_models::XLoraLlama::new(
869 &cfg,
870 vb,
871 lora_config,
872 xlora_config,
873 xlora_ordering,
874 self.is_gptx(config)?,
875 normal_loading_metadata,
876 preload_adapters,
877 )?))
878 }
879 fn is_gptx(&self, _: &str) -> Result<bool> {
880 Ok(true)
881 }
882 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
883 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
884 Ok(Box::new(cfg))
885 }
886}
887
888impl IsqModelLoader for LlamaLoader {
889 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
890 Ok(vec![
891 Regex::new(r"lm_head\.(weight|bias)$")?,
892 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
894 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
895 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
896 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
897 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
899 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
900 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
901 ])
902 }
903 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
904 self.isq_layer_regexes(config)
905 }
906}
907
908impl DeviceMappedModelLoader for LlamaLoader {
909 fn mapped_max_act_size_elems(
910 &self,
911 config: &str,
912 params: &AutoDeviceMapParams,
913 ) -> Result<usize> {
914 let AutoDeviceMapParams::Text {
915 max_seq_len,
916 max_batch_size,
917 } = params
918 else {
919 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
920 };
921
922 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
923
924 Ok(
925 max_batch_size
926 * cfg.num_attention_heads
927 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
928 )
929 }
930 fn non_mapped_max_act_size_elems(
931 &self,
932 _config: &str,
933 _params: &AutoDeviceMapParams,
934 ) -> Result<usize> {
935 Ok(0)
936 }
937
938 fn non_mapped_size_in_bytes(
939 &self,
940 config: &str,
941 dtype: DType,
942 weight_pack_factor: usize,
943 _matformer_config: Option<&MatformerSliceConfig>,
944 ) -> Result<usize> {
945 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
946
947 let elems = {
948 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
949 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
951 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
952 } else {
953 0
954 };
955 let norm = cfg.hidden_size;
956 embed_tokens + lm_head + norm
957 };
958 Ok(elems * dtype.size_in_bytes())
959 }
960
961 fn layer_sizes_in_bytes(
962 &self,
963 config: &str,
964 dtype: DType,
965 weight_pack_factor: usize,
966 _matformer_config: Option<&MatformerSliceConfig>,
967 ) -> Result<Vec<usize>> {
968 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
969
970 let per_layer_elems = {
971 let input_layernorm = cfg.hidden_size;
972 let post_attention_layernorm = cfg.hidden_size;
973
974 let size_in = cfg.hidden_size;
975 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
976 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
977 let q_proj = size_in * size_q / weight_pack_factor;
978 let k_proj = size_in * size_kv / weight_pack_factor;
979 let v_proj = size_in * size_kv / weight_pack_factor;
980 let o_proj = size_q * size_in / weight_pack_factor;
981
982 let h_size = cfg.hidden_size;
983 let i_size = cfg.intermediate_size;
984 let gate_proj = h_size * i_size / weight_pack_factor;
985 let up_proj = h_size * i_size / weight_pack_factor;
986 let down_proj = i_size * h_size / weight_pack_factor;
987
988 input_layernorm
989 + post_attention_layernorm
990 + q_proj
991 + k_proj
992 + v_proj
993 + o_proj
994 + gate_proj
995 + up_proj
996 + down_proj
997 };
998 Ok(vec![
999 per_layer_elems * dtype.size_in_bytes();
1000 cfg.num_hidden_layers
1001 ])
1002 }
1003
1004 fn num_layers(&self, config: &str) -> Result<usize> {
1005 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1006
1007 Ok(cfg.num_hidden_layers)
1008 }
1009 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1010 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1011
1012 let cfg = ModelConfigMetadata {
1013 max_seq_len: cfg.max_position_embeddings,
1014 num_layers: cfg.num_hidden_layers,
1015 hidden_size: cfg.hidden_size,
1016 num_kv_heads: cfg.num_key_value_heads,
1017 num_attn_heads: cfg.num_attention_heads,
1018 sliding_window: None,
1019 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1020 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1021 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1022 };
1023
1024 Ok(Box::new(cfg))
1025 }
1026}
1027
1028pub struct MixtralLoader;
1031
1032impl NormalModelLoader for MixtralLoader {
1033 fn load(
1034 &self,
1035 config: &str,
1036 vb: ShardedVarBuilder,
1037 normal_loading_metadata: NormalLoadingMetadata,
1038 attention_mechanism: AttentionImplementation,
1039 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1040 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1041
1042 Ok(Box::new(models::mixtral::Model::new(
1043 &cfg,
1044 vb,
1045 self.is_gptx(config)?,
1046 normal_loading_metadata,
1047 attention_mechanism,
1048 )?))
1049 }
1050 fn load_xlora(
1051 &self,
1052 config: &str,
1053 vb: ShardedVarBuilder,
1054 lora_config: &[((String, String), LoraConfig)],
1055 xlora_config: Option<XLoraConfig>,
1056 xlora_ordering: Ordering,
1057 normal_loading_metadata: NormalLoadingMetadata,
1058 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1059 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1060 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1061
1062 Ok(Box::new(xlora_models::XLoraMixtral::new(
1063 &cfg,
1064 vb,
1065 lora_config,
1066 xlora_config,
1067 xlora_ordering,
1068 self.is_gptx(config)?,
1069 normal_loading_metadata,
1070 preload_adapters,
1071 )?))
1072 }
1073 fn is_gptx(&self, _: &str) -> Result<bool> {
1074 Ok(true)
1075 }
1076 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1077 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1078
1079 Ok(Box::new(cfg))
1080 }
1081}
1082
1083impl IsqModelLoader for MixtralLoader {
1084 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1085 Ok(vec![
1086 Regex::new(r"lm_head\.(weight|bias)$")?,
1087 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1089 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1090 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1091 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1092 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1094 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1095 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1096 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1097 ])
1098 }
1099 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1100 self.isq_layer_regexes(config)
1101 }
1102}
1103
1104impl DeviceMappedModelLoader for MixtralLoader {
1105 fn mapped_max_act_size_elems(
1106 &self,
1107 config: &str,
1108 params: &AutoDeviceMapParams,
1109 ) -> Result<usize> {
1110 let AutoDeviceMapParams::Text {
1111 max_seq_len,
1112 max_batch_size,
1113 } = params
1114 else {
1115 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1116 };
1117
1118 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1119
1120 Ok(
1121 max_batch_size
1122 * cfg.num_attention_heads
1123 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1124 )
1125 }
1126 fn non_mapped_max_act_size_elems(
1127 &self,
1128 _config: &str,
1129 _params: &AutoDeviceMapParams,
1130 ) -> Result<usize> {
1131 Ok(0)
1132 }
1133
1134 fn non_mapped_size_in_bytes(
1135 &self,
1136 config: &str,
1137 dtype: DType,
1138 weight_pack_factor: usize,
1139 _matformer_config: Option<&MatformerSliceConfig>,
1140 ) -> Result<usize> {
1141 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1142
1143 let elems = {
1144 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1145 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1147 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1148 } else {
1149 0
1150 };
1151 let norm = cfg.hidden_size;
1152 embed_tokens + lm_head + norm
1153 };
1154 Ok(elems * dtype.size_in_bytes())
1155 }
1156
1157 fn layer_sizes_in_bytes(
1158 &self,
1159 config: &str,
1160 dtype: DType,
1161 weight_pack_factor: usize,
1162 _matformer_config: Option<&MatformerSliceConfig>,
1163 ) -> Result<Vec<usize>> {
1164 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1165
1166 let per_layer_elems = {
1167 let input_layernorm = cfg.hidden_size;
1168 let post_attention_layernorm = cfg.hidden_size;
1169
1170 let size_in = cfg.hidden_size;
1171 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1172 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1173 let q_proj = size_in * size_q / weight_pack_factor;
1174 let k_proj = size_in * size_kv / weight_pack_factor;
1175 let v_proj = size_in * size_kv / weight_pack_factor;
1176 let o_proj = size_q * size_in / weight_pack_factor;
1177
1178 let moe_block = {
1179 let gate = cfg.hidden_size * cfg.num_local_experts;
1180 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1182 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1183 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1184 gate + cfg.num_local_experts * w1
1185 + cfg.num_local_experts * w2
1186 + cfg.num_local_experts * w3
1187 };
1188
1189 input_layernorm
1190 + post_attention_layernorm
1191 + q_proj
1192 + k_proj
1193 + v_proj
1194 + o_proj
1195 + moe_block
1196 };
1197 Ok(vec![
1198 per_layer_elems * dtype.size_in_bytes();
1199 cfg.num_hidden_layers
1200 ])
1201 }
1202
1203 fn num_layers(&self, config: &str) -> Result<usize> {
1204 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1205
1206 Ok(cfg.num_hidden_layers)
1207 }
1208
1209 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1210 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1211
1212 let cfg = ModelConfigMetadata {
1213 max_seq_len: cfg.max_position_embeddings,
1214 num_layers: cfg.num_hidden_layers,
1215 hidden_size: cfg.hidden_size,
1216 num_kv_heads: cfg.num_key_value_heads,
1217 num_attn_heads: cfg.num_attention_heads,
1218 sliding_window: cfg.sliding_window,
1219 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1220 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1221 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1222 };
1223
1224 Ok(Box::new(cfg))
1225 }
1226}
1227
1228pub struct Phi2Loader;
1234
1235impl NormalModelLoader for Phi2Loader {
1236 fn load(
1237 &self,
1238 config: &str,
1239 vb: ShardedVarBuilder,
1240 normal_loading_metadata: NormalLoadingMetadata,
1241 attention_mechanism: AttentionImplementation,
1242 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1243 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1244
1245 Ok(Box::new(models::phi2::Model::new(
1246 &cfg,
1247 vb,
1248 self.is_gptx(config)?,
1249 normal_loading_metadata,
1250 attention_mechanism,
1251 )?))
1252 }
1253 fn load_xlora(
1254 &self,
1255 config: &str,
1256 vb: ShardedVarBuilder,
1257 lora_config: &[((String, String), LoraConfig)],
1258 xlora_config: Option<XLoraConfig>,
1259 xlora_ordering: Ordering,
1260 normal_loading_metadata: NormalLoadingMetadata,
1261 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1262 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1263 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1264
1265 Ok(Box::new(xlora_models::XLoraPhi2::new(
1266 &cfg,
1267 vb,
1268 lora_config,
1269 xlora_config,
1270 xlora_ordering,
1271 self.is_gptx(config)?,
1272 normal_loading_metadata,
1273 preload_adapters,
1274 )?))
1275 }
1276 fn is_gptx(&self, _: &str) -> Result<bool> {
1277 Ok(true)
1278 }
1279 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1280 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1281
1282 Ok(Box::new(cfg))
1283 }
1284}
1285
1286impl IsqModelLoader for Phi2Loader {
1287 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1288 Ok(vec![
1289 Regex::new(r"lm_head\.(weight|bias)$")?,
1290 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1292 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1293 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1294 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1295 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1297 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1298 ])
1299 }
1300 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1301 self.isq_layer_regexes(config)
1302 }
1303}
1304
1305impl DeviceMappedModelLoader for Phi2Loader {
1306 fn mapped_max_act_size_elems(
1307 &self,
1308 config: &str,
1309 params: &AutoDeviceMapParams,
1310 ) -> Result<usize> {
1311 let AutoDeviceMapParams::Text {
1312 max_seq_len,
1313 max_batch_size,
1314 } = params
1315 else {
1316 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1317 };
1318
1319 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1320
1321 Ok(
1322 max_batch_size
1323 * cfg.num_attention_heads
1324 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1325 )
1326 }
1327 fn non_mapped_max_act_size_elems(
1328 &self,
1329 _config: &str,
1330 _params: &AutoDeviceMapParams,
1331 ) -> Result<usize> {
1332 Ok(0)
1333 }
1334
1335 fn non_mapped_size_in_bytes(
1336 &self,
1337 config: &str,
1338 dtype: DType,
1339 weight_pack_factor: usize,
1340 _matformer_config: Option<&MatformerSliceConfig>,
1341 ) -> Result<usize> {
1342 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1343
1344 let elems = {
1345 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1346 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1348 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1349 } else {
1350 0
1351 };
1352 let norm = cfg.hidden_size;
1353 embed_tokens + lm_head + norm
1354 };
1355 Ok(elems * dtype.size_in_bytes())
1356 }
1357
1358 fn layer_sizes_in_bytes(
1359 &self,
1360 config: &str,
1361 dtype: DType,
1362 weight_pack_factor: usize,
1363 _matformer_config: Option<&MatformerSliceConfig>,
1364 ) -> Result<Vec<usize>> {
1365 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1366
1367 let per_layer_elems = {
1368 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1369
1370 let size_in = cfg.hidden_size;
1371 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1372 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1373 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1374 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1375 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1376 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1377 let (q_norm, k_norm) = if cfg.qk_layernorm {
1378 (cfg.head_dim(), cfg.head_dim())
1379 } else {
1380 (0, 0)
1381 };
1382
1383 let h_size = cfg.hidden_size;
1384 let i_size = cfg.intermediate_size;
1385 let fc1 = h_size * i_size / weight_pack_factor;
1386 let fc2 = h_size * i_size / weight_pack_factor;
1387
1388 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1389 };
1390 Ok(vec![
1391 per_layer_elems * dtype.size_in_bytes();
1392 cfg.num_hidden_layers
1393 ])
1394 }
1395
1396 fn num_layers(&self, config: &str) -> Result<usize> {
1397 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1398
1399 Ok(cfg.num_hidden_layers)
1400 }
1401
1402 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1403 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1404
1405 let cfg = ModelConfigMetadata {
1406 max_seq_len: cfg.max_position_embeddings,
1407 num_layers: cfg.num_hidden_layers,
1408 hidden_size: cfg.hidden_size,
1409 num_kv_heads: cfg.num_key_value_heads(),
1410 num_attn_heads: cfg.num_attention_heads,
1411 sliding_window: None,
1412 k_head_dim: cfg.head_dim(),
1413 v_head_dim: cfg.head_dim(),
1414 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1415 };
1416
1417 Ok(Box::new(cfg))
1418 }
1419}
1420
1421pub struct Phi3Loader;
1427
1428impl NormalModelLoader for Phi3Loader {
1429 fn load(
1430 &self,
1431 config: &str,
1432 vb: ShardedVarBuilder,
1433 normal_loading_metadata: NormalLoadingMetadata,
1434 attention_mechanism: AttentionImplementation,
1435 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1436 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1437
1438 Ok(Box::new(models::phi3::Model::new(
1439 &cfg,
1440 vb,
1441 self.is_gptx(config)?,
1442 normal_loading_metadata,
1443 attention_mechanism,
1444 )?))
1445 }
1446 fn load_xlora(
1447 &self,
1448 config: &str,
1449 vb: ShardedVarBuilder,
1450 lora_config: &[((String, String), LoraConfig)],
1451 xlora_config: Option<XLoraConfig>,
1452 xlora_ordering: Ordering,
1453 normal_loading_metadata: NormalLoadingMetadata,
1454 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1455 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1456 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1457
1458 Ok(Box::new(xlora_models::XLoraPhi3::new(
1459 &cfg,
1460 vb,
1461 lora_config,
1462 xlora_config,
1463 xlora_ordering,
1464 self.is_gptx(config)?,
1465 normal_loading_metadata,
1466 preload_adapters,
1467 )?))
1468 }
1469 fn is_gptx(&self, _: &str) -> Result<bool> {
1470 Ok(true)
1471 }
1472 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1473 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1474
1475 Ok(Box::new(cfg))
1476 }
1477}
1478
1479impl IsqModelLoader for Phi3Loader {
1480 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1481 Ok(vec![
1482 Regex::new(r"lm_head\.(weight|bias)$")?,
1483 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1485 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1486 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1488 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1489 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1490 ])
1491 }
1492 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1493 self.isq_layer_regexes(config)
1494 }
1495}
1496
1497impl DeviceMappedModelLoader for Phi3Loader {
1498 fn mapped_max_act_size_elems(
1499 &self,
1500 config: &str,
1501 params: &AutoDeviceMapParams,
1502 ) -> Result<usize> {
1503 let AutoDeviceMapParams::Text {
1504 max_seq_len,
1505 max_batch_size,
1506 } = params
1507 else {
1508 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1509 };
1510
1511 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1512
1513 Ok(
1514 max_batch_size
1515 * cfg.num_attention_heads
1516 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1517 )
1518 }
1519 fn non_mapped_max_act_size_elems(
1520 &self,
1521 _config: &str,
1522 _params: &AutoDeviceMapParams,
1523 ) -> Result<usize> {
1524 Ok(0)
1525 }
1526
1527 fn non_mapped_size_in_bytes(
1528 &self,
1529 config: &str,
1530 dtype: DType,
1531 weight_pack_factor: usize,
1532 _matformer_config: Option<&MatformerSliceConfig>,
1533 ) -> Result<usize> {
1534 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1535
1536 let elems = {
1537 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1538 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1540 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1541 } else {
1542 0
1543 };
1544 let norm = cfg.hidden_size;
1545 embed_tokens + lm_head + norm
1546 };
1547 Ok(elems * dtype.size_in_bytes())
1548 }
1549
1550 fn layer_sizes_in_bytes(
1551 &self,
1552 config: &str,
1553 dtype: DType,
1554 weight_pack_factor: usize,
1555 _matformer_config: Option<&MatformerSliceConfig>,
1556 ) -> Result<Vec<usize>> {
1557 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1558
1559 let per_layer_elems = {
1560 let input_layernorm = cfg.hidden_size;
1561 let post_attention_layernorm = cfg.hidden_size;
1562
1563 let size_in = cfg.hidden_size;
1564 let head_dim = cfg.head_dim();
1565 let op_size =
1566 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1567 let qkv_proj = size_in * op_size / weight_pack_factor;
1568 let o_proj =
1569 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1570
1571 let h_size = cfg.hidden_size;
1572 let i_size = cfg.intermediate_size;
1573 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1574 let down_proj = h_size * i_size / weight_pack_factor;
1575
1576 input_layernorm
1577 + post_attention_layernorm
1578 + qkv_proj
1579 + o_proj
1580 + gate_up_proj
1581 + down_proj
1582 };
1583 Ok(vec![
1584 per_layer_elems * dtype.size_in_bytes();
1585 cfg.num_hidden_layers
1586 ])
1587 }
1588
1589 fn num_layers(&self, config: &str) -> Result<usize> {
1590 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1591
1592 Ok(cfg.num_hidden_layers)
1593 }
1594
1595 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1596 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1597
1598 let cfg = ModelConfigMetadata {
1599 max_seq_len: cfg.max_position_embeddings,
1600 num_layers: cfg.num_hidden_layers,
1601 hidden_size: cfg.hidden_size,
1602 num_kv_heads: cfg.num_key_value_heads,
1603 num_attn_heads: cfg.num_attention_heads,
1604 sliding_window: cfg.sliding_window,
1605 k_head_dim: cfg.head_dim(),
1606 v_head_dim: cfg.head_dim(),
1607 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1608 };
1609
1610 Ok(Box::new(cfg))
1611 }
1612}
1613
1614pub struct Qwen2Loader;
1620
1621impl NormalModelLoader for Qwen2Loader {
1622 fn load(
1623 &self,
1624 config: &str,
1625 vb: ShardedVarBuilder,
1626 normal_loading_metadata: NormalLoadingMetadata,
1627 attention_mechanism: AttentionImplementation,
1628 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1629 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1630
1631 Ok(Box::new(models::qwen2::Model::new(
1632 &cfg,
1633 vb,
1634 self.is_gptx(config)?,
1635 normal_loading_metadata,
1636 attention_mechanism,
1637 )?))
1638 }
1639 fn load_xlora(
1640 &self,
1641 _config: &str,
1642 _vb: ShardedVarBuilder,
1643 _lora_config: &[((String, String), LoraConfig)],
1644 _xlora_config: Option<XLoraConfig>,
1645 _xlora_ordering: Ordering,
1646 _normal_loading_metadata: NormalLoadingMetadata,
1647 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1648 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1649 todo!()
1650 }
1651 fn is_gptx(&self, _: &str) -> Result<bool> {
1652 Ok(true)
1653 }
1654 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1655 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1656
1657 Ok(Box::new(cfg))
1658 }
1659}
1660
1661impl IsqModelLoader for Qwen2Loader {
1662 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1663 Ok(vec![
1664 Regex::new(r"lm_head\.(weight|bias)$")?,
1665 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1667 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1668 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1669 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1670 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1672 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1673 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1674 ])
1675 }
1676 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1677 self.isq_layer_regexes(config)
1678 }
1679}
1680
1681impl DeviceMappedModelLoader for Qwen2Loader {
1682 fn mapped_max_act_size_elems(
1683 &self,
1684 config: &str,
1685 params: &AutoDeviceMapParams,
1686 ) -> Result<usize> {
1687 let AutoDeviceMapParams::Text {
1688 max_seq_len,
1689 max_batch_size,
1690 } = params
1691 else {
1692 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1693 };
1694
1695 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1696
1697 Ok(
1698 max_batch_size
1699 * cfg.num_attention_heads
1700 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1701 )
1702 }
1703 fn non_mapped_max_act_size_elems(
1704 &self,
1705 _config: &str,
1706 _params: &AutoDeviceMapParams,
1707 ) -> Result<usize> {
1708 Ok(0)
1709 }
1710
1711 fn non_mapped_size_in_bytes(
1712 &self,
1713 config: &str,
1714 dtype: DType,
1715 weight_pack_factor: usize,
1716 _matformer_config: Option<&MatformerSliceConfig>,
1717 ) -> Result<usize> {
1718 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1719
1720 let elems = {
1721 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1722 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1724 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1725 } else {
1726 0
1727 };
1728 let norm = cfg.hidden_size;
1729 embed_tokens + lm_head + norm
1730 };
1731 Ok(elems * dtype.size_in_bytes())
1732 }
1733
1734 fn layer_sizes_in_bytes(
1735 &self,
1736 config: &str,
1737 dtype: DType,
1738 weight_pack_factor: usize,
1739 _matformer_config: Option<&MatformerSliceConfig>,
1740 ) -> Result<Vec<usize>> {
1741 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1742
1743 let per_layer_elems = {
1744 let input_layernorm = cfg.hidden_size;
1745 let post_attention_layernorm = cfg.hidden_size;
1746
1747 let size_in = cfg.hidden_size;
1748 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1749 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1750 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1751 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1752 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1753 let o_proj = size_q * size_in / weight_pack_factor;
1754
1755 let h_size = cfg.hidden_size;
1756 let i_size = cfg.intermediate_size;
1757 let gate_proj = h_size * i_size / weight_pack_factor;
1758 let up_proj = h_size * i_size / weight_pack_factor;
1759 let down_proj = i_size * h_size / weight_pack_factor;
1760
1761 input_layernorm
1762 + post_attention_layernorm
1763 + q_proj
1764 + k_proj
1765 + v_proj
1766 + o_proj
1767 + gate_proj
1768 + up_proj
1769 + down_proj
1770 };
1771 Ok(vec![
1772 per_layer_elems * dtype.size_in_bytes();
1773 cfg.num_hidden_layers
1774 ])
1775 }
1776
1777 fn num_layers(&self, config: &str) -> Result<usize> {
1778 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1779
1780 Ok(cfg.num_hidden_layers)
1781 }
1782
1783 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1784 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1785
1786 let cfg = ModelConfigMetadata {
1787 max_seq_len: cfg.max_position_embeddings,
1788 num_layers: cfg.num_hidden_layers,
1789 hidden_size: cfg.hidden_size,
1790 num_kv_heads: cfg.num_key_value_heads,
1791 num_attn_heads: cfg.num_attention_heads,
1792 sliding_window: cfg.sliding_window,
1793 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1794 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1795 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1796 };
1797
1798 Ok(Box::new(cfg))
1799 }
1800}
1801
1802pub struct Gemma2Loader;
1808
1809impl NormalModelLoader for Gemma2Loader {
1810 fn load(
1811 &self,
1812 config: &str,
1813 vb: ShardedVarBuilder,
1814 normal_loading_metadata: NormalLoadingMetadata,
1815 attention_mechanism: AttentionImplementation,
1816 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1817 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1818
1819 Ok(Box::new(models::gemma2::Model::new(
1820 &cfg,
1821 vb,
1822 self.is_gptx(config)?,
1823 normal_loading_metadata,
1824 attention_mechanism,
1825 )?))
1826 }
1827 fn load_xlora(
1828 &self,
1829 config: &str,
1830 vb: ShardedVarBuilder,
1831 lora_config: &[((String, String), LoraConfig)],
1832 xlora_config: Option<XLoraConfig>,
1833 xlora_ordering: Ordering,
1834 normal_loading_metadata: NormalLoadingMetadata,
1835 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1836 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1837 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1838
1839 Ok(Box::new(xlora_models::XLoraGemma2::new(
1840 &cfg,
1841 vb,
1842 lora_config,
1843 xlora_config,
1844 xlora_ordering,
1845 self.is_gptx(config)?,
1846 normal_loading_metadata,
1847 preload_adapters,
1848 )?))
1849 }
1850 fn is_gptx(&self, _: &str) -> Result<bool> {
1851 Ok(true)
1852 }
1853 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1854 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1855
1856 Ok(Box::new(cfg))
1857 }
1858}
1859
1860impl IsqModelLoader for Gemma2Loader {
1861 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1862 Ok(vec![
1863 Regex::new(r"lm_head\.(weight|bias)$")?,
1864 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1866 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1867 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1868 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1869 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1871 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1872 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1873 ])
1874 }
1875 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1876 self.isq_layer_regexes(config)
1877 }
1878}
1879
1880impl DeviceMappedModelLoader for Gemma2Loader {
1881 fn mapped_max_act_size_elems(
1882 &self,
1883 config: &str,
1884 params: &AutoDeviceMapParams,
1885 ) -> Result<usize> {
1886 let AutoDeviceMapParams::Text {
1887 max_seq_len,
1888 max_batch_size,
1889 } = params
1890 else {
1891 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1892 };
1893
1894 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1895
1896 Ok(
1897 max_batch_size
1898 * cfg.num_attention_heads
1899 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1900 )
1901 }
1902 fn non_mapped_max_act_size_elems(
1903 &self,
1904 _config: &str,
1905 _params: &AutoDeviceMapParams,
1906 ) -> Result<usize> {
1907 Ok(0)
1908 }
1909
1910 fn non_mapped_size_in_bytes(
1911 &self,
1912 config: &str,
1913 dtype: DType,
1914 weight_pack_factor: usize,
1915 _matformer_config: Option<&MatformerSliceConfig>,
1916 ) -> Result<usize> {
1917 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1918
1919 let elems = {
1920 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1921 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1923 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1924 } else {
1925 0
1926 };
1927 let norm = cfg.hidden_size;
1928 embed_tokens + lm_head + norm
1929 };
1930 Ok(elems * dtype.size_in_bytes())
1931 }
1932
1933 fn layer_sizes_in_bytes(
1934 &self,
1935 config: &str,
1936 dtype: DType,
1937 weight_pack_factor: usize,
1938 _matformer_config: Option<&MatformerSliceConfig>,
1939 ) -> Result<Vec<usize>> {
1940 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1941
1942 let per_layer_elems = {
1943 let input_layernorm = cfg.hidden_size;
1944 let post_attention_layernorm = cfg.hidden_size;
1945
1946 let size_in = cfg.hidden_size;
1947 let size_q = cfg.head_dim * cfg.num_attention_heads;
1948 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1949 let q_proj =
1950 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1951 let k_proj =
1952 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1953 let v_proj =
1954 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1955 let o_proj =
1956 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1957
1958 let h_size = cfg.hidden_size;
1959 let i_size = cfg.intermediate_size;
1960 let gate_proj = h_size * i_size / weight_pack_factor;
1961 let up_proj = h_size * i_size / weight_pack_factor;
1962 let down_proj = i_size * h_size / weight_pack_factor;
1963
1964 input_layernorm
1965 + post_attention_layernorm
1966 + q_proj
1967 + k_proj
1968 + v_proj
1969 + o_proj
1970 + gate_proj
1971 + up_proj
1972 + down_proj
1973 };
1974 Ok(vec![
1975 per_layer_elems * dtype.size_in_bytes();
1976 cfg.num_hidden_layers
1977 ])
1978 }
1979
1980 fn num_layers(&self, config: &str) -> Result<usize> {
1981 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1982
1983 Ok(cfg.num_hidden_layers)
1984 }
1985 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1986 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1987
1988 let cfg = ModelConfigMetadata {
1989 max_seq_len: cfg.max_position_embeddings,
1990 num_layers: cfg.num_hidden_layers,
1991 hidden_size: cfg.hidden_size,
1992 num_kv_heads: cfg.num_key_value_heads,
1993 num_attn_heads: cfg.num_attention_heads,
1994 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1996 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1997 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1998 };
1999
2000 Ok(Box::new(cfg))
2001 }
2002}
2003
2004pub struct Starcoder2Loader;
2010
2011impl NormalModelLoader for Starcoder2Loader {
2012 fn load(
2013 &self,
2014 config: &str,
2015 vb: ShardedVarBuilder,
2016 normal_loading_metadata: NormalLoadingMetadata,
2017 attention_mechanism: AttentionImplementation,
2018 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2019 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2020
2021 Ok(Box::new(models::starcoder2::Model::new(
2022 &cfg,
2023 vb,
2024 self.is_gptx(config)?,
2025 normal_loading_metadata,
2026 attention_mechanism,
2027 )?))
2028 }
2029 fn load_xlora(
2030 &self,
2031 config: &str,
2032 vb: ShardedVarBuilder,
2033 lora_config: &[((String, String), LoraConfig)],
2034 xlora_config: Option<XLoraConfig>,
2035 xlora_ordering: Ordering,
2036 normal_loading_metadata: NormalLoadingMetadata,
2037 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2038 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2039 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2040
2041 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2042 &cfg,
2043 vb,
2044 lora_config,
2045 xlora_config,
2046 xlora_ordering,
2047 self.is_gptx(config)?,
2048 normal_loading_metadata,
2049 preload_adapters,
2050 )?))
2051 }
2052 fn is_gptx(&self, _: &str) -> Result<bool> {
2053 Ok(true)
2054 }
2055 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2056 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2057
2058 Ok(Box::new(cfg))
2059 }
2060}
2061
2062impl IsqModelLoader for Starcoder2Loader {
2063 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2064 Ok(vec![
2065 Regex::new(r"lm_head\.(weight|bias)$")?,
2066 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2068 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2069 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2070 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2071 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2073 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2074 ])
2075 }
2076 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2077 self.isq_layer_regexes(config)
2078 }
2079}
2080
2081impl DeviceMappedModelLoader for Starcoder2Loader {
2082 fn mapped_max_act_size_elems(
2083 &self,
2084 config: &str,
2085 params: &AutoDeviceMapParams,
2086 ) -> Result<usize> {
2087 let AutoDeviceMapParams::Text {
2088 max_seq_len,
2089 max_batch_size,
2090 } = params
2091 else {
2092 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2093 };
2094
2095 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2096
2097 Ok(
2098 max_batch_size
2099 * cfg.num_attention_heads
2100 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2101 )
2102 }
2103 fn non_mapped_max_act_size_elems(
2104 &self,
2105 _config: &str,
2106 _params: &AutoDeviceMapParams,
2107 ) -> Result<usize> {
2108 Ok(0)
2109 }
2110
2111 fn non_mapped_size_in_bytes(
2112 &self,
2113 config: &str,
2114 dtype: DType,
2115 weight_pack_factor: usize,
2116 _matformer_config: Option<&MatformerSliceConfig>,
2117 ) -> Result<usize> {
2118 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2119
2120 let elems = {
2121 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2122 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2124 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2125 } else {
2126 0
2127 };
2128 let norm = cfg.hidden_size + cfg.hidden_size;
2129 embed_tokens + lm_head + norm
2130 };
2131 Ok(elems * dtype.size_in_bytes())
2132 }
2133
2134 fn layer_sizes_in_bytes(
2135 &self,
2136 config: &str,
2137 dtype: DType,
2138 weight_pack_factor: usize,
2139 _matformer_config: Option<&MatformerSliceConfig>,
2140 ) -> Result<Vec<usize>> {
2141 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2142
2143 let per_layer_elems = {
2144 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2145 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2146
2147 let size_in = cfg.hidden_size;
2148 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2149 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2150 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2151 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2152 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2153 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2154
2155 let h_size = cfg.hidden_size;
2156 let i_size = cfg.intermediate_size;
2157 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2158 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2159
2160 input_layernorm
2161 + post_attention_layernorm
2162 + q_proj
2163 + k_proj
2164 + v_proj
2165 + o_proj
2166 + fc1
2167 + fc2
2168 };
2169 Ok(vec![
2170 per_layer_elems * dtype.size_in_bytes();
2171 cfg.num_hidden_layers
2172 ])
2173 }
2174
2175 fn num_layers(&self, config: &str) -> Result<usize> {
2176 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2177
2178 Ok(cfg.num_hidden_layers)
2179 }
2180
2181 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2182 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2183
2184 let cfg = ModelConfigMetadata {
2185 max_seq_len: cfg.max_position_embeddings,
2186 num_layers: cfg.num_hidden_layers,
2187 hidden_size: cfg.hidden_size,
2188 num_kv_heads: cfg.num_key_value_heads,
2189 num_attn_heads: cfg.num_attention_heads,
2190 sliding_window: cfg.sliding_window,
2191 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2192 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2193 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2194 };
2195
2196 Ok(Box::new(cfg))
2197 }
2198}
2199
2200pub struct Phi3_5MoELoader;
2206
2207impl NormalModelLoader for Phi3_5MoELoader {
2208 fn load(
2209 &self,
2210 config: &str,
2211 vb: ShardedVarBuilder,
2212 normal_loading_metadata: NormalLoadingMetadata,
2213 attention_mechanism: AttentionImplementation,
2214 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2215 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2216
2217 Ok(Box::new(models::phi3_5_moe::Model::new(
2218 &cfg,
2219 vb,
2220 self.is_gptx(config)?,
2221 normal_loading_metadata,
2222 attention_mechanism,
2223 )?))
2224 }
2225 fn load_xlora(
2226 &self,
2227 config: &str,
2228 vb: ShardedVarBuilder,
2229 lora_config: &[((String, String), LoraConfig)],
2230 xlora_config: Option<XLoraConfig>,
2231 xlora_ordering: Ordering,
2232 normal_loading_metadata: NormalLoadingMetadata,
2233 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2234 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2235 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2236
2237 Ok(Box::new(xlora_models::XLoraPhi3::new(
2238 &cfg,
2239 vb,
2240 lora_config,
2241 xlora_config,
2242 xlora_ordering,
2243 self.is_gptx(config)?,
2244 normal_loading_metadata,
2245 preload_adapters,
2246 )?))
2247 }
2248 fn is_gptx(&self, _: &str) -> Result<bool> {
2249 Ok(true)
2250 }
2251 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2252 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2253
2254 Ok(Box::new(cfg))
2255 }
2256}
2257
2258impl IsqModelLoader for Phi3_5MoELoader {
2259 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2260 Ok(vec![
2261 Regex::new(r"lm_head\.(weight|bias)$")?,
2262 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2264 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2265 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2266 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2267 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2269 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2270 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2271 ])
2272 }
2273 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2274 self.isq_layer_regexes(config)
2275 }
2276
2277 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2278 Ok(vec![
2279 Regex::new(r"lm_head\.(weight|bias)$")?,
2280 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2282 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2283 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2284 ])
2285 }
2286 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2287 self.isq_layer_regexes_moqe(config)
2288 }
2289}
2290
2291impl DeviceMappedModelLoader for Phi3_5MoELoader {
2292 fn mapped_max_act_size_elems(
2293 &self,
2294 config: &str,
2295 params: &AutoDeviceMapParams,
2296 ) -> Result<usize> {
2297 let AutoDeviceMapParams::Text {
2298 max_seq_len,
2299 max_batch_size,
2300 } = params
2301 else {
2302 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2303 };
2304
2305 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2306
2307 Ok(
2308 max_batch_size
2309 * cfg.num_attention_heads
2310 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2311 )
2312 }
2313 fn non_mapped_max_act_size_elems(
2314 &self,
2315 _config: &str,
2316 _params: &AutoDeviceMapParams,
2317 ) -> Result<usize> {
2318 Ok(0)
2319 }
2320
2321 fn non_mapped_size_in_bytes(
2322 &self,
2323 config: &str,
2324 dtype: DType,
2325 weight_pack_factor: usize,
2326 _matformer_config: Option<&MatformerSliceConfig>,
2327 ) -> Result<usize> {
2328 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2329
2330 let elems = {
2331 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2332 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2334 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2335 } else {
2336 0
2337 };
2338 let norm = cfg.hidden_size;
2339 embed_tokens + lm_head + norm
2340 };
2341 Ok(elems * dtype.size_in_bytes())
2342 }
2343
2344 fn layer_sizes_in_bytes(
2345 &self,
2346 config: &str,
2347 dtype: DType,
2348 weight_pack_factor: usize,
2349 _matformer_config: Option<&MatformerSliceConfig>,
2350 ) -> Result<Vec<usize>> {
2351 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2352
2353 let per_layer_elems = {
2354 let input_layernorm = cfg.hidden_size;
2355 let post_attention_layernorm = cfg.hidden_size;
2356
2357 let size_in = cfg.hidden_size;
2358 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2359 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2360 let q_proj =
2361 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2362 let k_proj =
2363 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2364 let v_proj =
2365 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2366 let o_proj =
2367 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2368
2369 let moe_block = {
2370 let gate = cfg.hidden_size * cfg.num_local_experts;
2371 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2373 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2374 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2375 gate + cfg.num_local_experts * w1
2376 + cfg.num_local_experts * w2
2377 + cfg.num_local_experts * w3
2378 };
2379
2380 input_layernorm
2381 + post_attention_layernorm
2382 + q_proj
2383 + k_proj
2384 + v_proj
2385 + o_proj
2386 + moe_block
2387 };
2388 Ok(vec![
2389 per_layer_elems * dtype.size_in_bytes();
2390 cfg.num_hidden_layers
2391 ])
2392 }
2393
2394 fn num_layers(&self, config: &str) -> Result<usize> {
2395 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2396
2397 Ok(cfg.num_hidden_layers)
2398 }
2399
2400 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2401 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2402
2403 let cfg = ModelConfigMetadata {
2404 max_seq_len: cfg.max_position_embeddings,
2405 num_layers: cfg.num_hidden_layers,
2406 hidden_size: cfg.hidden_size,
2407 num_kv_heads: cfg.num_key_value_heads,
2408 num_attn_heads: cfg.num_attention_heads,
2409 sliding_window: cfg.sliding_window,
2410 k_head_dim: cfg.head_dim(),
2411 v_head_dim: cfg.head_dim(),
2412 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2413 };
2414
2415 Ok(Box::new(cfg))
2416 }
2417}
2418
2419pub struct DeepSeekV2Loader;
2423
2424impl NormalModelLoader for DeepSeekV2Loader {
2425 fn load(
2426 &self,
2427 config: &str,
2428 vb: ShardedVarBuilder,
2429 normal_loading_metadata: NormalLoadingMetadata,
2430 attention_mechanism: AttentionImplementation,
2431 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2432 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2433
2434 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2435 &cfg,
2436 vb,
2437 self.is_gptx(config)?,
2438 normal_loading_metadata,
2439 attention_mechanism,
2440 )?))
2441 }
2442 fn load_xlora(
2443 &self,
2444 _config: &str,
2445 _vb: ShardedVarBuilder,
2446 _lora_config: &[((String, String), LoraConfig)],
2447 _xlora_config: Option<XLoraConfig>,
2448 _xlora_ordering: Ordering,
2449 _normal_loading_metadata: NormalLoadingMetadata,
2450 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2451 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2452 todo!()
2453 }
2454 fn is_gptx(&self, _: &str) -> Result<bool> {
2455 Ok(true)
2456 }
2457 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2458 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2459 Ok(Box::new(cfg))
2460 }
2461}
2462
2463impl IsqModelLoader for DeepSeekV2Loader {
2464 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2465 let mut data = vec![
2466 Regex::new(r"lm_head\.(weight|bias)$")?,
2467 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2469 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2470 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2471 ];
2472 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2473 if cfg.q_lora_rank.is_some() {
2474 data.extend(vec![
2475 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2476 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2477 ]);
2478 } else {
2479 data.push(Regex::new(
2480 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2481 )?);
2482 }
2483 for layer_idx in 0..cfg.num_hidden_layers {
2484 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2485 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2486 }) {
2487 for i in 0..n_routed_experts {
2488 data.extend(vec![
2489 Regex::new(&format!(
2490 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2491 ))?,
2492 Regex::new(&format!(
2493 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2494 ))?,
2495 Regex::new(&format!(
2496 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2497 ))?,
2498 ]);
2499 }
2500 if cfg.n_shared_experts.is_some() {
2501 data.extend(vec![
2502 Regex::new(&format!(
2503 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2504 ))?,
2505 Regex::new(&format!(
2506 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2507 ))?,
2508 Regex::new(&format!(
2509 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2510 ))?,
2511 ]);
2512 }
2513 } else {
2514 data.extend(vec![
2515 Regex::new(&format!(
2516 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2517 ))?,
2518 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2519 Regex::new(&format!(
2520 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2521 ))?,
2522 ]);
2523 };
2524 }
2525 Ok(data)
2526 }
2527 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2528 self.isq_layer_regexes(config)
2529 }
2530
2531 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2532 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2533 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2534 for layer_idx in 0..cfg.num_hidden_layers {
2535 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2536 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2537 }) {
2538 for i in 0..n_routed_experts {
2539 data.extend(vec![
2540 Regex::new(&format!(
2541 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2542 ))?,
2543 Regex::new(&format!(
2544 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2545 ))?,
2546 Regex::new(&format!(
2547 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2548 ))?,
2549 ]);
2550 }
2551 if cfg.n_shared_experts.is_some() {
2552 data.extend(vec![
2553 Regex::new(&format!(
2554 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2555 ))?,
2556 Regex::new(&format!(
2557 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2558 ))?,
2559 Regex::new(&format!(
2560 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2561 ))?,
2562 ]);
2563 }
2564 } else {
2565 data.extend(vec![
2566 Regex::new(&format!(
2567 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2568 ))?,
2569 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2570 Regex::new(&format!(
2571 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2572 ))?,
2573 ]);
2574 };
2575 }
2576 Ok(data)
2577 }
2578 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2579 self.isq_layer_regexes_moqe(config)
2580 }
2581}
2582
2583impl DeviceMappedModelLoader for DeepSeekV2Loader {
2584 fn mapped_max_act_size_elems(
2585 &self,
2586 config: &str,
2587 params: &AutoDeviceMapParams,
2588 ) -> Result<usize> {
2589 let AutoDeviceMapParams::Text {
2590 max_seq_len,
2591 max_batch_size,
2592 } = params
2593 else {
2594 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2595 };
2596
2597 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2598
2599 Ok(
2600 max_batch_size
2601 * cfg.num_attention_heads
2602 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2603 )
2604 }
2605 fn non_mapped_max_act_size_elems(
2606 &self,
2607 _config: &str,
2608 _params: &AutoDeviceMapParams,
2609 ) -> Result<usize> {
2610 Ok(0)
2611 }
2612
2613 fn non_mapped_size_in_bytes(
2614 &self,
2615 config: &str,
2616 dtype: DType,
2617 weight_pack_factor: usize,
2618 _matformer_config: Option<&MatformerSliceConfig>,
2619 ) -> Result<usize> {
2620 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2621 let elems = {
2622 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2623 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2625 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2626 } else {
2627 0
2628 };
2629 let norm = cfg.hidden_size;
2630 embed_tokens + lm_head + norm
2631 };
2632 Ok(elems * dtype.size_in_bytes())
2633 }
2634
2635 fn layer_sizes_in_bytes(
2636 &self,
2637 config: &str,
2638 dtype: DType,
2639 weight_pack_factor: usize,
2640 _matformer_config: Option<&MatformerSliceConfig>,
2641 ) -> Result<Vec<usize>> {
2642 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2643 let mut per_layer_elems = Vec::new();
2644
2645 for layer_idx in 0..cfg.num_hidden_layers {
2646 let input_layernorm = cfg.hidden_size;
2647 let post_attention_layernorm = cfg.hidden_size;
2648
2649 let q_proj = match cfg.q_lora_rank {
2650 Some(lora_rank) => {
2651 let a = cfg.hidden_size * lora_rank;
2652 let norm = lora_rank;
2653 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2654 a + norm + b
2655 }
2656 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2657 };
2658 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2659 / weight_pack_factor
2660 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2661 let kv_a_layernorm = cfg.kv_lora_rank;
2662 let kv_b_proj = cfg.kv_lora_rank
2663 * cfg.num_attention_heads
2664 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2665 / weight_pack_factor;
2666 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2667 / weight_pack_factor
2668 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2669
2670 let moe_block = {
2671 let mut sum = 0;
2672 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2673 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2674 }) {
2675 let h_size = cfg.hidden_size;
2676 let gate_proj =
2677 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2678 let up_proj =
2679 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2680 let down_proj =
2681 cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
2682 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2683 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2684 / weight_pack_factor;
2685 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2686 / weight_pack_factor;
2687 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2688 / weight_pack_factor;
2689 gate_proj + up_proj + down_proj
2690 } else {
2691 0
2692 };
2693 let gate_weight = n_routed_experts * cfg.hidden_size;
2694 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2695 } else {
2696 let h_size = cfg.hidden_size;
2697 let i_size = cfg.intermediate_size;
2698 let gate_proj = h_size * i_size / weight_pack_factor;
2699 let up_proj = h_size * i_size / weight_pack_factor;
2700 let down_proj = i_size * h_size / weight_pack_factor;
2701 sum += gate_proj + up_proj + down_proj;
2702 }
2703 sum
2704 };
2705
2706 per_layer_elems.push(
2707 input_layernorm
2708 + post_attention_layernorm
2709 + q_proj
2710 + kv_a_layernorm
2711 + kv_a_proj_with_mqa
2712 + kv_b_proj
2713 + o_proj
2714 + moe_block,
2715 );
2716 }
2717
2718 Ok(per_layer_elems
2719 .into_iter()
2720 .map(|x| x * dtype.size_in_bytes())
2721 .collect())
2722 }
2723
2724 fn num_layers(&self, config: &str) -> Result<usize> {
2725 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2726 Ok(cfg.num_hidden_layers)
2727 }
2728
2729 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2730 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2731
2732 let cfg = ModelConfigMetadata {
2733 max_seq_len: cfg.max_position_embeddings,
2734 num_layers: cfg.num_hidden_layers,
2735 hidden_size: cfg.hidden_size,
2736 num_kv_heads: cfg.num_attention_heads,
2737 num_attn_heads: cfg.num_attention_heads,
2738 sliding_window: None,
2739 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2740 v_head_dim: cfg.v_head_dim,
2741 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2742 };
2743
2744 Ok(Box::new(cfg))
2745 }
2746}
2747
2748pub struct DeepSeekV3Loader;
2752
2753impl NormalModelLoader for DeepSeekV3Loader {
2754 fn load(
2755 &self,
2756 config: &str,
2757 vb: ShardedVarBuilder,
2758 normal_loading_metadata: NormalLoadingMetadata,
2759 attention_mechanism: AttentionImplementation,
2760 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2761 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2762 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2763 &cfg,
2764 vb,
2765 self.is_gptx(config)?,
2766 normal_loading_metadata,
2767 attention_mechanism,
2768 )?))
2769 }
2770 fn load_xlora(
2771 &self,
2772 _config: &str,
2773 _vb: ShardedVarBuilder,
2774 _lora_config: &[((String, String), LoraConfig)],
2775 _xlora_config: Option<XLoraConfig>,
2776 _xlora_ordering: Ordering,
2777 _normal_loading_metadata: NormalLoadingMetadata,
2778 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2779 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2780 todo!()
2781 }
2782 fn is_gptx(&self, _: &str) -> Result<bool> {
2783 Ok(true)
2784 }
2785 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2786 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2787 Ok(Box::new(cfg))
2788 }
2789}
2790
2791impl IsqModelLoader for DeepSeekV3Loader {
2792 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2793 let mut data = vec![
2794 Regex::new(r"lm_head\.(weight|bias)$")?,
2795 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2797 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2798 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2799 ];
2800 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2801 if cfg.q_lora_rank.is_some() {
2802 data.extend(vec![
2803 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2804 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2805 ]);
2806 } else {
2807 data.push(Regex::new(
2808 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2809 )?);
2810 }
2811 for layer_idx in 0..cfg.num_hidden_layers {
2812 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2813 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2814 }) {
2815 for i in 0..n_routed_experts {
2816 data.extend(vec![
2817 Regex::new(&format!(
2818 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2819 ))?,
2820 Regex::new(&format!(
2821 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2822 ))?,
2823 Regex::new(&format!(
2824 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2825 ))?,
2826 ]);
2827 }
2828 if cfg.n_shared_experts.is_some() {
2829 data.extend(vec![
2830 Regex::new(&format!(
2831 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2832 ))?,
2833 Regex::new(&format!(
2834 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2835 ))?,
2836 Regex::new(&format!(
2837 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2838 ))?,
2839 ]);
2840 }
2841 } else {
2842 data.extend(vec![
2843 Regex::new(&format!(
2844 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2845 ))?,
2846 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2847 Regex::new(&format!(
2848 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2849 ))?,
2850 ]);
2851 };
2852 }
2853 Ok(data)
2854 }
2855 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2856 self.isq_layer_regexes(config)
2857 }
2858
2859 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2860 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2861 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2862 for layer_idx in 0..cfg.num_hidden_layers {
2863 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2864 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2865 }) {
2866 for i in 0..n_routed_experts {
2867 data.extend(vec![
2868 Regex::new(&format!(
2869 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2870 ))?,
2871 Regex::new(&format!(
2872 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2873 ))?,
2874 Regex::new(&format!(
2875 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2876 ))?,
2877 ]);
2878 }
2879 if cfg.n_shared_experts.is_some() {
2880 data.extend(vec![
2881 Regex::new(&format!(
2882 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2883 ))?,
2884 Regex::new(&format!(
2885 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2886 ))?,
2887 Regex::new(&format!(
2888 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2889 ))?,
2890 ]);
2891 }
2892 } else {
2893 data.extend(vec![
2894 Regex::new(&format!(
2895 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2896 ))?,
2897 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2898 Regex::new(&format!(
2899 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2900 ))?,
2901 ]);
2902 };
2903 }
2904 Ok(data)
2905 }
2906 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2907 self.isq_layer_regexes_moqe(config)
2908 }
2909}
2910
2911impl DeviceMappedModelLoader for DeepSeekV3Loader {
2912 fn mapped_max_act_size_elems(
2913 &self,
2914 config: &str,
2915 params: &AutoDeviceMapParams,
2916 ) -> Result<usize> {
2917 let AutoDeviceMapParams::Text {
2918 max_seq_len,
2919 max_batch_size,
2920 } = params
2921 else {
2922 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2923 };
2924
2925 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2926
2927 Ok(
2928 max_batch_size
2929 * cfg.num_attention_heads
2930 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2931 )
2932 }
2933 fn non_mapped_max_act_size_elems(
2934 &self,
2935 _config: &str,
2936 _params: &AutoDeviceMapParams,
2937 ) -> Result<usize> {
2938 Ok(0)
2939 }
2940
2941 fn non_mapped_size_in_bytes(
2942 &self,
2943 config: &str,
2944 dtype: DType,
2945 weight_pack_factor: usize,
2946 _matformer_config: Option<&MatformerSliceConfig>,
2947 ) -> Result<usize> {
2948 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2949 let elems = {
2950 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2951 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2953 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2954 } else {
2955 0
2956 };
2957 let norm = cfg.hidden_size;
2958 embed_tokens + lm_head + norm
2959 };
2960 Ok(elems * dtype.size_in_bytes())
2961 }
2962
2963 fn layer_sizes_in_bytes(
2964 &self,
2965 config: &str,
2966 dtype: DType,
2967 weight_pack_factor: usize,
2968 _matformer_config: Option<&MatformerSliceConfig>,
2969 ) -> Result<Vec<usize>> {
2970 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2971 let mut per_layer_elems = Vec::new();
2972
2973 for layer_idx in 0..cfg.num_hidden_layers {
2974 let input_layernorm = cfg.hidden_size;
2975 let post_attention_layernorm = cfg.hidden_size;
2976
2977 let q_proj = match cfg.q_lora_rank {
2978 Some(lora_rank) => {
2979 let a = cfg.hidden_size * lora_rank;
2980 let norm = lora_rank;
2981 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2982 a + norm + b
2983 }
2984 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2985 };
2986 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2987 / weight_pack_factor
2988 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2989 let kv_a_layernorm = cfg.kv_lora_rank;
2990 let kv_b_proj = cfg.kv_lora_rank
2991 * cfg.num_attention_heads
2992 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2993 / weight_pack_factor;
2994 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2995 / weight_pack_factor
2996 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2997
2998 let moe_block = {
2999 let mut sum = 0;
3000 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
3001 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
3002 }) {
3003 let h_size = cfg.hidden_size;
3004 let gate_proj =
3005 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3006 let up_proj =
3007 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3008 let down_proj =
3009 cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
3010 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3011 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3012 / weight_pack_factor;
3013 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3014 / weight_pack_factor;
3015 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3016 / weight_pack_factor;
3017 gate_proj + up_proj + down_proj
3018 } else {
3019 0
3020 };
3021 let gate_weight = n_routed_experts * cfg.hidden_size;
3022 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3023 } else {
3024 let h_size = cfg.hidden_size;
3025 let i_size = cfg.intermediate_size;
3026 let gate_proj = h_size * i_size / weight_pack_factor;
3027 let up_proj = h_size * i_size / weight_pack_factor;
3028 let down_proj = i_size * h_size / weight_pack_factor;
3029 sum += gate_proj + up_proj + down_proj;
3030 }
3031 sum
3032 };
3033
3034 per_layer_elems.push(
3035 input_layernorm
3036 + post_attention_layernorm
3037 + q_proj
3038 + kv_a_layernorm
3039 + kv_a_proj_with_mqa
3040 + kv_b_proj
3041 + o_proj
3042 + moe_block,
3043 );
3044 }
3045
3046 Ok(per_layer_elems
3047 .into_iter()
3048 .map(|x| x * dtype.size_in_bytes())
3049 .collect())
3050 }
3051
3052 fn num_layers(&self, config: &str) -> Result<usize> {
3053 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3054 Ok(cfg.num_hidden_layers)
3055 }
3056
3057 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3058 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3059
3060 let cfg = ModelConfigMetadata {
3061 max_seq_len: cfg.max_position_embeddings,
3062 num_layers: cfg.num_hidden_layers,
3063 hidden_size: cfg.hidden_size,
3064 num_kv_heads: cfg.num_attention_heads,
3065 num_attn_heads: cfg.num_attention_heads,
3066 sliding_window: None,
3067 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3068 v_head_dim: cfg.v_head_dim,
3069 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3070 };
3071
3072 Ok(Box::new(cfg))
3073 }
3074}
3075
3076pub struct Qwen3Loader;
3080
3081impl NormalModelLoader for Qwen3Loader {
3082 fn load(
3083 &self,
3084 config: &str,
3085 vb: ShardedVarBuilder,
3086 normal_loading_metadata: NormalLoadingMetadata,
3087 attention_mechanism: AttentionImplementation,
3088 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3089 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3090
3091 Ok(Box::new(models::qwen3::Model::new(
3092 &cfg,
3093 vb,
3094 self.is_gptx(config)?,
3095 normal_loading_metadata,
3096 attention_mechanism,
3097 )?))
3098 }
3099 fn load_xlora(
3100 &self,
3101 _config: &str,
3102 _vb: ShardedVarBuilder,
3103 _lora_config: &[((String, String), LoraConfig)],
3104 _xlora_config: Option<XLoraConfig>,
3105 _xlora_ordering: Ordering,
3106 _normal_loading_metadata: NormalLoadingMetadata,
3107 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3108 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3109 todo!()
3110 }
3111 fn is_gptx(&self, _: &str) -> Result<bool> {
3112 Ok(true)
3113 }
3114 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3115 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3116
3117 Ok(Box::new(cfg))
3118 }
3119}
3120
3121impl IsqModelLoader for Qwen3Loader {
3122 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3123 Ok(vec![
3124 Regex::new(r"lm_head\.(weight|bias)$")?,
3125 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3127 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3128 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3129 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3130 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3132 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3133 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3134 ])
3135 }
3136 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3137 self.isq_layer_regexes(config)
3138 }
3139}
3140
3141impl DeviceMappedModelLoader for Qwen3Loader {
3142 fn mapped_max_act_size_elems(
3143 &self,
3144 config: &str,
3145 params: &AutoDeviceMapParams,
3146 ) -> Result<usize> {
3147 let AutoDeviceMapParams::Text {
3148 max_seq_len,
3149 max_batch_size,
3150 } = params
3151 else {
3152 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3153 };
3154
3155 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3156
3157 Ok(
3158 max_batch_size
3159 * cfg.num_attention_heads
3160 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3161 )
3162 }
3163 fn non_mapped_max_act_size_elems(
3164 &self,
3165 _config: &str,
3166 _params: &AutoDeviceMapParams,
3167 ) -> Result<usize> {
3168 Ok(0)
3169 }
3170
3171 fn non_mapped_size_in_bytes(
3172 &self,
3173 config: &str,
3174 dtype: DType,
3175 weight_pack_factor: usize,
3176 _matformer_config: Option<&MatformerSliceConfig>,
3177 ) -> Result<usize> {
3178 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3179 let elems = {
3180 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3181 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3183 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3184 } else {
3185 0
3186 };
3187 let norm = cfg.hidden_size;
3188 embed_tokens + lm_head + norm
3189 };
3190 Ok(elems * dtype.size_in_bytes())
3191 }
3192
3193 fn layer_sizes_in_bytes(
3194 &self,
3195 config: &str,
3196 dtype: DType,
3197 weight_pack_factor: usize,
3198 _matformer_config: Option<&MatformerSliceConfig>,
3199 ) -> Result<Vec<usize>> {
3200 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3201 let per_layer_elems = {
3202 let input_layernorm = cfg.hidden_size;
3203 let post_attention_layernorm = cfg.hidden_size;
3204
3205 let size_in = cfg.hidden_size;
3206 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3207 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3208 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3209 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3210 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3211 let o_proj = size_q * size_in / weight_pack_factor;
3212
3213 let h_size = cfg.hidden_size;
3214 let i_size = cfg.intermediate_size;
3215 let gate_proj = h_size * i_size / weight_pack_factor;
3216 let up_proj = h_size * i_size / weight_pack_factor;
3217 let down_proj = i_size * h_size / weight_pack_factor;
3218
3219 let q_norm = cfg.head_dim();
3220 let k_norm = cfg.head_dim();
3221
3222 input_layernorm
3223 + post_attention_layernorm
3224 + q_proj
3225 + k_proj
3226 + v_proj
3227 + o_proj
3228 + gate_proj
3229 + up_proj
3230 + down_proj
3231 + q_norm
3232 + k_norm
3233 };
3234 Ok(vec![
3235 per_layer_elems * dtype.size_in_bytes();
3236 cfg.num_hidden_layers
3237 ])
3238 }
3239
3240 fn num_layers(&self, config: &str) -> Result<usize> {
3241 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3242 Ok(cfg.num_hidden_layers)
3243 }
3244
3245 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3246 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3247
3248 let cfg = ModelConfigMetadata {
3249 max_seq_len: cfg.max_position_embeddings,
3250 num_layers: cfg.num_hidden_layers,
3251 hidden_size: cfg.hidden_size,
3252 num_kv_heads: cfg.num_key_value_heads,
3253 num_attn_heads: cfg.num_attention_heads,
3254 sliding_window: cfg.sliding_window,
3255 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3256 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3257 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3258 };
3259
3260 Ok(Box::new(cfg))
3261 }
3262}
3263
3264pub struct GLM4Loader;
3268
3269impl NormalModelLoader for GLM4Loader {
3270 fn load(
3271 &self,
3272 config: &str,
3273 vb: ShardedVarBuilder,
3274 normal_loading_metadata: NormalLoadingMetadata,
3275 attention_mechanism: AttentionImplementation,
3276 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3277 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3278
3279 Ok(Box::new(models::glm4::Model::new(
3280 &cfg,
3281 vb,
3282 self.is_gptx(config)?,
3283 normal_loading_metadata,
3284 attention_mechanism,
3285 )?))
3286 }
3287 fn load_xlora(
3288 &self,
3289 _config: &str,
3290 _vb: ShardedVarBuilder,
3291 _lora_config: &[((String, String), LoraConfig)],
3292 _xlora_config: Option<XLoraConfig>,
3293 _xlora_ordering: Ordering,
3294 _normal_loading_metadata: NormalLoadingMetadata,
3295 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3296 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3297 todo!()
3298 }
3299 fn is_gptx(&self, _: &str) -> Result<bool> {
3300 Ok(true)
3301 }
3302 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3303 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3304
3305 Ok(Box::new(cfg))
3306 }
3307}
3308
3309impl IsqModelLoader for GLM4Loader {
3310 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3311 Ok(vec![
3312 Regex::new(r"lm_head\.(weight|bias)$")?,
3313 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3315 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3316 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3317 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3318 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3320 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3321 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3322 ])
3323 }
3324 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3325 self.isq_layer_regexes(config)
3326 }
3327}
3328
3329impl DeviceMappedModelLoader for GLM4Loader {
3330 fn mapped_max_act_size_elems(
3331 &self,
3332 config: &str,
3333 params: &AutoDeviceMapParams,
3334 ) -> Result<usize> {
3335 let AutoDeviceMapParams::Text {
3336 max_seq_len,
3337 max_batch_size,
3338 } = params
3339 else {
3340 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3341 };
3342
3343 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3344
3345 Ok(
3346 max_batch_size
3347 * cfg.num_attention_heads
3348 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3349 )
3350 }
3351 fn non_mapped_max_act_size_elems(
3352 &self,
3353 _config: &str,
3354 _params: &AutoDeviceMapParams,
3355 ) -> Result<usize> {
3356 Ok(0)
3357 }
3358
3359 fn non_mapped_size_in_bytes(
3360 &self,
3361 config: &str,
3362 dtype: DType,
3363 weight_pack_factor: usize,
3364 _matformer_config: Option<&MatformerSliceConfig>,
3365 ) -> Result<usize> {
3366 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3367 let elems = {
3368 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3369 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3371 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3372 } else {
3373 0
3374 };
3375 let norm = cfg.hidden_size;
3376 embed_tokens + lm_head + norm
3377 };
3378 Ok(elems * dtype.size_in_bytes())
3379 }
3380
3381 fn layer_sizes_in_bytes(
3382 &self,
3383 config: &str,
3384 dtype: DType,
3385 weight_pack_factor: usize,
3386 _matformer_config: Option<&MatformerSliceConfig>,
3387 ) -> Result<Vec<usize>> {
3388 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3389 let per_layer_elems = {
3390 let input_layernorm = cfg.hidden_size;
3391 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3394 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3395 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3396 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3397 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3398 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3399 let o_proj = size_q * size_in / weight_pack_factor;
3400
3401 let h_size = cfg.hidden_size;
3402 let i_size = cfg.intermediate_size;
3403 let gate_proj = h_size * i_size / weight_pack_factor;
3404 let up_proj = h_size * i_size / weight_pack_factor;
3405 let down_proj = i_size * h_size / weight_pack_factor;
3406
3407 input_layernorm
3408 + post_attention_layernorm
3409 + q_proj
3410 + k_proj
3411 + v_proj
3412 + o_proj
3413 + gate_proj
3414 + up_proj
3415 + down_proj
3416 };
3417 Ok(vec![
3418 per_layer_elems * dtype.size_in_bytes();
3419 cfg.num_hidden_layers
3420 ])
3421 }
3422
3423 fn num_layers(&self, config: &str) -> Result<usize> {
3424 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3425 Ok(cfg.num_hidden_layers)
3426 }
3427
3428 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3429 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3430
3431 let cfg = ModelConfigMetadata {
3432 max_seq_len: cfg.max_position_embeddings,
3433 num_layers: cfg.num_hidden_layers,
3434 hidden_size: cfg.hidden_size,
3435 num_kv_heads: cfg.num_key_value_heads,
3436 num_attn_heads: cfg.num_attention_heads,
3437 sliding_window: cfg.sliding_window,
3438 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3439 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3440 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3441 };
3442
3443 Ok(Box::new(cfg))
3444 }
3445}
3446
3447pub struct GLM4MoeLiteLoader;
3451
3452impl NormalModelLoader for GLM4MoeLiteLoader {
3453 fn load(
3454 &self,
3455 config: &str,
3456 vb: ShardedVarBuilder,
3457 normal_loading_metadata: NormalLoadingMetadata,
3458 attention_mechanism: AttentionImplementation,
3459 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3460 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3461 Ok(Box::new(models::glm4_moe_lite::Glm4MoeLite::new(
3462 &cfg,
3463 vb,
3464 self.is_gptx(config)?,
3465 normal_loading_metadata,
3466 attention_mechanism,
3467 )?))
3468 }
3469 fn load_xlora(
3470 &self,
3471 _config: &str,
3472 _vb: ShardedVarBuilder,
3473 _lora_config: &[((String, String), LoraConfig)],
3474 _xlora_config: Option<XLoraConfig>,
3475 _xlora_ordering: Ordering,
3476 _normal_loading_metadata: NormalLoadingMetadata,
3477 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3478 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3479 todo!()
3480 }
3481 fn is_gptx(&self, _: &str) -> Result<bool> {
3482 Ok(true)
3483 }
3484 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3485 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3486 Ok(Box::new(cfg))
3487 }
3488}
3489
3490impl IsqModelLoader for GLM4MoeLiteLoader {
3491 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3492 let mut data = vec![
3493 Regex::new(r"lm_head\.(weight|bias)$")?,
3494 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3496 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3497 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3498 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3500 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3501 ];
3502 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3503 for layer_idx in 0..cfg.num_hidden_layers {
3504 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3505 for i in 0..cfg.n_routed_experts {
3507 data.extend(vec![
3508 Regex::new(&format!(
3509 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3510 ))?,
3511 Regex::new(&format!(
3512 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3513 ))?,
3514 Regex::new(&format!(
3515 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3516 ))?,
3517 ]);
3518 }
3519 if cfg.n_shared_experts > 0 {
3520 data.extend(vec![
3521 Regex::new(&format!(
3522 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3523 ))?,
3524 Regex::new(&format!(
3525 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3526 ))?,
3527 Regex::new(&format!(
3528 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3529 ))?,
3530 ]);
3531 }
3532 } else {
3533 data.extend(vec![
3535 Regex::new(&format!(
3536 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3537 ))?,
3538 Regex::new(&format!(
3539 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3540 ))?,
3541 Regex::new(&format!(
3542 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3543 ))?,
3544 ]);
3545 };
3546 }
3547 Ok(data)
3548 }
3549 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3550 self.isq_layer_regexes(config)
3551 }
3552
3553 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3554 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3555 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3556 for layer_idx in 0..cfg.num_hidden_layers {
3557 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3558 for i in 0..cfg.n_routed_experts {
3560 data.extend(vec![
3561 Regex::new(&format!(
3562 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3563 ))?,
3564 Regex::new(&format!(
3565 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3566 ))?,
3567 Regex::new(&format!(
3568 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3569 ))?,
3570 ]);
3571 }
3572 if cfg.n_shared_experts > 0 {
3573 data.extend(vec![
3574 Regex::new(&format!(
3575 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3576 ))?,
3577 Regex::new(&format!(
3578 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3579 ))?,
3580 Regex::new(&format!(
3581 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3582 ))?,
3583 ]);
3584 }
3585 } else {
3586 data.extend(vec![
3588 Regex::new(&format!(
3589 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3590 ))?,
3591 Regex::new(&format!(
3592 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3593 ))?,
3594 Regex::new(&format!(
3595 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3596 ))?,
3597 ]);
3598 };
3599 }
3600 Ok(data)
3601 }
3602 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3603 self.isq_layer_regexes_moqe(config)
3604 }
3605}
3606
3607impl DeviceMappedModelLoader for GLM4MoeLiteLoader {
3608 fn mapped_max_act_size_elems(
3609 &self,
3610 config: &str,
3611 params: &AutoDeviceMapParams,
3612 ) -> Result<usize> {
3613 let AutoDeviceMapParams::Text {
3614 max_seq_len,
3615 max_batch_size,
3616 } = params
3617 else {
3618 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3619 };
3620
3621 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3622
3623 Ok(
3624 max_batch_size
3625 * cfg.num_attention_heads
3626 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3627 )
3628 }
3629 fn non_mapped_max_act_size_elems(
3630 &self,
3631 _config: &str,
3632 _params: &AutoDeviceMapParams,
3633 ) -> Result<usize> {
3634 Ok(0)
3635 }
3636
3637 fn non_mapped_size_in_bytes(
3638 &self,
3639 config: &str,
3640 dtype: DType,
3641 weight_pack_factor: usize,
3642 _matformer_config: Option<&MatformerSliceConfig>,
3643 ) -> Result<usize> {
3644 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3645 let elems = {
3646 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3647 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3649 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3650 } else {
3651 0
3652 };
3653 let norm = cfg.hidden_size;
3654 embed_tokens + lm_head + norm
3655 };
3656 Ok(elems * dtype.size_in_bytes())
3657 }
3658
3659 fn layer_sizes_in_bytes(
3660 &self,
3661 config: &str,
3662 dtype: DType,
3663 weight_pack_factor: usize,
3664 _matformer_config: Option<&MatformerSliceConfig>,
3665 ) -> Result<Vec<usize>> {
3666 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3667 let mut per_layer_elems = Vec::new();
3668
3669 for layer_idx in 0..cfg.num_hidden_layers {
3670 let input_layernorm = cfg.hidden_size;
3671 let post_attention_layernorm = cfg.hidden_size;
3672
3673 let q_proj = {
3675 let a = cfg.hidden_size * cfg.q_lora_rank / weight_pack_factor;
3676 let norm = cfg.q_lora_rank;
3677 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.q_lora_rank
3678 / weight_pack_factor;
3679 a + norm + b
3680 };
3681 let kv_a_proj_with_mqa =
3682 cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim) / weight_pack_factor;
3683 let kv_a_layernorm = cfg.kv_lora_rank;
3684 let kv_b_proj = cfg.kv_lora_rank
3685 * cfg.num_attention_heads
3686 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3687 / weight_pack_factor;
3688 let o_proj =
3689 cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size / weight_pack_factor;
3690
3691 let moe_block = {
3692 let mut sum = 0;
3693 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3694 let h_size = cfg.hidden_size;
3696 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3697 * cfg.n_routed_experts;
3698 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3699 * cfg.n_routed_experts;
3700 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3701 * cfg.n_routed_experts;
3702 let shared_experts = if cfg.n_shared_experts > 0 {
3703 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3704 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3705 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
3706 gate_proj + up_proj + down_proj
3707 } else {
3708 0
3709 };
3710 let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
3711 let e_score_correction_bias = cfg.n_routed_experts;
3712 sum += gate_proj
3713 + up_proj
3714 + down_proj
3715 + shared_experts
3716 + gate_weight
3717 + e_score_correction_bias;
3718 } else {
3719 let h_size = cfg.hidden_size;
3721 let i_size = cfg.intermediate_size;
3722 let gate_proj = h_size * i_size / weight_pack_factor;
3723 let up_proj = h_size * i_size / weight_pack_factor;
3724 let down_proj = i_size * h_size / weight_pack_factor;
3725 sum += gate_proj + up_proj + down_proj;
3726 }
3727 sum
3728 };
3729
3730 per_layer_elems.push(
3731 input_layernorm
3732 + post_attention_layernorm
3733 + q_proj
3734 + kv_a_layernorm
3735 + kv_a_proj_with_mqa
3736 + kv_b_proj
3737 + o_proj
3738 + moe_block,
3739 );
3740 }
3741
3742 Ok(per_layer_elems
3743 .into_iter()
3744 .map(|x| x * dtype.size_in_bytes())
3745 .collect())
3746 }
3747
3748 fn num_layers(&self, config: &str) -> Result<usize> {
3749 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3750 Ok(cfg.num_hidden_layers)
3751 }
3752
3753 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3754 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3755
3756 let cfg = ModelConfigMetadata {
3757 max_seq_len: cfg.max_position_embeddings,
3758 num_layers: cfg.num_hidden_layers,
3759 hidden_size: cfg.hidden_size,
3760 num_kv_heads: cfg.num_attention_heads,
3761 num_attn_heads: cfg.num_attention_heads,
3762 sliding_window: None,
3763 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3764 v_head_dim: cfg.v_head_dim,
3765 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3766 };
3767
3768 Ok(Box::new(cfg))
3769 }
3770}
3771
3772pub struct GLM4MoeLoader;
3776
3777impl NormalModelLoader for GLM4MoeLoader {
3778 fn load(
3779 &self,
3780 config: &str,
3781 vb: ShardedVarBuilder,
3782 normal_loading_metadata: NormalLoadingMetadata,
3783 attention_mechanism: AttentionImplementation,
3784 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3785 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3786 Ok(Box::new(models::glm4_moe::Glm4Moe::new(
3787 &cfg,
3788 vb,
3789 self.is_gptx(config)?,
3790 normal_loading_metadata,
3791 attention_mechanism,
3792 )?))
3793 }
3794 fn load_xlora(
3795 &self,
3796 _config: &str,
3797 _vb: ShardedVarBuilder,
3798 _lora_config: &[((String, String), LoraConfig)],
3799 _xlora_config: Option<XLoraConfig>,
3800 _xlora_ordering: Ordering,
3801 _normal_loading_metadata: NormalLoadingMetadata,
3802 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3803 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3804 todo!()
3805 }
3806 fn is_gptx(&self, _: &str) -> Result<bool> {
3807 Ok(true)
3808 }
3809 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3810 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3811 Ok(Box::new(cfg))
3812 }
3813}
3814
3815impl IsqModelLoader for GLM4MoeLoader {
3816 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3817 let mut data = vec![
3818 Regex::new(r"lm_head\.(weight|bias)$")?,
3819 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3821 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3822 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3823 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3824 ];
3825 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3826 for layer_idx in 0..cfg.num_hidden_layers {
3827 if layer_idx >= cfg.first_k_dense_replace {
3828 for i in 0..cfg.n_routed_experts {
3830 data.extend(vec![
3831 Regex::new(&format!(
3832 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3833 ))?,
3834 Regex::new(&format!(
3835 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3836 ))?,
3837 Regex::new(&format!(
3838 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3839 ))?,
3840 ]);
3841 }
3842 if cfg.n_shared_experts > 0 {
3843 data.extend(vec![
3844 Regex::new(&format!(
3845 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3846 ))?,
3847 Regex::new(&format!(
3848 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3849 ))?,
3850 Regex::new(&format!(
3851 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3852 ))?,
3853 ]);
3854 }
3855 } else {
3856 data.extend(vec![
3858 Regex::new(&format!(
3859 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3860 ))?,
3861 Regex::new(&format!(
3862 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3863 ))?,
3864 Regex::new(&format!(
3865 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3866 ))?,
3867 ]);
3868 };
3869 }
3870 Ok(data)
3871 }
3872 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3873 self.isq_layer_regexes(config)
3874 }
3875
3876 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3877 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3878 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3879 for layer_idx in 0..cfg.num_hidden_layers {
3880 if layer_idx >= cfg.first_k_dense_replace {
3881 for i in 0..cfg.n_routed_experts {
3883 data.extend(vec![
3884 Regex::new(&format!(
3885 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3886 ))?,
3887 Regex::new(&format!(
3888 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3889 ))?,
3890 Regex::new(&format!(
3891 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3892 ))?,
3893 ]);
3894 }
3895 if cfg.n_shared_experts > 0 {
3896 data.extend(vec![
3897 Regex::new(&format!(
3898 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3899 ))?,
3900 Regex::new(&format!(
3901 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3902 ))?,
3903 Regex::new(&format!(
3904 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3905 ))?,
3906 ]);
3907 }
3908 } else {
3909 data.extend(vec![
3911 Regex::new(&format!(
3912 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3913 ))?,
3914 Regex::new(&format!(
3915 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3916 ))?,
3917 Regex::new(&format!(
3918 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3919 ))?,
3920 ]);
3921 };
3922 }
3923 Ok(data)
3924 }
3925 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3926 self.isq_layer_regexes_moqe(config)
3927 }
3928}
3929
3930impl DeviceMappedModelLoader for GLM4MoeLoader {
3931 fn mapped_max_act_size_elems(
3932 &self,
3933 config: &str,
3934 params: &AutoDeviceMapParams,
3935 ) -> Result<usize> {
3936 let AutoDeviceMapParams::Text {
3937 max_seq_len,
3938 max_batch_size,
3939 } = params
3940 else {
3941 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3942 };
3943
3944 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3945
3946 Ok(
3947 max_batch_size
3948 * cfg.num_attention_heads
3949 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3950 )
3951 }
3952 fn non_mapped_max_act_size_elems(
3953 &self,
3954 _config: &str,
3955 _params: &AutoDeviceMapParams,
3956 ) -> Result<usize> {
3957 Ok(0)
3958 }
3959
3960 fn non_mapped_size_in_bytes(
3961 &self,
3962 config: &str,
3963 dtype: DType,
3964 weight_pack_factor: usize,
3965 _matformer_config: Option<&MatformerSliceConfig>,
3966 ) -> Result<usize> {
3967 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3968 let elems = {
3969 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3970 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3971 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3972 } else {
3973 0
3974 };
3975 let norm = cfg.hidden_size;
3976 embed_tokens + lm_head + norm
3977 };
3978 Ok(elems * dtype.size_in_bytes())
3979 }
3980
3981 fn layer_sizes_in_bytes(
3982 &self,
3983 config: &str,
3984 dtype: DType,
3985 weight_pack_factor: usize,
3986 _matformer_config: Option<&MatformerSliceConfig>,
3987 ) -> Result<Vec<usize>> {
3988 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3989 let mut per_layer_elems = Vec::new();
3990
3991 let head_dim = cfg.head_dim();
3992 for layer_idx in 0..cfg.num_hidden_layers {
3993 let input_layernorm = cfg.hidden_size;
3994 let post_attention_layernorm = cfg.hidden_size;
3995
3996 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor
3998 + bias_if!(cfg.attention_bias, cfg.num_attention_heads * head_dim);
3999 let k_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4000 + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4001 let v_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4002 + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4003 let o_proj = cfg.num_attention_heads * head_dim * cfg.hidden_size / weight_pack_factor;
4004
4005 let qk_norm = if cfg.use_qk_norm {
4007 head_dim * 2 } else {
4009 0
4010 };
4011
4012 let moe_block = {
4013 let mut sum = 0;
4014 if layer_idx >= cfg.first_k_dense_replace {
4015 let h_size = cfg.hidden_size;
4017 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4018 * cfg.n_routed_experts;
4019 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4020 * cfg.n_routed_experts;
4021 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
4022 * cfg.n_routed_experts;
4023 let shared_experts = if cfg.n_shared_experts > 0 {
4024 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4025 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4026 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
4027 gate_proj + up_proj + down_proj
4028 } else {
4029 0
4030 };
4031 let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
4032 let e_score_correction_bias = cfg.n_routed_experts;
4033 sum += gate_proj
4034 + up_proj
4035 + down_proj
4036 + shared_experts
4037 + gate_weight
4038 + e_score_correction_bias;
4039 } else {
4040 let h_size = cfg.hidden_size;
4042 let i_size = cfg.intermediate_size;
4043 let gate_proj = h_size * i_size / weight_pack_factor;
4044 let up_proj = h_size * i_size / weight_pack_factor;
4045 let down_proj = i_size * h_size / weight_pack_factor;
4046 sum += gate_proj + up_proj + down_proj;
4047 }
4048 sum
4049 };
4050
4051 per_layer_elems.push(
4052 input_layernorm
4053 + post_attention_layernorm
4054 + q_proj
4055 + k_proj
4056 + v_proj
4057 + o_proj
4058 + qk_norm
4059 + moe_block,
4060 );
4061 }
4062
4063 Ok(per_layer_elems
4064 .into_iter()
4065 .map(|x| x * dtype.size_in_bytes())
4066 .collect())
4067 }
4068
4069 fn num_layers(&self, config: &str) -> Result<usize> {
4070 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4071 Ok(cfg.num_hidden_layers)
4072 }
4073
4074 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4075 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4076
4077 let head_dim = cfg.head_dim();
4078 let cfg = ModelConfigMetadata {
4079 max_seq_len: cfg.max_position_embeddings,
4080 num_layers: cfg.num_hidden_layers,
4081 hidden_size: cfg.hidden_size,
4082 num_kv_heads: cfg.num_key_value_heads,
4083 num_attn_heads: cfg.num_attention_heads,
4084 sliding_window: None,
4085 k_head_dim: head_dim,
4086 v_head_dim: head_dim,
4087 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4088 };
4089
4090 Ok(Box::new(cfg))
4091 }
4092}
4093
4094pub struct Qwen3MoELoader;
4098
4099impl NormalModelLoader for Qwen3MoELoader {
4100 fn load(
4101 &self,
4102 config: &str,
4103 vb: ShardedVarBuilder,
4104 normal_loading_metadata: NormalLoadingMetadata,
4105 attention_mechanism: AttentionImplementation,
4106 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4107 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4108
4109 Ok(Box::new(models::qwen3_moe::Model::new(
4110 &cfg,
4111 vb,
4112 self.is_gptx(config)?,
4113 normal_loading_metadata,
4114 attention_mechanism,
4115 )?))
4116 }
4117 fn load_xlora(
4118 &self,
4119 _config: &str,
4120 _vb: ShardedVarBuilder,
4121 _lora_config: &[((String, String), LoraConfig)],
4122 _xlora_config: Option<XLoraConfig>,
4123 _xlora_ordering: Ordering,
4124 _normal_loading_metadata: NormalLoadingMetadata,
4125 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4126 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4127 todo!()
4128 }
4129 fn is_gptx(&self, _: &str) -> Result<bool> {
4130 Ok(true)
4131 }
4132 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4133 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4134
4135 Ok(Box::new(cfg))
4136 }
4137}
4138
4139impl IsqModelLoader for Qwen3MoELoader {
4140 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4141 Ok(vec![
4142 Regex::new(r"lm_head\.(weight|bias)$")?,
4143 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4145 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4146 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4147 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4148 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4150 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4151 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4152 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
4154 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
4155 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
4156 ])
4157 }
4158 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4159 self.isq_layer_regexes(config)
4160 }
4161 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
4162 self.isq_layer_regexes_moqe(config)
4163 }
4164}
4165
4166impl DeviceMappedModelLoader for Qwen3MoELoader {
4167 fn mapped_max_act_size_elems(
4168 &self,
4169 config: &str,
4170 params: &AutoDeviceMapParams,
4171 ) -> Result<usize> {
4172 let AutoDeviceMapParams::Text {
4173 max_seq_len,
4174 max_batch_size,
4175 } = params
4176 else {
4177 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4178 };
4179
4180 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4181
4182 Ok(
4183 max_batch_size
4184 * cfg.num_attention_heads
4185 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4186 )
4187 }
4188 fn non_mapped_max_act_size_elems(
4189 &self,
4190 _config: &str,
4191 _params: &AutoDeviceMapParams,
4192 ) -> Result<usize> {
4193 Ok(0)
4194 }
4195
4196 fn non_mapped_size_in_bytes(
4197 &self,
4198 config: &str,
4199 dtype: DType,
4200 weight_pack_factor: usize,
4201 _matformer_config: Option<&MatformerSliceConfig>,
4202 ) -> Result<usize> {
4203 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4204 let elems = {
4205 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
4206 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4208 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4209 } else {
4210 0
4211 };
4212 let norm = cfg.hidden_size;
4213 embed_tokens + lm_head + norm
4214 };
4215 Ok(elems * dtype.size_in_bytes())
4216 }
4217
4218 fn layer_sizes_in_bytes(
4219 &self,
4220 config: &str,
4221 dtype: DType,
4222 weight_pack_factor: usize,
4223 _matformer_config: Option<&MatformerSliceConfig>,
4224 ) -> Result<Vec<usize>> {
4225 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4226
4227 let mut layer_sizes_in_bytes = Vec::new();
4228 for layer_idx in 0..cfg.num_hidden_layers {
4229 let input_layernorm = cfg.hidden_size;
4230 let post_attention_layernorm = cfg.hidden_size;
4231
4232 let size_in = cfg.hidden_size;
4233 let size_q = cfg.head_dim() * cfg.num_attention_heads;
4234 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
4235 let q_proj = size_in * size_q / weight_pack_factor;
4236 let k_proj = size_in * size_kv / weight_pack_factor;
4237 let v_proj = size_in * size_kv / weight_pack_factor;
4238 let o_proj = size_q * size_in / weight_pack_factor;
4239
4240 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
4241 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
4242 {
4243 let gate_size = cfg.hidden_size * cfg.num_experts;
4244 let expert_size = {
4245 let h_size = cfg.hidden_size;
4246 let i_size = cfg.moe_intermediate_size;
4247 let gate_proj = h_size * i_size / weight_pack_factor;
4248 let up_proj = h_size * i_size / weight_pack_factor;
4249 let down_proj = i_size * h_size / weight_pack_factor;
4250 gate_proj + up_proj + down_proj
4251 };
4252 expert_size * cfg.num_experts + gate_size
4253 } else {
4254 let h_size = cfg.hidden_size;
4255 let i_size = cfg.intermediate_size;
4256 let gate_proj = h_size * i_size / weight_pack_factor;
4257 let up_proj = h_size * i_size / weight_pack_factor;
4258 let down_proj = i_size * h_size / weight_pack_factor;
4259 gate_proj + up_proj + down_proj
4260 };
4261
4262 let q_norm = cfg.head_dim();
4263 let k_norm = cfg.head_dim();
4264
4265 let size_elems = input_layernorm
4266 + post_attention_layernorm
4267 + q_proj
4268 + k_proj
4269 + v_proj
4270 + o_proj
4271 + mlp_size
4272 + q_norm
4273 + k_norm;
4274
4275 let size_in_bytes = size_elems * dtype.size_in_bytes();
4276 layer_sizes_in_bytes.push(size_in_bytes);
4277 }
4278
4279 Ok(layer_sizes_in_bytes)
4280 }
4281
4282 fn num_layers(&self, config: &str) -> Result<usize> {
4283 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4284 Ok(cfg.num_hidden_layers)
4285 }
4286
4287 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4288 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4289
4290 let cfg = ModelConfigMetadata {
4291 max_seq_len: cfg.max_position_embeddings,
4292 num_layers: cfg.num_hidden_layers,
4293 hidden_size: cfg.hidden_size,
4294 num_kv_heads: cfg.num_key_value_heads,
4295 num_attn_heads: cfg.num_attention_heads,
4296 sliding_window: cfg.sliding_window,
4297 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4298 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4299 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4300 };
4301
4302 Ok(Box::new(cfg))
4303 }
4304}
4305
4306pub struct SmolLm3Loader;
4312
4313impl NormalModelLoader for SmolLm3Loader {
4314 fn load(
4315 &self,
4316 config: &str,
4317 vb: ShardedVarBuilder,
4318 normal_loading_metadata: NormalLoadingMetadata,
4319 attention_mechanism: AttentionImplementation,
4320 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4321 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4322
4323 Ok(Box::new(models::smollm3::SmolLm3::new(
4324 &cfg,
4325 vb,
4326 self.is_gptx(config)?,
4327 normal_loading_metadata,
4328 attention_mechanism,
4329 )?))
4330 }
4331 fn load_xlora(
4332 &self,
4333 _config: &str,
4334 _vb: ShardedVarBuilder,
4335 _lora_config: &[((String, String), LoraConfig)],
4336 _xlora_config: Option<XLoraConfig>,
4337 _xlora_ordering: Ordering,
4338 _normal_loading_metadata: NormalLoadingMetadata,
4339 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4340 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4341 todo!()
4342 }
4343 fn is_gptx(&self, _: &str) -> Result<bool> {
4344 Ok(true)
4345 }
4346 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4347 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4348 Ok(Box::new(cfg))
4349 }
4350}
4351
4352impl IsqModelLoader for SmolLm3Loader {
4353 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4354 Ok(vec![
4355 Regex::new(r"lm_head\.(weight|bias)$")?,
4356 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4358 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4359 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4360 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4361 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4363 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4364 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4365 ])
4366 }
4367 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4368 self.isq_layer_regexes(config)
4369 }
4370}
4371
4372impl DeviceMappedModelLoader for SmolLm3Loader {
4373 fn mapped_max_act_size_elems(
4374 &self,
4375 config: &str,
4376 params: &AutoDeviceMapParams,
4377 ) -> Result<usize> {
4378 let AutoDeviceMapParams::Text {
4379 max_seq_len,
4380 max_batch_size,
4381 } = params
4382 else {
4383 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4384 };
4385
4386 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4387
4388 Ok(
4389 max_batch_size
4390 * cfg.num_attention_heads
4391 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4392 )
4393 }
4394 fn non_mapped_max_act_size_elems(
4395 &self,
4396 _config: &str,
4397 _params: &AutoDeviceMapParams,
4398 ) -> Result<usize> {
4399 Ok(0)
4400 }
4401
4402 fn non_mapped_size_in_bytes(
4403 &self,
4404 config: &str,
4405 dtype: DType,
4406 weight_pack_factor: usize,
4407 _matformer_config: Option<&MatformerSliceConfig>,
4408 ) -> Result<usize> {
4409 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4410
4411 let elems = {
4412 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4413 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4415 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4416 } else {
4417 0
4418 };
4419 let norm = cfg.hidden_size;
4420 embed_tokens + lm_head + norm
4421 };
4422 Ok(elems * dtype.size_in_bytes())
4423 }
4424
4425 fn layer_sizes_in_bytes(
4426 &self,
4427 config: &str,
4428 dtype: DType,
4429 weight_pack_factor: usize,
4430 _matformer_config: Option<&MatformerSliceConfig>,
4431 ) -> Result<Vec<usize>> {
4432 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4433
4434 let per_layer_elems = {
4435 let input_layernorm = cfg.hidden_size;
4436 let post_attention_layernorm = cfg.hidden_size;
4437
4438 let size_in = cfg.hidden_size;
4439 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4440 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4441 let q_proj = size_in * size_q / weight_pack_factor;
4442 let k_proj = size_in * size_kv / weight_pack_factor;
4443 let v_proj = size_in * size_kv / weight_pack_factor;
4444 let o_proj = size_q * size_in / weight_pack_factor;
4445
4446 let h_size = cfg.hidden_size;
4447 let i_size = cfg.intermediate_size;
4448 let gate_proj = h_size * i_size / weight_pack_factor;
4449 let up_proj = h_size * i_size / weight_pack_factor;
4450 let down_proj = i_size * h_size / weight_pack_factor;
4451
4452 input_layernorm
4453 + post_attention_layernorm
4454 + q_proj
4455 + k_proj
4456 + v_proj
4457 + o_proj
4458 + gate_proj
4459 + up_proj
4460 + down_proj
4461 };
4462 Ok(vec![
4463 per_layer_elems * dtype.size_in_bytes();
4464 cfg.num_hidden_layers
4465 ])
4466 }
4467
4468 fn num_layers(&self, config: &str) -> Result<usize> {
4469 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4470
4471 Ok(cfg.num_hidden_layers)
4472 }
4473 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4474 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4475
4476 let cfg = ModelConfigMetadata {
4477 max_seq_len: cfg.max_position_embeddings,
4478 num_layers: cfg.num_hidden_layers,
4479 hidden_size: cfg.hidden_size,
4480 num_kv_heads: cfg.num_key_value_heads,
4481 num_attn_heads: cfg.num_attention_heads,
4482 sliding_window: None,
4483 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4484 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4485 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4486 };
4487
4488 Ok(Box::new(cfg))
4489 }
4490}
4491
4492pub struct GraniteMoeHybridLoader;
4498
4499impl NormalModelLoader for GraniteMoeHybridLoader {
4500 fn load(
4501 &self,
4502 config: &str,
4503 vb: ShardedVarBuilder,
4504 normal_loading_metadata: NormalLoadingMetadata,
4505 attention_mechanism: AttentionImplementation,
4506 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4507 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4508
4509 Ok(Box::new(models::granite::GraniteMoeHybrid::new(
4510 &cfg,
4511 vb,
4512 self.is_gptx(config)?,
4513 normal_loading_metadata,
4514 attention_mechanism,
4515 )?))
4516 }
4517 fn load_xlora(
4518 &self,
4519 _config: &str,
4520 _vb: ShardedVarBuilder,
4521 _lora_config: &[((String, String), LoraConfig)],
4522 _xlora_config: Option<XLoraConfig>,
4523 _xlora_ordering: Ordering,
4524 _normal_loading_metadata: NormalLoadingMetadata,
4525 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4526 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4527 todo!()
4528 }
4529 fn is_gptx(&self, _: &str) -> Result<bool> {
4530 Ok(true)
4531 }
4532 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4533 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4534 Ok(Box::new(cfg))
4535 }
4536}
4537
4538impl IsqModelLoader for GraniteMoeHybridLoader {
4539 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4540 Ok(vec![
4541 Regex::new(r"lm_head\.(weight|bias)$")?,
4542 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4544 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4545 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4546 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4547 Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
4549 Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
4550 ])
4551 }
4552 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4553 self.isq_layer_regexes(config)
4554 }
4555}
4556
4557impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
4558 fn mapped_max_act_size_elems(
4559 &self,
4560 config: &str,
4561 params: &AutoDeviceMapParams,
4562 ) -> Result<usize> {
4563 let AutoDeviceMapParams::Text {
4564 max_seq_len,
4565 max_batch_size,
4566 } = params
4567 else {
4568 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4569 };
4570
4571 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4572
4573 Ok(
4574 max_batch_size
4575 * cfg.num_attention_heads
4576 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4577 )
4578 }
4579 fn non_mapped_max_act_size_elems(
4580 &self,
4581 _config: &str,
4582 _params: &AutoDeviceMapParams,
4583 ) -> Result<usize> {
4584 Ok(0)
4585 }
4586
4587 fn non_mapped_size_in_bytes(
4588 &self,
4589 config: &str,
4590 dtype: DType,
4591 weight_pack_factor: usize,
4592 _matformer_config: Option<&MatformerSliceConfig>,
4593 ) -> Result<usize> {
4594 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4595
4596 let elems = {
4597 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4598 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4600 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4601 } else {
4602 0
4603 };
4604 let norm = cfg.hidden_size;
4605 embed_tokens + lm_head + norm
4606 };
4607 Ok(elems * dtype.size_in_bytes())
4608 }
4609
4610 fn layer_sizes_in_bytes(
4611 &self,
4612 config: &str,
4613 dtype: DType,
4614 weight_pack_factor: usize,
4615 _matformer_config: Option<&MatformerSliceConfig>,
4616 ) -> Result<Vec<usize>> {
4617 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4618
4619 let per_layer_elems = {
4620 let input_layernorm = cfg.hidden_size;
4621 let post_attention_layernorm = cfg.hidden_size;
4622
4623 let size_in = cfg.hidden_size;
4624 let size_q = cfg.head_dim() * cfg.num_attention_heads;
4625 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
4626 let q_proj = size_in * size_q / weight_pack_factor;
4627 let k_proj = size_in * size_kv / weight_pack_factor;
4628 let v_proj = size_in * size_kv / weight_pack_factor;
4629 let o_proj = size_q * size_in / weight_pack_factor;
4630
4631 let h_size = cfg.hidden_size;
4632 let shared_i_size = cfg.shared_intermediate_size();
4633 let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
4635 let output_linear = shared_i_size * h_size / weight_pack_factor;
4636
4637 input_layernorm
4638 + post_attention_layernorm
4639 + q_proj
4640 + k_proj
4641 + v_proj
4642 + o_proj
4643 + input_linear
4644 + output_linear
4645 };
4646 Ok(vec![
4647 per_layer_elems * dtype.size_in_bytes();
4648 cfg.num_hidden_layers
4649 ])
4650 }
4651
4652 fn num_layers(&self, config: &str) -> Result<usize> {
4653 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4654
4655 Ok(cfg.num_hidden_layers)
4656 }
4657 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4658 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4659
4660 let cfg = ModelConfigMetadata {
4661 max_seq_len: cfg.max_position_embeddings,
4662 num_layers: cfg.num_hidden_layers,
4663 hidden_size: cfg.hidden_size,
4664 num_kv_heads: cfg.num_key_value_heads(),
4665 num_attn_heads: cfg.num_attention_heads,
4666 sliding_window: None,
4667 k_head_dim: cfg.head_dim(),
4668 v_head_dim: cfg.head_dim(),
4669 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4670 };
4671
4672 Ok(Box::new(cfg))
4673 }
4674}
4675
4676pub struct GptOssLoader;
4682
4683impl NormalModelLoader for GptOssLoader {
4684 fn load(
4685 &self,
4686 config: &str,
4687 vb: ShardedVarBuilder,
4688 normal_loading_metadata: NormalLoadingMetadata,
4689 attention_mechanism: AttentionImplementation,
4690 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4691 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4692
4693 Ok(Box::new(models::gpt_oss::Model::new(
4694 &cfg,
4695 vb,
4696 self.is_gptx(config)?,
4697 normal_loading_metadata,
4698 attention_mechanism,
4699 )?))
4700 }
4701 fn load_xlora(
4702 &self,
4703 _config: &str,
4704 _vb: ShardedVarBuilder,
4705 _lora_config: &[((String, String), LoraConfig)],
4706 _xlora_config: Option<XLoraConfig>,
4707 _xlora_ordering: Ordering,
4708 _normal_loading_metadata: NormalLoadingMetadata,
4709 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4710 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4711 anyhow::bail!("GPT-OSS does not support X-LoRA")
4712 }
4713 fn is_gptx(&self, _: &str) -> Result<bool> {
4714 Ok(true)
4715 }
4716 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4717 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4718 Ok(Box::new(cfg))
4719 }
4720 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4721 Ok(false)
4722 }
4723}
4724
4725impl IsqModelLoader for GptOssLoader {
4726 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4727 Ok(vec![
4729 Regex::new(r"lm_head\.(weight|bias)$")?,
4730 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4732 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4733 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4734 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4735 ])
4736 }
4737 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4738 self.isq_layer_regexes(config)
4739 }
4740}
4741
4742impl DeviceMappedModelLoader for GptOssLoader {
4743 fn mapped_max_act_size_elems(
4744 &self,
4745 config: &str,
4746 params: &AutoDeviceMapParams,
4747 ) -> Result<usize> {
4748 let AutoDeviceMapParams::Text {
4749 max_seq_len,
4750 max_batch_size,
4751 } = params
4752 else {
4753 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4754 };
4755
4756 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4757
4758 Ok(
4759 max_batch_size
4760 * cfg.num_attention_heads
4761 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4762 )
4763 }
4764 fn non_mapped_max_act_size_elems(
4765 &self,
4766 _config: &str,
4767 _params: &AutoDeviceMapParams,
4768 ) -> Result<usize> {
4769 Ok(0)
4770 }
4771
4772 fn non_mapped_size_in_bytes(
4773 &self,
4774 config: &str,
4775 dtype: DType,
4776 weight_pack_factor: usize,
4777 _matformer_config: Option<&MatformerSliceConfig>,
4778 ) -> Result<usize> {
4779 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4780
4781 let elems = {
4782 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4783 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4784 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4785 } else {
4786 0
4787 };
4788 let norm = cfg.hidden_size;
4789 embed_tokens + lm_head + norm
4790 };
4791 Ok(elems * dtype.size_in_bytes())
4792 }
4793
4794 fn layer_sizes_in_bytes(
4795 &self,
4796 config: &str,
4797 dtype: DType,
4798 weight_pack_factor: usize,
4799 _matformer_config: Option<&MatformerSliceConfig>,
4800 ) -> Result<Vec<usize>> {
4801 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4802
4803 let per_layer_elems = {
4804 let input_layernorm = cfg.hidden_size;
4805 let post_attention_layernorm = cfg.hidden_size;
4806
4807 let size_in = cfg.hidden_size;
4808 let head_dim = cfg.head_dim();
4809 let size_q = head_dim * cfg.num_attention_heads;
4810 let size_kv = head_dim * cfg.num_key_value_heads;
4811 let q_proj =
4812 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4813 let k_proj =
4814 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4815 let v_proj =
4816 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4817 let o_proj =
4818 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4819
4820 let mxfp4_pack = 2;
4825 let gate_up_proj_size =
4826 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4827 let down_proj_size =
4828 cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4829 let gate_up_scales =
4831 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4832 let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4833 let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4835 let down_bias = cfg.num_local_experts * cfg.hidden_size;
4836 let router = cfg.hidden_size * cfg.num_local_experts;
4838 let sinks = cfg.num_attention_heads;
4840
4841 input_layernorm
4842 + post_attention_layernorm
4843 + q_proj
4844 + k_proj
4845 + v_proj
4846 + o_proj
4847 + gate_up_proj_size
4848 + down_proj_size
4849 + gate_up_scales
4850 + down_scales
4851 + gate_up_bias
4852 + down_bias
4853 + router
4854 + sinks
4855 };
4856 Ok(vec![
4857 per_layer_elems * dtype.size_in_bytes();
4858 cfg.num_hidden_layers
4859 ])
4860 }
4861
4862 fn num_layers(&self, config: &str) -> Result<usize> {
4863 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4864
4865 Ok(cfg.num_hidden_layers)
4866 }
4867 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4868 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4869
4870 let head_dim = cfg.head_dim();
4871 let cfg = ModelConfigMetadata {
4872 max_seq_len: cfg.max_position_embeddings,
4873 num_layers: cfg.num_hidden_layers,
4874 hidden_size: cfg.hidden_size,
4875 num_kv_heads: cfg.num_key_value_heads,
4876 num_attn_heads: cfg.num_attention_heads,
4877 sliding_window: cfg.sliding_window,
4878 k_head_dim: head_dim,
4879 v_head_dim: head_dim,
4880 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4881 };
4882
4883 Ok(Box::new(cfg))
4884 }
4885}