1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::speculative::SpeculativeTargetMixin;
11use crate::{
12 amoe::AnyMoeBaseModelMixin,
13 device_map::DeviceMapper,
14 lora::{LoraConfig, Ordering},
15 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
16 pipeline::{
17 isq::IsqModelLoader,
18 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
19 EitherCache, IsqModel,
20 },
21 utils::varbuilder_utils::DeviceForLoadTensor,
22 xlora_models::NonGranularState,
23};
24use anyhow::Result;
25use hanzo_ml::{DType, Device, Tensor};
26use hanzo_quant::log::once_log_debug;
27
28use hanzo_quant::ShardedVarBuilder;
29use indicatif::MultiProgress;
30#[cfg(feature = "pyo3_macros")]
31use pyo3::pyclass;
32
33use regex::Regex;
34use serde::Deserialize;
35
36use crate::{
37 models,
38 xlora_models::{self, XLoraConfig},
39};
40
41use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
42
43pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin + SpeculativeTargetMixin {
44 #[allow(clippy::too_many_arguments)]
45 fn forward(
46 &self,
47 input_ids: &Tensor,
48 seqlen_offsets: &[usize],
49 context_lens: Vec<(usize, usize)>,
50 position_ids: Vec<usize>,
51 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
52 flash_params: &FlashParams,
53 ) -> hanzo_ml::Result<Tensor>;
54 #[allow(clippy::too_many_arguments)]
55 fn xlora_forward(
56 &self,
57 input_ids: &Tensor,
58 input_ids_full: &Tensor,
59 seqlen_offsets: &[usize],
60 seqlen_offsets_full: &[usize],
61 no_kv_cache: bool,
62 non_granular_state: &Option<NonGranularState>,
63 context_lens: Vec<(usize, usize)>,
64 position_ids: Vec<usize>,
65 flash_params: &FlashParams,
66 flash_params_full: &FlashParams,
67 ) -> hanzo_ml::Result<Tensor>;
68 fn is_xlora(&self) -> bool;
69 fn device(&self) -> &Device;
70 fn cache(&self) -> &EitherCache;
71 fn cache_mut(&mut self) -> &mut EitherCache;
72 fn max_seq_len(&self) -> usize;
73 fn config(&self) -> &ModelConfigMetadata;
74 fn model_config(&self) -> Arc<dyn ModelConfigLike + Send + Sync> {
75 Arc::new(self.config().clone())
76 }
77}
78
79pub struct NormalLoadingMetadata {
81 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
83 pub loading_isq: bool,
85 pub real_device: Device,
87 pub multi_progress: Arc<MultiProgress>,
89 pub matformer_slicing_config: Option<MatformerSliceConfig>,
91}
92
93pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
94 fn load(
95 &self,
96 config: &str,
97 vb: ShardedVarBuilder,
98 normal_loading_metadata: NormalLoadingMetadata,
99 attention_mechanism: AttentionImplementation,
100 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
101 #[allow(clippy::too_many_arguments)]
102 fn load_xlora(
103 &self,
104 config: &str,
105 vb: ShardedVarBuilder,
106 lora_config: &[((String, String), LoraConfig)],
107 xlora_config: Option<XLoraConfig>,
108 xlora_ordering: Ordering,
109 normal_loading_metadata: NormalLoadingMetadata,
110 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
111 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
112 fn is_gptx(&self, config: &str) -> Result<bool>;
113 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
114 Ok(true)
115 }
116 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
117 fn get_device_for_tensor(
118 &self,
119 config: &str,
120 _mapper: &dyn DeviceMapper,
121 loading_isq: bool,
122 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
123 if loading_isq {
124 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
125 } else {
126 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
127 let num_layers = self.model_config(config)?.num_layers();
128 let closure = move |name: String| {
129 if let Some(captures) = re.captures(&name) {
130 captures
131 .get(1)
132 .and_then(|m| m.as_str().parse::<usize>().ok())
133 .map(|l| l.min(num_layers))
134 .map(DeviceForLoadTensor::Idx)
135 .unwrap_or(DeviceForLoadTensor::Base)
136 } else {
137 DeviceForLoadTensor::Base
138 }
139 };
140
141 Ok(Arc::new(closure))
142 }
143 }
144}
145
146#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
147#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
148pub enum NormalLoaderType {
150 #[serde(rename = "mistral")]
151 Mistral,
152 #[serde(rename = "gemma")]
153 Gemma,
154 #[serde(rename = "mixtral")]
155 Mixtral,
156 #[serde(rename = "llama")]
157 Llama,
158 #[serde(rename = "phi2")]
159 Phi2,
160 #[serde(rename = "phi3")]
161 Phi3,
162 #[serde(rename = "qwen2")]
163 Qwen2,
164 #[serde(rename = "gemma2")]
165 Gemma2,
166 #[serde(rename = "starcoder2")]
167 Starcoder2,
168 #[serde(rename = "phi3.5moe")]
169 Phi3_5MoE,
170 #[serde(rename = "deepseekv2")]
171 DeepSeekV2,
172 #[serde(rename = "deepseekv3")]
173 DeepSeekV3,
174 #[serde(rename = "qwen3")]
175 Qwen3,
176 #[serde(rename = "glm4")]
177 GLM4,
178 #[serde(rename = "glm4moelite")]
179 GLM4MoeLite,
180 #[serde(rename = "glm4moe")]
181 GLM4Moe,
182 #[serde(rename = "qwen3moe")]
183 Qwen3Moe,
184 #[serde(rename = "smollm3")]
185 SmolLm3,
186 #[serde(rename = "granitemoehybrid")]
187 GraniteMoeHybrid,
188 #[serde(rename = "gpt_oss")]
189 GptOss,
190 #[serde(rename = "qwen3next")]
191 Qwen3Next,
192}
193
194impl NormalLoaderType {
196 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
197 match name {
198 "MistralForCausalLM" => Ok(Self::Mistral),
199 "MixtralForCausalLM" => Ok(Self::Mixtral),
200 "GemmaForCausalLM" => Ok(Self::Gemma),
201 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
202 "PhiForCausalLM" => Ok(Self::Phi2),
203 "Phi3ForCausalLM" => Ok(Self::Phi3),
204 "LlamaForCausalLM" => Ok(Self::Llama),
205 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
206 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
207 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
208 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
209 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
210 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
211 "Glm4ForCausalLM" => Ok(Self::GLM4),
212 "Glm4MoeLiteForCausalLM" => Ok(Self::GLM4MoeLite),
213 "Glm4MoeForCausalLM" => Ok(Self::GLM4Moe),
214 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
215 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
216 "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
217 "GptOssForCausalLM" => Ok(Self::GptOss),
218 "Qwen3NextForCausalLM" => Ok(Self::Qwen3Next),
219 other => anyhow::bail!(
220 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
221 ),
222 }
223 }
224}
225
226impl FromStr for NormalLoaderType {
227 type Err = String;
228 fn from_str(s: &str) -> Result<Self, Self::Err> {
229 match s {
230 "mistral" => Ok(Self::Mistral),
231 "gemma" => Ok(Self::Gemma),
232 "mixtral" => Ok(Self::Mixtral),
233 "llama" => Ok(Self::Llama),
234 "phi2" => Ok(Self::Phi2),
235 "phi3" => Ok(Self::Phi3),
236 "qwen2" => Ok(Self::Qwen2),
237 "gemma2" => Ok(Self::Gemma2),
238 "starcoder2" => Ok(Self::Starcoder2),
239 "phi3.5moe" => Ok(Self::Phi3_5MoE),
240 "deepseekv2" => Ok(Self::DeepSeekV2),
241 "deepseekv3" => Ok(Self::DeepSeekV3),
242 "qwen3" => Ok(Self::Qwen3),
243 "glm4" => Ok(Self::GLM4),
244 "glm4moelite" => Ok(Self::GLM4MoeLite),
245 "glm4moe" => Ok(Self::GLM4Moe),
246 "qwen3moe" => Ok(Self::Qwen3Moe),
247 "smollm3" => Ok(Self::SmolLm3),
248 "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
249 "gpt_oss" => Ok(Self::GptOss),
250 "qwen3next" => Ok(Self::Qwen3Next),
251 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `glm4moelite`, `glm4moe`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`, `qwen3next`.")),
252 }
253 }
254}
255
256impl Display for NormalLoaderType {
257 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
258 match self {
259 Self::Gemma => write!(f, "gemma"),
260 Self::Gemma2 => write!(f, "gemma2"),
261 Self::Llama => write!(f, "llama"),
262 Self::Mistral => write!(f, "mistral"),
263 Self::Mixtral => write!(f, "mixtral"),
264 Self::Phi2 => write!(f, "phi2"),
265 Self::Phi3 => write!(f, "phi3"),
266 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
267 Self::Qwen2 => write!(f, "qwen2"),
268 Self::Starcoder2 => write!(f, "starcoder2"),
269 Self::DeepSeekV2 => write!(f, "deepseekv2"),
270 Self::DeepSeekV3 => write!(f, "deepseekv3"),
271 Self::Qwen3 => write!(f, "qwen3"),
272 Self::GLM4 => write!(f, "glm4"),
273 Self::GLM4MoeLite => write!(f, "glm4moelite"),
274 Self::GLM4Moe => write!(f, "glm4moe"),
275 Self::Qwen3Moe => write!(f, "qwen3moe"),
276 Self::SmolLm3 => write!(f, "smollm3"),
277 Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
278 Self::GptOss => write!(f, "gpt_oss"),
279 Self::Qwen3Next => write!(f, "qwen3next"),
280 }
281 }
282}
283
284macro_rules! bias_if {
285 ($cond:expr, $size:expr) => {
286 if $cond {
287 $size
288 } else {
289 0
290 }
291 };
292}
293
294pub struct AutoNormalLoader;
296
297#[derive(Deserialize)]
298struct AutoNormalLoaderConfig {
299 architectures: Vec<String>,
300}
301
302impl AutoNormalLoader {
303 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
304 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
305 if auto_cfg.architectures.len() != 1 {
306 anyhow::bail!("Expected to have one name for `architectures` config field.")
307 }
308
309 let name = &auto_cfg.architectures[0];
310
311 let tp = NormalLoaderType::from_causal_lm_name(name)?;
312
313 once_log_debug(format!("Automatic loader type determined to be `{tp}`"));
314
315 match tp {
316 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
317 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
318 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
319 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
320 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
321 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
322 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
323 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
324 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
325 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
326 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
327 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
328 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
329 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
330 NormalLoaderType::GLM4MoeLite => Ok(Box::new(GLM4MoeLiteLoader)),
331 NormalLoaderType::GLM4Moe => Ok(Box::new(GLM4MoeLoader)),
332 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
333 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
334 NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
335 NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
336 NormalLoaderType::Qwen3Next => Ok(Box::new(Qwen3NextLoader)),
337 }
338 }
339}
340
341impl NormalModelLoader for AutoNormalLoader {
342 fn load(
343 &self,
344 config: &str,
345 vb: ShardedVarBuilder,
346 normal_loading_metadata: NormalLoadingMetadata,
347 attention_mechanism: AttentionImplementation,
348 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
349 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
350 }
351 fn load_xlora(
352 &self,
353 config: &str,
354 vb: ShardedVarBuilder,
355 lora_config: &[((String, String), LoraConfig)],
356 xlora_config: Option<XLoraConfig>,
357 xlora_ordering: Ordering,
358 normal_loading_metadata: NormalLoadingMetadata,
359 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
360 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
361 Self::get_loader(config)?.load_xlora(
362 config,
363 vb,
364 lora_config,
365 xlora_config,
366 xlora_ordering,
367 normal_loading_metadata,
368 preload_adapters,
369 )
370 }
371 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
372 Self::get_loader(config)?.get_config_repr(config)
373 }
374 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
375 Self::get_loader(config)?.supports_paged_attention(config)
376 }
377 fn is_gptx(&self, config: &str) -> Result<bool> {
378 Self::get_loader(config)?.is_gptx(config)
379 }
380}
381
382impl IsqModelLoader for AutoNormalLoader {
383 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
384 Self::get_loader(config)?.immediate_isq_predicates(config)
385 }
386 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
387 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
388 }
389 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
390 Self::get_loader(config)?.isq_layer_regexes(config)
391 }
392 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
393 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
394 }
395}
396
397impl DeviceMappedModelLoader for AutoNormalLoader {
398 fn non_mapped_size_in_bytes(
399 &self,
400 config: &str,
401 dtype: DType,
402 weight_pack_factor: usize,
403 _matformer_config: Option<&MatformerSliceConfig>,
404 ) -> Result<usize> {
405 Self::get_loader(config)?.non_mapped_size_in_bytes(
406 config,
407 dtype,
408 weight_pack_factor,
409 _matformer_config,
410 )
411 }
412 fn num_layers(&self, config: &str) -> Result<usize> {
413 Self::get_loader(config)?.num_layers(config)
414 }
415 fn layer_sizes_in_bytes(
416 &self,
417 config: &str,
418 dtype: DType,
419 weight_pack_factor: usize,
420 _matformer_config: Option<&MatformerSliceConfig>,
421 ) -> Result<Vec<usize>> {
422 Self::get_loader(config)?.layer_sizes_in_bytes(
423 config,
424 dtype,
425 weight_pack_factor,
426 _matformer_config,
427 )
428 }
429 fn mapped_max_act_size_elems(
430 &self,
431 config: &str,
432 params: &super::AutoDeviceMapParams,
433 ) -> Result<usize> {
434 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
435 }
436 fn non_mapped_max_act_size_elems(
437 &self,
438 _config: &str,
439 _params: &AutoDeviceMapParams,
440 ) -> Result<usize> {
441 Ok(0)
442 }
443 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
444 Self::get_loader(config)?.model_config(config)
445 }
446}
447
448pub struct MistralLoader;
451
452impl NormalModelLoader for MistralLoader {
453 fn load(
454 &self,
455 config: &str,
456 vb: ShardedVarBuilder,
457 normal_loading_metadata: NormalLoadingMetadata,
458 attention_mechanism: AttentionImplementation,
459 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
460 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
461 Ok(Box::new(models::mistral::Model::new(
462 &cfg,
463 vb,
464 self.is_gptx(config)?,
465 normal_loading_metadata,
466 attention_mechanism,
467 )?))
468 }
469 fn load_xlora(
470 &self,
471 config: &str,
472 vb: ShardedVarBuilder,
473 lora_config: &[((String, String), LoraConfig)],
474 xlora_config: Option<XLoraConfig>,
475 xlora_ordering: Ordering,
476 normal_loading_metadata: NormalLoadingMetadata,
477 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
478 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
479 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
480 Ok(Box::new(xlora_models::XLoraMistral::new(
481 &cfg,
482 vb,
483 lora_config,
484 xlora_config,
485 xlora_ordering,
486 self.is_gptx(config)?,
487 normal_loading_metadata,
488 preload_adapters,
489 )?))
490 }
491 fn is_gptx(&self, _: &str) -> Result<bool> {
492 Ok(true)
493 }
494 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
495 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
496 Ok(Box::new(cfg))
497 }
498}
499
500impl IsqModelLoader for MistralLoader {
501 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
502 Ok(vec![
503 Regex::new(r"lm_head\.(weight|bias)$")?,
504 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
506 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
507 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
508 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
509 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
511 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
512 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
513 ])
514 }
515 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
516 self.isq_layer_regexes(config)
517 }
518}
519
520impl DeviceMappedModelLoader for MistralLoader {
521 fn mapped_max_act_size_elems(
522 &self,
523 config: &str,
524 params: &AutoDeviceMapParams,
525 ) -> Result<usize> {
526 let AutoDeviceMapParams::Text {
527 max_seq_len,
528 max_batch_size,
529 } = params
530 else {
531 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
532 };
533
534 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
535
536 Ok(
537 max_batch_size
538 * cfg.num_attention_heads
539 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
540 )
541 }
542 fn non_mapped_max_act_size_elems(
543 &self,
544 _config: &str,
545 _params: &AutoDeviceMapParams,
546 ) -> Result<usize> {
547 Ok(0)
548 }
549
550 fn non_mapped_size_in_bytes(
551 &self,
552 config: &str,
553 dtype: DType,
554 weight_pack_factor: usize,
555 _matformer_config: Option<&MatformerSliceConfig>,
556 ) -> Result<usize> {
557 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
558
559 let elems = {
560 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
561 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
563 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
564 } else {
565 0
566 };
567 let norm = cfg.hidden_size;
568 embed_tokens + lm_head + norm
569 };
570 Ok(elems * dtype.size_in_bytes())
571 }
572
573 fn layer_sizes_in_bytes(
574 &self,
575 config: &str,
576 dtype: DType,
577 weight_pack_factor: usize,
578 _matformer_config: Option<&MatformerSliceConfig>,
579 ) -> Result<Vec<usize>> {
580 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
581
582 let per_layer_elems = {
583 let input_layernorm = cfg.hidden_size;
584 let post_attention_layernorm = cfg.hidden_size;
585
586 let size_in = cfg.hidden_size;
587 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
588 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
589 let q_proj = size_in * size_q / weight_pack_factor;
590 let k_proj = size_in * size_kv / weight_pack_factor;
591 let v_proj = size_in * size_kv / weight_pack_factor;
592 let o_proj = size_q * size_in / weight_pack_factor;
593
594 let h_size = cfg.hidden_size;
595 let i_size = cfg.intermediate_size;
596 let gate_proj = h_size * i_size / weight_pack_factor;
597 let up_proj = h_size * i_size / weight_pack_factor;
598 let down_proj = i_size * h_size / weight_pack_factor;
599
600 input_layernorm
601 + post_attention_layernorm
602 + q_proj
603 + k_proj
604 + v_proj
605 + o_proj
606 + gate_proj
607 + up_proj
608 + down_proj
609 };
610 Ok(vec![
611 per_layer_elems * dtype.size_in_bytes();
612 cfg.num_hidden_layers
613 ])
614 }
615
616 fn num_layers(&self, config: &str) -> Result<usize> {
617 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
618 Ok(cfg.num_hidden_layers)
619 }
620
621 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
622 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
623
624 let cfg = ModelConfigMetadata {
625 max_seq_len: cfg.max_position_embeddings,
626 num_layers: cfg.num_hidden_layers,
627 hidden_size: cfg.hidden_size,
628 num_kv_heads: cfg.num_key_value_heads,
629 num_attn_heads: cfg.num_attention_heads,
630 sliding_window: cfg.sliding_window,
631 k_head_dim: cfg.head_dim(),
632 v_head_dim: cfg.head_dim(),
633 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
634 };
635
636 Ok(Box::new(cfg))
637 }
638}
639
640pub struct GemmaLoader;
646
647impl NormalModelLoader for GemmaLoader {
648 fn load(
649 &self,
650 config: &str,
651 vb: ShardedVarBuilder,
652 normal_loading_metadata: NormalLoadingMetadata,
653 attention_mechanism: AttentionImplementation,
654 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
655 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
656
657 Ok(Box::new(models::gemma::Model::new(
658 &cfg,
659 vb,
660 self.is_gptx(config)?,
661 normal_loading_metadata,
662 attention_mechanism,
663 )?))
664 }
665 fn load_xlora(
666 &self,
667 config: &str,
668 vb: ShardedVarBuilder,
669 lora_config: &[((String, String), LoraConfig)],
670 xlora_config: Option<XLoraConfig>,
671 xlora_ordering: Ordering,
672 normal_loading_metadata: NormalLoadingMetadata,
673 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
674 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
675 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
676
677 Ok(Box::new(xlora_models::XLoraGemma::new(
678 &cfg,
679 vb,
680 lora_config,
681 xlora_config,
682 xlora_ordering,
683 self.is_gptx(config)?,
684 normal_loading_metadata,
685 preload_adapters,
686 )?))
687 }
688 fn is_gptx(&self, _: &str) -> Result<bool> {
689 Ok(true)
690 }
691 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
692 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
693 Ok(Box::new(cfg))
694 }
695}
696
697impl IsqModelLoader for GemmaLoader {
698 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
699 Ok(vec![
700 Regex::new(r"lm_head\.(weight|bias)$")?,
701 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
703 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
704 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
705 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
706 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
708 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
709 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
710 ])
711 }
712 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
713 self.isq_layer_regexes(config)
714 }
715}
716
717impl DeviceMappedModelLoader for GemmaLoader {
718 fn mapped_max_act_size_elems(
719 &self,
720 config: &str,
721 params: &AutoDeviceMapParams,
722 ) -> Result<usize> {
723 let AutoDeviceMapParams::Text {
724 max_seq_len,
725 max_batch_size,
726 } = params
727 else {
728 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
729 };
730
731 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
732
733 Ok(
734 max_batch_size
735 * cfg.num_attention_heads
736 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
737 )
738 }
739 fn non_mapped_max_act_size_elems(
740 &self,
741 _config: &str,
742 _params: &AutoDeviceMapParams,
743 ) -> Result<usize> {
744 Ok(0)
745 }
746
747 fn non_mapped_size_in_bytes(
748 &self,
749 config: &str,
750 dtype: DType,
751 weight_pack_factor: usize,
752 _matformer_config: Option<&MatformerSliceConfig>,
753 ) -> Result<usize> {
754 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
755
756 let elems = {
757 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
758 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
760 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
761 } else {
762 0
763 };
764 let norm = cfg.hidden_size;
765 embed_tokens + lm_head + norm
766 };
767 Ok(elems * dtype.size_in_bytes())
768 }
769
770 fn layer_sizes_in_bytes(
771 &self,
772 config: &str,
773 dtype: DType,
774 weight_pack_factor: usize,
775 _matformer_config: Option<&MatformerSliceConfig>,
776 ) -> Result<Vec<usize>> {
777 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
778
779 let per_layer_elems = {
780 let input_layernorm = cfg.hidden_size;
781 let post_attention_layernorm = cfg.hidden_size;
782
783 let size_in = cfg.hidden_size;
784 let size_q = cfg.head_dim * cfg.num_attention_heads;
785 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
786 let q_proj =
787 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
788 let k_proj =
789 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
790 let v_proj =
791 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
792 let o_proj =
793 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
794
795 let h_size = cfg.hidden_size;
796 let i_size = cfg.intermediate_size;
797 let gate_proj = h_size * i_size / weight_pack_factor;
798 let up_proj = h_size * i_size / weight_pack_factor;
799 let down_proj = i_size * h_size / weight_pack_factor;
800
801 input_layernorm
802 + post_attention_layernorm
803 + q_proj
804 + k_proj
805 + v_proj
806 + o_proj
807 + gate_proj
808 + up_proj
809 + down_proj
810 };
811 Ok(vec![
812 per_layer_elems * dtype.size_in_bytes();
813 cfg.num_hidden_layers
814 ])
815 }
816
817 fn num_layers(&self, config: &str) -> Result<usize> {
818 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
819 Ok(cfg.num_hidden_layers)
820 }
821
822 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
823 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
824
825 let cfg = ModelConfigMetadata {
826 max_seq_len: cfg.max_position_embeddings,
827 num_layers: cfg.num_hidden_layers,
828 hidden_size: cfg.hidden_size,
829 num_kv_heads: cfg.num_key_value_heads,
830 num_attn_heads: cfg.num_attention_heads,
831 sliding_window: None,
832 k_head_dim: cfg.head_dim,
833 v_head_dim: cfg.head_dim,
834 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
835 };
836
837 Ok(Box::new(cfg))
838 }
839}
840
841pub struct LlamaLoader;
847
848impl NormalModelLoader for LlamaLoader {
849 fn load(
850 &self,
851 config: &str,
852 vb: ShardedVarBuilder,
853 normal_loading_metadata: NormalLoadingMetadata,
854 attention_mechanism: AttentionImplementation,
855 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
856 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
857
858 Ok(Box::new(models::llama::Llama::new(
859 &cfg,
860 vb,
861 self.is_gptx(config)?,
862 normal_loading_metadata,
863 attention_mechanism,
864 )?))
865 }
866 fn load_xlora(
867 &self,
868 config: &str,
869 vb: ShardedVarBuilder,
870 lora_config: &[((String, String), LoraConfig)],
871 xlora_config: Option<XLoraConfig>,
872 xlora_ordering: Ordering,
873 normal_loading_metadata: NormalLoadingMetadata,
874 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
875 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
876 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
877
878 Ok(Box::new(xlora_models::XLoraLlama::new(
879 &cfg,
880 vb,
881 lora_config,
882 xlora_config,
883 xlora_ordering,
884 self.is_gptx(config)?,
885 normal_loading_metadata,
886 preload_adapters,
887 )?))
888 }
889 fn is_gptx(&self, _: &str) -> Result<bool> {
890 Ok(true)
891 }
892 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
893 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
894 Ok(Box::new(cfg))
895 }
896}
897
898impl IsqModelLoader for LlamaLoader {
899 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
900 Ok(vec![
901 Regex::new(r"lm_head\.(weight|bias)$")?,
902 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
904 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
905 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
906 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
907 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
909 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
910 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
911 ])
912 }
913 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
914 self.isq_layer_regexes(config)
915 }
916}
917
918impl DeviceMappedModelLoader for LlamaLoader {
919 fn mapped_max_act_size_elems(
920 &self,
921 config: &str,
922 params: &AutoDeviceMapParams,
923 ) -> Result<usize> {
924 let AutoDeviceMapParams::Text {
925 max_seq_len,
926 max_batch_size,
927 } = params
928 else {
929 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
930 };
931
932 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
933
934 Ok(
935 max_batch_size
936 * cfg.num_attention_heads
937 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
938 )
939 }
940 fn non_mapped_max_act_size_elems(
941 &self,
942 _config: &str,
943 _params: &AutoDeviceMapParams,
944 ) -> Result<usize> {
945 Ok(0)
946 }
947
948 fn non_mapped_size_in_bytes(
949 &self,
950 config: &str,
951 dtype: DType,
952 weight_pack_factor: usize,
953 _matformer_config: Option<&MatformerSliceConfig>,
954 ) -> Result<usize> {
955 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
956
957 let elems = {
958 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
959 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
961 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
962 } else {
963 0
964 };
965 let norm = cfg.hidden_size;
966 embed_tokens + lm_head + norm
967 };
968 Ok(elems * dtype.size_in_bytes())
969 }
970
971 fn layer_sizes_in_bytes(
972 &self,
973 config: &str,
974 dtype: DType,
975 weight_pack_factor: usize,
976 _matformer_config: Option<&MatformerSliceConfig>,
977 ) -> Result<Vec<usize>> {
978 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
979
980 let per_layer_elems = {
981 let input_layernorm = cfg.hidden_size;
982 let post_attention_layernorm = cfg.hidden_size;
983
984 let size_in = cfg.hidden_size;
985 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
986 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
987 let q_proj = size_in * size_q / weight_pack_factor;
988 let k_proj = size_in * size_kv / weight_pack_factor;
989 let v_proj = size_in * size_kv / weight_pack_factor;
990 let o_proj = size_q * size_in / weight_pack_factor;
991
992 let h_size = cfg.hidden_size;
993 let i_size = cfg.intermediate_size;
994 let gate_proj = h_size * i_size / weight_pack_factor;
995 let up_proj = h_size * i_size / weight_pack_factor;
996 let down_proj = i_size * h_size / weight_pack_factor;
997
998 input_layernorm
999 + post_attention_layernorm
1000 + q_proj
1001 + k_proj
1002 + v_proj
1003 + o_proj
1004 + gate_proj
1005 + up_proj
1006 + down_proj
1007 };
1008 Ok(vec![
1009 per_layer_elems * dtype.size_in_bytes();
1010 cfg.num_hidden_layers
1011 ])
1012 }
1013
1014 fn num_layers(&self, config: &str) -> Result<usize> {
1015 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1016
1017 Ok(cfg.num_hidden_layers)
1018 }
1019 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1020 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1021
1022 let cfg = ModelConfigMetadata {
1023 max_seq_len: cfg.max_position_embeddings,
1024 num_layers: cfg.num_hidden_layers,
1025 hidden_size: cfg.hidden_size,
1026 num_kv_heads: cfg.num_key_value_heads,
1027 num_attn_heads: cfg.num_attention_heads,
1028 sliding_window: None,
1029 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1030 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1031 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1032 };
1033
1034 Ok(Box::new(cfg))
1035 }
1036}
1037
1038pub struct MixtralLoader;
1041
1042impl NormalModelLoader for MixtralLoader {
1043 fn load(
1044 &self,
1045 config: &str,
1046 vb: ShardedVarBuilder,
1047 normal_loading_metadata: NormalLoadingMetadata,
1048 attention_mechanism: AttentionImplementation,
1049 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1050 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1051
1052 Ok(Box::new(models::mixtral::Model::new(
1053 &cfg,
1054 vb,
1055 self.is_gptx(config)?,
1056 normal_loading_metadata,
1057 attention_mechanism,
1058 )?))
1059 }
1060 fn load_xlora(
1061 &self,
1062 config: &str,
1063 vb: ShardedVarBuilder,
1064 lora_config: &[((String, String), LoraConfig)],
1065 xlora_config: Option<XLoraConfig>,
1066 xlora_ordering: Ordering,
1067 normal_loading_metadata: NormalLoadingMetadata,
1068 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1069 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1070 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1071
1072 Ok(Box::new(xlora_models::XLoraMixtral::new(
1073 &cfg,
1074 vb,
1075 lora_config,
1076 xlora_config,
1077 xlora_ordering,
1078 self.is_gptx(config)?,
1079 normal_loading_metadata,
1080 preload_adapters,
1081 )?))
1082 }
1083 fn is_gptx(&self, _: &str) -> Result<bool> {
1084 Ok(true)
1085 }
1086 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1087 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1088
1089 Ok(Box::new(cfg))
1090 }
1091}
1092
1093impl IsqModelLoader for MixtralLoader {
1094 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1095 Ok(vec![
1096 Regex::new(r"lm_head\.(weight|bias)$")?,
1097 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1099 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1100 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1101 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1102 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1104 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1105 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1106 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1107 ])
1108 }
1109 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1110 self.isq_layer_regexes(config)
1111 }
1112}
1113
1114impl DeviceMappedModelLoader for MixtralLoader {
1115 fn mapped_max_act_size_elems(
1116 &self,
1117 config: &str,
1118 params: &AutoDeviceMapParams,
1119 ) -> Result<usize> {
1120 let AutoDeviceMapParams::Text {
1121 max_seq_len,
1122 max_batch_size,
1123 } = params
1124 else {
1125 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1126 };
1127
1128 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1129
1130 Ok(
1131 max_batch_size
1132 * cfg.num_attention_heads
1133 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1134 )
1135 }
1136 fn non_mapped_max_act_size_elems(
1137 &self,
1138 _config: &str,
1139 _params: &AutoDeviceMapParams,
1140 ) -> Result<usize> {
1141 Ok(0)
1142 }
1143
1144 fn non_mapped_size_in_bytes(
1145 &self,
1146 config: &str,
1147 dtype: DType,
1148 weight_pack_factor: usize,
1149 _matformer_config: Option<&MatformerSliceConfig>,
1150 ) -> Result<usize> {
1151 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1152
1153 let elems = {
1154 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1155 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1157 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1158 } else {
1159 0
1160 };
1161 let norm = cfg.hidden_size;
1162 embed_tokens + lm_head + norm
1163 };
1164 Ok(elems * dtype.size_in_bytes())
1165 }
1166
1167 fn layer_sizes_in_bytes(
1168 &self,
1169 config: &str,
1170 dtype: DType,
1171 weight_pack_factor: usize,
1172 _matformer_config: Option<&MatformerSliceConfig>,
1173 ) -> Result<Vec<usize>> {
1174 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1175
1176 let per_layer_elems = {
1177 let input_layernorm = cfg.hidden_size;
1178 let post_attention_layernorm = cfg.hidden_size;
1179
1180 let size_in = cfg.hidden_size;
1181 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1182 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1183 let q_proj = size_in * size_q / weight_pack_factor;
1184 let k_proj = size_in * size_kv / weight_pack_factor;
1185 let v_proj = size_in * size_kv / weight_pack_factor;
1186 let o_proj = size_q * size_in / weight_pack_factor;
1187
1188 let moe_block = {
1189 let gate = cfg.hidden_size * cfg.num_local_experts;
1190 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1192 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1193 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1194 gate + cfg.num_local_experts * w1
1195 + cfg.num_local_experts * w2
1196 + cfg.num_local_experts * w3
1197 };
1198
1199 input_layernorm
1200 + post_attention_layernorm
1201 + q_proj
1202 + k_proj
1203 + v_proj
1204 + o_proj
1205 + moe_block
1206 };
1207 Ok(vec![
1208 per_layer_elems * dtype.size_in_bytes();
1209 cfg.num_hidden_layers
1210 ])
1211 }
1212
1213 fn num_layers(&self, config: &str) -> Result<usize> {
1214 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1215
1216 Ok(cfg.num_hidden_layers)
1217 }
1218
1219 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1220 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1221
1222 let cfg = ModelConfigMetadata {
1223 max_seq_len: cfg.max_position_embeddings,
1224 num_layers: cfg.num_hidden_layers,
1225 hidden_size: cfg.hidden_size,
1226 num_kv_heads: cfg.num_key_value_heads,
1227 num_attn_heads: cfg.num_attention_heads,
1228 sliding_window: cfg.sliding_window,
1229 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1230 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1231 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1232 };
1233
1234 Ok(Box::new(cfg))
1235 }
1236}
1237
1238pub struct Phi2Loader;
1244
1245impl NormalModelLoader for Phi2Loader {
1246 fn load(
1247 &self,
1248 config: &str,
1249 vb: ShardedVarBuilder,
1250 normal_loading_metadata: NormalLoadingMetadata,
1251 attention_mechanism: AttentionImplementation,
1252 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1253 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1254
1255 Ok(Box::new(models::phi2::Model::new(
1256 &cfg,
1257 vb,
1258 self.is_gptx(config)?,
1259 normal_loading_metadata,
1260 attention_mechanism,
1261 )?))
1262 }
1263 fn load_xlora(
1264 &self,
1265 config: &str,
1266 vb: ShardedVarBuilder,
1267 lora_config: &[((String, String), LoraConfig)],
1268 xlora_config: Option<XLoraConfig>,
1269 xlora_ordering: Ordering,
1270 normal_loading_metadata: NormalLoadingMetadata,
1271 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1272 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1273 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1274
1275 Ok(Box::new(xlora_models::XLoraPhi2::new(
1276 &cfg,
1277 vb,
1278 lora_config,
1279 xlora_config,
1280 xlora_ordering,
1281 self.is_gptx(config)?,
1282 normal_loading_metadata,
1283 preload_adapters,
1284 )?))
1285 }
1286 fn is_gptx(&self, _: &str) -> Result<bool> {
1287 Ok(true)
1288 }
1289 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1290 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1291
1292 Ok(Box::new(cfg))
1293 }
1294}
1295
1296impl IsqModelLoader for Phi2Loader {
1297 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1298 Ok(vec![
1299 Regex::new(r"lm_head\.(weight|bias)$")?,
1300 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1302 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1303 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1304 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1305 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1307 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1308 ])
1309 }
1310 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1311 self.isq_layer_regexes(config)
1312 }
1313}
1314
1315impl DeviceMappedModelLoader for Phi2Loader {
1316 fn mapped_max_act_size_elems(
1317 &self,
1318 config: &str,
1319 params: &AutoDeviceMapParams,
1320 ) -> Result<usize> {
1321 let AutoDeviceMapParams::Text {
1322 max_seq_len,
1323 max_batch_size,
1324 } = params
1325 else {
1326 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1327 };
1328
1329 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1330
1331 Ok(
1332 max_batch_size
1333 * cfg.num_attention_heads
1334 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1335 )
1336 }
1337 fn non_mapped_max_act_size_elems(
1338 &self,
1339 _config: &str,
1340 _params: &AutoDeviceMapParams,
1341 ) -> Result<usize> {
1342 Ok(0)
1343 }
1344
1345 fn non_mapped_size_in_bytes(
1346 &self,
1347 config: &str,
1348 dtype: DType,
1349 weight_pack_factor: usize,
1350 _matformer_config: Option<&MatformerSliceConfig>,
1351 ) -> Result<usize> {
1352 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1353
1354 let elems = {
1355 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1356 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1358 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1359 } else {
1360 0
1361 };
1362 let norm = cfg.hidden_size;
1363 embed_tokens + lm_head + norm
1364 };
1365 Ok(elems * dtype.size_in_bytes())
1366 }
1367
1368 fn layer_sizes_in_bytes(
1369 &self,
1370 config: &str,
1371 dtype: DType,
1372 weight_pack_factor: usize,
1373 _matformer_config: Option<&MatformerSliceConfig>,
1374 ) -> Result<Vec<usize>> {
1375 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1376
1377 let per_layer_elems = {
1378 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1379
1380 let size_in = cfg.hidden_size;
1381 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1382 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1383 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1384 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1385 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1386 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1387 let (q_norm, k_norm) = if cfg.qk_layernorm {
1388 (cfg.head_dim(), cfg.head_dim())
1389 } else {
1390 (0, 0)
1391 };
1392
1393 let h_size = cfg.hidden_size;
1394 let i_size = cfg.intermediate_size;
1395 let fc1 = h_size * i_size / weight_pack_factor;
1396 let fc2 = h_size * i_size / weight_pack_factor;
1397
1398 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1399 };
1400 Ok(vec![
1401 per_layer_elems * dtype.size_in_bytes();
1402 cfg.num_hidden_layers
1403 ])
1404 }
1405
1406 fn num_layers(&self, config: &str) -> Result<usize> {
1407 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1408
1409 Ok(cfg.num_hidden_layers)
1410 }
1411
1412 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1413 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1414
1415 let cfg = ModelConfigMetadata {
1416 max_seq_len: cfg.max_position_embeddings,
1417 num_layers: cfg.num_hidden_layers,
1418 hidden_size: cfg.hidden_size,
1419 num_kv_heads: cfg.num_key_value_heads(),
1420 num_attn_heads: cfg.num_attention_heads,
1421 sliding_window: None,
1422 k_head_dim: cfg.head_dim(),
1423 v_head_dim: cfg.head_dim(),
1424 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1425 };
1426
1427 Ok(Box::new(cfg))
1428 }
1429}
1430
1431pub struct Phi3Loader;
1437
1438impl NormalModelLoader for Phi3Loader {
1439 fn load(
1440 &self,
1441 config: &str,
1442 vb: ShardedVarBuilder,
1443 normal_loading_metadata: NormalLoadingMetadata,
1444 attention_mechanism: AttentionImplementation,
1445 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1446 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1447
1448 Ok(Box::new(models::phi3::Model::new(
1449 &cfg,
1450 vb,
1451 self.is_gptx(config)?,
1452 normal_loading_metadata,
1453 attention_mechanism,
1454 )?))
1455 }
1456 fn load_xlora(
1457 &self,
1458 config: &str,
1459 vb: ShardedVarBuilder,
1460 lora_config: &[((String, String), LoraConfig)],
1461 xlora_config: Option<XLoraConfig>,
1462 xlora_ordering: Ordering,
1463 normal_loading_metadata: NormalLoadingMetadata,
1464 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1465 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1466 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1467
1468 Ok(Box::new(xlora_models::XLoraPhi3::new(
1469 &cfg,
1470 vb,
1471 lora_config,
1472 xlora_config,
1473 xlora_ordering,
1474 self.is_gptx(config)?,
1475 normal_loading_metadata,
1476 preload_adapters,
1477 )?))
1478 }
1479 fn is_gptx(&self, _: &str) -> Result<bool> {
1480 Ok(true)
1481 }
1482 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1483 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1484
1485 Ok(Box::new(cfg))
1486 }
1487}
1488
1489impl IsqModelLoader for Phi3Loader {
1490 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1491 Ok(vec![
1492 Regex::new(r"lm_head\.(weight|bias)$")?,
1493 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1495 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1496 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1498 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1499 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1500 ])
1501 }
1502 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1503 self.isq_layer_regexes(config)
1504 }
1505}
1506
1507impl DeviceMappedModelLoader for Phi3Loader {
1508 fn mapped_max_act_size_elems(
1509 &self,
1510 config: &str,
1511 params: &AutoDeviceMapParams,
1512 ) -> Result<usize> {
1513 let AutoDeviceMapParams::Text {
1514 max_seq_len,
1515 max_batch_size,
1516 } = params
1517 else {
1518 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1519 };
1520
1521 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1522
1523 Ok(
1524 max_batch_size
1525 * cfg.num_attention_heads
1526 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1527 )
1528 }
1529 fn non_mapped_max_act_size_elems(
1530 &self,
1531 _config: &str,
1532 _params: &AutoDeviceMapParams,
1533 ) -> Result<usize> {
1534 Ok(0)
1535 }
1536
1537 fn non_mapped_size_in_bytes(
1538 &self,
1539 config: &str,
1540 dtype: DType,
1541 weight_pack_factor: usize,
1542 _matformer_config: Option<&MatformerSliceConfig>,
1543 ) -> Result<usize> {
1544 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1545
1546 let elems = {
1547 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1548 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1550 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1551 } else {
1552 0
1553 };
1554 let norm = cfg.hidden_size;
1555 embed_tokens + lm_head + norm
1556 };
1557 Ok(elems * dtype.size_in_bytes())
1558 }
1559
1560 fn layer_sizes_in_bytes(
1561 &self,
1562 config: &str,
1563 dtype: DType,
1564 weight_pack_factor: usize,
1565 _matformer_config: Option<&MatformerSliceConfig>,
1566 ) -> Result<Vec<usize>> {
1567 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1568
1569 let per_layer_elems = {
1570 let input_layernorm = cfg.hidden_size;
1571 let post_attention_layernorm = cfg.hidden_size;
1572
1573 let size_in = cfg.hidden_size;
1574 let head_dim = cfg.head_dim();
1575 let op_size =
1576 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1577 let qkv_proj = size_in * op_size / weight_pack_factor;
1578 let o_proj =
1579 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1580
1581 let h_size = cfg.hidden_size;
1582 let i_size = cfg.intermediate_size;
1583 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1584 let down_proj = h_size * i_size / weight_pack_factor;
1585
1586 input_layernorm
1587 + post_attention_layernorm
1588 + qkv_proj
1589 + o_proj
1590 + gate_up_proj
1591 + down_proj
1592 };
1593 Ok(vec![
1594 per_layer_elems * dtype.size_in_bytes();
1595 cfg.num_hidden_layers
1596 ])
1597 }
1598
1599 fn num_layers(&self, config: &str) -> Result<usize> {
1600 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1601
1602 Ok(cfg.num_hidden_layers)
1603 }
1604
1605 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1606 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1607
1608 let cfg = ModelConfigMetadata {
1609 max_seq_len: cfg.max_position_embeddings,
1610 num_layers: cfg.num_hidden_layers,
1611 hidden_size: cfg.hidden_size,
1612 num_kv_heads: cfg.num_key_value_heads,
1613 num_attn_heads: cfg.num_attention_heads,
1614 sliding_window: cfg.sliding_window,
1615 k_head_dim: cfg.head_dim(),
1616 v_head_dim: cfg.head_dim(),
1617 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1618 };
1619
1620 Ok(Box::new(cfg))
1621 }
1622}
1623
1624pub struct Qwen2Loader;
1630
1631impl NormalModelLoader for Qwen2Loader {
1632 fn load(
1633 &self,
1634 config: &str,
1635 vb: ShardedVarBuilder,
1636 normal_loading_metadata: NormalLoadingMetadata,
1637 attention_mechanism: AttentionImplementation,
1638 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1639 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1640
1641 Ok(Box::new(models::qwen2::Model::new(
1642 &cfg,
1643 vb,
1644 self.is_gptx(config)?,
1645 normal_loading_metadata,
1646 attention_mechanism,
1647 )?))
1648 }
1649 fn load_xlora(
1650 &self,
1651 _config: &str,
1652 _vb: ShardedVarBuilder,
1653 _lora_config: &[((String, String), LoraConfig)],
1654 _xlora_config: Option<XLoraConfig>,
1655 _xlora_ordering: Ordering,
1656 _normal_loading_metadata: NormalLoadingMetadata,
1657 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1658 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1659 todo!()
1660 }
1661 fn is_gptx(&self, _: &str) -> Result<bool> {
1662 Ok(true)
1663 }
1664 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1665 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1666
1667 Ok(Box::new(cfg))
1668 }
1669}
1670
1671impl IsqModelLoader for Qwen2Loader {
1672 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1673 Ok(vec![
1674 Regex::new(r"lm_head\.(weight|bias)$")?,
1675 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1677 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1678 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1679 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1680 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1682 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1683 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1684 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
1686 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
1687 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
1688 ])
1689 }
1690 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1691 self.isq_layer_regexes(config)
1692 }
1693}
1694
1695impl DeviceMappedModelLoader for Qwen2Loader {
1696 fn mapped_max_act_size_elems(
1697 &self,
1698 config: &str,
1699 params: &AutoDeviceMapParams,
1700 ) -> Result<usize> {
1701 let AutoDeviceMapParams::Text {
1702 max_seq_len,
1703 max_batch_size,
1704 } = params
1705 else {
1706 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1707 };
1708
1709 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1710
1711 Ok(
1712 max_batch_size
1713 * cfg.num_attention_heads
1714 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1715 )
1716 }
1717 fn non_mapped_max_act_size_elems(
1718 &self,
1719 _config: &str,
1720 _params: &AutoDeviceMapParams,
1721 ) -> Result<usize> {
1722 Ok(0)
1723 }
1724
1725 fn non_mapped_size_in_bytes(
1726 &self,
1727 config: &str,
1728 dtype: DType,
1729 weight_pack_factor: usize,
1730 _matformer_config: Option<&MatformerSliceConfig>,
1731 ) -> Result<usize> {
1732 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1733
1734 let elems = {
1735 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1736 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1738 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1739 } else {
1740 0
1741 };
1742 let norm = cfg.hidden_size;
1743 embed_tokens + lm_head + norm
1744 };
1745 Ok(elems * dtype.size_in_bytes())
1746 }
1747
1748 fn layer_sizes_in_bytes(
1749 &self,
1750 config: &str,
1751 dtype: DType,
1752 weight_pack_factor: usize,
1753 _matformer_config: Option<&MatformerSliceConfig>,
1754 ) -> Result<Vec<usize>> {
1755 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1756
1757 let per_layer_elems = {
1758 let input_layernorm = cfg.hidden_size;
1759 let post_attention_layernorm = cfg.hidden_size;
1760
1761 let size_in = cfg.hidden_size;
1762 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1763 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1764 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1765 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1766 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1767 let o_proj = size_q * size_in / weight_pack_factor;
1768
1769 let h_size = cfg.hidden_size;
1770 let i_size = cfg.intermediate_size;
1771 let gate_proj = h_size * i_size / weight_pack_factor;
1772 let up_proj = h_size * i_size / weight_pack_factor;
1773 let down_proj = i_size * h_size / weight_pack_factor;
1774
1775 input_layernorm
1776 + post_attention_layernorm
1777 + q_proj
1778 + k_proj
1779 + v_proj
1780 + o_proj
1781 + gate_proj
1782 + up_proj
1783 + down_proj
1784 };
1785 Ok(vec![
1786 per_layer_elems * dtype.size_in_bytes();
1787 cfg.num_hidden_layers
1788 ])
1789 }
1790
1791 fn num_layers(&self, config: &str) -> Result<usize> {
1792 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1793
1794 Ok(cfg.num_hidden_layers)
1795 }
1796
1797 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1798 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1799
1800 let cfg = ModelConfigMetadata {
1801 max_seq_len: cfg.max_position_embeddings,
1802 num_layers: cfg.num_hidden_layers,
1803 hidden_size: cfg.hidden_size,
1804 num_kv_heads: cfg.num_key_value_heads,
1805 num_attn_heads: cfg.num_attention_heads,
1806 sliding_window: cfg.sliding_window,
1807 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1808 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1809 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1810 };
1811
1812 Ok(Box::new(cfg))
1813 }
1814}
1815
1816pub struct Gemma2Loader;
1822
1823impl NormalModelLoader for Gemma2Loader {
1824 fn load(
1825 &self,
1826 config: &str,
1827 vb: ShardedVarBuilder,
1828 normal_loading_metadata: NormalLoadingMetadata,
1829 attention_mechanism: AttentionImplementation,
1830 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1831 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1832
1833 Ok(Box::new(models::gemma2::Model::new(
1834 &cfg,
1835 vb,
1836 self.is_gptx(config)?,
1837 normal_loading_metadata,
1838 attention_mechanism,
1839 )?))
1840 }
1841 fn load_xlora(
1842 &self,
1843 config: &str,
1844 vb: ShardedVarBuilder,
1845 lora_config: &[((String, String), LoraConfig)],
1846 xlora_config: Option<XLoraConfig>,
1847 xlora_ordering: Ordering,
1848 normal_loading_metadata: NormalLoadingMetadata,
1849 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1850 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1851 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1852
1853 Ok(Box::new(xlora_models::XLoraGemma2::new(
1854 &cfg,
1855 vb,
1856 lora_config,
1857 xlora_config,
1858 xlora_ordering,
1859 self.is_gptx(config)?,
1860 normal_loading_metadata,
1861 preload_adapters,
1862 )?))
1863 }
1864 fn is_gptx(&self, _: &str) -> Result<bool> {
1865 Ok(true)
1866 }
1867 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1868 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1869
1870 Ok(Box::new(cfg))
1871 }
1872}
1873
1874impl IsqModelLoader for Gemma2Loader {
1875 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1876 Ok(vec![
1877 Regex::new(r"lm_head\.(weight|bias)$")?,
1878 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1880 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1881 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1882 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1883 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1885 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1886 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1887 ])
1888 }
1889 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1890 self.isq_layer_regexes(config)
1891 }
1892}
1893
1894impl DeviceMappedModelLoader for Gemma2Loader {
1895 fn mapped_max_act_size_elems(
1896 &self,
1897 config: &str,
1898 params: &AutoDeviceMapParams,
1899 ) -> Result<usize> {
1900 let AutoDeviceMapParams::Text {
1901 max_seq_len,
1902 max_batch_size,
1903 } = params
1904 else {
1905 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1906 };
1907
1908 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1909
1910 Ok(
1911 max_batch_size
1912 * cfg.num_attention_heads
1913 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1914 )
1915 }
1916 fn non_mapped_max_act_size_elems(
1917 &self,
1918 _config: &str,
1919 _params: &AutoDeviceMapParams,
1920 ) -> Result<usize> {
1921 Ok(0)
1922 }
1923
1924 fn non_mapped_size_in_bytes(
1925 &self,
1926 config: &str,
1927 dtype: DType,
1928 weight_pack_factor: usize,
1929 _matformer_config: Option<&MatformerSliceConfig>,
1930 ) -> Result<usize> {
1931 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1932
1933 let elems = {
1934 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1935 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1937 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1938 } else {
1939 0
1940 };
1941 let norm = cfg.hidden_size;
1942 embed_tokens + lm_head + norm
1943 };
1944 Ok(elems * dtype.size_in_bytes())
1945 }
1946
1947 fn layer_sizes_in_bytes(
1948 &self,
1949 config: &str,
1950 dtype: DType,
1951 weight_pack_factor: usize,
1952 _matformer_config: Option<&MatformerSliceConfig>,
1953 ) -> Result<Vec<usize>> {
1954 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1955
1956 let per_layer_elems = {
1957 let input_layernorm = cfg.hidden_size;
1958 let post_attention_layernorm = cfg.hidden_size;
1959
1960 let size_in = cfg.hidden_size;
1961 let size_q = cfg.head_dim * cfg.num_attention_heads;
1962 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1963 let q_proj =
1964 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1965 let k_proj =
1966 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1967 let v_proj =
1968 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1969 let o_proj =
1970 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1971
1972 let h_size = cfg.hidden_size;
1973 let i_size = cfg.intermediate_size;
1974 let gate_proj = h_size * i_size / weight_pack_factor;
1975 let up_proj = h_size * i_size / weight_pack_factor;
1976 let down_proj = i_size * h_size / weight_pack_factor;
1977
1978 input_layernorm
1979 + post_attention_layernorm
1980 + q_proj
1981 + k_proj
1982 + v_proj
1983 + o_proj
1984 + gate_proj
1985 + up_proj
1986 + down_proj
1987 };
1988 Ok(vec![
1989 per_layer_elems * dtype.size_in_bytes();
1990 cfg.num_hidden_layers
1991 ])
1992 }
1993
1994 fn num_layers(&self, config: &str) -> Result<usize> {
1995 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1996
1997 Ok(cfg.num_hidden_layers)
1998 }
1999 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2000 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
2001
2002 let cfg = ModelConfigMetadata {
2003 max_seq_len: cfg.max_position_embeddings,
2004 num_layers: cfg.num_hidden_layers,
2005 hidden_size: cfg.hidden_size,
2006 num_kv_heads: cfg.num_key_value_heads,
2007 num_attn_heads: cfg.num_attention_heads,
2008 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2010 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2011 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2012 };
2013
2014 Ok(Box::new(cfg))
2015 }
2016}
2017
2018pub struct Starcoder2Loader;
2024
2025impl NormalModelLoader for Starcoder2Loader {
2026 fn load(
2027 &self,
2028 config: &str,
2029 vb: ShardedVarBuilder,
2030 normal_loading_metadata: NormalLoadingMetadata,
2031 attention_mechanism: AttentionImplementation,
2032 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2033 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2034
2035 Ok(Box::new(models::starcoder2::Model::new(
2036 &cfg,
2037 vb,
2038 self.is_gptx(config)?,
2039 normal_loading_metadata,
2040 attention_mechanism,
2041 )?))
2042 }
2043 fn load_xlora(
2044 &self,
2045 config: &str,
2046 vb: ShardedVarBuilder,
2047 lora_config: &[((String, String), LoraConfig)],
2048 xlora_config: Option<XLoraConfig>,
2049 xlora_ordering: Ordering,
2050 normal_loading_metadata: NormalLoadingMetadata,
2051 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2052 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2053 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2054
2055 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2056 &cfg,
2057 vb,
2058 lora_config,
2059 xlora_config,
2060 xlora_ordering,
2061 self.is_gptx(config)?,
2062 normal_loading_metadata,
2063 preload_adapters,
2064 )?))
2065 }
2066 fn is_gptx(&self, _: &str) -> Result<bool> {
2067 Ok(true)
2068 }
2069 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2070 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2071
2072 Ok(Box::new(cfg))
2073 }
2074}
2075
2076impl IsqModelLoader for Starcoder2Loader {
2077 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2078 Ok(vec![
2079 Regex::new(r"lm_head\.(weight|bias)$")?,
2080 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2082 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2083 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2084 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2085 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2087 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2088 ])
2089 }
2090 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2091 self.isq_layer_regexes(config)
2092 }
2093}
2094
2095impl DeviceMappedModelLoader for Starcoder2Loader {
2096 fn mapped_max_act_size_elems(
2097 &self,
2098 config: &str,
2099 params: &AutoDeviceMapParams,
2100 ) -> Result<usize> {
2101 let AutoDeviceMapParams::Text {
2102 max_seq_len,
2103 max_batch_size,
2104 } = params
2105 else {
2106 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2107 };
2108
2109 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2110
2111 Ok(
2112 max_batch_size
2113 * cfg.num_attention_heads
2114 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2115 )
2116 }
2117 fn non_mapped_max_act_size_elems(
2118 &self,
2119 _config: &str,
2120 _params: &AutoDeviceMapParams,
2121 ) -> Result<usize> {
2122 Ok(0)
2123 }
2124
2125 fn non_mapped_size_in_bytes(
2126 &self,
2127 config: &str,
2128 dtype: DType,
2129 weight_pack_factor: usize,
2130 _matformer_config: Option<&MatformerSliceConfig>,
2131 ) -> Result<usize> {
2132 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2133
2134 let elems = {
2135 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2136 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2138 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2139 } else {
2140 0
2141 };
2142 let norm = cfg.hidden_size + cfg.hidden_size;
2143 embed_tokens + lm_head + norm
2144 };
2145 Ok(elems * dtype.size_in_bytes())
2146 }
2147
2148 fn layer_sizes_in_bytes(
2149 &self,
2150 config: &str,
2151 dtype: DType,
2152 weight_pack_factor: usize,
2153 _matformer_config: Option<&MatformerSliceConfig>,
2154 ) -> Result<Vec<usize>> {
2155 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2156
2157 let per_layer_elems = {
2158 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2159 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2160
2161 let size_in = cfg.hidden_size;
2162 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2163 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2164 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2165 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2166 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2167 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2168
2169 let h_size = cfg.hidden_size;
2170 let i_size = cfg.intermediate_size;
2171 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2172 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2173
2174 input_layernorm
2175 + post_attention_layernorm
2176 + q_proj
2177 + k_proj
2178 + v_proj
2179 + o_proj
2180 + fc1
2181 + fc2
2182 };
2183 Ok(vec![
2184 per_layer_elems * dtype.size_in_bytes();
2185 cfg.num_hidden_layers
2186 ])
2187 }
2188
2189 fn num_layers(&self, config: &str) -> Result<usize> {
2190 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2191
2192 Ok(cfg.num_hidden_layers)
2193 }
2194
2195 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2196 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2197
2198 let cfg = ModelConfigMetadata {
2199 max_seq_len: cfg.max_position_embeddings,
2200 num_layers: cfg.num_hidden_layers,
2201 hidden_size: cfg.hidden_size,
2202 num_kv_heads: cfg.num_key_value_heads,
2203 num_attn_heads: cfg.num_attention_heads,
2204 sliding_window: cfg.sliding_window,
2205 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2206 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2207 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2208 };
2209
2210 Ok(Box::new(cfg))
2211 }
2212}
2213
2214pub struct Phi3_5MoELoader;
2220
2221impl NormalModelLoader for Phi3_5MoELoader {
2222 fn load(
2223 &self,
2224 config: &str,
2225 vb: ShardedVarBuilder,
2226 normal_loading_metadata: NormalLoadingMetadata,
2227 attention_mechanism: AttentionImplementation,
2228 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2229 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2230
2231 Ok(Box::new(models::phi3_5_moe::Model::new(
2232 &cfg,
2233 vb,
2234 self.is_gptx(config)?,
2235 normal_loading_metadata,
2236 attention_mechanism,
2237 )?))
2238 }
2239 fn load_xlora(
2240 &self,
2241 config: &str,
2242 vb: ShardedVarBuilder,
2243 lora_config: &[((String, String), LoraConfig)],
2244 xlora_config: Option<XLoraConfig>,
2245 xlora_ordering: Ordering,
2246 normal_loading_metadata: NormalLoadingMetadata,
2247 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2248 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2249 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2250
2251 Ok(Box::new(xlora_models::XLoraPhi3::new(
2252 &cfg,
2253 vb,
2254 lora_config,
2255 xlora_config,
2256 xlora_ordering,
2257 self.is_gptx(config)?,
2258 normal_loading_metadata,
2259 preload_adapters,
2260 )?))
2261 }
2262 fn is_gptx(&self, _: &str) -> Result<bool> {
2263 Ok(true)
2264 }
2265 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2266 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2267
2268 Ok(Box::new(cfg))
2269 }
2270}
2271
2272impl IsqModelLoader for Phi3_5MoELoader {
2273 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2274 Ok(vec![
2275 Regex::new(r"lm_head\.(weight|bias)$")?,
2276 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2278 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2279 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2280 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2281 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2283 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2284 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2285 ])
2286 }
2287 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2288 self.isq_layer_regexes(config)
2289 }
2290
2291 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2292 Ok(vec![
2293 Regex::new(r"lm_head\.(weight|bias)$")?,
2294 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2296 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2297 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2298 ])
2299 }
2300 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2301 self.isq_layer_regexes_moqe(config)
2302 }
2303}
2304
2305impl DeviceMappedModelLoader for Phi3_5MoELoader {
2306 fn mapped_max_act_size_elems(
2307 &self,
2308 config: &str,
2309 params: &AutoDeviceMapParams,
2310 ) -> Result<usize> {
2311 let AutoDeviceMapParams::Text {
2312 max_seq_len,
2313 max_batch_size,
2314 } = params
2315 else {
2316 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2317 };
2318
2319 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2320
2321 Ok(
2322 max_batch_size
2323 * cfg.num_attention_heads
2324 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2325 )
2326 }
2327 fn non_mapped_max_act_size_elems(
2328 &self,
2329 _config: &str,
2330 _params: &AutoDeviceMapParams,
2331 ) -> Result<usize> {
2332 Ok(0)
2333 }
2334
2335 fn non_mapped_size_in_bytes(
2336 &self,
2337 config: &str,
2338 dtype: DType,
2339 weight_pack_factor: usize,
2340 _matformer_config: Option<&MatformerSliceConfig>,
2341 ) -> Result<usize> {
2342 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2343
2344 let elems = {
2345 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2346 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2348 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2349 } else {
2350 0
2351 };
2352 let norm = cfg.hidden_size;
2353 embed_tokens + lm_head + norm
2354 };
2355 Ok(elems * dtype.size_in_bytes())
2356 }
2357
2358 fn layer_sizes_in_bytes(
2359 &self,
2360 config: &str,
2361 dtype: DType,
2362 weight_pack_factor: usize,
2363 _matformer_config: Option<&MatformerSliceConfig>,
2364 ) -> Result<Vec<usize>> {
2365 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2366
2367 let per_layer_elems = {
2368 let input_layernorm = cfg.hidden_size;
2369 let post_attention_layernorm = cfg.hidden_size;
2370
2371 let size_in = cfg.hidden_size;
2372 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2373 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2374 let q_proj =
2375 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2376 let k_proj =
2377 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2378 let v_proj =
2379 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2380 let o_proj =
2381 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2382
2383 let moe_block = {
2384 let gate = cfg.hidden_size * cfg.num_local_experts;
2385 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2387 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2388 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2389 gate + cfg.num_local_experts * w1
2390 + cfg.num_local_experts * w2
2391 + cfg.num_local_experts * w3
2392 };
2393
2394 input_layernorm
2395 + post_attention_layernorm
2396 + q_proj
2397 + k_proj
2398 + v_proj
2399 + o_proj
2400 + moe_block
2401 };
2402 Ok(vec![
2403 per_layer_elems * dtype.size_in_bytes();
2404 cfg.num_hidden_layers
2405 ])
2406 }
2407
2408 fn num_layers(&self, config: &str) -> Result<usize> {
2409 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2410
2411 Ok(cfg.num_hidden_layers)
2412 }
2413
2414 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2415 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2416
2417 let cfg = ModelConfigMetadata {
2418 max_seq_len: cfg.max_position_embeddings,
2419 num_layers: cfg.num_hidden_layers,
2420 hidden_size: cfg.hidden_size,
2421 num_kv_heads: cfg.num_key_value_heads,
2422 num_attn_heads: cfg.num_attention_heads,
2423 sliding_window: cfg.sliding_window,
2424 k_head_dim: cfg.head_dim(),
2425 v_head_dim: cfg.head_dim(),
2426 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2427 };
2428
2429 Ok(Box::new(cfg))
2430 }
2431}
2432
2433pub struct DeepSeekV2Loader;
2437
2438impl NormalModelLoader for DeepSeekV2Loader {
2439 fn load(
2440 &self,
2441 config: &str,
2442 vb: ShardedVarBuilder,
2443 normal_loading_metadata: NormalLoadingMetadata,
2444 attention_mechanism: AttentionImplementation,
2445 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2446 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2447
2448 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2449 &cfg,
2450 vb,
2451 self.is_gptx(config)?,
2452 normal_loading_metadata,
2453 attention_mechanism,
2454 )?))
2455 }
2456 fn load_xlora(
2457 &self,
2458 _config: &str,
2459 _vb: ShardedVarBuilder,
2460 _lora_config: &[((String, String), LoraConfig)],
2461 _xlora_config: Option<XLoraConfig>,
2462 _xlora_ordering: Ordering,
2463 _normal_loading_metadata: NormalLoadingMetadata,
2464 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2465 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2466 todo!()
2467 }
2468 fn is_gptx(&self, _: &str) -> Result<bool> {
2469 Ok(true)
2470 }
2471 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2472 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2473 Ok(Box::new(cfg))
2474 }
2475}
2476
2477impl IsqModelLoader for DeepSeekV2Loader {
2478 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2479 let mut data = vec![
2480 Regex::new(r"lm_head\.(weight|bias)$")?,
2481 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2483 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2484 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2485 ];
2486 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2487 if cfg.q_lora_rank.is_some() {
2488 data.extend(vec![
2489 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2490 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2491 ]);
2492 } else {
2493 data.push(Regex::new(
2494 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2495 )?);
2496 }
2497 for layer_idx in 0..cfg.num_hidden_layers {
2498 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2499 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2500 }) {
2501 for i in 0..n_routed_experts {
2502 data.extend(vec![
2503 Regex::new(&format!(
2504 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2505 ))?,
2506 Regex::new(&format!(
2507 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2508 ))?,
2509 Regex::new(&format!(
2510 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2511 ))?,
2512 ]);
2513 }
2514 if cfg.n_shared_experts.is_some() {
2515 data.extend(vec![
2516 Regex::new(&format!(
2517 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2518 ))?,
2519 Regex::new(&format!(
2520 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2521 ))?,
2522 Regex::new(&format!(
2523 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2524 ))?,
2525 ]);
2526 }
2527 } else {
2528 data.extend(vec![
2529 Regex::new(&format!(
2530 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2531 ))?,
2532 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2533 Regex::new(&format!(
2534 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2535 ))?,
2536 ]);
2537 };
2538 }
2539 Ok(data)
2540 }
2541 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2542 self.isq_layer_regexes(config)
2543 }
2544
2545 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2546 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2547 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2548 for layer_idx in 0..cfg.num_hidden_layers {
2549 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2550 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2551 }) {
2552 for i in 0..n_routed_experts {
2553 data.extend(vec![
2554 Regex::new(&format!(
2555 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2556 ))?,
2557 Regex::new(&format!(
2558 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2559 ))?,
2560 Regex::new(&format!(
2561 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2562 ))?,
2563 ]);
2564 }
2565 if cfg.n_shared_experts.is_some() {
2566 data.extend(vec![
2567 Regex::new(&format!(
2568 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2569 ))?,
2570 Regex::new(&format!(
2571 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2572 ))?,
2573 Regex::new(&format!(
2574 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2575 ))?,
2576 ]);
2577 }
2578 } else {
2579 data.extend(vec![
2580 Regex::new(&format!(
2581 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2582 ))?,
2583 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2584 Regex::new(&format!(
2585 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2586 ))?,
2587 ]);
2588 };
2589 }
2590 Ok(data)
2591 }
2592 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2593 self.isq_layer_regexes_moqe(config)
2594 }
2595}
2596
2597impl DeviceMappedModelLoader for DeepSeekV2Loader {
2598 fn mapped_max_act_size_elems(
2599 &self,
2600 config: &str,
2601 params: &AutoDeviceMapParams,
2602 ) -> Result<usize> {
2603 let AutoDeviceMapParams::Text {
2604 max_seq_len,
2605 max_batch_size,
2606 } = params
2607 else {
2608 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2609 };
2610
2611 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2612
2613 Ok(
2614 max_batch_size
2615 * cfg.num_attention_heads
2616 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2617 )
2618 }
2619 fn non_mapped_max_act_size_elems(
2620 &self,
2621 _config: &str,
2622 _params: &AutoDeviceMapParams,
2623 ) -> Result<usize> {
2624 Ok(0)
2625 }
2626
2627 fn non_mapped_size_in_bytes(
2628 &self,
2629 config: &str,
2630 dtype: DType,
2631 weight_pack_factor: usize,
2632 _matformer_config: Option<&MatformerSliceConfig>,
2633 ) -> Result<usize> {
2634 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2635 let elems = {
2636 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2637 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2639 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2640 } else {
2641 0
2642 };
2643 let norm = cfg.hidden_size;
2644 embed_tokens + lm_head + norm
2645 };
2646 Ok(elems * dtype.size_in_bytes())
2647 }
2648
2649 fn layer_sizes_in_bytes(
2650 &self,
2651 config: &str,
2652 dtype: DType,
2653 weight_pack_factor: usize,
2654 _matformer_config: Option<&MatformerSliceConfig>,
2655 ) -> Result<Vec<usize>> {
2656 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2657 let mut per_layer_elems = Vec::new();
2658
2659 for layer_idx in 0..cfg.num_hidden_layers {
2660 let input_layernorm = cfg.hidden_size;
2661 let post_attention_layernorm = cfg.hidden_size;
2662
2663 let q_proj = match cfg.q_lora_rank {
2664 Some(lora_rank) => {
2665 let a = cfg.hidden_size * lora_rank;
2666 let norm = lora_rank;
2667 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2668 a + norm + b
2669 }
2670 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2671 };
2672 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2673 / weight_pack_factor
2674 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2675 let kv_a_layernorm = cfg.kv_lora_rank;
2676 let kv_b_proj = cfg.kv_lora_rank
2677 * cfg.num_attention_heads
2678 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2679 / weight_pack_factor;
2680 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2681 / weight_pack_factor
2682 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2683
2684 let moe_block = {
2685 let mut sum = 0;
2686 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2687 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2688 }) {
2689 let h_size = cfg.hidden_size;
2690 let gate_proj =
2691 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2692 let up_proj =
2693 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2694 let down_proj =
2695 cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
2696 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2697 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2698 / weight_pack_factor;
2699 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2700 / weight_pack_factor;
2701 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2702 / weight_pack_factor;
2703 gate_proj + up_proj + down_proj
2704 } else {
2705 0
2706 };
2707 let gate_weight = n_routed_experts * cfg.hidden_size;
2708 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2709 } else {
2710 let h_size = cfg.hidden_size;
2711 let i_size = cfg.intermediate_size;
2712 let gate_proj = h_size * i_size / weight_pack_factor;
2713 let up_proj = h_size * i_size / weight_pack_factor;
2714 let down_proj = i_size * h_size / weight_pack_factor;
2715 sum += gate_proj + up_proj + down_proj;
2716 }
2717 sum
2718 };
2719
2720 per_layer_elems.push(
2721 input_layernorm
2722 + post_attention_layernorm
2723 + q_proj
2724 + kv_a_layernorm
2725 + kv_a_proj_with_mqa
2726 + kv_b_proj
2727 + o_proj
2728 + moe_block,
2729 );
2730 }
2731
2732 Ok(per_layer_elems
2733 .into_iter()
2734 .map(|x| x * dtype.size_in_bytes())
2735 .collect())
2736 }
2737
2738 fn num_layers(&self, config: &str) -> Result<usize> {
2739 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2740 Ok(cfg.num_hidden_layers)
2741 }
2742
2743 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2744 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2745
2746 let cfg = ModelConfigMetadata {
2747 max_seq_len: cfg.max_position_embeddings,
2748 num_layers: cfg.num_hidden_layers,
2749 hidden_size: cfg.hidden_size,
2750 num_kv_heads: cfg.num_attention_heads,
2751 num_attn_heads: cfg.num_attention_heads,
2752 sliding_window: None,
2753 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2754 v_head_dim: cfg.v_head_dim,
2755 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2756 };
2757
2758 Ok(Box::new(cfg))
2759 }
2760}
2761
2762pub struct DeepSeekV3Loader;
2766
2767impl NormalModelLoader for DeepSeekV3Loader {
2768 fn load(
2769 &self,
2770 config: &str,
2771 vb: ShardedVarBuilder,
2772 normal_loading_metadata: NormalLoadingMetadata,
2773 attention_mechanism: AttentionImplementation,
2774 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2775 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2776 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2777 &cfg,
2778 vb,
2779 self.is_gptx(config)?,
2780 normal_loading_metadata,
2781 attention_mechanism,
2782 )?))
2783 }
2784 fn load_xlora(
2785 &self,
2786 _config: &str,
2787 _vb: ShardedVarBuilder,
2788 _lora_config: &[((String, String), LoraConfig)],
2789 _xlora_config: Option<XLoraConfig>,
2790 _xlora_ordering: Ordering,
2791 _normal_loading_metadata: NormalLoadingMetadata,
2792 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2793 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2794 todo!()
2795 }
2796 fn is_gptx(&self, _: &str) -> Result<bool> {
2797 Ok(true)
2798 }
2799 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2800 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2801 Ok(Box::new(cfg))
2802 }
2803}
2804
2805impl IsqModelLoader for DeepSeekV3Loader {
2806 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2807 let mut data = vec![
2808 Regex::new(r"lm_head\.(weight|bias)$")?,
2809 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2811 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2812 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2813 ];
2814 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2815 if cfg.q_lora_rank.is_some() {
2816 data.extend(vec![
2817 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2818 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2819 ]);
2820 } else {
2821 data.push(Regex::new(
2822 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2823 )?);
2824 }
2825 for layer_idx in 0..cfg.num_hidden_layers {
2826 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2827 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2828 }) {
2829 for i in 0..n_routed_experts {
2830 data.extend(vec![
2831 Regex::new(&format!(
2832 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2833 ))?,
2834 Regex::new(&format!(
2835 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2836 ))?,
2837 Regex::new(&format!(
2838 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2839 ))?,
2840 ]);
2841 }
2842 if cfg.n_shared_experts.is_some() {
2843 data.extend(vec![
2844 Regex::new(&format!(
2845 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2846 ))?,
2847 Regex::new(&format!(
2848 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2849 ))?,
2850 Regex::new(&format!(
2851 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2852 ))?,
2853 ]);
2854 }
2855 } else {
2856 data.extend(vec![
2857 Regex::new(&format!(
2858 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2859 ))?,
2860 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2861 Regex::new(&format!(
2862 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2863 ))?,
2864 ]);
2865 };
2866 }
2867 Ok(data)
2868 }
2869 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2870 self.isq_layer_regexes(config)
2871 }
2872
2873 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2874 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2875 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2876 for layer_idx in 0..cfg.num_hidden_layers {
2877 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2878 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2879 }) {
2880 for i in 0..n_routed_experts {
2881 data.extend(vec![
2882 Regex::new(&format!(
2883 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2884 ))?,
2885 Regex::new(&format!(
2886 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2887 ))?,
2888 Regex::new(&format!(
2889 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2890 ))?,
2891 ]);
2892 }
2893 if cfg.n_shared_experts.is_some() {
2894 data.extend(vec![
2895 Regex::new(&format!(
2896 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2897 ))?,
2898 Regex::new(&format!(
2899 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2900 ))?,
2901 Regex::new(&format!(
2902 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2903 ))?,
2904 ]);
2905 }
2906 } else {
2907 data.extend(vec![
2908 Regex::new(&format!(
2909 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2910 ))?,
2911 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2912 Regex::new(&format!(
2913 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2914 ))?,
2915 ]);
2916 };
2917 }
2918 Ok(data)
2919 }
2920 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2921 self.isq_layer_regexes_moqe(config)
2922 }
2923}
2924
2925impl DeviceMappedModelLoader for DeepSeekV3Loader {
2926 fn mapped_max_act_size_elems(
2927 &self,
2928 config: &str,
2929 params: &AutoDeviceMapParams,
2930 ) -> Result<usize> {
2931 let AutoDeviceMapParams::Text {
2932 max_seq_len,
2933 max_batch_size,
2934 } = params
2935 else {
2936 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2937 };
2938
2939 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2940
2941 Ok(
2942 max_batch_size
2943 * cfg.num_attention_heads
2944 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2945 )
2946 }
2947 fn non_mapped_max_act_size_elems(
2948 &self,
2949 _config: &str,
2950 _params: &AutoDeviceMapParams,
2951 ) -> Result<usize> {
2952 Ok(0)
2953 }
2954
2955 fn non_mapped_size_in_bytes(
2956 &self,
2957 config: &str,
2958 dtype: DType,
2959 weight_pack_factor: usize,
2960 _matformer_config: Option<&MatformerSliceConfig>,
2961 ) -> Result<usize> {
2962 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2963 let elems = {
2964 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2965 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2967 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2968 } else {
2969 0
2970 };
2971 let norm = cfg.hidden_size;
2972 embed_tokens + lm_head + norm
2973 };
2974 Ok(elems * dtype.size_in_bytes())
2975 }
2976
2977 fn layer_sizes_in_bytes(
2978 &self,
2979 config: &str,
2980 dtype: DType,
2981 weight_pack_factor: usize,
2982 _matformer_config: Option<&MatformerSliceConfig>,
2983 ) -> Result<Vec<usize>> {
2984 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2985 let mut per_layer_elems = Vec::new();
2986
2987 for layer_idx in 0..cfg.num_hidden_layers {
2988 let input_layernorm = cfg.hidden_size;
2989 let post_attention_layernorm = cfg.hidden_size;
2990
2991 let q_proj = match cfg.q_lora_rank {
2992 Some(lora_rank) => {
2993 let a = cfg.hidden_size * lora_rank;
2994 let norm = lora_rank;
2995 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2996 a + norm + b
2997 }
2998 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2999 };
3000 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
3001 / weight_pack_factor
3002 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
3003 let kv_a_layernorm = cfg.kv_lora_rank;
3004 let kv_b_proj = cfg.kv_lora_rank
3005 * cfg.num_attention_heads
3006 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3007 / weight_pack_factor;
3008 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
3009 / weight_pack_factor
3010 + bias_if!(cfg.attention_bias, cfg.hidden_size);
3011
3012 let moe_block = {
3013 let mut sum = 0;
3014 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
3015 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
3016 }) {
3017 let h_size = cfg.hidden_size;
3018 let gate_proj =
3019 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3020 let up_proj =
3021 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3022 let down_proj =
3023 cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
3024 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3025 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3026 / weight_pack_factor;
3027 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3028 / weight_pack_factor;
3029 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3030 / weight_pack_factor;
3031 gate_proj + up_proj + down_proj
3032 } else {
3033 0
3034 };
3035 let gate_weight = n_routed_experts * cfg.hidden_size;
3036 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3037 } else {
3038 let h_size = cfg.hidden_size;
3039 let i_size = cfg.intermediate_size;
3040 let gate_proj = h_size * i_size / weight_pack_factor;
3041 let up_proj = h_size * i_size / weight_pack_factor;
3042 let down_proj = i_size * h_size / weight_pack_factor;
3043 sum += gate_proj + up_proj + down_proj;
3044 }
3045 sum
3046 };
3047
3048 per_layer_elems.push(
3049 input_layernorm
3050 + post_attention_layernorm
3051 + q_proj
3052 + kv_a_layernorm
3053 + kv_a_proj_with_mqa
3054 + kv_b_proj
3055 + o_proj
3056 + moe_block,
3057 );
3058 }
3059
3060 Ok(per_layer_elems
3061 .into_iter()
3062 .map(|x| x * dtype.size_in_bytes())
3063 .collect())
3064 }
3065
3066 fn num_layers(&self, config: &str) -> Result<usize> {
3067 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3068 Ok(cfg.num_hidden_layers)
3069 }
3070
3071 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3072 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3073
3074 let cfg = ModelConfigMetadata {
3075 max_seq_len: cfg.max_position_embeddings,
3076 num_layers: cfg.num_hidden_layers,
3077 hidden_size: cfg.hidden_size,
3078 num_kv_heads: cfg.num_attention_heads,
3079 num_attn_heads: cfg.num_attention_heads,
3080 sliding_window: None,
3081 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3082 v_head_dim: cfg.v_head_dim,
3083 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3084 };
3085
3086 Ok(Box::new(cfg))
3087 }
3088}
3089
3090pub struct Qwen3Loader;
3094
3095impl NormalModelLoader for Qwen3Loader {
3096 fn load(
3097 &self,
3098 config: &str,
3099 vb: ShardedVarBuilder,
3100 normal_loading_metadata: NormalLoadingMetadata,
3101 attention_mechanism: AttentionImplementation,
3102 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3103 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3104
3105 Ok(Box::new(models::qwen3::Model::new(
3106 &cfg,
3107 vb,
3108 self.is_gptx(config)?,
3109 normal_loading_metadata,
3110 attention_mechanism,
3111 )?))
3112 }
3113 fn load_xlora(
3114 &self,
3115 _config: &str,
3116 _vb: ShardedVarBuilder,
3117 _lora_config: &[((String, String), LoraConfig)],
3118 _xlora_config: Option<XLoraConfig>,
3119 _xlora_ordering: Ordering,
3120 _normal_loading_metadata: NormalLoadingMetadata,
3121 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3122 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3123 todo!()
3124 }
3125 fn is_gptx(&self, _: &str) -> Result<bool> {
3126 Ok(true)
3127 }
3128 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3129 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3130
3131 Ok(Box::new(cfg))
3132 }
3133}
3134
3135impl IsqModelLoader for Qwen3Loader {
3136 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3137 Ok(vec![
3138 Regex::new(r"lm_head\.(weight|bias)$")?,
3139 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3141 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3142 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3143 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3144 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3146 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3147 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3148 ])
3149 }
3150 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3151 self.isq_layer_regexes(config)
3152 }
3153}
3154
3155impl DeviceMappedModelLoader for Qwen3Loader {
3156 fn mapped_max_act_size_elems(
3157 &self,
3158 config: &str,
3159 params: &AutoDeviceMapParams,
3160 ) -> Result<usize> {
3161 let AutoDeviceMapParams::Text {
3162 max_seq_len,
3163 max_batch_size,
3164 } = params
3165 else {
3166 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3167 };
3168
3169 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3170
3171 Ok(
3172 max_batch_size
3173 * cfg.num_attention_heads
3174 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3175 )
3176 }
3177 fn non_mapped_max_act_size_elems(
3178 &self,
3179 _config: &str,
3180 _params: &AutoDeviceMapParams,
3181 ) -> Result<usize> {
3182 Ok(0)
3183 }
3184
3185 fn non_mapped_size_in_bytes(
3186 &self,
3187 config: &str,
3188 dtype: DType,
3189 weight_pack_factor: usize,
3190 _matformer_config: Option<&MatformerSliceConfig>,
3191 ) -> Result<usize> {
3192 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3193 let elems = {
3194 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3195 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3197 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3198 } else {
3199 0
3200 };
3201 let norm = cfg.hidden_size;
3202 embed_tokens + lm_head + norm
3203 };
3204 Ok(elems * dtype.size_in_bytes())
3205 }
3206
3207 fn layer_sizes_in_bytes(
3208 &self,
3209 config: &str,
3210 dtype: DType,
3211 weight_pack_factor: usize,
3212 _matformer_config: Option<&MatformerSliceConfig>,
3213 ) -> Result<Vec<usize>> {
3214 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3215 let per_layer_elems = {
3216 let input_layernorm = cfg.hidden_size;
3217 let post_attention_layernorm = cfg.hidden_size;
3218
3219 let size_in = cfg.hidden_size;
3220 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3221 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3222 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3223 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3224 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3225 let o_proj = size_q * size_in / weight_pack_factor;
3226
3227 let h_size = cfg.hidden_size;
3228 let i_size = cfg.intermediate_size;
3229 let gate_proj = h_size * i_size / weight_pack_factor;
3230 let up_proj = h_size * i_size / weight_pack_factor;
3231 let down_proj = i_size * h_size / weight_pack_factor;
3232
3233 let q_norm = cfg.head_dim();
3234 let k_norm = cfg.head_dim();
3235
3236 input_layernorm
3237 + post_attention_layernorm
3238 + q_proj
3239 + k_proj
3240 + v_proj
3241 + o_proj
3242 + gate_proj
3243 + up_proj
3244 + down_proj
3245 + q_norm
3246 + k_norm
3247 };
3248 Ok(vec![
3249 per_layer_elems * dtype.size_in_bytes();
3250 cfg.num_hidden_layers
3251 ])
3252 }
3253
3254 fn num_layers(&self, config: &str) -> Result<usize> {
3255 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3256 Ok(cfg.num_hidden_layers)
3257 }
3258
3259 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3260 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3261
3262 let cfg = ModelConfigMetadata {
3263 max_seq_len: cfg.max_position_embeddings,
3264 num_layers: cfg.num_hidden_layers,
3265 hidden_size: cfg.hidden_size,
3266 num_kv_heads: cfg.num_key_value_heads,
3267 num_attn_heads: cfg.num_attention_heads,
3268 sliding_window: cfg.sliding_window,
3269 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3270 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3271 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3272 };
3273
3274 Ok(Box::new(cfg))
3275 }
3276}
3277
3278pub struct GLM4Loader;
3282
3283impl NormalModelLoader for GLM4Loader {
3284 fn load(
3285 &self,
3286 config: &str,
3287 vb: ShardedVarBuilder,
3288 normal_loading_metadata: NormalLoadingMetadata,
3289 attention_mechanism: AttentionImplementation,
3290 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3291 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3292
3293 Ok(Box::new(models::glm4::Model::new(
3294 &cfg,
3295 vb,
3296 self.is_gptx(config)?,
3297 normal_loading_metadata,
3298 attention_mechanism,
3299 )?))
3300 }
3301 fn load_xlora(
3302 &self,
3303 _config: &str,
3304 _vb: ShardedVarBuilder,
3305 _lora_config: &[((String, String), LoraConfig)],
3306 _xlora_config: Option<XLoraConfig>,
3307 _xlora_ordering: Ordering,
3308 _normal_loading_metadata: NormalLoadingMetadata,
3309 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3310 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3311 todo!()
3312 }
3313 fn is_gptx(&self, _: &str) -> Result<bool> {
3314 Ok(true)
3315 }
3316 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3317 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3318
3319 Ok(Box::new(cfg))
3320 }
3321}
3322
3323impl IsqModelLoader for GLM4Loader {
3324 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3325 Ok(vec![
3326 Regex::new(r"lm_head\.(weight|bias)$")?,
3327 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3329 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3330 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3331 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3332 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3334 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3335 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3336 ])
3337 }
3338 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3339 self.isq_layer_regexes(config)
3340 }
3341}
3342
3343impl DeviceMappedModelLoader for GLM4Loader {
3344 fn mapped_max_act_size_elems(
3345 &self,
3346 config: &str,
3347 params: &AutoDeviceMapParams,
3348 ) -> Result<usize> {
3349 let AutoDeviceMapParams::Text {
3350 max_seq_len,
3351 max_batch_size,
3352 } = params
3353 else {
3354 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3355 };
3356
3357 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3358
3359 Ok(
3360 max_batch_size
3361 * cfg.num_attention_heads
3362 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3363 )
3364 }
3365 fn non_mapped_max_act_size_elems(
3366 &self,
3367 _config: &str,
3368 _params: &AutoDeviceMapParams,
3369 ) -> Result<usize> {
3370 Ok(0)
3371 }
3372
3373 fn non_mapped_size_in_bytes(
3374 &self,
3375 config: &str,
3376 dtype: DType,
3377 weight_pack_factor: usize,
3378 _matformer_config: Option<&MatformerSliceConfig>,
3379 ) -> Result<usize> {
3380 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3381 let elems = {
3382 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3383 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3385 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3386 } else {
3387 0
3388 };
3389 let norm = cfg.hidden_size;
3390 embed_tokens + lm_head + norm
3391 };
3392 Ok(elems * dtype.size_in_bytes())
3393 }
3394
3395 fn layer_sizes_in_bytes(
3396 &self,
3397 config: &str,
3398 dtype: DType,
3399 weight_pack_factor: usize,
3400 _matformer_config: Option<&MatformerSliceConfig>,
3401 ) -> Result<Vec<usize>> {
3402 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3403 let per_layer_elems = {
3404 let input_layernorm = cfg.hidden_size;
3405 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3408 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3409 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3410 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3411 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3412 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3413 let o_proj = size_q * size_in / weight_pack_factor;
3414
3415 let h_size = cfg.hidden_size;
3416 let i_size = cfg.intermediate_size;
3417 let gate_proj = h_size * i_size / weight_pack_factor;
3418 let up_proj = h_size * i_size / weight_pack_factor;
3419 let down_proj = i_size * h_size / weight_pack_factor;
3420
3421 input_layernorm
3422 + post_attention_layernorm
3423 + q_proj
3424 + k_proj
3425 + v_proj
3426 + o_proj
3427 + gate_proj
3428 + up_proj
3429 + down_proj
3430 };
3431 Ok(vec![
3432 per_layer_elems * dtype.size_in_bytes();
3433 cfg.num_hidden_layers
3434 ])
3435 }
3436
3437 fn num_layers(&self, config: &str) -> Result<usize> {
3438 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3439 Ok(cfg.num_hidden_layers)
3440 }
3441
3442 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3443 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3444
3445 let cfg = ModelConfigMetadata {
3446 max_seq_len: cfg.max_position_embeddings,
3447 num_layers: cfg.num_hidden_layers,
3448 hidden_size: cfg.hidden_size,
3449 num_kv_heads: cfg.num_key_value_heads,
3450 num_attn_heads: cfg.num_attention_heads,
3451 sliding_window: cfg.sliding_window,
3452 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3453 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3454 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3455 };
3456
3457 Ok(Box::new(cfg))
3458 }
3459}
3460
3461pub struct GLM4MoeLiteLoader;
3465
3466impl NormalModelLoader for GLM4MoeLiteLoader {
3467 fn load(
3468 &self,
3469 config: &str,
3470 vb: ShardedVarBuilder,
3471 normal_loading_metadata: NormalLoadingMetadata,
3472 attention_mechanism: AttentionImplementation,
3473 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3474 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3475 Ok(Box::new(models::glm4_moe_lite::Glm4MoeLite::new(
3476 &cfg,
3477 vb,
3478 self.is_gptx(config)?,
3479 normal_loading_metadata,
3480 attention_mechanism,
3481 )?))
3482 }
3483 fn load_xlora(
3484 &self,
3485 _config: &str,
3486 _vb: ShardedVarBuilder,
3487 _lora_config: &[((String, String), LoraConfig)],
3488 _xlora_config: Option<XLoraConfig>,
3489 _xlora_ordering: Ordering,
3490 _normal_loading_metadata: NormalLoadingMetadata,
3491 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3492 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3493 todo!()
3494 }
3495 fn is_gptx(&self, _: &str) -> Result<bool> {
3496 Ok(true)
3497 }
3498 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3499 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3500 Ok(Box::new(cfg))
3501 }
3502}
3503
3504impl IsqModelLoader for GLM4MoeLiteLoader {
3505 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3506 let mut data = vec![
3507 Regex::new(r"lm_head\.(weight|bias)$")?,
3508 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3510 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3511 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3512 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3514 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3515 ];
3516 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3517 for layer_idx in 0..cfg.num_hidden_layers {
3518 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3519 for i in 0..cfg.n_routed_experts {
3521 data.extend(vec![
3522 Regex::new(&format!(
3523 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3524 ))?,
3525 Regex::new(&format!(
3526 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3527 ))?,
3528 Regex::new(&format!(
3529 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3530 ))?,
3531 ]);
3532 }
3533 if cfg.n_shared_experts > 0 {
3534 data.extend(vec![
3535 Regex::new(&format!(
3536 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3537 ))?,
3538 Regex::new(&format!(
3539 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3540 ))?,
3541 Regex::new(&format!(
3542 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3543 ))?,
3544 ]);
3545 }
3546 } else {
3547 data.extend(vec![
3549 Regex::new(&format!(
3550 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3551 ))?,
3552 Regex::new(&format!(
3553 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3554 ))?,
3555 Regex::new(&format!(
3556 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3557 ))?,
3558 ]);
3559 };
3560 }
3561 Ok(data)
3562 }
3563 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3564 self.isq_layer_regexes(config)
3565 }
3566
3567 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3568 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3569 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3570 for layer_idx in 0..cfg.num_hidden_layers {
3571 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3572 for i in 0..cfg.n_routed_experts {
3574 data.extend(vec![
3575 Regex::new(&format!(
3576 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3577 ))?,
3578 Regex::new(&format!(
3579 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3580 ))?,
3581 Regex::new(&format!(
3582 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3583 ))?,
3584 ]);
3585 }
3586 if cfg.n_shared_experts > 0 {
3587 data.extend(vec![
3588 Regex::new(&format!(
3589 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3590 ))?,
3591 Regex::new(&format!(
3592 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3593 ))?,
3594 Regex::new(&format!(
3595 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3596 ))?,
3597 ]);
3598 }
3599 } else {
3600 data.extend(vec![
3602 Regex::new(&format!(
3603 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3604 ))?,
3605 Regex::new(&format!(
3606 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3607 ))?,
3608 Regex::new(&format!(
3609 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3610 ))?,
3611 ]);
3612 };
3613 }
3614 Ok(data)
3615 }
3616 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3617 self.isq_layer_regexes_moqe(config)
3618 }
3619}
3620
3621impl DeviceMappedModelLoader for GLM4MoeLiteLoader {
3622 fn mapped_max_act_size_elems(
3623 &self,
3624 config: &str,
3625 params: &AutoDeviceMapParams,
3626 ) -> Result<usize> {
3627 let AutoDeviceMapParams::Text {
3628 max_seq_len,
3629 max_batch_size,
3630 } = params
3631 else {
3632 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3633 };
3634
3635 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3636
3637 Ok(
3638 max_batch_size
3639 * cfg.num_attention_heads
3640 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3641 )
3642 }
3643 fn non_mapped_max_act_size_elems(
3644 &self,
3645 _config: &str,
3646 _params: &AutoDeviceMapParams,
3647 ) -> Result<usize> {
3648 Ok(0)
3649 }
3650
3651 fn non_mapped_size_in_bytes(
3652 &self,
3653 config: &str,
3654 dtype: DType,
3655 weight_pack_factor: usize,
3656 _matformer_config: Option<&MatformerSliceConfig>,
3657 ) -> Result<usize> {
3658 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3659 let elems = {
3660 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3661 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3663 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3664 } else {
3665 0
3666 };
3667 let norm = cfg.hidden_size;
3668 embed_tokens + lm_head + norm
3669 };
3670 Ok(elems * dtype.size_in_bytes())
3671 }
3672
3673 fn layer_sizes_in_bytes(
3674 &self,
3675 config: &str,
3676 dtype: DType,
3677 weight_pack_factor: usize,
3678 _matformer_config: Option<&MatformerSliceConfig>,
3679 ) -> Result<Vec<usize>> {
3680 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3681 let mut per_layer_elems = Vec::new();
3682
3683 for layer_idx in 0..cfg.num_hidden_layers {
3684 let input_layernorm = cfg.hidden_size;
3685 let post_attention_layernorm = cfg.hidden_size;
3686
3687 let q_proj = {
3689 let a = cfg.hidden_size * cfg.q_lora_rank / weight_pack_factor;
3690 let norm = cfg.q_lora_rank;
3691 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.q_lora_rank
3692 / weight_pack_factor;
3693 a + norm + b
3694 };
3695 let kv_a_proj_with_mqa =
3696 cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim) / weight_pack_factor;
3697 let kv_a_layernorm = cfg.kv_lora_rank;
3698 let kv_b_proj = cfg.kv_lora_rank
3699 * cfg.num_attention_heads
3700 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3701 / weight_pack_factor;
3702 let o_proj =
3703 cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size / weight_pack_factor;
3704
3705 let moe_block = {
3706 let mut sum = 0;
3707 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3708 let h_size = cfg.hidden_size;
3710 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3711 * cfg.n_routed_experts;
3712 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3713 * cfg.n_routed_experts;
3714 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3715 * cfg.n_routed_experts;
3716 let shared_experts = if cfg.n_shared_experts > 0 {
3717 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3718 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3719 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
3720 gate_proj + up_proj + down_proj
3721 } else {
3722 0
3723 };
3724 let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
3725 let e_score_correction_bias = cfg.n_routed_experts;
3726 sum += gate_proj
3727 + up_proj
3728 + down_proj
3729 + shared_experts
3730 + gate_weight
3731 + e_score_correction_bias;
3732 } else {
3733 let h_size = cfg.hidden_size;
3735 let i_size = cfg.intermediate_size;
3736 let gate_proj = h_size * i_size / weight_pack_factor;
3737 let up_proj = h_size * i_size / weight_pack_factor;
3738 let down_proj = i_size * h_size / weight_pack_factor;
3739 sum += gate_proj + up_proj + down_proj;
3740 }
3741 sum
3742 };
3743
3744 per_layer_elems.push(
3745 input_layernorm
3746 + post_attention_layernorm
3747 + q_proj
3748 + kv_a_layernorm
3749 + kv_a_proj_with_mqa
3750 + kv_b_proj
3751 + o_proj
3752 + moe_block,
3753 );
3754 }
3755
3756 Ok(per_layer_elems
3757 .into_iter()
3758 .map(|x| x * dtype.size_in_bytes())
3759 .collect())
3760 }
3761
3762 fn num_layers(&self, config: &str) -> Result<usize> {
3763 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3764 Ok(cfg.num_hidden_layers)
3765 }
3766
3767 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3768 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3769
3770 let cfg = ModelConfigMetadata {
3771 max_seq_len: cfg.max_position_embeddings,
3772 num_layers: cfg.num_hidden_layers,
3773 hidden_size: cfg.hidden_size,
3774 num_kv_heads: cfg.num_attention_heads,
3775 num_attn_heads: cfg.num_attention_heads,
3776 sliding_window: None,
3777 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3778 v_head_dim: cfg.v_head_dim,
3779 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3780 };
3781
3782 Ok(Box::new(cfg))
3783 }
3784}
3785
3786pub struct GLM4MoeLoader;
3790
3791impl NormalModelLoader for GLM4MoeLoader {
3792 fn load(
3793 &self,
3794 config: &str,
3795 vb: ShardedVarBuilder,
3796 normal_loading_metadata: NormalLoadingMetadata,
3797 attention_mechanism: AttentionImplementation,
3798 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3799 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3800 Ok(Box::new(models::glm4_moe::Glm4Moe::new(
3801 &cfg,
3802 vb,
3803 self.is_gptx(config)?,
3804 normal_loading_metadata,
3805 attention_mechanism,
3806 )?))
3807 }
3808 fn load_xlora(
3809 &self,
3810 _config: &str,
3811 _vb: ShardedVarBuilder,
3812 _lora_config: &[((String, String), LoraConfig)],
3813 _xlora_config: Option<XLoraConfig>,
3814 _xlora_ordering: Ordering,
3815 _normal_loading_metadata: NormalLoadingMetadata,
3816 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3817 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3818 todo!()
3819 }
3820 fn is_gptx(&self, _: &str) -> Result<bool> {
3821 Ok(true)
3822 }
3823 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3824 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3825 Ok(Box::new(cfg))
3826 }
3827}
3828
3829impl IsqModelLoader for GLM4MoeLoader {
3830 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3831 let mut data = vec![
3832 Regex::new(r"lm_head\.(weight|bias)$")?,
3833 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3835 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3836 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3837 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3838 ];
3839 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3840 for layer_idx in 0..cfg.num_hidden_layers {
3841 if layer_idx >= cfg.first_k_dense_replace {
3842 for i in 0..cfg.n_routed_experts {
3844 data.extend(vec![
3845 Regex::new(&format!(
3846 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3847 ))?,
3848 Regex::new(&format!(
3849 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3850 ))?,
3851 Regex::new(&format!(
3852 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3853 ))?,
3854 ]);
3855 }
3856 if cfg.n_shared_experts > 0 {
3857 data.extend(vec![
3858 Regex::new(&format!(
3859 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3860 ))?,
3861 Regex::new(&format!(
3862 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3863 ))?,
3864 Regex::new(&format!(
3865 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3866 ))?,
3867 ]);
3868 }
3869 } else {
3870 data.extend(vec![
3872 Regex::new(&format!(
3873 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3874 ))?,
3875 Regex::new(&format!(
3876 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3877 ))?,
3878 Regex::new(&format!(
3879 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3880 ))?,
3881 ]);
3882 };
3883 }
3884 Ok(data)
3885 }
3886 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3887 self.isq_layer_regexes(config)
3888 }
3889
3890 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3891 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3892 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3893 for layer_idx in 0..cfg.num_hidden_layers {
3894 if layer_idx >= cfg.first_k_dense_replace {
3895 for i in 0..cfg.n_routed_experts {
3897 data.extend(vec![
3898 Regex::new(&format!(
3899 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3900 ))?,
3901 Regex::new(&format!(
3902 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3903 ))?,
3904 Regex::new(&format!(
3905 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3906 ))?,
3907 ]);
3908 }
3909 if cfg.n_shared_experts > 0 {
3910 data.extend(vec![
3911 Regex::new(&format!(
3912 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3913 ))?,
3914 Regex::new(&format!(
3915 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3916 ))?,
3917 Regex::new(&format!(
3918 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3919 ))?,
3920 ]);
3921 }
3922 } else {
3923 data.extend(vec![
3925 Regex::new(&format!(
3926 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3927 ))?,
3928 Regex::new(&format!(
3929 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3930 ))?,
3931 Regex::new(&format!(
3932 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3933 ))?,
3934 ]);
3935 };
3936 }
3937 Ok(data)
3938 }
3939 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3940 self.isq_layer_regexes_moqe(config)
3941 }
3942}
3943
3944impl DeviceMappedModelLoader for GLM4MoeLoader {
3945 fn mapped_max_act_size_elems(
3946 &self,
3947 config: &str,
3948 params: &AutoDeviceMapParams,
3949 ) -> Result<usize> {
3950 let AutoDeviceMapParams::Text {
3951 max_seq_len,
3952 max_batch_size,
3953 } = params
3954 else {
3955 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3956 };
3957
3958 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3959
3960 Ok(
3961 max_batch_size
3962 * cfg.num_attention_heads
3963 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3964 )
3965 }
3966 fn non_mapped_max_act_size_elems(
3967 &self,
3968 _config: &str,
3969 _params: &AutoDeviceMapParams,
3970 ) -> Result<usize> {
3971 Ok(0)
3972 }
3973
3974 fn non_mapped_size_in_bytes(
3975 &self,
3976 config: &str,
3977 dtype: DType,
3978 weight_pack_factor: usize,
3979 _matformer_config: Option<&MatformerSliceConfig>,
3980 ) -> Result<usize> {
3981 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3982 let elems = {
3983 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3984 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3985 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3986 } else {
3987 0
3988 };
3989 let norm = cfg.hidden_size;
3990 embed_tokens + lm_head + norm
3991 };
3992 Ok(elems * dtype.size_in_bytes())
3993 }
3994
3995 fn layer_sizes_in_bytes(
3996 &self,
3997 config: &str,
3998 dtype: DType,
3999 weight_pack_factor: usize,
4000 _matformer_config: Option<&MatformerSliceConfig>,
4001 ) -> Result<Vec<usize>> {
4002 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4003 let mut per_layer_elems = Vec::new();
4004
4005 let head_dim = cfg.head_dim();
4006 for layer_idx in 0..cfg.num_hidden_layers {
4007 let input_layernorm = cfg.hidden_size;
4008 let post_attention_layernorm = cfg.hidden_size;
4009
4010 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor
4012 + bias_if!(cfg.attention_bias, cfg.num_attention_heads * head_dim);
4013 let k_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4014 + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4015 let v_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4016 + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4017 let o_proj = cfg.num_attention_heads * head_dim * cfg.hidden_size / weight_pack_factor;
4018
4019 let qk_norm = if cfg.use_qk_norm {
4021 head_dim * 2 } else {
4023 0
4024 };
4025
4026 let moe_block = {
4027 let mut sum = 0;
4028 if layer_idx >= cfg.first_k_dense_replace {
4029 let h_size = cfg.hidden_size;
4031 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4032 * cfg.n_routed_experts;
4033 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4034 * cfg.n_routed_experts;
4035 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
4036 * cfg.n_routed_experts;
4037 let shared_experts = if cfg.n_shared_experts > 0 {
4038 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4039 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4040 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
4041 gate_proj + up_proj + down_proj
4042 } else {
4043 0
4044 };
4045 let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
4046 let e_score_correction_bias = cfg.n_routed_experts;
4047 sum += gate_proj
4048 + up_proj
4049 + down_proj
4050 + shared_experts
4051 + gate_weight
4052 + e_score_correction_bias;
4053 } else {
4054 let h_size = cfg.hidden_size;
4056 let i_size = cfg.intermediate_size;
4057 let gate_proj = h_size * i_size / weight_pack_factor;
4058 let up_proj = h_size * i_size / weight_pack_factor;
4059 let down_proj = i_size * h_size / weight_pack_factor;
4060 sum += gate_proj + up_proj + down_proj;
4061 }
4062 sum
4063 };
4064
4065 per_layer_elems.push(
4066 input_layernorm
4067 + post_attention_layernorm
4068 + q_proj
4069 + k_proj
4070 + v_proj
4071 + o_proj
4072 + qk_norm
4073 + moe_block,
4074 );
4075 }
4076
4077 Ok(per_layer_elems
4078 .into_iter()
4079 .map(|x| x * dtype.size_in_bytes())
4080 .collect())
4081 }
4082
4083 fn num_layers(&self, config: &str) -> Result<usize> {
4084 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4085 Ok(cfg.num_hidden_layers)
4086 }
4087
4088 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4089 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4090
4091 let head_dim = cfg.head_dim();
4092 let cfg = ModelConfigMetadata {
4093 max_seq_len: cfg.max_position_embeddings,
4094 num_layers: cfg.num_hidden_layers,
4095 hidden_size: cfg.hidden_size,
4096 num_kv_heads: cfg.num_key_value_heads,
4097 num_attn_heads: cfg.num_attention_heads,
4098 sliding_window: None,
4099 k_head_dim: head_dim,
4100 v_head_dim: head_dim,
4101 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4102 };
4103
4104 Ok(Box::new(cfg))
4105 }
4106}
4107
4108pub struct Qwen3MoELoader;
4112
4113impl NormalModelLoader for Qwen3MoELoader {
4114 fn load(
4115 &self,
4116 config: &str,
4117 vb: ShardedVarBuilder,
4118 normal_loading_metadata: NormalLoadingMetadata,
4119 attention_mechanism: AttentionImplementation,
4120 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4121 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4122
4123 Ok(Box::new(models::qwen3_moe::Model::new(
4124 &cfg,
4125 vb,
4126 self.is_gptx(config)?,
4127 normal_loading_metadata,
4128 attention_mechanism,
4129 )?))
4130 }
4131 fn load_xlora(
4132 &self,
4133 _config: &str,
4134 _vb: ShardedVarBuilder,
4135 _lora_config: &[((String, String), LoraConfig)],
4136 _xlora_config: Option<XLoraConfig>,
4137 _xlora_ordering: Ordering,
4138 _normal_loading_metadata: NormalLoadingMetadata,
4139 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4140 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4141 todo!()
4142 }
4143 fn is_gptx(&self, _: &str) -> Result<bool> {
4144 Ok(true)
4145 }
4146 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4147 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4148
4149 Ok(Box::new(cfg))
4150 }
4151}
4152
4153impl IsqModelLoader for Qwen3MoELoader {
4154 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4155 Ok(vec![
4156 Regex::new(r"lm_head\.(weight|bias)$")?,
4157 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4159 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4160 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4161 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4162 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4164 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4165 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4166 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
4168 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
4169 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
4170 ])
4171 }
4172 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4173 self.isq_layer_regexes(config)
4174 }
4175 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
4176 self.isq_layer_regexes_moqe(config)
4177 }
4178}
4179
4180impl DeviceMappedModelLoader for Qwen3MoELoader {
4181 fn mapped_max_act_size_elems(
4182 &self,
4183 config: &str,
4184 params: &AutoDeviceMapParams,
4185 ) -> Result<usize> {
4186 let AutoDeviceMapParams::Text {
4187 max_seq_len,
4188 max_batch_size,
4189 } = params
4190 else {
4191 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4192 };
4193
4194 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4195
4196 Ok(
4197 max_batch_size
4198 * cfg.num_attention_heads
4199 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4200 )
4201 }
4202 fn non_mapped_max_act_size_elems(
4203 &self,
4204 _config: &str,
4205 _params: &AutoDeviceMapParams,
4206 ) -> Result<usize> {
4207 Ok(0)
4208 }
4209
4210 fn non_mapped_size_in_bytes(
4211 &self,
4212 config: &str,
4213 dtype: DType,
4214 weight_pack_factor: usize,
4215 _matformer_config: Option<&MatformerSliceConfig>,
4216 ) -> Result<usize> {
4217 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4218 let elems = {
4219 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
4220 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4222 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4223 } else {
4224 0
4225 };
4226 let norm = cfg.hidden_size;
4227 embed_tokens + lm_head + norm
4228 };
4229 Ok(elems * dtype.size_in_bytes())
4230 }
4231
4232 fn layer_sizes_in_bytes(
4233 &self,
4234 config: &str,
4235 dtype: DType,
4236 weight_pack_factor: usize,
4237 _matformer_config: Option<&MatformerSliceConfig>,
4238 ) -> Result<Vec<usize>> {
4239 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4240
4241 let mut layer_sizes_in_bytes = Vec::new();
4242 for layer_idx in 0..cfg.num_hidden_layers {
4243 let input_layernorm = cfg.hidden_size;
4244 let post_attention_layernorm = cfg.hidden_size;
4245
4246 let size_in = cfg.hidden_size;
4247 let size_q = cfg.head_dim() * cfg.num_attention_heads;
4248 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
4249 let q_proj = size_in * size_q / weight_pack_factor;
4250 let k_proj = size_in * size_kv / weight_pack_factor;
4251 let v_proj = size_in * size_kv / weight_pack_factor;
4252 let o_proj = size_q * size_in / weight_pack_factor;
4253
4254 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
4255 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
4256 {
4257 let gate_size = cfg.hidden_size * cfg.num_experts;
4258 let expert_size = {
4259 let h_size = cfg.hidden_size;
4260 let i_size = cfg.moe_intermediate_size;
4261 let gate_proj = h_size * i_size / weight_pack_factor;
4262 let up_proj = h_size * i_size / weight_pack_factor;
4263 let down_proj = i_size * h_size / weight_pack_factor;
4264 gate_proj + up_proj + down_proj
4265 };
4266 expert_size * cfg.num_experts + gate_size
4267 } else {
4268 let h_size = cfg.hidden_size;
4269 let i_size = cfg.intermediate_size;
4270 let gate_proj = h_size * i_size / weight_pack_factor;
4271 let up_proj = h_size * i_size / weight_pack_factor;
4272 let down_proj = i_size * h_size / weight_pack_factor;
4273 gate_proj + up_proj + down_proj
4274 };
4275
4276 let q_norm = cfg.head_dim();
4277 let k_norm = cfg.head_dim();
4278
4279 let size_elems = input_layernorm
4280 + post_attention_layernorm
4281 + q_proj
4282 + k_proj
4283 + v_proj
4284 + o_proj
4285 + mlp_size
4286 + q_norm
4287 + k_norm;
4288
4289 let size_in_bytes = size_elems * dtype.size_in_bytes();
4290 layer_sizes_in_bytes.push(size_in_bytes);
4291 }
4292
4293 Ok(layer_sizes_in_bytes)
4294 }
4295
4296 fn num_layers(&self, config: &str) -> Result<usize> {
4297 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4298 Ok(cfg.num_hidden_layers)
4299 }
4300
4301 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4302 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4303
4304 let cfg = ModelConfigMetadata {
4305 max_seq_len: cfg.max_position_embeddings,
4306 num_layers: cfg.num_hidden_layers,
4307 hidden_size: cfg.hidden_size,
4308 num_kv_heads: cfg.num_key_value_heads,
4309 num_attn_heads: cfg.num_attention_heads,
4310 sliding_window: cfg.sliding_window,
4311 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4312 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4313 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4314 };
4315
4316 Ok(Box::new(cfg))
4317 }
4318}
4319
4320pub struct SmolLm3Loader;
4326
4327impl NormalModelLoader for SmolLm3Loader {
4328 fn load(
4329 &self,
4330 config: &str,
4331 vb: ShardedVarBuilder,
4332 normal_loading_metadata: NormalLoadingMetadata,
4333 attention_mechanism: AttentionImplementation,
4334 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4335 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4336
4337 Ok(Box::new(models::smollm3::SmolLm3::new(
4338 &cfg,
4339 vb,
4340 self.is_gptx(config)?,
4341 normal_loading_metadata,
4342 attention_mechanism,
4343 )?))
4344 }
4345 fn load_xlora(
4346 &self,
4347 _config: &str,
4348 _vb: ShardedVarBuilder,
4349 _lora_config: &[((String, String), LoraConfig)],
4350 _xlora_config: Option<XLoraConfig>,
4351 _xlora_ordering: Ordering,
4352 _normal_loading_metadata: NormalLoadingMetadata,
4353 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4354 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4355 todo!()
4356 }
4357 fn is_gptx(&self, _: &str) -> Result<bool> {
4358 Ok(true)
4359 }
4360 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4361 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4362 Ok(Box::new(cfg))
4363 }
4364}
4365
4366impl IsqModelLoader for SmolLm3Loader {
4367 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4368 Ok(vec![
4369 Regex::new(r"lm_head\.(weight|bias)$")?,
4370 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4372 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4373 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4374 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4375 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4377 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4378 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4379 ])
4380 }
4381 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4382 self.isq_layer_regexes(config)
4383 }
4384}
4385
4386impl DeviceMappedModelLoader for SmolLm3Loader {
4387 fn mapped_max_act_size_elems(
4388 &self,
4389 config: &str,
4390 params: &AutoDeviceMapParams,
4391 ) -> Result<usize> {
4392 let AutoDeviceMapParams::Text {
4393 max_seq_len,
4394 max_batch_size,
4395 } = params
4396 else {
4397 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4398 };
4399
4400 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4401
4402 Ok(
4403 max_batch_size
4404 * cfg.num_attention_heads
4405 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4406 )
4407 }
4408 fn non_mapped_max_act_size_elems(
4409 &self,
4410 _config: &str,
4411 _params: &AutoDeviceMapParams,
4412 ) -> Result<usize> {
4413 Ok(0)
4414 }
4415
4416 fn non_mapped_size_in_bytes(
4417 &self,
4418 config: &str,
4419 dtype: DType,
4420 weight_pack_factor: usize,
4421 _matformer_config: Option<&MatformerSliceConfig>,
4422 ) -> Result<usize> {
4423 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4424
4425 let elems = {
4426 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4427 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4429 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4430 } else {
4431 0
4432 };
4433 let norm = cfg.hidden_size;
4434 embed_tokens + lm_head + norm
4435 };
4436 Ok(elems * dtype.size_in_bytes())
4437 }
4438
4439 fn layer_sizes_in_bytes(
4440 &self,
4441 config: &str,
4442 dtype: DType,
4443 weight_pack_factor: usize,
4444 _matformer_config: Option<&MatformerSliceConfig>,
4445 ) -> Result<Vec<usize>> {
4446 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4447
4448 let per_layer_elems = {
4449 let input_layernorm = cfg.hidden_size;
4450 let post_attention_layernorm = cfg.hidden_size;
4451
4452 let size_in = cfg.hidden_size;
4453 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4454 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4455 let q_proj = size_in * size_q / weight_pack_factor;
4456 let k_proj = size_in * size_kv / weight_pack_factor;
4457 let v_proj = size_in * size_kv / weight_pack_factor;
4458 let o_proj = size_q * size_in / weight_pack_factor;
4459
4460 let h_size = cfg.hidden_size;
4461 let i_size = cfg.intermediate_size;
4462 let gate_proj = h_size * i_size / weight_pack_factor;
4463 let up_proj = h_size * i_size / weight_pack_factor;
4464 let down_proj = i_size * h_size / weight_pack_factor;
4465
4466 input_layernorm
4467 + post_attention_layernorm
4468 + q_proj
4469 + k_proj
4470 + v_proj
4471 + o_proj
4472 + gate_proj
4473 + up_proj
4474 + down_proj
4475 };
4476 Ok(vec![
4477 per_layer_elems * dtype.size_in_bytes();
4478 cfg.num_hidden_layers
4479 ])
4480 }
4481
4482 fn num_layers(&self, config: &str) -> Result<usize> {
4483 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4484
4485 Ok(cfg.num_hidden_layers)
4486 }
4487 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4488 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4489
4490 let cfg = ModelConfigMetadata {
4491 max_seq_len: cfg.max_position_embeddings,
4492 num_layers: cfg.num_hidden_layers,
4493 hidden_size: cfg.hidden_size,
4494 num_kv_heads: cfg.num_key_value_heads,
4495 num_attn_heads: cfg.num_attention_heads,
4496 sliding_window: None,
4497 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4498 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4499 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4500 };
4501
4502 Ok(Box::new(cfg))
4503 }
4504}
4505
4506pub struct GraniteMoeHybridLoader;
4512
4513impl NormalModelLoader for GraniteMoeHybridLoader {
4514 fn load(
4515 &self,
4516 config: &str,
4517 vb: ShardedVarBuilder,
4518 normal_loading_metadata: NormalLoadingMetadata,
4519 attention_mechanism: AttentionImplementation,
4520 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4521 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4522
4523 Ok(Box::new(models::granite::GraniteMoeHybrid::new(
4524 &cfg,
4525 vb,
4526 self.is_gptx(config)?,
4527 normal_loading_metadata,
4528 attention_mechanism,
4529 )?))
4530 }
4531 fn load_xlora(
4532 &self,
4533 _config: &str,
4534 _vb: ShardedVarBuilder,
4535 _lora_config: &[((String, String), LoraConfig)],
4536 _xlora_config: Option<XLoraConfig>,
4537 _xlora_ordering: Ordering,
4538 _normal_loading_metadata: NormalLoadingMetadata,
4539 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4540 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4541 anyhow::bail!("GraniteMoeHybrid does not support X-LoRA")
4542 }
4543 fn is_gptx(&self, _: &str) -> Result<bool> {
4544 Ok(true)
4545 }
4546 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4547 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4548 Ok(Box::new(cfg))
4549 }
4550 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4551 Ok(true)
4552 }
4553}
4554
4555impl IsqModelLoader for GraniteMoeHybridLoader {
4556 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4557 Ok(vec![
4558 Regex::new(r"lm_head\.(weight|bias)$")?,
4559 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4561 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4562 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4563 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4564 Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
4566 Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
4567 ])
4568 }
4569 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4570 self.isq_layer_regexes(config)
4571 }
4572}
4573
4574impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
4575 fn mapped_max_act_size_elems(
4576 &self,
4577 config: &str,
4578 params: &AutoDeviceMapParams,
4579 ) -> Result<usize> {
4580 let AutoDeviceMapParams::Text {
4581 max_seq_len,
4582 max_batch_size,
4583 } = params
4584 else {
4585 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4586 };
4587
4588 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4589
4590 Ok(
4591 max_batch_size
4592 * cfg.num_attention_heads
4593 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4594 )
4595 }
4596 fn non_mapped_max_act_size_elems(
4597 &self,
4598 _config: &str,
4599 _params: &AutoDeviceMapParams,
4600 ) -> Result<usize> {
4601 Ok(0)
4602 }
4603
4604 fn non_mapped_size_in_bytes(
4605 &self,
4606 config: &str,
4607 dtype: DType,
4608 weight_pack_factor: usize,
4609 _matformer_config: Option<&MatformerSliceConfig>,
4610 ) -> Result<usize> {
4611 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4612
4613 let elems = {
4614 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4615 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4617 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4618 } else {
4619 0
4620 };
4621 let norm = cfg.hidden_size;
4622 embed_tokens + lm_head + norm
4623 };
4624 Ok(elems * dtype.size_in_bytes())
4625 }
4626
4627 fn layer_sizes_in_bytes(
4628 &self,
4629 config: &str,
4630 dtype: DType,
4631 weight_pack_factor: usize,
4632 _matformer_config: Option<&MatformerSliceConfig>,
4633 ) -> Result<Vec<usize>> {
4634 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4635
4636 let per_layer_elems = {
4637 let input_layernorm = cfg.hidden_size;
4638 let post_attention_layernorm = cfg.hidden_size;
4639
4640 let size_in = cfg.hidden_size;
4641 let size_q = cfg.head_dim() * cfg.num_attention_heads;
4642 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
4643 let q_proj = size_in * size_q / weight_pack_factor;
4644 let k_proj = size_in * size_kv / weight_pack_factor;
4645 let v_proj = size_in * size_kv / weight_pack_factor;
4646 let o_proj = size_q * size_in / weight_pack_factor;
4647
4648 let h_size = cfg.hidden_size;
4649 let shared_i_size = cfg.shared_intermediate_size();
4650 let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
4652 let output_linear = shared_i_size * h_size / weight_pack_factor;
4653
4654 input_layernorm
4655 + post_attention_layernorm
4656 + q_proj
4657 + k_proj
4658 + v_proj
4659 + o_proj
4660 + input_linear
4661 + output_linear
4662 };
4663 Ok(vec![
4664 per_layer_elems * dtype.size_in_bytes();
4665 cfg.num_hidden_layers
4666 ])
4667 }
4668
4669 fn num_layers(&self, config: &str) -> Result<usize> {
4670 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4671
4672 Ok(cfg.num_hidden_layers)
4673 }
4674 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4675 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4676
4677 let cfg = ModelConfigMetadata {
4678 max_seq_len: cfg.max_position_embeddings,
4679 num_layers: cfg.num_hidden_layers,
4680 hidden_size: cfg.hidden_size,
4681 num_kv_heads: cfg.num_key_value_heads(),
4682 num_attn_heads: cfg.num_attention_heads,
4683 sliding_window: None,
4684 k_head_dim: cfg.head_dim(),
4685 v_head_dim: cfg.head_dim(),
4686 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4687 };
4688
4689 Ok(Box::new(cfg))
4690 }
4691}
4692
4693pub struct GptOssLoader;
4699
4700impl NormalModelLoader for GptOssLoader {
4701 fn load(
4702 &self,
4703 config: &str,
4704 vb: ShardedVarBuilder,
4705 normal_loading_metadata: NormalLoadingMetadata,
4706 attention_mechanism: AttentionImplementation,
4707 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4708 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4709
4710 Ok(Box::new(models::gpt_oss::Model::new(
4711 &cfg,
4712 vb,
4713 self.is_gptx(config)?,
4714 normal_loading_metadata,
4715 attention_mechanism,
4716 )?))
4717 }
4718 fn load_xlora(
4719 &self,
4720 _config: &str,
4721 _vb: ShardedVarBuilder,
4722 _lora_config: &[((String, String), LoraConfig)],
4723 _xlora_config: Option<XLoraConfig>,
4724 _xlora_ordering: Ordering,
4725 _normal_loading_metadata: NormalLoadingMetadata,
4726 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4727 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4728 anyhow::bail!("GPT-OSS does not support X-LoRA")
4729 }
4730 fn is_gptx(&self, _: &str) -> Result<bool> {
4731 Ok(true)
4732 }
4733 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4734 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4735 Ok(Box::new(cfg))
4736 }
4737}
4738
4739impl IsqModelLoader for GptOssLoader {
4740 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4741 Ok(vec![
4743 Regex::new(r"lm_head\.(weight|bias)$")?,
4744 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4746 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4747 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4748 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4749 ])
4750 }
4751 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4752 self.isq_layer_regexes(config)
4753 }
4754}
4755
4756impl DeviceMappedModelLoader for GptOssLoader {
4757 fn mapped_max_act_size_elems(
4758 &self,
4759 config: &str,
4760 params: &AutoDeviceMapParams,
4761 ) -> Result<usize> {
4762 let AutoDeviceMapParams::Text {
4763 max_seq_len,
4764 max_batch_size,
4765 } = params
4766 else {
4767 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4768 };
4769
4770 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4771
4772 Ok(
4773 max_batch_size
4774 * cfg.num_attention_heads
4775 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4776 )
4777 }
4778 fn non_mapped_max_act_size_elems(
4779 &self,
4780 _config: &str,
4781 _params: &AutoDeviceMapParams,
4782 ) -> Result<usize> {
4783 Ok(0)
4784 }
4785
4786 fn non_mapped_size_in_bytes(
4787 &self,
4788 config: &str,
4789 dtype: DType,
4790 weight_pack_factor: usize,
4791 _matformer_config: Option<&MatformerSliceConfig>,
4792 ) -> Result<usize> {
4793 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4794
4795 let elems = {
4796 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4797 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4798 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4799 } else {
4800 0
4801 };
4802 let norm = cfg.hidden_size;
4803 embed_tokens + lm_head + norm
4804 };
4805 Ok(elems * dtype.size_in_bytes())
4806 }
4807
4808 fn layer_sizes_in_bytes(
4809 &self,
4810 config: &str,
4811 dtype: DType,
4812 weight_pack_factor: usize,
4813 _matformer_config: Option<&MatformerSliceConfig>,
4814 ) -> Result<Vec<usize>> {
4815 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4816
4817 let per_layer_elems = {
4818 let input_layernorm = cfg.hidden_size;
4819 let post_attention_layernorm = cfg.hidden_size;
4820
4821 let size_in = cfg.hidden_size;
4822 let head_dim = cfg.head_dim();
4823 let size_q = head_dim * cfg.num_attention_heads;
4824 let size_kv = head_dim * cfg.num_key_value_heads;
4825 let q_proj =
4826 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4827 let k_proj =
4828 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4829 let v_proj =
4830 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4831 let o_proj =
4832 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4833
4834 let mxfp4_pack = 2;
4839 let gate_up_proj_size =
4840 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4841 let down_proj_size =
4842 cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4843 let gate_up_scales =
4845 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4846 let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4847 let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4849 let down_bias = cfg.num_local_experts * cfg.hidden_size;
4850 let router = cfg.hidden_size * cfg.num_local_experts;
4852 let sinks = cfg.num_attention_heads;
4854
4855 input_layernorm
4856 + post_attention_layernorm
4857 + q_proj
4858 + k_proj
4859 + v_proj
4860 + o_proj
4861 + gate_up_proj_size
4862 + down_proj_size
4863 + gate_up_scales
4864 + down_scales
4865 + gate_up_bias
4866 + down_bias
4867 + router
4868 + sinks
4869 };
4870 Ok(vec![
4871 per_layer_elems * dtype.size_in_bytes();
4872 cfg.num_hidden_layers
4873 ])
4874 }
4875
4876 fn num_layers(&self, config: &str) -> Result<usize> {
4877 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4878
4879 Ok(cfg.num_hidden_layers)
4880 }
4881 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4882 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4883
4884 let head_dim = cfg.head_dim();
4885 let cfg = ModelConfigMetadata {
4886 max_seq_len: cfg.max_position_embeddings,
4887 num_layers: cfg.num_hidden_layers,
4888 hidden_size: cfg.hidden_size,
4889 num_kv_heads: cfg.num_key_value_heads,
4890 num_attn_heads: cfg.num_attention_heads,
4891 sliding_window: cfg.sliding_window,
4892 k_head_dim: head_dim,
4893 v_head_dim: head_dim,
4894 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4895 };
4896
4897 Ok(Box::new(cfg))
4898 }
4899}
4900
4901pub struct Qwen3NextLoader;
4907
4908impl NormalModelLoader for Qwen3NextLoader {
4909 fn load(
4910 &self,
4911 config: &str,
4912 vb: ShardedVarBuilder,
4913 normal_loading_metadata: NormalLoadingMetadata,
4914 attention_mechanism: AttentionImplementation,
4915 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4916 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4917
4918 Ok(Box::new(models::qwen3_next::Model::new(
4919 &cfg,
4920 vb,
4921 self.is_gptx(config)?,
4922 normal_loading_metadata,
4923 attention_mechanism,
4924 )?))
4925 }
4926 fn load_xlora(
4927 &self,
4928 _config: &str,
4929 _vb: ShardedVarBuilder,
4930 _lora_config: &[((String, String), LoraConfig)],
4931 _xlora_config: Option<XLoraConfig>,
4932 _xlora_ordering: Ordering,
4933 _normal_loading_metadata: NormalLoadingMetadata,
4934 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4935 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4936 anyhow::bail!("Qwen3Next does not support X-LoRA")
4937 }
4938 fn is_gptx(&self, _: &str) -> Result<bool> {
4939 Ok(true)
4940 }
4941 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4942 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4943 Ok(Box::new(cfg))
4944 }
4945 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4946 Ok(true)
4947 }
4948}
4949
4950impl IsqModelLoader for Qwen3NextLoader {
4951 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4952 Ok(vec![
4953 Regex::new(r"lm_head\.(weight|bias)$")?,
4954 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4955 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4956 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4957 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4958 Regex::new(r"layers\.(\d+)\.linear_attn\.out_proj\.(weight|bias)$")?,
4959 Regex::new(
4960 r"layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.(weight|bias)$",
4961 )?,
4962 Regex::new(
4963 r"layers\.(\d+)\.mlp\.shared_expert\.(gate_proj|up_proj|down_proj)\.(weight|bias)$",
4964 )?,
4965 ])
4966 }
4967 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4968 self.isq_layer_regexes(config)
4969 }
4970}
4971
4972impl DeviceMappedModelLoader for Qwen3NextLoader {
4973 fn mapped_max_act_size_elems(
4974 &self,
4975 config: &str,
4976 params: &AutoDeviceMapParams,
4977 ) -> Result<usize> {
4978 let AutoDeviceMapParams::Text {
4979 max_seq_len,
4980 max_batch_size,
4981 } = params
4982 else {
4983 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4984 };
4985
4986 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4987
4988 Ok(
4989 max_batch_size
4990 * cfg.num_attention_heads
4991 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4992 )
4993 }
4994 fn non_mapped_max_act_size_elems(
4995 &self,
4996 _config: &str,
4997 _params: &AutoDeviceMapParams,
4998 ) -> Result<usize> {
4999 Ok(0)
5000 }
5001
5002 fn non_mapped_size_in_bytes(
5003 &self,
5004 config: &str,
5005 dtype: DType,
5006 weight_pack_factor: usize,
5007 _matformer_config: Option<&MatformerSliceConfig>,
5008 ) -> Result<usize> {
5009 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5010
5011 let elems = {
5012 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5013 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
5014 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5015 } else {
5016 0
5017 };
5018 let norm = cfg.hidden_size;
5019 embed_tokens + lm_head + norm
5020 };
5021 Ok(elems * dtype.size_in_bytes())
5022 }
5023
5024 fn layer_sizes_in_bytes(
5025 &self,
5026 config: &str,
5027 dtype: DType,
5028 weight_pack_factor: usize,
5029 _matformer_config: Option<&MatformerSliceConfig>,
5030 ) -> Result<Vec<usize>> {
5031 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5032 let layer_types = cfg.layer_types();
5033 let mut layer_sizes = Vec::with_capacity(cfg.num_hidden_layers);
5034
5035 for layer_type in &layer_types {
5036 let input_layernorm = cfg.hidden_size;
5037 let post_attention_layernorm = cfg.hidden_size;
5038
5039 let attn_elems = match layer_type {
5040 crate::models::qwen3_next::LayerType::FullAttention => {
5041 let hidden = cfg.hidden_size;
5042 let q_dim = cfg.head_dim * cfg.num_attention_heads;
5043 let kv_dim = cfg.head_dim * cfg.num_key_value_heads;
5044 let q_proj = hidden * q_dim * 2 / weight_pack_factor;
5045 let k_proj = hidden * kv_dim / weight_pack_factor;
5046 let v_proj = hidden * kv_dim / weight_pack_factor;
5047 let o_proj = q_dim * hidden / weight_pack_factor;
5048 let q_norm = cfg.head_dim;
5049 let k_norm = cfg.head_dim;
5050 q_proj + k_proj + v_proj + o_proj + q_norm + k_norm
5051 }
5052 crate::models::qwen3_next::LayerType::LinearAttention => {
5053 let hidden = cfg.hidden_size;
5054 let key_dim = cfg.linear_key_dim();
5055 let value_dim = cfg.linear_value_dim();
5056 let conv_dim = cfg.linear_conv_dim();
5057 let in_proj_qkvz = hidden * (key_dim * 2 + value_dim * 2);
5059 let in_proj_ba = hidden * (cfg.linear_num_value_heads * 2);
5061 let out_proj = value_dim * hidden / weight_pack_factor;
5062 let conv1d = conv_dim * cfg.linear_conv_kernel_dim;
5063 let dt_bias = cfg.linear_num_value_heads;
5064 let a_log = cfg.linear_num_value_heads;
5065 let norm = cfg.linear_value_head_dim;
5066 in_proj_qkvz + in_proj_ba + out_proj + conv1d + dt_bias + a_log + norm
5067 }
5068 };
5069
5070 let moe_gate = cfg.hidden_size * cfg.num_experts;
5071 let shared_expert =
5072 3 * cfg.hidden_size * cfg.shared_expert_intermediate_size / weight_pack_factor;
5073 let routed_experts = cfg.num_experts * 3 * cfg.hidden_size * cfg.moe_intermediate_size
5074 / weight_pack_factor;
5075
5076 let per_layer_elems = input_layernorm
5077 + post_attention_layernorm
5078 + attn_elems
5079 + moe_gate
5080 + shared_expert
5081 + routed_experts;
5082
5083 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5084 }
5085
5086 Ok(layer_sizes)
5087 }
5088
5089 fn num_layers(&self, config: &str) -> Result<usize> {
5090 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5091 Ok(cfg.num_hidden_layers)
5092 }
5093 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5094 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5095
5096 let cfg = ModelConfigMetadata {
5097 max_seq_len: cfg.max_position_embeddings,
5098 num_layers: cfg.num_hidden_layers,
5099 hidden_size: cfg.hidden_size,
5100 num_kv_heads: cfg.num_key_value_heads,
5101 num_attn_heads: cfg.num_attention_heads,
5102 sliding_window: None,
5103 k_head_dim: cfg.head_dim,
5104 v_head_dim: cfg.head_dim,
5105 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5106 };
5107
5108 Ok(Box::new(cfg))
5109 }
5110}