Skip to main content

rlx_models_core/
weight_loader.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pluggable weight loader trait (plan #56).
17//!
18//! Borrowed from MAX's `max/python/max/graph/weights/` layout:
19//! `load_safetensors.py`, `load_gguf.py`, plus a `load.py` dispatcher
20//! that detects format from the file extension.
21//!
22//! Today the only impl is safetensors (via the existing
23//! [`WeightMap::from_file`]). Adding a new format = one new struct
24//! that implements [`WeightLoader`] + an extension match in
25//! [`load_from_path`]. The model graph builders take any
26//! `&mut dyn WeightLoader` so they don't care which format the
27//! weights came from.
28
29use anyhow::{Context, Result, anyhow};
30use rlx_gguf::MetaValue;
31use std::collections::HashSet;
32use std::path::Path;
33
34/// Walk the GGUF metadata for `{arch}.block_count -
35/// {arch}.nextn_predict_layers`. Returns `Some(main_layer_count)`
36/// when the file has explicit MTP metadata, else `None`.
37fn compute_mtp_layer_threshold(file: &rlx_gguf::GgufFile) -> Option<u32> {
38    let arch = file
39        .metadata
40        .get("general.architecture")
41        .and_then(MetaValue::as_str)?;
42    let block_count = file
43        .metadata
44        .get(&format!("{arch}.block_count"))
45        .and_then(MetaValue::as_u32)?;
46    let nextn = file
47        .metadata
48        .get(&format!("{arch}.nextn_predict_layers"))
49        .and_then(MetaValue::as_u32)?;
50    if nextn == 0 {
51        return None;
52    }
53    Some(block_count.saturating_sub(nextn))
54}
55
56use crate::gguf_resolve::resolve_gguf_tensor_name;
57use crate::gguf_support::gguf_architecture_str;
58use crate::weight_map::PackedWeightTensor;
59use crate::weight_map::WeightMap;
60use rlx_ir::quant::QuantScheme;
61
62/// Translate a Hugging Face / safetensors-convention tensor name to
63/// the GGUF / llama.cpp convention. Returns `None` when no mapping
64/// exists (caller should treat the name as already-GGUF or as an
65/// architecture-specific weight that this mapper doesn't know about,
66/// e.g. MTP heads — see [`is_mtp_weight`]).
67///
68/// The mapping mirrors the table baked into `llama.cpp`'s
69/// `gguf-py/gguf/tensor_mapping.py` for the LLaMA-family architectures
70/// (Qwen3 reuses it). When adding new architectures, prefer extending
71/// this function over forking it.
72pub fn hf_to_gguf_name(hf: &str) -> Option<String> {
73    // Top-level (non-layer) tensors.
74    match hf {
75        "model.embed_tokens.weight" => return Some("token_embd.weight".into()),
76        "model.norm.weight" => return Some("output_norm.weight".into()),
77        "lm_head.weight" => return Some("output.weight".into()),
78        _ => {}
79    }
80    // Layer tensors: `model.layers.{i}.<tail>.weight` → `blk.{i}.<gguf-tail>.weight`.
81    let rest = hf.strip_prefix("model.layers.")?;
82    let dot = rest.find('.')?;
83    let (idx_str, tail_with_dot) = rest.split_at(dot);
84    let tail = &tail_with_dot[1..]; // skip the '.'
85    let idx: usize = idx_str.parse().ok()?;
86    let gguf_tail = match tail {
87        "input_layernorm.weight" => "attn_norm.weight",
88        // Llama-style: `post_attention_layernorm` ≡ pre-FFN norm.
89        // Gemma 2 disagrees (it has 4 distinct norms) and uses a
90        // dedicated `Gemma2GgufResolver` to override this entry. Don't
91        // add Gemma 2 names here — they'd collide with Llama callers.
92        "post_attention_layernorm.weight" => "ffn_norm.weight",
93        "self_attn.q_proj.weight" => "attn_q.weight",
94        "self_attn.k_proj.weight" => "attn_k.weight",
95        "self_attn.v_proj.weight" => "attn_v.weight",
96        "self_attn.o_proj.weight" => "attn_output.weight",
97        "self_attn.q_proj.bias" => "attn_q.bias",
98        "self_attn.k_proj.bias" => "attn_k.bias",
99        "self_attn.v_proj.bias" => "attn_v.bias",
100        "self_attn.q_norm.weight" => "attn_q_norm.weight",
101        "self_attn.k_norm.weight" => "attn_k_norm.weight",
102        "mlp.gate_proj.weight" => "ffn_gate.weight",
103        "mlp.up_proj.weight" => "ffn_up.weight",
104        "mlp.down_proj.weight" => "ffn_down.weight",
105        _ => return None,
106    };
107    Some(format!("blk.{idx}.{gguf_tail}"))
108}
109
110/// Inverse of [`hf_to_gguf_name`] — translate a GGUF / llama.cpp
111/// tensor name back to the safetensors / HuggingFace convention. Used
112/// by drain-style loaders (e.g. `Qwen3Generator::from_loader`) that
113/// cache weights by name and need the cache key to match what the
114/// builder will ask for.
115pub fn gguf_to_hf_name(gguf: &str) -> Option<String> {
116    match gguf {
117        "token_embd.weight" => return Some("model.embed_tokens.weight".into()),
118        "output_norm.weight" => return Some("model.norm.weight".into()),
119        "output.weight" => return Some("lm_head.weight".into()),
120        _ => {}
121    }
122    let rest = gguf.strip_prefix("blk.")?;
123    let dot = rest.find('.')?;
124    let (idx_str, tail_with_dot) = rest.split_at(dot);
125    let tail = &tail_with_dot[1..];
126    let idx: usize = idx_str.parse().ok()?;
127    let hf_tail = match tail {
128        "attn_norm.weight" => "input_layernorm.weight",
129        "ffn_norm.weight" => "post_attention_layernorm.weight",
130        "attn_q.weight" => "self_attn.q_proj.weight",
131        "attn_k.weight" => "self_attn.k_proj.weight",
132        "attn_v.weight" => "self_attn.v_proj.weight",
133        "attn_output.weight" => "self_attn.o_proj.weight",
134        "attn_q.bias" => "self_attn.q_proj.bias",
135        "attn_k.bias" => "self_attn.k_proj.bias",
136        "attn_v.bias" => "self_attn.v_proj.bias",
137        "attn_q_norm.weight" => "self_attn.q_norm.weight",
138        "attn_k_norm.weight" => "self_attn.k_norm.weight",
139        "ffn_gate.weight" => "mlp.gate_proj.weight",
140        "ffn_up.weight" => "mlp.up_proj.weight",
141        "ffn_down.weight" => "mlp.down_proj.weight",
142        _ => return None,
143    };
144    Some(format!("model.layers.{idx}.{hf_tail}"))
145}
146
147/// Arch-aware variant of [`gguf_to_hf_name`]. Falls back to the
148/// generic mapping when `arch` doesn't carry overrides; for arches
149/// whose 1↔1 alias disagrees with the Llama convention (Gemma 2/3/4:
150/// `ffn_norm` is the pre-FFN norm, not the post-attention norm), it
151/// returns the arch-correct HF name instead. Used by drain-style
152/// `from_loader` paths to compute cache keys that match what the
153/// builder will ask for.
154pub fn gguf_to_hf_name_for_arch(gguf: &str, arch: &str) -> Option<String> {
155    if matches!(
156        arch,
157        "gemma2" | "gemma3" | "gemma3n" | "gemma4" | "gemma4moe"
158    ) {
159        match gguf {
160            "token_embd.weight" => return Some("model.embed_tokens.weight".into()),
161            "output_norm.weight" => return Some("model.norm.weight".into()),
162            "output.weight" => return Some("lm_head.weight".into()),
163            _ => {}
164        }
165        let rest = gguf.strip_prefix("blk.")?;
166        let dot = rest.find('.')?;
167        let (idx_str, tail_with_dot) = rest.split_at(dot);
168        let tail = &tail_with_dot[1..];
169        let idx: usize = idx_str.parse().ok()?;
170        let hf_tail = match tail {
171            "attn_norm.weight" => "input_layernorm.weight",
172            "post_attention_norm.weight" => "post_attention_layernorm.weight",
173            "ffn_norm.weight" => "pre_feedforward_layernorm.weight",
174            "post_ffw_norm.weight" => "post_feedforward_layernorm.weight",
175            "attn_q.weight" => "self_attn.q_proj.weight",
176            "attn_k.weight" => "self_attn.k_proj.weight",
177            "attn_v.weight" => "self_attn.v_proj.weight",
178            "attn_output.weight" => "self_attn.o_proj.weight",
179            "ffn_gate.weight" => "mlp.gate_proj.weight",
180            "ffn_up.weight" => "mlp.up_proj.weight",
181            "ffn_down.weight" => "mlp.down_proj.weight",
182            _ => return None,
183        };
184        return Some(format!("model.layers.{idx}.{hf_tail}"));
185    }
186    gguf_to_hf_name(gguf)
187}
188
189/// Match GGUF tensor names that hold a Gemma RMSNorm gain. Covers all
190/// four per-layer norms in the V2/V3/V4 sandwich (attn / post_attention
191/// / ffn / post_ffw) plus the final `output_norm`, in both GGUF-native
192/// and HF spellings — drain order is undefined, so we may see either
193/// convention at the call site.
194fn is_gemma_norm_weight(name: &str) -> bool {
195    if name == "output_norm.weight" || name == "model.norm.weight" {
196        return true;
197    }
198    if let Some(rest) = name
199        .strip_prefix("blk.")
200        .and_then(|r| r.split_once('.').map(|x| x.1))
201    {
202        return matches!(
203            rest,
204            "attn_norm.weight"
205                | "post_attention_norm.weight"
206                | "ffn_norm.weight"
207                | "post_ffw_norm.weight"
208        );
209    }
210    if let Some(rest) = name
211        .strip_prefix("model.layers.")
212        .and_then(|r| r.split_once('.').map(|x| x.1))
213    {
214        return matches!(
215            rest,
216            "input_layernorm.weight"
217                | "post_attention_layernorm.weight"
218                | "pre_feedforward_layernorm.weight"
219                | "post_feedforward_layernorm.weight"
220        );
221    }
222    false
223}
224
225/// True if the GGUF tensor name **looks like** a Multi-Token
226/// Prediction head by its name alone — substring match on
227/// `mtp_*` / `*.mtp` / `output_mtp_*` style. Covers MTP variants
228/// that name their heads explicitly.
229///
230/// **NOT enough on its own** for the most common unsloth /
231/// DeepSeek-V3 convention, which encodes MTP heads as *extra
232/// `blk.N` layers* with N beyond the main `block_count`. For that
233/// case use [`GgufLoader::mtp_layer_threshold`] / the loader's
234/// `is_mtp_tensor` instance method — they read
235/// `{arch}.nextn_predict_layers` from the GGUF metadata and treat
236/// trailing `blk.*` indices accordingly.
237pub fn is_mtp_weight(name: &str) -> bool {
238    name.contains("mtp_") || name.contains(".mtp") || name.starts_with("mtp")
239}
240
241/// Map GGML storage type to RLX packed matmul scheme (K-quants only).
242pub fn ggml_type_to_quant_scheme(dtype: rlx_gguf::GgmlType) -> Option<QuantScheme> {
243    use rlx_gguf::GgmlType;
244    match dtype {
245        GgmlType::Q2K => Some(QuantScheme::GgufQ2K),
246        GgmlType::Q3K => Some(QuantScheme::GgufQ3K),
247        GgmlType::Q4K => Some(QuantScheme::GgufQ4K),
248        GgmlType::Q5K => Some(QuantScheme::GgufQ5K),
249        GgmlType::Q6K => Some(QuantScheme::GgufQ6K),
250        GgmlType::Q8K => Some(QuantScheme::GgufQ8K),
251        GgmlType::Q4_0 => Some(QuantScheme::GgufQ4_0),
252        GgmlType::Q8_0 => Some(QuantScheme::GgufQ8_0),
253        _ => None,
254    }
255}
256
257/// Whether [`QuantScheme`] may stay packed in `Op::DequantMatMul` graphs.
258///
259/// `GgufQ6K` requires `rlx-gguf` ≥ 0.2.1 with signed scale bytes in
260/// [`rlx_gguf::dequant_q6_k_block`] (crates.io 0.2.1 used `as f32` and
261/// skewed v_proj / down_proj). When the block path disagrees with
262/// [`rlx_gguf::dequant_q6_k`], callers fall back to F32
263/// [`WeightLoader::take_transposed`].
264pub fn dequant_matmul_supported(scheme: QuantScheme) -> bool {
265    match scheme {
266        QuantScheme::GgufQ6K => q6k_dequant_matmul_supported(),
267        _ => true,
268    }
269}
270
271fn q6k_dequant_matmul_supported() -> bool {
272    static OK: std::sync::OnceLock<bool> = std::sync::OnceLock::new();
273    *OK.get_or_init(probe_q6k_block_dequant)
274}
275
276/// Synthetic Q6_K block: d=1, scale byte 0xFF (i8 −1), q=1 at slot 0.
277fn probe_q6k_block_dequant() -> bool {
278    use rlx_gguf::{QK_K, dequant_q6_k, dequant_q6_k_block};
279    const BLK: usize = QK_K / 2 + QK_K / 4 + QK_K / 16 + 2;
280    let mut block = [0u8; BLK];
281    let sc_off = QK_K / 2 + QK_K / 4;
282    block[sc_off] = 0xFF;
283    block[0] = 0x21;
284    block[QK_K / 2] = 0x08;
285    block[BLK - 2..].copy_from_slice(&half::f16::ONE.to_le_bytes());
286
287    let mut out_block = [0f32; QK_K];
288    dequant_q6_k_block(&block, &mut out_block);
289    let full = match dequant_q6_k(&block, QK_K) {
290        Ok(v) => v,
291        Err(_) => return false,
292    };
293    (out_block[0] - full[0]).abs() < 1e-4
294}
295
296/// Common interface every weight format must satisfy. Mirrors the
297/// existing `WeightMap` API so the safetensors impl is a one-line
298/// adapter.
299///
300/// Register additional on-disk formats with [`crate::weight_registry::register_weight_format`].
301pub trait WeightLoader: Send {
302    /// Format id (`safetensors`, `gguf`, or a custom registration).
303    fn format_id(&self) -> &'static str {
304        "unknown"
305    }
306    /// Number of distinct weights in the file.
307    fn len(&self) -> usize;
308    fn is_empty(&self) -> bool {
309        self.len() == 0
310    }
311    /// Take the named tensor as `(f32_data, shape)`. Removes from the
312    /// loader so callers can detect "weights I never used."
313    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
314    /// Same as `take` but transposed (last two dims swapped). Most
315    /// safetensors weights are stored row-major-of-PyTorch
316    /// convention, which RLX's IR consumes column-major; this helper
317    /// is the convention-bridge.
318    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
319    /// Take packed K-quant bytes when supported; default returns `None`.
320    fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
321        let _ = key;
322        Ok(None)
323    }
324    /// Borrow packed bytes without marking taken (GGUF mmap path).
325    fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
326        let _ = key;
327        None
328    }
329    /// Names that haven't been taken yet — useful for "did the
330    /// model use every weight?" hygiene checks.
331    fn remaining_keys(&self) -> Vec<String>;
332    /// Architecture name from the underlying file (`general.architecture`
333    /// for GGUF, `None` for safetensors). Drain-style consumers use this
334    /// to pick an arch-specific reverse name mapping when the canonical
335    /// HF name depends on the model family (e.g. Gemma 2's 4 norms per
336    /// layer don't share the Llama 2-norm reverse alias).
337    fn arch_hint(&self) -> Option<&str> {
338        None
339    }
340}
341
342impl WeightLoader for WeightMap {
343    fn format_id(&self) -> &'static str {
344        "safetensors"
345    }
346    fn len(&self) -> usize {
347        Self::len(self)
348    }
349    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
350        Self::take(self, key)
351    }
352    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
353        Self::take_transposed(self, key)
354    }
355    fn remaining_keys(&self) -> Vec<String> {
356        self.keys().map(|s| s.to_string()).collect()
357    }
358}
359
360/// Adapter that lets a HF-safetensors-backed [`WeightLoader`] satisfy
361/// requests phrased in GGUF-style names (`blk.N.attn_q.weight` etc.).
362///
363/// Builders like [`Qwen35Weights::from_loader`] address tensors using
364/// the GGUF / llama.cpp convention; the underlying safetensors file
365/// stores them under HF / PyTorch names (`model.layers.N.self_attn.q_proj.weight`).
366/// This wrapper:
367///
368/// 1. Tries the requested key verbatim (in case it's already-HF or
369///    the file was named GGUF-style).
370/// 2. Tries [`gguf_to_hf_name`] to translate the GGUF key → HF key.
371/// 3. Returns the underlying loader's error otherwise.
372pub struct HfTranslatingLoader<L: WeightLoader> {
373    inner: L,
374}
375
376impl<L: WeightLoader> HfTranslatingLoader<L> {
377    pub fn new(inner: L) -> Self {
378        Self { inner }
379    }
380    pub fn into_inner(self) -> L {
381        self.inner
382    }
383    pub fn inner(&self) -> &L {
384        &self.inner
385    }
386    pub fn inner_mut(&mut self) -> &mut L {
387        &mut self.inner
388    }
389}
390
391impl<L: WeightLoader> WeightLoader for HfTranslatingLoader<L> {
392    fn format_id(&self) -> &'static str {
393        self.inner.format_id()
394    }
395    fn len(&self) -> usize {
396        self.inner.len()
397    }
398    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
399        match self.inner.take(key) {
400            Ok(v) => Ok(v),
401            Err(_) => {
402                if let Some(hf) = gguf_to_hf_name(key) {
403                    return self.inner.take(&hf);
404                }
405                self.inner.take(key)
406            }
407        }
408    }
409    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
410        match self.inner.take_transposed(key) {
411            Ok(v) => Ok(v),
412            Err(_) => {
413                if let Some(hf) = gguf_to_hf_name(key) {
414                    return self.inner.take_transposed(&hf);
415                }
416                self.inner.take_transposed(key)
417            }
418        }
419    }
420    fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
421        self.inner.take_packed(key)
422    }
423    fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
424        self.inner.tensor_bytes_borrowed(key)
425    }
426    fn remaining_keys(&self) -> Vec<String> {
427        self.inner.remaining_keys()
428    }
429}
430
431/// Dispatch on the file extension via [`crate::weight_registry`].
432pub fn load_from_path(path: &str) -> Result<Box<dyn WeightLoader>> {
433    crate::weight_registry::open_weight_loader(Path::new(path))
434}
435
436// ─── GGUF adapter ─────────────────────────────────────────────────
437//
438// Wraps `rlx_gguf::GgufFile` so it satisfies `WeightLoader`. Tracks
439// taken keys in a side-set since `dequant_f32` borrows the file
440// immutably; the alternative — pre-decoding every tensor at load
441// time — defeats the point of GGUF's lazy block layout.
442
443pub struct GgufLoader {
444    file: rlx_gguf::GgufFile,
445    arch: String,
446    taken: HashSet<String>,
447    /// When true, `remaining_keys` / `len` / `take` treat MTP-head
448    /// weights as ordinary tensors instead of hiding them. The base
449    /// qwen3 builder ignores MTP tensors regardless — this flag
450    /// only changes the *visibility* in the `WeightLoader` surface
451    /// so downstream MTP-aware builders can iterate them through
452    /// the standard drain pattern.
453    include_mtp: bool,
454    /// First `blk.N` index that belongs to an MTP head, computed
455    /// from `{arch}.block_count - {arch}.nextn_predict_layers` at
456    /// construction. `None` for files without the metadata key
457    /// (= no MTP heads encoded as trailing blocks).
458    mtp_layer_threshold: Option<u32>,
459}
460
461impl GgufLoader {
462    pub fn from_file(path: &str) -> Result<Self> {
463        let file = crate::gguf_support::load_gguf_file(std::path::Path::new(path))?;
464        let arch = gguf_architecture_str(&file)
465            .unwrap_or("unknown")
466            .to_string();
467        let mtp_layer_threshold = compute_mtp_layer_threshold(&file);
468        Ok(Self {
469            file,
470            arch,
471            taken: HashSet::new(),
472            include_mtp: false,
473            mtp_layer_threshold,
474        })
475    }
476
477    pub fn architecture(&self) -> &str {
478        &self.arch
479    }
480
481    /// First `blk.N` index that the GGUF metadata reports as an MTP
482    /// head, derived from `{arch}.block_count -
483    /// {arch}.nextn_predict_layers`. `None` for files where the
484    /// `nextn_predict_layers` key is absent (= no MTP, or MTP is
485    /// encoded under a different naming scheme — fall back to
486    /// [`is_mtp_weight`] in that case).
487    pub fn mtp_layer_threshold(&self) -> Option<u32> {
488        self.mtp_layer_threshold
489    }
490
491    /// Borrow the underlying parsed `GgufFile` so callers (e.g. arch
492    /// builders that read `general.architecture`-specific keys)
493    /// don't have to re-parse 800+ tensor headers a second time.
494    pub fn file(&self) -> &rlx_gguf::GgufFile {
495        &self.file
496    }
497
498    /// Borrow the raw on-disk byte slice for a tensor without
499    /// marking it taken. Returns `None` if the key doesn't resolve
500    /// or the byte range is invalid. Used by the qwen35 packed-
501    /// upload path to stream K-quant bytes from mmap straight into
502    /// the compiled arena, skipping a per-tensor `Vec<u8>`
503    /// allocation (≈ 16 GB on Qwen3.6-27B Q4_K_M).
504    pub fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
505        let real = self.resolve(key).ok()?;
506        let t = self.file.get(&real)?;
507        self.file.tensor_bytes(t).ok()
508    }
509
510    /// Variant of [`Self::take_packed`] that returns only the
511    /// `(scheme, shape)` metadata without copying bytes. The caller
512    /// uploads bytes separately via [`Self::tensor_bytes_borrowed`]
513    /// after the graph is compiled — eliminates the per-tensor
514    /// `Vec<u8>` allocation. Marks the tensor taken on success;
515    /// returns `Ok(None)` for non-K-quant dtypes so the caller can
516    /// fall back to the dequant path.
517    pub fn take_packed_metadata(
518        &mut self,
519        key: &str,
520    ) -> Result<Option<(rlx_ir::quant::QuantScheme, Vec<usize>)>> {
521        let real = self.resolve(key)?;
522        if self.taken.contains(&real) {
523            return Err(anyhow!("weight already taken: {key} (→ {real})"));
524        }
525        if !self.include_mtp && self.is_mtp_tensor(&real) {
526            return Err(anyhow!(
527                "refusing to take MTP weight `{real}` without include_mtp(true)"
528            ));
529        }
530        let t = self
531            .file
532            .get(&real)
533            .ok_or_else(|| anyhow!("tensor missing: {real}"))?;
534        let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
535            return Ok(None);
536        };
537        if !dequant_matmul_supported(scheme) {
538            return Ok(None);
539        }
540        let mut shape = t.shape.clone();
541        shape.reverse();
542        self.taken.insert(real);
543        Ok(Some((scheme, shape)))
544    }
545
546    /// True if `name` is an MTP weight under this file's naming
547    /// scheme. Combines the substring heuristic ([`is_mtp_weight`])
548    /// with the model-aware `blk.N where N >= threshold` check.
549    pub fn is_mtp_tensor(&self, name: &str) -> bool {
550        if is_mtp_weight(name) {
551            return true;
552        }
553        if let Some(thresh) = self.mtp_layer_threshold {
554            if let Some(rest) = name.strip_prefix("blk.") {
555                if let Some(dot) = rest.find('.') {
556                    if let Ok(idx) = rest[..dot].parse::<u32>() {
557                        if idx >= thresh {
558                            return true;
559                        }
560                    }
561                }
562            }
563        }
564        false
565    }
566
567    /// Toggle MTP-weight visibility. With `include = true`, MTP
568    /// heads show up in `remaining_keys()` (and count toward `len()`)
569    /// — drain-style consumers like
570    /// `Qwen3Generator::from_loader` will then pull them into the
571    /// weights cache. Default off so non-MTP models behave exactly
572    /// as before. Call this before any `take()` / drain so the
573    /// inclusion choice is consistent across the load.
574    pub fn include_mtp(&mut self, include: bool) -> &mut Self {
575        self.include_mtp = include;
576        self
577    }
578
579    /// Take a tensor's **packed bytes** (no dequant), plus its
580    /// [`rlx_ir::quant::QuantScheme`] and safetensors-style shape.
581    /// Returns `None` when the tensor is stored uncompressed
582    /// (F32/F16/BF16) — caller should fall back to `take()` for
583    /// those.
584    ///
585    /// Used by the qwen3 builder's *packed-weights mode*: the LM
586    /// head + per-layer matmul weights stay in the arena as raw
587    /// K-quant bytes, and the graph emits
588    /// `Op::DequantMatMul { scheme }` instead of `Op::MatMul` for
589    /// them. Cuts the load-time memory footprint by ~7-9× on
590    /// Q4_K_M / Q6_K models — the unblocker for ≥14 B Qwen3 / Llama
591    /// GGUFs on commodity Macs.
592    pub fn take_packed(&mut self, key: &str) -> Result<Option<PackedWeightTensor>> {
593        let real = self.resolve(key)?;
594        if self.taken.contains(&real) {
595            return Err(anyhow!("weight already taken: {key} (→ {real})"));
596        }
597        if !self.include_mtp && self.is_mtp_tensor(&real) {
598            return Err(anyhow!(
599                "refusing to take MTP weight `{real}` without include_mtp(true)"
600            ));
601        }
602        let t = self
603            .file
604            .get(&real)
605            .ok_or_else(|| anyhow!("tensor missing: {real}"))?;
606        // Map ggml dtype → our QuantScheme. K-quants and Q4_0/Q8_0 can stay
607        // packed on CPU; Q4_1/Q5_* + uncompressed F32/F16/BF16 fall through
608        // F32/F16/BF16 fall through to the dequant path (return
609        // None — caller switches to `take`).
610        let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
611            return Ok(None);
612        };
613        if !dequant_matmul_supported(scheme) {
614            return Ok(None);
615        };
616        let bytes = self
617            .file
618            .tensor_bytes(t)
619            .with_context(|| format!("read packed bytes for {real}"))?
620            .to_vec();
621        let mut shape = t.shape.clone();
622        // Match the safetensors-style shape ordering applied by
623        // `take` — GGUF stores innermost-first, safetensors stores
624        // outermost-first; byte layout is identical.
625        shape.reverse();
626        self.taken.insert(real);
627        Ok(Some((bytes, scheme, shape)))
628    }
629
630    /// Take a single MTP weight by name. Bypasses the `include_mtp`
631    /// filter so callers can grab specific heads without flipping
632    /// the global visibility. Returns an error if the name isn't a
633    /// recognized MTP weight (use [`take`] for non-MTP keys).
634    pub fn take_mtp(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
635        if !self.is_mtp_tensor(key) {
636            return Err(anyhow!("not an MTP weight under this file's scheme: {key}"));
637        }
638        if !self.file.tensors.contains_key(key) {
639            return Err(anyhow!("MTP weight not found in GGUF: {key}"));
640        }
641        if self.taken.contains(key) {
642            return Err(anyhow!("MTP weight already taken: {key}"));
643        }
644        let (data, raw_shape) = self.file.dequant_f32(key)?;
645        self.taken.insert(key.to_string());
646        let mut shape = raw_shape;
647        shape.reverse();
648        Ok((data, shape))
649    }
650}
651
652impl GgufLoader {
653    /// Resolve a caller-supplied key (HF or GGUF naming) to the
654    /// actual GGUF tensor name via registered architecture resolvers.
655    fn resolve(&self, key: &str) -> Result<String> {
656        resolve_gguf_tensor_name(&self.file, &self.arch, key)
657            .ok_or_else(|| anyhow!("weight not found in GGUF (arch={}): {key}", self.arch))
658    }
659}
660
661impl WeightLoader for GgufLoader {
662    fn format_id(&self) -> &'static str {
663        "gguf"
664    }
665    fn arch_hint(&self) -> Option<&str> {
666        Some(&self.arch)
667    }
668    fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
669        self.take_packed(key)
670    }
671    fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
672        GgufLoader::tensor_bytes_borrowed(self, key)
673    }
674    fn len(&self) -> usize {
675        self.file
676            .tensors
677            .keys()
678            .filter(|k| !self.taken.contains(*k) && (self.include_mtp || !self.is_mtp_tensor(k)))
679            .count()
680    }
681    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
682        let real = self.resolve(key)?;
683        if self.taken.contains(&real) {
684            return Err(anyhow!("weight already taken: {key} (→ {real})"));
685        }
686        if !self.include_mtp && self.is_mtp_tensor(&real) {
687            return Err(anyhow!(
688                "refusing to take MTP weight `{real}` without include_mtp(true); \
689                 use loader.take_mtp(...) for explicit MTP grabs or \
690                 loader.include_mtp(true) to include them in drains"
691            ));
692        }
693        let (mut data, raw_shape) = self.file.dequant_f32(&real)?;
694        self.taken.insert(real.clone());
695        // Gemma's GGUF conversion script bakes the `(1 + gamma)` offset
696        // into every norm weight (see llama.cpp `convert_hf_to_gguf.py`
697        // → `GemmaModel.modify_tensors`). HF/safetensors stores raw
698        // gamma. Subtract 1 here so the loader publishes the
699        // safetensors convention and `GemmaRmsNormStage`'s explicit
700        // `+1` stays correct for both sources. Without this the RMS
701        // gain is systematically inflated and logits skew (cosine
702        // ~0.7 vs llama.cpp; structural, not numerical noise).
703        if matches!(
704            self.arch.as_str(),
705            "gemma" | "gemma2" | "gemma3" | "gemma3n" | "gemma4"
706        ) && is_gemma_norm_weight(&real)
707        {
708            for v in data.iter_mut() {
709                *v -= 1.0;
710            }
711        }
712        // GGUF/ggml report tensor shapes innermost-first (`ne[0]` is
713        // the fastest-varying dim) while safetensors reports outermost-
714        // first. The actual byte layout is identical row-major — only
715        // the shape label is reversed. Reverse to match safetensors so
716        // existing builders work unchanged; no data movement.
717        let mut shape = raw_shape;
718        shape.reverse();
719        Ok((data, shape))
720    }
721    /// **BREAKING CHANGE in 0.2.0:** prior to 0.2.0 this method was
722    /// a no-op for GGUF (returned the bytes unchanged with the GGUF
723    /// shape label) which silently produced garbage logits when the
724    /// builder expected `[in, out]` row-major. From 0.2.0 onwards
725    /// `take` normalizes GGUF's reverse-shape convention so this
726    /// method matches the safetensors variant byte-for-byte.
727    /// Downstream code that explicitly worked around the old buggy
728    /// behavior (manually re-transposing the result) must drop that
729    /// workaround.
730    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
731        // After the safetensors normalization in `take`, this matches
732        // the WeightMap implementation byte-for-byte.
733        let (data, shape) = self.take(key)?;
734        if shape.len() != 2 {
735            return Err(anyhow!("transpose requires 2D, got {shape:?}"));
736        }
737        let (rows, cols) = (shape[0], shape[1]);
738        let mut t = vec![0f32; data.len()];
739        for i in 0..rows {
740            for j in 0..cols {
741                t[j * rows + i] = data[i * cols + j];
742            }
743        }
744        Ok((t, vec![cols, rows]))
745    }
746    fn remaining_keys(&self) -> Vec<String> {
747        // MTP weights default to invisible — they belong to optional
748        // speculative heads and the base qwen3 builder ignores them.
749        // Callers wanting MTP-aware loading flip `include_mtp(true)`
750        // first, which surfaces them here.
751        self.file
752            .tensors
753            .keys()
754            .filter(|k| {
755                !self.taken.contains(k.as_str()) && (self.include_mtp || !self.is_mtp_tensor(k))
756            })
757            .cloned()
758            .collect()
759    }
760}
761
762impl GgufLoader {
763    /// Tensor names that look like MTP heads under this file's
764    /// scheme (combines the substring heuristic with the
765    /// model-aware `blk.N where N >= threshold` check — see
766    /// [`is_mtp_tensor`](Self::is_mtp_tensor)).
767    /// Returned unfiltered by `remaining_keys` so consumers wanting
768    /// to wire MTP can find them explicitly.
769    pub fn mtp_keys(&self) -> Vec<String> {
770        self.file
771            .tensors
772            .keys()
773            .filter(|k| self.is_mtp_tensor(k))
774            .cloned()
775            .collect()
776    }
777}
778
779#[cfg(test)]
780mod tests {
781    use super::*;
782
783    #[test]
784    fn unknown_extension_errors() {
785        let r = load_from_path("/tmp/no-such-thing.bin");
786        match r {
787            Err(e) => assert!(e.to_string().contains("unsupported")),
788            Ok(_) => panic!("expected error"),
789        }
790    }
791
792    #[test]
793    fn hf_to_gguf_top_level() {
794        assert_eq!(
795            hf_to_gguf_name("model.embed_tokens.weight").as_deref(),
796            Some("token_embd.weight")
797        );
798        assert_eq!(
799            hf_to_gguf_name("model.norm.weight").as_deref(),
800            Some("output_norm.weight")
801        );
802        assert_eq!(
803            hf_to_gguf_name("lm_head.weight").as_deref(),
804            Some("output.weight")
805        );
806    }
807
808    #[test]
809    fn hf_to_gguf_per_layer() {
810        let cases = [
811            (
812                "model.layers.0.self_attn.q_proj.weight",
813                "blk.0.attn_q.weight",
814            ),
815            (
816                "model.layers.7.self_attn.o_proj.weight",
817                "blk.7.attn_output.weight",
818            ),
819            (
820                "model.layers.3.mlp.gate_proj.weight",
821                "blk.3.ffn_gate.weight",
822            ),
823            (
824                "model.layers.12.mlp.down_proj.weight",
825                "blk.12.ffn_down.weight",
826            ),
827            (
828                "model.layers.4.input_layernorm.weight",
829                "blk.4.attn_norm.weight",
830            ),
831            (
832                "model.layers.4.post_attention_layernorm.weight",
833                "blk.4.ffn_norm.weight",
834            ),
835            (
836                "model.layers.0.self_attn.q_norm.weight",
837                "blk.0.attn_q_norm.weight",
838            ),
839        ];
840        for (hf, gguf) in cases {
841            assert_eq!(
842                hf_to_gguf_name(hf).as_deref(),
843                Some(gguf),
844                "mismatch for {hf}"
845            );
846        }
847    }
848
849    #[test]
850    fn hf_to_gguf_unknown_returns_none() {
851        assert!(hf_to_gguf_name("model.layers.0.some_new_thing.weight").is_none());
852        assert!(hf_to_gguf_name("model.layers.foo.input_layernorm.weight").is_none());
853    }
854
855    #[test]
856    fn mtp_detection() {
857        assert!(is_mtp_weight("mtp_blk.0.attn_q.weight"));
858        assert!(is_mtp_weight("output_mtp_0.weight"));
859        assert!(is_mtp_weight("model.layers.0.mtp_head.weight"));
860        assert!(!is_mtp_weight("blk.0.attn_q.weight"));
861        assert!(!is_mtp_weight("output.weight"));
862    }
863
864    /// Build a tiny GGUF with `qwen35.block_count = 25` and
865    /// `qwen35.nextn_predict_layers = 1`, then verify the loader's
866    /// model-aware detector classifies `blk.24.*` as MTP while
867    /// `blk.0.*` stays in the base model. This is the unsloth /
868    /// DeepSeek-V3 convention — substring-based `is_mtp_weight`
869    /// alone wouldn't catch it.
870    #[test]
871    fn ggml_q4_0_maps_to_packed_scheme() {
872        use rlx_gguf::GgmlType;
873        assert_eq!(
874            ggml_type_to_quant_scheme(GgmlType::Q4_0),
875            Some(rlx_ir::quant::QuantScheme::GgufQ4_0)
876        );
877        assert_eq!(
878            ggml_type_to_quant_scheme(GgmlType::Q8_0),
879            Some(rlx_ir::quant::QuantScheme::GgufQ8_0)
880        );
881    }
882
883    #[test]
884    fn q6k_dequant_matmul_follows_block_probe() {
885        use rlx_ir::quant::QuantScheme;
886        assert!(dequant_matmul_supported(QuantScheme::GgufQ4K));
887        assert_eq!(
888            dequant_matmul_supported(QuantScheme::GgufQ6K),
889            super::probe_q6k_block_dequant()
890        );
891    }
892
893    #[test]
894    fn gguf_loader_threshold_based_mtp_detection() {
895        let mut buf: Vec<u8> = Vec::new();
896        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
897        buf.extend_from_slice(&3u32.to_le_bytes());
898        buf.extend_from_slice(&3u64.to_le_bytes()); // 3 tensors
899        buf.extend_from_slice(&3u64.to_le_bytes()); // 3 KV
900        // KV: general.architecture = "qwen35" (type 8 = string)
901        let write_string = |buf: &mut Vec<u8>, k: &str, v: &str| {
902            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
903            buf.extend_from_slice(k.as_bytes());
904            buf.extend_from_slice(&8u32.to_le_bytes());
905            buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
906            buf.extend_from_slice(v.as_bytes());
907        };
908        let write_u32 = |buf: &mut Vec<u8>, k: &str, v: u32| {
909            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
910            buf.extend_from_slice(k.as_bytes());
911            buf.extend_from_slice(&4u32.to_le_bytes()); // type 4 = u32
912            buf.extend_from_slice(&v.to_le_bytes());
913        };
914        write_string(&mut buf, "general.architecture", "qwen35");
915        write_u32(&mut buf, "qwen35.block_count", 25);
916        write_u32(&mut buf, "qwen35.nextn_predict_layers", 1);
917        // Three tensors: blk.0.attn_q.weight (main), blk.24.attn_q.weight (MTP),
918        // and token_embd.weight (always main).
919        let write_tensor = |buf: &mut Vec<u8>, name: &str, shape: &[usize], off: u64| {
920            buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
921            buf.extend_from_slice(name.as_bytes());
922            buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
923            for &d in shape {
924                buf.extend_from_slice(&(d as u64).to_le_bytes());
925            }
926            buf.extend_from_slice(&0u32.to_le_bytes()); // F32
927            buf.extend_from_slice(&off.to_le_bytes());
928        };
929        write_tensor(&mut buf, "blk.0.attn_q.weight", &[4, 4], 0);
930        write_tensor(&mut buf, "blk.24.attn_q.weight", &[4, 4], 64);
931        write_tensor(&mut buf, "token_embd.weight", &[4, 4], 128);
932        while !buf
933            .len()
934            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
935        {
936            buf.push(0);
937        }
938        // 3 × 4×4 f32 = 192 bytes of data.
939        for _ in 0..(3 * 16) {
940            buf.extend_from_slice(&0.5f32.to_le_bytes());
941        }
942        let path = std::env::temp_dir().join("rlx_mtp_threshold_test.gguf");
943        std::fs::write(&path, &buf).unwrap();
944        let loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
945
946        assert_eq!(loader.mtp_layer_threshold(), Some(24));
947        assert!(!loader.is_mtp_tensor("blk.0.attn_q.weight"));
948        assert!(loader.is_mtp_tensor("blk.24.attn_q.weight"));
949        assert!(!loader.is_mtp_tensor("token_embd.weight"));
950        let mtp = loader.mtp_keys();
951        assert_eq!(mtp, vec!["blk.24.attn_q.weight".to_string()]);
952
953        std::fs::remove_file(&path).ok();
954    }
955
956    /// Synthesize a tiny GGUF file in memory with two GGUF-named
957    /// tensors (`token_embd.weight` and `blk.0.attn_q.weight`) plus
958    /// one MTP weight (`output_mtp_0.weight`). Then verify:
959    ///   1. `take()` resolves the HF names via the mapper.
960    ///   2. `remaining_keys()` hides the MTP weight.
961    ///   3. `mtp_keys()` surfaces it for callers that opt in.
962    #[test]
963    fn gguf_loader_resolves_hf_names_and_skips_mtp() {
964        let mut tensors = Vec::new();
965        let mut data: Vec<f32> = Vec::new();
966
967        // tensor #1: token_embd.weight, shape [3, 4], values 0..12
968        let t1: Vec<f32> = (0..12).map(|x| x as f32).collect();
969        tensors.push(("token_embd.weight", vec![3usize, 4], data.len()));
970        data.extend_from_slice(&t1);
971
972        // tensor #2: blk.0.attn_q.weight, shape [4, 4], values 100..116
973        let t2: Vec<f32> = (100..116).map(|x| x as f32).collect();
974        tensors.push(("blk.0.attn_q.weight", vec![4usize, 4], data.len()));
975        data.extend_from_slice(&t2);
976
977        // tensor #3: output_mtp_0.weight (MTP head) — present but skipped
978        let t3: Vec<f32> = vec![0.5f32; 8];
979        tensors.push(("output_mtp_0.weight", vec![2usize, 4], data.len()));
980        data.extend_from_slice(&t3);
981
982        // Build the GGUF byte stream by hand.
983        let mut buf: Vec<u8> = Vec::new();
984        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
985        buf.extend_from_slice(&3u32.to_le_bytes()); // version
986        buf.extend_from_slice(&(tensors.len() as u64).to_le_bytes());
987        buf.extend_from_slice(&0u64.to_le_bytes()); // kv_count
988
989        // Tensor info section.
990        for (name, shape, _) in &tensors {
991            buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
992            buf.extend_from_slice(name.as_bytes());
993            buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
994            for &d in shape {
995                buf.extend_from_slice(&(d as u64).to_le_bytes());
996            }
997            buf.extend_from_slice(&0u32.to_le_bytes()); // dtype = F32
998            // Offset within the data segment — patched after alignment.
999            buf.extend_from_slice(&0u64.to_le_bytes());
1000        }
1001        // Align to DEFAULT_ALIGNMENT before data section.
1002        while !buf
1003            .len()
1004            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
1005        {
1006            buf.push(0);
1007        }
1008        let data_start = buf.len();
1009        for v in &data {
1010            buf.extend_from_slice(&v.to_le_bytes());
1011        }
1012        // Patch the offsets we wrote as 0 above.
1013        let header = (4 + 4 + 8 + 8) as usize; // magic + version + tensor_count + kv_count
1014        let mut cursor = header;
1015        for (name, shape, byte_off) in &tensors {
1016            let name_len_bytes = 8;
1017            let name_bytes = name.len();
1018            let n_dims_bytes = 4;
1019            let dims_bytes = shape.len() * 8;
1020            let dtype_bytes = 4;
1021            let off_bytes = 8;
1022            let info_size =
1023                name_len_bytes + name_bytes + n_dims_bytes + dims_bytes + dtype_bytes + off_bytes;
1024            let off_field_at = cursor + info_size - off_bytes;
1025            let final_off = (*byte_off * 4) as u64; // f32 byte offset within data segment
1026            for i in 0..8 {
1027                buf[off_field_at + i] = (final_off >> (i * 8)) as u8;
1028            }
1029            cursor += info_size;
1030        }
1031        let _ = data_start;
1032
1033        // Write to a temp file (GgufFile reads from a path).
1034        let path = std::env::temp_dir().join("rlx_test_qwen3_mini.gguf");
1035        std::fs::write(&path, &buf).unwrap();
1036
1037        let mut loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1038        // Pre-MTP: 3 tensors total, MTP is hidden so visible count = 2.
1039        assert_eq!(loader.len(), 2);
1040
1041        // HF-name lookup resolves via the mapper. GGUF reports shapes
1042        // innermost-first while safetensors reports outermost-first;
1043        // byte layout is identical, only the shape label flips. The
1044        // synthetic GGUF here was built with shape `[3, 4]`, so the
1045        // loader hands back `[4, 3]` with the same bytes.
1046        let (out, shape) = loader
1047            .take("model.embed_tokens.weight")
1048            .expect("hf-named token_embd should resolve");
1049        assert_eq!(shape, vec![4, 3]);
1050        assert_eq!(&out, &t1);
1051
1052        let (out, shape) = loader
1053            .take("model.layers.0.self_attn.q_proj.weight")
1054            .expect("hf-named attn_q should resolve");
1055        assert_eq!(shape, vec![4, 4]);
1056        assert_eq!(&out, &t2);
1057
1058        // MTP weight stays out of remaining_keys, in mtp_keys.
1059        assert_eq!(loader.remaining_keys(), Vec::<String>::new());
1060        assert_eq!(loader.mtp_keys(), vec!["output_mtp_0.weight".to_string()]);
1061
1062        // include_mtp(true): MTP weights become visible in
1063        // remaining_keys + drainable via take(), and `take_mtp`
1064        // works for explicit grabs without the flag.
1065        let mut loader2 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1066        loader2.include_mtp(true);
1067        let visible: std::collections::HashSet<String> =
1068            loader2.remaining_keys().into_iter().collect();
1069        assert!(visible.contains("token_embd.weight"));
1070        assert!(visible.contains("blk.0.attn_q.weight"));
1071        assert!(
1072            visible.contains("output_mtp_0.weight"),
1073            "MTP weight should be visible with include_mtp(true)"
1074        );
1075        let (mtp_data, mtp_shape) = loader2.take_mtp("output_mtp_0.weight").unwrap();
1076        assert_eq!(mtp_shape, vec![4, 2]);
1077        assert_eq!(mtp_data, t3);
1078
1079        // include_mtp(false) — default — refuses MTP via take().
1080        let mut loader3 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1081        let err = loader3.take("output_mtp_0.weight").unwrap_err();
1082        let msg = format!("{err:#}");
1083        assert!(
1084            msg.contains("include_mtp(true)"),
1085            "expected MTP guard error, got: {msg}"
1086        );
1087
1088        std::fs::remove_file(&path).ok();
1089    }
1090
1091    #[test]
1092    fn missing_gguf_file_errors() {
1093        // .gguf is now a known extension → error comes from `open`,
1094        // not from the dispatcher.
1095        let r = load_from_path("/tmp/no-such-thing-rlx-gguf-test.gguf");
1096        assert!(r.is_err());
1097    }
1098}