Skip to main content

ferrum_quantization/gguf/
loader.rs

1//! `GgufLoader<B>`: implements `WeightLoader<B>` against a GGUF file.
2//!
3//! Bridges the model layer (which addresses weights by ferrum's HuggingFace-
4//! style names) to the on-disk GGUF format (llama.cpp's `blk.{i}.attn_q.weight`
5//! shorthand). Three responsibilities:
6//!
7//!   1. **Name translation** — delegates to `gguf::names::ferrum_to_gguf`
8//!   2. **Tensor materialisation** — uses Phase 1A's `GgufFile::read_tensor`
9//!      then dequant on CPU into `B::Buffer` for `load_tensor`, or wraps
10//!      the QTensor in `GgufLinear<B>` for `load_linear`.
11//!   3. **Fusion** — reproduces the `qkv_proj` / `gate_up_proj` shims the
12//!      model expects: q/k/v split tensors are concatenated row-wise into
13//!      a single fused weight before the Linear is built.
14//!
15//! All paths go through eager dequant-to-fp32 (Phase 1B's strategy).
16//! Phase 1D will add a quant-aware shortcut so Q4_K_M weights can stay
17//! quantised in backend memory; the public `WeightLoader<B>` API stays
18//! the same.
19
20use std::path::Path;
21use std::sync::Arc;
22
23use candle_core::Device;
24use ferrum_kernels::backend::{Backend, BackendQuantGguf, BackendQuantMarlin};
25use ferrum_types::{FerrumError, Result};
26
27use crate::config::QuantConfig;
28use crate::gguf::file::GgufFile;
29use crate::gguf::linear::GgufLinear;
30use crate::gguf::names::{ferrum_to_gguf, gate_up_split_parts, qkv_split_parts};
31use crate::loader::WeightLoader;
32use crate::traits::Linear;
33
34const GGUF_LOAD_TRACE_ENV: &str = "FERRUM_GGUF_LOAD_TRACE";
35
36#[derive(Debug, Clone, Copy, Default, PartialEq, Eq)]
37struct GgufLoaderRuntimeConfig {
38    load_trace: bool,
39}
40
41impl GgufLoaderRuntimeConfig {
42    fn from_env() -> Self {
43        Self::from_env_vars(std::env::vars())
44    }
45
46    fn from_env_vars<I, K, V>(vars: I) -> Self
47    where
48        I: IntoIterator<Item = (K, V)>,
49        K: Into<String>,
50        V: Into<String>,
51    {
52        Self {
53            load_trace: vars
54                .into_iter()
55                .any(|(name, _value)| name.into() == GGUF_LOAD_TRACE_ENV),
56        }
57    }
58}
59
60/// Backend-generic weight loader for GGUF files.
61///
62/// Build with [`GgufLoader::open`]. The underlying file stays mmap'd for
63/// the lifetime of the loader so per-tensor reads only do byte slicing,
64/// not file I/O.
65pub struct GgufLoader<B: Backend + BackendQuantGguf + BackendQuantMarlin> {
66    gguf: Arc<GgufFile>,
67    /// Decode device for `QTensor::dequantize`. We always use CPU here:
68    /// the dequant is followed by `B::from_slice`, which uploads to the
69    /// backend's preferred memory. Going through Metal/CUDA candle paths
70    /// would add a cross-allocator hop with no benefit (Phase 1D revisits).
71    decode_device: Device,
72    runtime_config: GgufLoaderRuntimeConfig,
73    _marker: std::marker::PhantomData<B>,
74}
75
76impl<B: Backend + BackendQuantGguf + BackendQuantMarlin> GgufLoader<B> {
77    /// Open and parse a `.gguf` file. Tensor payloads stay on disk (mmap'd)
78    /// until each `load_tensor` / `load_linear` call.
79    pub fn open(path: impl AsRef<Path>) -> Result<Self> {
80        let gguf = GgufFile::open(path).map_err(candle_to_ferrum)?;
81        Ok(Self {
82            gguf: Arc::new(gguf),
83            decode_device: Device::Cpu,
84            runtime_config: GgufLoaderRuntimeConfig::from_env(),
85            _marker: std::marker::PhantomData,
86        })
87    }
88
89    /// Build from an already-opened [`GgufFile`] (test helper, also useful
90    /// when several loaders share the same mmap).
91    pub fn from_file(gguf: Arc<GgufFile>) -> Self {
92        Self {
93            gguf,
94            decode_device: Device::Cpu,
95            runtime_config: GgufLoaderRuntimeConfig::from_env(),
96            _marker: std::marker::PhantomData,
97        }
98    }
99
100    /// Direct access to the underlying file — exposes metadata + tensor
101    /// descriptor lookups for callers that need them (e.g. a config helper
102    /// that reads `general.architecture` and `<arch>.block_count`).
103    pub fn gguf(&self) -> &GgufFile {
104        &self.gguf
105    }
106
107    // ── Internals ────────────────────────────────────────────────────────
108
109    /// Look up a ferrum-named tensor in the GGUF, returning the GGUF tensor
110    /// name on success.
111    fn locate(&self, ferrum_name: &str) -> Result<String> {
112        let gguf_name = ferrum_to_gguf(ferrum_name).ok_or_else(|| {
113            FerrumError::model(format!(
114                "GgufLoader: unrecognised tensor name '{ferrum_name}' (no GGUF mapping)"
115            ))
116        })?;
117        if !self.gguf.has_tensor(&gguf_name) {
118            return Err(FerrumError::model(format!(
119                "GgufLoader: tensor '{ferrum_name}' (mapped to '{gguf_name}') not present in GGUF"
120            )));
121        }
122        Ok(gguf_name)
123    }
124
125    /// Read a quantized tensor and dequantize to fp32 row-major. Used by
126    /// both `load_tensor` (raw buffer) and the fusion path (concat sources).
127    fn read_dequant(&self, gguf_name: &str) -> Result<Vec<f32>> {
128        let qt = self
129            .gguf
130            .read_tensor(gguf_name, &self.decode_device)
131            .map_err(candle_to_ferrum)?;
132        let dense = qt
133            .dequantize(&self.decode_device)
134            .map_err(candle_to_ferrum)?;
135        let flat = dense.flatten_all().map_err(candle_to_ferrum)?;
136        flat.to_vec1::<f32>().map_err(candle_to_ferrum)
137    }
138
139    /// Look up a tensor's `[rows, cols]` (2-D) without reading the payload.
140    /// Errors if the tensor isn't 2-D — fusion needs row counts to compute
141    /// the combined output dim.
142    fn rows_cols(&self, gguf_name: &str) -> Result<(usize, usize)> {
143        let info = self
144            .gguf
145            .tensor_info(gguf_name)
146            .ok_or_else(|| FerrumError::model(format!("tensor info missing for '{gguf_name}'")))?;
147        let dims = info.shape.dims();
148        if dims.len() != 2 {
149            return Err(FerrumError::model(format!(
150                "expected 2-D tensor for '{gguf_name}', got rank {}",
151                dims.len()
152            )));
153        }
154        Ok((dims[0], dims[1]))
155    }
156
157    /// Build a fused `Linear<B>` by row-concatenating several sub-tensors.
158    /// All parts must share `cols` (in_features); rows (out_features) sum.
159    ///
160    /// Two paths:
161    ///   1. **Fast (quant-fused)** — every part is Q4_K with no bias. The
162    ///      raw super-block bytes are byte-concatenated and handed to
163    ///      `QuantLinear::from_gguf_bytes`, so weights stay quantised in
164    ///      backend memory.
165    ///   2. **Eager (dense-fused)** — fallback. Each part is dequanted to
166    ///      fp32 and concatenated; the result wraps a dense fp16 weight
167    ///      via `GgufLinear::from_dense_rows`.
168    ///
169    /// Why the dual path: an 8B Qwen3 has 36 layers × (qkv + gate_up) of
170    /// ~140M weights apiece — eager-fp32-fusing them inflates 5 GB on disk
171    /// to 25+ GB in RAM, defeating Q4_K_M entirely. The fast path only
172    /// works for Q4K-without-bias which is the vast majority of dense
173    /// transformers; bias-bearing fusions (rare) take the eager hit.
174    fn load_fused(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
175        if let Some(fast) = self.try_load_fused_q4k(parts)? {
176            if self.runtime_config.load_trace {
177                eprintln!("[gguf-load] {:?} → fused-Q4 (homogeneous)", parts);
178            }
179            return Ok(fast);
180        }
181        if let Some(multi) = self.try_load_fused_multi_quant(parts)? {
182            if self.runtime_config.load_trace {
183                eprintln!("[gguf-load] {:?} → MultiQuant (mixed dtype)", parts);
184            }
185            return Ok(multi);
186        }
187        if self.runtime_config.load_trace {
188            eprintln!("[gguf-load] {:?} → eager fp32 fallback ⚠", parts);
189        }
190        self.load_fused_eager(parts)
191    }
192
193    /// Multi-quant fused fast path: each part is a Q4_K or Q6_K tensor
194    /// with no bias. Parts may have **different** quant types (e.g.
195    /// Qwen3 qkv_proj where q+k are Q4_K but v is Q6_K). Builds a
196    /// `MetalQuantStore::Fused` (or whatever the backend's `Fused`
197    /// variant is) so each part stays compact in backend memory and
198    /// gemv dispatches per part with output offsets.
199    fn try_load_fused_multi_quant(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
200        let mut spec: Vec<(ferrum_kernels::backend::GgufQuantType, &[u8], usize)> = Vec::new();
201        let mut cols_check: Option<usize> = None;
202
203        for stem in parts {
204            let weight_name = format!("{stem}.weight");
205            let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
206                FerrumError::model(format!(
207                    "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
208                ))
209            })?;
210            if !self.gguf.has_tensor(&gguf_name) {
211                return Err(FerrumError::model(format!(
212                    "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
213                )));
214            }
215
216            // Bias on a fused part disqualifies the whole multi-quant
217            // path; fall back to eager fusion which already handles bias.
218            let has_bias = ferrum_to_gguf(&format!("{stem}.bias"))
219                .map(|n| self.gguf.has_tensor(&n))
220                .unwrap_or(false);
221            if has_bias {
222                return Ok(None);
223            }
224
225            let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
226                FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
227            })?;
228            let kind = match info.ggml_dtype {
229                candle_core::quantized::GgmlDType::Q4K => {
230                    ferrum_kernels::backend::GgufQuantType::Q4K
231                }
232                candle_core::quantized::GgmlDType::Q6K => {
233                    ferrum_kernels::backend::GgufQuantType::Q6K
234                }
235                _ => return Ok(None), // unsupported quant in this part
236            };
237
238            let dims = info.shape.dims();
239            if dims.len() != 2 {
240                return Ok(None);
241            }
242            let (rows, cols) = (dims[0], dims[1]);
243            if cols % 256 != 0 {
244                return Ok(None);
245            }
246            match cols_check {
247                Some(c) if c != cols => {
248                    return Err(FerrumError::model(format!(
249                        "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
250                    )))
251                }
252                _ => cols_check = Some(cols),
253            }
254
255            // Slice the mmap directly. The slice's lifetime is tied to
256            // `&self.gguf`, which outlives this scope, so the backend
257            // can read the bytes safely without us owning a copy.
258            let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
259                FerrumError::model(format!(
260                    "GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
261                ))
262            })?;
263            spec.push((kind, bytes, rows));
264        }
265
266        let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
267        let parts_view: Vec<(_, &[u8], _)> = spec
268            .iter()
269            .map(|(kind, bytes, rows)| (*kind, *bytes, *rows))
270            .collect();
271        let quant = match crate::QuantLinear::<B>::from_gguf_fused(&parts_view, cols) {
272            Ok(q) => q,
273            Err(_) => return Ok(None), // backend doesn't support Fused
274        };
275        Ok(Some(Box::new(quant)))
276    }
277
278    /// Q4_K fast path for `load_fused`. Returns `Ok(None)` if any part
279    /// disqualifies (non-Q4K dtype, rank != 2, has bias, cols mismatch).
280    fn try_load_fused_q4k(&self, parts: &[String]) -> Result<Option<Box<dyn Linear<B>>>> {
281        let mut fused_bytes: Vec<u8> = Vec::new();
282        let mut total_rows = 0usize;
283        let mut cols_check: Option<usize> = None;
284
285        for stem in parts {
286            let weight_name = format!("{stem}.weight");
287            let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
288                FerrumError::model(format!(
289                    "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
290                ))
291            })?;
292            if !self.gguf.has_tensor(&gguf_name) {
293                return Err(FerrumError::model(format!(
294                    "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
295                )));
296            }
297
298            // Disqualifier 1: bias on this part — can't byte-concat that
299            // into a single QuantLinear.
300            let bias_name = ferrum_to_gguf(&format!("{stem}.bias"))
301                .map(|n| self.gguf.has_tensor(&n))
302                .unwrap_or(false);
303            if bias_name {
304                return Ok(None);
305            }
306
307            let info = self.gguf.tensor_info(&gguf_name).ok_or_else(|| {
308                FerrumError::model(format!("tensor_info missing for '{gguf_name}'"))
309            })?;
310
311            // Disqualifier 2: not Q4K dtype.
312            if !matches!(info.ggml_dtype, candle_core::quantized::GgmlDType::Q4K) {
313                return Ok(None);
314            }
315
316            let dims = info.shape.dims();
317            if dims.len() != 2 {
318                return Ok(None);
319            }
320            let (rows, cols) = (dims[0], dims[1]);
321
322            // Disqualifier 3: cols not a multiple of 256 (Q4K super-block
323            // boundary) — should not happen for Q4K tensors, but guard
324            // anyway so byte-concat produces a valid block stream.
325            if cols % 256 != 0 {
326                return Ok(None);
327            }
328
329            match cols_check {
330                Some(c) if c != cols => {
331                    return Err(FerrumError::model(format!(
332                        "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
333                    )))
334                }
335                _ => cols_check = Some(cols),
336            }
337
338            // Read raw block bytes directly from the mmap (no candle
339            // QTensor intermediate copy). Fused tensors must still be
340            // byte-concatenated into a single buffer, so the fused
341            // payload itself remains a heap allocation — but it's
342            // a one-shot total ≪ MoE expert weights, so the
343            // consequence is negligible.
344            let bytes = self.gguf.tensor_byte_slice(&gguf_name).ok_or_else(|| {
345                FerrumError::model(format!(
346                    "GgufLoader: tensor_byte_slice failed for '{gguf_name}'"
347                ))
348            })?;
349            // Sanity: 144 bytes per super-block, super-blocks = rows * (cols / 256).
350            let expected = rows * (cols / 256) * 144;
351            debug_assert_eq!(
352                bytes.len(),
353                expected,
354                "Q4K byte count mismatch for '{gguf_name}': got {} expected {}",
355                bytes.len(),
356                expected
357            );
358
359            fused_bytes.extend_from_slice(bytes);
360            total_rows += rows;
361        }
362
363        let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
364        let quant = crate::QuantLinear::<B>::from_gguf_bytes(
365            ferrum_kernels::backend::GgufQuantType::Q4K,
366            &fused_bytes,
367            total_rows,
368            cols,
369        )?;
370        Ok(Some(Box::new(quant)))
371    }
372
373    /// Eager (dequant-to-fp32 then concat) fusion. Used for non-Q4K parts
374    /// or parts with bias. See `load_fused` doc for the trade-off.
375    fn load_fused_eager(&self, parts: &[String]) -> Result<Box<dyn Linear<B>>> {
376        let mut fused: Vec<f32> = Vec::new();
377        let mut total_rows = 0usize;
378        let mut cols_check: Option<usize> = None;
379
380        for stem in parts {
381            let weight_name = format!("{stem}.weight");
382            let gguf_name = ferrum_to_gguf(&weight_name).ok_or_else(|| {
383                FerrumError::model(format!(
384                    "GgufLoader: fusion source '{weight_name}' has no GGUF mapping"
385                ))
386            })?;
387            if !self.gguf.has_tensor(&gguf_name) {
388                return Err(FerrumError::model(format!(
389                    "GgufLoader: fusion source '{weight_name}' (gguf '{gguf_name}') missing"
390                )));
391            }
392            let (rows, cols) = self.rows_cols(&gguf_name)?;
393            match cols_check {
394                Some(c) if c != cols => {
395                    return Err(FerrumError::model(format!(
396                        "GgufLoader: fusion in_features mismatch ({c} vs {cols} for '{stem}')"
397                    )))
398                }
399                _ => cols_check = Some(cols),
400            }
401            let data = self.read_dequant(&gguf_name)?;
402            debug_assert_eq!(data.len(), rows * cols);
403            fused.extend_from_slice(&data);
404            total_rows += rows;
405        }
406
407        let cols = cols_check.ok_or_else(|| FerrumError::model("fusion: no parts"))?;
408        Ok(Box::new(GgufLinear::<B>::from_dense_rows(
409            &fused, total_rows, cols,
410        )))
411    }
412}
413
414impl<B: Backend + BackendQuantGguf + BackendQuantMarlin> WeightLoader<B> for GgufLoader<B> {
415    fn load_tensor(&self, name: &str) -> Result<B::Buffer> {
416        let gguf_name = self.locate(name)?;
417        let raw = self.read_dequant(&gguf_name)?;
418        Ok(B::from_slice(&raw))
419    }
420
421    fn load_linear(&self, name: &str) -> Result<Box<dyn Linear<B>>> {
422        // 1) Direct path: <name>.weight exists as a single GGUF tensor.
423        if let Some(gguf_weight) = ferrum_to_gguf(&format!("{name}.weight")) {
424            if self.gguf.has_tensor(&gguf_weight) {
425                // Inspect the on-disk dtype before reading the payload.
426                // Q4_K_M (and future k-quant flavours) get the QuantLinear
427                // path that keeps weights quantised in backend memory;
428                // F16 / F32 / non-Q4-K dtypes fall through to GgufLinear's
429                // eager-dequant DenseLinear path.
430                let info = self.gguf.tensor_info(&gguf_weight).ok_or_else(|| {
431                    FerrumError::model(format!("tensor_info missing for '{gguf_weight}'"))
432                })?;
433                let dims = info.shape.dims();
434                if dims.len() != 2 {
435                    return Err(FerrumError::model(format!(
436                        "GgufLoader::load_linear '{name}': expected rank-2 weight, got rank {}",
437                        dims.len()
438                    )));
439                }
440                let (n_rows, n_cols) = (dims[0], dims[1]);
441
442                let quant_kind = match info.ggml_dtype {
443                    candle_core::quantized::GgmlDType::Q4K => {
444                        Some(ferrum_kernels::backend::GgufQuantType::Q4K)
445                    }
446                    candle_core::quantized::GgmlDType::Q6K => {
447                        Some(ferrum_kernels::backend::GgufQuantType::Q6K)
448                    }
449                    _ => None,
450                };
451                if let Some(kind) = quant_kind {
452                    // Read raw block bytes and hand to QuantLinear.
453                    // Bias on quantised projections is rare in GGUF
454                    // (Qwen2.5 attention biases land as F32), so we
455                    // currently take the bias path only when the bias
456                    // tensor is present AND the weight is non-quantised.
457                    // For quantised weights with bias, fall back to
458                    // eager dequant so Phase 1B's bias support keeps
459                    // working.
460                    let has_bias = ferrum_to_gguf(&format!("{name}.bias"))
461                        .map(|n| self.gguf.has_tensor(&n))
462                        .unwrap_or(false);
463                    if !has_bias {
464                        // Zero-copy: slice the mmap directly. The
465                        // backend's registry (`register_gguf_mmap`)
466                        // recognises the slice as belonging to the
467                        // shared file buffer and returns a `QuantStore`
468                        // that bind-references the big buffer with an
469                        // offset, instead of allocating a fresh device
470                        // copy. Falls back to copy if no registration
471                        // covers this slice.
472                        let bytes = self.gguf.tensor_byte_slice(&gguf_weight).ok_or_else(|| {
473                            FerrumError::model(format!(
474                                "GgufLoader: tensor_byte_slice failed for '{gguf_weight}'"
475                            ))
476                        })?;
477                        let quant =
478                            crate::QuantLinear::<B>::from_gguf_bytes(kind, bytes, n_rows, n_cols)?;
479                        return Ok(Box::new(quant));
480                    }
481                    // else fall through to eager-dequant bias path below
482                }
483
484                let qt = self
485                    .gguf
486                    .read_tensor(&gguf_weight, &self.decode_device)
487                    .map_err(candle_to_ferrum)?;
488                if let Some(gguf_bias) = ferrum_to_gguf(&format!("{name}.bias")) {
489                    if self.gguf.has_tensor(&gguf_bias) {
490                        let bqt = self
491                            .gguf
492                            .read_tensor(&gguf_bias, &self.decode_device)
493                            .map_err(candle_to_ferrum)?;
494                        let linear = GgufLinear::<B>::from_qtensor_with_bias(&qt, &bqt)
495                            .map_err(candle_to_ferrum)?;
496                        return Ok(Box::new(linear));
497                    }
498                }
499                let linear = GgufLinear::<B>::from_qtensor(&qt).map_err(candle_to_ferrum)?;
500                return Ok(Box::new(linear));
501            }
502        }
503
504        // 2) Fusion path: qkv_proj from q_proj/k_proj/v_proj
505        if let Some(layer_prefix) = name.strip_suffix("self_attn.qkv_proj") {
506            let parts = qkv_split_parts(layer_prefix);
507            return self.load_fused(&parts);
508        }
509        // 3) Fusion path: gate_up_proj from gate_proj/up_proj
510        if let Some(layer_prefix) = name.strip_suffix("mlp.gate_up_proj") {
511            let parts = gate_up_split_parts(layer_prefix);
512            return self.load_fused(&parts);
513        }
514
515        Err(FerrumError::model(format!(
516            "GgufLoader: could not load Linear '{name}' — no direct weight, no split components"
517        )))
518    }
519
520    fn has_tensor(&self, name: &str) -> bool {
521        match ferrum_to_gguf(name) {
522            Some(g) => self.gguf.has_tensor(&g),
523            None => false,
524        }
525    }
526
527    fn quant_config(&self) -> Option<&QuantConfig> {
528        // Phase 1C doesn't surface a QuantConfig — every tensor in a GGUF
529        // declares its own dtype (`GgmlDType`) per descriptor, so the
530        // model's existing branching on QuantConfig::method isn't useful
531        // here. Phase 1D may add a derived config if downstream code grows
532        // a need for it.
533        None
534    }
535}
536
537fn candle_to_ferrum(e: candle_core::Error) -> FerrumError {
538    FerrumError::model(format!("candle: {e}"))
539}
540
541#[cfg(test)]
542mod tests {
543    use super::*;
544
545    #[test]
546    fn gguf_loader_runtime_config_parses_load_trace_presence() {
547        let cfg =
548            GgufLoaderRuntimeConfig::from_env_vars([(GGUF_LOAD_TRACE_ENV, ""), ("OTHER", "1")]);
549        assert!(cfg.load_trace);
550
551        let cfg = GgufLoaderRuntimeConfig::from_env_vars([("OTHER", "1")]);
552        assert!(!cfg.load_trace);
553    }
554}