1use std::path::Path;
7
8use crate::gguf::{GgufFile, MetadataValue};
9use crate::tensor::{DType, Tensor};
10
11use super::Architecture;
12use super::config::{ActivationType, ModelConfig, RopeConfig, RopeScalingType, RopeType};
13use super::deltanet::{BetaAlphaProjection, DeltaNetConfig, DeltaNetLayer};
14use super::mamba::{MambaConfig, MambaLayer};
15use super::error::{ModelError, ModelResult};
16use super::layers::{
17 Attention, AttentionLayer, FeedForward, FfnLayer, LayerNorm, Linear, NormLayer,
18 NoGateFeedForward, RMSNorm, TransformerLayer,
19};
20use super::bert::{BertLayer, BertModel};
21use super::llama::LlamaModel;
22use super::moe::{MoeConfig, MoeExpert, MoeLayer, MoeRouter};
23
24pub trait ModelSource {
30 fn config(&self) -> &ModelConfig;
32
33 fn config_mut(&mut self) -> &mut ModelConfig;
35
36 fn architecture(&self) -> Architecture;
38
39 fn load_tensor(&self, name: &str) -> ModelResult<Tensor>;
41
42 fn try_load_tensor(&self, name: &str) -> Option<Tensor>;
44}
45
46pub struct ModelLoader {
48 gguf: GgufFile,
50 architecture: Architecture,
52 config: ModelConfig,
54}
55
56impl ModelLoader {
57 pub fn load<P: AsRef<Path>>(path: P) -> ModelResult<Self> {
59 let gguf = GgufFile::open(path)?;
60
61 let arch_str = gguf
63 .data
64 .get_string("general.architecture")
65 .ok_or_else(|| ModelError::MissingMetadata("general.architecture".into()))?;
66
67 let architecture = Architecture::from_gguf_str(arch_str);
68
69 if matches!(architecture, Architecture::Unknown) {
70 return Err(ModelError::UnsupportedArchitecture(arch_str.to_string()));
71 }
72
73 let config = Self::parse_config(&gguf, &architecture)?;
75
76 Ok(Self {
77 gguf,
78 architecture,
79 config,
80 })
81 }
82
83 fn parse_config(gguf: &GgufFile, architecture: &Architecture) -> ModelResult<ModelConfig> {
85 let arch = architecture.as_str();
86
87 let get_u32 = |key: &str| -> ModelResult<u32> {
89 gguf.data
90 .get_u32(key)
91 .ok_or_else(|| ModelError::MissingMetadata(key.into()))
92 };
93
94 let get_f32_or =
96 |key: &str, default: f32| -> f32 { gguf.data.get_f32(key).unwrap_or(default) };
97
98 let vocab_size = get_u32(&format!("{}.vocab_size", arch))
101 .or_else(|_| get_u32("tokenizer.ggml.vocab_size"))
102 .map(|v| v as usize)
103 .unwrap_or_else(|_| {
104 if let Some(tokens) = gguf.data.metadata.get("tokenizer.ggml.tokens")
106 && let MetadataValue::Array(arr) = tokens
107 {
108 return arr.values.len();
109 }
110 if let Some(emb_info) = gguf.data.get_tensor("token_embd.weight") {
112 if emb_info.dims.len() == 2 {
114 return emb_info.dims[1] as usize;
115 }
116 }
117 32000
119 });
120
121 let hidden_size = get_u32(&format!("{}.embedding_length", arch))? as usize;
122
123 let num_layers = get_u32(&format!("{}.block_count", arch))? as usize;
124
125 let (num_heads, num_kv_heads, head_dim) =
127 if matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
128 let nh = get_u32(&format!("{}.attention.head_count", arch)).unwrap_or(1) as usize;
129 let nkv = get_u32(&format!("{}.attention.head_count_kv", arch))
130 .unwrap_or(nh as u32) as usize;
131 let hd = get_u32(&format!("{}.attention.key_length", arch))
132 .unwrap_or_else(|_| (hidden_size / nh.max(1)) as u32) as usize;
133 (nh, nkv, hd)
134 } else {
135 let nh = get_u32(&format!("{}.attention.head_count", arch))? as usize;
136 let nkv = get_u32(&format!("{}.attention.head_count_kv", arch))
137 .unwrap_or(nh as u32) as usize;
138 let hd = get_u32(&format!("{}.attention.key_length", arch))
139 .map(|v| v as usize)
140 .unwrap_or(hidden_size / nh);
141 (nh, nkv, hd)
142 };
143
144 let intermediate_size = get_u32(&format!("{}.feed_forward_length", arch))
145 .unwrap_or_else(|_| {
146 if matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
147 hidden_size as u32 } else {
149 (hidden_size * 4 * 2 / 3) as u32
150 }
151 }) as usize;
152
153 let max_seq_len = get_u32(&format!("{}.context_length", arch)).unwrap_or(2048) as usize;
154
155 let norm_eps = gguf
156 .data
157 .get_f32(&format!("{}.attention.layer_norm_rms_epsilon", arch))
158 .or_else(|| gguf.data.get_f32(&format!("{}.attention.layer_norm_epsilon", arch)))
159 .unwrap_or(1e-5);
160
161 let freq_base = get_f32_or(&format!("{}.rope.freq_base", arch), 10000.0);
163 let freq_scale = get_f32_or(&format!("{}.rope.scale_linear", arch), 1.0);
164
165 let rope_type = match architecture {
169 Architecture::Qwen2
170 | Architecture::Qwen2Moe
171 | Architecture::Qwen3
172 | Architecture::Qwen35
173 | Architecture::Qwen35Moe
174 | Architecture::Qwen3Moe
175 | Architecture::Qwen3Next
176 | Architecture::GPTNeoX
177 | Architecture::Falcon
178 | Architecture::Phi
179 | Architecture::Phi2
180 | Architecture::Phi3
181 | Architecture::PhiMoe
182 | Architecture::GPTJ
183 | Architecture::StableLM
184 | Architecture::Gemma
185 | Architecture::Gemma2
186 | Architecture::Gemma3
187 | Architecture::Gemma3N
188 | Architecture::Gemma4
189 | Architecture::GemmaEmbedding => RopeType::NeoX,
190 _ => RopeType::Normal,
191 };
192
193 let num_experts = get_u32(&format!("{}.expert_count", arch)).unwrap_or(0) as usize;
195 let num_experts_per_token =
196 get_u32(&format!("{}.expert_used_count", arch)).unwrap_or(0) as usize;
197 let expert_intermediate_size =
198 get_u32(&format!("{}.expert_feed_forward_length", arch)).unwrap_or(0) as usize;
199
200 let key_length =
202 get_u32(&format!("{}.attention.key_length", arch)).unwrap_or(head_dim as u32) as usize;
203 let value_length = get_u32(&format!("{}.attention.value_length", arch))
204 .unwrap_or(head_dim as u32) as usize;
205
206 let rope_n_dims = get_u32(&format!("{}.rope.dimension_count", arch))
207 .unwrap_or(head_dim as u32) as usize;
208
209 let mrope_sections = if let Some(MetadataValue::Array(arr)) =
211 gguf.data.metadata.get(&format!("{}.rope.dimension_sections", arch))
212 {
213 let sections: Vec<usize> = arr.values.iter().filter_map(|v| match v {
214 MetadataValue::Int32(n) if *n > 0 => Some(*n as usize),
215 _ => None,
216 }).collect();
217 if sections.is_empty() { None } else { Some(sections) }
218 } else {
219 None
220 };
221
222 let rope_config = RopeConfig {
223 freq_base,
224 freq_scale,
225 n_dims: rope_n_dims,
226 scaling_type: RopeScalingType::None,
227 original_max_position_embeddings: max_seq_len,
228 rope_type,
229 mrope_sections,
230 };
231
232 let has_combined_qkv = architecture.has_combined_qkv();
234 let uses_layer_norm = architecture.uses_layer_norm();
235 let uses_gelu = architecture.uses_gelu();
236 let has_ffn_gate = !architecture.has_no_gate_ffn();
237
238 let attn_logit_softcap =
240 get_f32_or(&format!("{}.attn_logit_softcapping", arch), 0.0);
241 let final_logit_softcap =
242 get_f32_or(&format!("{}.final_logit_softcapping", arch), 0.0);
243 let sliding_window =
244 get_u32(&format!("{}.attention.sliding_window", arch)).unwrap_or(0) as usize;
245
246 let attention_bias = matches!(
248 architecture,
249 Architecture::Qwen
250 | Architecture::Qwen2
251 | Architecture::Qwen2Moe
252 | Architecture::Phi2
253 | Architecture::Phi3
254 | Architecture::PhiMoe
255 | Architecture::GPTNeoX
256 | Architecture::GPTJ
257 | Architecture::Falcon
258 | Architecture::BLOOM
259 | Architecture::MPT
260 | Architecture::OPT
261 | Architecture::GPT2
262 | Architecture::StableLM
263 | Architecture::Baichuan
264 );
265
266 let mlp_bias = matches!(
267 architecture,
268 Architecture::GPT2
269 | Architecture::GPTJ
270 | Architecture::GPTNeoX
271 | Architecture::BLOOM
272 | Architecture::OPT
273 | Architecture::StableLM
274 | Architecture::Phi2
275 | Architecture::Phi3
276 );
277
278 let use_parallel_residual = matches!(
281 architecture,
282 Architecture::GPTNeoX
283 | Architecture::GPTJ
284 | Architecture::StableLM
285 | Architecture::Phi
286 | Architecture::Phi2
287 | Architecture::CodeShell
288 );
289
290 let hidden_act = if architecture.uses_gelu() {
292 ActivationType::GELU
293 } else {
294 ActivationType::SiLU
295 };
296
297 let mut config = ModelConfig {
298 vocab_size,
299 hidden_size,
300 intermediate_size,
301 num_layers,
302 num_heads,
303 num_kv_heads,
304 head_dim,
305 max_seq_len,
306 norm_eps,
307 rope_config,
308 use_parallel_residual,
309 hidden_act,
310 attention_bias,
311 mlp_bias,
312 tie_word_embeddings: gguf
313 .data
314 .get_string("general.tie_word_embeddings")
315 .map(|s| s == "true")
316 .unwrap_or(false),
317 num_experts,
318 num_experts_per_token,
319 expert_intermediate_size,
320 key_length,
321 value_length,
322 ssm_d_inner: get_u32(&format!("{}.ssm.inner_size", arch)).unwrap_or(0) as usize,
323 ssm_d_state: get_u32(&format!("{}.ssm.state_size", arch)).unwrap_or(0) as usize,
324 ssm_n_group: {
325 let g = get_u32(&format!("{}.ssm.group_count", arch)).unwrap_or(0) as usize;
326 if g == 0 && matches!(architecture, Architecture::Mamba | Architecture::Mamba2) {
328 1
329 } else {
330 g
331 }
332 },
333 ssm_dt_rank: get_u32(&format!("{}.ssm.time_step_rank", arch)).unwrap_or(0) as usize,
334 ssm_conv_kernel: get_u32(&format!("{}.ssm.conv_kernel", arch)).unwrap_or(0) as usize,
335 attn_logit_softcap,
336 final_logit_softcap,
337 sliding_window,
338 has_combined_qkv,
339 uses_layer_norm,
340 uses_gelu,
341 has_ffn_gate,
342 attention_layer_configs: None,
343 kv_source_layer: None,
344 };
345
346 if architecture.has_heterogeneous_attention() {
352 let global_head_dim = config.head_dim; let global_kv_heads = config.num_kv_heads; let global_rope_freq_base = config.rope_config.freq_base; let global_rope_dims = if let Some(data) = gguf.tensor_data("rope_freqs.weight") {
365 let floats: &[f32] = bytemuck::cast_slice(data);
366 let active_pairs = floats.iter().filter(|&&v| v < 1e10).count();
367 active_pairs * 2
368 } else {
369 get_u32(&format!("{}.rope.dimension_count", arch))
370 .unwrap_or(global_head_dim as u32) as usize
371 };
372
373 let swa_head_dim =
375 get_u32(&format!("{}.attention.key_length_swa", arch))
376 .unwrap_or(global_head_dim as u32) as usize;
377 let swa_kv_heads =
378 get_u32(&format!("{}.attention.head_count_kv_swa", arch))
379 .unwrap_or(global_kv_heads as u32) as usize;
380 let swa_rope_freq_base =
381 get_f32_or(&format!("{}.rope.freq_base_swa", arch), global_rope_freq_base);
382 let swa_rope_dims =
383 get_u32(&format!("{}.rope.dimension_count_swa", arch))
384 .unwrap_or(swa_head_dim as u32) as usize;
385 let sliding_window = config.sliding_window;
386
387 let swa_pattern: Vec<bool> =
389 if let Some(MetadataValue::Array(arr)) =
390 gguf.data.metadata.get(&format!("{}.attention.sliding_window_pattern", arch))
391 {
392 arr.values
393 .iter()
394 .map(|v| matches!(v, MetadataValue::Bool(true)))
395 .collect()
396 } else {
397 (0..config.num_layers)
399 .map(|i| i % 6 != 5)
400 .collect()
401 };
402
403 let (swa_kv_heads, global_kv_heads) = match gguf
409 .data
410 .get_u32_array(&format!("{}.attention.head_count_kv", arch))
411 .filter(|v| v.len() == config.num_layers)
412 {
413 Some(per_layer) => {
414 let swa = swa_pattern
415 .iter()
416 .position(|&s| s)
417 .map(|i| per_layer[i] as usize)
418 .unwrap_or(swa_kv_heads);
419 let global = swa_pattern
420 .iter()
421 .position(|&s| !s)
422 .map(|i| per_layer[i] as usize)
423 .unwrap_or(global_kv_heads);
424 (swa, global)
425 }
426 None => (swa_kv_heads, global_kv_heads),
427 };
428 config.num_kv_heads = global_kv_heads;
430
431 config.attention_layer_configs =
432 Some(ModelConfig::build_attention_layer_configs_from_pattern(
433 &swa_pattern,
434 swa_head_dim,
435 swa_kv_heads,
436 swa_rope_freq_base,
437 swa_rope_dims,
438 sliding_window,
439 global_head_dim,
440 global_kv_heads,
441 global_rope_freq_base,
442 global_rope_dims,
443 ));
444
445 let shared_layers =
449 get_u32(&format!("{}.attention.shared_kv_layers", arch)).unwrap_or(0) as usize;
450 if shared_layers > 0 {
451 config.kv_source_layer = Some(ModelConfig::build_kv_source_mapping(
452 config.num_layers,
453 shared_layers,
454 config.attention_layer_configs.as_ref().unwrap(),
455 ));
456 }
457 }
458
459 Ok(config)
460 }
461
462 pub fn config(&self) -> &ModelConfig {
464 &self.config
465 }
466
467 pub fn config_mut(&mut self) -> &mut ModelConfig {
469 &mut self.config
470 }
471
472 pub fn architecture(&self) -> Architecture {
474 self.architecture
475 }
476
477 pub fn build_model(self) -> ModelResult<LlamaModel> {
483 build_llama_model(&self)
484 }
485
486 pub fn build_bert_model(self) -> ModelResult<BertModel> {
488 let token_embedding = self.load_tensor("token_embd.weight")?;
489
490 let position_embedding = self.try_load_tensor("position_embd.weight");
491 let token_type_embedding = self.try_load_tensor("token_types.weight");
492
493 let embed_norm = if let Some(w) = self.try_load_tensor("token_embd_norm.weight") {
495 if let Some(b) = self.try_load_tensor("token_embd_norm.bias") {
496 Some(NormLayer::Layer(LayerNorm::new(w, b, self.config.norm_eps)?))
497 } else {
498 Some(NormLayer::RMS(RMSNorm::new(w, self.config.norm_eps)?))
499 }
500 } else {
501 None
502 };
503
504 let mut layers = Vec::with_capacity(self.config.num_layers);
505 for i in 0..self.config.num_layers {
506 let prefix = format!("blk.{}", i);
507
508 let attn_norm_w = self
510 .try_load_tensor(&format!("{}.attn_output_norm.weight", prefix))
511 .or_else(|| self.try_load_tensor(&format!("{}.attn_norm.weight", prefix)))
512 .ok_or_else(|| {
513 ModelError::MissingTensor(format!("{}.attn_norm.weight", prefix))
514 })?;
515 let attn_norm_b = self
516 .try_load_tensor(&format!("{}.attn_output_norm.bias", prefix))
517 .or_else(|| self.try_load_tensor(&format!("{}.attn_norm.bias", prefix)));
518 let attn_norm = if let Some(b) = attn_norm_b {
519 NormLayer::Layer(LayerNorm::new(attn_norm_w, b, self.config.norm_eps)?)
520 } else {
521 NormLayer::RMS(RMSNorm::new(attn_norm_w, self.config.norm_eps)?)
522 };
523
524 let (wq, wk, wv) =
526 if let Some(qkv) = self.try_load_tensor(&format!("{}.attn_qkv.weight", prefix)) {
527 let num_heads = self.config.num_heads;
529 let head_dim = self.config.head_dim;
530 let hidden = self.config.hidden_size;
531 let q_size = num_heads * head_dim;
532 let k_size = num_heads * head_dim;
533 let v_size = num_heads * head_dim;
534
535 let qkv_f32 = if qkv.dtype() == DType::F32 {
536 qkv.as_f32()?.to_vec()
537 } else {
538 let backend = crate::backend::default_backend();
539 let mut deq = Tensor::zeros(vec![qkv.numel()], DType::F32);
540 backend
541 .dequantize(&qkv, &mut deq)
542 .map_err(|e| ModelError::ConfigError(format!("Dequant QKV: {}", e)))?;
543 deq.as_f32()?.to_vec()
544 };
545
546 let q_start = 0;
549 let k_start_off = q_size * hidden;
550 let v_start_off = (q_size + k_size) * hidden;
551
552 let qkv_bias = self.try_load_tensor(&format!("{}.attn_qkv.bias", prefix));
553 let (qb, kb, vb) = if let Some(ref b) = qkv_bias {
554 let bd = b.as_f32()?;
555 (
556 Some(Tensor::from_f32(&bd[..q_size], vec![q_size])?),
557 Some(Tensor::from_f32(
558 &bd[q_size..q_size + k_size],
559 vec![k_size],
560 )?),
561 Some(Tensor::from_f32(&bd[q_size + k_size..], vec![v_size])?),
562 )
563 } else {
564 (None, None, None)
565 };
566
567 (
568 Linear::new(
569 Tensor::from_f32(&qkv_f32[q_start..q_start + q_size * hidden], vec![hidden, q_size])?,
570 qb,
571 )?,
572 Linear::new(
573 Tensor::from_f32(&qkv_f32[k_start_off..k_start_off + k_size * hidden], vec![hidden, k_size])?,
574 kb,
575 )?,
576 Linear::new(
577 Tensor::from_f32(&qkv_f32[v_start_off..v_start_off + v_size * hidden], vec![hidden, v_size])?,
578 vb,
579 )?,
580 )
581 } else {
582 let qb = self.try_load_tensor(&format!("{}.attn_q.bias", prefix));
583 let kb = self.try_load_tensor(&format!("{}.attn_k.bias", prefix));
584 let vb = self.try_load_tensor(&format!("{}.attn_v.bias", prefix));
585 (
586 Linear::new(
587 self.load_tensor(&format!("{}.attn_q.weight", prefix))?,
588 qb,
589 )?,
590 Linear::new(
591 self.load_tensor(&format!("{}.attn_k.weight", prefix))?,
592 kb,
593 )?,
594 Linear::new(
595 self.load_tensor(&format!("{}.attn_v.weight", prefix))?,
596 vb,
597 )?,
598 )
599 };
600
601 let wo_bias = self.try_load_tensor(&format!("{}.attn_output.bias", prefix));
602 let wo = Linear::new(
603 self.load_tensor(&format!("{}.attn_output.weight", prefix))?,
604 wo_bias,
605 )?;
606
607 let ffn_norm_w = self
609 .try_load_tensor(&format!("{}.layer_output_norm.weight", prefix))
610 .or_else(|| self.try_load_tensor(&format!("{}.ffn_norm.weight", prefix)))
611 .ok_or_else(|| {
612 ModelError::MissingTensor(format!("{}.ffn_norm.weight", prefix))
613 })?;
614 let ffn_norm_b = self
615 .try_load_tensor(&format!("{}.layer_output_norm.bias", prefix))
616 .or_else(|| self.try_load_tensor(&format!("{}.ffn_norm.bias", prefix)));
617 let ffn_norm = if let Some(b) = ffn_norm_b {
618 NormLayer::Layer(LayerNorm::new(ffn_norm_w, b, self.config.norm_eps)?)
619 } else {
620 NormLayer::RMS(RMSNorm::new(ffn_norm_w, self.config.norm_eps)?)
621 };
622
623 let ffn_up_bias = self.try_load_tensor(&format!("{}.ffn_up.bias", prefix));
624 let ffn_up = Linear::new(
625 self.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
626 ffn_up_bias,
627 )?;
628 let ffn_down_bias = self.try_load_tensor(&format!("{}.ffn_down.bias", prefix));
629 let ffn_down = Linear::new(
630 self.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
631 ffn_down_bias,
632 )?;
633
634 layers.push(BertLayer {
635 attn_norm,
636 wq,
637 wk,
638 wv,
639 wo,
640 num_heads: self.config.num_heads,
641 head_dim: self.config.head_dim,
642 ffn_norm,
643 ffn_up,
644 ffn_down,
645 });
646 }
647
648 BertModel::new(
649 self.config,
650 token_embedding,
651 position_embedding,
652 token_type_embedding,
653 embed_norm,
654 layers,
655 self.architecture,
656 )
657 }
658
659 pub fn deltanet_config(&self) -> Option<DeltaNetConfig> {
662 deltanet_config_from_source(self)
663 }
664
665 pub fn recurrent_config(&self) -> Option<super::deltanet::RecurrentConfig> {
667 if !self.config.has_ssm() {
668 return None;
669 }
670 if matches!(self.architecture, Architecture::Mamba | Architecture::Mamba2) {
671 Some(super::deltanet::RecurrentConfig::Mamba(MambaConfig {
672 d_inner: self.config.ssm_d_inner,
673 d_state: self.config.ssm_d_state,
674 dt_rank: self.config.ssm_dt_rank,
675 conv_kernel: self.config.ssm_conv_kernel.max(1),
676 }))
677 } else if let Some(dn) = self.deltanet_config() {
678 Some(super::deltanet::RecurrentConfig::DeltaNet(dn))
679 } else {
680 None
681 }
682 }
683
684 fn gguf_try_load_tensor(&self, name: &str) -> Option<Tensor> {
687 let tensor_info = self.gguf.data.get_tensor(name)?;
688 let tensor_data = self.gguf.tensor_data(name)?;
689
690 let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
691 let dtype = DType::from(tensor_info.dtype);
692
693 Tensor::new(tensor_data.to_vec(), shape, dtype)
694 .ok()
695 .map(|mut t| {
696 t.set_name(name);
697 t
698 })
699 }
700
701 fn gguf_load_tensor(&self, name: &str) -> ModelResult<Tensor> {
704 let tensor_info = self
705 .gguf
706 .data
707 .get_tensor(name)
708 .ok_or_else(|| ModelError::MissingTensor(name.into()))?;
709
710 let tensor_data = self
711 .gguf
712 .tensor_data(name)
713 .ok_or_else(|| ModelError::MissingTensor(name.into()))?;
714
715 let shape: Vec<usize> = tensor_info.dims.iter().map(|&d| d as usize).collect();
716 let dtype = DType::from(tensor_info.dtype);
717
718 let mut tensor = Tensor::new(tensor_data.to_vec(), shape, dtype)?;
722
723 tensor.set_name(name);
725
726 Ok(tensor)
727 }
728}
729
730impl ModelSource for ModelLoader {
731 fn config(&self) -> &ModelConfig {
732 &self.config
733 }
734
735 fn config_mut(&mut self) -> &mut ModelConfig {
736 &mut self.config
737 }
738
739 fn architecture(&self) -> Architecture {
740 self.architecture
741 }
742
743 fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
744 self.gguf_load_tensor(name)
745 }
746
747 fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
748 self.gguf_try_load_tensor(name)
749 }
750}
751
752pub fn build_llama_model(source: &dyn ModelSource) -> ModelResult<LlamaModel> {
765 let token_embedding = source.load_tensor("token_embd.weight")?;
767
768 let config = source.config();
770 let mut layers = Vec::with_capacity(config.num_layers);
771 for i in 0..config.num_layers {
772 let layer = load_transformer_layer(source, i)?;
773 layers.push(layer);
774 }
775
776 let recurrent_count = layers.iter().filter(|l| l.is_recurrent()).count();
778 if recurrent_count > 0 {
779 tracing::info!(
780 "Model has {}/{} DeltaNet recurrent layers",
781 recurrent_count,
782 layers.len()
783 );
784 }
785
786 let norm_weight =
788 apply_gemma_norm_weight_offset(source.load_tensor("output_norm.weight")?)?;
789 let norm = if let Some(bias) = source.try_load_tensor("output_norm.bias") {
790 NormLayer::Layer(LayerNorm::new(norm_weight, bias, config.norm_eps)?)
791 } else {
792 NormLayer::RMS(RMSNorm::new(norm_weight, config.norm_eps)?)
793 };
794
795 let output_bias = source.try_load_tensor("output.bias");
797 let output =
798 if config.tie_word_embeddings || source.try_load_tensor("output.weight").is_none() {
799 Linear::new(token_embedding.clone(), output_bias)?
800 } else {
801 let output_weight = source.load_tensor("output.weight")?;
802 Linear::new(output_weight, output_bias)?
803 };
804
805 let per_layer_token_embd = source.try_load_tensor("per_layer_token_embd.weight");
807 let per_layer_model_proj = source
808 .try_load_tensor("per_layer_model_proj.weight")
809 .map(|w| {
810 if w.dtype() != DType::F32 {
812 let backend = crate::backend::default_backend();
813 let mut deq = Tensor::zeros(vec![w.numel()], DType::F32);
814 backend
815 .dequantize(&w, &mut deq)
816 .map_err(|e| {
817 ModelError::ConfigError(format!(
818 "Failed to dequantize per_layer_model_proj: {e}"
819 ))
820 })?;
821 let shape = w.shape().to_vec();
822 let deq = deq.reshape(shape)?;
823 Linear::new(deq, None)
824 } else {
825 Linear::new(w, None)
826 }
827 })
828 .transpose()?;
829 let per_layer_proj_norm = source
830 .try_load_tensor("per_layer_proj_norm.weight")
831 .map(|w| RMSNorm::new(w, config.norm_eps))
832 .transpose()?;
833
834 let n_epl = per_layer_proj_norm
836 .as_ref()
837 .map(|n| n.hidden_size)
838 .unwrap_or(0);
839
840 if n_epl > 0 {
841 tracing::info!(
842 "Gemma 4 PLIE active: n_epl={}, n_layers={}, total_pl_dim={}",
843 n_epl,
844 config.num_layers,
845 n_epl * config.num_layers
846 );
847 }
848
849 LlamaModel::new(
850 config.clone(),
851 token_embedding,
852 layers,
853 norm,
854 output,
855 source.architecture(),
856 per_layer_token_embd,
857 per_layer_model_proj,
858 per_layer_proj_norm,
859 n_epl,
860 )
861}
862
863fn load_transformer_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<TransformerLayer> {
865 let prefix = format!("blk.{}", layer_idx);
866 let config = source.config();
867 let arch = source.architecture();
868 let is_mamba = matches!(arch, Architecture::Mamba | Architecture::Mamba2);
869
870 let attn_norm_weight = source
872 .try_load_tensor(&format!("{}.attn_norm.weight", prefix))
873 .or_else(|| source.try_load_tensor(&format!("{}.norm.weight", prefix)))
874 .ok_or_else(|| ModelError::MissingTensor(format!("{}.attn_norm.weight", prefix)))?;
875 let attn_norm_weight = apply_gemma_norm_weight_offset(attn_norm_weight)?;
876 let attn_norm_bias = source
877 .try_load_tensor(&format!("{}.attn_norm.bias", prefix))
878 .or_else(|| source.try_load_tensor(&format!("{}.norm.bias", prefix)));
879 let attn_norm = if let Some(bias) = attn_norm_bias {
880 NormLayer::Layer(LayerNorm::new(attn_norm_weight, bias, config.norm_eps)?)
881 } else {
882 NormLayer::RMS(RMSNorm::new(attn_norm_weight, config.norm_eps)?)
883 };
884
885 let attn_layer = load_attention_layer(source, layer_idx)?;
887
888 let ffn_norm_weight = source.try_load_tensor(&format!("{}.ffn_norm.weight", prefix));
901 let ffn_norm_bias = source.try_load_tensor(&format!("{}.ffn_norm.bias", prefix));
902 let post_attn_w =
903 source.try_load_tensor(&format!("{}.post_attention_norm.weight", prefix));
904 let post_attn_b = source.try_load_tensor(&format!("{}.post_attention_norm.bias", prefix));
905
906 let (ffn_norm, post_attn_norm) = if let Some(w) = ffn_norm_weight {
907 let w = apply_gemma_norm_weight_offset(w)?;
908 let ffn = if let Some(bias) = ffn_norm_bias {
909 NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
910 } else {
911 NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
912 };
913 let pan = post_attn_w
914 .map(|w| -> ModelResult<NormLayer> {
915 let w = apply_gemma_norm_weight_offset(w)?;
916 Ok(if let Some(bias) = post_attn_b {
917 NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
918 } else {
919 NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
920 })
921 })
922 .transpose()?;
923 (ffn, pan)
924 } else if let Some(w) = post_attn_w {
925 let w = apply_gemma_norm_weight_offset(w)?;
928 let ffn = if let Some(bias) = post_attn_b {
929 NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
930 } else {
931 NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
932 };
933 (ffn, None)
934 } else if is_mamba || config.use_parallel_residual {
935 let hidden = config.hidden_size;
939 let ffn = NormLayer::RMS(RMSNorm::new(
940 Tensor::from_f32(&vec![1.0f32; hidden], vec![hidden])?,
941 config.norm_eps,
942 )?);
943 (ffn, None)
944 } else {
945 return Err(ModelError::MissingTensor(format!(
946 "{}.ffn_norm.weight",
947 prefix
948 )));
949 };
950
951 let ffn_layer = if config.is_moe() {
953 load_moe_layer(source, layer_idx)?
954 } else if is_mamba
955 && source.try_load_tensor(&format!("{}.ffn_up.weight", prefix)).is_none()
956 {
957 FfnLayer::Identity
958 } else if !config.has_ffn_gate {
959 let up_tensor = source.load_tensor(&format!("{}.ffn_up.weight", prefix))?;
960 let up_out_dim = up_tensor.shape()[up_tensor.ndim() - 1];
961 let intermediate = config.intermediate_size;
962
963 if up_out_dim == 2 * intermediate {
964 let hidden = config.hidden_size;
966 let up_f32 = if up_tensor.dtype() == DType::F32 {
967 up_tensor.as_f32()?.to_vec()
968 } else {
969 let backend = crate::backend::default_backend();
970 let mut deq = Tensor::zeros(vec![up_tensor.numel()], DType::F32);
971 backend
972 .dequantize(&up_tensor, &mut deq)
973 .map_err(|e| ModelError::ConfigError(format!("Dequant ffn_up: {}", e)))?;
974 deq.as_f32()?.to_vec()
975 };
976
977 let gate_data = &up_f32[..hidden * intermediate];
978 let up_data = &up_f32[hidden * intermediate..];
979 let w_gate = Linear::new(
980 Tensor::from_f32(gate_data, vec![hidden, intermediate])?,
981 None,
982 )?;
983 let w_up = Linear::new(
984 Tensor::from_f32(up_data, vec![hidden, intermediate])?,
985 None,
986 )?;
987 let w_down = Linear::new(
988 source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
989 None,
990 )?;
991 let mut ffn = FeedForward::new(w_gate, w_up, w_down);
992 ffn.use_gelu = config.uses_gelu;
993 FfnLayer::Dense(ffn)
994 } else {
995 let w_up = Linear::new(
996 up_tensor,
997 source.try_load_tensor(&format!("{}.ffn_up.bias", prefix)),
998 )?;
999 let w_down = Linear::new(
1000 source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
1001 source.try_load_tensor(&format!("{}.ffn_down.bias", prefix)),
1002 )?;
1003 FfnLayer::NoGate(NoGateFeedForward::new(
1004 w_up,
1005 w_down,
1006 config.uses_gelu,
1007 ))
1008 }
1009 } else {
1010 let w_gate = Linear::new(
1011 source.load_tensor(&format!("{}.ffn_gate.weight", prefix))?,
1012 None,
1013 )?;
1014 let w_up = Linear::new(
1015 source.load_tensor(&format!("{}.ffn_up.weight", prefix))?,
1016 None,
1017 )?;
1018 let w_down = Linear::new(
1019 source.load_tensor(&format!("{}.ffn_down.weight", prefix))?,
1020 None,
1021 )?;
1022 let mut ffn = FeedForward::new(w_gate, w_up, w_down);
1023 ffn.use_gelu = config.uses_gelu;
1024 FfnLayer::Dense(ffn)
1025 };
1026
1027 let post_ffn_norm =
1029 if let Some(w) = source.try_load_tensor(&format!("{}.post_ffw_norm.weight", prefix)) {
1030 let w = apply_gemma_norm_weight_offset(w)?;
1031 let b = source.try_load_tensor(&format!("{}.post_ffw_norm.bias", prefix));
1032 Some(if let Some(bias) = b {
1033 NormLayer::Layer(LayerNorm::new(w, bias, config.norm_eps)?)
1034 } else {
1035 NormLayer::RMS(RMSNorm::new(w, config.norm_eps)?)
1036 })
1037 } else {
1038 None
1039 };
1040
1041 let rope_freq_base_override = config
1042 .attention_layer_configs
1043 .as_ref()
1044 .map(|cfgs| cfgs[layer_idx].rope_freq_base)
1045 .unwrap_or(0.0);
1046
1047 let plie_inp_gate = source
1049 .try_load_tensor(&format!("{}.inp_gate.weight", prefix))
1050 .map(|w| Linear::new(w, None))
1051 .transpose()?;
1052 let plie_proj = source
1053 .try_load_tensor(&format!("{}.proj.weight", prefix))
1054 .map(|w| Linear::new(w, None))
1055 .transpose()?;
1056 let plie_post_norm = source
1057 .try_load_tensor(&format!("{}.post_norm.weight", prefix))
1058 .map(|w| RMSNorm::new(w, config.norm_eps))
1059 .transpose()?;
1060 let layer_output_scale = source
1061 .try_load_tensor(&format!("{}.layer_output_scale.weight", prefix))
1062 .and_then(|t| {
1063 if t.dtype() == crate::tensor::DType::F32 {
1064 t.as_f32().ok().map(|d| d[0])
1065 } else {
1066 let raw = t.data();
1068 if raw.len() >= 2 {
1069 let bits = u16::from_le_bytes([raw[0], raw[1]]);
1070 Some(f32::from_bits((bits as u32) << 16))
1071 } else {
1072 None
1073 }
1074 }
1075 });
1076
1077 Ok(TransformerLayer {
1078 attn_norm,
1079 attn_layer,
1080 post_attn_norm,
1081 ffn_norm,
1082 ffn_layer,
1083 post_ffn_norm,
1084 layer_idx,
1085 use_parallel_residual: config.use_parallel_residual,
1086 rope_freq_base_override,
1087 plie_inp_gate,
1088 plie_proj,
1089 plie_post_norm,
1090 layer_output_scale,
1091 })
1092}
1093
1094fn load_attention_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<AttentionLayer> {
1096 let prefix = format!("blk.{}", layer_idx);
1097 let config = source.config();
1098
1099 if let Some(wq_weight) = source.try_load_tensor(&format!("{}.attn_q.weight", prefix)) {
1100 let attn = load_full_attention(source, layer_idx, wq_weight)?;
1102 Ok(AttentionLayer::FullAttention(attn))
1103 } else if let Some(qkv_weight) =
1104 source.try_load_tensor(&format!("{}.attn_qkv.weight", prefix))
1105 {
1106 if config.has_ssm() {
1107 let dn = load_deltanet_layer(source, layer_idx)?;
1109 Ok(AttentionLayer::DeltaNet(Box::new(dn)))
1110 } else {
1111 let attn = load_combined_qkv_attention(source, layer_idx, qkv_weight)?;
1113 Ok(AttentionLayer::FullAttention(attn))
1114 }
1115 } else if config.has_ssm()
1116 && source.try_load_tensor(&format!("{}.ssm_in.weight", prefix)).is_some()
1117 {
1118 let mamba = load_mamba_layer(source, layer_idx)?;
1120 Ok(AttentionLayer::Mamba(Box::new(mamba)))
1121 } else {
1122 Err(ModelError::MissingTensor(format!(
1123 "{}.attn_q.weight or {}.attn_qkv.weight or {}.ssm_in.weight",
1124 prefix, prefix, prefix
1125 )))
1126 }
1127}
1128
1129fn load_full_attention(
1131 source: &dyn ModelSource,
1132 layer_idx: usize,
1133 wq_weight: Tensor,
1134) -> ModelResult<Attention> {
1135 let prefix = format!("blk.{}", layer_idx);
1136 let config = source.config();
1137 let arch = source.architecture();
1138 let use_neox_rope = matches!(config.rope_config.rope_type, RopeType::NeoX);
1139
1140 let (num_kv_heads, head_dim, kl, vl, rope_dims) =
1142 if let Some(ref layer_configs) = config.attention_layer_configs {
1143 let lc = &layer_configs[layer_idx];
1144 (lc.num_kv_heads, lc.head_dim, lc.head_dim, lc.head_dim, lc.rope_dims)
1145 } else {
1146 let kl = config.key_length;
1147 let vl = config.value_length;
1148 let rope_dims = config.rope_config.n_dims;
1149 (config.num_kv_heads, config.head_dim, kl, vl, rope_dims)
1150 };
1151
1152 let wq_bias = source.try_load_tensor(&format!("{}.attn_q.bias", prefix));
1153 let actual_q_out = wq_weight.shape()[1];
1154 let has_attention_gate = actual_q_out == config.num_heads * (kl + vl);
1155
1156 let wq = Linear::new(wq_weight, wq_bias)?;
1157
1158 let wk_bias = source.try_load_tensor(&format!("{}.attn_k.bias", prefix));
1159 let wk = Linear::new(
1160 source.load_tensor(&format!("{}.attn_k.weight", prefix))?,
1161 wk_bias,
1162 )?;
1163 let wv_bias = source.try_load_tensor(&format!("{}.attn_v.bias", prefix));
1169 let wv_weight = match source.try_load_tensor(&format!("{}.attn_v.weight", prefix)) {
1170 Some(w) => w,
1171 None => source.load_tensor(&format!("{}.attn_k.weight", prefix))?,
1172 };
1173 let wv = Linear::new(wv_weight, wv_bias)?;
1174 let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
1175 let wo = Linear::new(
1176 source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
1177 wo_bias,
1178 )?;
1179
1180 let mut attention = Attention::with_kv_dims(
1181 wq, wk, wv, wo,
1182 config.num_heads,
1183 num_kv_heads,
1184 head_dim,
1185 kl, vl, rope_dims,
1186 use_neox_rope,
1187 has_attention_gate,
1188 );
1189
1190 if arch.uses_qk_norm()
1191 && let (Some(q_norm_w), Some(k_norm_w)) = (
1192 source.try_load_tensor(&format!("{}.attn_q_norm.weight", prefix)),
1193 source.try_load_tensor(&format!("{}.attn_k_norm.weight", prefix)),
1194 )
1195 {
1196 let q_norm = RMSNorm::new(q_norm_w, config.norm_eps)?;
1197 let k_norm = RMSNorm::new(k_norm_w, config.norm_eps)?;
1198 attention.set_qk_norms(q_norm, k_norm);
1199 }
1200
1201 if config.attn_logit_softcap > 0.0 {
1202 attention.set_attn_logit_softcap(config.attn_logit_softcap);
1203 }
1204
1205 if matches!(arch, Architecture::Qwen3Next | Architecture::Qwen35Moe) {
1207 attention.set_rope_partial_at_end(true);
1208 }
1209
1210 if let Some(ref sections) = config.rope_config.mrope_sections {
1212 attention.mrope_sections = Some(sections.clone());
1213 }
1214
1215 if let Some(ref layer_configs) = config.attention_layer_configs {
1217 let lc = &layer_configs[layer_idx];
1218 if lc.sliding_window > 0 {
1219 attention.set_sliding_window(lc.sliding_window);
1220 }
1221 if lc.rope_dims < lc.head_dim {
1222 attention.set_rope_freq_dim(lc.head_dim);
1223 }
1224 attention.normalize_v = true;
1226 attention.scale = 1.0;
1234 }
1235
1236 Ok(attention)
1237}
1238
1239fn load_combined_qkv_attention(
1241 source: &dyn ModelSource,
1242 layer_idx: usize,
1243 qkv_weight: Tensor,
1244) -> ModelResult<Attention> {
1245 let prefix = format!("blk.{}", layer_idx);
1246 let config = source.config();
1247 let use_neox_rope = matches!(config.rope_config.rope_type, RopeType::NeoX);
1248 let kl = config.key_length;
1249 let vl = config.value_length;
1250 let rope_dims = config.rope_config.n_dims;
1251 let num_heads = config.num_heads;
1252 let num_kv_heads = config.num_kv_heads;
1253 let head_dim = config.head_dim;
1254
1255 let qkv_shape = qkv_weight.shape();
1257 let in_features = qkv_shape[0];
1258 let q_size = num_heads * head_dim;
1259 let k_size = num_kv_heads * head_dim;
1260 let v_size = num_kv_heads * head_dim;
1261
1262 let qkv_bias = source.try_load_tensor(&format!("{}.attn_qkv.bias", prefix));
1264
1265 if qkv_weight.dtype() == DType::F32 {
1266 let qkv_f32 = qkv_weight.as_f32()?;
1267
1268 let q_start = 0;
1272 let k_start = q_size * in_features;
1273 let v_start = (q_size + k_size) * in_features;
1274
1275 let q_tensor = Tensor::from_f32(
1276 &qkv_f32[q_start..q_start + q_size * in_features],
1277 vec![in_features, q_size],
1278 )?;
1279 let k_tensor = Tensor::from_f32(
1280 &qkv_f32[k_start..k_start + k_size * in_features],
1281 vec![in_features, k_size],
1282 )?;
1283 let v_tensor = Tensor::from_f32(
1284 &qkv_f32[v_start..v_start + v_size * in_features],
1285 vec![in_features, v_size],
1286 )?;
1287
1288 let (q_bias, k_bias, v_bias) = if let Some(ref bias) = qkv_bias {
1290 let b = bias.as_f32()?;
1291 let qb = Tensor::from_f32(&b[..q_size], vec![q_size])?;
1292 let kb = Tensor::from_f32(&b[q_size..q_size + k_size], vec![k_size])?;
1293 let vb = Tensor::from_f32(&b[q_size + k_size..], vec![v_size])?;
1294 (Some(qb), Some(kb), Some(vb))
1295 } else {
1296 (None, None, None)
1297 };
1298
1299 let wq = Linear::new(q_tensor, q_bias)?;
1300 let wk = Linear::new(k_tensor, k_bias)?;
1301 let wv = Linear::new(v_tensor, v_bias)?;
1302
1303 let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
1304 let wo = Linear::new(
1305 source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
1306 wo_bias,
1307 )?;
1308
1309 Ok(Attention::with_kv_dims(
1310 wq, wk, wv, wo,
1311 num_heads, num_kv_heads, head_dim,
1312 kl, vl, rope_dims,
1313 use_neox_rope, false,
1314 ))
1315 } else {
1316 let backend = crate::backend::default_backend();
1319 let numel = qkv_weight.numel();
1320 let mut dequant = Tensor::zeros(vec![numel], DType::F32);
1321 backend
1322 .dequantize(&qkv_weight, &mut dequant)
1323 .map_err(|e| ModelError::ConfigError(format!("Failed to dequantize QKV: {}", e)))?;
1324 let qkv_f32 = dequant.as_f32()?;
1325
1326 let q_start = 0;
1328 let k_start = q_size * in_features;
1329 let v_start = (q_size + k_size) * in_features;
1330
1331 let q_tensor = Tensor::from_f32(
1332 &qkv_f32[q_start..q_start + q_size * in_features],
1333 vec![in_features, q_size],
1334 )?;
1335 let k_tensor = Tensor::from_f32(
1336 &qkv_f32[k_start..k_start + k_size * in_features],
1337 vec![in_features, k_size],
1338 )?;
1339 let v_tensor = Tensor::from_f32(
1340 &qkv_f32[v_start..v_start + v_size * in_features],
1341 vec![in_features, v_size],
1342 )?;
1343
1344 let (q_bias, k_bias, v_bias) = if let Some(ref bias) = qkv_bias {
1345 let b = bias.as_f32()?;
1346 let qb = Tensor::from_f32(&b[..q_size], vec![q_size])?;
1347 let kb = Tensor::from_f32(&b[q_size..q_size + k_size], vec![k_size])?;
1348 let vb = Tensor::from_f32(&b[q_size + k_size..], vec![v_size])?;
1349 (Some(qb), Some(kb), Some(vb))
1350 } else {
1351 (None, None, None)
1352 };
1353
1354 let wq = Linear::new(q_tensor, q_bias)?;
1355 let wk = Linear::new(k_tensor, k_bias)?;
1356 let wv = Linear::new(v_tensor, v_bias)?;
1357
1358 let wo_bias = source.try_load_tensor(&format!("{}.attn_output.bias", prefix));
1359 let wo = Linear::new(
1360 source.load_tensor(&format!("{}.attn_output.weight", prefix))?,
1361 wo_bias,
1362 )?;
1363
1364 Ok(Attention::with_kv_dims(
1365 wq, wk, wv, wo,
1366 num_heads, num_kv_heads, head_dim,
1367 kl, vl, rope_dims,
1368 use_neox_rope, false,
1369 ))
1370 }
1371}
1372
1373fn load_deltanet_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<DeltaNetLayer> {
1375 let prefix = format!("blk.{}", layer_idx);
1376 let cfg = source.config();
1377
1378 let d_inner = cfg.ssm_d_inner;
1379 let d_state = cfg.ssm_d_state;
1380 let num_v_heads = cfg.ssm_dt_rank;
1381 let num_k_heads = cfg.ssm_n_group;
1382 let head_v_dim = d_inner / num_v_heads;
1383 let head_k_dim = d_state;
1384 let conv_kernel = cfg.ssm_conv_kernel;
1385 let q_dim = num_k_heads * head_k_dim;
1386 let k_dim = num_k_heads * head_k_dim;
1387 let qkv_dim = q_dim + k_dim + d_inner;
1388
1389 let dn_config = DeltaNetConfig {
1390 d_inner,
1391 d_state,
1392 num_v_heads,
1393 num_k_heads,
1394 head_v_dim,
1395 head_k_dim,
1396 conv_kernel,
1397 qkv_dim,
1398 };
1399
1400 let attn_qkv = Linear::new(
1401 source.load_tensor(&format!("{}.attn_qkv.weight", prefix))?,
1402 None,
1403 )?;
1404
1405 let attn_gate = Linear::new(
1406 source.load_tensor(&format!("{}.attn_gate.weight", prefix))?,
1407 None,
1408 )?;
1409
1410 let ssm_ba = if let Some(ba_weight) =
1411 source.try_load_tensor(&format!("{}.ssm_ba.weight", prefix))
1412 {
1413 BetaAlphaProjection::Combined(Linear::new(ba_weight, None)?)
1414 } else {
1415 let beta_w = source.load_tensor(&format!("{}.ssm_beta.weight", prefix))?;
1416 let alpha_w = source.load_tensor(&format!("{}.ssm_alpha.weight", prefix))?;
1417 BetaAlphaProjection::Separate {
1418 beta: Linear::new(beta_w, None)?,
1419 alpha: Linear::new(alpha_w, None)?,
1420 }
1421 };
1422
1423 let ssm_conv1d_weight = source.load_tensor(&format!("{}.ssm_conv1d.weight", prefix))?;
1424 let ssm_a = source.load_tensor(&format!("{}.ssm_a", prefix))?;
1425 let ssm_dt_bias = source.load_tensor(&format!("{}.ssm_dt.bias", prefix))?;
1426
1427 let ssm_norm_weight = source.load_tensor(&format!("{}.ssm_norm.weight", prefix))?;
1428 let ssm_norm = RMSNorm::new(ssm_norm_weight, cfg.norm_eps)?;
1429
1430 let ssm_out = Linear::new(
1431 source.load_tensor(&format!("{}.ssm_out.weight", prefix))?,
1432 None,
1433 )?;
1434
1435 tracing::info!("Layer {}: loaded DeltaNet (d_inner={}, d_state={}, v_heads={}, k_heads={}, conv={})",
1436 layer_idx, d_inner, d_state, num_v_heads, num_k_heads, conv_kernel);
1437
1438 Ok(DeltaNetLayer {
1439 config: dn_config,
1440 attn_qkv,
1441 attn_gate,
1442 ssm_ba,
1443 ssm_conv1d_weight,
1444 ssm_a,
1445 ssm_dt_bias,
1446 ssm_norm,
1447 ssm_out,
1448 })
1449}
1450
1451fn load_mamba_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<MambaLayer> {
1455 let prefix = format!("blk.{}", layer_idx);
1456 let cfg = source.config();
1457
1458 let d_inner = cfg.ssm_d_inner;
1459 let d_state = cfg.ssm_d_state;
1460 let dt_rank = cfg.ssm_dt_rank;
1461 let conv_kernel = cfg.ssm_conv_kernel.max(1);
1462
1463 let mamba_config = MambaConfig {
1464 d_inner,
1465 d_state,
1466 dt_rank,
1467 conv_kernel,
1468 };
1469
1470 let ssm_in = Linear::new(
1471 source.load_tensor(&format!("{}.ssm_in.weight", prefix))?,
1472 None,
1473 )?;
1474
1475 let ssm_conv1d_weight = source.load_tensor(&format!("{}.ssm_conv1d.weight", prefix))?;
1476 let ssm_conv1d_bias = source.try_load_tensor(&format!("{}.ssm_conv1d.bias", prefix));
1477
1478 let ssm_x = Linear::new(
1479 source.load_tensor(&format!("{}.ssm_x.weight", prefix))?,
1480 None,
1481 )?;
1482
1483 let ssm_dt = Linear::new(
1484 source.load_tensor(&format!("{}.ssm_dt.weight", prefix))?,
1485 None,
1486 )?;
1487
1488 let ssm_dt_bias = source.load_tensor(&format!("{}.ssm_dt.bias", prefix))?;
1489 let ssm_a = source.load_tensor(&format!("{}.ssm_a", prefix))?;
1490 let ssm_d = source.try_load_tensor(&format!("{}.ssm_d", prefix));
1491
1492 let ssm_norm = match source.try_load_tensor(&format!("{}.ssm_norm.weight", prefix)) {
1493 Some(w) => Some(RMSNorm::new(w, cfg.norm_eps)?),
1494 None => None,
1495 };
1496
1497 let ssm_out = Linear::new(
1498 source.load_tensor(&format!("{}.ssm_out.weight", prefix))?,
1499 None,
1500 )?;
1501
1502 tracing::info!(
1503 "Layer {}: loaded Mamba SSM (d_inner={}, d_state={}, dt_rank={}, conv={})",
1504 layer_idx, d_inner, d_state, dt_rank, conv_kernel
1505 );
1506
1507 Ok(MambaLayer {
1508 ssm_in,
1509 ssm_conv1d_weight,
1510 ssm_conv1d_bias,
1511 ssm_x,
1512 ssm_dt,
1513 ssm_dt_bias,
1514 ssm_a,
1515 ssm_d,
1516 ssm_norm,
1517 ssm_out,
1518 config: mamba_config,
1519 })
1520}
1521
1522fn load_moe_layer(source: &dyn ModelSource, layer_idx: usize) -> ModelResult<FfnLayer> {
1524 let prefix = format!("blk.{}", layer_idx);
1525 let config = source.config();
1526 let num_experts = config.num_experts;
1527 let hidden_dim = config.hidden_size;
1528
1529 let expert_ffn_dim = if config.expert_intermediate_size > 0 {
1532 config.expert_intermediate_size
1533 } else {
1534 config.intermediate_size / config.num_experts_per_token
1535 };
1536
1537 let router_weight = source.load_tensor(&format!("{}.ffn_gate_inp.weight", prefix))?;
1539 let router = MoeRouter::from_weight(
1540 router_weight,
1541 config.num_experts_per_token,
1542 false, );
1544
1545 let gate_exps = source.load_tensor(&format!("{}.ffn_gate_exps.weight", prefix))?;
1548 let up_exps = source.load_tensor(&format!("{}.ffn_up_exps.weight", prefix))?;
1549 let down_exps = source.load_tensor(&format!("{}.ffn_down_exps.weight", prefix))?;
1550
1551 let mut experts = Vec::with_capacity(num_experts);
1552 for e in 0..num_experts {
1553 let mut gate_proj = extract_expert_tensor(&gate_exps, e)?;
1554 let mut up_proj = extract_expert_tensor(&up_exps, e)?;
1555 let mut down_proj = extract_expert_tensor(&down_exps, e)?;
1556
1557 gate_proj.set_name(format!("blk.{}.ffn_gate.{}.weight", layer_idx, e));
1558 up_proj.set_name(format!("blk.{}.ffn_up.{}.weight", layer_idx, e));
1559 down_proj.set_name(format!("blk.{}.ffn_down.{}.weight", layer_idx, e));
1560
1561 experts.push(MoeExpert {
1562 gate_proj,
1563 up_proj,
1564 down_proj,
1565 use_gelu: config.uses_gelu,
1566 });
1567 }
1568
1569 let mut shared_experts = Vec::new();
1571 if let (Some(mut gate_shexp), Some(mut up_shexp), Some(mut down_shexp)) = (
1572 source.try_load_tensor(&format!("{}.ffn_gate_shexp.weight", prefix)),
1573 source.try_load_tensor(&format!("{}.ffn_up_shexp.weight", prefix)),
1574 source.try_load_tensor(&format!("{}.ffn_down_shexp.weight", prefix)),
1575 ) {
1576 gate_shexp.set_name(format!("blk.{}.ffn_gate_shexp.0.weight", layer_idx));
1577 up_shexp.set_name(format!("blk.{}.ffn_up_shexp.0.weight", layer_idx));
1578 down_shexp.set_name(format!("blk.{}.ffn_down_shexp.0.weight", layer_idx));
1579 shared_experts.push(MoeExpert {
1580 gate_proj: gate_shexp,
1581 up_proj: up_shexp,
1582 down_proj: down_shexp,
1583 use_gelu: config.uses_gelu,
1584 });
1585 }
1586
1587 let shared_expert_gate =
1590 source.try_load_tensor(&format!("{}.ffn_gate_inp_shexp.weight", prefix))
1591 .map(|t| {
1592 if t.dtype() == DType::F32 {
1593 t
1594 } else {
1595 let raw = t.data();
1596 let f32_vals: Vec<f32> = match t.dtype() {
1597 DType::BF16 => {
1598 raw.chunks_exact(2)
1599 .map(|c| {
1600 let bits = u16::from_le_bytes([c[0], c[1]]);
1601 f32::from_bits((bits as u32) << 16)
1602 })
1603 .collect()
1604 }
1605 _ => {
1606 tracing::warn!("Unsupported dtype for shared expert gate, zeroing");
1607 vec![0.0f32; t.numel()]
1608 }
1609 };
1610 let shape = t.shape().to_vec();
1611 Tensor::from_f32(&f32_vals, shape).unwrap()
1612 }
1613 });
1614 if shared_expert_gate.is_some() {
1615 tracing::debug!("Layer {}: loaded shared expert gate", layer_idx);
1616 }
1617
1618 let num_shared = shared_experts.len();
1619 let moe_config = MoeConfig {
1620 num_experts,
1621 num_experts_per_token: config.num_experts_per_token,
1622 expert_hidden_dim: expert_ffn_dim,
1623 num_shared_experts: num_shared,
1624 aux_loss_coef: 0.0,
1625 normalize_router_logits: false,
1626 };
1627
1628 let mut moe_layer = MoeLayer::new(hidden_dim, moe_config);
1629 moe_layer.router = router;
1630 moe_layer.experts = experts;
1631 moe_layer.shared_experts = shared_experts;
1632 moe_layer.shared_expert_gate = shared_expert_gate;
1633
1634 Ok(FfnLayer::Moe(moe_layer))
1635}
1636
1637fn extract_expert_tensor(
1643 batched: &Tensor,
1644 expert_idx: usize,
1645) -> ModelResult<Tensor> {
1646 let shape = batched.shape();
1647 if shape.len() != 3 {
1648 return Err(ModelError::ConfigError(format!(
1649 "Expected 3D batched expert tensor, got shape {:?}",
1650 shape
1651 )));
1652 }
1653 let ne0 = shape[0];
1654 let ne1 = shape[1];
1655 let num_experts = shape[2];
1656 let expert_numel = ne0 * ne1;
1657
1658 if expert_idx >= num_experts {
1659 return Err(ModelError::ConfigError(format!(
1660 "Expert index {} out of bounds ({})",
1661 expert_idx, num_experts
1662 )));
1663 }
1664
1665 let per_expert_shape = vec![ne0, ne1];
1666
1667 if batched.dtype().is_quantized() {
1668 let block_size = batched.dtype().block_size();
1669 let block_bytes = batched.dtype().block_bytes();
1670
1671 if !expert_numel.is_multiple_of(block_size) {
1672 return Err(ModelError::ConfigError(format!(
1673 "Expert tensor elements ({}) not aligned to block size ({})",
1674 expert_numel, block_size
1675 )));
1676 }
1677
1678 let blocks_per_expert = expert_numel / block_size;
1679 let bytes_per_expert = blocks_per_expert * block_bytes;
1680 let byte_offset = expert_idx * bytes_per_expert;
1681
1682 let raw_data = batched.data();
1683 let expert_bytes = &raw_data[byte_offset..byte_offset + bytes_per_expert];
1684
1685 let mut tensor =
1686 Tensor::new(expert_bytes.to_vec(), per_expert_shape, batched.dtype())?;
1687 tensor.set_name(format!("expert.{}", expert_idx));
1688 Ok(tensor)
1689 } else {
1690 let f32_data = batched.as_f32()?;
1691 let offset = expert_idx * expert_numel;
1692 let expert_slice = &f32_data[offset..offset + expert_numel];
1693
1694 let mut tensor = Tensor::from_f32(expert_slice, per_expert_shape)?;
1695 tensor.set_name(format!("expert.{}", expert_idx));
1696 Ok(tensor)
1697 }
1698}
1699
1700fn apply_gemma_norm_weight_offset(weight: Tensor) -> ModelResult<Tensor> {
1706 Ok(weight)
1707}
1708
1709pub fn deltanet_config_from_source(source: &dyn ModelSource) -> Option<DeltaNetConfig> {
1712 let config = source.config();
1713 let arch = source.architecture();
1714 if !config.has_ssm()
1715 || matches!(arch, Architecture::Mamba | Architecture::Mamba2)
1716 {
1717 return None;
1718 }
1719 let d_inner = config.ssm_d_inner;
1720 let d_state = config.ssm_d_state;
1721 let num_v_heads = config.ssm_dt_rank;
1722 let num_k_heads = config.ssm_n_group.max(1);
1723 let head_v_dim = d_inner / num_v_heads.max(1);
1724 let head_k_dim = d_state;
1725 let conv_kernel = config.ssm_conv_kernel;
1726 let q_dim = num_k_heads * head_k_dim;
1727 let k_dim = num_k_heads * head_k_dim;
1728 let qkv_dim = q_dim + k_dim + d_inner;
1729
1730 Some(DeltaNetConfig {
1731 d_inner,
1732 d_state,
1733 num_v_heads,
1734 num_k_heads,
1735 head_v_dim,
1736 head_k_dim,
1737 conv_kernel,
1738 qkv_dim,
1739 })
1740}
1741
1742pub fn load_llama_model<P: AsRef<Path>>(path: P) -> ModelResult<LlamaModel> {
1746 let loader = ModelLoader::load(path)?;
1747
1748 if !loader.architecture().is_llama_like() {
1749 return Err(ModelError::UnsupportedArchitecture(
1750 loader.architecture().to_string(),
1751 ));
1752 }
1753
1754 loader.build_model()
1755}
1756
1757#[cfg(test)]
1758mod tests {
1759 use super::*;
1760 use crate::model::config::{AttentionLayerConfig, AttentionLayerType};
1761 use std::collections::HashMap;
1762
1763 struct MockSource {
1766 tensors: HashMap<String, Tensor>,
1767 config: ModelConfig,
1768 arch: Architecture,
1769 }
1770
1771 impl ModelSource for MockSource {
1772 fn config(&self) -> &ModelConfig {
1773 &self.config
1774 }
1775 fn config_mut(&mut self) -> &mut ModelConfig {
1776 &mut self.config
1777 }
1778 fn architecture(&self) -> Architecture {
1779 self.arch
1780 }
1781 fn load_tensor(&self, name: &str) -> ModelResult<Tensor> {
1782 self.tensors
1783 .get(name)
1784 .cloned()
1785 .ok_or_else(|| ModelError::MissingTensor(name.to_string()))
1786 }
1787 fn try_load_tensor(&self, name: &str) -> Option<Tensor> {
1788 self.tensors.get(name).cloned()
1789 }
1790 }
1791
1792 #[test]
1797 fn test_load_full_attention_aliases_v_to_k_when_v_absent() {
1798 let hidden = 3840usize;
1799 let num_heads = 16usize;
1800 let kv_heads = 1usize; let head_dim = 512usize; let kl = head_dim;
1803 let q_out = num_heads * kl; let mut tensors = HashMap::new();
1806 tensors.insert(
1808 "blk.1.attn_k.weight".to_string(),
1809 Tensor::zeros(vec![hidden, kv_heads * kl], DType::F32),
1810 );
1811 tensors.insert(
1812 "blk.1.attn_output.weight".to_string(),
1813 Tensor::zeros(vec![num_heads * head_dim, hidden], DType::F32),
1814 );
1815 let mut config = ModelConfig::default();
1818 config.num_heads = num_heads;
1819 config.hidden_size = hidden;
1820 config.attn_logit_softcap = 0.0;
1821 config.attention_layer_configs = Some(vec![
1822 AttentionLayerConfig {
1823 layer_type: AttentionLayerType::Sliding,
1824 head_dim: 256,
1825 num_kv_heads: 8,
1826 rope_freq_base: 10_000.0,
1827 rope_dims: 256,
1828 sliding_window: 1024,
1829 },
1830 AttentionLayerConfig {
1831 layer_type: AttentionLayerType::Global,
1832 head_dim: 512,
1833 num_kv_heads: 1,
1834 rope_freq_base: 1_000_000.0,
1835 rope_dims: 128,
1836 sliding_window: 0,
1837 },
1838 ]);
1839
1840 let src = MockSource {
1841 tensors,
1842 config,
1843 arch: Architecture::Gemma4,
1844 };
1845 let wq = Tensor::zeros(vec![hidden, q_out], DType::F32);
1846
1847 let attn = load_full_attention(&src, 1, wq)
1848 .expect("global layer with tied K/V should load without attn_v");
1849
1850 assert_eq!(
1851 attn.wv.out_features, attn.wk.out_features,
1852 "V projection must alias K projection when attn_v is absent"
1853 );
1854 assert_eq!(attn.wv.out_features, kv_heads * kl);
1855 }
1856
1857 #[test]
1858 fn test_architecture_detection() {
1859 assert!(Architecture::Llama.is_llama_like());
1860 assert!(Architecture::Mistral.is_llama_like());
1861 assert!(Architecture::GPT2.is_llama_like());
1862 assert!(!Architecture::Bert.is_llama_like());
1863 assert!(!Architecture::Mamba.is_llama_like());
1864 }
1865}