1use anyhow::{Context, Result, anyhow};
30use rlx_gguf::MetaValue;
31use std::collections::HashSet;
32use std::path::Path;
33
34fn compute_mtp_layer_threshold(file: &rlx_gguf::GgufFile) -> Option<u32> {
38 let arch = file
39 .metadata
40 .get("general.architecture")
41 .and_then(MetaValue::as_str)?;
42 let block_count = file
43 .metadata
44 .get(&format!("{arch}.block_count"))
45 .and_then(MetaValue::as_u32)?;
46 let nextn = file
47 .metadata
48 .get(&format!("{arch}.nextn_predict_layers"))
49 .and_then(MetaValue::as_u32)?;
50 if nextn == 0 {
51 return None;
52 }
53 Some(block_count.saturating_sub(nextn))
54}
55
56use crate::gguf_resolve::resolve_gguf_tensor_name;
57use crate::gguf_support::gguf_architecture_str;
58use crate::weight_map::PackedWeightTensor;
59use crate::weight_map::WeightMap;
60use rlx_ir::quant::QuantScheme;
61
62pub fn hf_to_gguf_name(hf: &str) -> Option<String> {
73 match hf {
75 "model.embed_tokens.weight" => return Some("token_embd.weight".into()),
76 "model.norm.weight" => return Some("output_norm.weight".into()),
77 "lm_head.weight" => return Some("output.weight".into()),
78 _ => {}
79 }
80 let rest = hf.strip_prefix("model.layers.")?;
82 let dot = rest.find('.')?;
83 let (idx_str, tail_with_dot) = rest.split_at(dot);
84 let tail = &tail_with_dot[1..]; let idx: usize = idx_str.parse().ok()?;
86 let gguf_tail = match tail {
87 "input_layernorm.weight" => "attn_norm.weight",
88 "post_attention_layernorm.weight" => "ffn_norm.weight",
93 "self_attn.q_proj.weight" => "attn_q.weight",
94 "self_attn.k_proj.weight" => "attn_k.weight",
95 "self_attn.v_proj.weight" => "attn_v.weight",
96 "self_attn.o_proj.weight" => "attn_output.weight",
97 "self_attn.q_proj.bias" => "attn_q.bias",
98 "self_attn.k_proj.bias" => "attn_k.bias",
99 "self_attn.v_proj.bias" => "attn_v.bias",
100 "self_attn.q_norm.weight" => "attn_q_norm.weight",
101 "self_attn.k_norm.weight" => "attn_k_norm.weight",
102 "mlp.gate_proj.weight" => "ffn_gate.weight",
103 "mlp.up_proj.weight" => "ffn_up.weight",
104 "mlp.down_proj.weight" => "ffn_down.weight",
105 _ => return None,
106 };
107 Some(format!("blk.{idx}.{gguf_tail}"))
108}
109
110pub fn gguf_to_hf_name(gguf: &str) -> Option<String> {
116 match gguf {
117 "token_embd.weight" => return Some("model.embed_tokens.weight".into()),
118 "output_norm.weight" => return Some("model.norm.weight".into()),
119 "output.weight" => return Some("lm_head.weight".into()),
120 _ => {}
121 }
122 let rest = gguf.strip_prefix("blk.")?;
123 let dot = rest.find('.')?;
124 let (idx_str, tail_with_dot) = rest.split_at(dot);
125 let tail = &tail_with_dot[1..];
126 let idx: usize = idx_str.parse().ok()?;
127 let hf_tail = match tail {
128 "attn_norm.weight" => "input_layernorm.weight",
129 "ffn_norm.weight" => "post_attention_layernorm.weight",
130 "attn_q.weight" => "self_attn.q_proj.weight",
131 "attn_k.weight" => "self_attn.k_proj.weight",
132 "attn_v.weight" => "self_attn.v_proj.weight",
133 "attn_output.weight" => "self_attn.o_proj.weight",
134 "attn_q.bias" => "self_attn.q_proj.bias",
135 "attn_k.bias" => "self_attn.k_proj.bias",
136 "attn_v.bias" => "self_attn.v_proj.bias",
137 "attn_q_norm.weight" => "self_attn.q_norm.weight",
138 "attn_k_norm.weight" => "self_attn.k_norm.weight",
139 "ffn_gate.weight" => "mlp.gate_proj.weight",
140 "ffn_up.weight" => "mlp.up_proj.weight",
141 "ffn_down.weight" => "mlp.down_proj.weight",
142 _ => return None,
143 };
144 Some(format!("model.layers.{idx}.{hf_tail}"))
145}
146
147pub fn gguf_to_hf_name_for_arch(gguf: &str, arch: &str) -> Option<String> {
155 if matches!(
156 arch,
157 "gemma2" | "gemma3" | "gemma3n" | "gemma4" | "gemma4moe"
158 ) {
159 match gguf {
160 "token_embd.weight" => return Some("model.embed_tokens.weight".into()),
161 "output_norm.weight" => return Some("model.norm.weight".into()),
162 "output.weight" => return Some("lm_head.weight".into()),
163 _ => {}
164 }
165 let rest = gguf.strip_prefix("blk.")?;
166 let dot = rest.find('.')?;
167 let (idx_str, tail_with_dot) = rest.split_at(dot);
168 let tail = &tail_with_dot[1..];
169 let idx: usize = idx_str.parse().ok()?;
170 let hf_tail = match tail {
171 "attn_norm.weight" => "input_layernorm.weight",
172 "post_attention_norm.weight" => "post_attention_layernorm.weight",
173 "ffn_norm.weight" => "pre_feedforward_layernorm.weight",
174 "post_ffw_norm.weight" => "post_feedforward_layernorm.weight",
175 "attn_q.weight" => "self_attn.q_proj.weight",
176 "attn_k.weight" => "self_attn.k_proj.weight",
177 "attn_v.weight" => "self_attn.v_proj.weight",
178 "attn_output.weight" => "self_attn.o_proj.weight",
179 "ffn_gate.weight" => "mlp.gate_proj.weight",
180 "ffn_up.weight" => "mlp.up_proj.weight",
181 "ffn_down.weight" => "mlp.down_proj.weight",
182 _ => return None,
183 };
184 return Some(format!("model.layers.{idx}.{hf_tail}"));
185 }
186 gguf_to_hf_name(gguf)
187}
188
189fn is_gemma_norm_weight(name: &str) -> bool {
195 if name == "output_norm.weight" || name == "model.norm.weight" {
196 return true;
197 }
198 if let Some(rest) = name
199 .strip_prefix("blk.")
200 .and_then(|r| r.split_once('.').map(|x| x.1))
201 {
202 return matches!(
203 rest,
204 "attn_norm.weight"
205 | "post_attention_norm.weight"
206 | "ffn_norm.weight"
207 | "post_ffw_norm.weight"
208 );
209 }
210 if let Some(rest) = name
211 .strip_prefix("model.layers.")
212 .and_then(|r| r.split_once('.').map(|x| x.1))
213 {
214 return matches!(
215 rest,
216 "input_layernorm.weight"
217 | "post_attention_layernorm.weight"
218 | "pre_feedforward_layernorm.weight"
219 | "post_feedforward_layernorm.weight"
220 );
221 }
222 false
223}
224
225pub fn is_mtp_weight(name: &str) -> bool {
238 name.contains("mtp_") || name.contains(".mtp") || name.starts_with("mtp")
239}
240
241pub fn ggml_type_to_quant_scheme(dtype: rlx_gguf::GgmlType) -> Option<QuantScheme> {
243 use rlx_gguf::GgmlType;
244 match dtype {
245 GgmlType::Q2K => Some(QuantScheme::GgufQ2K),
246 GgmlType::Q3K => Some(QuantScheme::GgufQ3K),
247 GgmlType::Q4K => Some(QuantScheme::GgufQ4K),
248 GgmlType::Q5K => Some(QuantScheme::GgufQ5K),
249 GgmlType::Q6K => Some(QuantScheme::GgufQ6K),
250 GgmlType::Q8K => Some(QuantScheme::GgufQ8K),
251 GgmlType::Q4_0 => Some(QuantScheme::GgufQ4_0),
252 GgmlType::Q8_0 => Some(QuantScheme::GgufQ8_0),
253 _ => None,
254 }
255}
256
257pub fn dequant_matmul_supported(scheme: QuantScheme) -> bool {
265 match scheme {
266 QuantScheme::GgufQ6K => q6k_dequant_matmul_supported(),
267 _ => true,
268 }
269}
270
271fn q6k_dequant_matmul_supported() -> bool {
272 static OK: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
273 *OK.get_or_init(probe_q6k_block_dequant)
274}
275
276fn probe_q6k_block_dequant() -> bool {
278 use rlx_gguf::{QK_K, dequant_q6_k, dequant_q6_k_block};
279 const BLK: usize = QK_K / 2 + QK_K / 4 + QK_K / 16 + 2;
280 let mut block = [0u8; BLK];
281 let sc_off = QK_K / 2 + QK_K / 4;
282 block[sc_off] = 0xFF;
283 block[0] = 0x21;
284 block[QK_K / 2] = 0x08;
285 block[BLK - 2..].copy_from_slice(&half::f16::ONE.to_le_bytes());
286
287 let mut out_block = [0f32; QK_K];
288 dequant_q6_k_block(&block, &mut out_block);
289 let full = match dequant_q6_k(&block, QK_K) {
290 Ok(v) => v,
291 Err(_) => return false,
292 };
293 (out_block[0] - full[0]).abs() < 1e-4
294}
295
296pub trait WeightLoader: Send {
302 fn format_id(&self) -> &'static str {
304 "unknown"
305 }
306 fn len(&self) -> usize;
308 fn is_empty(&self) -> bool {
309 self.len() == 0
310 }
311 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
314 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
319 fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
321 let _ = key;
322 Ok(None)
323 }
324 fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
326 let _ = key;
327 None
328 }
329 fn remaining_keys(&self) -> Vec<String>;
332 fn arch_hint(&self) -> Option<&str> {
338 None
339 }
340}
341
342impl WeightLoader for WeightMap {
343 fn format_id(&self) -> &'static str {
344 "safetensors"
345 }
346 fn len(&self) -> usize {
347 Self::len(self)
348 }
349 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
350 Self::take(self, key)
351 }
352 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
353 Self::take_transposed(self, key)
354 }
355 fn remaining_keys(&self) -> Vec<String> {
356 self.keys().map(|s| s.to_string()).collect()
357 }
358}
359
360pub struct HfTranslatingLoader<L: WeightLoader> {
373 inner: L,
374}
375
376impl<L: WeightLoader> HfTranslatingLoader<L> {
377 pub fn new(inner: L) -> Self {
378 Self { inner }
379 }
380 pub fn into_inner(self) -> L {
381 self.inner
382 }
383 pub fn inner(&self) -> &L {
384 &self.inner
385 }
386 pub fn inner_mut(&mut self) -> &mut L {
387 &mut self.inner
388 }
389}
390
391impl<L: WeightLoader> WeightLoader for HfTranslatingLoader<L> {
392 fn format_id(&self) -> &'static str {
393 self.inner.format_id()
394 }
395 fn len(&self) -> usize {
396 self.inner.len()
397 }
398 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
399 match self.inner.take(key) {
400 Ok(v) => Ok(v),
401 Err(_) => {
402 if let Some(hf) = gguf_to_hf_name(key) {
403 return self.inner.take(&hf);
404 }
405 self.inner.take(key)
406 }
407 }
408 }
409 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
410 match self.inner.take_transposed(key) {
411 Ok(v) => Ok(v),
412 Err(_) => {
413 if let Some(hf) = gguf_to_hf_name(key) {
414 return self.inner.take_transposed(&hf);
415 }
416 self.inner.take_transposed(key)
417 }
418 }
419 }
420 fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
421 self.inner.take_packed(key)
422 }
423 fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
424 self.inner.tensor_bytes_borrowed(key)
425 }
426 fn remaining_keys(&self) -> Vec<String> {
427 self.inner.remaining_keys()
428 }
429}
430
431pub fn load_from_path(path: &str) -> Result<Box<dyn WeightLoader>> {
433 crate::weight_registry::open_weight_loader(Path::new(path))
434}
435
436pub struct GgufLoader {
444 file: rlx_gguf::GgufFile,
445 arch: String,
446 taken: HashSet<String>,
447 include_mtp: bool,
454 mtp_layer_threshold: Option<u32>,
459}
460
461impl GgufLoader {
462 pub fn from_file(path: &str) -> Result<Self> {
463 let file = crate::gguf_support::load_gguf_file(std::path::Path::new(path))?;
464 let arch = gguf_architecture_str(&file)
465 .unwrap_or("unknown")
466 .to_string();
467 let mtp_layer_threshold = compute_mtp_layer_threshold(&file);
468 Ok(Self {
469 file,
470 arch,
471 taken: HashSet::new(),
472 include_mtp: false,
473 mtp_layer_threshold,
474 })
475 }
476
477 pub fn architecture(&self) -> &str {
478 &self.arch
479 }
480
481 pub fn mtp_layer_threshold(&self) -> Option<u32> {
488 self.mtp_layer_threshold
489 }
490
491 pub fn file(&self) -> &rlx_gguf::GgufFile {
495 &self.file
496 }
497
498 pub fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
505 let real = self.resolve(key).ok()?;
506 let t = self.file.get(&real)?;
507 self.file.tensor_bytes(t).ok()
508 }
509
510 pub fn take_packed_metadata(
518 &mut self,
519 key: &str,
520 ) -> Result<Option<(rlx_ir::quant::QuantScheme, Vec<usize>)>> {
521 let real = self.resolve(key)?;
522 if self.taken.contains(&real) {
523 return Err(anyhow!("weight already taken: {key} (→ {real})"));
524 }
525 if !self.include_mtp && self.is_mtp_tensor(&real) {
526 return Err(anyhow!(
527 "refusing to take MTP weight `{real}` without include_mtp(true)"
528 ));
529 }
530 let t = self
531 .file
532 .get(&real)
533 .ok_or_else(|| anyhow!("tensor missing: {real}"))?;
534 let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
535 return Ok(None);
536 };
537 if !dequant_matmul_supported(scheme) {
538 return Ok(None);
539 }
540 let mut shape = t.shape.clone();
541 shape.reverse();
542 self.taken.insert(real);
543 Ok(Some((scheme, shape)))
544 }
545
546 pub fn is_mtp_tensor(&self, name: &str) -> bool {
550 if is_mtp_weight(name) {
551 return true;
552 }
553 if let Some(thresh) = self.mtp_layer_threshold {
554 if let Some(rest) = name.strip_prefix("blk.") {
555 if let Some(dot) = rest.find('.') {
556 if let Ok(idx) = rest[..dot].parse::<u32>() {
557 if idx >= thresh {
558 return true;
559 }
560 }
561 }
562 }
563 }
564 false
565 }
566
567 pub fn include_mtp(&mut self, include: bool) -> &mut Self {
575 self.include_mtp = include;
576 self
577 }
578
579 pub fn take_packed(&mut self, key: &str) -> Result<Option<PackedWeightTensor>> {
593 let real = self.resolve(key)?;
594 if self.taken.contains(&real) {
595 return Err(anyhow!("weight already taken: {key} (→ {real})"));
596 }
597 if !self.include_mtp && self.is_mtp_tensor(&real) {
598 return Err(anyhow!(
599 "refusing to take MTP weight `{real}` without include_mtp(true)"
600 ));
601 }
602 let t = self
603 .file
604 .get(&real)
605 .ok_or_else(|| anyhow!("tensor missing: {real}"))?;
606 let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
611 return Ok(None);
612 };
613 if !dequant_matmul_supported(scheme) {
614 return Ok(None);
615 };
616 let bytes = self
617 .file
618 .tensor_bytes(t)
619 .with_context(|| format!("read packed bytes for {real}"))?
620 .to_vec();
621 let mut shape = t.shape.clone();
622 shape.reverse();
626 self.taken.insert(real);
627 Ok(Some((bytes, scheme, shape)))
628 }
629
630 pub fn take_mtp(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
635 if !self.is_mtp_tensor(key) {
636 return Err(anyhow!("not an MTP weight under this file's scheme: {key}"));
637 }
638 if !self.file.tensors.contains_key(key) {
639 return Err(anyhow!("MTP weight not found in GGUF: {key}"));
640 }
641 if self.taken.contains(key) {
642 return Err(anyhow!("MTP weight already taken: {key}"));
643 }
644 let (data, raw_shape) = self.file.dequant_f32(key)?;
645 self.taken.insert(key.to_string());
646 let mut shape = raw_shape;
647 shape.reverse();
648 Ok((data, shape))
649 }
650}
651
652impl GgufLoader {
653 fn resolve(&self, key: &str) -> Result<String> {
656 resolve_gguf_tensor_name(&self.file, &self.arch, key)
657 .ok_or_else(|| anyhow!("weight not found in GGUF (arch={}): {key}", self.arch))
658 }
659}
660
661impl WeightLoader for GgufLoader {
662 fn format_id(&self) -> &'static str {
663 "gguf"
664 }
665 fn arch_hint(&self) -> Option<&str> {
666 Some(&self.arch)
667 }
668 fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
669 self.take_packed(key)
670 }
671 fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
672 GgufLoader::tensor_bytes_borrowed(self, key)
673 }
674 fn len(&self) -> usize {
675 self.file
676 .tensors
677 .keys()
678 .filter(|k| !self.taken.contains(*k) && (self.include_mtp || !self.is_mtp_tensor(k)))
679 .count()
680 }
681 fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
682 let real = self.resolve(key)?;
683 if self.taken.contains(&real) {
684 return Err(anyhow!("weight already taken: {key} (→ {real})"));
685 }
686 if !self.include_mtp && self.is_mtp_tensor(&real) {
687 return Err(anyhow!(
688 "refusing to take MTP weight `{real}` without include_mtp(true); \
689 use loader.take_mtp(...) for explicit MTP grabs or \
690 loader.include_mtp(true) to include them in drains"
691 ));
692 }
693 let (mut data, raw_shape) = self.file.dequant_f32(&real)?;
694 self.taken.insert(real.clone());
695 if matches!(
704 self.arch.as_str(),
705 "gemma" | "gemma2" | "gemma3" | "gemma3n" | "gemma4"
706 ) && is_gemma_norm_weight(&real)
707 {
708 for v in data.iter_mut() {
709 *v -= 1.0;
710 }
711 }
712 let mut shape = raw_shape;
718 shape.reverse();
719 Ok((data, shape))
720 }
721 fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
731 let (data, shape) = self.take(key)?;
734 if shape.len() != 2 {
735 return Err(anyhow!("transpose requires 2D, got {shape:?}"));
736 }
737 let (rows, cols) = (shape[0], shape[1]);
738 let mut t = vec![0f32; data.len()];
739 for i in 0..rows {
740 for j in 0..cols {
741 t[j * rows + i] = data[i * cols + j];
742 }
743 }
744 Ok((t, vec![cols, rows]))
745 }
746 fn remaining_keys(&self) -> Vec<String> {
747 self.file
752 .tensors
753 .keys()
754 .filter(|k| {
755 !self.taken.contains(k.as_str()) && (self.include_mtp || !self.is_mtp_tensor(k))
756 })
757 .cloned()
758 .collect()
759 }
760}
761
762impl GgufLoader {
763 pub fn mtp_keys(&self) -> Vec<String> {
770 self.file
771 .tensors
772 .keys()
773 .filter(|k| self.is_mtp_tensor(k))
774 .cloned()
775 .collect()
776 }
777}
778
779#[cfg(test)]
780mod tests {
781 use super::*;
782
783 #[test]
784 fn unknown_extension_errors() {
785 let r = load_from_path("/tmp/no-such-thing.bin");
786 match r {
787 Err(e) => assert!(e.to_string().contains("unsupported")),
788 Ok(_) => panic!("expected error"),
789 }
790 }
791
792 #[test]
793 fn hf_to_gguf_top_level() {
794 assert_eq!(
795 hf_to_gguf_name("model.embed_tokens.weight").as_deref(),
796 Some("token_embd.weight")
797 );
798 assert_eq!(
799 hf_to_gguf_name("model.norm.weight").as_deref(),
800 Some("output_norm.weight")
801 );
802 assert_eq!(
803 hf_to_gguf_name("lm_head.weight").as_deref(),
804 Some("output.weight")
805 );
806 }
807
808 #[test]
809 fn hf_to_gguf_per_layer() {
810 let cases = [
811 (
812 "model.layers.0.self_attn.q_proj.weight",
813 "blk.0.attn_q.weight",
814 ),
815 (
816 "model.layers.7.self_attn.o_proj.weight",
817 "blk.7.attn_output.weight",
818 ),
819 (
820 "model.layers.3.mlp.gate_proj.weight",
821 "blk.3.ffn_gate.weight",
822 ),
823 (
824 "model.layers.12.mlp.down_proj.weight",
825 "blk.12.ffn_down.weight",
826 ),
827 (
828 "model.layers.4.input_layernorm.weight",
829 "blk.4.attn_norm.weight",
830 ),
831 (
832 "model.layers.4.post_attention_layernorm.weight",
833 "blk.4.ffn_norm.weight",
834 ),
835 (
836 "model.layers.0.self_attn.q_norm.weight",
837 "blk.0.attn_q_norm.weight",
838 ),
839 ];
840 for (hf, gguf) in cases {
841 assert_eq!(
842 hf_to_gguf_name(hf).as_deref(),
843 Some(gguf),
844 "mismatch for {hf}"
845 );
846 }
847 }
848
849 #[test]
850 fn hf_to_gguf_unknown_returns_none() {
851 assert!(hf_to_gguf_name("model.layers.0.some_new_thing.weight").is_none());
852 assert!(hf_to_gguf_name("model.layers.foo.input_layernorm.weight").is_none());
853 }
854
855 #[test]
856 fn mtp_detection() {
857 assert!(is_mtp_weight("mtp_blk.0.attn_q.weight"));
858 assert!(is_mtp_weight("output_mtp_0.weight"));
859 assert!(is_mtp_weight("model.layers.0.mtp_head.weight"));
860 assert!(!is_mtp_weight("blk.0.attn_q.weight"));
861 assert!(!is_mtp_weight("output.weight"));
862 }
863
864 #[test]
871 fn ggml_q4_0_maps_to_packed_scheme() {
872 use rlx_gguf::GgmlType;
873 assert_eq!(
874 ggml_type_to_quant_scheme(GgmlType::Q4_0),
875 Some(rlx_ir::quant::QuantScheme::GgufQ4_0)
876 );
877 assert_eq!(
878 ggml_type_to_quant_scheme(GgmlType::Q8_0),
879 Some(rlx_ir::quant::QuantScheme::GgufQ8_0)
880 );
881 }
882
883 #[test]
884 fn q6k_dequant_matmul_follows_block_probe() {
885 use rlx_ir::quant::QuantScheme;
886 assert!(dequant_matmul_supported(QuantScheme::GgufQ4K));
887 assert_eq!(
888 dequant_matmul_supported(QuantScheme::GgufQ6K),
889 super::probe_q6k_block_dequant()
890 );
891 }
892
893 #[test]
894 fn gguf_loader_threshold_based_mtp_detection() {
895 let mut buf: Vec<u8> = Vec::new();
896 buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
897 buf.extend_from_slice(&3u32.to_le_bytes());
898 buf.extend_from_slice(&3u64.to_le_bytes()); buf.extend_from_slice(&3u64.to_le_bytes()); let write_string = |buf: &mut Vec<u8>, k: &str, v: &str| {
902 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
903 buf.extend_from_slice(k.as_bytes());
904 buf.extend_from_slice(&8u32.to_le_bytes());
905 buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
906 buf.extend_from_slice(v.as_bytes());
907 };
908 let write_u32 = |buf: &mut Vec<u8>, k: &str, v: u32| {
909 buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
910 buf.extend_from_slice(k.as_bytes());
911 buf.extend_from_slice(&4u32.to_le_bytes()); buf.extend_from_slice(&v.to_le_bytes());
913 };
914 write_string(&mut buf, "general.architecture", "qwen35");
915 write_u32(&mut buf, "qwen35.block_count", 25);
916 write_u32(&mut buf, "qwen35.nextn_predict_layers", 1);
917 let write_tensor = |buf: &mut Vec<u8>, name: &str, shape: &[usize], off: u64| {
920 buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
921 buf.extend_from_slice(name.as_bytes());
922 buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
923 for &d in shape {
924 buf.extend_from_slice(&(d as u64).to_le_bytes());
925 }
926 buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&off.to_le_bytes());
928 };
929 write_tensor(&mut buf, "blk.0.attn_q.weight", &[4, 4], 0);
930 write_tensor(&mut buf, "blk.24.attn_q.weight", &[4, 4], 64);
931 write_tensor(&mut buf, "token_embd.weight", &[4, 4], 128);
932 while !buf
933 .len()
934 .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
935 {
936 buf.push(0);
937 }
938 for _ in 0..(3 * 16) {
940 buf.extend_from_slice(&0.5f32.to_le_bytes());
941 }
942 let path = std::env::temp_dir().join("rlx_mtp_threshold_test.gguf");
943 std::fs::write(&path, &buf).unwrap();
944 let loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
945
946 assert_eq!(loader.mtp_layer_threshold(), Some(24));
947 assert!(!loader.is_mtp_tensor("blk.0.attn_q.weight"));
948 assert!(loader.is_mtp_tensor("blk.24.attn_q.weight"));
949 assert!(!loader.is_mtp_tensor("token_embd.weight"));
950 let mtp = loader.mtp_keys();
951 assert_eq!(mtp, vec!["blk.24.attn_q.weight".to_string()]);
952
953 std::fs::remove_file(&path).ok();
954 }
955
956 #[test]
963 fn gguf_loader_resolves_hf_names_and_skips_mtp() {
964 let mut tensors = Vec::new();
965 let mut data: Vec<f32> = Vec::new();
966
967 let t1: Vec<f32> = (0..12).map(|x| x as f32).collect();
969 tensors.push(("token_embd.weight", vec![3usize, 4], data.len()));
970 data.extend_from_slice(&t1);
971
972 let t2: Vec<f32> = (100..116).map(|x| x as f32).collect();
974 tensors.push(("blk.0.attn_q.weight", vec![4usize, 4], data.len()));
975 data.extend_from_slice(&t2);
976
977 let t3: Vec<f32> = vec![0.5f32; 8];
979 tensors.push(("output_mtp_0.weight", vec![2usize, 4], data.len()));
980 data.extend_from_slice(&t3);
981
982 let mut buf: Vec<u8> = Vec::new();
984 buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
985 buf.extend_from_slice(&3u32.to_le_bytes()); buf.extend_from_slice(&(tensors.len() as u64).to_le_bytes());
987 buf.extend_from_slice(&0u64.to_le_bytes()); for (name, shape, _) in &tensors {
991 buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
992 buf.extend_from_slice(name.as_bytes());
993 buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
994 for &d in shape {
995 buf.extend_from_slice(&(d as u64).to_le_bytes());
996 }
997 buf.extend_from_slice(&0u32.to_le_bytes()); buf.extend_from_slice(&0u64.to_le_bytes());
1000 }
1001 while !buf
1003 .len()
1004 .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
1005 {
1006 buf.push(0);
1007 }
1008 let data_start = buf.len();
1009 for v in &data {
1010 buf.extend_from_slice(&v.to_le_bytes());
1011 }
1012 let header = (4 + 4 + 8 + 8) as usize; let mut cursor = header;
1015 for (name, shape, byte_off) in &tensors {
1016 let name_len_bytes = 8;
1017 let name_bytes = name.len();
1018 let n_dims_bytes = 4;
1019 let dims_bytes = shape.len() * 8;
1020 let dtype_bytes = 4;
1021 let off_bytes = 8;
1022 let info_size =
1023 name_len_bytes + name_bytes + n_dims_bytes + dims_bytes + dtype_bytes + off_bytes;
1024 let off_field_at = cursor + info_size - off_bytes;
1025 let final_off = (*byte_off * 4) as u64; for i in 0..8 {
1027 buf[off_field_at + i] = (final_off >> (i * 8)) as u8;
1028 }
1029 cursor += info_size;
1030 }
1031 let _ = data_start;
1032
1033 let path = std::env::temp_dir().join("rlx_test_qwen3_mini.gguf");
1035 std::fs::write(&path, &buf).unwrap();
1036
1037 let mut loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1038 assert_eq!(loader.len(), 2);
1040
1041 let (out, shape) = loader
1047 .take("model.embed_tokens.weight")
1048 .expect("hf-named token_embd should resolve");
1049 assert_eq!(shape, vec![4, 3]);
1050 assert_eq!(&out, &t1);
1051
1052 let (out, shape) = loader
1053 .take("model.layers.0.self_attn.q_proj.weight")
1054 .expect("hf-named attn_q should resolve");
1055 assert_eq!(shape, vec![4, 4]);
1056 assert_eq!(&out, &t2);
1057
1058 assert_eq!(loader.remaining_keys(), Vec::<String>::new());
1060 assert_eq!(loader.mtp_keys(), vec!["output_mtp_0.weight".to_string()]);
1061
1062 let mut loader2 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1066 loader2.include_mtp(true);
1067 let visible: std::collections::HashSet<String> =
1068 loader2.remaining_keys().into_iter().collect();
1069 assert!(visible.contains("token_embd.weight"));
1070 assert!(visible.contains("blk.0.attn_q.weight"));
1071 assert!(
1072 visible.contains("output_mtp_0.weight"),
1073 "MTP weight should be visible with include_mtp(true)"
1074 );
1075 let (mtp_data, mtp_shape) = loader2.take_mtp("output_mtp_0.weight").unwrap();
1076 assert_eq!(mtp_shape, vec![4, 2]);
1077 assert_eq!(mtp_data, t3);
1078
1079 let mut loader3 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1081 let err = loader3.take("output_mtp_0.weight").unwrap_err();
1082 let msg = format!("{err:#}");
1083 assert!(
1084 msg.contains("include_mtp(true)"),
1085 "expected MTP guard error, got: {msg}"
1086 );
1087
1088 std::fs::remove_file(&path).ok();
1089 }
1090
1091 #[test]
1092 fn missing_gguf_file_errors() {
1093 let r = load_from_path("/tmp/no-such-thing-rlx-gguf-test.gguf");
1096 assert!(r.is_err());
1097 }
1098}