1use crate::config::Qwen35Config;
43use anyhow::{Context, Result, anyhow};
44use rlx_core::weight_loader::{GgufLoader, WeightLoader};
45use rlx_ir::quant::QuantScheme;
46
47#[derive(Debug, Clone)]
56pub enum MatWeight {
57 F32(Vec<f32>),
60 Packed {
70 key: String,
71 scheme: QuantScheme,
72 shape: Vec<usize>,
73 },
74}
75
76impl MatWeight {
77 pub fn len(&self) -> usize {
78 match self {
79 MatWeight::F32(v) => v.len(),
80 MatWeight::Packed { shape, .. } => shape.iter().product(),
81 }
82 }
83 pub fn is_empty(&self) -> bool {
84 self.len() == 0
85 }
86 pub fn shape(&self) -> &[usize] {
89 match self {
90 MatWeight::F32(_) => &[],
91 MatWeight::Packed { shape, .. } => shape,
92 }
93 }
94 pub fn is_packed(&self) -> bool {
95 matches!(self, MatWeight::Packed { .. })
96 }
97 pub fn packed_key(&self) -> Option<&str> {
99 match self {
100 MatWeight::F32(_) => None,
101 MatWeight::Packed { key, .. } => Some(key.as_str()),
102 }
103 }
104}
105
106#[derive(Debug, Clone)]
108#[allow(clippy::large_enum_variant)]
109pub enum Qwen35LayerFfn {
110 Dense {
111 gate: MatWeight,
112 up: MatWeight,
113 down: MatWeight,
114 },
115 Moe(Qwen35MoeFfn),
116}
117
118#[derive(Debug, Clone)]
120pub struct Qwen35MoeFfn {
121 pub router: MatWeight,
123 pub gate_exps: MatWeight,
125 pub up_exps: MatWeight,
126 pub down_exps: MatWeight,
128 pub shared_router: Vec<f32>,
130 pub shared_gate: MatWeight,
131 pub shared_up: MatWeight,
132 pub shared_down: MatWeight,
133}
134
135#[derive(Debug, Clone)]
138pub enum Qwen35TrunkLayer {
139 Linear(Qwen35LinearLayer),
140 FullAttn(Qwen35FullAttnLayer),
141}
142
143#[derive(Debug, Clone)]
146pub struct Qwen35LinearLayer {
147 pub attn_norm: Vec<f32>,
149 pub attn_post_norm: Vec<f32>,
151 pub attn_qkv: MatWeight,
155 pub attn_gate: MatWeight,
157 pub ssm_conv1d: Vec<f32>,
162 pub ssm_dt_bias: Vec<f32>,
164 pub ssm_a: Vec<f32>,
167 pub ssm_beta: MatWeight,
169 pub ssm_alpha: MatWeight,
171 pub ssm_norm: Vec<f32>,
173 pub ssm_out: MatWeight,
175 pub ffn: Qwen35LayerFfn,
176}
177
178#[derive(Debug, Clone)]
182pub struct Qwen35FullAttnLayer {
183 pub attn_norm: Vec<f32>,
184 pub attn_post_norm: Vec<f32>,
185 pub attn_q_gate: MatWeight,
188 pub attn_k: MatWeight,
189 pub attn_v: MatWeight,
190 pub attn_output: MatWeight,
191 pub attn_q_norm: Vec<f32>,
192 pub attn_k_norm: Vec<f32>,
193 pub ffn: Qwen35LayerFfn,
194}
195
196#[derive(Debug, Clone)]
198pub struct Qwen35MtpLayer {
199 pub base: Qwen35FullAttnLayer,
202 pub eh_proj: MatWeight,
204 pub enorm: Vec<f32>,
206 pub hnorm: Vec<f32>,
208 pub embed_tokens: Option<MatWeight>,
211 pub shared_head_head: Option<MatWeight>,
214 pub shared_head_norm: Option<Vec<f32>>,
217}
218
219#[derive(Debug, Clone)]
221pub struct Qwen35Weights {
222 pub token_embd: Vec<f32>,
226 pub output_norm: Vec<f32>,
228 pub output: Option<MatWeight>,
231 pub token_embd_lm: Option<MatWeight>,
235 pub trunk_layers: Vec<Qwen35TrunkLayer>,
236 pub mtp_layers: Vec<Qwen35MtpLayer>,
237}
238
239impl Qwen35Weights {
240 pub fn lm_vocab_size(&self, cfg: &Qwen35Config) -> usize {
243 if self.token_embd.is_empty() || cfg.hidden_size == 0 {
244 return cfg.vocab_size;
245 }
246 self.token_embd.len() / cfg.hidden_size
247 }
248}
249
250impl Qwen35Weights {
251 pub fn from_loader(loader: &mut dyn WeightLoader, cfg: &Qwen35Config) -> Result<Self> {
260 Self::from_loader_inner(loader, cfg, None)
261 }
262
263 pub fn from_loader_packed(loader: &mut GgufLoader, cfg: &Qwen35Config) -> Result<Self> {
271 let pack_via = loader as *mut GgufLoader;
275 Self::from_loader_inner(loader, cfg, Some(pack_via))
276 }
277
278 fn from_loader_inner(
279 loader: &mut dyn WeightLoader,
280 cfg: &Qwen35Config,
281 pack_via: Option<*mut GgufLoader>,
282 ) -> Result<Self> {
283 let n_layer = cfg.num_hidden_layers;
284 let nextn = cfg.nextn_predict_layers;
285 if nextn >= n_layer {
286 return Err(anyhow!(
287 "qwen35: nextn_predict_layers={nextn} must be < num_hidden_layers={n_layer}",
288 ));
289 }
290 let n_main = n_layer - nextn;
291 let interval = cfg.full_attention_interval.max(1);
292
293 let token_embd_lm = pack_via.and_then(|p| peek_gguf_packed_mat(p, "token_embd.weight"));
294 let token_embd = take_f32(loader, "token_embd.weight")?;
295 let output_norm = take_f32(loader, "output_norm.weight")?;
296 let output = take_mat(loader, "output.weight", pack_via).ok();
297
298 let mut trunk_layers = Vec::with_capacity(n_main);
299 for il in 0..n_main {
300 let is_full_attn = ((il + 1) % interval) == 0;
301 if is_full_attn {
302 trunk_layers.push(Qwen35TrunkLayer::FullAttn(load_full_attn_layer(
303 loader, il, cfg, pack_via,
304 )?));
305 } else {
306 trunk_layers.push(Qwen35TrunkLayer::Linear(load_linear_layer(
307 loader, il, cfg, pack_via,
308 )?));
309 }
310 }
311
312 let mut mtp_layers = Vec::with_capacity(nextn);
313 for il in n_main..n_layer {
314 mtp_layers.push(load_mtp_layer(loader, il, cfg, pack_via)?);
315 }
316
317 Ok(Self {
318 token_embd,
319 output_norm,
320 output,
321 token_embd_lm,
322 trunk_layers,
323 mtp_layers,
324 })
325 }
326}
327
328fn peek_gguf_packed_mat(loader: *mut GgufLoader, key: &str) -> Option<MatWeight> {
329 use rlx_gguf::GgmlType;
330 use rlx_ir::quant::QuantScheme;
331 let g = unsafe { &*loader };
332 let t = g.file().get(key)?;
333 let scheme = match t.dtype {
334 GgmlType::Q4K => QuantScheme::GgufQ4K,
335 GgmlType::Q5K => QuantScheme::GgufQ5K,
336 GgmlType::Q6K => QuantScheme::GgufQ6K,
337 GgmlType::Q8K => QuantScheme::GgufQ8K,
338 _ => return None,
339 };
340 let mut shape = t.shape.clone();
341 shape.reverse();
342 Some(MatWeight::Packed {
343 key: key.to_string(),
344 scheme,
345 shape,
346 })
347}
348
349fn take_f32(loader: &mut dyn WeightLoader, key: &str) -> Result<Vec<f32>> {
350 let (data, _shape) = loader
351 .take(key)
352 .with_context(|| format!("missing tensor: {key}"))?;
353 Ok(data)
354}
355
356fn take_mat(
363 loader: &mut dyn WeightLoader,
364 key: &str,
365 pack_via: Option<*mut GgufLoader>,
366) -> Result<MatWeight> {
367 if let Some(p) = pack_via {
368 let g: &mut GgufLoader = unsafe { &mut *p };
373 match g.take_packed_metadata(key) {
374 Ok(Some((scheme, shape))) => {
375 return Ok(MatWeight::Packed {
376 key: key.to_string(),
377 scheme,
378 shape,
379 });
380 }
381 Ok(None) => { }
382 Err(_e) => { }
383 }
384 }
385 let (data, _shape) = loader
386 .take(key)
387 .with_context(|| format!("missing tensor: {key}"))?;
388 Ok(MatWeight::F32(data))
389}
390
391fn take_expert_mat(
394 loader: &mut dyn WeightLoader,
395 key: &str,
396 pack_via: Option<*mut GgufLoader>,
397) -> Result<MatWeight> {
398 if let Some(p) = pack_via {
399 let g: &mut GgufLoader = unsafe { &mut *p };
400 if let Ok(Some((scheme, shape))) = g.take_packed_metadata(key) {
401 if shape.len() == 3 {
402 let n_expert = shape[2];
403 return Ok(MatWeight::Packed {
404 key: key.to_string(),
405 scheme,
406 shape: vec![n_expert, shape[0], shape[1]],
407 });
408 }
409 }
410 }
411 let (data, shape) = loader
412 .take(key)
413 .with_context(|| format!("missing MoE tensor: {key}"))?;
414 if shape.len() != 3 {
415 return Err(anyhow!(
416 "MoE tensor {key}: expected rank-3 GGML shape, got {shape:?}"
417 ));
418 }
419 let n_expert = shape[2];
420 let permuted = permute_ggml_expert_to_grouped(&data, shape[0], shape[1], n_expert);
421 Ok(MatWeight::F32(permuted))
422}
423
424fn permute_ggml_expert_to_grouped(data: &[f32], d0: usize, d1: usize, n_expert: usize) -> Vec<f32> {
425 let mut out = vec![0f32; data.len()];
426 for e in 0..n_expert {
427 for i0 in 0..d0 {
428 for i1 in 0..d1 {
429 let src = i0 + d0 * i1 + d0 * d1 * e;
430 let dst = e * (d0 * d1) + i0 * d1 + i1;
431 out[dst] = data[src];
432 }
433 }
434 }
435 out
436}
437
438fn load_layer_ffn(
439 loader: &mut dyn WeightLoader,
440 il: usize,
441 cfg: &Qwen35Config,
442 pack_via: Option<*mut GgufLoader>,
443) -> Result<Qwen35LayerFfn> {
444 let p = |suffix: &str| format!("blk.{il}.{suffix}");
445 if cfg.is_moe() {
446 Ok(Qwen35LayerFfn::Moe(load_moe_ffn(
447 loader, il, cfg, pack_via,
448 )?))
449 } else {
450 Ok(Qwen35LayerFfn::Dense {
451 gate: take_mat(loader, &p("ffn_gate.weight"), pack_via)?,
452 up: take_mat(loader, &p("ffn_up.weight"), pack_via)?,
453 down: take_mat(loader, &p("ffn_down.weight"), pack_via)?,
454 })
455 }
456}
457
458fn load_moe_ffn(
459 loader: &mut dyn WeightLoader,
460 il: usize,
461 cfg: &Qwen35Config,
462 pack_via: Option<*mut GgufLoader>,
463) -> Result<Qwen35MoeFfn> {
464 let p = |suffix: &str| format!("blk.{il}.{suffix}");
465 let router = take_mat(loader, &p("ffn_gate_inp.weight"), pack_via)?;
466 let down_exps = take_expert_mat(loader, &p("ffn_down_exps.weight"), pack_via)?;
467 let (gate_exps, up_exps) = match (
468 take_expert_mat(loader, &p("ffn_gate_exps.weight"), pack_via),
469 take_expert_mat(loader, &p("ffn_up_exps.weight"), pack_via),
470 ) {
471 (Ok(g), Ok(u)) => (g, u),
472 _ => {
473 let fused = take_expert_mat(loader, &p("ffn_gate_up_exps.weight"), pack_via)?;
474 split_fused_gate_up_exps(fused, cfg)?
475 }
476 };
477 Ok(Qwen35MoeFfn {
478 router,
479 gate_exps,
480 up_exps,
481 down_exps,
482 shared_router: take_f32(loader, &p("ffn_gate_inp_shexp.weight"))?,
483 shared_gate: take_mat(loader, &p("ffn_gate_shexp.weight"), pack_via)?,
484 shared_up: take_mat(loader, &p("ffn_up_shexp.weight"), pack_via)?,
485 shared_down: take_mat(loader, &p("ffn_down_shexp.weight"), pack_via)?,
486 })
487}
488
489fn split_fused_gate_up_exps(
491 fused: MatWeight,
492 cfg: &Qwen35Config,
493) -> Result<(MatWeight, MatWeight)> {
494 let MatWeight::F32(data) = fused else {
495 return Err(anyhow!(
496 "fused gate_up_exps must be F32 after take_expert_mat"
497 ));
498 };
499 let n_ff = cfg.expert_ffn_dim();
500 let n_embd = cfg.hidden_size;
501 let n_expert = cfg.num_experts;
502 let expected = 2 * n_ff * n_embd * n_expert;
503 if data.len() != expected {
504 return Err(anyhow!(
505 "fused gate_up_exps: len {} != 2*{n_ff}*{n_embd}*{n_expert}",
506 data.len()
507 ));
508 }
509 let expert_slab = 2 * n_ff * n_embd;
510 let half = n_ff * n_embd;
511 let mut gate = Vec::with_capacity(n_expert * half);
512 let mut up = Vec::with_capacity(n_expert * half);
513 for e in 0..n_expert {
514 let base = e * expert_slab;
515 gate.extend_from_slice(&data[base..base + half]);
516 up.extend_from_slice(&data[base + half..base + expert_slab]);
517 }
518 Ok((MatWeight::F32(gate), MatWeight::F32(up)))
519}
520
521fn load_linear_layer(
522 loader: &mut dyn WeightLoader,
523 il: usize,
524 cfg: &Qwen35Config,
525 pack_via: Option<*mut GgufLoader>,
526) -> Result<Qwen35LinearLayer> {
527 let p = |suffix: &str| format!("blk.{il}.{suffix}");
528 Ok(Qwen35LinearLayer {
529 attn_norm: take_f32(loader, &p("attn_norm.weight"))?,
530 attn_post_norm: take_f32(loader, &p("post_attention_norm.weight"))?,
531 attn_qkv: take_mat(loader, &p("attn_qkv.weight"), pack_via)?,
532 attn_gate: take_mat(loader, &p("attn_gate.weight"), pack_via)?,
533 ssm_conv1d: take_f32(loader, &p("ssm_conv1d.weight"))?,
534 ssm_dt_bias: take_f32(loader, &p("ssm_dt.bias"))?,
535 ssm_a: take_f32(loader, &p("ssm_a"))?,
536 ssm_beta: take_mat(loader, &p("ssm_beta.weight"), pack_via)?,
537 ssm_alpha: take_mat(loader, &p("ssm_alpha.weight"), pack_via)?,
538 ssm_norm: take_f32(loader, &p("ssm_norm.weight"))?,
539 ssm_out: take_mat(loader, &p("ssm_out.weight"), pack_via)?,
540 ffn: load_layer_ffn(loader, il, cfg, pack_via)?,
541 })
542}
543
544fn load_full_attn_layer(
545 loader: &mut dyn WeightLoader,
546 il: usize,
547 cfg: &Qwen35Config,
548 pack_via: Option<*mut GgufLoader>,
549) -> Result<Qwen35FullAttnLayer> {
550 let p = |suffix: &str| format!("blk.{il}.{suffix}");
551 Ok(Qwen35FullAttnLayer {
552 attn_norm: take_f32(loader, &p("attn_norm.weight"))?,
553 attn_post_norm: take_f32(loader, &p("post_attention_norm.weight"))?,
554 attn_q_gate: take_mat(loader, &p("attn_q.weight"), pack_via)?,
555 attn_k: take_mat(loader, &p("attn_k.weight"), pack_via)?,
556 attn_v: take_mat(loader, &p("attn_v.weight"), pack_via)?,
557 attn_output: take_mat(loader, &p("attn_output.weight"), pack_via)?,
558 attn_q_norm: take_f32(loader, &p("attn_q_norm.weight"))?,
559 attn_k_norm: take_f32(loader, &p("attn_k_norm.weight"))?,
560 ffn: load_layer_ffn(loader, il, cfg, pack_via)?,
561 })
562}
563
564fn load_mtp_layer(
565 loader: &mut dyn WeightLoader,
566 il: usize,
567 cfg: &Qwen35Config,
568 pack_via: Option<*mut GgufLoader>,
569) -> Result<Qwen35MtpLayer> {
570 let base = load_full_attn_layer(loader, il, cfg, pack_via)?;
571 let p = |suffix: &str| format!("blk.{il}.nextn.{suffix}");
572 let eh_proj = take_mat(loader, &p("eh_proj.weight"), pack_via)?;
573 let enorm = take_f32(loader, &p("enorm.weight"))?;
574 let hnorm = take_f32(loader, &p("hnorm.weight"))?;
575 let embed_tokens = take_mat(loader, &p("embed_tokens.weight"), pack_via).ok();
576 let shared_head_head = take_mat(loader, &p("shared_head_head.weight"), pack_via).ok();
577 let shared_head_norm = take_f32(loader, &p("shared_head_norm.weight")).ok();
578 Ok(Qwen35MtpLayer {
579 base,
580 eh_proj,
581 enorm,
582 hnorm,
583 embed_tokens,
584 shared_head_head,
585 shared_head_norm,
586 })
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592 use std::collections::HashMap;
593
594 struct MockLoader {
600 store: HashMap<String, (Vec<f32>, Vec<usize>)>,
601 }
602
603 impl WeightLoader for MockLoader {
604 fn len(&self) -> usize {
605 self.store.len()
606 }
607 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
608 self.store
609 .remove(key)
610 .ok_or_else(|| anyhow!("mock: missing key {key}"))
611 }
612 fn take_transposed(&mut self, _key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
613 unimplemented!("mock loader: not used by qwen35 loader")
614 }
615 fn remaining_keys(&self) -> Vec<String> {
616 self.store.keys().cloned().collect()
617 }
618 }
619
620 fn populate(store: &mut HashMap<String, (Vec<f32>, Vec<usize>)>, key: &str, marker: f32) {
621 store.insert(key.to_string(), (vec![marker], vec![1]));
622 }
623
624 fn build_synth_store(cfg: &Qwen35Config) -> HashMap<String, (Vec<f32>, Vec<usize>)> {
625 let mut store = HashMap::new();
626 populate(&mut store, "token_embd.weight", 1.0);
627 populate(&mut store, "output_norm.weight", 2.0);
628 let n_main = cfg.num_hidden_layers - cfg.nextn_predict_layers;
632 let interval = cfg.full_attention_interval.max(1);
633
634 for il in 0..n_main {
635 let is_full_attn = ((il + 1) % interval) == 0;
636 let p = |suf: &str| format!("blk.{il}.{suf}");
637 if is_full_attn {
638 for k in [
639 "attn_norm.weight",
640 "post_attention_norm.weight",
641 "attn_q.weight",
642 "attn_k.weight",
643 "attn_v.weight",
644 "attn_output.weight",
645 "attn_q_norm.weight",
646 "attn_k_norm.weight",
647 "ffn_gate.weight",
648 "ffn_down.weight",
649 "ffn_up.weight",
650 ] {
651 populate(&mut store, &p(k), 10.0 + il as f32);
652 }
653 } else {
654 for k in [
655 "attn_norm.weight",
656 "post_attention_norm.weight",
657 "attn_qkv.weight",
658 "attn_gate.weight",
659 "ssm_conv1d.weight",
660 "ssm_dt.bias",
661 "ssm_a",
662 "ssm_beta.weight",
663 "ssm_alpha.weight",
664 "ssm_norm.weight",
665 "ssm_out.weight",
666 "ffn_gate.weight",
667 "ffn_down.weight",
668 "ffn_up.weight",
669 ] {
670 populate(&mut store, &p(k), 100.0 + il as f32);
671 }
672 }
673 }
674
675 for il in n_main..cfg.num_hidden_layers {
676 let p = |suf: &str| format!("blk.{il}.{suf}");
677 for k in [
678 "attn_norm.weight",
679 "post_attention_norm.weight",
680 "attn_q.weight",
681 "attn_k.weight",
682 "attn_v.weight",
683 "attn_output.weight",
684 "attn_q_norm.weight",
685 "attn_k_norm.weight",
686 "ffn_gate.weight",
687 "ffn_down.weight",
688 "ffn_up.weight",
689 "nextn.eh_proj.weight",
690 "nextn.enorm.weight",
691 "nextn.hnorm.weight",
692 ] {
693 populate(&mut store, &p(k), 1000.0 + il as f32);
694 }
695 }
696 store
697 }
698
699 fn dummy_cfg() -> Qwen35Config {
700 Qwen35Config {
705 vocab_size: 0,
706 hidden_size: 1024,
707 intermediate_size: 3584,
708 num_hidden_layers: 6,
709 nextn_predict_layers: 1,
710 num_attention_heads: 16,
711 num_key_value_heads: 4,
712 key_length: 128,
713 value_length: 128,
714 max_position_embeddings: 40_960,
715 rms_norm_eps: 1e-6,
716 rope_theta: 10_000_000.0,
717 rope_dim_count: 64,
718 rope_dim_sections: vec![],
719 full_attention_interval: 4,
720 ssm_conv_kernel: 4,
721 ssm_group_count: 16,
722 ssm_inner_size: 2048,
723 ssm_state_size: 128,
724 ssm_time_step_rank: 16,
725 tie_word_embeddings: true,
726 num_experts: 0,
727 num_experts_used: 0,
728 expert_ffn_size: 0,
729 shared_expert_ffn_size: 0,
730 expert_weights_scale: 1.0,
731 }
732 }
733
734 #[test]
738 fn qwen35_weights_loader_classifies_layers_and_loads_mtp() {
739 let cfg = dummy_cfg();
740 let mut loader = MockLoader {
741 store: build_synth_store(&cfg),
742 };
743 let w = Qwen35Weights::from_loader(&mut loader, &cfg).expect("load qwen35 weights");
744
745 assert_eq!(w.trunk_layers.len(), 5); for (i, layer) in w.trunk_layers.iter().enumerate() {
750 let want_full = ((i + 1) % 4) == 0;
751 match (want_full, layer) {
752 (true, Qwen35TrunkLayer::FullAttn(_)) => {}
753 (false, Qwen35TrunkLayer::Linear(_)) => {}
754 _ => panic!(
755 "layer {i}: want_full={want_full}, got {:?}",
756 std::mem::discriminant(layer)
757 ),
758 }
759 }
760
761 assert_eq!(w.mtp_layers.len(), 1);
764 let mtp = &w.mtp_layers[0];
765 assert_eq!(mtp.eh_proj.len(), 1);
768 assert!(matches!(mtp.eh_proj, MatWeight::F32(_)));
769 assert_eq!(mtp.enorm.len(), 1);
770 assert_eq!(mtp.hnorm.len(), 1);
771 assert!(mtp.embed_tokens.is_none());
772 assert!(mtp.shared_head_head.is_none());
773 assert!(mtp.shared_head_norm.is_none());
774
775 assert!(w.output.is_none());
779 assert_eq!(w.token_embd.len(), 1);
780 assert_eq!(w.output_norm.len(), 1);
781 }
782
783 #[test]
785 fn qwen35_weights_loader_reports_missing_tensor_key() {
786 let cfg = dummy_cfg();
787 let mut store = build_synth_store(&cfg);
788 store.remove("blk.2.ssm_conv1d.weight");
789 let mut loader = MockLoader { store };
790 let err = Qwen35Weights::from_loader(&mut loader, &cfg).expect_err("must error");
791 let msg = format!("{err:#}");
792 assert!(
793 msg.contains("blk.2.ssm_conv1d.weight"),
794 "error message must point at the missing key: {msg}"
795 );
796 }
797}