Skip to main content

rlx_gguf_convert/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12
13//! Convert tensors from external formats (safetensors, ONNX) into
14//! GGUF with per-tensor quantization. Designed to be called at first
15//! inference load: read the source file once, write a GGUF blob with
16//! a chosen quant scheme, then on subsequent loads dequant the GGUF
17//! directly — cutting both disk footprint and memory at load time
18//! for transformer weights (often ≥4× shrink at Q4_K_M).
19//!
20//! # Quick start
21//!
22//! ```ignore
23//! use rlx_gguf_convert::{Converter, Scheme};
24//!
25//! let report = Converter::from_safetensors("model.safetensors")?
26//!     .default_scheme(Scheme::Q4_K)
27//!     .skip_quant_for(|name, shape| {
28//!         // Tiny 1-D tensors (norms, biases) stay full-precision.
29//!         name.contains("norm") || name.contains("bias") || shape.len() < 2
30//!     })
31//!     .architecture("llama")
32//!     .write_gguf("model.q4_k.gguf")?;
33//! println!("wrote {} tensors, {:.2}× smaller",
34//!          report.tensors,
35//!          report.compression_ratio());
36//! # Ok::<(), anyhow::Error>(())
37//! ```
38//!
39//! # Real-weight benchmarks
40//!
41//! Validated end-to-end against two production checkpoints (mean
42//! cosine is the average of [`Converter::write_gguf`] output
43//! dequantized and compared back to the source values for every
44//! quantized weight tensor; non-quantized tensors round-trip exactly
45//! and aren't included). M2 mini, release build.
46//!
47//! | Model              | Source size      | Scheme | Output | Shrink | Mean cosine | Wall  |
48//! |--------------------|------------------|--------|--------|-------:|------------:|------:|
49//! | Bio_ClinicalBERT   | 416 MB F32       | Q8_0   | 113 MB | 3.75×  |    0.999984 | 0.27s |
50//! | Bio_ClinicalBERT   | 416 MB F32       | Q6_K   |  86 MB | 4.85×  |    0.999815 | 0.22s |
51//! | Bio_ClinicalBERT   | 416 MB F32       | Q4_K   |  59 MB | 7.05×  |    0.996785 | 0.44s |
52//! | Bio_ClinicalBERT   | 416 MB F32       | Q4_0   |  59 MB | 7.05×  |    0.996169 | 0.44s |
53//! | Qwen3-TTS 0.6B     | 1.7 GB BF16      | Q4_K   | 491 MB | 3.55×  |    0.996712 | 3.7s  |
54//!
55//! [`ConvertReport::compression_ratio`] reports source-byte shrink
56//! (BF16 inputs naturally compress less than F32 inputs because
57//! they're already 2× smaller on disk).
58//!
59//! # Per-tensor schemes
60//!
61//! Three priority levels, applied in order:
62//!
63//! 1. Exact-name override — [`Converter::scheme_for_name`].
64//! 2. Predicate override — [`Converter::scheme_for`] returning
65//!    `Some(scheme)` to override or `None` to fall through.
66//! 3. Default — [`Converter::default_scheme`].
67//!
68//! Tensors whose element count doesn't divide the chosen scheme's
69//! block size fall back to F16. Tensors matched by
70//! [`Converter::skip_quant_for`] stay at their source dtype
71//! (preserved via [`Scheme::F32`] / [`Scheme::F16`] / [`Scheme::BF16`]).
72//!
73//! # Crate layout
74//!
75//! * [`Scheme`] / [`Converter`] / [`ConvertReport`] are the public API.
76//! * Source readers gate behind features:
77//!   * `safetensors` (default) — `.safetensors` files.
78//!   * `onnx` — ONNX initializer tensors via `rlx-onnx-import`.
79//! * The encoder side is shared with [`rlx_gguf`], so output
80//!   round-trips through [`rlx_gguf::GgufFile::dequant_f32`].
81
82use std::collections::HashMap;
83use std::path::{Path, PathBuf};
84
85use anyhow::{Context, Result, bail};
86
87pub use rlx_gguf::{GgmlType, MetaValue};
88
89mod source;
90pub use source::{NamedTensor, TensorReader};
91
92/// Quantization scheme to apply to a tensor when converting. Mirrors
93/// the [`GgmlType`] variants we have encoders for.
94///
95/// Variant naming follows the canonical GGUF convention (`Q4_K`,
96/// `Q8_0`, …) so it survives copy/paste from llama.cpp docs and CLI
97/// flags. Parse a name with [`Scheme::parse`]; map to the underlying
98/// [`GgmlType`] with [`Scheme::to_ggml`].
99///
100/// `Mixed`-style presets (mostly-Q4_K with a few critical projections
101/// at Q6_K, llama.cpp's `Q4_K_M`) are not first-class enum variants —
102/// express them with a per-tensor override via
103/// [`Converter::scheme_for`] / [`Converter::scheme_for_name`].
104///
105/// # Picking a scheme
106///
107/// | Scheme  | Bits / elem | When to use |
108/// |---------|-------------|-------------|
109/// | `F32`   | 32          | "Don't touch" — debugging, gold reference. |
110/// | `F16`/`BF16` | 16     | Lossless-ish; default fallback for shape mismatches. |
111/// | `Q8_0`  | 8.5         | Highest decode speed; ~0% accuracy loss. |
112/// | `Q6_K`  | 6.5         | Near-F16 quality at < 50% size. |
113/// | `Q5_K`  | 5.5         | "Best balance" for memory-constrained inference. |
114/// | `Q4_K`  | 4.5         | Standard 4-bit; ~7× shrink vs F32 source. |
115/// | `Q4_0`  | 4.5         | Legacy; faster decode kernels, slightly worse accuracy. |
116/// | `Q3_K`, `Q2_K` | 3.4 / 2.6 | Aggressive shrink; only for tolerant models. |
117#[allow(non_camel_case_types)]
118#[derive(Debug, Clone, Copy, PartialEq, Eq)]
119pub enum Scheme {
120    F32,
121    F16,
122    BF16,
123    Q8_0,
124    Q4_0,
125    Q4_1,
126    Q5_0,
127    Q5_1,
128    Q2_K,
129    Q3_K,
130    Q4_K,
131    Q5_K,
132    Q6_K,
133    Q8_K,
134}
135
136impl Scheme {
137    pub fn to_ggml(self) -> GgmlType {
138        match self {
139            Self::F32 => GgmlType::F32,
140            Self::F16 => GgmlType::F16,
141            Self::BF16 => GgmlType::BF16,
142            Self::Q8_0 => GgmlType::Q8_0,
143            Self::Q4_0 => GgmlType::Q4_0,
144            Self::Q4_1 => GgmlType::Q4_1,
145            Self::Q5_0 => GgmlType::Q5_0,
146            Self::Q5_1 => GgmlType::Q5_1,
147            Self::Q2_K => GgmlType::Q2K,
148            Self::Q3_K => GgmlType::Q3K,
149            Self::Q4_K => GgmlType::Q4K,
150            Self::Q5_K => GgmlType::Q5K,
151            Self::Q6_K => GgmlType::Q6K,
152            Self::Q8_K => GgmlType::Q8K,
153        }
154    }
155
156    /// Parse a scheme name (`"q4_k"`, `"f16"`, …). Case-insensitive.
157    pub fn parse(s: &str) -> Result<Self> {
158        Ok(match s.to_ascii_uppercase().as_str() {
159            "F32" => Self::F32,
160            "F16" => Self::F16,
161            "BF16" => Self::BF16,
162            "Q8_0" => Self::Q8_0,
163            "Q4_0" => Self::Q4_0,
164            "Q4_1" => Self::Q4_1,
165            "Q5_0" => Self::Q5_0,
166            "Q5_1" => Self::Q5_1,
167            "Q2_K" => Self::Q2_K,
168            "Q3_K" => Self::Q3_K,
169            "Q4_K" => Self::Q4_K,
170            "Q5_K" => Self::Q5_K,
171            "Q6_K" => Self::Q6_K,
172            "Q8_K" => Self::Q8_K,
173            other => bail!("unknown scheme {other}"),
174        })
175    }
176
177    /// The required element-count divisor for this scheme. For example,
178    /// Q4_K requires multiples of 256; Q8_0 requires multiples of 32.
179    pub fn block_size(self) -> usize {
180        match self {
181            Self::F32 | Self::F16 | Self::BF16 => 1,
182            Self::Q8_0 | Self::Q4_0 | Self::Q4_1 | Self::Q5_0 | Self::Q5_1 => 32,
183            Self::Q2_K | Self::Q3_K | Self::Q4_K | Self::Q5_K | Self::Q6_K | Self::Q8_K => 256,
184        }
185    }
186}
187
188/// Conversion summary returned by [`Converter::write_gguf`]. Use it
189/// to log compression ratios, generate a per-scheme histogram, or
190/// drive a re-convert pass with different scheme rules.
191///
192/// Byte counts are measured against the actual source-file layout
193/// (e.g. n×2 for BF16 source tensors), not f32-lifted equivalents —
194/// so the ratio matches what a user would see comparing the two
195/// files on disk.
196#[derive(Debug, Clone)]
197pub struct ConvertReport {
198    /// Number of tensors written to the output GGUF.
199    pub tensors: usize,
200    /// Total source-file bytes summed across all converted tensors.
201    pub input_bytes: usize,
202    /// Total bytes occupied by the encoded tensors (does **not**
203    /// include the GGUF header/metadata overhead — that's typically
204    /// well under 0.1% of the data segment).
205    pub output_bytes: usize,
206    /// Per-tensor scheme assignment in the order tensors were written.
207    /// Useful for verifying that overrides matched the right tensors.
208    pub schemes: Vec<(String, Scheme)>,
209    /// Where the GGUF file was written.
210    pub output_path: PathBuf,
211}
212
213impl ConvertReport {
214    /// `input_bytes / output_bytes` — the "shrink factor" most users
215    /// expect to see. Returns 0.0 for empty conversions to avoid a
216    /// divide-by-zero panic.
217    pub fn compression_ratio(&self) -> f64 {
218        if self.output_bytes == 0 {
219            0.0
220        } else {
221            self.input_bytes as f64 / self.output_bytes as f64
222        }
223    }
224}
225
226type SchemeFn = Box<dyn Fn(&str, &[usize]) -> Option<Scheme>>;
227type SkipFn = Box<dyn Fn(&str, &[usize]) -> bool>;
228
229/// Top-level conversion driver. Build with [`Converter::from_reader`]
230/// (or the `from_safetensors` / `from_onnx` convenience constructors
231/// behind their feature gates), set a default + per-tensor scheme,
232/// then [`Converter::write_gguf`].
233///
234/// # Builder ordering
235///
236/// The builder methods are independent; call them in any order. Only
237/// the final [`Converter::write_gguf`] call performs I/O. Predicates
238/// installed by [`Converter::scheme_for`] / [`Converter::skip_quant_for`]
239/// own their captured state (`Fn + 'static`), so the converter is
240/// `Send + 'static` itself.
241///
242/// # Example
243///
244/// ```ignore
245/// use rlx_gguf_convert::{Converter, Scheme};
246///
247/// let report = Converter::from_safetensors("model.safetensors")?
248///     .default_scheme(Scheme::Q4_K)
249///     // Promote the embed + output projection — these dominate
250///     // quality loss on small models.
251///     .scheme_for(|name, _| {
252///         if name.contains("embed") || name.ends_with("lm_head.weight") {
253///             Some(Scheme::Q6_K)
254///         } else {
255///             None
256///         }
257///     })
258///     // Keep biases / norms / 1-D tensors at native precision.
259///     .skip_quant_for(|name, shape| {
260///         shape.len() < 2 || name.contains("norm") || name.contains("bias")
261///     })
262///     .architecture("llama")
263///     .write_gguf("model.q4_k.gguf")?;
264/// # Ok::<(), anyhow::Error>(())
265/// ```
266pub struct Converter {
267    reader: Box<dyn TensorReader>,
268    default_scheme: Scheme,
269    per_tensor: HashMap<String, Scheme>,
270    scheme_fn: Option<SchemeFn>,
271    skip_fn: Option<SkipFn>,
272    arch: Option<String>,
273    meta: Vec<(String, MetaValue)>,
274}
275
276impl Converter {
277    /// Build a converter from any [`TensorReader`].
278    pub fn from_reader(reader: Box<dyn TensorReader>) -> Self {
279        Self {
280            reader,
281            default_scheme: Scheme::Q4_K,
282            per_tensor: HashMap::new(),
283            scheme_fn: None,
284            skip_fn: None,
285            arch: None,
286            meta: Vec::new(),
287        }
288    }
289
290    /// Convenience: open a `.safetensors` file at `path`. Requires the
291    /// `safetensors` feature (on by default).
292    #[cfg(feature = "safetensors")]
293    pub fn from_safetensors(path: impl AsRef<Path>) -> Result<Self> {
294        let reader = source::SafetensorsReader::open(path.as_ref())?;
295        Ok(Self::from_reader(Box::new(reader)))
296    }
297
298    /// Convenience: open a `.onnx` file at `path` and read its
299    /// initializer tensors. Requires the `onnx` feature.
300    #[cfg(feature = "onnx")]
301    pub fn from_onnx(path: impl AsRef<Path>) -> Result<Self> {
302        let reader = source::OnnxReader::open(path.as_ref())?;
303        Ok(Self::from_reader(Box::new(reader)))
304    }
305
306    /// Set the default scheme used when no override matches.
307    pub fn default_scheme(mut self, scheme: Scheme) -> Self {
308        self.default_scheme = scheme;
309        self
310    }
311
312    /// Override the scheme for a specific tensor name.
313    pub fn scheme_for_name(mut self, name: impl Into<String>, scheme: Scheme) -> Self {
314        self.per_tensor.insert(name.into(), scheme);
315        self
316    }
317
318    /// Callback for matching scheme overrides by name + shape. Returns
319    /// `Some(scheme)` to set, `None` to fall through to per-name +
320    /// default. Use for patterns like "every tensor whose name ends
321    /// with `.weight`".
322    pub fn scheme_for<F>(mut self, f: F) -> Self
323    where
324        F: Fn(&str, &[usize]) -> Option<Scheme> + 'static,
325    {
326        self.scheme_fn = Some(Box::new(f));
327        self
328    }
329
330    /// Callback to skip quantization entirely (leave the tensor at
331    /// its native dtype: F32 / F16 / BF16). Common pattern: skip 1-D
332    /// tensors (biases, norms) and any tensor whose element count
333    /// doesn't divide the chosen scheme's block size.
334    pub fn skip_quant_for<F>(mut self, f: F) -> Self
335    where
336        F: Fn(&str, &[usize]) -> bool + 'static,
337    {
338        self.skip_fn = Some(Box::new(f));
339        self
340    }
341
342    /// Set `general.architecture` metadata (e.g. `"llama"`, `"qwen3"`).
343    pub fn architecture(mut self, arch: impl Into<String>) -> Self {
344        self.arch = Some(arch.into());
345        self
346    }
347
348    /// Add a custom metadata key/value pair.
349    pub fn meta(mut self, key: impl Into<String>, value: MetaValue) -> Self {
350        self.meta.push((key.into(), value));
351        self
352    }
353
354    /// Pick the scheme for `(name, shape)` applying overrides in
355    /// priority order: `scheme_for_name` → `scheme_for` → default.
356    fn resolve_scheme(&self, name: &str, shape: &[usize], native: GgmlType) -> Scheme {
357        if let Some(s) = self.per_tensor.get(name) {
358            return *s;
359        }
360        if let Some(f) = self.scheme_fn.as_ref() {
361            if let Some(s) = f(name, shape) {
362                return s;
363            }
364        }
365        if let Some(f) = self.skip_fn.as_ref() {
366            if f(name, shape) {
367                return native_to_scheme(native);
368            }
369        }
370        let elems: usize = shape.iter().product();
371        // If the tensor's shape doesn't divide the chosen scheme's
372        // block size, fall back to F16 — much better than failing the
373        // entire convert. (Embeddings often have a head-aligned final
374        // dimension but bias rows of 1 element, for example.)
375        if !elems.is_multiple_of(self.default_scheme.block_size()) {
376            return Scheme::F16;
377        }
378        self.default_scheme
379    }
380
381    /// Run the conversion and write the output GGUF file.
382    ///
383    /// For each tensor in the source file:
384    /// 1. Read into f32 (lifting from whatever native dtype it was).
385    /// 2. Resolve a [`Scheme`] via the override stack (name → predicate
386    ///    → default), falling back to F16 if the element count doesn't
387    ///    divide the chosen scheme's block size.
388    /// 3. Encode with [`rlx_gguf::quantize`] and stream into the
389    ///    [`rlx_gguf::GgufWriter`].
390    ///
391    /// On success returns a [`ConvertReport`] describing the per-tensor
392    /// scheme assignment, byte counts, and the output path.
393    pub fn write_gguf(self, out: impl AsRef<Path>) -> Result<ConvertReport> {
394        let out_path = out.as_ref().to_path_buf();
395        let names = self.reader.names();
396        let mut writer = rlx_gguf::GgufWriter::new();
397        if let Some(arch) = &self.arch {
398            writer.set_arch(arch);
399        }
400        for (k, v) in &self.meta {
401            writer.set_meta(k.clone(), v.clone());
402        }
403        let mut input_bytes = 0usize;
404        let mut output_bytes = 0usize;
405        let mut schemes: Vec<(String, Scheme)> = Vec::with_capacity(names.len());
406        for name in names {
407            let NamedTensor {
408                name,
409                shape,
410                data,
411                native,
412                source_bytes,
413            } = self
414                .reader
415                .read_tensor(&name)
416                .with_context(|| format!("reading tensor {name}"))?;
417            input_bytes += source_bytes;
418            let scheme = self.resolve_scheme(&name, &shape, native);
419            let dtype = scheme.to_ggml();
420            let bytes = rlx_gguf::quantize(&data, dtype)
421                .with_context(|| format!("quantize tensor {name} → {scheme:?}"))?;
422            output_bytes += bytes.len();
423            writer.add_tensor_bytes(name.clone(), shape, dtype, bytes)?;
424            schemes.push((name, scheme));
425        }
426        writer.write_to_path(&out_path)?;
427        Ok(ConvertReport {
428            tensors: schemes.len(),
429            input_bytes,
430            output_bytes,
431            schemes,
432            output_path: out_path,
433        })
434    }
435}
436
437fn native_to_scheme(dtype: GgmlType) -> Scheme {
438    match dtype {
439        GgmlType::F32 => Scheme::F32,
440        GgmlType::F16 => Scheme::F16,
441        GgmlType::BF16 => Scheme::BF16,
442        // Default for any other native dtype is F16 — we just want to
443        // skip lossy quantization, not preserve some exotic input.
444        _ => Scheme::F16,
445    }
446}
447
448// ─── tests ────────────────────────────────────────────────────────
449
450#[cfg(test)]
451mod tests {
452    use super::*;
453
454    struct StubReader {
455        names: Vec<String>,
456        tensors: HashMap<String, (Vec<f32>, Vec<usize>, GgmlType)>,
457    }
458
459    impl TensorReader for StubReader {
460        fn names(&self) -> Vec<String> {
461            self.names.clone()
462        }
463        fn read_tensor(&self, name: &str) -> Result<NamedTensor> {
464            let (data, shape, native) = self
465                .tensors
466                .get(name)
467                .ok_or_else(|| anyhow::anyhow!("no tensor {name}"))?
468                .clone();
469            let source_bytes = data.len() * 4;
470            Ok(NamedTensor {
471                name: name.to_string(),
472                shape,
473                data,
474                native,
475                source_bytes,
476            })
477        }
478    }
479
480    #[test]
481    fn convert_stub_to_q4_k_roundtrips() {
482        // 2 super-blocks of 256 elements each.
483        let n = 512;
484        let data: Vec<f32> = (0..n).map(|i| (i as f32 - n as f32 / 2.0) * 0.01).collect();
485        let mut tensors = HashMap::new();
486        tensors.insert("w".to_string(), (data.clone(), vec![2, 256], GgmlType::F32));
487        tensors.insert(
488            "bias".to_string(),
489            (vec![0.5, -0.5], vec![2], GgmlType::F32),
490        );
491        let reader = StubReader {
492            names: vec!["w".into(), "bias".into()],
493            tensors,
494        };
495        let tmp = tempfile::NamedTempFile::new().unwrap();
496        let report = Converter::from_reader(Box::new(reader))
497            .default_scheme(Scheme::Q4_K)
498            .skip_quant_for(|_, shape| shape.len() < 2)
499            .architecture("test")
500            .write_gguf(tmp.path())
501            .unwrap();
502        assert_eq!(report.tensors, 2);
503        let parsed = rlx_gguf::GgufFile::from_path(tmp.path()).unwrap();
504        let (out, shape) = parsed.dequant_f32("w").unwrap();
505        assert_eq!(shape, vec![2, 256]);
506        let cos: f32 = {
507            let dot: f32 = data.iter().zip(&out).map(|(a, b)| a * b).sum();
508            let na: f32 = data.iter().map(|x| x * x).sum::<f32>().sqrt();
509            let nb: f32 = out.iter().map(|x| x * x).sum::<f32>().sqrt();
510            dot / (na * nb)
511        };
512        assert!(cos > 0.99, "Q4_K conversion cosine {cos}");
513        // bias was kept at native dtype (F32 here — input was F32 and
514        // the skip-quant predicate matched on shape.len() < 2).
515        assert_eq!(parsed.get("bias").unwrap().dtype, GgmlType::F32);
516    }
517
518    #[test]
519    fn scheme_parse_roundtrip() {
520        for s in [
521            Scheme::F32,
522            Scheme::F16,
523            Scheme::BF16,
524            Scheme::Q8_0,
525            Scheme::Q4_K,
526            Scheme::Q6_K,
527        ] {
528            let name = format!("{s:?}");
529            // Skip variants Rust prints with caps that don't match parse
530            // exactly (e.g. `Q4_K` is already canonical).
531            let parsed = Scheme::parse(&name).unwrap();
532            assert_eq!(parsed, s);
533        }
534    }
535}