1use anyhow::{Context, Result, anyhow};
30use rlx_gguf::MetaValue;
31use std::collections::HashSet;
32use std::path::Path;
33
34fn compute_mtp_layer_threshold(file: &rlx_gguf::GgufFile) -> Option<u32> {
38 let arch = file
39 .metadata
40 .get("general.architecture")
41 .and_then(MetaValue::as_str)?;
42 let block_count = file
43 .metadata
44 .get(&format!("{arch}.block_count"))
45 .and_then(MetaValue::as_u32)?;
46 let nextn = file
47 .metadata
48 .get(&format!("{arch}.nextn_predict_layers"))
49 .and_then(MetaValue::as_u32)?;
50 if nextn == 0 {
51 return None;
52 }
53 Some(block_count.saturating_sub(nextn))
54}
55
56use crate::gguf_resolve::resolve_gguf_tensor_name;
57use crate::gguf_support::gguf_architecture_str;
58use crate::weight_map::PackedWeightTensor;
59use crate::weight_map::WeightMap;
60use rlx_ir::quant::QuantScheme;
61
62pub fn hf_to_gguf_name(hf: &str) -> Option<String> {
73 match hf {
75 "model.embed_tokens.weight" => return Some("token_embd.weight".into()),
76 "model.norm.weight" => return Some("output_norm.weight".into()),
77 "lm_head.weight" => return Some("output.weight".into()),
78 _ => {}
79 }
80 let rest = hf.strip_prefix("model.layers.")?;
82 let dot = rest.find('.')?;
83 let (idx_str, tail_with_dot) = rest.split_at(dot);
84 let tail = &tail_with_dot[1..]; let idx: usize = idx_str.parse().ok()?;
86 let gguf_tail = match tail {
87 "input_layernorm.weight" => "attn_norm.weight",
88 "post_attention_layernorm.weight" => "ffn_norm.weight",
93 "self_attn.q_proj.weight" => "attn_q.weight",
94 "self_attn.k_proj.weight" => "attn_k.weight",
95 "self_attn.v_proj.weight" => "attn_v.weight",
96 "self_attn.o_proj.weight" => "attn_output.weight",
97 "self_attn.q_proj.bias" => "attn_q.bias",
98 "self_attn.k_proj.bias" => "attn_k.bias",
99 "self_attn.v_proj.bias" => "attn_v.bias",
100 "self_attn.q_norm.weight" => "attn_q_norm.weight",
101 "self_attn.k_norm.weight" => "attn_k_norm.weight",
102 "mlp.gate_proj.weight" => "ffn_gate.weight",
103 "mlp.up_proj.weight" => "ffn_up.weight",
104 "mlp.down_proj.weight" => "ffn_down.weight",
105 _ => return None,
106 };
107 Some(format!("blk.{idx}.{gguf_tail}"))
108}
109
110pub fn gguf_to_hf_name(gguf: &str) -> Option<String> {
116 match gguf {
117 "token_embd.weight" => return Some("model.embed_tokens.weight".into()),
118 "output_norm.weight" => return Some("model.norm.weight".into()),
119 "output.weight" => return Some("lm_head.weight".into()),
120 _ => {}
121 }
122 let rest = gguf.strip_prefix("blk.")?;
123 let dot = rest.find('.')?;
124 let (idx_str, tail_with_dot) = rest.split_at(dot);
125 let tail = &tail_with_dot[1..];
126 let idx: usize = idx_str.parse().ok()?;
127 let hf_tail = match tail {
128 "attn_norm.weight" => "input_layernorm.weight",
129 "ffn_norm.weight" => "post_attention_layernorm.weight",
130 "attn_q.weight" => "self_attn.q_proj.weight",
131 "attn_k.weight" => "self_attn.k_proj.weight",
132 "attn_v.weight" => "self_attn.v_proj.weight",
133 "attn_output.weight" => "self_attn.o_proj.weight",
134 "attn_q.bias" => "self_attn.q_proj.bias",
135 "attn_k.bias" => "self_attn.k_proj.bias",
136 "attn_v.bias" => "self_attn.v_proj.bias",
137 "attn_q_norm.weight" => "self_attn.q_norm.weight",
138 "attn_k_norm.weight" => "self_attn.k_norm.weight",
139 "ffn_gate.weight" => "mlp.gate_proj.weight",
140 "ffn_up.weight" => "mlp.up_proj.weight",
141 "ffn_down.weight" => "mlp.down_proj.weight",
142 _ => return None,
143 };
144 Some(format!("model.layers.{idx}.{hf_tail}"))
145}
146
147pub fn gguf_to_hf_name_for_arch(gguf: &str, arch: &str) -> Option<String> {
155 if matches!(
156 arch,
157 "gemma2" | "gemma3" | "gemma3n" | "gemma4" | "gemma4moe"
158 ) {
159 match gguf {
160 "token_embd.weight" => return Some("model.embed_tokens.weight".into()),
161 "output_norm.weight" => return Some("model.norm.weight".into()),
162 "output.weight" => return Some("lm_head.weight".into()),
163 _ => {}
164 }
165 let rest = gguf.strip_prefix("blk.")?;
166 let dot = rest.find('.')?;
167 let (idx_str, tail_with_dot) = rest.split_at(dot);
168 let tail = &tail_with_dot[1..];
169 let idx: usize = idx_str.parse().ok()?;
170 let hf_tail = match tail {
171 "attn_norm.weight" => "input_layernorm.weight",
172 "post_attention_norm.weight" => "post_attention_layernorm.weight",
173 "ffn_norm.weight" => "pre_feedforward_layernorm.weight",
174 "post_ffw_norm.weight" => "post_feedforward_layernorm.weight",
175 "attn_q.weight" => "self_attn.q_proj.weight",
176 "attn_k.weight" => "self_attn.k_proj.weight",
177 "attn_v.weight" => "self_attn.v_proj.weight",
178 "attn_output.weight" => "self_attn.o_proj.weight",
179 "ffn_gate.weight" => "mlp.gate_proj.weight",
180 "ffn_up.weight" => "mlp.up_proj.weight",
181 "ffn_down.weight" => "mlp.down_proj.weight",
182 _ => return None,
183 };
184 return Some(format!("model.layers.{idx}.{hf_tail}"));
185 }
186 gguf_to_hf_name(gguf)
187}
188
189fn is_gemma_norm_weight(name: &str) -> bool {
195 if name == "output_norm.weight" || name == "model.norm.weight" {
196 return true;
197 }
198 if let Some(rest) = name
199 .strip_prefix("blk.")
200 .and_then(|r| r.split_once('.').map(|x| x.1))
201 {
202 return matches!(
203 rest,
204 "attn_norm.weight"
205 | "post_attention_norm.weight"
206 | "ffn_norm.weight"
207 | "post_ffw_norm.weight"
208 );
209 }
210 if let Some(rest) = name
211 .strip_prefix("model.layers.")
212 .and_then(|r| r.split_once('.').map(|x| x.1))
213 {
214 return matches!(
215 rest,
216 "input_layernorm.weight"
217 | "post_attention_layernorm.weight"
218 | "pre_feedforward_layernorm.weight"
219 | "post_feedforward_layernorm.weight"
220 );
221 }
222 false
223}
224
225pub fn is_mtp_weight(name: &str) -> bool {
238 name.contains("mtp_") || name.contains(".mtp") || name.starts_with("mtp")
239}
240
241pub fn ggml_type_to_quant_scheme(dtype: rlx_gguf::GgmlType) -> Option<QuantScheme> {
243 use rlx_gguf::GgmlType;
244 match dtype {
245 GgmlType::Q2K => Some(QuantScheme::GgufQ2K),
246 GgmlType::Q3K => Some(QuantScheme::GgufQ3K),
247 GgmlType::Q4K => Some(QuantScheme::GgufQ4K),
248 GgmlType::Q5K => Some(QuantScheme::GgufQ5K),
249 GgmlType::Q6K => Some(QuantScheme::GgufQ6K),
250 GgmlType::Q8K => Some(QuantScheme::GgufQ8K),
251 GgmlType::Q4_0 => Some(QuantScheme::GgufQ4_0),
252 GgmlType::Q8_0 => Some(QuantScheme::GgufQ8_0),
253 _ => None,
254 }
255}
256
257pub trait WeightLoader: Send {
263 fn format_id(&self) -> &'static str {
265 "unknown"
266 }
267 fn len(&self) -> usize;
269 fn is_empty(&self) -> bool {
270 self.len() == 0
271 }
272 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
275 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
280 fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
282 let _ = key;
283 Ok(None)
284 }
285 fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
287 let _ = key;
288 None
289 }
290 fn remaining_keys(&self) -> Vec<String>;
293 fn arch_hint(&self) -> Option<&str> {
299 None
300 }
301}
302
303impl WeightLoader for WeightMap {
304 fn format_id(&self) -> &'static str {
305 "safetensors"
306 }
307 fn len(&self) -> usize {
308 Self::len(self)
309 }
310 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
311 Self::take(self, key)
312 }
313 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
314 Self::take_transposed(self, key)
315 }
316 fn remaining_keys(&self) -> Vec<String> {
317 self.keys().map(|s| s.to_string()).collect()
318 }
319}
320
321pub struct HfTranslatingLoader<L: WeightLoader> {
334 inner: L,
335}
336
337impl<L: WeightLoader> HfTranslatingLoader<L> {
338 pub fn new(inner: L) -> Self {
339 Self { inner }
340 }
341 pub fn into_inner(self) -> L {
342 self.inner
343 }
344 pub fn inner(&self) -> &L {
345 &self.inner
346 }
347 pub fn inner_mut(&mut self) -> &mut L {
348 &mut self.inner
349 }
350}
351
352impl<L: WeightLoader> WeightLoader for HfTranslatingLoader<L> {
353 fn format_id(&self) -> &'static str {
354 self.inner.format_id()
355 }
356 fn len(&self) -> usize {
357 self.inner.len()
358 }
359 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
360 match self.inner.take(key) {
361 Ok(v) => Ok(v),
362 Err(_) => {
363 if let Some(hf) = gguf_to_hf_name(key) {
364 return self.inner.take(&hf);
365 }
366 self.inner.take(key)
367 }
368 }
369 }
370 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
371 match self.inner.take_transposed(key) {
372 Ok(v) => Ok(v),
373 Err(_) => {
374 if let Some(hf) = gguf_to_hf_name(key) {
375 return self.inner.take_transposed(&hf);
376 }
377 self.inner.take_transposed(key)
378 }
379 }
380 }
381 fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
382 self.inner.take_packed(key)
383 }
384 fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
385 self.inner.tensor_bytes_borrowed(key)
386 }
387 fn remaining_keys(&self) -> Vec<String> {
388 self.inner.remaining_keys()
389 }
390}
391
392pub fn load_from_path(path: &str) -> Result<Box<dyn WeightLoader>> {
394 crate::weight_registry::open_weight_loader(Path::new(path))
395}
396
397pub struct GgufLoader {
405 file: rlx_gguf::GgufFile,
406 arch: String,
407 taken: HashSet<String>,
408 include_mtp: bool,
415 mtp_layer_threshold: Option<u32>,
420}
421
422impl GgufLoader {
423 pub fn from_file(path: &str) -> Result<Self> {
424 let file = crate::gguf_support::load_gguf_file(std::path::Path::new(path))?;
425 let arch = gguf_architecture_str(&file)
426 .unwrap_or("unknown")
427 .to_string();
428 let mtp_layer_threshold = compute_mtp_layer_threshold(&file);
429 Ok(Self {
430 file,
431 arch,
432 taken: HashSet::new(),
433 include_mtp: false,
434 mtp_layer_threshold,
435 })
436 }
437
438 pub fn architecture(&self) -> &str {
439 &self.arch
440 }
441
442 pub fn mtp_layer_threshold(&self) -> Option<u32> {
449 self.mtp_layer_threshold
450 }
451
452 pub fn file(&self) -> &rlx_gguf::GgufFile {
456 &self.file
457 }
458
459 pub fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
466 let real = self.resolve(key).ok()?;
467 let t = self.file.get(&real)?;
468 self.file.tensor_bytes(t).ok()
469 }
470
471 pub fn take_packed_metadata(
479 &mut self,
480 key: &str,
481 ) -> Result<Option<(rlx_ir::quant::QuantScheme, Vec<usize>)>> {
482 let real = self.resolve(key)?;
483 if self.taken.contains(&real) {
484 return Err(anyhow!("weight already taken: {key} (→ {real})"));
485 }
486 if !self.include_mtp && self.is_mtp_tensor(&real) {
487 return Err(anyhow!(
488 "refusing to take MTP weight `{real}` without include_mtp(true)"
489 ));
490 }
491 let t = self
492 .file
493 .get(&real)
494 .ok_or_else(|| anyhow!("tensor missing: {real}"))?;
495 let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
496 return Ok(None);
497 };
498 let mut shape = t.shape.clone();
499 shape.reverse();
500 self.taken.insert(real);
501 Ok(Some((scheme, shape)))
502 }
503
504 pub fn is_mtp_tensor(&self, name: &str) -> bool {
508 if is_mtp_weight(name) {
509 return true;
510 }
511 if let Some(thresh) = self.mtp_layer_threshold {
512 if let Some(rest) = name.strip_prefix("blk.") {
513 if let Some(dot) = rest.find('.') {
514 if let Ok(idx) = rest[..dot].parse::<u32>() {
515 if idx >= thresh {
516 return true;
517 }
518 }
519 }
520 }
521 }
522 false
523 }
524
525 pub fn include_mtp(&mut self, include: bool) -> &mut Self {
533 self.include_mtp = include;
534 self
535 }
536
537 pub fn take_packed(&mut self, key: &str) -> Result<Option<PackedWeightTensor>> {
551 let real = self.resolve(key)?;
552 if self.taken.contains(&real) {
553 return Err(anyhow!("weight already taken: {key} (→ {real})"));
554 }
555 if !self.include_mtp && self.is_mtp_tensor(&real) {
556 return Err(anyhow!(
557 "refusing to take MTP weight `{real}` without include_mtp(true)"
558 ));
559 }
560 let t = self
561 .file
562 .get(&real)
563 .ok_or_else(|| anyhow!("tensor missing: {real}"))?;
564 let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
569 return Ok(None);
570 };
571 let bytes = self
572 .file
573 .tensor_bytes(t)
574 .with_context(|| format!("read packed bytes for {real}"))?
575 .to_vec();
576 let mut shape = t.shape.clone();
577 shape.reverse();
581 self.taken.insert(real);
582 Ok(Some((bytes, scheme, shape)))
583 }
584
585 pub fn take_mtp(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
590 if !self.is_mtp_tensor(key) {
591 return Err(anyhow!("not an MTP weight under this file's scheme: {key}"));
592 }
593 if !self.file.tensors.contains_key(key) {
594 return Err(anyhow!("MTP weight not found in GGUF: {key}"));
595 }
596 if self.taken.contains(key) {
597 return Err(anyhow!("MTP weight already taken: {key}"));
598 }
599 let (data, raw_shape) = self.file.dequant_f32(key)?;
600 self.taken.insert(key.to_string());
601 let mut shape = raw_shape;
602 shape.reverse();
603 Ok((data, shape))
604 }
605}
606
607impl GgufLoader {
608 fn resolve(&self, key: &str) -> Result<String> {
611 resolve_gguf_tensor_name(&self.file, &self.arch, key)
612 .ok_or_else(|| anyhow!("weight not found in GGUF (arch={}): {key}", self.arch))
613 }
614}
615
616impl WeightLoader for GgufLoader {
617 fn format_id(&self) -> &'static str {
618 "gguf"
619 }
620 fn arch_hint(&self) -> Option<&str> {
621 Some(&self.arch)
622 }
623 fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
624 self.take_packed(key)
625 }
626 fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
627 GgufLoader::tensor_bytes_borrowed(self, key)
628 }
629 fn len(&self) -> usize {
630 self.file
631 .tensors
632 .keys()
633 .filter(|k| !self.taken.contains(*k) && (self.include_mtp || !self.is_mtp_tensor(k)))
634 .count()
635 }
636 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
637 let real = self.resolve(key)?;
638 if self.taken.contains(&real) {
639 return Err(anyhow!("weight already taken: {key} (→ {real})"));
640 }
641 if !self.include_mtp && self.is_mtp_tensor(&real) {
642 return Err(anyhow!(
643 "refusing to take MTP weight `{real}` without include_mtp(true); \
644 use loader.take_mtp(...) for explicit MTP grabs or \
645 loader.include_mtp(true) to include them in drains"
646 ));
647 }
648 let (mut data, raw_shape) = self.file.dequant_f32(&real)?;
649 self.taken.insert(real.clone());
650 if matches!(
659 self.arch.as_str(),
660 "gemma" | "gemma2" | "gemma3" | "gemma3n" | "gemma4"
661 ) && is_gemma_norm_weight(&real)
662 {
663 for v in data.iter_mut() {
664 *v -= 1.0;
665 }
666 }
667 let mut shape = raw_shape;
673 shape.reverse();
674 Ok((data, shape))
675 }
676 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
686 let (data, shape) = self.take(key)?;
689 if shape.len() != 2 {
690 return Err(anyhow!("transpose requires 2D, got {shape:?}"));
691 }
692 let (rows, cols) = (shape[0], shape[1]);
693 let mut t = vec![0f32; data.len()];
694 for i in 0..rows {
695 for j in 0..cols {
696 t[j * rows + i] = data[i * cols + j];
697 }
698 }
699 Ok((t, vec![cols, rows]))
700 }
701 fn remaining_keys(&self) -> Vec<String> {
702 self.file
707 .tensors
708 .keys()
709 .filter(|k| {
710 !self.taken.contains(k.as_str()) && (self.include_mtp || !self.is_mtp_tensor(k))
711 })
712 .cloned()
713 .collect()
714 }
715}
716
717impl GgufLoader {
718 pub fn mtp_keys(&self) -> Vec<String> {
725 self.file
726 .tensors
727 .keys()
728 .filter(|k| self.is_mtp_tensor(k))
729 .cloned()
730 .collect()
731 }
732}
733
734#[cfg(test)]
735mod tests {
736 use super::*;
737
738 #[test]
739 fn unknown_extension_errors() {
740 let r = load_from_path("/tmp/no-such-thing.bin");
741 match r {
742 Err(e) => assert!(e.to_string().contains("unsupported")),
743 Ok(_) => panic!("expected error"),
744 }
745 }
746
747 #[test]
748 fn hf_to_gguf_top_level() {
749 assert_eq!(
750 hf_to_gguf_name("model.embed_tokens.weight").as_deref(),
751 Some("token_embd.weight")
752 );
753 assert_eq!(
754 hf_to_gguf_name("model.norm.weight").as_deref(),
755 Some("output_norm.weight")
756 );
757 assert_eq!(
758 hf_to_gguf_name("lm_head.weight").as_deref(),
759 Some("output.weight")
760 );
761 }
762
763 #[test]
764 fn hf_to_gguf_per_layer() {
765 let cases = [
766 (
767 "model.layers.0.self_attn.q_proj.weight",
768 "blk.0.attn_q.weight",
769 ),
770 (
771 "model.layers.7.self_attn.o_proj.weight",
772 "blk.7.attn_output.weight",
773 ),
774 (
775 "model.layers.3.mlp.gate_proj.weight",
776 "blk.3.ffn_gate.weight",
777 ),
778 (
779 "model.layers.12.mlp.down_proj.weight",
780 "blk.12.ffn_down.weight",
781 ),
782 (
783 "model.layers.4.input_layernorm.weight",
784 "blk.4.attn_norm.weight",
785 ),
786 (
787 "model.layers.4.post_attention_layernorm.weight",
788 "blk.4.ffn_norm.weight",
789 ),
790 (
791 "model.layers.0.self_attn.q_norm.weight",
792 "blk.0.attn_q_norm.weight",
793 ),
794 ];
795 for (hf, gguf) in cases {
796 assert_eq!(
797 hf_to_gguf_name(hf).as_deref(),
798 Some(gguf),
799 "mismatch for {hf}"
800 );
801 }
802 }
803
804 #[test]
805 fn hf_to_gguf_unknown_returns_none() {
806 assert!(hf_to_gguf_name("model.layers.0.some_new_thing.weight").is_none());
807 assert!(hf_to_gguf_name("model.layers.foo.input_layernorm.weight").is_none());
808 }
809
810 #[test]
811 fn mtp_detection() {
812 assert!(is_mtp_weight("mtp_blk.0.attn_q.weight"));
813 assert!(is_mtp_weight("output_mtp_0.weight"));
814 assert!(is_mtp_weight("model.layers.0.mtp_head.weight"));
815 assert!(!is_mtp_weight("blk.0.attn_q.weight"));
816 assert!(!is_mtp_weight("output.weight"));
817 }
818
819 #[test]
826 fn ggml_q4_0_maps_to_packed_scheme() {
827 use rlx_gguf::GgmlType;
828 assert_eq!(
829 ggml_type_to_quant_scheme(GgmlType::Q4_0),
830 Some(rlx_ir::quant::QuantScheme::GgufQ4_0)
831 );
832 assert_eq!(
833 ggml_type_to_quant_scheme(GgmlType::Q8_0),
834 Some(rlx_ir::quant::QuantScheme::GgufQ8_0)
835 );
836 }
837
838 #[test]
839 fn gguf_loader_threshold_based_mtp_detection() {
840 let mut buf: Vec<u8> = Vec::new();
841 buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
842 buf.extend_from_slice(&3u32.to_le_bytes());
843 buf.extend_from_slice(&3u64.to_le_bytes()); buf.extend_from_slice(&3u64.to_le_bytes()); let write_string = |buf: &mut Vec<u8>, k: &str, v: &str| {
847 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
848 buf.extend_from_slice(k.as_bytes());
849 buf.extend_from_slice(&8u32.to_le_bytes());
850 buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
851 buf.extend_from_slice(v.as_bytes());
852 };
853 let write_u32 = |buf: &mut Vec<u8>, k: &str, v: u32| {
854 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
855 buf.extend_from_slice(k.as_bytes());
856 buf.extend_from_slice(&4u32.to_le_bytes()); buf.extend_from_slice(&v.to_le_bytes());
858 };
859 write_string(&mut buf, "general.architecture", "qwen35");
860 write_u32(&mut buf, "qwen35.block_count", 25);
861 write_u32(&mut buf, "qwen35.nextn_predict_layers", 1);
862 let write_tensor = |buf: &mut Vec<u8>, name: &str, shape: &[usize], off: u64| {
865 buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
866 buf.extend_from_slice(name.as_bytes());
867 buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
868 for &d in shape {
869 buf.extend_from_slice(&(d as u64).to_le_bytes());
870 }
871 buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&off.to_le_bytes());
873 };
874 write_tensor(&mut buf, "blk.0.attn_q.weight", &[4, 4], 0);
875 write_tensor(&mut buf, "blk.24.attn_q.weight", &[4, 4], 64);
876 write_tensor(&mut buf, "token_embd.weight", &[4, 4], 128);
877 while !buf
878 .len()
879 .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
880 {
881 buf.push(0);
882 }
883 for _ in 0..(3 * 16) {
885 buf.extend_from_slice(&0.5f32.to_le_bytes());
886 }
887 let path = std::env::temp_dir().join("rlx_mtp_threshold_test.gguf");
888 std::fs::write(&path, &buf).unwrap();
889 let loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
890
891 assert_eq!(loader.mtp_layer_threshold(), Some(24));
892 assert!(!loader.is_mtp_tensor("blk.0.attn_q.weight"));
893 assert!(loader.is_mtp_tensor("blk.24.attn_q.weight"));
894 assert!(!loader.is_mtp_tensor("token_embd.weight"));
895 let mtp = loader.mtp_keys();
896 assert_eq!(mtp, vec!["blk.24.attn_q.weight".to_string()]);
897
898 std::fs::remove_file(&path).ok();
899 }
900
901 #[test]
908 fn gguf_loader_resolves_hf_names_and_skips_mtp() {
909 let mut tensors = Vec::new();
910 let mut data: Vec<f32> = Vec::new();
911
912 let t1: Vec<f32> = (0..12).map(|x| x as f32).collect();
914 tensors.push(("token_embd.weight", vec![3usize, 4], data.len()));
915 data.extend_from_slice(&t1);
916
917 let t2: Vec<f32> = (100..116).map(|x| x as f32).collect();
919 tensors.push(("blk.0.attn_q.weight", vec![4usize, 4], data.len()));
920 data.extend_from_slice(&t2);
921
922 let t3: Vec<f32> = vec![0.5f32; 8];
924 tensors.push(("output_mtp_0.weight", vec![2usize, 4], data.len()));
925 data.extend_from_slice(&t3);
926
927 let mut buf: Vec<u8> = Vec::new();
929 buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
930 buf.extend_from_slice(&3u32.to_le_bytes()); buf.extend_from_slice(&(tensors.len() as u64).to_le_bytes());
932 buf.extend_from_slice(&0u64.to_le_bytes()); for (name, shape, _) in &tensors {
936 buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
937 buf.extend_from_slice(name.as_bytes());
938 buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
939 for &d in shape {
940 buf.extend_from_slice(&(d as u64).to_le_bytes());
941 }
942 buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes());
945 }
946 while !buf
948 .len()
949 .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
950 {
951 buf.push(0);
952 }
953 let data_start = buf.len();
954 for v in &data {
955 buf.extend_from_slice(&v.to_le_bytes());
956 }
957 let header = (4 + 4 + 8 + 8) as usize; let mut cursor = header;
960 for (name, shape, byte_off) in &tensors {
961 let name_len_bytes = 8;
962 let name_bytes = name.len();
963 let n_dims_bytes = 4;
964 let dims_bytes = shape.len() * 8;
965 let dtype_bytes = 4;
966 let off_bytes = 8;
967 let info_size =
968 name_len_bytes + name_bytes + n_dims_bytes + dims_bytes + dtype_bytes + off_bytes;
969 let off_field_at = cursor + info_size - off_bytes;
970 let final_off = (*byte_off * 4) as u64; for i in 0..8 {
972 buf[off_field_at + i] = (final_off >> (i * 8)) as u8;
973 }
974 cursor += info_size;
975 }
976 let _ = data_start;
977
978 let path = std::env::temp_dir().join("rlx_test_qwen3_mini.gguf");
980 std::fs::write(&path, &buf).unwrap();
981
982 let mut loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
983 assert_eq!(loader.len(), 2);
985
986 let (out, shape) = loader
992 .take("model.embed_tokens.weight")
993 .expect("hf-named token_embd should resolve");
994 assert_eq!(shape, vec![4, 3]);
995 assert_eq!(&out, &t1);
996
997 let (out, shape) = loader
998 .take("model.layers.0.self_attn.q_proj.weight")
999 .expect("hf-named attn_q should resolve");
1000 assert_eq!(shape, vec![4, 4]);
1001 assert_eq!(&out, &t2);
1002
1003 assert_eq!(loader.remaining_keys(), Vec::<String>::new());
1005 assert_eq!(loader.mtp_keys(), vec!["output_mtp_0.weight".to_string()]);
1006
1007 let mut loader2 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1011 loader2.include_mtp(true);
1012 let visible: std::collections::HashSet<String> =
1013 loader2.remaining_keys().into_iter().collect();
1014 assert!(visible.contains("token_embd.weight"));
1015 assert!(visible.contains("blk.0.attn_q.weight"));
1016 assert!(
1017 visible.contains("output_mtp_0.weight"),
1018 "MTP weight should be visible with include_mtp(true)"
1019 );
1020 let (mtp_data, mtp_shape) = loader2.take_mtp("output_mtp_0.weight").unwrap();
1021 assert_eq!(mtp_shape, vec![4, 2]);
1022 assert_eq!(mtp_data, t3);
1023
1024 let mut loader3 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1026 let err = loader3.take("output_mtp_0.weight").unwrap_err();
1027 let msg = format!("{err:#}");
1028 assert!(
1029 msg.contains("include_mtp(true)"),
1030 "expected MTP guard error, got: {msg}"
1031 );
1032
1033 std::fs::remove_file(&path).ok();
1034 }
1035
1036 #[test]
1037 fn missing_gguf_file_errors() {
1038 let r = load_from_path("/tmp/no-such-thing-rlx-gguf-test.gguf");
1041 assert!(r.is_err());
1042 }
1043}