1use std::{
2 collections::HashMap,
3 fmt::{Debug, Display},
4 str::FromStr,
5 sync::Arc,
6};
7
8use crate::{attention::ATTENTION_CHUNK_SIZE, matformer::MatformerSliceConfig};
9
10use crate::{
11 amoe::AnyMoeBaseModelMixin,
12 device_map::DeviceMapper,
13 lora::{LoraConfig, Ordering},
14 paged_attention::{AttentionImplementation, ModelConfigLike, ModelConfigMetadata},
15 pipeline::{
16 isq::IsqModelLoader,
17 text_models_inputs_processor::{FlashParams, PagedAttentionInputMetadata},
18 EitherCache, IsqModel,
19 },
20 utils::varbuilder_utils::DeviceForLoadTensor,
21 xlora_models::NonGranularState,
22};
23use anyhow::Result;
24use candle_core::{DType, Device, Tensor};
25use mistralrs_quant::log::once_log_info;
26
27use indicatif::MultiProgress;
28use mistralrs_quant::ShardedVarBuilder;
29#[cfg(feature = "pyo3_macros")]
30use pyo3::pyclass;
31
32use regex::Regex;
33use serde::Deserialize;
34
35use crate::{
36 models,
37 xlora_models::{self, XLoraConfig},
38};
39
40use super::{AutoDeviceMapParams, DeviceMappedModelLoader};
41
42pub trait NormalModel: IsqModel + AnyMoeBaseModelMixin {
43 #[allow(clippy::too_many_arguments)]
44 fn forward(
45 &self,
46 input_ids: &Tensor,
47 seqlen_offsets: &[usize],
48 context_lens: Vec<(usize, usize)>,
49 position_ids: Vec<usize>,
50 metadata: Option<(Vec<(Tensor, Tensor)>, &PagedAttentionInputMetadata)>,
51 flash_params: &FlashParams,
52 ) -> candle_core::Result<Tensor>;
53 #[allow(clippy::too_many_arguments)]
54 fn xlora_forward(
55 &self,
56 input_ids: &Tensor,
57 input_ids_full: &Tensor,
58 seqlen_offsets: &[usize],
59 seqlen_offsets_full: &[usize],
60 no_kv_cache: bool,
61 non_granular_state: &Option<NonGranularState>,
62 context_lens: Vec<(usize, usize)>,
63 position_ids: Vec<usize>,
64 flash_params: &FlashParams,
65 flash_params_full: &FlashParams,
66 ) -> candle_core::Result<Tensor>;
67 fn is_xlora(&self) -> bool;
68 fn device(&self) -> &Device;
69 fn cache(&self) -> &EitherCache;
70 fn cache_mut(&mut self) -> &mut EitherCache;
71 fn max_seq_len(&self) -> usize;
72 fn config(&self) -> &ModelConfigMetadata;
73 fn model_config(&self) -> Arc<dyn ModelConfigLike + Send + Sync> {
74 Arc::new(self.config().clone())
75 }
76}
77
78pub struct NormalLoadingMetadata {
80 pub mapper: Box<dyn DeviceMapper + Send + Sync>,
82 pub loading_isq: bool,
84 pub real_device: Device,
86 pub multi_progress: Arc<MultiProgress>,
88 pub matformer_slicing_config: Option<MatformerSliceConfig>,
90}
91
92pub trait NormalModelLoader: IsqModelLoader + Send + Sync + DeviceMappedModelLoader {
93 fn load(
94 &self,
95 config: &str,
96 vb: ShardedVarBuilder,
97 normal_loading_metadata: NormalLoadingMetadata,
98 attention_mechanism: AttentionImplementation,
99 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
100 #[allow(clippy::too_many_arguments)]
101 fn load_xlora(
102 &self,
103 config: &str,
104 vb: ShardedVarBuilder,
105 lora_config: &[((String, String), LoraConfig)],
106 xlora_config: Option<XLoraConfig>,
107 xlora_ordering: Ordering,
108 normal_loading_metadata: NormalLoadingMetadata,
109 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
110 ) -> Result<Box<dyn NormalModel + Send + Sync>>;
111 fn is_gptx(&self, config: &str) -> Result<bool>;
112 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
113 Ok(true)
114 }
115 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>>;
116 fn get_device_for_tensor(
117 &self,
118 config: &str,
119 _mapper: &dyn DeviceMapper,
120 loading_isq: bool,
121 ) -> Result<Arc<dyn Fn(String) -> DeviceForLoadTensor + Send + Sync + 'static>> {
122 if loading_isq {
123 Ok(Arc::new(|_| DeviceForLoadTensor::Base))
124 } else {
125 let re = Regex::new(r"\.layers\.(\d+)\.").unwrap();
126 let num_layers = self.model_config(config)?.num_layers();
127 let closure = move |name: String| {
128 if let Some(captures) = re.captures(&name) {
129 captures
130 .get(1)
131 .and_then(|m| m.as_str().parse::<usize>().ok())
132 .map(|l| l.min(num_layers))
133 .map(DeviceForLoadTensor::Idx)
134 .unwrap_or(DeviceForLoadTensor::Base)
135 } else {
136 DeviceForLoadTensor::Base
137 }
138 };
139
140 Ok(Arc::new(closure))
141 }
142 }
143}
144
145#[cfg_attr(feature = "pyo3_macros", pyclass(eq, eq_int))]
146#[derive(Clone, Debug, Deserialize, serde::Serialize, PartialEq)]
147pub enum NormalLoaderType {
149 #[serde(rename = "mistral")]
150 Mistral,
151 #[serde(rename = "gemma")]
152 Gemma,
153 #[serde(rename = "mixtral")]
154 Mixtral,
155 #[serde(rename = "llama")]
156 Llama,
157 #[serde(rename = "phi2")]
158 Phi2,
159 #[serde(rename = "phi3")]
160 Phi3,
161 #[serde(rename = "qwen2")]
162 Qwen2,
163 #[serde(rename = "gemma2")]
164 Gemma2,
165 #[serde(rename = "starcoder2")]
166 Starcoder2,
167 #[serde(rename = "phi3.5moe")]
168 Phi3_5MoE,
169 #[serde(rename = "deepseekv2")]
170 DeepSeekV2,
171 #[serde(rename = "deepseekv3")]
172 DeepSeekV3,
173 #[serde(rename = "qwen3")]
174 Qwen3,
175 #[serde(rename = "glm4")]
176 GLM4,
177 #[serde(rename = "glm4moelite")]
178 GLM4MoeLite,
179 #[serde(rename = "glm4moe")]
180 GLM4Moe,
181 #[serde(rename = "qwen3moe")]
182 Qwen3Moe,
183 #[serde(rename = "smollm3")]
184 SmolLm3,
185 #[serde(rename = "granitemoehybrid")]
186 GraniteMoeHybrid,
187 #[serde(rename = "gpt_oss")]
188 GptOss,
189 #[serde(rename = "qwen3next")]
190 Qwen3Next,
191}
192
193impl NormalLoaderType {
195 pub fn from_causal_lm_name(name: &str) -> Result<Self> {
196 match name {
197 "MistralForCausalLM" => Ok(Self::Mistral),
198 "MixtralForCausalLM" => Ok(Self::Mixtral),
199 "GemmaForCausalLM" => Ok(Self::Gemma),
200 "Gemma2ForCausalLM" => Ok(Self::Gemma2),
201 "PhiForCausalLM" => Ok(Self::Phi2),
202 "Phi3ForCausalLM" => Ok(Self::Phi3),
203 "LlamaForCausalLM" => Ok(Self::Llama),
204 "Qwen2ForCausalLM" => Ok(Self::Qwen2),
205 "Starcoder2ForCausalLM" => Ok(Self::Starcoder2),
206 "PhiMoEForCausalLM" => Ok(Self::Phi3_5MoE),
207 "DeepseekV2ForCausalLM" => Ok(Self::DeepSeekV2),
208 "DeepseekV3ForCausalLM" => Ok(Self::DeepSeekV3),
209 "Qwen3ForCausalLM" => Ok(Self::Qwen3),
210 "Glm4ForCausalLM" => Ok(Self::GLM4),
211 "Glm4MoeLiteForCausalLM" => Ok(Self::GLM4MoeLite),
212 "Glm4MoeForCausalLM" => Ok(Self::GLM4Moe),
213 "Qwen3MoeForCausalLM" => Ok(Self::Qwen3Moe),
214 "SmolLM3ForCausalLM" => Ok(Self::SmolLm3),
215 "GraniteMoeHybridForCausalLM" => Ok(Self::GraniteMoeHybrid),
216 "GptOssForCausalLM" => Ok(Self::GptOss),
217 "Qwen3NextForCausalLM" => Ok(Self::Qwen3Next),
218 other => anyhow::bail!(
219 "Unsupported Hugging Face Transformers -CausalLM model class `{other}`. Please raise an issue."
220 ),
221 }
222 }
223}
224
225impl FromStr for NormalLoaderType {
226 type Err = String;
227 fn from_str(s: &str) -> Result<Self, Self::Err> {
228 match s {
229 "mistral" => Ok(Self::Mistral),
230 "gemma" => Ok(Self::Gemma),
231 "mixtral" => Ok(Self::Mixtral),
232 "llama" => Ok(Self::Llama),
233 "phi2" => Ok(Self::Phi2),
234 "phi3" => Ok(Self::Phi3),
235 "qwen2" => Ok(Self::Qwen2),
236 "gemma2" => Ok(Self::Gemma2),
237 "starcoder2" => Ok(Self::Starcoder2),
238 "phi3.5moe" => Ok(Self::Phi3_5MoE),
239 "deepseekv2" => Ok(Self::DeepSeekV2),
240 "deepseekv3" => Ok(Self::DeepSeekV3),
241 "qwen3" => Ok(Self::Qwen3),
242 "glm4" => Ok(Self::GLM4),
243 "glm4moelite" => Ok(Self::GLM4MoeLite),
244 "glm4moe" => Ok(Self::GLM4Moe),
245 "qwen3moe" => Ok(Self::Qwen3Moe),
246 "smollm3" => Ok(Self::SmolLm3),
247 "granitemoehybrid" => Ok(Self::GraniteMoeHybrid),
248 "gpt_oss" => Ok(Self::GptOss),
249 "qwen3next" => Ok(Self::Qwen3Next),
250 a => Err(format!("Unknown architecture `{a}`. Possible architectures: `mistral`, `gemma`, `mixtral`, `llama`, `phi2`, `phi3`, `qwen2`, `gemma2`, `starcoder2`, `phi3.5moe`, `deepseekv2`, `deepseekv3`, `qwen3`, `glm4`, `glm4moelite`, `glm4moe`, `qwen3moe`, `smollm3`, `granitemoehybrid`, `gpt_oss`, `qwen3next`.")),
251 }
252 }
253}
254
255impl Display for NormalLoaderType {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 match self {
258 Self::Gemma => write!(f, "gemma"),
259 Self::Gemma2 => write!(f, "gemma2"),
260 Self::Llama => write!(f, "llama"),
261 Self::Mistral => write!(f, "mistral"),
262 Self::Mixtral => write!(f, "mixtral"),
263 Self::Phi2 => write!(f, "phi2"),
264 Self::Phi3 => write!(f, "phi3"),
265 Self::Phi3_5MoE => write!(f, "phi3.5moe"),
266 Self::Qwen2 => write!(f, "qwen2"),
267 Self::Starcoder2 => write!(f, "starcoder2"),
268 Self::DeepSeekV2 => write!(f, "deepseekv2"),
269 Self::DeepSeekV3 => write!(f, "deepseekv3"),
270 Self::Qwen3 => write!(f, "qwen3"),
271 Self::GLM4 => write!(f, "glm4"),
272 Self::GLM4MoeLite => write!(f, "glm4moelite"),
273 Self::GLM4Moe => write!(f, "glm4moe"),
274 Self::Qwen3Moe => write!(f, "qwen3moe"),
275 Self::SmolLm3 => write!(f, "smollm3"),
276 Self::GraniteMoeHybrid => write!(f, "granitemoehybrid"),
277 Self::GptOss => write!(f, "gpt_oss"),
278 Self::Qwen3Next => write!(f, "qwen3next"),
279 }
280 }
281}
282
283macro_rules! bias_if {
284 ($cond:expr, $size:expr) => {
285 if $cond {
286 $size
287 } else {
288 0
289 }
290 };
291}
292
293pub struct AutoNormalLoader;
295
296#[derive(Deserialize)]
297struct AutoNormalLoaderConfig {
298 architectures: Vec<String>,
299}
300
301impl AutoNormalLoader {
302 fn get_loader(config: &str) -> Result<Box<dyn NormalModelLoader>> {
303 let auto_cfg: AutoNormalLoaderConfig = serde_json::from_str(config)?;
304 if auto_cfg.architectures.len() != 1 {
305 anyhow::bail!("Expected to have one name for `architectures` config field.")
306 }
307
308 let name = &auto_cfg.architectures[0];
309
310 let tp = NormalLoaderType::from_causal_lm_name(name)?;
311
312 once_log_info(format!("Automatic loader type determined to be `{tp}`"));
313
314 match tp {
315 NormalLoaderType::Mistral => Ok(Box::new(MistralLoader)),
316 NormalLoaderType::Gemma => Ok(Box::new(GemmaLoader)),
317 NormalLoaderType::Llama => Ok(Box::new(LlamaLoader)),
318 NormalLoaderType::Mixtral => Ok(Box::new(MixtralLoader)),
319 NormalLoaderType::Phi2 => Ok(Box::new(Phi2Loader)),
320 NormalLoaderType::Phi3 => Ok(Box::new(Phi3Loader)),
321 NormalLoaderType::Qwen2 => Ok(Box::new(Qwen2Loader)),
322 NormalLoaderType::Gemma2 => Ok(Box::new(Gemma2Loader)),
323 NormalLoaderType::Starcoder2 => Ok(Box::new(Starcoder2Loader)),
324 NormalLoaderType::Phi3_5MoE => Ok(Box::new(Phi3_5MoELoader)),
325 NormalLoaderType::DeepSeekV2 => Ok(Box::new(DeepSeekV2Loader)),
326 NormalLoaderType::DeepSeekV3 => Ok(Box::new(DeepSeekV3Loader)),
327 NormalLoaderType::Qwen3 => Ok(Box::new(Qwen3Loader)),
328 NormalLoaderType::GLM4 => Ok(Box::new(GLM4Loader)),
329 NormalLoaderType::GLM4MoeLite => Ok(Box::new(GLM4MoeLiteLoader)),
330 NormalLoaderType::GLM4Moe => Ok(Box::new(GLM4MoeLoader)),
331 NormalLoaderType::Qwen3Moe => Ok(Box::new(Qwen3MoELoader)),
332 NormalLoaderType::SmolLm3 => Ok(Box::new(SmolLm3Loader)),
333 NormalLoaderType::GraniteMoeHybrid => Ok(Box::new(GraniteMoeHybridLoader)),
334 NormalLoaderType::GptOss => Ok(Box::new(GptOssLoader)),
335 NormalLoaderType::Qwen3Next => Ok(Box::new(Qwen3NextLoader)),
336 }
337 }
338}
339
340impl NormalModelLoader for AutoNormalLoader {
341 fn load(
342 &self,
343 config: &str,
344 vb: ShardedVarBuilder,
345 normal_loading_metadata: NormalLoadingMetadata,
346 attention_mechanism: AttentionImplementation,
347 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
348 Self::get_loader(config)?.load(config, vb, normal_loading_metadata, attention_mechanism)
349 }
350 fn load_xlora(
351 &self,
352 config: &str,
353 vb: ShardedVarBuilder,
354 lora_config: &[((String, String), LoraConfig)],
355 xlora_config: Option<XLoraConfig>,
356 xlora_ordering: Ordering,
357 normal_loading_metadata: NormalLoadingMetadata,
358 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
359 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
360 Self::get_loader(config)?.load_xlora(
361 config,
362 vb,
363 lora_config,
364 xlora_config,
365 xlora_ordering,
366 normal_loading_metadata,
367 preload_adapters,
368 )
369 }
370 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
371 Self::get_loader(config)?.get_config_repr(config)
372 }
373 fn supports_paged_attention(&self, config: &str) -> Result<bool> {
374 Self::get_loader(config)?.supports_paged_attention(config)
375 }
376 fn is_gptx(&self, config: &str) -> Result<bool> {
377 Self::get_loader(config)?.is_gptx(config)
378 }
379}
380
381impl IsqModelLoader for AutoNormalLoader {
382 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
383 Self::get_loader(config)?.immediate_isq_predicates(config)
384 }
385 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
386 Self::get_loader(config)?.immediate_isq_predicates_moqe(config)
387 }
388 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
389 Self::get_loader(config)?.isq_layer_regexes(config)
390 }
391 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
392 Self::get_loader(config)?.isq_layer_regexes_moqe(config)
393 }
394}
395
396impl DeviceMappedModelLoader for AutoNormalLoader {
397 fn non_mapped_size_in_bytes(
398 &self,
399 config: &str,
400 dtype: DType,
401 weight_pack_factor: usize,
402 _matformer_config: Option<&MatformerSliceConfig>,
403 ) -> Result<usize> {
404 Self::get_loader(config)?.non_mapped_size_in_bytes(
405 config,
406 dtype,
407 weight_pack_factor,
408 _matformer_config,
409 )
410 }
411 fn num_layers(&self, config: &str) -> Result<usize> {
412 Self::get_loader(config)?.num_layers(config)
413 }
414 fn layer_sizes_in_bytes(
415 &self,
416 config: &str,
417 dtype: DType,
418 weight_pack_factor: usize,
419 _matformer_config: Option<&MatformerSliceConfig>,
420 ) -> Result<Vec<usize>> {
421 Self::get_loader(config)?.layer_sizes_in_bytes(
422 config,
423 dtype,
424 weight_pack_factor,
425 _matformer_config,
426 )
427 }
428 fn mapped_max_act_size_elems(
429 &self,
430 config: &str,
431 params: &super::AutoDeviceMapParams,
432 ) -> Result<usize> {
433 Self::get_loader(config)?.mapped_max_act_size_elems(config, params)
434 }
435 fn non_mapped_max_act_size_elems(
436 &self,
437 _config: &str,
438 _params: &AutoDeviceMapParams,
439 ) -> Result<usize> {
440 Ok(0)
441 }
442 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
443 Self::get_loader(config)?.model_config(config)
444 }
445}
446
447pub struct MistralLoader;
450
451impl NormalModelLoader for MistralLoader {
452 fn load(
453 &self,
454 config: &str,
455 vb: ShardedVarBuilder,
456 normal_loading_metadata: NormalLoadingMetadata,
457 attention_mechanism: AttentionImplementation,
458 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
459 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
460 Ok(Box::new(models::mistral::Model::new(
461 &cfg,
462 vb,
463 self.is_gptx(config)?,
464 normal_loading_metadata,
465 attention_mechanism,
466 )?))
467 }
468 fn load_xlora(
469 &self,
470 config: &str,
471 vb: ShardedVarBuilder,
472 lora_config: &[((String, String), LoraConfig)],
473 xlora_config: Option<XLoraConfig>,
474 xlora_ordering: Ordering,
475 normal_loading_metadata: NormalLoadingMetadata,
476 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
477 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
478 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
479 Ok(Box::new(xlora_models::XLoraMistral::new(
480 &cfg,
481 vb,
482 lora_config,
483 xlora_config,
484 xlora_ordering,
485 self.is_gptx(config)?,
486 normal_loading_metadata,
487 preload_adapters,
488 )?))
489 }
490 fn is_gptx(&self, _: &str) -> Result<bool> {
491 Ok(true)
492 }
493 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
494 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
495 Ok(Box::new(cfg))
496 }
497}
498
499impl IsqModelLoader for MistralLoader {
500 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
501 Ok(vec![
502 Regex::new(r"lm_head\.(weight|bias)$")?,
503 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
505 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
506 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
507 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
508 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
510 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
511 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
512 ])
513 }
514 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
515 self.isq_layer_regexes(config)
516 }
517}
518
519impl DeviceMappedModelLoader for MistralLoader {
520 fn mapped_max_act_size_elems(
521 &self,
522 config: &str,
523 params: &AutoDeviceMapParams,
524 ) -> Result<usize> {
525 let AutoDeviceMapParams::Text {
526 max_seq_len,
527 max_batch_size,
528 } = params
529 else {
530 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
531 };
532
533 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
534
535 Ok(
536 max_batch_size
537 * cfg.num_attention_heads
538 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
539 )
540 }
541 fn non_mapped_max_act_size_elems(
542 &self,
543 _config: &str,
544 _params: &AutoDeviceMapParams,
545 ) -> Result<usize> {
546 Ok(0)
547 }
548
549 fn non_mapped_size_in_bytes(
550 &self,
551 config: &str,
552 dtype: DType,
553 weight_pack_factor: usize,
554 _matformer_config: Option<&MatformerSliceConfig>,
555 ) -> Result<usize> {
556 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
557
558 let elems = {
559 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
560 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
562 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
563 } else {
564 0
565 };
566 let norm = cfg.hidden_size;
567 embed_tokens + lm_head + norm
568 };
569 Ok(elems * dtype.size_in_bytes())
570 }
571
572 fn layer_sizes_in_bytes(
573 &self,
574 config: &str,
575 dtype: DType,
576 weight_pack_factor: usize,
577 _matformer_config: Option<&MatformerSliceConfig>,
578 ) -> Result<Vec<usize>> {
579 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
580
581 let per_layer_elems = {
582 let input_layernorm = cfg.hidden_size;
583 let post_attention_layernorm = cfg.hidden_size;
584
585 let size_in = cfg.hidden_size;
586 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
587 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
588 let q_proj = size_in * size_q / weight_pack_factor;
589 let k_proj = size_in * size_kv / weight_pack_factor;
590 let v_proj = size_in * size_kv / weight_pack_factor;
591 let o_proj = size_q * size_in / weight_pack_factor;
592
593 let h_size = cfg.hidden_size;
594 let i_size = cfg.intermediate_size;
595 let gate_proj = h_size * i_size / weight_pack_factor;
596 let up_proj = h_size * i_size / weight_pack_factor;
597 let down_proj = i_size * h_size / weight_pack_factor;
598
599 input_layernorm
600 + post_attention_layernorm
601 + q_proj
602 + k_proj
603 + v_proj
604 + o_proj
605 + gate_proj
606 + up_proj
607 + down_proj
608 };
609 Ok(vec![
610 per_layer_elems * dtype.size_in_bytes();
611 cfg.num_hidden_layers
612 ])
613 }
614
615 fn num_layers(&self, config: &str) -> Result<usize> {
616 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
617 Ok(cfg.num_hidden_layers)
618 }
619
620 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
621 let cfg: crate::models::mistral::Config = serde_json::from_str(config)?;
622
623 let cfg = ModelConfigMetadata {
624 max_seq_len: cfg.max_position_embeddings,
625 num_layers: cfg.num_hidden_layers,
626 hidden_size: cfg.hidden_size,
627 num_kv_heads: cfg.num_key_value_heads,
628 num_attn_heads: cfg.num_attention_heads,
629 sliding_window: cfg.sliding_window,
630 k_head_dim: cfg.head_dim(),
631 v_head_dim: cfg.head_dim(),
632 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
633 };
634
635 Ok(Box::new(cfg))
636 }
637}
638
639pub struct GemmaLoader;
645
646impl NormalModelLoader for GemmaLoader {
647 fn load(
648 &self,
649 config: &str,
650 vb: ShardedVarBuilder,
651 normal_loading_metadata: NormalLoadingMetadata,
652 attention_mechanism: AttentionImplementation,
653 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
654 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
655
656 Ok(Box::new(models::gemma::Model::new(
657 &cfg,
658 vb,
659 self.is_gptx(config)?,
660 normal_loading_metadata,
661 attention_mechanism,
662 )?))
663 }
664 fn load_xlora(
665 &self,
666 config: &str,
667 vb: ShardedVarBuilder,
668 lora_config: &[((String, String), LoraConfig)],
669 xlora_config: Option<XLoraConfig>,
670 xlora_ordering: Ordering,
671 normal_loading_metadata: NormalLoadingMetadata,
672 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
673 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
674 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
675
676 Ok(Box::new(xlora_models::XLoraGemma::new(
677 &cfg,
678 vb,
679 lora_config,
680 xlora_config,
681 xlora_ordering,
682 self.is_gptx(config)?,
683 normal_loading_metadata,
684 preload_adapters,
685 )?))
686 }
687 fn is_gptx(&self, _: &str) -> Result<bool> {
688 Ok(true)
689 }
690 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
691 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
692 Ok(Box::new(cfg))
693 }
694}
695
696impl IsqModelLoader for GemmaLoader {
697 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
698 Ok(vec![
699 Regex::new(r"lm_head\.(weight|bias)$")?,
700 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
702 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
703 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
704 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
705 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
707 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
708 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
709 ])
710 }
711 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
712 self.isq_layer_regexes(config)
713 }
714}
715
716impl DeviceMappedModelLoader for GemmaLoader {
717 fn mapped_max_act_size_elems(
718 &self,
719 config: &str,
720 params: &AutoDeviceMapParams,
721 ) -> Result<usize> {
722 let AutoDeviceMapParams::Text {
723 max_seq_len,
724 max_batch_size,
725 } = params
726 else {
727 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
728 };
729
730 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
731
732 Ok(
733 max_batch_size
734 * cfg.num_attention_heads
735 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
736 )
737 }
738 fn non_mapped_max_act_size_elems(
739 &self,
740 _config: &str,
741 _params: &AutoDeviceMapParams,
742 ) -> Result<usize> {
743 Ok(0)
744 }
745
746 fn non_mapped_size_in_bytes(
747 &self,
748 config: &str,
749 dtype: DType,
750 weight_pack_factor: usize,
751 _matformer_config: Option<&MatformerSliceConfig>,
752 ) -> Result<usize> {
753 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
754
755 let elems = {
756 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
757 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
759 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
760 } else {
761 0
762 };
763 let norm = cfg.hidden_size;
764 embed_tokens + lm_head + norm
765 };
766 Ok(elems * dtype.size_in_bytes())
767 }
768
769 fn layer_sizes_in_bytes(
770 &self,
771 config: &str,
772 dtype: DType,
773 weight_pack_factor: usize,
774 _matformer_config: Option<&MatformerSliceConfig>,
775 ) -> Result<Vec<usize>> {
776 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
777
778 let per_layer_elems = {
779 let input_layernorm = cfg.hidden_size;
780 let post_attention_layernorm = cfg.hidden_size;
781
782 let size_in = cfg.hidden_size;
783 let size_q = cfg.head_dim * cfg.num_attention_heads;
784 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
785 let q_proj =
786 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
787 let k_proj =
788 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
789 let v_proj =
790 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
791 let o_proj =
792 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
793
794 let h_size = cfg.hidden_size;
795 let i_size = cfg.intermediate_size;
796 let gate_proj = h_size * i_size / weight_pack_factor;
797 let up_proj = h_size * i_size / weight_pack_factor;
798 let down_proj = i_size * h_size / weight_pack_factor;
799
800 input_layernorm
801 + post_attention_layernorm
802 + q_proj
803 + k_proj
804 + v_proj
805 + o_proj
806 + gate_proj
807 + up_proj
808 + down_proj
809 };
810 Ok(vec![
811 per_layer_elems * dtype.size_in_bytes();
812 cfg.num_hidden_layers
813 ])
814 }
815
816 fn num_layers(&self, config: &str) -> Result<usize> {
817 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
818 Ok(cfg.num_hidden_layers)
819 }
820
821 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
822 let cfg: crate::models::gemma::Config = serde_json::from_str(config)?;
823
824 let cfg = ModelConfigMetadata {
825 max_seq_len: cfg.max_position_embeddings,
826 num_layers: cfg.num_hidden_layers,
827 hidden_size: cfg.hidden_size,
828 num_kv_heads: cfg.num_key_value_heads,
829 num_attn_heads: cfg.num_attention_heads,
830 sliding_window: None,
831 k_head_dim: cfg.head_dim,
832 v_head_dim: cfg.head_dim,
833 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
834 };
835
836 Ok(Box::new(cfg))
837 }
838}
839
840pub struct LlamaLoader;
846
847impl NormalModelLoader for LlamaLoader {
848 fn load(
849 &self,
850 config: &str,
851 vb: ShardedVarBuilder,
852 normal_loading_metadata: NormalLoadingMetadata,
853 attention_mechanism: AttentionImplementation,
854 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
855 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
856
857 Ok(Box::new(models::llama::Llama::new(
858 &cfg,
859 vb,
860 self.is_gptx(config)?,
861 normal_loading_metadata,
862 attention_mechanism,
863 )?))
864 }
865 fn load_xlora(
866 &self,
867 config: &str,
868 vb: ShardedVarBuilder,
869 lora_config: &[((String, String), LoraConfig)],
870 xlora_config: Option<XLoraConfig>,
871 xlora_ordering: Ordering,
872 normal_loading_metadata: NormalLoadingMetadata,
873 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
874 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
875 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
876
877 Ok(Box::new(xlora_models::XLoraLlama::new(
878 &cfg,
879 vb,
880 lora_config,
881 xlora_config,
882 xlora_ordering,
883 self.is_gptx(config)?,
884 normal_loading_metadata,
885 preload_adapters,
886 )?))
887 }
888 fn is_gptx(&self, _: &str) -> Result<bool> {
889 Ok(true)
890 }
891 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
892 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
893 Ok(Box::new(cfg))
894 }
895}
896
897impl IsqModelLoader for LlamaLoader {
898 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
899 Ok(vec![
900 Regex::new(r"lm_head\.(weight|bias)$")?,
901 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
903 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
904 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
905 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
906 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
908 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
909 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
910 ])
911 }
912 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
913 self.isq_layer_regexes(config)
914 }
915}
916
917impl DeviceMappedModelLoader for LlamaLoader {
918 fn mapped_max_act_size_elems(
919 &self,
920 config: &str,
921 params: &AutoDeviceMapParams,
922 ) -> Result<usize> {
923 let AutoDeviceMapParams::Text {
924 max_seq_len,
925 max_batch_size,
926 } = params
927 else {
928 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
929 };
930
931 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
932
933 Ok(
934 max_batch_size
935 * cfg.num_attention_heads
936 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
937 )
938 }
939 fn non_mapped_max_act_size_elems(
940 &self,
941 _config: &str,
942 _params: &AutoDeviceMapParams,
943 ) -> Result<usize> {
944 Ok(0)
945 }
946
947 fn non_mapped_size_in_bytes(
948 &self,
949 config: &str,
950 dtype: DType,
951 weight_pack_factor: usize,
952 _matformer_config: Option<&MatformerSliceConfig>,
953 ) -> Result<usize> {
954 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
955
956 let elems = {
957 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
958 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
960 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
961 } else {
962 0
963 };
964 let norm = cfg.hidden_size;
965 embed_tokens + lm_head + norm
966 };
967 Ok(elems * dtype.size_in_bytes())
968 }
969
970 fn layer_sizes_in_bytes(
971 &self,
972 config: &str,
973 dtype: DType,
974 weight_pack_factor: usize,
975 _matformer_config: Option<&MatformerSliceConfig>,
976 ) -> Result<Vec<usize>> {
977 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
978
979 let per_layer_elems = {
980 let input_layernorm = cfg.hidden_size;
981 let post_attention_layernorm = cfg.hidden_size;
982
983 let size_in = cfg.hidden_size;
984 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
985 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
986 let q_proj = size_in * size_q / weight_pack_factor;
987 let k_proj = size_in * size_kv / weight_pack_factor;
988 let v_proj = size_in * size_kv / weight_pack_factor;
989 let o_proj = size_q * size_in / weight_pack_factor;
990
991 let h_size = cfg.hidden_size;
992 let i_size = cfg.intermediate_size;
993 let gate_proj = h_size * i_size / weight_pack_factor;
994 let up_proj = h_size * i_size / weight_pack_factor;
995 let down_proj = i_size * h_size / weight_pack_factor;
996
997 input_layernorm
998 + post_attention_layernorm
999 + q_proj
1000 + k_proj
1001 + v_proj
1002 + o_proj
1003 + gate_proj
1004 + up_proj
1005 + down_proj
1006 };
1007 Ok(vec![
1008 per_layer_elems * dtype.size_in_bytes();
1009 cfg.num_hidden_layers
1010 ])
1011 }
1012
1013 fn num_layers(&self, config: &str) -> Result<usize> {
1014 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1015
1016 Ok(cfg.num_hidden_layers)
1017 }
1018 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1019 let cfg: crate::models::llama::Config = serde_json::from_str(config)?;
1020
1021 let cfg = ModelConfigMetadata {
1022 max_seq_len: cfg.max_position_embeddings,
1023 num_layers: cfg.num_hidden_layers,
1024 hidden_size: cfg.hidden_size,
1025 num_kv_heads: cfg.num_key_value_heads,
1026 num_attn_heads: cfg.num_attention_heads,
1027 sliding_window: None,
1028 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1029 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1030 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1031 };
1032
1033 Ok(Box::new(cfg))
1034 }
1035}
1036
1037pub struct MixtralLoader;
1040
1041impl NormalModelLoader for MixtralLoader {
1042 fn load(
1043 &self,
1044 config: &str,
1045 vb: ShardedVarBuilder,
1046 normal_loading_metadata: NormalLoadingMetadata,
1047 attention_mechanism: AttentionImplementation,
1048 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1049 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1050
1051 Ok(Box::new(models::mixtral::Model::new(
1052 &cfg,
1053 vb,
1054 self.is_gptx(config)?,
1055 normal_loading_metadata,
1056 attention_mechanism,
1057 )?))
1058 }
1059 fn load_xlora(
1060 &self,
1061 config: &str,
1062 vb: ShardedVarBuilder,
1063 lora_config: &[((String, String), LoraConfig)],
1064 xlora_config: Option<XLoraConfig>,
1065 xlora_ordering: Ordering,
1066 normal_loading_metadata: NormalLoadingMetadata,
1067 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1068 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1069 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1070
1071 Ok(Box::new(xlora_models::XLoraMixtral::new(
1072 &cfg,
1073 vb,
1074 lora_config,
1075 xlora_config,
1076 xlora_ordering,
1077 self.is_gptx(config)?,
1078 normal_loading_metadata,
1079 preload_adapters,
1080 )?))
1081 }
1082 fn is_gptx(&self, _: &str) -> Result<bool> {
1083 Ok(true)
1084 }
1085 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1086 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1087
1088 Ok(Box::new(cfg))
1089 }
1090}
1091
1092impl IsqModelLoader for MixtralLoader {
1093 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1094 Ok(vec![
1095 Regex::new(r"lm_head\.(weight|bias)$")?,
1096 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1098 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1099 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1100 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1101 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.gate\.(weight|bias)$")?,
1103 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
1104 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
1105 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
1106 ])
1107 }
1108 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1109 self.isq_layer_regexes(config)
1110 }
1111}
1112
1113impl DeviceMappedModelLoader for MixtralLoader {
1114 fn mapped_max_act_size_elems(
1115 &self,
1116 config: &str,
1117 params: &AutoDeviceMapParams,
1118 ) -> Result<usize> {
1119 let AutoDeviceMapParams::Text {
1120 max_seq_len,
1121 max_batch_size,
1122 } = params
1123 else {
1124 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1125 };
1126
1127 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1128
1129 Ok(
1130 max_batch_size
1131 * cfg.num_attention_heads
1132 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1133 )
1134 }
1135 fn non_mapped_max_act_size_elems(
1136 &self,
1137 _config: &str,
1138 _params: &AutoDeviceMapParams,
1139 ) -> Result<usize> {
1140 Ok(0)
1141 }
1142
1143 fn non_mapped_size_in_bytes(
1144 &self,
1145 config: &str,
1146 dtype: DType,
1147 weight_pack_factor: usize,
1148 _matformer_config: Option<&MatformerSliceConfig>,
1149 ) -> Result<usize> {
1150 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1151
1152 let elems = {
1153 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1154 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1156 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1157 } else {
1158 0
1159 };
1160 let norm = cfg.hidden_size;
1161 embed_tokens + lm_head + norm
1162 };
1163 Ok(elems * dtype.size_in_bytes())
1164 }
1165
1166 fn layer_sizes_in_bytes(
1167 &self,
1168 config: &str,
1169 dtype: DType,
1170 weight_pack_factor: usize,
1171 _matformer_config: Option<&MatformerSliceConfig>,
1172 ) -> Result<Vec<usize>> {
1173 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1174
1175 let per_layer_elems = {
1176 let input_layernorm = cfg.hidden_size;
1177 let post_attention_layernorm = cfg.hidden_size;
1178
1179 let size_in = cfg.hidden_size;
1180 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1181 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1182 let q_proj = size_in * size_q / weight_pack_factor;
1183 let k_proj = size_in * size_kv / weight_pack_factor;
1184 let v_proj = size_in * size_kv / weight_pack_factor;
1185 let o_proj = size_q * size_in / weight_pack_factor;
1186
1187 let moe_block = {
1188 let gate = cfg.hidden_size * cfg.num_local_experts;
1189 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1191 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1192 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
1193 gate + cfg.num_local_experts * w1
1194 + cfg.num_local_experts * w2
1195 + cfg.num_local_experts * w3
1196 };
1197
1198 input_layernorm
1199 + post_attention_layernorm
1200 + q_proj
1201 + k_proj
1202 + v_proj
1203 + o_proj
1204 + moe_block
1205 };
1206 Ok(vec![
1207 per_layer_elems * dtype.size_in_bytes();
1208 cfg.num_hidden_layers
1209 ])
1210 }
1211
1212 fn num_layers(&self, config: &str) -> Result<usize> {
1213 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1214
1215 Ok(cfg.num_hidden_layers)
1216 }
1217
1218 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1219 let cfg: crate::models::mixtral::Config = serde_json::from_str(config)?;
1220
1221 let cfg = ModelConfigMetadata {
1222 max_seq_len: cfg.max_position_embeddings,
1223 num_layers: cfg.num_hidden_layers,
1224 hidden_size: cfg.hidden_size,
1225 num_kv_heads: cfg.num_key_value_heads,
1226 num_attn_heads: cfg.num_attention_heads,
1227 sliding_window: cfg.sliding_window,
1228 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1229 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1230 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1231 };
1232
1233 Ok(Box::new(cfg))
1234 }
1235}
1236
1237pub struct Phi2Loader;
1243
1244impl NormalModelLoader for Phi2Loader {
1245 fn load(
1246 &self,
1247 config: &str,
1248 vb: ShardedVarBuilder,
1249 normal_loading_metadata: NormalLoadingMetadata,
1250 attention_mechanism: AttentionImplementation,
1251 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1252 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1253
1254 Ok(Box::new(models::phi2::Model::new(
1255 &cfg,
1256 vb,
1257 self.is_gptx(config)?,
1258 normal_loading_metadata,
1259 attention_mechanism,
1260 )?))
1261 }
1262 fn load_xlora(
1263 &self,
1264 config: &str,
1265 vb: ShardedVarBuilder,
1266 lora_config: &[((String, String), LoraConfig)],
1267 xlora_config: Option<XLoraConfig>,
1268 xlora_ordering: Ordering,
1269 normal_loading_metadata: NormalLoadingMetadata,
1270 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1271 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1272 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1273
1274 Ok(Box::new(xlora_models::XLoraPhi2::new(
1275 &cfg,
1276 vb,
1277 lora_config,
1278 xlora_config,
1279 xlora_ordering,
1280 self.is_gptx(config)?,
1281 normal_loading_metadata,
1282 preload_adapters,
1283 )?))
1284 }
1285 fn is_gptx(&self, _: &str) -> Result<bool> {
1286 Ok(true)
1287 }
1288 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1289 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1290
1291 Ok(Box::new(cfg))
1292 }
1293}
1294
1295impl IsqModelLoader for Phi2Loader {
1296 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1297 Ok(vec![
1298 Regex::new(r"lm_head\.(weight|bias)$")?,
1299 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1301 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1302 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1303 Regex::new(r"layers\.(\d+)\.self_attn\.dense\.(weight|bias)$")?,
1304 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
1306 Regex::new(r"layers\.(\d+)\.mlp\.fc2\.(weight|bias)$")?,
1307 ])
1308 }
1309 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1310 self.isq_layer_regexes(config)
1311 }
1312}
1313
1314impl DeviceMappedModelLoader for Phi2Loader {
1315 fn mapped_max_act_size_elems(
1316 &self,
1317 config: &str,
1318 params: &AutoDeviceMapParams,
1319 ) -> Result<usize> {
1320 let AutoDeviceMapParams::Text {
1321 max_seq_len,
1322 max_batch_size,
1323 } = params
1324 else {
1325 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1326 };
1327
1328 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1329
1330 Ok(
1331 max_batch_size
1332 * cfg.num_attention_heads
1333 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1334 )
1335 }
1336 fn non_mapped_max_act_size_elems(
1337 &self,
1338 _config: &str,
1339 _params: &AutoDeviceMapParams,
1340 ) -> Result<usize> {
1341 Ok(0)
1342 }
1343
1344 fn non_mapped_size_in_bytes(
1345 &self,
1346 config: &str,
1347 dtype: DType,
1348 weight_pack_factor: usize,
1349 _matformer_config: Option<&MatformerSliceConfig>,
1350 ) -> Result<usize> {
1351 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1352
1353 let elems = {
1354 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1355 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1357 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1358 } else {
1359 0
1360 };
1361 let norm = cfg.hidden_size;
1362 embed_tokens + lm_head + norm
1363 };
1364 Ok(elems * dtype.size_in_bytes())
1365 }
1366
1367 fn layer_sizes_in_bytes(
1368 &self,
1369 config: &str,
1370 dtype: DType,
1371 weight_pack_factor: usize,
1372 _matformer_config: Option<&MatformerSliceConfig>,
1373 ) -> Result<Vec<usize>> {
1374 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1375
1376 let per_layer_elems = {
1377 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
1378
1379 let size_in = cfg.hidden_size;
1380 let size_q = cfg.head_dim() * cfg.num_attention_heads;
1381 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
1382 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1383 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1384 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1385 let o_proj = size_q * size_in / weight_pack_factor + size_in;
1386 let (q_norm, k_norm) = if cfg.qk_layernorm {
1387 (cfg.head_dim(), cfg.head_dim())
1388 } else {
1389 (0, 0)
1390 };
1391
1392 let h_size = cfg.hidden_size;
1393 let i_size = cfg.intermediate_size;
1394 let fc1 = h_size * i_size / weight_pack_factor;
1395 let fc2 = h_size * i_size / weight_pack_factor;
1396
1397 input_layernorm + q_proj + k_proj + v_proj + o_proj + q_norm + k_norm + fc1 + fc2
1398 };
1399 Ok(vec![
1400 per_layer_elems * dtype.size_in_bytes();
1401 cfg.num_hidden_layers
1402 ])
1403 }
1404
1405 fn num_layers(&self, config: &str) -> Result<usize> {
1406 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1407
1408 Ok(cfg.num_hidden_layers)
1409 }
1410
1411 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1412 let cfg: crate::models::phi2::Config = serde_json::from_str(config)?;
1413
1414 let cfg = ModelConfigMetadata {
1415 max_seq_len: cfg.max_position_embeddings,
1416 num_layers: cfg.num_hidden_layers,
1417 hidden_size: cfg.hidden_size,
1418 num_kv_heads: cfg.num_key_value_heads(),
1419 num_attn_heads: cfg.num_attention_heads,
1420 sliding_window: None,
1421 k_head_dim: cfg.head_dim(),
1422 v_head_dim: cfg.head_dim(),
1423 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1424 };
1425
1426 Ok(Box::new(cfg))
1427 }
1428}
1429
1430pub struct Phi3Loader;
1436
1437impl NormalModelLoader for Phi3Loader {
1438 fn load(
1439 &self,
1440 config: &str,
1441 vb: ShardedVarBuilder,
1442 normal_loading_metadata: NormalLoadingMetadata,
1443 attention_mechanism: AttentionImplementation,
1444 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1445 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1446
1447 Ok(Box::new(models::phi3::Model::new(
1448 &cfg,
1449 vb,
1450 self.is_gptx(config)?,
1451 normal_loading_metadata,
1452 attention_mechanism,
1453 )?))
1454 }
1455 fn load_xlora(
1456 &self,
1457 config: &str,
1458 vb: ShardedVarBuilder,
1459 lora_config: &[((String, String), LoraConfig)],
1460 xlora_config: Option<XLoraConfig>,
1461 xlora_ordering: Ordering,
1462 normal_loading_metadata: NormalLoadingMetadata,
1463 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1464 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1465 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1466
1467 Ok(Box::new(xlora_models::XLoraPhi3::new(
1468 &cfg,
1469 vb,
1470 lora_config,
1471 xlora_config,
1472 xlora_ordering,
1473 self.is_gptx(config)?,
1474 normal_loading_metadata,
1475 preload_adapters,
1476 )?))
1477 }
1478 fn is_gptx(&self, _: &str) -> Result<bool> {
1479 Ok(true)
1480 }
1481 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1482 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1483
1484 Ok(Box::new(cfg))
1485 }
1486}
1487
1488impl IsqModelLoader for Phi3Loader {
1489 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1490 Ok(vec![
1491 Regex::new(r"lm_head\.(weight|bias)$")?,
1492 Regex::new(r"layers\.(\d+)\.self_attn\.qkv_proj\.(weight|bias)$")?,
1494 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1495 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1497 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1498 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1499 ])
1500 }
1501 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1502 self.isq_layer_regexes(config)
1503 }
1504}
1505
1506impl DeviceMappedModelLoader for Phi3Loader {
1507 fn mapped_max_act_size_elems(
1508 &self,
1509 config: &str,
1510 params: &AutoDeviceMapParams,
1511 ) -> Result<usize> {
1512 let AutoDeviceMapParams::Text {
1513 max_seq_len,
1514 max_batch_size,
1515 } = params
1516 else {
1517 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1518 };
1519
1520 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1521
1522 Ok(
1523 max_batch_size
1524 * cfg.num_attention_heads
1525 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1526 )
1527 }
1528 fn non_mapped_max_act_size_elems(
1529 &self,
1530 _config: &str,
1531 _params: &AutoDeviceMapParams,
1532 ) -> Result<usize> {
1533 Ok(0)
1534 }
1535
1536 fn non_mapped_size_in_bytes(
1537 &self,
1538 config: &str,
1539 dtype: DType,
1540 weight_pack_factor: usize,
1541 _matformer_config: Option<&MatformerSliceConfig>,
1542 ) -> Result<usize> {
1543 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1544
1545 let elems = {
1546 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1547 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1549 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1550 } else {
1551 0
1552 };
1553 let norm = cfg.hidden_size;
1554 embed_tokens + lm_head + norm
1555 };
1556 Ok(elems * dtype.size_in_bytes())
1557 }
1558
1559 fn layer_sizes_in_bytes(
1560 &self,
1561 config: &str,
1562 dtype: DType,
1563 weight_pack_factor: usize,
1564 _matformer_config: Option<&MatformerSliceConfig>,
1565 ) -> Result<Vec<usize>> {
1566 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1567
1568 let per_layer_elems = {
1569 let input_layernorm = cfg.hidden_size;
1570 let post_attention_layernorm = cfg.hidden_size;
1571
1572 let size_in = cfg.hidden_size;
1573 let head_dim = cfg.head_dim();
1574 let op_size =
1575 cfg.num_attention_heads * head_dim + 2 * cfg.num_key_value_heads * head_dim;
1576 let qkv_proj = size_in * op_size / weight_pack_factor;
1577 let o_proj =
1578 (cfg.num_attention_heads * head_dim) * size_in / weight_pack_factor + size_in;
1579
1580 let h_size = cfg.hidden_size;
1581 let i_size = cfg.intermediate_size;
1582 let gate_up_proj = h_size * (2 * i_size) / weight_pack_factor;
1583 let down_proj = h_size * i_size / weight_pack_factor;
1584
1585 input_layernorm
1586 + post_attention_layernorm
1587 + qkv_proj
1588 + o_proj
1589 + gate_up_proj
1590 + down_proj
1591 };
1592 Ok(vec![
1593 per_layer_elems * dtype.size_in_bytes();
1594 cfg.num_hidden_layers
1595 ])
1596 }
1597
1598 fn num_layers(&self, config: &str) -> Result<usize> {
1599 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1600
1601 Ok(cfg.num_hidden_layers)
1602 }
1603
1604 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1605 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
1606
1607 let cfg = ModelConfigMetadata {
1608 max_seq_len: cfg.max_position_embeddings,
1609 num_layers: cfg.num_hidden_layers,
1610 hidden_size: cfg.hidden_size,
1611 num_kv_heads: cfg.num_key_value_heads,
1612 num_attn_heads: cfg.num_attention_heads,
1613 sliding_window: cfg.sliding_window,
1614 k_head_dim: cfg.head_dim(),
1615 v_head_dim: cfg.head_dim(),
1616 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1617 };
1618
1619 Ok(Box::new(cfg))
1620 }
1621}
1622
1623pub struct Qwen2Loader;
1629
1630impl NormalModelLoader for Qwen2Loader {
1631 fn load(
1632 &self,
1633 config: &str,
1634 vb: ShardedVarBuilder,
1635 normal_loading_metadata: NormalLoadingMetadata,
1636 attention_mechanism: AttentionImplementation,
1637 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1638 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1639
1640 Ok(Box::new(models::qwen2::Model::new(
1641 &cfg,
1642 vb,
1643 self.is_gptx(config)?,
1644 normal_loading_metadata,
1645 attention_mechanism,
1646 )?))
1647 }
1648 fn load_xlora(
1649 &self,
1650 _config: &str,
1651 _vb: ShardedVarBuilder,
1652 _lora_config: &[((String, String), LoraConfig)],
1653 _xlora_config: Option<XLoraConfig>,
1654 _xlora_ordering: Ordering,
1655 _normal_loading_metadata: NormalLoadingMetadata,
1656 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1657 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1658 todo!()
1659 }
1660 fn is_gptx(&self, _: &str) -> Result<bool> {
1661 Ok(true)
1662 }
1663 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1664 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1665
1666 Ok(Box::new(cfg))
1667 }
1668}
1669
1670impl IsqModelLoader for Qwen2Loader {
1671 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1672 Ok(vec![
1673 Regex::new(r"lm_head\.(weight|bias)$")?,
1674 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1676 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1677 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1678 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1679 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1681 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1682 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1683 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
1685 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
1686 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
1687 ])
1688 }
1689 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1690 self.isq_layer_regexes(config)
1691 }
1692}
1693
1694impl DeviceMappedModelLoader for Qwen2Loader {
1695 fn mapped_max_act_size_elems(
1696 &self,
1697 config: &str,
1698 params: &AutoDeviceMapParams,
1699 ) -> Result<usize> {
1700 let AutoDeviceMapParams::Text {
1701 max_seq_len,
1702 max_batch_size,
1703 } = params
1704 else {
1705 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1706 };
1707
1708 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1709
1710 Ok(
1711 max_batch_size
1712 * cfg.num_attention_heads
1713 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1714 )
1715 }
1716 fn non_mapped_max_act_size_elems(
1717 &self,
1718 _config: &str,
1719 _params: &AutoDeviceMapParams,
1720 ) -> Result<usize> {
1721 Ok(0)
1722 }
1723
1724 fn non_mapped_size_in_bytes(
1725 &self,
1726 config: &str,
1727 dtype: DType,
1728 weight_pack_factor: usize,
1729 _matformer_config: Option<&MatformerSliceConfig>,
1730 ) -> Result<usize> {
1731 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1732
1733 let elems = {
1734 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1735 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1737 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1738 } else {
1739 0
1740 };
1741 let norm = cfg.hidden_size;
1742 embed_tokens + lm_head + norm
1743 };
1744 Ok(elems * dtype.size_in_bytes())
1745 }
1746
1747 fn layer_sizes_in_bytes(
1748 &self,
1749 config: &str,
1750 dtype: DType,
1751 weight_pack_factor: usize,
1752 _matformer_config: Option<&MatformerSliceConfig>,
1753 ) -> Result<Vec<usize>> {
1754 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1755
1756 let per_layer_elems = {
1757 let input_layernorm = cfg.hidden_size;
1758 let post_attention_layernorm = cfg.hidden_size;
1759
1760 let size_in = cfg.hidden_size;
1761 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
1762 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
1763 let q_proj = size_in * size_q / weight_pack_factor + size_q;
1764 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
1765 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
1766 let o_proj = size_q * size_in / weight_pack_factor;
1767
1768 let h_size = cfg.hidden_size;
1769 let i_size = cfg.intermediate_size;
1770 let gate_proj = h_size * i_size / weight_pack_factor;
1771 let up_proj = h_size * i_size / weight_pack_factor;
1772 let down_proj = i_size * h_size / weight_pack_factor;
1773
1774 input_layernorm
1775 + post_attention_layernorm
1776 + q_proj
1777 + k_proj
1778 + v_proj
1779 + o_proj
1780 + gate_proj
1781 + up_proj
1782 + down_proj
1783 };
1784 Ok(vec![
1785 per_layer_elems * dtype.size_in_bytes();
1786 cfg.num_hidden_layers
1787 ])
1788 }
1789
1790 fn num_layers(&self, config: &str) -> Result<usize> {
1791 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1792
1793 Ok(cfg.num_hidden_layers)
1794 }
1795
1796 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1797 let cfg: crate::models::qwen2::Config = serde_json::from_str(config)?;
1798
1799 let cfg = ModelConfigMetadata {
1800 max_seq_len: cfg.max_position_embeddings,
1801 num_layers: cfg.num_hidden_layers,
1802 hidden_size: cfg.hidden_size,
1803 num_kv_heads: cfg.num_key_value_heads,
1804 num_attn_heads: cfg.num_attention_heads,
1805 sliding_window: cfg.sliding_window,
1806 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1807 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
1808 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
1809 };
1810
1811 Ok(Box::new(cfg))
1812 }
1813}
1814
1815pub struct Gemma2Loader;
1821
1822impl NormalModelLoader for Gemma2Loader {
1823 fn load(
1824 &self,
1825 config: &str,
1826 vb: ShardedVarBuilder,
1827 normal_loading_metadata: NormalLoadingMetadata,
1828 attention_mechanism: AttentionImplementation,
1829 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1830 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1831
1832 Ok(Box::new(models::gemma2::Model::new(
1833 &cfg,
1834 vb,
1835 self.is_gptx(config)?,
1836 normal_loading_metadata,
1837 attention_mechanism,
1838 )?))
1839 }
1840 fn load_xlora(
1841 &self,
1842 config: &str,
1843 vb: ShardedVarBuilder,
1844 lora_config: &[((String, String), LoraConfig)],
1845 xlora_config: Option<XLoraConfig>,
1846 xlora_ordering: Ordering,
1847 normal_loading_metadata: NormalLoadingMetadata,
1848 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
1849 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
1850 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1851
1852 Ok(Box::new(xlora_models::XLoraGemma2::new(
1853 &cfg,
1854 vb,
1855 lora_config,
1856 xlora_config,
1857 xlora_ordering,
1858 self.is_gptx(config)?,
1859 normal_loading_metadata,
1860 preload_adapters,
1861 )?))
1862 }
1863 fn is_gptx(&self, _: &str) -> Result<bool> {
1864 Ok(true)
1865 }
1866 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
1867 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1868
1869 Ok(Box::new(cfg))
1870 }
1871}
1872
1873impl IsqModelLoader for Gemma2Loader {
1874 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
1875 Ok(vec![
1876 Regex::new(r"lm_head\.(weight|bias)$")?,
1877 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
1879 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
1880 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
1881 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
1882 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
1884 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
1885 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
1886 ])
1887 }
1888 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
1889 self.isq_layer_regexes(config)
1890 }
1891}
1892
1893impl DeviceMappedModelLoader for Gemma2Loader {
1894 fn mapped_max_act_size_elems(
1895 &self,
1896 config: &str,
1897 params: &AutoDeviceMapParams,
1898 ) -> Result<usize> {
1899 let AutoDeviceMapParams::Text {
1900 max_seq_len,
1901 max_batch_size,
1902 } = params
1903 else {
1904 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
1905 };
1906
1907 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1908
1909 Ok(
1910 max_batch_size
1911 * cfg.num_attention_heads
1912 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
1913 )
1914 }
1915 fn non_mapped_max_act_size_elems(
1916 &self,
1917 _config: &str,
1918 _params: &AutoDeviceMapParams,
1919 ) -> Result<usize> {
1920 Ok(0)
1921 }
1922
1923 fn non_mapped_size_in_bytes(
1924 &self,
1925 config: &str,
1926 dtype: DType,
1927 weight_pack_factor: usize,
1928 _matformer_config: Option<&MatformerSliceConfig>,
1929 ) -> Result<usize> {
1930 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1931
1932 let elems = {
1933 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
1934 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
1936 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
1937 } else {
1938 0
1939 };
1940 let norm = cfg.hidden_size;
1941 embed_tokens + lm_head + norm
1942 };
1943 Ok(elems * dtype.size_in_bytes())
1944 }
1945
1946 fn layer_sizes_in_bytes(
1947 &self,
1948 config: &str,
1949 dtype: DType,
1950 weight_pack_factor: usize,
1951 _matformer_config: Option<&MatformerSliceConfig>,
1952 ) -> Result<Vec<usize>> {
1953 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1954
1955 let per_layer_elems = {
1956 let input_layernorm = cfg.hidden_size;
1957 let post_attention_layernorm = cfg.hidden_size;
1958
1959 let size_in = cfg.hidden_size;
1960 let size_q = cfg.head_dim * cfg.num_attention_heads;
1961 let size_kv = cfg.head_dim * cfg.num_key_value_heads;
1962 let q_proj =
1963 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
1964 let k_proj =
1965 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1966 let v_proj =
1967 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
1968 let o_proj =
1969 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
1970
1971 let h_size = cfg.hidden_size;
1972 let i_size = cfg.intermediate_size;
1973 let gate_proj = h_size * i_size / weight_pack_factor;
1974 let up_proj = h_size * i_size / weight_pack_factor;
1975 let down_proj = i_size * h_size / weight_pack_factor;
1976
1977 input_layernorm
1978 + post_attention_layernorm
1979 + q_proj
1980 + k_proj
1981 + v_proj
1982 + o_proj
1983 + gate_proj
1984 + up_proj
1985 + down_proj
1986 };
1987 Ok(vec![
1988 per_layer_elems * dtype.size_in_bytes();
1989 cfg.num_hidden_layers
1990 ])
1991 }
1992
1993 fn num_layers(&self, config: &str) -> Result<usize> {
1994 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
1995
1996 Ok(cfg.num_hidden_layers)
1997 }
1998 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
1999 let cfg: crate::models::gemma2::Config = serde_json::from_str(config)?;
2000
2001 let cfg = ModelConfigMetadata {
2002 max_seq_len: cfg.max_position_embeddings,
2003 num_layers: cfg.num_hidden_layers,
2004 hidden_size: cfg.hidden_size,
2005 num_kv_heads: cfg.num_key_value_heads,
2006 num_attn_heads: cfg.num_attention_heads,
2007 sliding_window: None, k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2009 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2010 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2011 };
2012
2013 Ok(Box::new(cfg))
2014 }
2015}
2016
2017pub struct Starcoder2Loader;
2023
2024impl NormalModelLoader for Starcoder2Loader {
2025 fn load(
2026 &self,
2027 config: &str,
2028 vb: ShardedVarBuilder,
2029 normal_loading_metadata: NormalLoadingMetadata,
2030 attention_mechanism: AttentionImplementation,
2031 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2032 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2033
2034 Ok(Box::new(models::starcoder2::Model::new(
2035 &cfg,
2036 vb,
2037 self.is_gptx(config)?,
2038 normal_loading_metadata,
2039 attention_mechanism,
2040 )?))
2041 }
2042 fn load_xlora(
2043 &self,
2044 config: &str,
2045 vb: ShardedVarBuilder,
2046 lora_config: &[((String, String), LoraConfig)],
2047 xlora_config: Option<XLoraConfig>,
2048 xlora_ordering: Ordering,
2049 normal_loading_metadata: NormalLoadingMetadata,
2050 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2051 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2052 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2053
2054 Ok(Box::new(xlora_models::XLoraStarcoder2::new(
2055 &cfg,
2056 vb,
2057 lora_config,
2058 xlora_config,
2059 xlora_ordering,
2060 self.is_gptx(config)?,
2061 normal_loading_metadata,
2062 preload_adapters,
2063 )?))
2064 }
2065 fn is_gptx(&self, _: &str) -> Result<bool> {
2066 Ok(true)
2067 }
2068 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2069 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2070
2071 Ok(Box::new(cfg))
2072 }
2073}
2074
2075impl IsqModelLoader for Starcoder2Loader {
2076 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2077 Ok(vec![
2078 Regex::new(r"lm_head\.(weight|bias)$")?,
2079 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2081 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2082 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2083 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2084 Regex::new(r"layers\.(\d+)\.mlp\.fc1\.(weight|bias)$")?,
2086 Regex::new(r"layers\.(\d+)\.mlp\.c_proj\.(weight|bias)$")?,
2087 ])
2088 }
2089 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2090 self.isq_layer_regexes(config)
2091 }
2092}
2093
2094impl DeviceMappedModelLoader for Starcoder2Loader {
2095 fn mapped_max_act_size_elems(
2096 &self,
2097 config: &str,
2098 params: &AutoDeviceMapParams,
2099 ) -> Result<usize> {
2100 let AutoDeviceMapParams::Text {
2101 max_seq_len,
2102 max_batch_size,
2103 } = params
2104 else {
2105 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2106 };
2107
2108 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2109
2110 Ok(
2111 max_batch_size
2112 * cfg.num_attention_heads
2113 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2114 )
2115 }
2116 fn non_mapped_max_act_size_elems(
2117 &self,
2118 _config: &str,
2119 _params: &AutoDeviceMapParams,
2120 ) -> Result<usize> {
2121 Ok(0)
2122 }
2123
2124 fn non_mapped_size_in_bytes(
2125 &self,
2126 config: &str,
2127 dtype: DType,
2128 weight_pack_factor: usize,
2129 _matformer_config: Option<&MatformerSliceConfig>,
2130 ) -> Result<usize> {
2131 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2132
2133 let elems = {
2134 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2135 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2137 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2138 } else {
2139 0
2140 };
2141 let norm = cfg.hidden_size + cfg.hidden_size;
2142 embed_tokens + lm_head + norm
2143 };
2144 Ok(elems * dtype.size_in_bytes())
2145 }
2146
2147 fn layer_sizes_in_bytes(
2148 &self,
2149 config: &str,
2150 dtype: DType,
2151 weight_pack_factor: usize,
2152 _matformer_config: Option<&MatformerSliceConfig>,
2153 ) -> Result<Vec<usize>> {
2154 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2155
2156 let per_layer_elems = {
2157 let input_layernorm = cfg.hidden_size + cfg.hidden_size;
2158 let post_attention_layernorm = cfg.hidden_size + cfg.hidden_size;
2159
2160 let size_in = cfg.hidden_size;
2161 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
2162 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
2163 let q_proj = size_in * size_q / weight_pack_factor + bias_if!(cfg.use_bias, size_q);
2164 let k_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2165 let v_proj = size_in * size_kv / weight_pack_factor + bias_if!(cfg.use_bias, size_kv);
2166 let o_proj = size_q * size_in / weight_pack_factor + bias_if!(cfg.use_bias, size_in);
2167
2168 let h_size = cfg.hidden_size;
2169 let i_size = cfg.intermediate_size;
2170 let fc1 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, i_size);
2171 let fc2 = h_size * i_size / weight_pack_factor + bias_if!(cfg.use_bias, h_size);
2172
2173 input_layernorm
2174 + post_attention_layernorm
2175 + q_proj
2176 + k_proj
2177 + v_proj
2178 + o_proj
2179 + fc1
2180 + fc2
2181 };
2182 Ok(vec![
2183 per_layer_elems * dtype.size_in_bytes();
2184 cfg.num_hidden_layers
2185 ])
2186 }
2187
2188 fn num_layers(&self, config: &str) -> Result<usize> {
2189 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2190
2191 Ok(cfg.num_hidden_layers)
2192 }
2193
2194 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2195 let cfg: crate::models::starcoder2::Config = serde_json::from_str(config)?;
2196
2197 let cfg = ModelConfigMetadata {
2198 max_seq_len: cfg.max_position_embeddings,
2199 num_layers: cfg.num_hidden_layers,
2200 hidden_size: cfg.hidden_size,
2201 num_kv_heads: cfg.num_key_value_heads,
2202 num_attn_heads: cfg.num_attention_heads,
2203 sliding_window: cfg.sliding_window,
2204 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2205 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
2206 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2207 };
2208
2209 Ok(Box::new(cfg))
2210 }
2211}
2212
2213pub struct Phi3_5MoELoader;
2219
2220impl NormalModelLoader for Phi3_5MoELoader {
2221 fn load(
2222 &self,
2223 config: &str,
2224 vb: ShardedVarBuilder,
2225 normal_loading_metadata: NormalLoadingMetadata,
2226 attention_mechanism: AttentionImplementation,
2227 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2228 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2229
2230 Ok(Box::new(models::phi3_5_moe::Model::new(
2231 &cfg,
2232 vb,
2233 self.is_gptx(config)?,
2234 normal_loading_metadata,
2235 attention_mechanism,
2236 )?))
2237 }
2238 fn load_xlora(
2239 &self,
2240 config: &str,
2241 vb: ShardedVarBuilder,
2242 lora_config: &[((String, String), LoraConfig)],
2243 xlora_config: Option<XLoraConfig>,
2244 xlora_ordering: Ordering,
2245 normal_loading_metadata: NormalLoadingMetadata,
2246 preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2247 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2248 let cfg: crate::models::phi3::Config = serde_json::from_str(config)?;
2249
2250 Ok(Box::new(xlora_models::XLoraPhi3::new(
2251 &cfg,
2252 vb,
2253 lora_config,
2254 xlora_config,
2255 xlora_ordering,
2256 self.is_gptx(config)?,
2257 normal_loading_metadata,
2258 preload_adapters,
2259 )?))
2260 }
2261 fn is_gptx(&self, _: &str) -> Result<bool> {
2262 Ok(true)
2263 }
2264 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2265 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2266
2267 Ok(Box::new(cfg))
2268 }
2269}
2270
2271impl IsqModelLoader for Phi3_5MoELoader {
2272 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
2273 Ok(vec![
2274 Regex::new(r"lm_head\.(weight|bias)$")?,
2275 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
2277 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
2278 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
2279 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2280 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2282 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2283 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2284 ])
2285 }
2286 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2287 self.isq_layer_regexes(config)
2288 }
2289
2290 fn isq_layer_regexes_moqe(&self, _config: &str) -> Result<Vec<Regex>> {
2291 Ok(vec![
2292 Regex::new(r"lm_head\.(weight|bias)$")?,
2293 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w1\.(weight|bias)$")?,
2295 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w2\.(weight|bias)$")?,
2296 Regex::new(r"layers\.(\d+)\.block_sparse_moe\.experts\.(\d+)\.w3\.(weight|bias)$")?,
2297 ])
2298 }
2299 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2300 self.isq_layer_regexes_moqe(config)
2301 }
2302}
2303
2304impl DeviceMappedModelLoader for Phi3_5MoELoader {
2305 fn mapped_max_act_size_elems(
2306 &self,
2307 config: &str,
2308 params: &AutoDeviceMapParams,
2309 ) -> Result<usize> {
2310 let AutoDeviceMapParams::Text {
2311 max_seq_len,
2312 max_batch_size,
2313 } = params
2314 else {
2315 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2316 };
2317
2318 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2319
2320 Ok(
2321 max_batch_size
2322 * cfg.num_attention_heads
2323 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2324 )
2325 }
2326 fn non_mapped_max_act_size_elems(
2327 &self,
2328 _config: &str,
2329 _params: &AutoDeviceMapParams,
2330 ) -> Result<usize> {
2331 Ok(0)
2332 }
2333
2334 fn non_mapped_size_in_bytes(
2335 &self,
2336 config: &str,
2337 dtype: DType,
2338 weight_pack_factor: usize,
2339 _matformer_config: Option<&MatformerSliceConfig>,
2340 ) -> Result<usize> {
2341 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2342
2343 let elems = {
2344 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2345 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2347 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2348 } else {
2349 0
2350 };
2351 let norm = cfg.hidden_size;
2352 embed_tokens + lm_head + norm
2353 };
2354 Ok(elems * dtype.size_in_bytes())
2355 }
2356
2357 fn layer_sizes_in_bytes(
2358 &self,
2359 config: &str,
2360 dtype: DType,
2361 weight_pack_factor: usize,
2362 _matformer_config: Option<&MatformerSliceConfig>,
2363 ) -> Result<Vec<usize>> {
2364 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2365
2366 let per_layer_elems = {
2367 let input_layernorm = cfg.hidden_size;
2368 let post_attention_layernorm = cfg.hidden_size;
2369
2370 let size_in = cfg.hidden_size;
2371 let size_q = cfg.head_dim() * cfg.num_attention_heads;
2372 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
2373 let q_proj =
2374 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
2375 let k_proj =
2376 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2377 let v_proj =
2378 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
2379 let o_proj =
2380 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
2381
2382 let moe_block = {
2383 let gate = cfg.hidden_size * cfg.num_local_experts;
2384 let w1 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2386 let w2 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2387 let w3 = cfg.hidden_size * cfg.intermediate_size / weight_pack_factor;
2388 gate + cfg.num_local_experts * w1
2389 + cfg.num_local_experts * w2
2390 + cfg.num_local_experts * w3
2391 };
2392
2393 input_layernorm
2394 + post_attention_layernorm
2395 + q_proj
2396 + k_proj
2397 + v_proj
2398 + o_proj
2399 + moe_block
2400 };
2401 Ok(vec![
2402 per_layer_elems * dtype.size_in_bytes();
2403 cfg.num_hidden_layers
2404 ])
2405 }
2406
2407 fn num_layers(&self, config: &str) -> Result<usize> {
2408 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2409
2410 Ok(cfg.num_hidden_layers)
2411 }
2412
2413 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2414 let cfg: crate::models::phi3_5_moe::Config = serde_json::from_str(config)?;
2415
2416 let cfg = ModelConfigMetadata {
2417 max_seq_len: cfg.max_position_embeddings,
2418 num_layers: cfg.num_hidden_layers,
2419 hidden_size: cfg.hidden_size,
2420 num_kv_heads: cfg.num_key_value_heads,
2421 num_attn_heads: cfg.num_attention_heads,
2422 sliding_window: cfg.sliding_window,
2423 k_head_dim: cfg.head_dim(),
2424 v_head_dim: cfg.head_dim(),
2425 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2426 };
2427
2428 Ok(Box::new(cfg))
2429 }
2430}
2431
2432pub struct DeepSeekV2Loader;
2436
2437impl NormalModelLoader for DeepSeekV2Loader {
2438 fn load(
2439 &self,
2440 config: &str,
2441 vb: ShardedVarBuilder,
2442 normal_loading_metadata: NormalLoadingMetadata,
2443 attention_mechanism: AttentionImplementation,
2444 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2445 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2446
2447 Ok(Box::new(models::deepseek2::DeepSeekV2::new(
2448 &cfg,
2449 vb,
2450 self.is_gptx(config)?,
2451 normal_loading_metadata,
2452 attention_mechanism,
2453 )?))
2454 }
2455 fn load_xlora(
2456 &self,
2457 _config: &str,
2458 _vb: ShardedVarBuilder,
2459 _lora_config: &[((String, String), LoraConfig)],
2460 _xlora_config: Option<XLoraConfig>,
2461 _xlora_ordering: Ordering,
2462 _normal_loading_metadata: NormalLoadingMetadata,
2463 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2464 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2465 todo!()
2466 }
2467 fn is_gptx(&self, _: &str) -> Result<bool> {
2468 Ok(true)
2469 }
2470 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2471 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2472 Ok(Box::new(cfg))
2473 }
2474}
2475
2476impl IsqModelLoader for DeepSeekV2Loader {
2477 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2478 let mut data = vec![
2479 Regex::new(r"lm_head\.(weight|bias)$")?,
2480 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2482 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2483 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2484 ];
2485 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2486 if cfg.q_lora_rank.is_some() {
2487 data.extend(vec![
2488 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2489 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2490 ]);
2491 } else {
2492 data.push(Regex::new(
2493 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2494 )?);
2495 }
2496 for layer_idx in 0..cfg.num_hidden_layers {
2497 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2498 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2499 }) {
2500 for i in 0..n_routed_experts {
2501 data.extend(vec![
2502 Regex::new(&format!(
2503 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2504 ))?,
2505 Regex::new(&format!(
2506 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2507 ))?,
2508 Regex::new(&format!(
2509 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2510 ))?,
2511 ]);
2512 }
2513 if cfg.n_shared_experts.is_some() {
2514 data.extend(vec![
2515 Regex::new(&format!(
2516 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2517 ))?,
2518 Regex::new(&format!(
2519 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2520 ))?,
2521 Regex::new(&format!(
2522 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2523 ))?,
2524 ]);
2525 }
2526 } else {
2527 data.extend(vec![
2528 Regex::new(&format!(
2529 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2530 ))?,
2531 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2532 Regex::new(&format!(
2533 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2534 ))?,
2535 ]);
2536 };
2537 }
2538 Ok(data)
2539 }
2540 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2541 self.isq_layer_regexes(config)
2542 }
2543
2544 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2545 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2546 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2547 for layer_idx in 0..cfg.num_hidden_layers {
2548 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2549 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2550 }) {
2551 for i in 0..n_routed_experts {
2552 data.extend(vec![
2553 Regex::new(&format!(
2554 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2555 ))?,
2556 Regex::new(&format!(
2557 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2558 ))?,
2559 Regex::new(&format!(
2560 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2561 ))?,
2562 ]);
2563 }
2564 if cfg.n_shared_experts.is_some() {
2565 data.extend(vec![
2566 Regex::new(&format!(
2567 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2568 ))?,
2569 Regex::new(&format!(
2570 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2571 ))?,
2572 Regex::new(&format!(
2573 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2574 ))?,
2575 ]);
2576 }
2577 } else {
2578 data.extend(vec![
2579 Regex::new(&format!(
2580 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2581 ))?,
2582 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2583 Regex::new(&format!(
2584 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2585 ))?,
2586 ]);
2587 };
2588 }
2589 Ok(data)
2590 }
2591 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2592 self.isq_layer_regexes_moqe(config)
2593 }
2594}
2595
2596impl DeviceMappedModelLoader for DeepSeekV2Loader {
2597 fn mapped_max_act_size_elems(
2598 &self,
2599 config: &str,
2600 params: &AutoDeviceMapParams,
2601 ) -> Result<usize> {
2602 let AutoDeviceMapParams::Text {
2603 max_seq_len,
2604 max_batch_size,
2605 } = params
2606 else {
2607 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2608 };
2609
2610 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2611
2612 Ok(
2613 max_batch_size
2614 * cfg.num_attention_heads
2615 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2616 )
2617 }
2618 fn non_mapped_max_act_size_elems(
2619 &self,
2620 _config: &str,
2621 _params: &AutoDeviceMapParams,
2622 ) -> Result<usize> {
2623 Ok(0)
2624 }
2625
2626 fn non_mapped_size_in_bytes(
2627 &self,
2628 config: &str,
2629 dtype: DType,
2630 weight_pack_factor: usize,
2631 _matformer_config: Option<&MatformerSliceConfig>,
2632 ) -> Result<usize> {
2633 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2634 let elems = {
2635 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2636 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2638 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2639 } else {
2640 0
2641 };
2642 let norm = cfg.hidden_size;
2643 embed_tokens + lm_head + norm
2644 };
2645 Ok(elems * dtype.size_in_bytes())
2646 }
2647
2648 fn layer_sizes_in_bytes(
2649 &self,
2650 config: &str,
2651 dtype: DType,
2652 weight_pack_factor: usize,
2653 _matformer_config: Option<&MatformerSliceConfig>,
2654 ) -> Result<Vec<usize>> {
2655 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2656 let mut per_layer_elems = Vec::new();
2657
2658 for layer_idx in 0..cfg.num_hidden_layers {
2659 let input_layernorm = cfg.hidden_size;
2660 let post_attention_layernorm = cfg.hidden_size;
2661
2662 let q_proj = match cfg.q_lora_rank {
2663 Some(lora_rank) => {
2664 let a = cfg.hidden_size * lora_rank;
2665 let norm = lora_rank;
2666 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2667 a + norm + b
2668 }
2669 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2670 };
2671 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
2672 / weight_pack_factor
2673 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
2674 let kv_a_layernorm = cfg.kv_lora_rank;
2675 let kv_b_proj = cfg.kv_lora_rank
2676 * cfg.num_attention_heads
2677 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
2678 / weight_pack_factor;
2679 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
2680 / weight_pack_factor
2681 + bias_if!(cfg.attention_bias, cfg.hidden_size);
2682
2683 let moe_block = {
2684 let mut sum = 0;
2685 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2686 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2687 }) {
2688 let h_size = cfg.hidden_size;
2689 let gate_proj =
2690 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2691 let up_proj =
2692 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
2693 let down_proj =
2694 cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
2695 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
2696 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2697 / weight_pack_factor;
2698 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
2699 / weight_pack_factor;
2700 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
2701 / weight_pack_factor;
2702 gate_proj + up_proj + down_proj
2703 } else {
2704 0
2705 };
2706 let gate_weight = n_routed_experts * cfg.hidden_size;
2707 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
2708 } else {
2709 let h_size = cfg.hidden_size;
2710 let i_size = cfg.intermediate_size;
2711 let gate_proj = h_size * i_size / weight_pack_factor;
2712 let up_proj = h_size * i_size / weight_pack_factor;
2713 let down_proj = i_size * h_size / weight_pack_factor;
2714 sum += gate_proj + up_proj + down_proj;
2715 }
2716 sum
2717 };
2718
2719 per_layer_elems.push(
2720 input_layernorm
2721 + post_attention_layernorm
2722 + q_proj
2723 + kv_a_layernorm
2724 + kv_a_proj_with_mqa
2725 + kv_b_proj
2726 + o_proj
2727 + moe_block,
2728 );
2729 }
2730
2731 Ok(per_layer_elems
2732 .into_iter()
2733 .map(|x| x * dtype.size_in_bytes())
2734 .collect())
2735 }
2736
2737 fn num_layers(&self, config: &str) -> Result<usize> {
2738 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2739 Ok(cfg.num_hidden_layers)
2740 }
2741
2742 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
2743 let cfg: crate::models::deepseek2::DeepSeekV2Config = serde_json::from_str(config)?;
2744
2745 let cfg = ModelConfigMetadata {
2746 max_seq_len: cfg.max_position_embeddings,
2747 num_layers: cfg.num_hidden_layers,
2748 hidden_size: cfg.hidden_size,
2749 num_kv_heads: cfg.num_attention_heads,
2750 num_attn_heads: cfg.num_attention_heads,
2751 sliding_window: None,
2752 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
2753 v_head_dim: cfg.v_head_dim,
2754 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
2755 };
2756
2757 Ok(Box::new(cfg))
2758 }
2759}
2760
2761pub struct DeepSeekV3Loader;
2765
2766impl NormalModelLoader for DeepSeekV3Loader {
2767 fn load(
2768 &self,
2769 config: &str,
2770 vb: ShardedVarBuilder,
2771 normal_loading_metadata: NormalLoadingMetadata,
2772 attention_mechanism: AttentionImplementation,
2773 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2774 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2775 Ok(Box::new(models::deepseek3::DeepSeekV3::new(
2776 &cfg,
2777 vb,
2778 self.is_gptx(config)?,
2779 normal_loading_metadata,
2780 attention_mechanism,
2781 )?))
2782 }
2783 fn load_xlora(
2784 &self,
2785 _config: &str,
2786 _vb: ShardedVarBuilder,
2787 _lora_config: &[((String, String), LoraConfig)],
2788 _xlora_config: Option<XLoraConfig>,
2789 _xlora_ordering: Ordering,
2790 _normal_loading_metadata: NormalLoadingMetadata,
2791 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
2792 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
2793 todo!()
2794 }
2795 fn is_gptx(&self, _: &str) -> Result<bool> {
2796 Ok(true)
2797 }
2798 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
2799 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2800 Ok(Box::new(cfg))
2801 }
2802}
2803
2804impl IsqModelLoader for DeepSeekV3Loader {
2805 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
2806 let mut data = vec![
2807 Regex::new(r"lm_head\.(weight|bias)$")?,
2808 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
2810 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
2811 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
2812 ];
2813 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2814 if cfg.q_lora_rank.is_some() {
2815 data.extend(vec![
2816 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
2817 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
2818 ]);
2819 } else {
2820 data.push(Regex::new(
2821 r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$",
2822 )?);
2823 }
2824 for layer_idx in 0..cfg.num_hidden_layers {
2825 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2826 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2827 }) {
2828 for i in 0..n_routed_experts {
2829 data.extend(vec![
2830 Regex::new(&format!(
2831 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2832 ))?,
2833 Regex::new(&format!(
2834 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2835 ))?,
2836 Regex::new(&format!(
2837 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2838 ))?,
2839 ]);
2840 }
2841 if cfg.n_shared_experts.is_some() {
2842 data.extend(vec![
2843 Regex::new(&format!(
2844 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2845 ))?,
2846 Regex::new(&format!(
2847 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2848 ))?,
2849 Regex::new(&format!(
2850 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2851 ))?,
2852 ]);
2853 }
2854 } else {
2855 data.extend(vec![
2856 Regex::new(&format!(
2857 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2858 ))?,
2859 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2860 Regex::new(&format!(
2861 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2862 ))?,
2863 ]);
2864 };
2865 }
2866 Ok(data)
2867 }
2868 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
2869 self.isq_layer_regexes(config)
2870 }
2871
2872 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2873 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
2874 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2875 for layer_idx in 0..cfg.num_hidden_layers {
2876 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
2877 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
2878 }) {
2879 for i in 0..n_routed_experts {
2880 data.extend(vec![
2881 Regex::new(&format!(
2882 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
2883 ))?,
2884 Regex::new(&format!(
2885 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
2886 ))?,
2887 Regex::new(&format!(
2888 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
2889 ))?,
2890 ]);
2891 }
2892 if cfg.n_shared_experts.is_some() {
2893 data.extend(vec![
2894 Regex::new(&format!(
2895 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
2896 ))?,
2897 Regex::new(&format!(
2898 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
2899 ))?,
2900 Regex::new(&format!(
2901 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
2902 ))?,
2903 ]);
2904 }
2905 } else {
2906 data.extend(vec![
2907 Regex::new(&format!(
2908 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
2909 ))?,
2910 Regex::new(&format!(r"layers.{layer_idx}.mlp\.up_proj\.(weight|bias)$"))?,
2911 Regex::new(&format!(
2912 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
2913 ))?,
2914 ]);
2915 };
2916 }
2917 Ok(data)
2918 }
2919 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
2920 self.isq_layer_regexes_moqe(config)
2921 }
2922}
2923
2924impl DeviceMappedModelLoader for DeepSeekV3Loader {
2925 fn mapped_max_act_size_elems(
2926 &self,
2927 config: &str,
2928 params: &AutoDeviceMapParams,
2929 ) -> Result<usize> {
2930 let AutoDeviceMapParams::Text {
2931 max_seq_len,
2932 max_batch_size,
2933 } = params
2934 else {
2935 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
2936 };
2937
2938 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2939
2940 Ok(
2941 max_batch_size
2942 * cfg.num_attention_heads
2943 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
2944 )
2945 }
2946 fn non_mapped_max_act_size_elems(
2947 &self,
2948 _config: &str,
2949 _params: &AutoDeviceMapParams,
2950 ) -> Result<usize> {
2951 Ok(0)
2952 }
2953
2954 fn non_mapped_size_in_bytes(
2955 &self,
2956 config: &str,
2957 dtype: DType,
2958 weight_pack_factor: usize,
2959 _matformer_config: Option<&MatformerSliceConfig>,
2960 ) -> Result<usize> {
2961 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2962 let elems = {
2963 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
2964 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
2966 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
2967 } else {
2968 0
2969 };
2970 let norm = cfg.hidden_size;
2971 embed_tokens + lm_head + norm
2972 };
2973 Ok(elems * dtype.size_in_bytes())
2974 }
2975
2976 fn layer_sizes_in_bytes(
2977 &self,
2978 config: &str,
2979 dtype: DType,
2980 weight_pack_factor: usize,
2981 _matformer_config: Option<&MatformerSliceConfig>,
2982 ) -> Result<Vec<usize>> {
2983 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
2984 let mut per_layer_elems = Vec::new();
2985
2986 for layer_idx in 0..cfg.num_hidden_layers {
2987 let input_layernorm = cfg.hidden_size;
2988 let post_attention_layernorm = cfg.hidden_size;
2989
2990 let q_proj = match cfg.q_lora_rank {
2991 Some(lora_rank) => {
2992 let a = cfg.hidden_size * lora_rank;
2993 let norm = lora_rank;
2994 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * lora_rank;
2995 a + norm + b
2996 }
2997 None => (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.hidden_size,
2998 };
2999 let kv_a_proj_with_mqa = cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim)
3000 / weight_pack_factor
3001 + bias_if!(cfg.attention_bias, cfg.kv_lora_rank + cfg.qk_rope_head_dim);
3002 let kv_a_layernorm = cfg.kv_lora_rank;
3003 let kv_b_proj = cfg.kv_lora_rank
3004 * cfg.num_attention_heads
3005 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3006 / weight_pack_factor;
3007 let o_proj = cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size
3008 / weight_pack_factor
3009 + bias_if!(cfg.attention_bias, cfg.hidden_size);
3010
3011 let moe_block = {
3012 let mut sum = 0;
3013 if let Some(n_routed_experts) = cfg.n_routed_experts.filter(|_| {
3014 layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0
3015 }) {
3016 let h_size = cfg.hidden_size;
3017 let gate_proj =
3018 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3019 let up_proj =
3020 h_size * cfg.moe_intermediate_size / weight_pack_factor * n_routed_experts;
3021 let down_proj =
3022 cfg.moe_intermediate_size * h_size / weight_pack_factor * n_routed_experts;
3023 let shared_experts = if let Some(n_shared_experts) = cfg.n_shared_experts {
3024 let gate_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3025 / weight_pack_factor;
3026 let up_proj = h_size * (cfg.intermediate_size * n_shared_experts)
3027 / weight_pack_factor;
3028 let down_proj = (cfg.intermediate_size * n_shared_experts) * h_size
3029 / weight_pack_factor;
3030 gate_proj + up_proj + down_proj
3031 } else {
3032 0
3033 };
3034 let gate_weight = n_routed_experts * cfg.hidden_size;
3035 sum += gate_proj + up_proj + down_proj + shared_experts + gate_weight;
3036 } else {
3037 let h_size = cfg.hidden_size;
3038 let i_size = cfg.intermediate_size;
3039 let gate_proj = h_size * i_size / weight_pack_factor;
3040 let up_proj = h_size * i_size / weight_pack_factor;
3041 let down_proj = i_size * h_size / weight_pack_factor;
3042 sum += gate_proj + up_proj + down_proj;
3043 }
3044 sum
3045 };
3046
3047 per_layer_elems.push(
3048 input_layernorm
3049 + post_attention_layernorm
3050 + q_proj
3051 + kv_a_layernorm
3052 + kv_a_proj_with_mqa
3053 + kv_b_proj
3054 + o_proj
3055 + moe_block,
3056 );
3057 }
3058
3059 Ok(per_layer_elems
3060 .into_iter()
3061 .map(|x| x * dtype.size_in_bytes())
3062 .collect())
3063 }
3064
3065 fn num_layers(&self, config: &str) -> Result<usize> {
3066 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3067 Ok(cfg.num_hidden_layers)
3068 }
3069
3070 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3071 let cfg: crate::models::deepseek3::DeepSeekV3Config = serde_json::from_str(config)?;
3072
3073 let cfg = ModelConfigMetadata {
3074 max_seq_len: cfg.max_position_embeddings,
3075 num_layers: cfg.num_hidden_layers,
3076 hidden_size: cfg.hidden_size,
3077 num_kv_heads: cfg.num_attention_heads,
3078 num_attn_heads: cfg.num_attention_heads,
3079 sliding_window: None,
3080 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3081 v_head_dim: cfg.v_head_dim,
3082 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3083 };
3084
3085 Ok(Box::new(cfg))
3086 }
3087}
3088
3089pub struct Qwen3Loader;
3093
3094impl NormalModelLoader for Qwen3Loader {
3095 fn load(
3096 &self,
3097 config: &str,
3098 vb: ShardedVarBuilder,
3099 normal_loading_metadata: NormalLoadingMetadata,
3100 attention_mechanism: AttentionImplementation,
3101 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3102 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3103
3104 Ok(Box::new(models::qwen3::Model::new(
3105 &cfg,
3106 vb,
3107 self.is_gptx(config)?,
3108 normal_loading_metadata,
3109 attention_mechanism,
3110 )?))
3111 }
3112 fn load_xlora(
3113 &self,
3114 _config: &str,
3115 _vb: ShardedVarBuilder,
3116 _lora_config: &[((String, String), LoraConfig)],
3117 _xlora_config: Option<XLoraConfig>,
3118 _xlora_ordering: Ordering,
3119 _normal_loading_metadata: NormalLoadingMetadata,
3120 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3121 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3122 todo!()
3123 }
3124 fn is_gptx(&self, _: &str) -> Result<bool> {
3125 Ok(true)
3126 }
3127 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3128 let cfg: crate::models::qwen3::Config = serde_json::from_str(config)?;
3129
3130 Ok(Box::new(cfg))
3131 }
3132}
3133
3134impl IsqModelLoader for Qwen3Loader {
3135 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3136 Ok(vec![
3137 Regex::new(r"lm_head\.(weight|bias)$")?,
3138 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3140 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3141 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3142 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3143 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3145 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3146 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3147 ])
3148 }
3149 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3150 self.isq_layer_regexes(config)
3151 }
3152}
3153
3154impl DeviceMappedModelLoader for Qwen3Loader {
3155 fn mapped_max_act_size_elems(
3156 &self,
3157 config: &str,
3158 params: &AutoDeviceMapParams,
3159 ) -> Result<usize> {
3160 let AutoDeviceMapParams::Text {
3161 max_seq_len,
3162 max_batch_size,
3163 } = params
3164 else {
3165 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3166 };
3167
3168 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3169
3170 Ok(
3171 max_batch_size
3172 * cfg.num_attention_heads
3173 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3174 )
3175 }
3176 fn non_mapped_max_act_size_elems(
3177 &self,
3178 _config: &str,
3179 _params: &AutoDeviceMapParams,
3180 ) -> Result<usize> {
3181 Ok(0)
3182 }
3183
3184 fn non_mapped_size_in_bytes(
3185 &self,
3186 config: &str,
3187 dtype: DType,
3188 weight_pack_factor: usize,
3189 _matformer_config: Option<&MatformerSliceConfig>,
3190 ) -> Result<usize> {
3191 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3192 let elems = {
3193 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3194 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3196 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3197 } else {
3198 0
3199 };
3200 let norm = cfg.hidden_size;
3201 embed_tokens + lm_head + norm
3202 };
3203 Ok(elems * dtype.size_in_bytes())
3204 }
3205
3206 fn layer_sizes_in_bytes(
3207 &self,
3208 config: &str,
3209 dtype: DType,
3210 weight_pack_factor: usize,
3211 _matformer_config: Option<&MatformerSliceConfig>,
3212 ) -> Result<Vec<usize>> {
3213 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3214 let per_layer_elems = {
3215 let input_layernorm = cfg.hidden_size;
3216 let post_attention_layernorm = cfg.hidden_size;
3217
3218 let size_in = cfg.hidden_size;
3219 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3220 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3221 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3222 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3223 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3224 let o_proj = size_q * size_in / weight_pack_factor;
3225
3226 let h_size = cfg.hidden_size;
3227 let i_size = cfg.intermediate_size;
3228 let gate_proj = h_size * i_size / weight_pack_factor;
3229 let up_proj = h_size * i_size / weight_pack_factor;
3230 let down_proj = i_size * h_size / weight_pack_factor;
3231
3232 let q_norm = cfg.head_dim();
3233 let k_norm = cfg.head_dim();
3234
3235 input_layernorm
3236 + post_attention_layernorm
3237 + q_proj
3238 + k_proj
3239 + v_proj
3240 + o_proj
3241 + gate_proj
3242 + up_proj
3243 + down_proj
3244 + q_norm
3245 + k_norm
3246 };
3247 Ok(vec![
3248 per_layer_elems * dtype.size_in_bytes();
3249 cfg.num_hidden_layers
3250 ])
3251 }
3252
3253 fn num_layers(&self, config: &str) -> Result<usize> {
3254 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3255 Ok(cfg.num_hidden_layers)
3256 }
3257
3258 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3259 let cfg: models::qwen3::Config = serde_json::from_str(config)?;
3260
3261 let cfg = ModelConfigMetadata {
3262 max_seq_len: cfg.max_position_embeddings,
3263 num_layers: cfg.num_hidden_layers,
3264 hidden_size: cfg.hidden_size,
3265 num_kv_heads: cfg.num_key_value_heads,
3266 num_attn_heads: cfg.num_attention_heads,
3267 sliding_window: cfg.sliding_window,
3268 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3269 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3270 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3271 };
3272
3273 Ok(Box::new(cfg))
3274 }
3275}
3276
3277pub struct GLM4Loader;
3281
3282impl NormalModelLoader for GLM4Loader {
3283 fn load(
3284 &self,
3285 config: &str,
3286 vb: ShardedVarBuilder,
3287 normal_loading_metadata: NormalLoadingMetadata,
3288 attention_mechanism: AttentionImplementation,
3289 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3290 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3291
3292 Ok(Box::new(models::glm4::Model::new(
3293 &cfg,
3294 vb,
3295 self.is_gptx(config)?,
3296 normal_loading_metadata,
3297 attention_mechanism,
3298 )?))
3299 }
3300 fn load_xlora(
3301 &self,
3302 _config: &str,
3303 _vb: ShardedVarBuilder,
3304 _lora_config: &[((String, String), LoraConfig)],
3305 _xlora_config: Option<XLoraConfig>,
3306 _xlora_ordering: Ordering,
3307 _normal_loading_metadata: NormalLoadingMetadata,
3308 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3309 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3310 todo!()
3311 }
3312 fn is_gptx(&self, _: &str) -> Result<bool> {
3313 Ok(true)
3314 }
3315 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3316 let cfg: crate::models::glm4::Config = serde_json::from_str(config)?;
3317
3318 Ok(Box::new(cfg))
3319 }
3320}
3321
3322impl IsqModelLoader for GLM4Loader {
3323 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
3324 Ok(vec![
3325 Regex::new(r"lm_head\.(weight|bias)$")?,
3326 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3328 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3329 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3330 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3331 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
3333 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
3334 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
3335 ])
3336 }
3337 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3338 self.isq_layer_regexes(config)
3339 }
3340}
3341
3342impl DeviceMappedModelLoader for GLM4Loader {
3343 fn mapped_max_act_size_elems(
3344 &self,
3345 config: &str,
3346 params: &AutoDeviceMapParams,
3347 ) -> Result<usize> {
3348 let AutoDeviceMapParams::Text {
3349 max_seq_len,
3350 max_batch_size,
3351 } = params
3352 else {
3353 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3354 };
3355
3356 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3357
3358 Ok(
3359 max_batch_size
3360 * cfg.num_attention_heads
3361 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3362 )
3363 }
3364 fn non_mapped_max_act_size_elems(
3365 &self,
3366 _config: &str,
3367 _params: &AutoDeviceMapParams,
3368 ) -> Result<usize> {
3369 Ok(0)
3370 }
3371
3372 fn non_mapped_size_in_bytes(
3373 &self,
3374 config: &str,
3375 dtype: DType,
3376 weight_pack_factor: usize,
3377 _matformer_config: Option<&MatformerSliceConfig>,
3378 ) -> Result<usize> {
3379 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3380 let elems = {
3381 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3382 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3384 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3385 } else {
3386 0
3387 };
3388 let norm = cfg.hidden_size;
3389 embed_tokens + lm_head + norm
3390 };
3391 Ok(elems * dtype.size_in_bytes())
3392 }
3393
3394 fn layer_sizes_in_bytes(
3395 &self,
3396 config: &str,
3397 dtype: DType,
3398 weight_pack_factor: usize,
3399 _matformer_config: Option<&MatformerSliceConfig>,
3400 ) -> Result<Vec<usize>> {
3401 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3402 let per_layer_elems = {
3403 let input_layernorm = cfg.hidden_size;
3404 let post_attention_layernorm = cfg.hidden_size * 3; let size_in = cfg.hidden_size;
3407 let size_q = cfg.head_dim() * cfg.num_attention_heads;
3408 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
3409 let q_proj = size_in * size_q / weight_pack_factor + size_q;
3410 let k_proj = size_in * size_kv / weight_pack_factor + size_kv;
3411 let v_proj = size_in * size_kv / weight_pack_factor + size_kv;
3412 let o_proj = size_q * size_in / weight_pack_factor;
3413
3414 let h_size = cfg.hidden_size;
3415 let i_size = cfg.intermediate_size;
3416 let gate_proj = h_size * i_size / weight_pack_factor;
3417 let up_proj = h_size * i_size / weight_pack_factor;
3418 let down_proj = i_size * h_size / weight_pack_factor;
3419
3420 input_layernorm
3421 + post_attention_layernorm
3422 + q_proj
3423 + k_proj
3424 + v_proj
3425 + o_proj
3426 + gate_proj
3427 + up_proj
3428 + down_proj
3429 };
3430 Ok(vec![
3431 per_layer_elems * dtype.size_in_bytes();
3432 cfg.num_hidden_layers
3433 ])
3434 }
3435
3436 fn num_layers(&self, config: &str) -> Result<usize> {
3437 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3438 Ok(cfg.num_hidden_layers)
3439 }
3440
3441 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3442 let cfg: models::glm4::Config = serde_json::from_str(config)?;
3443
3444 let cfg = ModelConfigMetadata {
3445 max_seq_len: cfg.max_position_embeddings,
3446 num_layers: cfg.num_hidden_layers,
3447 hidden_size: cfg.hidden_size,
3448 num_kv_heads: cfg.num_key_value_heads,
3449 num_attn_heads: cfg.num_attention_heads,
3450 sliding_window: cfg.sliding_window,
3451 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3452 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
3453 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3454 };
3455
3456 Ok(Box::new(cfg))
3457 }
3458}
3459
3460pub struct GLM4MoeLiteLoader;
3464
3465impl NormalModelLoader for GLM4MoeLiteLoader {
3466 fn load(
3467 &self,
3468 config: &str,
3469 vb: ShardedVarBuilder,
3470 normal_loading_metadata: NormalLoadingMetadata,
3471 attention_mechanism: AttentionImplementation,
3472 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3473 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3474 Ok(Box::new(models::glm4_moe_lite::Glm4MoeLite::new(
3475 &cfg,
3476 vb,
3477 self.is_gptx(config)?,
3478 normal_loading_metadata,
3479 attention_mechanism,
3480 )?))
3481 }
3482 fn load_xlora(
3483 &self,
3484 _config: &str,
3485 _vb: ShardedVarBuilder,
3486 _lora_config: &[((String, String), LoraConfig)],
3487 _xlora_config: Option<XLoraConfig>,
3488 _xlora_ordering: Ordering,
3489 _normal_loading_metadata: NormalLoadingMetadata,
3490 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3491 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3492 todo!()
3493 }
3494 fn is_gptx(&self, _: &str) -> Result<bool> {
3495 Ok(true)
3496 }
3497 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3498 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3499 Ok(Box::new(cfg))
3500 }
3501}
3502
3503impl IsqModelLoader for GLM4MoeLiteLoader {
3504 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3505 let mut data = vec![
3506 Regex::new(r"lm_head\.(weight|bias)$")?,
3507 Regex::new(r"layers\.(\d+)\.self_attn\.kv_a_proj_with_mqa\.(weight|bias)$")?,
3509 Regex::new(r"layers\.(\d+)\.self_attn\.kv_b_proj\.(weight|bias)$")?,
3510 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3511 Regex::new(r"layers\.(\d+)\.self_attn\.q_a_proj\.(weight|bias)$")?,
3513 Regex::new(r"layers\.(\d+)\.self_attn\.q_b_proj\.(weight|bias)$")?,
3514 ];
3515 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3516 for layer_idx in 0..cfg.num_hidden_layers {
3517 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3518 for i in 0..cfg.n_routed_experts {
3520 data.extend(vec![
3521 Regex::new(&format!(
3522 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3523 ))?,
3524 Regex::new(&format!(
3525 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3526 ))?,
3527 Regex::new(&format!(
3528 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3529 ))?,
3530 ]);
3531 }
3532 if cfg.n_shared_experts > 0 {
3533 data.extend(vec![
3534 Regex::new(&format!(
3535 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3536 ))?,
3537 Regex::new(&format!(
3538 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3539 ))?,
3540 Regex::new(&format!(
3541 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3542 ))?,
3543 ]);
3544 }
3545 } else {
3546 data.extend(vec![
3548 Regex::new(&format!(
3549 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3550 ))?,
3551 Regex::new(&format!(
3552 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3553 ))?,
3554 Regex::new(&format!(
3555 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3556 ))?,
3557 ]);
3558 };
3559 }
3560 Ok(data)
3561 }
3562 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3563 self.isq_layer_regexes(config)
3564 }
3565
3566 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3567 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3568 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3569 for layer_idx in 0..cfg.num_hidden_layers {
3570 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3571 for i in 0..cfg.n_routed_experts {
3573 data.extend(vec![
3574 Regex::new(&format!(
3575 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3576 ))?,
3577 Regex::new(&format!(
3578 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3579 ))?,
3580 Regex::new(&format!(
3581 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3582 ))?,
3583 ]);
3584 }
3585 if cfg.n_shared_experts > 0 {
3586 data.extend(vec![
3587 Regex::new(&format!(
3588 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3589 ))?,
3590 Regex::new(&format!(
3591 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3592 ))?,
3593 Regex::new(&format!(
3594 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3595 ))?,
3596 ]);
3597 }
3598 } else {
3599 data.extend(vec![
3601 Regex::new(&format!(
3602 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3603 ))?,
3604 Regex::new(&format!(
3605 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3606 ))?,
3607 Regex::new(&format!(
3608 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3609 ))?,
3610 ]);
3611 };
3612 }
3613 Ok(data)
3614 }
3615 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3616 self.isq_layer_regexes_moqe(config)
3617 }
3618}
3619
3620impl DeviceMappedModelLoader for GLM4MoeLiteLoader {
3621 fn mapped_max_act_size_elems(
3622 &self,
3623 config: &str,
3624 params: &AutoDeviceMapParams,
3625 ) -> Result<usize> {
3626 let AutoDeviceMapParams::Text {
3627 max_seq_len,
3628 max_batch_size,
3629 } = params
3630 else {
3631 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3632 };
3633
3634 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3635
3636 Ok(
3637 max_batch_size
3638 * cfg.num_attention_heads
3639 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3640 )
3641 }
3642 fn non_mapped_max_act_size_elems(
3643 &self,
3644 _config: &str,
3645 _params: &AutoDeviceMapParams,
3646 ) -> Result<usize> {
3647 Ok(0)
3648 }
3649
3650 fn non_mapped_size_in_bytes(
3651 &self,
3652 config: &str,
3653 dtype: DType,
3654 weight_pack_factor: usize,
3655 _matformer_config: Option<&MatformerSliceConfig>,
3656 ) -> Result<usize> {
3657 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3658 let elems = {
3659 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3660 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3662 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3663 } else {
3664 0
3665 };
3666 let norm = cfg.hidden_size;
3667 embed_tokens + lm_head + norm
3668 };
3669 Ok(elems * dtype.size_in_bytes())
3670 }
3671
3672 fn layer_sizes_in_bytes(
3673 &self,
3674 config: &str,
3675 dtype: DType,
3676 weight_pack_factor: usize,
3677 _matformer_config: Option<&MatformerSliceConfig>,
3678 ) -> Result<Vec<usize>> {
3679 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3680 let mut per_layer_elems = Vec::new();
3681
3682 for layer_idx in 0..cfg.num_hidden_layers {
3683 let input_layernorm = cfg.hidden_size;
3684 let post_attention_layernorm = cfg.hidden_size;
3685
3686 let q_proj = {
3688 let a = cfg.hidden_size * cfg.q_lora_rank / weight_pack_factor;
3689 let norm = cfg.q_lora_rank;
3690 let b = (cfg.num_attention_heads * cfg.q_head_dim()) * cfg.q_lora_rank
3691 / weight_pack_factor;
3692 a + norm + b
3693 };
3694 let kv_a_proj_with_mqa =
3695 cfg.hidden_size * (cfg.kv_lora_rank + cfg.qk_rope_head_dim) / weight_pack_factor;
3696 let kv_a_layernorm = cfg.kv_lora_rank;
3697 let kv_b_proj = cfg.kv_lora_rank
3698 * cfg.num_attention_heads
3699 * (cfg.q_head_dim() - cfg.qk_rope_head_dim + cfg.v_head_dim)
3700 / weight_pack_factor;
3701 let o_proj =
3702 cfg.num_attention_heads * cfg.v_head_dim * cfg.hidden_size / weight_pack_factor;
3703
3704 let moe_block = {
3705 let mut sum = 0;
3706 if layer_idx >= cfg.first_k_dense_replace && layer_idx % cfg.moe_layer_freq == 0 {
3707 let h_size = cfg.hidden_size;
3709 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3710 * cfg.n_routed_experts;
3711 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
3712 * cfg.n_routed_experts;
3713 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
3714 * cfg.n_routed_experts;
3715 let shared_experts = if cfg.n_shared_experts > 0 {
3716 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3717 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
3718 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
3719 gate_proj + up_proj + down_proj
3720 } else {
3721 0
3722 };
3723 let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
3724 let e_score_correction_bias = cfg.n_routed_experts;
3725 sum += gate_proj
3726 + up_proj
3727 + down_proj
3728 + shared_experts
3729 + gate_weight
3730 + e_score_correction_bias;
3731 } else {
3732 let h_size = cfg.hidden_size;
3734 let i_size = cfg.intermediate_size;
3735 let gate_proj = h_size * i_size / weight_pack_factor;
3736 let up_proj = h_size * i_size / weight_pack_factor;
3737 let down_proj = i_size * h_size / weight_pack_factor;
3738 sum += gate_proj + up_proj + down_proj;
3739 }
3740 sum
3741 };
3742
3743 per_layer_elems.push(
3744 input_layernorm
3745 + post_attention_layernorm
3746 + q_proj
3747 + kv_a_layernorm
3748 + kv_a_proj_with_mqa
3749 + kv_b_proj
3750 + o_proj
3751 + moe_block,
3752 );
3753 }
3754
3755 Ok(per_layer_elems
3756 .into_iter()
3757 .map(|x| x * dtype.size_in_bytes())
3758 .collect())
3759 }
3760
3761 fn num_layers(&self, config: &str) -> Result<usize> {
3762 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3763 Ok(cfg.num_hidden_layers)
3764 }
3765
3766 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
3767 let cfg: crate::models::glm4_moe_lite::Glm4MoeLiteConfig = serde_json::from_str(config)?;
3768
3769 let cfg = ModelConfigMetadata {
3770 max_seq_len: cfg.max_position_embeddings,
3771 num_layers: cfg.num_hidden_layers,
3772 hidden_size: cfg.hidden_size,
3773 num_kv_heads: cfg.num_attention_heads,
3774 num_attn_heads: cfg.num_attention_heads,
3775 sliding_window: None,
3776 k_head_dim: cfg.qk_rope_head_dim + cfg.qk_nope_head_dim,
3777 v_head_dim: cfg.v_head_dim,
3778 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
3779 };
3780
3781 Ok(Box::new(cfg))
3782 }
3783}
3784
3785pub struct GLM4MoeLoader;
3789
3790impl NormalModelLoader for GLM4MoeLoader {
3791 fn load(
3792 &self,
3793 config: &str,
3794 vb: ShardedVarBuilder,
3795 normal_loading_metadata: NormalLoadingMetadata,
3796 attention_mechanism: AttentionImplementation,
3797 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3798 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3799 Ok(Box::new(models::glm4_moe::Glm4Moe::new(
3800 &cfg,
3801 vb,
3802 self.is_gptx(config)?,
3803 normal_loading_metadata,
3804 attention_mechanism,
3805 )?))
3806 }
3807 fn load_xlora(
3808 &self,
3809 _config: &str,
3810 _vb: ShardedVarBuilder,
3811 _lora_config: &[((String, String), LoraConfig)],
3812 _xlora_config: Option<XLoraConfig>,
3813 _xlora_ordering: Ordering,
3814 _normal_loading_metadata: NormalLoadingMetadata,
3815 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
3816 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
3817 todo!()
3818 }
3819 fn is_gptx(&self, _: &str) -> Result<bool> {
3820 Ok(true)
3821 }
3822 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
3823 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3824 Ok(Box::new(cfg))
3825 }
3826}
3827
3828impl IsqModelLoader for GLM4MoeLoader {
3829 fn isq_layer_regexes(&self, config: &str) -> Result<Vec<Regex>> {
3830 let mut data = vec![
3831 Regex::new(r"lm_head\.(weight|bias)$")?,
3832 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
3834 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
3835 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
3836 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
3837 ];
3838 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3839 for layer_idx in 0..cfg.num_hidden_layers {
3840 if layer_idx >= cfg.first_k_dense_replace {
3841 for i in 0..cfg.n_routed_experts {
3843 data.extend(vec![
3844 Regex::new(&format!(
3845 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3846 ))?,
3847 Regex::new(&format!(
3848 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3849 ))?,
3850 Regex::new(&format!(
3851 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3852 ))?,
3853 ]);
3854 }
3855 if cfg.n_shared_experts > 0 {
3856 data.extend(vec![
3857 Regex::new(&format!(
3858 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3859 ))?,
3860 Regex::new(&format!(
3861 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3862 ))?,
3863 Regex::new(&format!(
3864 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3865 ))?,
3866 ]);
3867 }
3868 } else {
3869 data.extend(vec![
3871 Regex::new(&format!(
3872 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3873 ))?,
3874 Regex::new(&format!(
3875 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3876 ))?,
3877 Regex::new(&format!(
3878 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3879 ))?,
3880 ]);
3881 };
3882 }
3883 Ok(data)
3884 }
3885 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
3886 self.isq_layer_regexes(config)
3887 }
3888
3889 fn isq_layer_regexes_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3890 let mut data = vec![Regex::new(r"lm_head\.(weight|bias)$")?];
3891 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3892 for layer_idx in 0..cfg.num_hidden_layers {
3893 if layer_idx >= cfg.first_k_dense_replace {
3894 for i in 0..cfg.n_routed_experts {
3896 data.extend(vec![
3897 Regex::new(&format!(
3898 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.gate_proj\.(weight|bias)$"
3899 ))?,
3900 Regex::new(&format!(
3901 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.up_proj\.(weight|bias)$"
3902 ))?,
3903 Regex::new(&format!(
3904 r"layers\.{layer_idx}\.mlp\.experts\.{i}\.down_proj\.(weight|bias)$"
3905 ))?,
3906 ]);
3907 }
3908 if cfg.n_shared_experts > 0 {
3909 data.extend(vec![
3910 Regex::new(&format!(
3911 r"layers\.{layer_idx}\.mlp\.shared_experts\.gate_proj\.(weight|bias)$"
3912 ))?,
3913 Regex::new(&format!(
3914 r"layers\.{layer_idx}\.mlp\.shared_experts\.up_proj\.(weight|bias)$"
3915 ))?,
3916 Regex::new(&format!(
3917 r"layers\.{layer_idx}\.mlp\.shared_experts\.down_proj\.(weight|bias)$"
3918 ))?,
3919 ]);
3920 }
3921 } else {
3922 data.extend(vec![
3924 Regex::new(&format!(
3925 r"layers\.{layer_idx}\.mlp\.gate_proj\.(weight|bias)$"
3926 ))?,
3927 Regex::new(&format!(
3928 r"layers\.{layer_idx}\.mlp\.up_proj\.(weight|bias)$"
3929 ))?,
3930 Regex::new(&format!(
3931 r"layers\.{layer_idx}\.mlp\.down_proj\.(weight|bias)$"
3932 ))?,
3933 ]);
3934 };
3935 }
3936 Ok(data)
3937 }
3938 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
3939 self.isq_layer_regexes_moqe(config)
3940 }
3941}
3942
3943impl DeviceMappedModelLoader for GLM4MoeLoader {
3944 fn mapped_max_act_size_elems(
3945 &self,
3946 config: &str,
3947 params: &AutoDeviceMapParams,
3948 ) -> Result<usize> {
3949 let AutoDeviceMapParams::Text {
3950 max_seq_len,
3951 max_batch_size,
3952 } = params
3953 else {
3954 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
3955 };
3956
3957 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3958
3959 Ok(
3960 max_batch_size
3961 * cfg.num_attention_heads
3962 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
3963 )
3964 }
3965 fn non_mapped_max_act_size_elems(
3966 &self,
3967 _config: &str,
3968 _params: &AutoDeviceMapParams,
3969 ) -> Result<usize> {
3970 Ok(0)
3971 }
3972
3973 fn non_mapped_size_in_bytes(
3974 &self,
3975 config: &str,
3976 dtype: DType,
3977 weight_pack_factor: usize,
3978 _matformer_config: Option<&MatformerSliceConfig>,
3979 ) -> Result<usize> {
3980 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
3981 let elems = {
3982 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
3983 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
3984 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
3985 } else {
3986 0
3987 };
3988 let norm = cfg.hidden_size;
3989 embed_tokens + lm_head + norm
3990 };
3991 Ok(elems * dtype.size_in_bytes())
3992 }
3993
3994 fn layer_sizes_in_bytes(
3995 &self,
3996 config: &str,
3997 dtype: DType,
3998 weight_pack_factor: usize,
3999 _matformer_config: Option<&MatformerSliceConfig>,
4000 ) -> Result<Vec<usize>> {
4001 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4002 let mut per_layer_elems = Vec::new();
4003
4004 let head_dim = cfg.head_dim();
4005 for layer_idx in 0..cfg.num_hidden_layers {
4006 let input_layernorm = cfg.hidden_size;
4007 let post_attention_layernorm = cfg.hidden_size;
4008
4009 let q_proj = cfg.hidden_size * cfg.num_attention_heads * head_dim / weight_pack_factor
4011 + bias_if!(cfg.attention_bias, cfg.num_attention_heads * head_dim);
4012 let k_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4013 + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4014 let v_proj = cfg.hidden_size * cfg.num_key_value_heads * head_dim / weight_pack_factor
4015 + bias_if!(cfg.attention_bias, cfg.num_key_value_heads * head_dim);
4016 let o_proj = cfg.num_attention_heads * head_dim * cfg.hidden_size / weight_pack_factor;
4017
4018 let qk_norm = if cfg.use_qk_norm {
4020 head_dim * 2 } else {
4022 0
4023 };
4024
4025 let moe_block = {
4026 let mut sum = 0;
4027 if layer_idx >= cfg.first_k_dense_replace {
4028 let h_size = cfg.hidden_size;
4030 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4031 * cfg.n_routed_experts;
4032 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor
4033 * cfg.n_routed_experts;
4034 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor
4035 * cfg.n_routed_experts;
4036 let shared_experts = if cfg.n_shared_experts > 0 {
4037 let gate_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4038 let up_proj = h_size * cfg.moe_intermediate_size / weight_pack_factor;
4039 let down_proj = cfg.moe_intermediate_size * h_size / weight_pack_factor;
4040 gate_proj + up_proj + down_proj
4041 } else {
4042 0
4043 };
4044 let gate_weight = cfg.n_routed_experts * cfg.hidden_size;
4045 let e_score_correction_bias = cfg.n_routed_experts;
4046 sum += gate_proj
4047 + up_proj
4048 + down_proj
4049 + shared_experts
4050 + gate_weight
4051 + e_score_correction_bias;
4052 } else {
4053 let h_size = cfg.hidden_size;
4055 let i_size = cfg.intermediate_size;
4056 let gate_proj = h_size * i_size / weight_pack_factor;
4057 let up_proj = h_size * i_size / weight_pack_factor;
4058 let down_proj = i_size * h_size / weight_pack_factor;
4059 sum += gate_proj + up_proj + down_proj;
4060 }
4061 sum
4062 };
4063
4064 per_layer_elems.push(
4065 input_layernorm
4066 + post_attention_layernorm
4067 + q_proj
4068 + k_proj
4069 + v_proj
4070 + o_proj
4071 + qk_norm
4072 + moe_block,
4073 );
4074 }
4075
4076 Ok(per_layer_elems
4077 .into_iter()
4078 .map(|x| x * dtype.size_in_bytes())
4079 .collect())
4080 }
4081
4082 fn num_layers(&self, config: &str) -> Result<usize> {
4083 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4084 Ok(cfg.num_hidden_layers)
4085 }
4086
4087 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4088 let cfg: crate::models::glm4_moe::Glm4MoeConfig = serde_json::from_str(config)?;
4089
4090 let head_dim = cfg.head_dim();
4091 let cfg = ModelConfigMetadata {
4092 max_seq_len: cfg.max_position_embeddings,
4093 num_layers: cfg.num_hidden_layers,
4094 hidden_size: cfg.hidden_size,
4095 num_kv_heads: cfg.num_key_value_heads,
4096 num_attn_heads: cfg.num_attention_heads,
4097 sliding_window: None,
4098 k_head_dim: head_dim,
4099 v_head_dim: head_dim,
4100 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4101 };
4102
4103 Ok(Box::new(cfg))
4104 }
4105}
4106
4107pub struct Qwen3MoELoader;
4111
4112impl NormalModelLoader for Qwen3MoELoader {
4113 fn load(
4114 &self,
4115 config: &str,
4116 vb: ShardedVarBuilder,
4117 normal_loading_metadata: NormalLoadingMetadata,
4118 attention_mechanism: AttentionImplementation,
4119 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4120 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4121
4122 Ok(Box::new(models::qwen3_moe::Model::new(
4123 &cfg,
4124 vb,
4125 self.is_gptx(config)?,
4126 normal_loading_metadata,
4127 attention_mechanism,
4128 )?))
4129 }
4130 fn load_xlora(
4131 &self,
4132 _config: &str,
4133 _vb: ShardedVarBuilder,
4134 _lora_config: &[((String, String), LoraConfig)],
4135 _xlora_config: Option<XLoraConfig>,
4136 _xlora_ordering: Ordering,
4137 _normal_loading_metadata: NormalLoadingMetadata,
4138 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4139 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4140 todo!()
4141 }
4142 fn is_gptx(&self, _: &str) -> Result<bool> {
4143 Ok(true)
4144 }
4145 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4146 let cfg: crate::models::qwen3_moe::Config = serde_json::from_str(config)?;
4147
4148 Ok(Box::new(cfg))
4149 }
4150}
4151
4152impl IsqModelLoader for Qwen3MoELoader {
4153 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4154 Ok(vec![
4155 Regex::new(r"lm_head\.(weight|bias)$")?,
4156 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4158 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4159 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4160 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4161 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4163 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4164 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4165 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.gate_proj\.(weight|bias)$")?,
4167 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.up_proj\.(weight|bias)$")?,
4168 Regex::new(r"layers\.(\d+)\.mlp\.experts\.(\d+)\.down_proj\.(weight|bias)$")?,
4169 ])
4170 }
4171 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4172 self.isq_layer_regexes(config)
4173 }
4174 fn immediate_isq_predicates_moqe(&self, config: &str) -> Result<Vec<Regex>> {
4175 self.isq_layer_regexes_moqe(config)
4176 }
4177}
4178
4179impl DeviceMappedModelLoader for Qwen3MoELoader {
4180 fn mapped_max_act_size_elems(
4181 &self,
4182 config: &str,
4183 params: &AutoDeviceMapParams,
4184 ) -> Result<usize> {
4185 let AutoDeviceMapParams::Text {
4186 max_seq_len,
4187 max_batch_size,
4188 } = params
4189 else {
4190 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4191 };
4192
4193 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4194
4195 Ok(
4196 max_batch_size
4197 * cfg.num_attention_heads
4198 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4199 )
4200 }
4201 fn non_mapped_max_act_size_elems(
4202 &self,
4203 _config: &str,
4204 _params: &AutoDeviceMapParams,
4205 ) -> Result<usize> {
4206 Ok(0)
4207 }
4208
4209 fn non_mapped_size_in_bytes(
4210 &self,
4211 config: &str,
4212 dtype: DType,
4213 weight_pack_factor: usize,
4214 _matformer_config: Option<&MatformerSliceConfig>,
4215 ) -> Result<usize> {
4216 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4217 let elems = {
4218 let embed_tokens = cfg.hidden_size * cfg.vocab_size;
4219 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4221 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4222 } else {
4223 0
4224 };
4225 let norm = cfg.hidden_size;
4226 embed_tokens + lm_head + norm
4227 };
4228 Ok(elems * dtype.size_in_bytes())
4229 }
4230
4231 fn layer_sizes_in_bytes(
4232 &self,
4233 config: &str,
4234 dtype: DType,
4235 weight_pack_factor: usize,
4236 _matformer_config: Option<&MatformerSliceConfig>,
4237 ) -> Result<Vec<usize>> {
4238 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4239
4240 let mut layer_sizes_in_bytes = Vec::new();
4241 for layer_idx in 0..cfg.num_hidden_layers {
4242 let input_layernorm = cfg.hidden_size;
4243 let post_attention_layernorm = cfg.hidden_size;
4244
4245 let size_in = cfg.hidden_size;
4246 let size_q = cfg.head_dim() * cfg.num_attention_heads;
4247 let size_kv = cfg.head_dim() * cfg.num_key_value_heads;
4248 let q_proj = size_in * size_q / weight_pack_factor;
4249 let k_proj = size_in * size_kv / weight_pack_factor;
4250 let v_proj = size_in * size_kv / weight_pack_factor;
4251 let o_proj = size_q * size_in / weight_pack_factor;
4252
4253 let mlp_size = if !cfg.mlp_only_layers.contains(&layer_idx)
4254 && (cfg.num_experts > 0 && (layer_idx + 1) % cfg.decoder_sparse_step == 0)
4255 {
4256 let gate_size = cfg.hidden_size * cfg.num_experts;
4257 let expert_size = {
4258 let h_size = cfg.hidden_size;
4259 let i_size = cfg.moe_intermediate_size;
4260 let gate_proj = h_size * i_size / weight_pack_factor;
4261 let up_proj = h_size * i_size / weight_pack_factor;
4262 let down_proj = i_size * h_size / weight_pack_factor;
4263 gate_proj + up_proj + down_proj
4264 };
4265 expert_size * cfg.num_experts + gate_size
4266 } else {
4267 let h_size = cfg.hidden_size;
4268 let i_size = cfg.intermediate_size;
4269 let gate_proj = h_size * i_size / weight_pack_factor;
4270 let up_proj = h_size * i_size / weight_pack_factor;
4271 let down_proj = i_size * h_size / weight_pack_factor;
4272 gate_proj + up_proj + down_proj
4273 };
4274
4275 let q_norm = cfg.head_dim();
4276 let k_norm = cfg.head_dim();
4277
4278 let size_elems = input_layernorm
4279 + post_attention_layernorm
4280 + q_proj
4281 + k_proj
4282 + v_proj
4283 + o_proj
4284 + mlp_size
4285 + q_norm
4286 + k_norm;
4287
4288 let size_in_bytes = size_elems * dtype.size_in_bytes();
4289 layer_sizes_in_bytes.push(size_in_bytes);
4290 }
4291
4292 Ok(layer_sizes_in_bytes)
4293 }
4294
4295 fn num_layers(&self, config: &str) -> Result<usize> {
4296 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4297 Ok(cfg.num_hidden_layers)
4298 }
4299
4300 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4301 let cfg: models::qwen3_moe::Config = serde_json::from_str(config)?;
4302
4303 let cfg = ModelConfigMetadata {
4304 max_seq_len: cfg.max_position_embeddings,
4305 num_layers: cfg.num_hidden_layers,
4306 hidden_size: cfg.hidden_size,
4307 num_kv_heads: cfg.num_key_value_heads,
4308 num_attn_heads: cfg.num_attention_heads,
4309 sliding_window: cfg.sliding_window,
4310 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4311 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4312 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4313 };
4314
4315 Ok(Box::new(cfg))
4316 }
4317}
4318
4319pub struct SmolLm3Loader;
4325
4326impl NormalModelLoader for SmolLm3Loader {
4327 fn load(
4328 &self,
4329 config: &str,
4330 vb: ShardedVarBuilder,
4331 normal_loading_metadata: NormalLoadingMetadata,
4332 attention_mechanism: AttentionImplementation,
4333 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4334 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4335
4336 Ok(Box::new(models::smollm3::SmolLm3::new(
4337 &cfg,
4338 vb,
4339 self.is_gptx(config)?,
4340 normal_loading_metadata,
4341 attention_mechanism,
4342 )?))
4343 }
4344 fn load_xlora(
4345 &self,
4346 _config: &str,
4347 _vb: ShardedVarBuilder,
4348 _lora_config: &[((String, String), LoraConfig)],
4349 _xlora_config: Option<XLoraConfig>,
4350 _xlora_ordering: Ordering,
4351 _normal_loading_metadata: NormalLoadingMetadata,
4352 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4353 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4354 todo!()
4355 }
4356 fn is_gptx(&self, _: &str) -> Result<bool> {
4357 Ok(true)
4358 }
4359 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4360 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4361 Ok(Box::new(cfg))
4362 }
4363}
4364
4365impl IsqModelLoader for SmolLm3Loader {
4366 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4367 Ok(vec![
4368 Regex::new(r"lm_head\.(weight|bias)$")?,
4369 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4371 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4372 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4373 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4374 Regex::new(r"layers\.(\d+)\.mlp\.gate_proj\.(weight|bias)$")?,
4376 Regex::new(r"layers\.(\d+)\.mlp\.up_proj\.(weight|bias)$")?,
4377 Regex::new(r"layers\.(\d+)\.mlp\.down_proj\.(weight|bias)$")?,
4378 ])
4379 }
4380 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4381 self.isq_layer_regexes(config)
4382 }
4383}
4384
4385impl DeviceMappedModelLoader for SmolLm3Loader {
4386 fn mapped_max_act_size_elems(
4387 &self,
4388 config: &str,
4389 params: &AutoDeviceMapParams,
4390 ) -> Result<usize> {
4391 let AutoDeviceMapParams::Text {
4392 max_seq_len,
4393 max_batch_size,
4394 } = params
4395 else {
4396 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4397 };
4398
4399 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4400
4401 Ok(
4402 max_batch_size
4403 * cfg.num_attention_heads
4404 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4405 )
4406 }
4407 fn non_mapped_max_act_size_elems(
4408 &self,
4409 _config: &str,
4410 _params: &AutoDeviceMapParams,
4411 ) -> Result<usize> {
4412 Ok(0)
4413 }
4414
4415 fn non_mapped_size_in_bytes(
4416 &self,
4417 config: &str,
4418 dtype: DType,
4419 weight_pack_factor: usize,
4420 _matformer_config: Option<&MatformerSliceConfig>,
4421 ) -> Result<usize> {
4422 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4423
4424 let elems = {
4425 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4426 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4428 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4429 } else {
4430 0
4431 };
4432 let norm = cfg.hidden_size;
4433 embed_tokens + lm_head + norm
4434 };
4435 Ok(elems * dtype.size_in_bytes())
4436 }
4437
4438 fn layer_sizes_in_bytes(
4439 &self,
4440 config: &str,
4441 dtype: DType,
4442 weight_pack_factor: usize,
4443 _matformer_config: Option<&MatformerSliceConfig>,
4444 ) -> Result<Vec<usize>> {
4445 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4446
4447 let per_layer_elems = {
4448 let input_layernorm = cfg.hidden_size;
4449 let post_attention_layernorm = cfg.hidden_size;
4450
4451 let size_in = cfg.hidden_size;
4452 let size_q = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_attention_heads;
4453 let size_kv = (cfg.hidden_size / cfg.num_attention_heads) * cfg.num_key_value_heads;
4454 let q_proj = size_in * size_q / weight_pack_factor;
4455 let k_proj = size_in * size_kv / weight_pack_factor;
4456 let v_proj = size_in * size_kv / weight_pack_factor;
4457 let o_proj = size_q * size_in / weight_pack_factor;
4458
4459 let h_size = cfg.hidden_size;
4460 let i_size = cfg.intermediate_size;
4461 let gate_proj = h_size * i_size / weight_pack_factor;
4462 let up_proj = h_size * i_size / weight_pack_factor;
4463 let down_proj = i_size * h_size / weight_pack_factor;
4464
4465 input_layernorm
4466 + post_attention_layernorm
4467 + q_proj
4468 + k_proj
4469 + v_proj
4470 + o_proj
4471 + gate_proj
4472 + up_proj
4473 + down_proj
4474 };
4475 Ok(vec![
4476 per_layer_elems * dtype.size_in_bytes();
4477 cfg.num_hidden_layers
4478 ])
4479 }
4480
4481 fn num_layers(&self, config: &str) -> Result<usize> {
4482 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4483
4484 Ok(cfg.num_hidden_layers)
4485 }
4486 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4487 let cfg: crate::models::smollm3::Config = serde_json::from_str(config)?;
4488
4489 let cfg = ModelConfigMetadata {
4490 max_seq_len: cfg.max_position_embeddings,
4491 num_layers: cfg.num_hidden_layers,
4492 hidden_size: cfg.hidden_size,
4493 num_kv_heads: cfg.num_key_value_heads,
4494 num_attn_heads: cfg.num_attention_heads,
4495 sliding_window: None,
4496 k_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4497 v_head_dim: cfg.hidden_size / cfg.num_attention_heads,
4498 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4499 };
4500
4501 Ok(Box::new(cfg))
4502 }
4503}
4504
4505pub struct GraniteMoeHybridLoader;
4511
4512impl NormalModelLoader for GraniteMoeHybridLoader {
4513 fn load(
4514 &self,
4515 config: &str,
4516 vb: ShardedVarBuilder,
4517 normal_loading_metadata: NormalLoadingMetadata,
4518 attention_mechanism: AttentionImplementation,
4519 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4520 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4521
4522 Ok(Box::new(models::granite::GraniteMoeHybrid::new(
4523 &cfg,
4524 vb,
4525 self.is_gptx(config)?,
4526 normal_loading_metadata,
4527 attention_mechanism,
4528 )?))
4529 }
4530 fn load_xlora(
4531 &self,
4532 _config: &str,
4533 _vb: ShardedVarBuilder,
4534 _lora_config: &[((String, String), LoraConfig)],
4535 _xlora_config: Option<XLoraConfig>,
4536 _xlora_ordering: Ordering,
4537 _normal_loading_metadata: NormalLoadingMetadata,
4538 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4539 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4540 anyhow::bail!("GraniteMoeHybrid does not support X-LoRA")
4541 }
4542 fn is_gptx(&self, _: &str) -> Result<bool> {
4543 Ok(true)
4544 }
4545 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4546 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4547 Ok(Box::new(cfg))
4548 }
4549 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4550 Ok(true)
4551 }
4552}
4553
4554impl IsqModelLoader for GraniteMoeHybridLoader {
4555 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4556 Ok(vec![
4557 Regex::new(r"lm_head\.(weight|bias)$")?,
4558 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4560 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4561 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4562 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4563 Regex::new(r"layers\.(\d+)\.shared_mlp\.input_linear\.(weight|bias)$")?,
4565 Regex::new(r"layers\.(\d+)\.shared_mlp\.output_linear\.(weight|bias)$")?,
4566 ])
4567 }
4568 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4569 self.isq_layer_regexes(config)
4570 }
4571}
4572
4573impl DeviceMappedModelLoader for GraniteMoeHybridLoader {
4574 fn mapped_max_act_size_elems(
4575 &self,
4576 config: &str,
4577 params: &AutoDeviceMapParams,
4578 ) -> Result<usize> {
4579 let AutoDeviceMapParams::Text {
4580 max_seq_len,
4581 max_batch_size,
4582 } = params
4583 else {
4584 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4585 };
4586
4587 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4588
4589 Ok(
4590 max_batch_size
4591 * cfg.num_attention_heads
4592 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4593 )
4594 }
4595 fn non_mapped_max_act_size_elems(
4596 &self,
4597 _config: &str,
4598 _params: &AutoDeviceMapParams,
4599 ) -> Result<usize> {
4600 Ok(0)
4601 }
4602
4603 fn non_mapped_size_in_bytes(
4604 &self,
4605 config: &str,
4606 dtype: DType,
4607 weight_pack_factor: usize,
4608 _matformer_config: Option<&MatformerSliceConfig>,
4609 ) -> Result<usize> {
4610 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4611
4612 let elems = {
4613 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4614 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4616 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4617 } else {
4618 0
4619 };
4620 let norm = cfg.hidden_size;
4621 embed_tokens + lm_head + norm
4622 };
4623 Ok(elems * dtype.size_in_bytes())
4624 }
4625
4626 fn layer_sizes_in_bytes(
4627 &self,
4628 config: &str,
4629 dtype: DType,
4630 weight_pack_factor: usize,
4631 _matformer_config: Option<&MatformerSliceConfig>,
4632 ) -> Result<Vec<usize>> {
4633 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4634
4635 let per_layer_elems = {
4636 let input_layernorm = cfg.hidden_size;
4637 let post_attention_layernorm = cfg.hidden_size;
4638
4639 let size_in = cfg.hidden_size;
4640 let size_q = cfg.head_dim() * cfg.num_attention_heads;
4641 let size_kv = cfg.head_dim() * cfg.num_key_value_heads();
4642 let q_proj = size_in * size_q / weight_pack_factor;
4643 let k_proj = size_in * size_kv / weight_pack_factor;
4644 let v_proj = size_in * size_kv / weight_pack_factor;
4645 let o_proj = size_q * size_in / weight_pack_factor;
4646
4647 let h_size = cfg.hidden_size;
4648 let shared_i_size = cfg.shared_intermediate_size();
4649 let input_linear = h_size * shared_i_size * 2 / weight_pack_factor;
4651 let output_linear = shared_i_size * h_size / weight_pack_factor;
4652
4653 input_layernorm
4654 + post_attention_layernorm
4655 + q_proj
4656 + k_proj
4657 + v_proj
4658 + o_proj
4659 + input_linear
4660 + output_linear
4661 };
4662 Ok(vec![
4663 per_layer_elems * dtype.size_in_bytes();
4664 cfg.num_hidden_layers
4665 ])
4666 }
4667
4668 fn num_layers(&self, config: &str) -> Result<usize> {
4669 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4670
4671 Ok(cfg.num_hidden_layers)
4672 }
4673 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4674 let cfg: crate::models::granite::Config = serde_json::from_str(config)?;
4675
4676 let cfg = ModelConfigMetadata {
4677 max_seq_len: cfg.max_position_embeddings,
4678 num_layers: cfg.num_hidden_layers,
4679 hidden_size: cfg.hidden_size,
4680 num_kv_heads: cfg.num_key_value_heads(),
4681 num_attn_heads: cfg.num_attention_heads,
4682 sliding_window: None,
4683 k_head_dim: cfg.head_dim(),
4684 v_head_dim: cfg.head_dim(),
4685 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4686 };
4687
4688 Ok(Box::new(cfg))
4689 }
4690}
4691
4692pub struct GptOssLoader;
4698
4699impl NormalModelLoader for GptOssLoader {
4700 fn load(
4701 &self,
4702 config: &str,
4703 vb: ShardedVarBuilder,
4704 normal_loading_metadata: NormalLoadingMetadata,
4705 attention_mechanism: AttentionImplementation,
4706 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4707 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4708
4709 Ok(Box::new(models::gpt_oss::Model::new(
4710 &cfg,
4711 vb,
4712 self.is_gptx(config)?,
4713 normal_loading_metadata,
4714 attention_mechanism,
4715 )?))
4716 }
4717 fn load_xlora(
4718 &self,
4719 _config: &str,
4720 _vb: ShardedVarBuilder,
4721 _lora_config: &[((String, String), LoraConfig)],
4722 _xlora_config: Option<XLoraConfig>,
4723 _xlora_ordering: Ordering,
4724 _normal_loading_metadata: NormalLoadingMetadata,
4725 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4726 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4727 anyhow::bail!("GPT-OSS does not support X-LoRA")
4728 }
4729 fn is_gptx(&self, _: &str) -> Result<bool> {
4730 Ok(true)
4731 }
4732 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4733 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4734 Ok(Box::new(cfg))
4735 }
4736}
4737
4738impl IsqModelLoader for GptOssLoader {
4739 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4740 Ok(vec![
4742 Regex::new(r"lm_head\.(weight|bias)$")?,
4743 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4745 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4746 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4747 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4748 ])
4749 }
4750 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4751 self.isq_layer_regexes(config)
4752 }
4753}
4754
4755impl DeviceMappedModelLoader for GptOssLoader {
4756 fn mapped_max_act_size_elems(
4757 &self,
4758 config: &str,
4759 params: &AutoDeviceMapParams,
4760 ) -> Result<usize> {
4761 let AutoDeviceMapParams::Text {
4762 max_seq_len,
4763 max_batch_size,
4764 } = params
4765 else {
4766 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4767 };
4768
4769 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4770
4771 Ok(
4772 max_batch_size
4773 * cfg.num_attention_heads
4774 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4775 )
4776 }
4777 fn non_mapped_max_act_size_elems(
4778 &self,
4779 _config: &str,
4780 _params: &AutoDeviceMapParams,
4781 ) -> Result<usize> {
4782 Ok(0)
4783 }
4784
4785 fn non_mapped_size_in_bytes(
4786 &self,
4787 config: &str,
4788 dtype: DType,
4789 weight_pack_factor: usize,
4790 _matformer_config: Option<&MatformerSliceConfig>,
4791 ) -> Result<usize> {
4792 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4793
4794 let elems = {
4795 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
4796 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
4797 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
4798 } else {
4799 0
4800 };
4801 let norm = cfg.hidden_size;
4802 embed_tokens + lm_head + norm
4803 };
4804 Ok(elems * dtype.size_in_bytes())
4805 }
4806
4807 fn layer_sizes_in_bytes(
4808 &self,
4809 config: &str,
4810 dtype: DType,
4811 weight_pack_factor: usize,
4812 _matformer_config: Option<&MatformerSliceConfig>,
4813 ) -> Result<Vec<usize>> {
4814 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4815
4816 let per_layer_elems = {
4817 let input_layernorm = cfg.hidden_size;
4818 let post_attention_layernorm = cfg.hidden_size;
4819
4820 let size_in = cfg.hidden_size;
4821 let head_dim = cfg.head_dim();
4822 let size_q = head_dim * cfg.num_attention_heads;
4823 let size_kv = head_dim * cfg.num_key_value_heads;
4824 let q_proj =
4825 size_in * size_q / weight_pack_factor + bias_if!(cfg.attention_bias, size_q);
4826 let k_proj =
4827 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4828 let v_proj =
4829 size_in * size_kv / weight_pack_factor + bias_if!(cfg.attention_bias, size_kv);
4830 let o_proj =
4831 size_q * size_in / weight_pack_factor + bias_if!(cfg.attention_bias, size_in);
4832
4833 let mxfp4_pack = 2;
4838 let gate_up_proj_size =
4839 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / mxfp4_pack;
4840 let down_proj_size =
4841 cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / mxfp4_pack;
4842 let gate_up_scales =
4844 cfg.num_local_experts * cfg.intermediate_size * 2 * cfg.hidden_size / 32;
4845 let down_scales = cfg.num_local_experts * cfg.hidden_size * cfg.intermediate_size / 32;
4846 let gate_up_bias = cfg.num_local_experts * cfg.intermediate_size * 2;
4848 let down_bias = cfg.num_local_experts * cfg.hidden_size;
4849 let router = cfg.hidden_size * cfg.num_local_experts;
4851 let sinks = cfg.num_attention_heads;
4853
4854 input_layernorm
4855 + post_attention_layernorm
4856 + q_proj
4857 + k_proj
4858 + v_proj
4859 + o_proj
4860 + gate_up_proj_size
4861 + down_proj_size
4862 + gate_up_scales
4863 + down_scales
4864 + gate_up_bias
4865 + down_bias
4866 + router
4867 + sinks
4868 };
4869 Ok(vec![
4870 per_layer_elems * dtype.size_in_bytes();
4871 cfg.num_hidden_layers
4872 ])
4873 }
4874
4875 fn num_layers(&self, config: &str) -> Result<usize> {
4876 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4877
4878 Ok(cfg.num_hidden_layers)
4879 }
4880 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
4881 let cfg: crate::models::gpt_oss::Config = serde_json::from_str(config)?;
4882
4883 let head_dim = cfg.head_dim();
4884 let cfg = ModelConfigMetadata {
4885 max_seq_len: cfg.max_position_embeddings,
4886 num_layers: cfg.num_hidden_layers,
4887 hidden_size: cfg.hidden_size,
4888 num_kv_heads: cfg.num_key_value_heads,
4889 num_attn_heads: cfg.num_attention_heads,
4890 sliding_window: cfg.sliding_window,
4891 k_head_dim: head_dim,
4892 v_head_dim: head_dim,
4893 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
4894 };
4895
4896 Ok(Box::new(cfg))
4897 }
4898}
4899
4900pub struct Qwen3NextLoader;
4906
4907impl NormalModelLoader for Qwen3NextLoader {
4908 fn load(
4909 &self,
4910 config: &str,
4911 vb: ShardedVarBuilder,
4912 normal_loading_metadata: NormalLoadingMetadata,
4913 attention_mechanism: AttentionImplementation,
4914 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4915 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4916
4917 Ok(Box::new(models::qwen3_next::Model::new(
4918 &cfg,
4919 vb,
4920 self.is_gptx(config)?,
4921 normal_loading_metadata,
4922 attention_mechanism,
4923 )?))
4924 }
4925 fn load_xlora(
4926 &self,
4927 _config: &str,
4928 _vb: ShardedVarBuilder,
4929 _lora_config: &[((String, String), LoraConfig)],
4930 _xlora_config: Option<XLoraConfig>,
4931 _xlora_ordering: Ordering,
4932 _normal_loading_metadata: NormalLoadingMetadata,
4933 _preload_adapters: &Option<HashMap<String, (ShardedVarBuilder, LoraConfig)>>,
4934 ) -> Result<Box<dyn NormalModel + Send + Sync>> {
4935 anyhow::bail!("Qwen3Next does not support X-LoRA")
4936 }
4937 fn is_gptx(&self, _: &str) -> Result<bool> {
4938 Ok(true)
4939 }
4940 fn get_config_repr(&self, config: &str) -> Result<Box<dyn Debug>> {
4941 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4942 Ok(Box::new(cfg))
4943 }
4944 fn supports_paged_attention(&self, _config: &str) -> Result<bool> {
4945 Ok(true)
4946 }
4947}
4948
4949impl IsqModelLoader for Qwen3NextLoader {
4950 fn isq_layer_regexes(&self, _config: &str) -> Result<Vec<Regex>> {
4951 Ok(vec![
4952 Regex::new(r"lm_head\.(weight|bias)$")?,
4953 Regex::new(r"layers\.(\d+)\.self_attn\.q_proj\.(weight|bias)$")?,
4954 Regex::new(r"layers\.(\d+)\.self_attn\.k_proj\.(weight|bias)$")?,
4955 Regex::new(r"layers\.(\d+)\.self_attn\.v_proj\.(weight|bias)$")?,
4956 Regex::new(r"layers\.(\d+)\.self_attn\.o_proj\.(weight|bias)$")?,
4957 Regex::new(r"layers\.(\d+)\.linear_attn\.out_proj\.(weight|bias)$")?,
4958 Regex::new(
4959 r"layers\.(\d+)\.mlp\.experts\.(\d+)\.(gate_proj|up_proj|down_proj)\.(weight|bias)$",
4960 )?,
4961 Regex::new(
4962 r"layers\.(\d+)\.mlp\.shared_expert\.(gate_proj|up_proj|down_proj)\.(weight|bias)$",
4963 )?,
4964 ])
4965 }
4966 fn immediate_isq_predicates(&self, config: &str) -> Result<Vec<Regex>> {
4967 self.isq_layer_regexes(config)
4968 }
4969}
4970
4971impl DeviceMappedModelLoader for Qwen3NextLoader {
4972 fn mapped_max_act_size_elems(
4973 &self,
4974 config: &str,
4975 params: &AutoDeviceMapParams,
4976 ) -> Result<usize> {
4977 let AutoDeviceMapParams::Text {
4978 max_seq_len,
4979 max_batch_size,
4980 } = params
4981 else {
4982 anyhow::bail!("Expected text AutoDeviceMapParams for this model!")
4983 };
4984
4985 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
4986
4987 Ok(
4988 max_batch_size
4989 * cfg.num_attention_heads
4990 * max_seq_len.min(&ATTENTION_CHUNK_SIZE).pow(2),
4991 )
4992 }
4993 fn non_mapped_max_act_size_elems(
4994 &self,
4995 _config: &str,
4996 _params: &AutoDeviceMapParams,
4997 ) -> Result<usize> {
4998 Ok(0)
4999 }
5000
5001 fn non_mapped_size_in_bytes(
5002 &self,
5003 config: &str,
5004 dtype: DType,
5005 weight_pack_factor: usize,
5006 _matformer_config: Option<&MatformerSliceConfig>,
5007 ) -> Result<usize> {
5008 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5009
5010 let elems = {
5011 let embed_tokens = cfg.hidden_size * cfg.vocab_size / weight_pack_factor;
5012 let lm_head = if !cfg.tie_word_embeddings || weight_pack_factor != 1 {
5013 cfg.hidden_size * cfg.vocab_size / weight_pack_factor
5014 } else {
5015 0
5016 };
5017 let norm = cfg.hidden_size;
5018 embed_tokens + lm_head + norm
5019 };
5020 Ok(elems * dtype.size_in_bytes())
5021 }
5022
5023 fn layer_sizes_in_bytes(
5024 &self,
5025 config: &str,
5026 dtype: DType,
5027 weight_pack_factor: usize,
5028 _matformer_config: Option<&MatformerSliceConfig>,
5029 ) -> Result<Vec<usize>> {
5030 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5031 let layer_types = cfg.layer_types();
5032 let mut layer_sizes = Vec::with_capacity(cfg.num_hidden_layers);
5033
5034 for layer_type in &layer_types {
5035 let input_layernorm = cfg.hidden_size;
5036 let post_attention_layernorm = cfg.hidden_size;
5037
5038 let attn_elems = match layer_type {
5039 crate::models::qwen3_next::LayerType::FullAttention => {
5040 let hidden = cfg.hidden_size;
5041 let q_dim = cfg.head_dim * cfg.num_attention_heads;
5042 let kv_dim = cfg.head_dim * cfg.num_key_value_heads;
5043 let q_proj = hidden * q_dim * 2 / weight_pack_factor;
5044 let k_proj = hidden * kv_dim / weight_pack_factor;
5045 let v_proj = hidden * kv_dim / weight_pack_factor;
5046 let o_proj = q_dim * hidden / weight_pack_factor;
5047 let q_norm = cfg.head_dim;
5048 let k_norm = cfg.head_dim;
5049 q_proj + k_proj + v_proj + o_proj + q_norm + k_norm
5050 }
5051 crate::models::qwen3_next::LayerType::LinearAttention => {
5052 let hidden = cfg.hidden_size;
5053 let key_dim = cfg.linear_key_dim();
5054 let value_dim = cfg.linear_value_dim();
5055 let conv_dim = cfg.linear_conv_dim();
5056 let in_proj_qkvz = hidden * (key_dim * 2 + value_dim * 2);
5058 let in_proj_ba = hidden * (cfg.linear_num_value_heads * 2);
5060 let out_proj = value_dim * hidden / weight_pack_factor;
5061 let conv1d = conv_dim * cfg.linear_conv_kernel_dim;
5062 let dt_bias = cfg.linear_num_value_heads;
5063 let a_log = cfg.linear_num_value_heads;
5064 let norm = cfg.linear_value_head_dim;
5065 in_proj_qkvz + in_proj_ba + out_proj + conv1d + dt_bias + a_log + norm
5066 }
5067 };
5068
5069 let moe_gate = cfg.hidden_size * cfg.num_experts;
5070 let shared_expert =
5071 3 * cfg.hidden_size * cfg.shared_expert_intermediate_size / weight_pack_factor;
5072 let routed_experts = cfg.num_experts * 3 * cfg.hidden_size * cfg.moe_intermediate_size
5073 / weight_pack_factor;
5074
5075 let per_layer_elems = input_layernorm
5076 + post_attention_layernorm
5077 + attn_elems
5078 + moe_gate
5079 + shared_expert
5080 + routed_experts;
5081
5082 layer_sizes.push(per_layer_elems * dtype.size_in_bytes());
5083 }
5084
5085 Ok(layer_sizes)
5086 }
5087
5088 fn num_layers(&self, config: &str) -> Result<usize> {
5089 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5090 Ok(cfg.num_hidden_layers)
5091 }
5092 fn model_config(&self, config: &str) -> Result<Box<dyn ModelConfigLike>> {
5093 let cfg: crate::models::qwen3_next::Config = serde_json::from_str(config)?;
5094
5095 let cfg = ModelConfigMetadata {
5096 max_seq_len: cfg.max_position_embeddings,
5097 num_layers: cfg.num_hidden_layers,
5098 hidden_size: cfg.hidden_size,
5099 num_kv_heads: cfg.num_key_value_heads,
5100 num_attn_heads: cfg.num_attention_heads,
5101 sliding_window: None,
5102 k_head_dim: cfg.head_dim,
5103 v_head_dim: cfg.head_dim,
5104 kv_cache_layout: crate::paged_attention::KvCacheLayout::Standard,
5105 };
5106
5107 Ok(Box::new(cfg))
5108 }
5109}