Skip to main content

rlx_models_core/
weight_loader.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Pluggable weight loader trait (plan #56).
17//!
18//! Borrowed from MAX's `max/python/max/graph/weights/` layout:
19//! `load_safetensors.py`, `load_gguf.py`, plus a `load.py` dispatcher
20//! that detects format from the file extension.
21//!
22//! Today the only impl is safetensors (via the existing
23//! [`WeightMap::from_file`]). Adding a new format = one new struct
24//! that implements [`WeightLoader`] + an extension match in
25//! [`load_from_path`]. The model graph builders take any
26//! `&mut dyn WeightLoader` so they don't care which format the
27//! weights came from.
28
29use anyhow::{Context, Result, anyhow};
30use rlx_gguf::MetaValue;
31use std::collections::HashSet;
32use std::path::Path;
33
34/// Walk the GGUF metadata for `{arch}.block_count -
35/// {arch}.nextn_predict_layers`. Returns `Some(main_layer_count)`
36/// when the file has explicit MTP metadata, else `None`.
37fn compute_mtp_layer_threshold(file: &rlx_gguf::GgufFile) -> Option<u32> {
38    let arch = file
39        .metadata
40        .get("general.architecture")
41        .and_then(MetaValue::as_str)?;
42    let block_count = file
43        .metadata
44        .get(&format!("{arch}.block_count"))
45        .and_then(MetaValue::as_u32)?;
46    let nextn = file
47        .metadata
48        .get(&format!("{arch}.nextn_predict_layers"))
49        .and_then(MetaValue::as_u32)?;
50    if nextn == 0 {
51        return None;
52    }
53    Some(block_count.saturating_sub(nextn))
54}
55
56use crate::gguf_resolve::resolve_gguf_tensor_name;
57use crate::gguf_support::gguf_architecture_str;
58use crate::weight_map::PackedWeightTensor;
59use crate::weight_map::WeightMap;
60use rlx_ir::quant::QuantScheme;
61
62/// Translate a Hugging Face / safetensors-convention tensor name to
63/// the GGUF / llama.cpp convention. Returns `None` when no mapping
64/// exists (caller should treat the name as already-GGUF or as an
65/// architecture-specific weight that this mapper doesn't know about,
66/// e.g. MTP heads — see [`is_mtp_weight`]).
67///
68/// The mapping mirrors the table baked into `llama.cpp`'s
69/// `gguf-py/gguf/tensor_mapping.py` for the LLaMA-family architectures
70/// (Qwen3 reuses it). When adding new architectures, prefer extending
71/// this function over forking it.
72pub fn hf_to_gguf_name(hf: &str) -> Option<String> {
73    // Top-level (non-layer) tensors.
74    match hf {
75        "model.embed_tokens.weight" => return Some("token_embd.weight".into()),
76        "model.norm.weight" => return Some("output_norm.weight".into()),
77        "lm_head.weight" => return Some("output.weight".into()),
78        _ => {}
79    }
80    // Layer tensors: `model.layers.{i}.<tail>.weight` → `blk.{i}.<gguf-tail>.weight`.
81    let rest = hf.strip_prefix("model.layers.")?;
82    let dot = rest.find('.')?;
83    let (idx_str, tail_with_dot) = rest.split_at(dot);
84    let tail = &tail_with_dot[1..]; // skip the '.'
85    let idx: usize = idx_str.parse().ok()?;
86    let gguf_tail = match tail {
87        "input_layernorm.weight" => "attn_norm.weight",
88        // Llama-style: `post_attention_layernorm` ≡ pre-FFN norm.
89        // Gemma 2 disagrees (it has 4 distinct norms) and uses a
90        // dedicated `Gemma2GgufResolver` to override this entry. Don't
91        // add Gemma 2 names here — they'd collide with Llama callers.
92        "post_attention_layernorm.weight" => "ffn_norm.weight",
93        "self_attn.q_proj.weight" => "attn_q.weight",
94        "self_attn.k_proj.weight" => "attn_k.weight",
95        "self_attn.v_proj.weight" => "attn_v.weight",
96        "self_attn.o_proj.weight" => "attn_output.weight",
97        "self_attn.q_proj.bias" => "attn_q.bias",
98        "self_attn.k_proj.bias" => "attn_k.bias",
99        "self_attn.v_proj.bias" => "attn_v.bias",
100        "self_attn.q_norm.weight" => "attn_q_norm.weight",
101        "self_attn.k_norm.weight" => "attn_k_norm.weight",
102        "mlp.gate_proj.weight" => "ffn_gate.weight",
103        "mlp.up_proj.weight" => "ffn_up.weight",
104        "mlp.down_proj.weight" => "ffn_down.weight",
105        _ => return None,
106    };
107    Some(format!("blk.{idx}.{gguf_tail}"))
108}
109
110/// Inverse of [`hf_to_gguf_name`] — translate a GGUF / llama.cpp
111/// tensor name back to the safetensors / HuggingFace convention. Used
112/// by drain-style loaders (e.g. `Qwen3Generator::from_loader`) that
113/// cache weights by name and need the cache key to match what the
114/// builder will ask for.
115pub fn gguf_to_hf_name(gguf: &str) -> Option<String> {
116    match gguf {
117        "token_embd.weight" => return Some("model.embed_tokens.weight".into()),
118        "output_norm.weight" => return Some("model.norm.weight".into()),
119        "output.weight" => return Some("lm_head.weight".into()),
120        _ => {}
121    }
122    let rest = gguf.strip_prefix("blk.")?;
123    let dot = rest.find('.')?;
124    let (idx_str, tail_with_dot) = rest.split_at(dot);
125    let tail = &tail_with_dot[1..];
126    let idx: usize = idx_str.parse().ok()?;
127    let hf_tail = match tail {
128        "attn_norm.weight" => "input_layernorm.weight",
129        "ffn_norm.weight" => "post_attention_layernorm.weight",
130        "attn_q.weight" => "self_attn.q_proj.weight",
131        "attn_k.weight" => "self_attn.k_proj.weight",
132        "attn_v.weight" => "self_attn.v_proj.weight",
133        "attn_output.weight" => "self_attn.o_proj.weight",
134        "attn_q.bias" => "self_attn.q_proj.bias",
135        "attn_k.bias" => "self_attn.k_proj.bias",
136        "attn_v.bias" => "self_attn.v_proj.bias",
137        "attn_q_norm.weight" => "self_attn.q_norm.weight",
138        "attn_k_norm.weight" => "self_attn.k_norm.weight",
139        "ffn_gate.weight" => "mlp.gate_proj.weight",
140        "ffn_up.weight" => "mlp.up_proj.weight",
141        "ffn_down.weight" => "mlp.down_proj.weight",
142        _ => return None,
143    };
144    Some(format!("model.layers.{idx}.{hf_tail}"))
145}
146
147/// Arch-aware variant of [`gguf_to_hf_name`]. Falls back to the
148/// generic mapping when `arch` doesn't carry overrides; for arches
149/// whose 1↔1 alias disagrees with the Llama convention (Gemma 2/3/4:
150/// `ffn_norm` is the pre-FFN norm, not the post-attention norm), it
151/// returns the arch-correct HF name instead. Used by drain-style
152/// `from_loader` paths to compute cache keys that match what the
153/// builder will ask for.
154pub fn gguf_to_hf_name_for_arch(gguf: &str, arch: &str) -> Option<String> {
155    if matches!(
156        arch,
157        "gemma2" | "gemma3" | "gemma3n" | "gemma4" | "gemma4moe"
158    ) {
159        match gguf {
160            "token_embd.weight" => return Some("model.embed_tokens.weight".into()),
161            "output_norm.weight" => return Some("model.norm.weight".into()),
162            "output.weight" => return Some("lm_head.weight".into()),
163            _ => {}
164        }
165        let rest = gguf.strip_prefix("blk.")?;
166        let dot = rest.find('.')?;
167        let (idx_str, tail_with_dot) = rest.split_at(dot);
168        let tail = &tail_with_dot[1..];
169        let idx: usize = idx_str.parse().ok()?;
170        let hf_tail = match tail {
171            "attn_norm.weight" => "input_layernorm.weight",
172            "post_attention_norm.weight" => "post_attention_layernorm.weight",
173            "ffn_norm.weight" => "pre_feedforward_layernorm.weight",
174            "post_ffw_norm.weight" => "post_feedforward_layernorm.weight",
175            "attn_q.weight" => "self_attn.q_proj.weight",
176            "attn_k.weight" => "self_attn.k_proj.weight",
177            "attn_v.weight" => "self_attn.v_proj.weight",
178            "attn_output.weight" => "self_attn.o_proj.weight",
179            "ffn_gate.weight" => "mlp.gate_proj.weight",
180            "ffn_up.weight" => "mlp.up_proj.weight",
181            "ffn_down.weight" => "mlp.down_proj.weight",
182            _ => return None,
183        };
184        return Some(format!("model.layers.{idx}.{hf_tail}"));
185    }
186    gguf_to_hf_name(gguf)
187}
188
189/// Match GGUF tensor names that hold a Gemma RMSNorm gain. Covers all
190/// four per-layer norms in the V2/V3/V4 sandwich (attn / post_attention
191/// / ffn / post_ffw) plus the final `output_norm`, in both GGUF-native
192/// and HF spellings — drain order is undefined, so we may see either
193/// convention at the call site.
194fn is_gemma_norm_weight(name: &str) -> bool {
195    if name == "output_norm.weight" || name == "model.norm.weight" {
196        return true;
197    }
198    if let Some(rest) = name
199        .strip_prefix("blk.")
200        .and_then(|r| r.split_once('.').map(|x| x.1))
201    {
202        return matches!(
203            rest,
204            "attn_norm.weight"
205                | "post_attention_norm.weight"
206                | "ffn_norm.weight"
207                | "post_ffw_norm.weight"
208        );
209    }
210    if let Some(rest) = name
211        .strip_prefix("model.layers.")
212        .and_then(|r| r.split_once('.').map(|x| x.1))
213    {
214        return matches!(
215            rest,
216            "input_layernorm.weight"
217                | "post_attention_layernorm.weight"
218                | "pre_feedforward_layernorm.weight"
219                | "post_feedforward_layernorm.weight"
220        );
221    }
222    false
223}
224
225/// True if the GGUF tensor name **looks like** a Multi-Token
226/// Prediction head by its name alone — substring match on
227/// `mtp_*` / `*.mtp` / `output_mtp_*` style. Covers MTP variants
228/// that name their heads explicitly.
229///
230/// **NOT enough on its own** for the most common unsloth /
231/// DeepSeek-V3 convention, which encodes MTP heads as *extra
232/// `blk.N` layers* with N beyond the main `block_count`. For that
233/// case use [`GgufLoader::mtp_layer_threshold`] / the loader's
234/// `is_mtp_tensor` instance method — they read
235/// `{arch}.nextn_predict_layers` from the GGUF metadata and treat
236/// trailing `blk.*` indices accordingly.
237pub fn is_mtp_weight(name: &str) -> bool {
238    name.contains("mtp_") || name.contains(".mtp") || name.starts_with("mtp")
239}
240
241/// Map GGML storage type to RLX packed matmul scheme (K-quants only).
242pub fn ggml_type_to_quant_scheme(dtype: rlx_gguf::GgmlType) -> Option<QuantScheme> {
243    use rlx_gguf::GgmlType;
244    match dtype {
245        GgmlType::Q2K => Some(QuantScheme::GgufQ2K),
246        GgmlType::Q3K => Some(QuantScheme::GgufQ3K),
247        GgmlType::Q4K => Some(QuantScheme::GgufQ4K),
248        GgmlType::Q5K => Some(QuantScheme::GgufQ5K),
249        GgmlType::Q6K => Some(QuantScheme::GgufQ6K),
250        GgmlType::Q8K => Some(QuantScheme::GgufQ8K),
251        GgmlType::Q4_0 => Some(QuantScheme::GgufQ4_0),
252        GgmlType::Q8_0 => Some(QuantScheme::GgufQ8_0),
253        _ => None,
254    }
255}
256
257/// Common interface every weight format must satisfy. Mirrors the
258/// existing `WeightMap` API so the safetensors impl is a one-line
259/// adapter.
260///
261/// Register additional on-disk formats with [`crate::weight_registry::register_weight_format`].
262pub trait WeightLoader: Send {
263    /// Format id (`safetensors`, `gguf`, or a custom registration).
264    fn format_id(&self) -> &'static str {
265        "unknown"
266    }
267    /// Number of distinct weights in the file.
268    fn len(&self) -> usize;
269    fn is_empty(&self) -> bool {
270        self.len() == 0
271    }
272    /// Take the named tensor as `(f32_data, shape)`. Removes from the
273    /// loader so callers can detect "weights I never used."
274    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
275    /// Same as `take` but transposed (last two dims swapped). Most
276    /// safetensors weights are stored row-major-of-PyTorch
277    /// convention, which RLX's IR consumes column-major; this helper
278    /// is the convention-bridge.
279    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)>;
280    /// Take packed K-quant bytes when supported; default returns `None`.
281    fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
282        let _ = key;
283        Ok(None)
284    }
285    /// Borrow packed bytes without marking taken (GGUF mmap path).
286    fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
287        let _ = key;
288        None
289    }
290    /// Names that haven't been taken yet — useful for "did the
291    /// model use every weight?" hygiene checks.
292    fn remaining_keys(&self) -> Vec<String>;
293    /// Architecture name from the underlying file (`general.architecture`
294    /// for GGUF, `None` for safetensors). Drain-style consumers use this
295    /// to pick an arch-specific reverse name mapping when the canonical
296    /// HF name depends on the model family (e.g. Gemma 2's 4 norms per
297    /// layer don't share the Llama 2-norm reverse alias).
298    fn arch_hint(&self) -> Option<&str> {
299        None
300    }
301}
302
303impl WeightLoader for WeightMap {
304    fn format_id(&self) -> &'static str {
305        "safetensors"
306    }
307    fn len(&self) -> usize {
308        Self::len(self)
309    }
310    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
311        Self::take(self, key)
312    }
313    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
314        Self::take_transposed(self, key)
315    }
316    fn remaining_keys(&self) -> Vec<String> {
317        self.keys().map(|s| s.to_string()).collect()
318    }
319}
320
321/// Adapter that lets a HF-safetensors-backed [`WeightLoader`] satisfy
322/// requests phrased in GGUF-style names (`blk.N.attn_q.weight` etc.).
323///
324/// Builders like [`Qwen35Weights::from_loader`] address tensors using
325/// the GGUF / llama.cpp convention; the underlying safetensors file
326/// stores them under HF / PyTorch names (`model.layers.N.self_attn.q_proj.weight`).
327/// This wrapper:
328///
329/// 1. Tries the requested key verbatim (in case it's already-HF or
330///    the file was named GGUF-style).
331/// 2. Tries [`gguf_to_hf_name`] to translate the GGUF key → HF key.
332/// 3. Returns the underlying loader's error otherwise.
333pub struct HfTranslatingLoader<L: WeightLoader> {
334    inner: L,
335}
336
337impl<L: WeightLoader> HfTranslatingLoader<L> {
338    pub fn new(inner: L) -> Self {
339        Self { inner }
340    }
341    pub fn into_inner(self) -> L {
342        self.inner
343    }
344    pub fn inner(&self) -> &L {
345        &self.inner
346    }
347    pub fn inner_mut(&mut self) -> &mut L {
348        &mut self.inner
349    }
350}
351
352impl<L: WeightLoader> WeightLoader for HfTranslatingLoader<L> {
353    fn format_id(&self) -> &'static str {
354        self.inner.format_id()
355    }
356    fn len(&self) -> usize {
357        self.inner.len()
358    }
359    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
360        match self.inner.take(key) {
361            Ok(v) => Ok(v),
362            Err(_) => {
363                if let Some(hf) = gguf_to_hf_name(key) {
364                    return self.inner.take(&hf);
365                }
366                self.inner.take(key)
367            }
368        }
369    }
370    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
371        match self.inner.take_transposed(key) {
372            Ok(v) => Ok(v),
373            Err(_) => {
374                if let Some(hf) = gguf_to_hf_name(key) {
375                    return self.inner.take_transposed(&hf);
376                }
377                self.inner.take_transposed(key)
378            }
379        }
380    }
381    fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
382        self.inner.take_packed(key)
383    }
384    fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
385        self.inner.tensor_bytes_borrowed(key)
386    }
387    fn remaining_keys(&self) -> Vec<String> {
388        self.inner.remaining_keys()
389    }
390}
391
392/// Dispatch on the file extension via [`crate::weight_registry`].
393pub fn load_from_path(path: &str) -> Result<Box<dyn WeightLoader>> {
394    crate::weight_registry::open_weight_loader(Path::new(path))
395}
396
397// ─── GGUF adapter ─────────────────────────────────────────────────
398//
399// Wraps `rlx_gguf::GgufFile` so it satisfies `WeightLoader`. Tracks
400// taken keys in a side-set since `dequant_f32` borrows the file
401// immutably; the alternative — pre-decoding every tensor at load
402// time — defeats the point of GGUF's lazy block layout.
403
404pub struct GgufLoader {
405    file: rlx_gguf::GgufFile,
406    arch: String,
407    taken: HashSet<String>,
408    /// When true, `remaining_keys` / `len` / `take` treat MTP-head
409    /// weights as ordinary tensors instead of hiding them. The base
410    /// qwen3 builder ignores MTP tensors regardless — this flag
411    /// only changes the *visibility* in the `WeightLoader` surface
412    /// so downstream MTP-aware builders can iterate them through
413    /// the standard drain pattern.
414    include_mtp: bool,
415    /// First `blk.N` index that belongs to an MTP head, computed
416    /// from `{arch}.block_count - {arch}.nextn_predict_layers` at
417    /// construction. `None` for files without the metadata key
418    /// (= no MTP heads encoded as trailing blocks).
419    mtp_layer_threshold: Option<u32>,
420}
421
422impl GgufLoader {
423    pub fn from_file(path: &str) -> Result<Self> {
424        let file = crate::gguf_support::load_gguf_file(std::path::Path::new(path))?;
425        let arch = gguf_architecture_str(&file)
426            .unwrap_or("unknown")
427            .to_string();
428        let mtp_layer_threshold = compute_mtp_layer_threshold(&file);
429        Ok(Self {
430            file,
431            arch,
432            taken: HashSet::new(),
433            include_mtp: false,
434            mtp_layer_threshold,
435        })
436    }
437
438    pub fn architecture(&self) -> &str {
439        &self.arch
440    }
441
442    /// First `blk.N` index that the GGUF metadata reports as an MTP
443    /// head, derived from `{arch}.block_count -
444    /// {arch}.nextn_predict_layers`. `None` for files where the
445    /// `nextn_predict_layers` key is absent (= no MTP, or MTP is
446    /// encoded under a different naming scheme — fall back to
447    /// [`is_mtp_weight`] in that case).
448    pub fn mtp_layer_threshold(&self) -> Option<u32> {
449        self.mtp_layer_threshold
450    }
451
452    /// Borrow the underlying parsed `GgufFile` so callers (e.g. arch
453    /// builders that read `general.architecture`-specific keys)
454    /// don't have to re-parse 800+ tensor headers a second time.
455    pub fn file(&self) -> &rlx_gguf::GgufFile {
456        &self.file
457    }
458
459    /// Borrow the raw on-disk byte slice for a tensor without
460    /// marking it taken. Returns `None` if the key doesn't resolve
461    /// or the byte range is invalid. Used by the qwen35 packed-
462    /// upload path to stream K-quant bytes from mmap straight into
463    /// the compiled arena, skipping a per-tensor `Vec<u8>`
464    /// allocation (≈ 16 GB on Qwen3.6-27B Q4_K_M).
465    pub fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
466        let real = self.resolve(key).ok()?;
467        let t = self.file.get(&real)?;
468        self.file.tensor_bytes(t).ok()
469    }
470
471    /// Variant of [`Self::take_packed`] that returns only the
472    /// `(scheme, shape)` metadata without copying bytes. The caller
473    /// uploads bytes separately via [`Self::tensor_bytes_borrowed`]
474    /// after the graph is compiled — eliminates the per-tensor
475    /// `Vec<u8>` allocation. Marks the tensor taken on success;
476    /// returns `Ok(None)` for non-K-quant dtypes so the caller can
477    /// fall back to the dequant path.
478    pub fn take_packed_metadata(
479        &mut self,
480        key: &str,
481    ) -> Result<Option<(rlx_ir::quant::QuantScheme, Vec<usize>)>> {
482        let real = self.resolve(key)?;
483        if self.taken.contains(&real) {
484            return Err(anyhow!("weight already taken: {key} (→ {real})"));
485        }
486        if !self.include_mtp && self.is_mtp_tensor(&real) {
487            return Err(anyhow!(
488                "refusing to take MTP weight `{real}` without include_mtp(true)"
489            ));
490        }
491        let t = self
492            .file
493            .get(&real)
494            .ok_or_else(|| anyhow!("tensor missing: {real}"))?;
495        let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
496            return Ok(None);
497        };
498        let mut shape = t.shape.clone();
499        shape.reverse();
500        self.taken.insert(real);
501        Ok(Some((scheme, shape)))
502    }
503
504    /// True if `name` is an MTP weight under this file's naming
505    /// scheme. Combines the substring heuristic ([`is_mtp_weight`])
506    /// with the model-aware `blk.N where N >= threshold` check.
507    pub fn is_mtp_tensor(&self, name: &str) -> bool {
508        if is_mtp_weight(name) {
509            return true;
510        }
511        if let Some(thresh) = self.mtp_layer_threshold {
512            if let Some(rest) = name.strip_prefix("blk.") {
513                if let Some(dot) = rest.find('.') {
514                    if let Ok(idx) = rest[..dot].parse::<u32>() {
515                        if idx >= thresh {
516                            return true;
517                        }
518                    }
519                }
520            }
521        }
522        false
523    }
524
525    /// Toggle MTP-weight visibility. With `include = true`, MTP
526    /// heads show up in `remaining_keys()` (and count toward `len()`)
527    /// — drain-style consumers like
528    /// `Qwen3Generator::from_loader` will then pull them into the
529    /// weights cache. Default off so non-MTP models behave exactly
530    /// as before. Call this before any `take()` / drain so the
531    /// inclusion choice is consistent across the load.
532    pub fn include_mtp(&mut self, include: bool) -> &mut Self {
533        self.include_mtp = include;
534        self
535    }
536
537    /// Take a tensor's **packed bytes** (no dequant), plus its
538    /// [`rlx_ir::quant::QuantScheme`] and safetensors-style shape.
539    /// Returns `None` when the tensor is stored uncompressed
540    /// (F32/F16/BF16) — caller should fall back to `take()` for
541    /// those.
542    ///
543    /// Used by the qwen3 builder's *packed-weights mode*: the LM
544    /// head + per-layer matmul weights stay in the arena as raw
545    /// K-quant bytes, and the graph emits
546    /// `Op::DequantMatMul { scheme }` instead of `Op::MatMul` for
547    /// them. Cuts the load-time memory footprint by ~7-9× on
548    /// Q4_K_M / Q6_K models — the unblocker for ≥14 B Qwen3 / Llama
549    /// GGUFs on commodity Macs.
550    pub fn take_packed(&mut self, key: &str) -> Result<Option<PackedWeightTensor>> {
551        let real = self.resolve(key)?;
552        if self.taken.contains(&real) {
553            return Err(anyhow!("weight already taken: {key} (→ {real})"));
554        }
555        if !self.include_mtp && self.is_mtp_tensor(&real) {
556            return Err(anyhow!(
557                "refusing to take MTP weight `{real}` without include_mtp(true)"
558            ));
559        }
560        let t = self
561            .file
562            .get(&real)
563            .ok_or_else(|| anyhow!("tensor missing: {real}"))?;
564        // Map ggml dtype → our QuantScheme. K-quants and Q4_0/Q8_0 can stay
565        // packed on CPU; Q4_1/Q5_* + uncompressed F32/F16/BF16 fall through
566        // F32/F16/BF16 fall through to the dequant path (return
567        // None — caller switches to `take`).
568        let Some(scheme) = ggml_type_to_quant_scheme(t.dtype) else {
569            return Ok(None);
570        };
571        let bytes = self
572            .file
573            .tensor_bytes(t)
574            .with_context(|| format!("read packed bytes for {real}"))?
575            .to_vec();
576        let mut shape = t.shape.clone();
577        // Match the safetensors-style shape ordering applied by
578        // `take` — GGUF stores innermost-first, safetensors stores
579        // outermost-first; byte layout is identical.
580        shape.reverse();
581        self.taken.insert(real);
582        Ok(Some((bytes, scheme, shape)))
583    }
584
585    /// Take a single MTP weight by name. Bypasses the `include_mtp`
586    /// filter so callers can grab specific heads without flipping
587    /// the global visibility. Returns an error if the name isn't a
588    /// recognized MTP weight (use [`take`] for non-MTP keys).
589    pub fn take_mtp(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
590        if !self.is_mtp_tensor(key) {
591            return Err(anyhow!("not an MTP weight under this file's scheme: {key}"));
592        }
593        if !self.file.tensors.contains_key(key) {
594            return Err(anyhow!("MTP weight not found in GGUF: {key}"));
595        }
596        if self.taken.contains(key) {
597            return Err(anyhow!("MTP weight already taken: {key}"));
598        }
599        let (data, raw_shape) = self.file.dequant_f32(key)?;
600        self.taken.insert(key.to_string());
601        let mut shape = raw_shape;
602        shape.reverse();
603        Ok((data, shape))
604    }
605}
606
607impl GgufLoader {
608    /// Resolve a caller-supplied key (HF or GGUF naming) to the
609    /// actual GGUF tensor name via registered architecture resolvers.
610    fn resolve(&self, key: &str) -> Result<String> {
611        resolve_gguf_tensor_name(&self.file, &self.arch, key)
612            .ok_or_else(|| anyhow!("weight not found in GGUF (arch={}): {key}", self.arch))
613    }
614}
615
616impl WeightLoader for GgufLoader {
617    fn format_id(&self) -> &'static str {
618        "gguf"
619    }
620    fn arch_hint(&self) -> Option<&str> {
621        Some(&self.arch)
622    }
623    fn take_packed(&mut self, key: &str) -> Result<Option<crate::weight_map::PackedWeightTensor>> {
624        self.take_packed(key)
625    }
626    fn tensor_bytes_borrowed(&self, key: &str) -> Option<&[u8]> {
627        GgufLoader::tensor_bytes_borrowed(self, key)
628    }
629    fn len(&self) -> usize {
630        self.file
631            .tensors
632            .keys()
633            .filter(|k| !self.taken.contains(*k) && (self.include_mtp || !self.is_mtp_tensor(k)))
634            .count()
635    }
636    fn take(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
637        let real = self.resolve(key)?;
638        if self.taken.contains(&real) {
639            return Err(anyhow!("weight already taken: {key} (→ {real})"));
640        }
641        if !self.include_mtp && self.is_mtp_tensor(&real) {
642            return Err(anyhow!(
643                "refusing to take MTP weight `{real}` without include_mtp(true); \
644                 use loader.take_mtp(...) for explicit MTP grabs or \
645                 loader.include_mtp(true) to include them in drains"
646            ));
647        }
648        let (mut data, raw_shape) = self.file.dequant_f32(&real)?;
649        self.taken.insert(real.clone());
650        // Gemma's GGUF conversion script bakes the `(1 + gamma)` offset
651        // into every norm weight (see llama.cpp `convert_hf_to_gguf.py`
652        // → `GemmaModel.modify_tensors`). HF/safetensors stores raw
653        // gamma. Subtract 1 here so the loader publishes the
654        // safetensors convention and `GemmaRmsNormStage`'s explicit
655        // `+1` stays correct for both sources. Without this the RMS
656        // gain is systematically inflated and logits skew (cosine
657        // ~0.7 vs llama.cpp; structural, not numerical noise).
658        if matches!(
659            self.arch.as_str(),
660            "gemma" | "gemma2" | "gemma3" | "gemma3n" | "gemma4"
661        ) && is_gemma_norm_weight(&real)
662        {
663            for v in data.iter_mut() {
664                *v -= 1.0;
665            }
666        }
667        // GGUF/ggml report tensor shapes innermost-first (`ne[0]` is
668        // the fastest-varying dim) while safetensors reports outermost-
669        // first. The actual byte layout is identical row-major — only
670        // the shape label is reversed. Reverse to match safetensors so
671        // existing builders work unchanged; no data movement.
672        let mut shape = raw_shape;
673        shape.reverse();
674        Ok((data, shape))
675    }
676    /// **BREAKING CHANGE in 0.2.0:** prior to 0.2.0 this method was
677    /// a no-op for GGUF (returned the bytes unchanged with the GGUF
678    /// shape label) which silently produced garbage logits when the
679    /// builder expected `[in, out]` row-major. From 0.2.0 onwards
680    /// `take` normalizes GGUF's reverse-shape convention so this
681    /// method matches the safetensors variant byte-for-byte.
682    /// Downstream code that explicitly worked around the old buggy
683    /// behavior (manually re-transposing the result) must drop that
684    /// workaround.
685    fn take_transposed(&mut self, key: &str) -> Result<(Vec<f32>, Vec<usize>)> {
686        // After the safetensors normalization in `take`, this matches
687        // the WeightMap implementation byte-for-byte.
688        let (data, shape) = self.take(key)?;
689        if shape.len() != 2 {
690            return Err(anyhow!("transpose requires 2D, got {shape:?}"));
691        }
692        let (rows, cols) = (shape[0], shape[1]);
693        let mut t = vec![0f32; data.len()];
694        for i in 0..rows {
695            for j in 0..cols {
696                t[j * rows + i] = data[i * cols + j];
697            }
698        }
699        Ok((t, vec![cols, rows]))
700    }
701    fn remaining_keys(&self) -> Vec<String> {
702        // MTP weights default to invisible — they belong to optional
703        // speculative heads and the base qwen3 builder ignores them.
704        // Callers wanting MTP-aware loading flip `include_mtp(true)`
705        // first, which surfaces them here.
706        self.file
707            .tensors
708            .keys()
709            .filter(|k| {
710                !self.taken.contains(k.as_str()) && (self.include_mtp || !self.is_mtp_tensor(k))
711            })
712            .cloned()
713            .collect()
714    }
715}
716
717impl GgufLoader {
718    /// Tensor names that look like MTP heads under this file's
719    /// scheme (combines the substring heuristic with the
720    /// model-aware `blk.N where N >= threshold` check — see
721    /// [`is_mtp_tensor`](Self::is_mtp_tensor)).
722    /// Returned unfiltered by `remaining_keys` so consumers wanting
723    /// to wire MTP can find them explicitly.
724    pub fn mtp_keys(&self) -> Vec<String> {
725        self.file
726            .tensors
727            .keys()
728            .filter(|k| self.is_mtp_tensor(k))
729            .cloned()
730            .collect()
731    }
732}
733
734#[cfg(test)]
735mod tests {
736    use super::*;
737
738    #[test]
739    fn unknown_extension_errors() {
740        let r = load_from_path("/tmp/no-such-thing.bin");
741        match r {
742            Err(e) => assert!(e.to_string().contains("unsupported")),
743            Ok(_) => panic!("expected error"),
744        }
745    }
746
747    #[test]
748    fn hf_to_gguf_top_level() {
749        assert_eq!(
750            hf_to_gguf_name("model.embed_tokens.weight").as_deref(),
751            Some("token_embd.weight")
752        );
753        assert_eq!(
754            hf_to_gguf_name("model.norm.weight").as_deref(),
755            Some("output_norm.weight")
756        );
757        assert_eq!(
758            hf_to_gguf_name("lm_head.weight").as_deref(),
759            Some("output.weight")
760        );
761    }
762
763    #[test]
764    fn hf_to_gguf_per_layer() {
765        let cases = [
766            (
767                "model.layers.0.self_attn.q_proj.weight",
768                "blk.0.attn_q.weight",
769            ),
770            (
771                "model.layers.7.self_attn.o_proj.weight",
772                "blk.7.attn_output.weight",
773            ),
774            (
775                "model.layers.3.mlp.gate_proj.weight",
776                "blk.3.ffn_gate.weight",
777            ),
778            (
779                "model.layers.12.mlp.down_proj.weight",
780                "blk.12.ffn_down.weight",
781            ),
782            (
783                "model.layers.4.input_layernorm.weight",
784                "blk.4.attn_norm.weight",
785            ),
786            (
787                "model.layers.4.post_attention_layernorm.weight",
788                "blk.4.ffn_norm.weight",
789            ),
790            (
791                "model.layers.0.self_attn.q_norm.weight",
792                "blk.0.attn_q_norm.weight",
793            ),
794        ];
795        for (hf, gguf) in cases {
796            assert_eq!(
797                hf_to_gguf_name(hf).as_deref(),
798                Some(gguf),
799                "mismatch for {hf}"
800            );
801        }
802    }
803
804    #[test]
805    fn hf_to_gguf_unknown_returns_none() {
806        assert!(hf_to_gguf_name("model.layers.0.some_new_thing.weight").is_none());
807        assert!(hf_to_gguf_name("model.layers.foo.input_layernorm.weight").is_none());
808    }
809
810    #[test]
811    fn mtp_detection() {
812        assert!(is_mtp_weight("mtp_blk.0.attn_q.weight"));
813        assert!(is_mtp_weight("output_mtp_0.weight"));
814        assert!(is_mtp_weight("model.layers.0.mtp_head.weight"));
815        assert!(!is_mtp_weight("blk.0.attn_q.weight"));
816        assert!(!is_mtp_weight("output.weight"));
817    }
818
819    /// Build a tiny GGUF with `qwen35.block_count = 25` and
820    /// `qwen35.nextn_predict_layers = 1`, then verify the loader's
821    /// model-aware detector classifies `blk.24.*` as MTP while
822    /// `blk.0.*` stays in the base model. This is the unsloth /
823    /// DeepSeek-V3 convention — substring-based `is_mtp_weight`
824    /// alone wouldn't catch it.
825    #[test]
826    fn ggml_q4_0_maps_to_packed_scheme() {
827        use rlx_gguf::GgmlType;
828        assert_eq!(
829            ggml_type_to_quant_scheme(GgmlType::Q4_0),
830            Some(rlx_ir::quant::QuantScheme::GgufQ4_0)
831        );
832        assert_eq!(
833            ggml_type_to_quant_scheme(GgmlType::Q8_0),
834            Some(rlx_ir::quant::QuantScheme::GgufQ8_0)
835        );
836    }
837
838    #[test]
839    fn gguf_loader_threshold_based_mtp_detection() {
840        let mut buf: Vec<u8> = Vec::new();
841        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
842        buf.extend_from_slice(&3u32.to_le_bytes());
843        buf.extend_from_slice(&3u64.to_le_bytes()); // 3 tensors
844        buf.extend_from_slice(&3u64.to_le_bytes()); // 3 KV
845        // KV: general.architecture = "qwen35" (type 8 = string)
846        let write_string = |buf: &mut Vec<u8>, k: &str, v: &str| {
847            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
848            buf.extend_from_slice(k.as_bytes());
849            buf.extend_from_slice(&8u32.to_le_bytes());
850            buf.extend_from_slice(&(v.len() as u64).to_le_bytes());
851            buf.extend_from_slice(v.as_bytes());
852        };
853        let write_u32 = |buf: &mut Vec<u8>, k: &str, v: u32| {
854            buf.extend_from_slice(&(k.len() as u64).to_le_bytes());
855            buf.extend_from_slice(k.as_bytes());
856            buf.extend_from_slice(&4u32.to_le_bytes()); // type 4 = u32
857            buf.extend_from_slice(&v.to_le_bytes());
858        };
859        write_string(&mut buf, "general.architecture", "qwen35");
860        write_u32(&mut buf, "qwen35.block_count", 25);
861        write_u32(&mut buf, "qwen35.nextn_predict_layers", 1);
862        // Three tensors: blk.0.attn_q.weight (main), blk.24.attn_q.weight (MTP),
863        // and token_embd.weight (always main).
864        let write_tensor = |buf: &mut Vec<u8>, name: &str, shape: &[usize], off: u64| {
865            buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
866            buf.extend_from_slice(name.as_bytes());
867            buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
868            for &d in shape {
869                buf.extend_from_slice(&(d as u64).to_le_bytes());
870            }
871            buf.extend_from_slice(&0u32.to_le_bytes()); // F32
872            buf.extend_from_slice(&off.to_le_bytes());
873        };
874        write_tensor(&mut buf, "blk.0.attn_q.weight", &[4, 4], 0);
875        write_tensor(&mut buf, "blk.24.attn_q.weight", &[4, 4], 64);
876        write_tensor(&mut buf, "token_embd.weight", &[4, 4], 128);
877        while !buf
878            .len()
879            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
880        {
881            buf.push(0);
882        }
883        // 3 × 4×4 f32 = 192 bytes of data.
884        for _ in 0..(3 * 16) {
885            buf.extend_from_slice(&0.5f32.to_le_bytes());
886        }
887        let path = std::env::temp_dir().join("rlx_mtp_threshold_test.gguf");
888        std::fs::write(&path, &buf).unwrap();
889        let loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
890
891        assert_eq!(loader.mtp_layer_threshold(), Some(24));
892        assert!(!loader.is_mtp_tensor("blk.0.attn_q.weight"));
893        assert!(loader.is_mtp_tensor("blk.24.attn_q.weight"));
894        assert!(!loader.is_mtp_tensor("token_embd.weight"));
895        let mtp = loader.mtp_keys();
896        assert_eq!(mtp, vec!["blk.24.attn_q.weight".to_string()]);
897
898        std::fs::remove_file(&path).ok();
899    }
900
901    /// Synthesize a tiny GGUF file in memory with two GGUF-named
902    /// tensors (`token_embd.weight` and `blk.0.attn_q.weight`) plus
903    /// one MTP weight (`output_mtp_0.weight`). Then verify:
904    ///   1. `take()` resolves the HF names via the mapper.
905    ///   2. `remaining_keys()` hides the MTP weight.
906    ///   3. `mtp_keys()` surfaces it for callers that opt in.
907    #[test]
908    fn gguf_loader_resolves_hf_names_and_skips_mtp() {
909        let mut tensors = Vec::new();
910        let mut data: Vec<f32> = Vec::new();
911
912        // tensor #1: token_embd.weight, shape [3, 4], values 0..12
913        let t1: Vec<f32> = (0..12).map(|x| x as f32).collect();
914        tensors.push(("token_embd.weight", vec![3usize, 4], data.len()));
915        data.extend_from_slice(&t1);
916
917        // tensor #2: blk.0.attn_q.weight, shape [4, 4], values 100..116
918        let t2: Vec<f32> = (100..116).map(|x| x as f32).collect();
919        tensors.push(("blk.0.attn_q.weight", vec![4usize, 4], data.len()));
920        data.extend_from_slice(&t2);
921
922        // tensor #3: output_mtp_0.weight (MTP head) — present but skipped
923        let t3: Vec<f32> = vec![0.5f32; 8];
924        tensors.push(("output_mtp_0.weight", vec![2usize, 4], data.len()));
925        data.extend_from_slice(&t3);
926
927        // Build the GGUF byte stream by hand.
928        let mut buf: Vec<u8> = Vec::new();
929        buf.extend_from_slice(&rlx_gguf::GGUF_MAGIC.to_le_bytes());
930        buf.extend_from_slice(&3u32.to_le_bytes()); // version
931        buf.extend_from_slice(&(tensors.len() as u64).to_le_bytes());
932        buf.extend_from_slice(&0u64.to_le_bytes()); // kv_count
933
934        // Tensor info section.
935        for (name, shape, _) in &tensors {
936            buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
937            buf.extend_from_slice(name.as_bytes());
938            buf.extend_from_slice(&(shape.len() as u32).to_le_bytes());
939            for &d in shape {
940                buf.extend_from_slice(&(d as u64).to_le_bytes());
941            }
942            buf.extend_from_slice(&0u32.to_le_bytes()); // dtype = F32
943            // Offset within the data segment — patched after alignment.
944            buf.extend_from_slice(&0u64.to_le_bytes());
945        }
946        // Align to DEFAULT_ALIGNMENT before data section.
947        while !buf
948            .len()
949            .is_multiple_of(rlx_gguf::DEFAULT_ALIGNMENT as usize)
950        {
951            buf.push(0);
952        }
953        let data_start = buf.len();
954        for v in &data {
955            buf.extend_from_slice(&v.to_le_bytes());
956        }
957        // Patch the offsets we wrote as 0 above.
958        let header = (4 + 4 + 8 + 8) as usize; // magic + version + tensor_count + kv_count
959        let mut cursor = header;
960        for (name, shape, byte_off) in &tensors {
961            let name_len_bytes = 8;
962            let name_bytes = name.len();
963            let n_dims_bytes = 4;
964            let dims_bytes = shape.len() * 8;
965            let dtype_bytes = 4;
966            let off_bytes = 8;
967            let info_size =
968                name_len_bytes + name_bytes + n_dims_bytes + dims_bytes + dtype_bytes + off_bytes;
969            let off_field_at = cursor + info_size - off_bytes;
970            let final_off = (*byte_off * 4) as u64; // f32 byte offset within data segment
971            for i in 0..8 {
972                buf[off_field_at + i] = (final_off >> (i * 8)) as u8;
973            }
974            cursor += info_size;
975        }
976        let _ = data_start;
977
978        // Write to a temp file (GgufFile reads from a path).
979        let path = std::env::temp_dir().join("rlx_test_qwen3_mini.gguf");
980        std::fs::write(&path, &buf).unwrap();
981
982        let mut loader = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
983        // Pre-MTP: 3 tensors total, MTP is hidden so visible count = 2.
984        assert_eq!(loader.len(), 2);
985
986        // HF-name lookup resolves via the mapper. GGUF reports shapes
987        // innermost-first while safetensors reports outermost-first;
988        // byte layout is identical, only the shape label flips. The
989        // synthetic GGUF here was built with shape `[3, 4]`, so the
990        // loader hands back `[4, 3]` with the same bytes.
991        let (out, shape) = loader
992            .take("model.embed_tokens.weight")
993            .expect("hf-named token_embd should resolve");
994        assert_eq!(shape, vec![4, 3]);
995        assert_eq!(&out, &t1);
996
997        let (out, shape) = loader
998            .take("model.layers.0.self_attn.q_proj.weight")
999            .expect("hf-named attn_q should resolve");
1000        assert_eq!(shape, vec![4, 4]);
1001        assert_eq!(&out, &t2);
1002
1003        // MTP weight stays out of remaining_keys, in mtp_keys.
1004        assert_eq!(loader.remaining_keys(), Vec::<String>::new());
1005        assert_eq!(loader.mtp_keys(), vec!["output_mtp_0.weight".to_string()]);
1006
1007        // include_mtp(true): MTP weights become visible in
1008        // remaining_keys + drainable via take(), and `take_mtp`
1009        // works for explicit grabs without the flag.
1010        let mut loader2 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1011        loader2.include_mtp(true);
1012        let visible: std::collections::HashSet<String> =
1013            loader2.remaining_keys().into_iter().collect();
1014        assert!(visible.contains("token_embd.weight"));
1015        assert!(visible.contains("blk.0.attn_q.weight"));
1016        assert!(
1017            visible.contains("output_mtp_0.weight"),
1018            "MTP weight should be visible with include_mtp(true)"
1019        );
1020        let (mtp_data, mtp_shape) = loader2.take_mtp("output_mtp_0.weight").unwrap();
1021        assert_eq!(mtp_shape, vec![4, 2]);
1022        assert_eq!(mtp_data, t3);
1023
1024        // include_mtp(false) — default — refuses MTP via take().
1025        let mut loader3 = GgufLoader::from_file(path.to_str().unwrap()).unwrap();
1026        let err = loader3.take("output_mtp_0.weight").unwrap_err();
1027        let msg = format!("{err:#}");
1028        assert!(
1029            msg.contains("include_mtp(true)"),
1030            "expected MTP guard error, got: {msg}"
1031        );
1032
1033        std::fs::remove_file(&path).ok();
1034    }
1035
1036    #[test]
1037    fn missing_gguf_file_errors() {
1038        // .gguf is now a known extension → error comes from `open`,
1039        // not from the dispatcher.
1040        let r = load_from_path("/tmp/no-such-thing-rlx-gguf-test.gguf");
1041        assert!(r.is_err());
1042    }
1043}