Skip to main content

ferrum_quantization/gguf/
loader.rs

1//! `GgufLoader<B>`: implements `WeightLoader<B>` against a GGUF file.
2//!
3//! Bridges the model layer (which addresses weights by ferrum's HuggingFace-
4//! style names) to the on-disk GGUF format (llama.cpp's `blk.{i}.attn_q.weight`
5//! shorthand). Three responsibilities:
6//!
7//!   1. **Name translation** — delegates to `gguf::names::ferrum_to_gguf`
8//!   2. **Tensor materialisation** — uses Phase 1A's `GgufFile::read_tensor`
9//!      then dequant on CPU into `B::Buffer` for `load_tensor`, or wraps
10//!      the QTensor in `GgufLinear<B>` for `load_linear`.
11//!   3. **Fusion** — reproduces the `qkv_proj` / `gate_up_proj` shims the
12//!      model expects: q/k/v split tensors are concatenated row-wise into
13//!      a single fused weight before the Linear is built.
14//!
15//! All paths go through eager dequant-to-fp32 (Phase 1B's strategy).
16//! Phase 1D will add a quant-aware shortcut so Q4_K_M weights can stay
17//! quantised in backend memory; the public `WeightLoader<B>` API stays
18//! the same.
19
20use std::path::Path;
21use std::sync::Arc;
22
23use candle_core::Device;
24use ferrum_kernels::backend::Backend;
25use ferrum_types::{FerrumError, Result};
26
27use crate::config::QuantConfig;
28use crate::gguf::file::GgufFile;
29use crate::gguf::linear::GgufLinear;
30use crate::gguf::names::{ferrum_to_gguf, gate_up_split_parts, qkv_split_parts};
31use crate::loader::WeightLoader;
32use crate::traits::Linear;
33
34/// Backend-generic weight loader for GGUF files.
35///
36/// Build with [`GgufLoader::open`]. The underlying file stays mmap'd for
37/// the lifetime of the loader so per-tensor reads only do byte slicing,
38/// not file I/O.
39pub struct GgufLoader<B: Backend> {
40    gguf: Arc<GgufFile>,
41    /// Decode device for `QTensor::dequantize`. We always use CPU here:
42    /// the dequant is followed by `B::from_slice`, which uploads to the
43    /// backend's preferred memory. Going through Metal/CUDA candle paths
44    /// would add a cross-allocator hop with no benefit (Phase 1D revisits).
45    decode_device: Device,
46    _marker: std::marker::PhantomData<B>,
47}
48
49impl<B: Backend> GgufLoader<B> {
50    /// Open and parse a `.gguf` file. Tensor payloads stay on disk (mmap'd)
51    /// until each `load_tensor` / `load_linear` call.
52    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
53        let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
54        Ok(Self {
55            gguf: Arc::new(gguf),
56            decode_device: Device::Cpu,
57            _marker: std::marker::PhantomData,
58        })
59    }
60
61    /// Build from an already-opened [`GgufFile`] (test helper, also useful
62    /// when several loaders share the same mmap).
63    pub fn from_file(gguf: Arc<GgufFile>) -> Self {
64        Self {
65            gguf,
66            decode_device: Device::Cpu,
67            _marker: std::marker::PhantomData,
68        }
69    }
70
71    /// Direct access to the underlying file — exposes metadata + tensor
72    /// descriptor lookups for callers that need them (e.g. a config helper
73    /// that reads `general.architecture` and `<arch>.block_count`).
74    pub fn gguf(&self) -> &GgufFile {
75        &self.gguf
76    }
77
78    // ── Internals ────────────────────────────────────────────────────────
79
80    /// Look up a ferrum-named tensor in the GGUF, returning the GGUF tensor
81    /// name on success.
82    fn locate(&self, ferrum_name: &str) -> Result<String> {
83        let gguf_name = ferrum_to_gguf(ferrum_name).ok_or_else(|| {
84            FerrumError::model(format!(
85                "GgufLoader: unrecognised tensor name '{ferrum_name}' (no GGUF mapping)"
86            ))
87        })?;
88        if !self.gguf.has_tensor(&gguf_name) {
89            return Err(FerrumError::model(format!(
90                "GgufLoader: tensor '{ferrum_name}' (mapped to '{gguf_name}') not present in GGUF"
91            )));
92        }
93        Ok(gguf_name)
94    }
95
96    /// Read a quantized tensor and dequantize to fp32 row-major. Used by
97    /// both `load_tensor` (raw buffer) and the fusion path (concat sources).
98    fn read_dequant(&self, gguf_name: &str) -> Result<Vec<f32>> {
99        let qt = self
100            .gguf
101            .read_tensor(gguf_name, &self.decode_device)
102            .map_err(candle_to_ferrum)?;
103        let dense = qt
104            .dequantize(&self.decode_device)
105            .map_err(candle_to_ferrum)?;
106        let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
107        flat.to_vec1::<f32>().map_err(candle_to_ferrum)
108    }
109
110    /// Look up a tensor's `[rows, cols]` (2-D) without reading the payload.
111    /// Errors if the tensor isn't 2-D — fusion needs row counts to compute
112    /// the combined output dim.
113    fn rows_cols(&self, gguf_name: &str) -> Result<(usize, usize)> {
114        let info = self
115            .gguf
116            .tensor_info(gguf_name)
117            .ok_or_else(|| FerrumError::model(format!("tensor info missing for '{gguf_name}'")))?;
118        let dims = info.shape.dims();
119        if dims.len() != 2 {
120            return Err(FerrumError::model(format!(
121                "expected 2-D tensor for '{gguf_name}', got rank {}",
122                dims.len()
123            )));
124        }
125        Ok((dims[0], dims[1]))
126    }
127
128    /// Build a fused `Linear<B>` by row-concatenating several sub-tensors.
129    /// All parts must share `cols` (in_features); rows (out_features) sum.
130    ///
131    /// Two paths:
132    ///   1. **Fast (quant-fused)** — every part is Q4_K with no bias. The
133    ///      raw super-block bytes are byte-concatenated and handed to
134    ///      `QuantLinear::from_gguf_bytes`, so weights stay quantised in
135    ///      backend memory.
136    ///   2. **Eager (dense-fused)** — fallback. Each part is dequanted to
137    ///      fp32 and concatenated; the result wraps a dense fp16 weight
138    ///      via `GgufLinear::from_dense_rows`.
139    ///
140    /// Why the dual path: an 8B Qwen3 has 36 layers × (qkv + gate_up) of
141    /// ~140M weights apiece — eager-fp32-fusing them inflates 5 GB on disk
142    /// to 25+ GB in RAM, defeating Q4_K_M entirely. The fast path only
143    /// works for Q4K-without-bias which is the vast majority of dense
144    /// transformers; bias-bearing fusions (rare) take the eager hit.
145    fn load_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
146        if let Some(fast) = self.try_load_fused_q4k(parts)? {
147            if std::env::var("FERRUM_GGUF_LOAD_TRACE").is_ok() {
148                eprintln!("[gguf-load] {:?} → fused-Q4 (homogeneous)", parts);
149            }
150            return Ok(fast);
151        }
152        if let Some(multi) = self.try_load_fused_multi_quant(parts)? {
153            if std::env::var("FERRUM_GGUF_LOAD_TRACE").is_ok() {
154                eprintln!("[gguf-load] {:?} → MultiQuant (mixed dtype)", parts);
155            }
156            return Ok(multi);
157        }
158        if std::env::var("FERRUM_GGUF_LOAD_TRACE").is_ok() {
159            eprintln!("[gguf-load] {:?} → eager fp32 fallback ⚠", parts);
160        }
161        self.load_fused_eager(parts)
162    }
163
164    /// Multi-quant fused fast path: each part is a Q4_K or Q6_K tensor
165    /// with no bias. Parts may have **different** quant types (e.g.
166    /// Qwen3 qkv_proj where q+k are Q4_K but v is Q6_K). Builds a
167    /// `MetalQuantStore::Fused` (or whatever the backend's `Fused`
168    /// variant is) so each part stays compact in backend memory and
169    /// gemv dispatches per part with output offsets.
170    fn try_load_fused_multi_quant(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
171        let mut spec: Vec<(ferrum_kernels::backend::GgufQuantType, &[u8], usize)> = Vec::new();
172        let mut cols_check: Option<usize> = None;
173
174        for stem in parts {
175            let weight_name = format!("{stem}.weight");
176            let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
177                FerrumError::model(format!(
178                    "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
179                ))
180            })?;
181            if !self.gguf.has_tensor(&gguf_name) {
182                return Err(FerrumError::model(format!(
183                    "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
184                )));
185            }
186
187            // Bias on a fused part disqualifies the whole multi-quant
188            // path; fall back to eager fusion which already handles bias.
189            let has_bias = ferrum_to_gguf(&format!("{stem}.bias"))
190                .map(|n| self.gguf.has_tensor(&n))
191                .unwrap_or(false);
192            if has_bias {
193                return Ok(None);
194            }
195
196            let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
197                FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
198            })?;
199            let kind = match info.ggml_dtype {
200                candle_core::quantized::GgmlDType::Q4K => {
201                    ferrum_kernels::backend::GgufQuantType::Q4K
202                }
203                candle_core::quantized::GgmlDType::Q6K => {
204                    ferrum_kernels::backend::GgufQuantType::Q6K
205                }
206                _ => return Ok(None), // unsupported quant in this part
207            };
208
209            let dims = info.shape.dims();
210            if dims.len() != 2 {
211                return Ok(None);
212            }
213            let (rows, cols) = (dims[0], dims[1]);
214            if cols % 256 != 0 {
215                return Ok(None);
216            }
217            match cols_check {
218                Some(c) if c != cols => {
219                    return Err(FerrumError::model(format!(
220                        "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
221                    )))
222                }
223                _ => cols_check = Some(cols),
224            }
225
226            // Slice the mmap directly. The slice's lifetime is tied to
227            // `&self.gguf`, which outlives this scope, so the backend
228            // can read the bytes safely without us owning a copy.
229            let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
230                FerrumError::model(format!(
231                    "GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
232                ))
233            })?;
234            spec.push((kind, bytes, rows));
235        }
236
237        let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
238        let parts_view: Vec<(_, &[u8], _)> = spec
239            .iter()
240            .map(|(kind, bytes, rows)| (*kind, *bytes, *rows))
241            .collect();
242        let quant = match crate::QuantLinear::<B>::from_gguf_fused(&parts_view, cols) {
243            Ok(q) => q,
244            Err(_) => return Ok(None), // backend doesn't support Fused
245        };
246        Ok(Some(Box::new(quant)))
247    }
248
249    /// Q4_K fast path for `load_fused`. Returns `Ok(None)` if any part
250    /// disqualifies (non-Q4K dtype, rank != 2, has bias, cols mismatch).
251    fn try_load_fused_q4k(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
252        let mut fused_bytes: Vec<u8> = Vec::new();
253        let mut total_rows = 0usize;
254        let mut cols_check: Option<usize> = None;
255
256        for stem in parts {
257            let weight_name = format!("{stem}.weight");
258            let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
259                FerrumError::model(format!(
260                    "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
261                ))
262            })?;
263            if !self.gguf.has_tensor(&gguf_name) {
264                return Err(FerrumError::model(format!(
265                    "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
266                )));
267            }
268
269            // Disqualifier 1: bias on this part — can't byte-concat that
270            // into a single QuantLinear.
271            let bias_name = ferrum_to_gguf(&format!("{stem}.bias"))
272                .map(|n| self.gguf.has_tensor(&n))
273                .unwrap_or(false);
274            if bias_name {
275                return Ok(None);
276            }
277
278            let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
279                FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
280            })?;
281
282            // Disqualifier 2: not Q4K dtype.
283            if !matches!(info.ggml_dtype, candle_core::quantized::GgmlDType::Q4K) {
284                return Ok(None);
285            }
286
287            let dims = info.shape.dims();
288            if dims.len() != 2 {
289                return Ok(None);
290            }
291            let (rows, cols) = (dims[0], dims[1]);
292
293            // Disqualifier 3: cols not a multiple of 256 (Q4K super-block
294            // boundary) — should not happen for Q4K tensors, but guard
295            // anyway so byte-concat produces a valid block stream.
296            if cols % 256 != 0 {
297                return Ok(None);
298            }
299
300            match cols_check {
301                Some(c) if c != cols => {
302                    return Err(FerrumError::model(format!(
303                        "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
304                    )))
305                }
306                _ => cols_check = Some(cols),
307            }
308
309            // Read raw block bytes directly from the mmap (no candle
310            // QTensor intermediate copy). Fused tensors must still be
311            // byte-concatenated into a single buffer, so the fused
312            // payload itself remains a heap allocation — but it's
313            // a one-shot total ≪ MoE expert weights, so the
314            // consequence is negligible.
315            let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
316                FerrumError::model(format!(
317                    "GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
318                ))
319            })?;
320            // Sanity: 144 bytes per super-block, super-blocks = rows * (cols / 256).
321            let expected = rows * (cols / 256) * 144;
322            debug_assert_eq!(
323                bytes.len(),
324                expected,
325                "Q4K byte count mismatch for '{gguf_name}': got {} expected {}",
326                bytes.len(),
327                expected
328            );
329
330            fused_bytes.extend_from_slice(bytes);
331            total_rows += rows;
332        }
333
334        let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
335        let quant = crate::QuantLinear::<B>::from_gguf_bytes(
336            ferrum_kernels::backend::GgufQuantType::Q4K,
337            &fused_bytes,
338            total_rows,
339            cols,
340        )?;
341        Ok(Some(Box::new(quant)))
342    }
343
344    /// Eager (dequant-to-fp32 then concat) fusion. Used for non-Q4K parts
345    /// or parts with bias. See `load_fused` doc for the trade-off.
346    fn load_fused_eager(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
347        let mut fused: Vec<f32> = Vec::new();
348        let mut total_rows = 0usize;
349        let mut cols_check: Option<usize> = None;
350
351        for stem in parts {
352            let weight_name = format!("{stem}.weight");
353            let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
354                FerrumError::model(format!(
355                    "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
356                ))
357            })?;
358            if !self.gguf.has_tensor(&gguf_name) {
359                return Err(FerrumError::model(format!(
360                    "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
361                )));
362            }
363            let (rows, cols) = self.rows_cols(&gguf_name)?;
364            match cols_check {
365                Some(c) if c != cols => {
366                    return Err(FerrumError::model(format!(
367                        "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
368                    )))
369                }
370                _ => cols_check = Some(cols),
371            }
372            let data = self.read_dequant(&gguf_name)?;
373            debug_assert_eq!(data.len(), rows * cols);
374            fused.extend_from_slice(&data);
375            total_rows += rows;
376        }
377
378        let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
379        Ok(Box::new(GgufLinear::<B>::from_dense_rows(
380            &fused, total_rows, cols,
381        )))
382    }
383}
384
385impl<B: Backend> WeightLoader<B> for GgufLoader<B> {
386    fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
387        let gguf_name = self.locate(name)?;
388        let raw = self.read_dequant(&gguf_name)?;
389        Ok(B::from_slice(&raw))
390    }
391
392    fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
393        // 1) Direct path: <name>.weight exists as a single GGUF tensor.
394        if let Some(gguf_weight) = ferrum_to_gguf(&format!("{name}.weight")) {
395            if self.gguf.has_tensor(&gguf_weight) {
396                // Inspect the on-disk dtype before reading the payload.
397                // Q4_K_M (and future k-quant flavours) get the QuantLinear
398                // path that keeps weights quantised in backend memory;
399                // F16 / F32 / non-Q4-K dtypes fall through to GgufLinear's
400                // eager-dequant DenseLinear path.
401                let info = self.gguf.tensor_info(&gguf_weight).ok_or_else(|| {
402                    FerrumError::model(format!("tensor_info missing for '{gguf_weight}'"))
403                })?;
404                let dims = info.shape.dims();
405                if dims.len() != 2 {
406                    return Err(FerrumError::model(format!(
407                        "GgufLoader::load_linear '{name}': expected rank-2 weight, got rank {}",
408                        dims.len()
409                    )));
410                }
411                let (n_rows, n_cols) = (dims[0], dims[1]);
412
413                let quant_kind = match info.ggml_dtype {
414                    candle_core::quantized::GgmlDType::Q4K => {
415                        Some(ferrum_kernels::backend::GgufQuantType::Q4K)
416                    }
417                    candle_core::quantized::GgmlDType::Q6K => {
418                        Some(ferrum_kernels::backend::GgufQuantType::Q6K)
419                    }
420                    _ => None,
421                };
422                if let Some(kind) = quant_kind {
423                    // Read raw block bytes and hand to QuantLinear.
424                    // Bias on quantised projections is rare in GGUF
425                    // (Qwen2.5 attention biases land as F32), so we
426                    // currently take the bias path only when the bias
427                    // tensor is present AND the weight is non-quantised.
428                    // For quantised weights with bias, fall back to
429                    // eager dequant so Phase 1B's bias support keeps
430                    // working.
431                    let has_bias = ferrum_to_gguf(&format!("{name}.bias"))
432                        .map(|n| self.gguf.has_tensor(&n))
433                        .unwrap_or(false);
434                    if !has_bias {
435                        // Zero-copy: slice the mmap directly. The
436                        // backend's registry (`register_gguf_mmap`)
437                        // recognises the slice as belonging to the
438                        // shared file buffer and returns a `QuantStore`
439                        // that bind-references the big buffer with an
440                        // offset, instead of allocating a fresh device
441                        // copy. Falls back to copy if no registration
442                        // covers this slice.
443                        let bytes = self.gguf.tensor_byte_slice(&gguf_weight).ok_or_else(|| {
444                            FerrumError::model(format!(
445                                "GgufLoader: tensor_byte_slice failed for '{gguf_weight}'"
446                            ))
447                        })?;
448                        let quant =
449                            crate::QuantLinear::<B>::from_gguf_bytes(kind, bytes, n_rows, n_cols)?;
450                        return Ok(Box::new(quant));
451                    }
452                    // else fall through to eager-dequant bias path below
453                }
454
455                let qt = self
456                    .gguf
457                    .read_tensor(&gguf_weight, &self.decode_device)
458                    .map_err(candle_to_ferrum)?;
459                if let Some(gguf_bias) = ferrum_to_gguf(&format!("{name}.bias")) {
460                    if self.gguf.has_tensor(&gguf_bias) {
461                        let bqt = self
462                            .gguf
463                            .read_tensor(&gguf_bias, &self.decode_device)
464                            .map_err(candle_to_ferrum)?;
465                        let linear = GgufLinear::<B>::from_qtensor_with_bias(&qt, &bqt)
466                            .map_err(candle_to_ferrum)?;
467                        return Ok(Box::new(linear));
468                    }
469                }
470                let linear = GgufLinear::<B>::from_qtensor(&qt).map_err(candle_to_ferrum)?;
471                return Ok(Box::new(linear));
472            }
473        }
474
475        // 2) Fusion path: qkv_proj from q_proj/k_proj/v_proj
476        if let Some(layer_prefix) = name.strip_suffix("self_attn.qkv_proj") {
477            let parts = qkv_split_parts(layer_prefix);
478            return self.load_fused(&parts);
479        }
480        // 3) Fusion path: gate_up_proj from gate_proj/up_proj
481        if let Some(layer_prefix) = name.strip_suffix("mlp.gate_up_proj") {
482            let parts = gate_up_split_parts(layer_prefix);
483            return self.load_fused(&parts);
484        }
485
486        Err(FerrumError::model(format!(
487            "GgufLoader: could not load Linear '{name}' — no direct weight, no split components"
488        )))
489    }
490
491    fn has_tensor(&self, name: &str) -> bool {
492        match ferrum_to_gguf(name) {
493            Some(g) => self.gguf.has_tensor(&g),
494            None => false,
495        }
496    }
497
498    fn quant_config(&self) -> Option<&QuantConfig> {
499        // Phase 1C doesn't surface a QuantConfig — every tensor in a GGUF
500        // declares its own dtype (`GgmlDType`) per descriptor, so the
501        // model's existing branching on QuantConfig::method isn't useful
502        // here. Phase 1D may add a derived config if downstream code grows
503        // a need for it.
504        None
505    }
506}
507
508fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
509    FerrumError::model(format!("candle: {e}"))
510}