Skip to main content

rlx_gguf/
lib.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! GGUF (GGML Universal Format) parser, dequantizer, **quantization
17//! encoder**, and **file writer**.
18//!
19//! Standalone: no `rlx-*` dependencies. Higher-level `WeightLoader` /
20//! HF name mapping lives in the separate model-builders repo (see root README).
21//!
22//! Supports GGUF v1, v2, v3 (the live formats). Tensor dtypes
23//! decoded today: F32, F16, BF16, Q8_0, Q4_0, Q4_1, Q5_0, Q5_1,
24//! and the full K-quant family Q2_K, Q3_K, Q4_K, Q5_K, Q6_K, Q8_K.
25//! The encoder side covers every dtype the decoder accepts — see
26//! [`quantize`] and [`writer::GgufWriter`]. Anything outside that set
27//! parses fine but errors on `dequant_f32` so callers know exactly
28//! which key is unreadable; extending is a one-arm match.
29//!
30//! Endianness: little-endian assumed (the only flavor that ships in
31//! practice). The GGUF spec reserves a flag for big-endian; we don't
32//! parse it.
33//!
34//! # Reading a GGUF file
35//!
36//! ```ignore
37//! use rlx_gguf::GgufFile;
38//!
39//! let f = GgufFile::from_path("model.gguf")?;
40//! let (data, shape) = f.dequant_f32("token_embd.weight")?;
41//! ```
42//!
43//! # Writing a GGUF file with mixed quant schemes
44//!
45//! ```ignore
46//! use rlx_gguf::{GgmlType, GgufWriter, MetaValue, quantize};
47//!
48//! let w_floats: Vec<f32> = /* … */;
49//! let bias_floats: Vec<f32> = /* … */;
50//!
51//! let mut w = GgufWriter::new();
52//! w.set_arch("llama");
53//! w.set_meta("general.name", MetaValue::String("my-model".into()));
54//!
55//! // Big projection → 4-bit K-quant. Tiny bias → float-16 (stays at
56//! // native precision so we don't pay 5% accuracy for 32 numbers).
57//! w.add_tensor_bytes("w", vec![4096, 4096], GgmlType::Q4K,
58//!     quantize(&w_floats, GgmlType::Q4K)?)?;
59//! w.add_tensor_bytes("b", vec![4096], GgmlType::F16,
60//!     quantize(&bias_floats, GgmlType::F16)?)?;
61//! w.write_to_path("out.gguf")?;
62//! ```
63//!
64//! For end-to-end safetensors / ONNX → GGUF conversion with per-tensor
65//! scheme rules see the companion [`rlx-gguf-convert`] crate.
66//!
67//! [`rlx-gguf-convert`]: https://docs.rs/rlx-gguf-convert
68
69use std::collections::HashMap;
70use std::fs::File;
71use std::io::{Read, Seek, SeekFrom};
72use std::path::Path;
73
74use anyhow::{Context, Result, anyhow, bail};
75
76pub mod quantize;
77pub mod writer;
78pub use quantize::quantize;
79pub use writer::{GgufWriter, TensorPayload};
80
81pub const GGUF_MAGIC: u32 = 0x4655_4747; // "GGUF" little-endian
82pub const DEFAULT_ALIGNMENT: u64 = 32;
83
84// ─── GGML tensor dtype codes ──────────────────────────────────────
85//
86// Subset of upstream `ggml_type`. Codes are stable; adding new ones
87// is append-only.
88
89#[derive(Debug, Clone, Copy, PartialEq, Eq)]
90#[repr(u32)]
91pub enum GgmlType {
92    F32 = 0,
93    F16 = 1,
94    Q4_0 = 2,
95    Q4_1 = 3,
96    Q5_0 = 6,
97    Q5_1 = 7,
98    Q8_0 = 8,
99    Q8_1 = 9,
100    // K-quants
101    Q2K = 10,
102    Q3K = 11,
103    Q4K = 12,
104    Q5K = 13,
105    Q6K = 14,
106    Q8K = 15,
107    // I-quants
108    IQ2XXS = 16,
109    IQ2XS = 17,
110    IQ3XXS = 18,
111    IQ1S = 19,
112    IQ4NL = 20,
113    IQ3S = 21,
114    IQ2S = 22,
115    IQ4XS = 23,
116    // Plain integer / float dtypes
117    I8 = 24,
118    I16 = 25,
119    I32 = 26,
120    I64 = 27,
121    F64 = 28,
122    IQ1M = 29,
123    BF16 = 30,
124    // Ternary / MX
125    TQ1_0 = 34,
126    TQ2_0 = 35,
127    MXFP4 = 39,
128    NVFP4 = 40,
129    Q1_0 = 41,
130}
131
132impl GgmlType {
133    pub fn from_u32(v: u32) -> Result<Self> {
134        Ok(match v {
135            0 => Self::F32,
136            1 => Self::F16,
137            2 => Self::Q4_0,
138            3 => Self::Q4_1,
139            6 => Self::Q5_0,
140            7 => Self::Q5_1,
141            8 => Self::Q8_0,
142            9 => Self::Q8_1,
143            10 => Self::Q2K,
144            11 => Self::Q3K,
145            12 => Self::Q4K,
146            13 => Self::Q5K,
147            14 => Self::Q6K,
148            15 => Self::Q8K,
149            16 => Self::IQ2XXS,
150            17 => Self::IQ2XS,
151            18 => Self::IQ3XXS,
152            19 => Self::IQ1S,
153            20 => Self::IQ4NL,
154            21 => Self::IQ3S,
155            22 => Self::IQ2S,
156            23 => Self::IQ4XS,
157            24 => Self::I8,
158            25 => Self::I16,
159            26 => Self::I32,
160            27 => Self::I64,
161            28 => Self::F64,
162            29 => Self::IQ1M,
163            30 => Self::BF16,
164            34 => Self::TQ1_0,
165            35 => Self::TQ2_0,
166            39 => Self::MXFP4,
167            40 => Self::NVFP4,
168            41 => Self::Q1_0,
169            other => bail!("unknown ggml type {other}"),
170        })
171    }
172}
173
174// ─── Metadata value ───────────────────────────────────────────────
175
176#[derive(Debug, Clone)]
177pub enum MetaValue {
178    U8(u8),
179    I8(i8),
180    U16(u16),
181    I16(i16),
182    U32(u32),
183    I32(i32),
184    U64(u64),
185    I64(i64),
186    F32(f32),
187    F64(f64),
188    Bool(bool),
189    String(String),
190    Array(Vec<MetaValue>),
191}
192
193impl MetaValue {
194    pub fn as_u32(&self) -> Option<u32> {
195        match self {
196            Self::U32(v) => Some(*v),
197            Self::I32(v) if *v >= 0 => Some(*v as u32),
198            Self::U64(v) if *v <= u32::MAX as u64 => Some(*v as u32),
199            _ => None,
200        }
201    }
202    pub fn as_u64(&self) -> Option<u64> {
203        match self {
204            Self::U32(v) => Some(*v as u64),
205            Self::U64(v) => Some(*v),
206            Self::I64(v) if *v >= 0 => Some(*v as u64),
207            _ => None,
208        }
209    }
210    pub fn as_str(&self) -> Option<&str> {
211        match self {
212            Self::String(s) => Some(s.as_str()),
213            _ => None,
214        }
215    }
216}
217
218// ─── Parsed file ──────────────────────────────────────────────────
219
220#[derive(Debug, Clone)]
221pub struct GgufTensor {
222    pub name: String,
223    pub shape: Vec<usize>, // GGML order (innermost first); kept verbatim.
224    pub dtype: GgmlType,
225    /// Offset within the tensor data segment (relative to `data_start`,
226    /// not to the start of the file).
227    pub offset: u64,
228}
229
230impl GgufTensor {
231    pub fn n_elements(&self) -> usize {
232        self.shape.iter().product()
233    }
234}
235
236pub struct GgufFile {
237    pub version: u32,
238    pub alignment: u64,
239    pub metadata: HashMap<String, MetaValue>,
240    pub tensors: HashMap<String, GgufTensor>,
241    /// Raw tensor-data segment (`data_start` to EOF). Slurped into
242    /// memory — fine for embed-class models. Future mmap path slots
243    /// in here.
244    data: Vec<u8>,
245}
246
247impl GgufFile {
248    pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
249        let path = path.as_ref();
250        let mut f = File::open(path).with_context(|| format!("opening {}", path.display()))?;
251        Self::from_reader(&mut f)
252    }
253
254    /// Merge a multi-part GGUF split (`split.count` > 1) into one in-memory file.
255    pub fn from_split_paths(paths: &[impl AsRef<Path>]) -> Result<Self> {
256        if paths.is_empty() {
257            bail!("from_split_paths: empty path list");
258        }
259        if paths.len() == 1 {
260            return Self::from_path(paths[0].as_ref());
261        }
262        let mut parts: Vec<Self> = paths
263            .iter()
264            .map(|p| Self::from_path(p.as_ref()))
265            .collect::<Result<_>>()?;
266        let mut merged = parts.remove(0);
267        for part in parts {
268            let base = merged.data.len() as u64;
269            for (name, mut t) in part.tensors {
270                t.offset = t.offset.saturating_add(base);
271                if merged.tensors.insert(name, t).is_some() {
272                    bail!("from_split_paths: duplicate tensor name in split merge");
273                }
274            }
275            merged.data.extend_from_slice(&part.data);
276        }
277        merged
278            .metadata
279            .insert("split.count".into(), MetaValue::U32(1));
280        merged.metadata.remove("split.no");
281        Ok(merged)
282    }
283
284    pub fn from_reader<R: Read + Seek>(r: &mut R) -> Result<Self> {
285        let magic = read_u32(r)?;
286        if magic != GGUF_MAGIC {
287            bail!("not a GGUF file (magic {magic:#x})");
288        }
289        let version = read_u32(r)?;
290        if !(1..=3).contains(&version) {
291            bail!("unsupported GGUF version {version}");
292        }
293
294        // v1 used u32 counts; v2/v3 use u64. Same field order.
295        let (tensor_count, kv_count) = if version == 1 {
296            (read_u32(r)? as u64, read_u32(r)? as u64)
297        } else {
298            (read_u64(r)?, read_u64(r)?)
299        };
300
301        let mut metadata = HashMap::with_capacity(kv_count as usize);
302        for _ in 0..kv_count {
303            let key = read_string(r, version)?;
304            let value = read_value(r, version)?;
305            metadata.insert(key, value);
306        }
307
308        let alignment = metadata
309            .get("general.alignment")
310            .and_then(MetaValue::as_u64)
311            .unwrap_or(DEFAULT_ALIGNMENT);
312
313        let mut tensors = HashMap::with_capacity(tensor_count as usize);
314        for _ in 0..tensor_count {
315            let name = read_string(r, version)?;
316            let n_dims = read_u32(r)?;
317            let mut shape = Vec::with_capacity(n_dims as usize);
318            for _ in 0..n_dims {
319                let d = if version == 1 {
320                    read_u32(r)? as u64
321                } else {
322                    read_u64(r)?
323                };
324                shape.push(d as usize);
325            }
326            let dtype_raw = read_u32(r)?;
327            let dtype =
328                GgmlType::from_u32(dtype_raw).with_context(|| format!("tensor {name}: dtype"))?;
329            let offset = read_u64(r)?;
330            tensors.insert(
331                name.clone(),
332                GgufTensor {
333                    name,
334                    shape,
335                    dtype,
336                    offset,
337                },
338            );
339        }
340
341        // Data segment starts at the next `alignment` boundary.
342        let pos = r.stream_position()?;
343        let pad = (alignment - (pos % alignment)) % alignment;
344        r.seek(SeekFrom::Current(pad as i64))?;
345
346        let mut data = Vec::new();
347        r.read_to_end(&mut data)?;
348
349        Ok(Self {
350            version,
351            alignment,
352            metadata,
353            tensors,
354            data,
355        })
356    }
357
358    pub fn keys(&self) -> impl Iterator<Item = &str> {
359        self.tensors.keys().map(|s| s.as_str())
360    }
361
362    pub fn get(&self, name: &str) -> Option<&GgufTensor> {
363        self.tensors.get(name)
364    }
365
366    /// Dequantize a tensor to f32. Shape is verbatim from the file
367    /// (GGML's innermost-first order); transpose / reorder is the
368    /// caller's job since conventions vary by model family.
369    pub fn dequant_f32(&self, name: &str) -> Result<(Vec<f32>, Vec<usize>)> {
370        let t = self
371            .tensors
372            .get(name)
373            .ok_or_else(|| anyhow!("tensor not found: {name}"))?;
374        let n = t.n_elements();
375        let bytes = self.tensor_bytes(t)?;
376        let data = match t.dtype {
377            GgmlType::F32 => dequant_f32_raw(bytes, n)?,
378            GgmlType::F16 => dequant_f16(bytes, n)?,
379            GgmlType::BF16 => dequant_bf16(bytes, n)?,
380            GgmlType::Q8_0 => dequant_q8_0(bytes, n)?,
381            GgmlType::Q4_0 => dequant_q4_0(bytes, n)?,
382            GgmlType::Q4_1 => dequant_q4_1(bytes, n)?,
383            GgmlType::Q5_0 => dequant_q5_0(bytes, n)?,
384            GgmlType::Q5_1 => dequant_q5_1(bytes, n)?,
385            GgmlType::Q4K => dequant_q4_k(bytes, n)?,
386            GgmlType::Q5K => dequant_q5_k(bytes, n)?,
387            GgmlType::Q6K => dequant_q6_k(bytes, n)?,
388            GgmlType::Q8K => dequant_q8_k(bytes, n)?,
389            GgmlType::Q2K => dequant_q2_k(bytes, n)?,
390            GgmlType::Q3K => dequant_q3_k(bytes, n)?,
391            other => bail!("dequant for {other:?} not implemented yet (tensor {name})"),
392        };
393        Ok((data, t.shape.clone()))
394    }
395
396    /// Slice the raw tensor bytes out of the data segment. Public so
397    /// callers writing custom kernels can pass quantized blocks
398    /// straight through.
399    pub fn tensor_bytes(&self, t: &GgufTensor) -> Result<&[u8]> {
400        let n = t.n_elements();
401        let nbytes = bytes_for(t.dtype, n)
402            .ok_or_else(|| anyhow!("element count {n} not aligned to block for {:?}", t.dtype))?;
403        let start = t.offset as usize;
404        let end = start
405            .checked_add(nbytes)
406            .ok_or_else(|| anyhow!("tensor {} byte range overflow", t.name))?;
407        if end > self.data.len() {
408            bail!(
409                "tensor {} extends past data segment ({end} > {})",
410                t.name,
411                self.data.len()
412            );
413        }
414        Ok(&self.data[start..end])
415    }
416}
417
418// ─── byte-count helpers ───────────────────────────────────────────
419
420/// Legacy Q8_0 block size (32 elements).
421pub const QK8_0: usize = 32;
422/// Legacy Q4_0 block size (32 elements).
423pub const QK4_0: usize = 32;
424const QK4_1: usize = 32;
425const QK5_0: usize = 32;
426const QK5_1: usize = 32;
427/// Super-block size shared by every K-quant format. Per llama.cpp's
428/// `ggml-quants.h`. Tensors quantized with `Q{4,5,6,8}_K` must have
429/// an element count divisible by 256.
430/// Super-block size for K-quant formats (256 elements).
431pub const QK_K: usize = 256;
432/// Byte size of the packed scales+mins region in `block_q4_K` /
433/// `block_q5_K` — 8 sub-blocks × 12 bits (6 bits scale + 6 bits min)
434/// = 96 bits = 12 bytes. Same layout in both formats.
435pub const K_SCALE_SIZE: usize = 12;
436
437/// Bytes a tensor of `n` elements occupies in storage for `dtype`.
438/// Returns `None` if `n` doesn't divide the scheme's block size.
439pub fn bytes_for_public(dtype: GgmlType, n: usize) -> Option<usize> {
440    bytes_for(dtype, n)
441}
442
443fn bytes_for(dtype: GgmlType, n: usize) -> Option<usize> {
444    let blk = |qk: usize, blk_bytes: usize| -> Option<usize> {
445        if !n.is_multiple_of(qk) {
446            return None;
447        }
448        Some((n / qk) * blk_bytes)
449    };
450    match dtype {
451        GgmlType::F32 => Some(n * 4),
452        GgmlType::F16 | GgmlType::BF16 => Some(n * 2),
453        GgmlType::Q8_0 => blk(QK8_0, 2 + QK8_0), // f16 d + 32×i8
454        GgmlType::Q4_0 => blk(QK4_0, 2 + QK4_0 / 2), // f16 d + 16×u8
455        GgmlType::Q4_1 => blk(QK4_1, 2 + 2 + QK4_1 / 2), // f16 d + f16 m + 16×u8
456        GgmlType::Q5_0 => blk(QK5_0, 2 + 4 + QK5_0 / 2), // f16 d + u32 qh + 16×u8
457        GgmlType::Q5_1 => blk(QK5_1, 2 + 2 + 4 + QK5_1 / 2), // f16 d + f16 m + u32 qh + 16×u8
458        // K-quants (super-block = 256 elements):
459        GgmlType::Q4K => blk(QK_K, 2 + 2 + K_SCALE_SIZE + QK_K / 2), // d + dmin + 12 scales + 128 quant
460        GgmlType::Q5K => blk(QK_K, 2 + 2 + K_SCALE_SIZE + QK_K / 8 + QK_K / 2), // + 32 high bits
461        GgmlType::Q6K => blk(QK_K, QK_K / 2 + QK_K / 4 + QK_K / 16 + 2), // ql + qh + scales(i8) + d
462        GgmlType::Q8K => blk(QK_K, 4 + QK_K + (QK_K / 16) * 2), // f32 d + 256 i8 + 16 i16 bsums
463        GgmlType::Q2K => blk(QK_K, 2 + 2 + QK_K / 16 + QK_K / 4), // d + dmin + 16 scales + 64 qs
464        GgmlType::Q3K => blk(QK_K, 2 + K_SCALE_SIZE + QK_K / 8 + QK_K / 4), // d + 12 scales + 32 hmask + 64 qs
465        // Anything else: not yet supported. dequant_f32 will reject
466        // these too; tensor_bytes returns None to stay consistent.
467        _ => None,
468    }
469}
470
471// ─── reader helpers ───────────────────────────────────────────────
472
473fn read_u8<R: Read>(r: &mut R) -> Result<u8> {
474    let mut b = [0u8; 1];
475    r.read_exact(&mut b)?;
476    Ok(b[0])
477}
478fn read_i8<R: Read>(r: &mut R) -> Result<i8> {
479    Ok(read_u8(r)? as i8)
480}
481fn read_u16<R: Read>(r: &mut R) -> Result<u16> {
482    let mut b = [0u8; 2];
483    r.read_exact(&mut b)?;
484    Ok(u16::from_le_bytes(b))
485}
486fn read_i16<R: Read>(r: &mut R) -> Result<i16> {
487    let mut b = [0u8; 2];
488    r.read_exact(&mut b)?;
489    Ok(i16::from_le_bytes(b))
490}
491fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
492    let mut b = [0u8; 4];
493    r.read_exact(&mut b)?;
494    Ok(u32::from_le_bytes(b))
495}
496fn read_i32<R: Read>(r: &mut R) -> Result<i32> {
497    let mut b = [0u8; 4];
498    r.read_exact(&mut b)?;
499    Ok(i32::from_le_bytes(b))
500}
501fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
502    let mut b = [0u8; 8];
503    r.read_exact(&mut b)?;
504    Ok(u64::from_le_bytes(b))
505}
506fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
507    let mut b = [0u8; 8];
508    r.read_exact(&mut b)?;
509    Ok(i64::from_le_bytes(b))
510}
511fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
512    let mut b = [0u8; 4];
513    r.read_exact(&mut b)?;
514    Ok(f32::from_le_bytes(b))
515}
516fn read_f64<R: Read>(r: &mut R) -> Result<f64> {
517    let mut b = [0u8; 8];
518    r.read_exact(&mut b)?;
519    Ok(f64::from_le_bytes(b))
520}
521fn read_bool<R: Read>(r: &mut R) -> Result<bool> {
522    Ok(read_u8(r)? != 0)
523}
524
525fn read_string<R: Read>(r: &mut R, version: u32) -> Result<String> {
526    let len = if version == 1 {
527        read_u32(r)? as u64
528    } else {
529        read_u64(r)?
530    };
531    let mut buf = vec![0u8; len as usize];
532    r.read_exact(&mut buf)?;
533    String::from_utf8(buf).map_err(|e| anyhow!("non-UTF8 string: {e}"))
534}
535
536fn read_value<R: Read + Seek>(r: &mut R, version: u32) -> Result<MetaValue> {
537    let ty = read_u32(r)?;
538    Ok(match ty {
539        0 => MetaValue::U8(read_u8(r)?),
540        1 => MetaValue::I8(read_i8(r)?),
541        2 => MetaValue::U16(read_u16(r)?),
542        3 => MetaValue::I16(read_i16(r)?),
543        4 => MetaValue::U32(read_u32(r)?),
544        5 => MetaValue::I32(read_i32(r)?),
545        6 => MetaValue::F32(read_f32(r)?),
546        7 => MetaValue::Bool(read_bool(r)?),
547        8 => MetaValue::String(read_string(r, version)?),
548        9 => {
549            let inner_ty = read_u32(r)?;
550            let len = if version == 1 {
551                read_u32(r)? as u64
552            } else {
553                read_u64(r)?
554            };
555            let mut out = Vec::with_capacity(len as usize);
556            for _ in 0..len {
557                out.push(read_scalar(r, inner_ty, version)?);
558            }
559            MetaValue::Array(out)
560        }
561        10 => MetaValue::U64(read_u64(r)?),
562        11 => MetaValue::I64(read_i64(r)?),
563        12 => MetaValue::F64(read_f64(r)?),
564        other => bail!("unknown metadata value type {other}"),
565    })
566}
567
568fn read_scalar<R: Read + Seek>(r: &mut R, ty: u32, version: u32) -> Result<MetaValue> {
569    // Arrays don't nest (per spec). We re-implement primitive reads
570    // here rather than calling read_value because that one expects a
571    // leading type tag.
572    Ok(match ty {
573        0 => MetaValue::U8(read_u8(r)?),
574        1 => MetaValue::I8(read_i8(r)?),
575        2 => MetaValue::U16(read_u16(r)?),
576        3 => MetaValue::I16(read_i16(r)?),
577        4 => MetaValue::U32(read_u32(r)?),
578        5 => MetaValue::I32(read_i32(r)?),
579        6 => MetaValue::F32(read_f32(r)?),
580        7 => MetaValue::Bool(read_bool(r)?),
581        8 => MetaValue::String(read_string(r, version)?),
582        10 => MetaValue::U64(read_u64(r)?),
583        11 => MetaValue::I64(read_i64(r)?),
584        12 => MetaValue::F64(read_f64(r)?),
585        9 => bail!("nested arrays not allowed in GGUF metadata"),
586        other => bail!("unknown array element type {other}"),
587    })
588}
589
590// ─── dequant kernels ──────────────────────────────────────────────
591//
592// Reference formulas mirror llama.cpp's `dequantize_row_*` in
593// `ggml-quants.c`. Naive on purpose — runs once at load, not hot.
594
595fn dequant_f32_raw(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
596    if bytes.len() != n * 4 {
597        bail!("F32: {} bytes for {n} elements", bytes.len());
598    }
599    let f: &[f32] = bytemuck::cast_slice(bytes);
600    Ok(f.to_vec())
601}
602
603fn dequant_f16(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
604    if bytes.len() != n * 2 {
605        bail!("F16: {} bytes for {n} elements", bytes.len());
606    }
607    let h: &[half::f16] = bytemuck::cast_slice(bytes);
608    Ok(h.iter().map(|x| x.to_f32()).collect())
609}
610
611fn dequant_bf16(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
612    if bytes.len() != n * 2 {
613        bail!("BF16: {} bytes for {n} elements", bytes.len());
614    }
615    let h: &[half::bf16] = bytemuck::cast_slice(bytes);
616    Ok(h.iter().map(|x| x.to_f32()).collect())
617}
618
619fn read_f16_le(b: &[u8]) -> f32 {
620    half::f16::from_le_bytes([b[0], b[1]]).to_f32()
621}
622
623/// Dequant one Q8_0 block (`2 + QK8_0` bytes → `QK8_0` f32 values).
624pub fn dequant_q8_0_block(block: &[u8], out: &mut [f32]) {
625    assert!(block.len() >= 2 + QK8_0 && out.len() >= QK8_0);
626    let d = read_f16_le(&block[..2]);
627    let qs = &block[2..2 + QK8_0];
628    for (o, &q) in out.iter_mut().zip(qs.iter()).take(QK8_0) {
629        *o = d * (q as i8) as f32;
630    }
631}
632
633/// Dequant one Q4_0 block (`2 + QK4_0/2` bytes → `QK4_0` f32 values).
634pub fn dequant_q4_0_block(block: &[u8], out: &mut [f32]) {
635    assert!(block.len() >= 2 + QK4_0 / 2 && out.len() >= QK4_0);
636    let d = read_f16_le(&block[..2]);
637    let qs = &block[2..2 + QK4_0 / 2];
638    // Layout: low nibbles → first half of block, high nibbles → second half.
639    for j in 0..QK4_0 / 2 {
640        let v0 = (qs[j] & 0x0F) as i32 - 8;
641        out[j] = d * v0 as f32;
642    }
643    for j in 0..QK4_0 / 2 {
644        let v1 = (qs[j] >> 4) as i32 - 8;
645        out[QK4_0 / 2 + j] = d * v1 as f32;
646    }
647}
648
649/// Full-tensor Q8_0 dequant (element count must be a multiple of [`QK8_0`]).
650pub fn dequant_q8_0(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
651    if !n.is_multiple_of(QK8_0) {
652        bail!("Q8_0: n={n} not divisible by {QK8_0}");
653    }
654    let nb = n / QK8_0;
655    let blk = 2 + QK8_0;
656    if bytes.len() != nb * blk {
657        bail!("Q8_0: bad byte count");
658    }
659    let mut out = vec![0f32; n];
660    for i in 0..nb {
661        let off = i * blk;
662        dequant_q8_0_block(&bytes[off..off + blk], &mut out[i * QK8_0..(i + 1) * QK8_0]);
663    }
664    Ok(out)
665}
666
667/// Full-tensor Q4_0 dequant (element count must be a multiple of [`QK4_0`]).
668pub fn dequant_q4_0(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
669    if !n.is_multiple_of(QK4_0) {
670        bail!("Q4_0: n={n} not divisible by {QK4_0}");
671    }
672    let nb = n / QK4_0;
673    let blk = 2 + QK4_0 / 2;
674    if bytes.len() != nb * blk {
675        bail!("Q4_0: bad byte count");
676    }
677    let mut out = vec![0f32; n];
678    for i in 0..nb {
679        let off = i * blk;
680        dequant_q4_0_block(&bytes[off..off + blk], &mut out[i * QK4_0..(i + 1) * QK4_0]);
681    }
682    Ok(out)
683}
684
685fn dequant_q4_1(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
686    if !n.is_multiple_of(QK4_1) {
687        bail!("Q4_1: n={n} not divisible by {QK4_1}");
688    }
689    let nb = n / QK4_1;
690    let blk = 2 + 2 + QK4_1 / 2;
691    if bytes.len() != nb * blk {
692        bail!("Q4_1: bad byte count");
693    }
694    let mut out = Vec::with_capacity(n);
695    for i in 0..nb {
696        let off = i * blk;
697        let d = read_f16_le(&bytes[off..off + 2]);
698        let m = read_f16_le(&bytes[off + 2..off + 4]);
699        let qs = &bytes[off + 4..off + 4 + QK4_1 / 2];
700        for j in 0..QK4_1 / 2 {
701            let v0 = (qs[j] & 0x0F) as f32;
702            out.push(d * v0 + m);
703        }
704        for j in 0..QK4_1 / 2 {
705            let v1 = (qs[j] >> 4) as f32;
706            out.push(d * v1 + m);
707        }
708    }
709    Ok(out)
710}
711
712fn dequant_q5_0(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
713    if !n.is_multiple_of(QK5_0) {
714        bail!("Q5_0: n={n} not divisible by {QK5_0}");
715    }
716    let nb = n / QK5_0;
717    let blk = 2 + 4 + QK5_0 / 2;
718    if bytes.len() != nb * blk {
719        bail!("Q5_0: bad byte count");
720    }
721    let mut out = Vec::with_capacity(n);
722    for i in 0..nb {
723        let off = i * blk;
724        let d = read_f16_le(&bytes[off..off + 2]);
725        let qh = u32::from_le_bytes([
726            bytes[off + 2],
727            bytes[off + 3],
728            bytes[off + 4],
729            bytes[off + 5],
730        ]);
731        let qs = &bytes[off + 6..off + 6 + QK5_0 / 2];
732        for j in 0..QK5_0 / 2 {
733            let xh0 = (((qh >> j) & 0x01) as u8) << 4;
734            let v0 = ((qs[j] & 0x0F) | xh0) as i32 - 16;
735            out.push(d * v0 as f32);
736        }
737        for j in 0..QK5_0 / 2 {
738            let xh1 = (((qh >> (j + 16)) & 0x01) as u8) << 4;
739            let v1 = ((qs[j] >> 4) | xh1) as i32 - 16;
740            out.push(d * v1 as f32);
741        }
742    }
743    Ok(out)
744}
745
746fn dequant_q5_1(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
747    if !n.is_multiple_of(QK5_1) {
748        bail!("Q5_1: n={n} not divisible by {QK5_1}");
749    }
750    let nb = n / QK5_1;
751    let blk = 2 + 2 + 4 + QK5_1 / 2;
752    if bytes.len() != nb * blk {
753        bail!("Q5_1: bad byte count");
754    }
755    let mut out = Vec::with_capacity(n);
756    for i in 0..nb {
757        let off = i * blk;
758        let d = read_f16_le(&bytes[off..off + 2]);
759        let m = read_f16_le(&bytes[off + 2..off + 4]);
760        let qh = u32::from_le_bytes([
761            bytes[off + 4],
762            bytes[off + 5],
763            bytes[off + 6],
764            bytes[off + 7],
765        ]);
766        let qs = &bytes[off + 8..off + 8 + QK5_1 / 2];
767        for j in 0..QK5_1 / 2 {
768            let xh0 = (((qh >> j) & 0x01) as u8) << 4;
769            let v0 = ((qs[j] & 0x0F) | xh0) as f32;
770            out.push(d * v0 + m);
771        }
772        for j in 0..QK5_1 / 2 {
773            let xh1 = (((qh >> (j + 16)) & 0x01) as u8) << 4;
774            let v1 = ((qs[j] >> 4) | xh1) as f32;
775            out.push(d * v1 + m);
776        }
777    }
778    Ok(out)
779}
780
781// ─── K-quants ─────────────────────────────────────────────────────
782//
783// All four K-quant formats share a 256-element super-block divided
784// into 8 sub-blocks of 32 elements. Q4_K / Q5_K pack a 6-bit scale
785// and a 6-bit min for each sub-block into the shared 12-byte
786// `scales` region; Q6_K stores 16 signed 8-bit scales directly. The
787// `get_scale_min_k4` helper mirrors the bit-interleaving used in
788// llama.cpp's reference decoder (`ggml-quants.c`):
789//
790//   j < 4:  scale = q[j]   & 0x3F                  min = q[j+4]  & 0x3F
791//   j >= 4: scale = (q[j+4]   & 0x0F) | ((q[j-4] >> 6) << 4)
792//           min   = (q[j+4]   >> 4)   | ((q[j  ] >> 6) << 4)
793fn get_scale_min_k4(j: usize, q: &[u8]) -> (u8, u8) {
794    if j < 4 {
795        (q[j] & 63, q[j + 4] & 63)
796    } else {
797        let d = (q[j + 4] & 0x0F) | ((q[j - 4] >> 6) << 4);
798        let m = (q[j + 4] >> 4) | ((q[j] >> 6) << 4);
799        (d, m)
800    }
801}
802
803/// Dequantize one Q4_K super-block (144 bytes) into `out` (256 f32s).
804pub fn dequant_q4_k_block(block: &[u8], out: &mut [f32; QK_K]) {
805    let d = read_f16_le(&block[0..2]);
806    let dmin = read_f16_le(&block[2..4]);
807    let scales = &block[4..4 + K_SCALE_SIZE];
808    let qs = &block[4 + K_SCALE_SIZE..];
809    let mut is = 0usize;
810    let mut out_i = 0usize;
811    for j in (0..8).step_by(2) {
812        let (sc0, m0) = get_scale_min_k4(j, scales);
813        let (sc1, m1) = get_scale_min_k4(j + 1, scales);
814        let d0 = d * sc0 as f32;
815        let m0f = dmin * m0 as f32;
816        let d1 = d * sc1 as f32;
817        let m1f = dmin * m1 as f32;
818        for l in 0..32 {
819            let q = qs[is + l];
820            out[out_i] = d0 * (q & 0x0F) as f32 - m0f;
821            out_i += 1;
822        }
823        for l in 0..32 {
824            let q = qs[is + l];
825            out[out_i] = d1 * (q >> 4) as f32 - m1f;
826            out_i += 1;
827        }
828        is += 32;
829    }
830}
831
832/// Dequantize one Q5_K super-block (176 bytes) into `out`.
833pub fn dequant_q5_k_block(block: &[u8], out: &mut [f32; QK_K]) {
834    let d = read_f16_le(&block[0..2]);
835    let dmin = read_f16_le(&block[2..4]);
836    let scales = &block[4..4 + K_SCALE_SIZE];
837    let qh = &block[4 + K_SCALE_SIZE..4 + K_SCALE_SIZE + QK_K / 8];
838    let qs = &block[4 + K_SCALE_SIZE + QK_K / 8..];
839    let mut is = 0usize;
840    let mut out_i = 0usize;
841    let mut u1: u8 = 1;
842    let mut u2: u8 = 2;
843    for j in (0..8).step_by(2) {
844        let (sc0, m0) = get_scale_min_k4(j, scales);
845        let (sc1, m1) = get_scale_min_k4(j + 1, scales);
846        let d0 = d * sc0 as f32;
847        let m0f = dmin * m0 as f32;
848        let d1 = d * sc1 as f32;
849        let m1f = dmin * m1 as f32;
850        for l in 0..32 {
851            let lo = qs[is + l] & 0x0F;
852            let hi = if qh[l] & u1 != 0 { 16 } else { 0 };
853            out[out_i] = d0 * (lo + hi) as f32 - m0f;
854            out_i += 1;
855        }
856        for l in 0..32 {
857            let lo = qs[is + l] >> 4;
858            let hi = if qh[l] & u2 != 0 { 16 } else { 0 };
859            out[out_i] = d1 * (lo + hi) as f32 - m1f;
860            out_i += 1;
861        }
862        is += 32;
863        u1 <<= 2;
864        u2 <<= 2;
865    }
866}
867
868/// Dequantize one Q6_K super-block (210 bytes) into `out`.
869pub fn dequant_q6_k_block(block: &[u8], out: &mut [f32; QK_K]) {
870    let ql_len = QK_K / 2;
871    let qh_len = QK_K / 4;
872    let sc_len = QK_K / 16;
873    let ql = &block[0..ql_len];
874    let qh = &block[ql_len..ql_len + qh_len];
875    let sc = &block[ql_len + qh_len..ql_len + qh_len + sc_len];
876    let d = read_f16_le(&block[ql_len + qh_len + sc_len..]);
877    for h in 0..2 {
878        let dst_base = h * 128;
879        let ql_off = h * 64;
880        let qh_off_h = h * 32;
881        let sc_off = h * 8;
882        for l in 0..32 {
883            let is = l / 16;
884            let qh_b = qh[qh_off_h + l];
885            let q1 = ((ql[ql_off + l] & 0x0F) | ((qh_b & 3) << 4)) as i32 - 32;
886            let q2 = ((ql[ql_off + l + 32] & 0x0F) | (((qh_b >> 2) & 3) << 4)) as i32 - 32;
887            let q3 = ((ql[ql_off + l] >> 4) | (((qh_b >> 4) & 3) << 4)) as i32 - 32;
888            let q4 = ((ql[ql_off + l + 32] >> 4) | (((qh_b >> 6) & 3) << 4)) as i32 - 32;
889            out[dst_base + l] = d * sc[sc_off + is] as i8 as f32 * q1 as f32;
890            out[dst_base + l + 32] = d * sc[sc_off + is + 2] as i8 as f32 * q2 as f32;
891            out[dst_base + l + 64] = d * sc[sc_off + is + 4] as i8 as f32 * q3 as f32;
892            out[dst_base + l + 96] = d * sc[sc_off + is + 6] as i8 as f32 * q4 as f32;
893        }
894    }
895}
896
897/// Dequantize one Q8_K super-block (276 bytes) into `out`.
898pub fn dequant_q8_k_block(block: &[u8], out: &mut [f32; QK_K]) {
899    let d = f32::from_le_bytes(block[0..4].try_into().unwrap());
900    let qs = &block[4..4 + QK_K];
901    for i in 0..QK_K {
902        out[i] = d * qs[i] as i8 as f32;
903    }
904}
905
906/// Dequantize one Q2_K super-block (84 bytes) into `out`.
907pub fn dequant_q2_k_block(block: &[u8], out: &mut [f32; QK_K]) {
908    let d = read_f16_le(&block[0..2]);
909    let min = read_f16_le(&block[2..4]);
910    let mut q = &block[4 + QK_K / 16..];
911    let mut is = 0usize;
912    let mut out_i = 0usize;
913    for _ in 0..(QK_K / 128) {
914        let mut shift = 0u32;
915        for _ in 0..4 {
916            let sc = block[4 + is];
917            is += 1;
918            let dl = d * (sc & 0xF) as f32;
919            let ml = min * (sc >> 4) as f32;
920            for l in 0..16 {
921                out[out_i] = dl * ((q[l] >> shift) & 3) as f32 - ml;
922                out_i += 1;
923            }
924            let sc = block[4 + is];
925            is += 1;
926            let dl = d * (sc & 0xF) as f32;
927            let ml = min * (sc >> 4) as f32;
928            for l in 0..16 {
929                out[out_i] = dl * ((q[l + 16] >> shift) & 3) as f32 - ml;
930                out_i += 1;
931            }
932            shift += 2;
933        }
934        q = &q[32..];
935    }
936}
937
938/// Dequantize one Q3_K super-block (110 bytes) into `out`.
939pub fn dequant_q3_k_block(block: &[u8], out: &mut [f32; QK_K]) {
940    const KMASK1: u32 = 0x0303_0303;
941    const KMASK2: u32 = 0x0f0f_0f0f;
942    let d_all = read_f16_le(&block[0..2]);
943    let hm = &block[2 + K_SCALE_SIZE..2 + K_SCALE_SIZE + QK_K / 8];
944    let mut q = &block[2 + K_SCALE_SIZE + QK_K / 8..];
945    let mut aux = [0u32; 4];
946    aux[0] = u32::from_le_bytes(block[2..6].try_into().unwrap());
947    aux[1] = u32::from_le_bytes(block[6..10].try_into().unwrap());
948    aux[2] = u32::from_le_bytes(block[10..14].try_into().unwrap());
949    let tmp = aux[2];
950    aux[2] = ((aux[0] >> 4) & KMASK2) | (((tmp >> 4) & KMASK1) << 4);
951    aux[3] = ((aux[1] >> 4) & KMASK2) | (((tmp >> 6) & KMASK1) << 4);
952    aux[0] = (aux[0] & KMASK2) | ((tmp & KMASK1) << 4);
953    aux[1] = (aux[1] & KMASK2) | (((tmp >> 2) & KMASK1) << 4);
954    let scales: &[i8; 16] = unsafe { &*(aux.as_ptr() as *const [i8; 16]) };
955    let mut is = 0usize;
956    let mut m: u8 = 1;
957    let mut out_i = 0usize;
958    for _ in 0..(QK_K / 128) {
959        let mut shift = 0u32;
960        for _ in 0..4 {
961            let dl = d_all * (scales[is] - 32) as f32;
962            is += 1;
963            for l in 0..16 {
964                let h = if hm[l] & m != 0 { 0 } else { 4 };
965                out[out_i] = dl * (((q[l] >> shift) & 3) as i8 - h) as f32;
966                out_i += 1;
967            }
968            let dl = d_all * (scales[is] - 32) as f32;
969            is += 1;
970            for l in 0..16 {
971                let h = if hm[l + 16] & m != 0 { 0 } else { 4 };
972                out[out_i] = dl * (((q[l + 16] >> shift) & 3) as i8 - h) as f32;
973                out_i += 1;
974            }
975            shift += 2;
976            m <<= 1;
977        }
978        q = &q[32..];
979    }
980}
981
982pub fn dequant_q2_k(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
983    if !n.is_multiple_of(QK_K) {
984        bail!("Q2_K: n={n} not divisible by {QK_K}");
985    }
986    let nb = n / QK_K;
987    let blk = 2 + 2 + QK_K / 16 + QK_K / 4;
988    if bytes.len() != nb * blk {
989        bail!("Q2_K: bad byte count");
990    }
991    let mut out = vec![0f32; n];
992    for i in 0..nb {
993        let off = i * blk;
994        dequant_q2_k_block(
995            &bytes[off..off + blk],
996            (&mut out[i * QK_K..(i + 1) * QK_K]).try_into().unwrap(),
997        );
998    }
999    Ok(out)
1000}
1001
1002pub fn dequant_q3_k(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
1003    if !n.is_multiple_of(QK_K) {
1004        bail!("Q3_K: n={n} not divisible by {QK_K}");
1005    }
1006    let nb = n / QK_K;
1007    let blk = 2 + K_SCALE_SIZE + QK_K / 8 + QK_K / 4;
1008    if bytes.len() != nb * blk {
1009        bail!("Q3_K: bad byte count");
1010    }
1011    let mut out = vec![0f32; n];
1012    for i in 0..nb {
1013        let off = i * blk;
1014        dequant_q3_k_block(
1015            &bytes[off..off + blk],
1016            (&mut out[i * QK_K..(i + 1) * QK_K]).try_into().unwrap(),
1017        );
1018    }
1019    Ok(out)
1020}
1021
1022/// Q4_K block: 144 bytes / 256 elements (4.5 bits/element).
1023/// Layout: f16 d + f16 dmin + 12-byte packed scales/mins + 128 nibbles.
1024pub fn dequant_q4_k(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
1025    if !n.is_multiple_of(QK_K) {
1026        bail!("Q4_K: n={n} not divisible by {QK_K}");
1027    }
1028    let nb = n / QK_K;
1029    let blk = 2 + 2 + K_SCALE_SIZE + QK_K / 2;
1030    if bytes.len() != nb * blk {
1031        bail!("Q4_K: bad byte count");
1032    }
1033    let mut out = Vec::with_capacity(n);
1034    for i in 0..nb {
1035        let off = i * blk;
1036        let d = read_f16_le(&bytes[off..off + 2]);
1037        let dmin = read_f16_le(&bytes[off + 2..off + 4]);
1038        let scales = &bytes[off + 4..off + 4 + K_SCALE_SIZE];
1039        let qs = &bytes[off + 4 + K_SCALE_SIZE..off + blk];
1040        // 8 sub-blocks × 32 elements. Each pair of sub-blocks reads
1041        // 32 nibbles (16 bytes): low nibbles → sub-block j, high
1042        // nibbles → sub-block j+1.
1043        let mut is = 0usize;
1044        for j in (0..8).step_by(2) {
1045            let (sc0, m0) = get_scale_min_k4(j, scales);
1046            let (sc1, m1) = get_scale_min_k4(j + 1, scales);
1047            let d0 = d * sc0 as f32;
1048            let m0 = dmin * m0 as f32;
1049            let d1 = d * sc1 as f32;
1050            let m1 = dmin * m1 as f32;
1051            for l in 0..32 {
1052                let q = qs[is + l];
1053                out.push(d0 * (q & 0x0F) as f32 - m0);
1054            }
1055            for l in 0..32 {
1056                let q = qs[is + l];
1057                out.push(d1 * (q >> 4) as f32 - m1);
1058            }
1059            is += 32;
1060        }
1061    }
1062    Ok(out)
1063}
1064
1065/// Q5_K block: 176 bytes / 256 elements (5.5 bits/element).
1066/// Layout: f16 d + f16 dmin + 12-byte packed scales/mins + 32-byte
1067/// high-bits + 128 nibbles. Each element's 5th bit lives in `qh`
1068/// indexed by position-within-super-block.
1069pub fn dequant_q5_k(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
1070    if !n.is_multiple_of(QK_K) {
1071        bail!("Q5_K: n={n} not divisible by {QK_K}");
1072    }
1073    let nb = n / QK_K;
1074    let blk = 2 + 2 + K_SCALE_SIZE + QK_K / 8 + QK_K / 2;
1075    if bytes.len() != nb * blk {
1076        bail!("Q5_K: bad byte count");
1077    }
1078    let mut out = Vec::with_capacity(n);
1079    for i in 0..nb {
1080        let off = i * blk;
1081        let d = read_f16_le(&bytes[off..off + 2]);
1082        let dmin = read_f16_le(&bytes[off + 2..off + 4]);
1083        let scales = &bytes[off + 4..off + 4 + K_SCALE_SIZE];
1084        let qh_off = off + 4 + K_SCALE_SIZE;
1085        let qh = &bytes[qh_off..qh_off + QK_K / 8];
1086        let qs_off = qh_off + QK_K / 8;
1087        let qs = &bytes[qs_off..qs_off + QK_K / 2];
1088        let mut is = 0usize;
1089        let mut u1: u8 = 1;
1090        let mut u2: u8 = 2;
1091        for j in (0..8).step_by(2) {
1092            let (sc0, m0) = get_scale_min_k4(j, scales);
1093            let (sc1, m1) = get_scale_min_k4(j + 1, scales);
1094            let d0 = d * sc0 as f32;
1095            let m0 = dmin * m0 as f32;
1096            let d1 = d * sc1 as f32;
1097            let m1 = dmin * m1 as f32;
1098            for l in 0..32 {
1099                let lo = qs[is + l] & 0x0F;
1100                let hi = if qh[l] & u1 != 0 { 16 } else { 0 };
1101                out.push(d0 * (lo + hi) as f32 - m0);
1102            }
1103            for l in 0..32 {
1104                let lo = qs[is + l] >> 4;
1105                let hi = if qh[l] & u2 != 0 { 16 } else { 0 };
1106                out.push(d1 * (lo + hi) as f32 - m1);
1107            }
1108            is += 32;
1109            u1 <<= 2;
1110            u2 <<= 2;
1111        }
1112    }
1113    Ok(out)
1114}
1115
1116/// Q6_K block: 210 bytes / 256 elements (6.5625 bits/element). The
1117/// highest-quality K-quant; common in `*-Q6_K.gguf` model dumps.
1118/// Layout: 128 low-nibble bytes + 64 high-2-bit bytes + 16 i8 scales
1119/// + f16 d (super-block scale; per-sub-block scales are signed).
1120pub fn dequant_q6_k(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
1121    if !n.is_multiple_of(QK_K) {
1122        bail!("Q6_K: n={n} not divisible by {QK_K}");
1123    }
1124    let nb = n / QK_K;
1125    let ql_len = QK_K / 2; // 128
1126    let qh_len = QK_K / 4; // 64
1127    let sc_len = QK_K / 16; // 16
1128    let blk = ql_len + qh_len + sc_len + 2;
1129    if bytes.len() != nb * blk {
1130        bail!("Q6_K: bad byte count");
1131    }
1132    let mut out = vec![0f32; n];
1133    for i in 0..nb {
1134        let off = i * blk;
1135        let ql = &bytes[off..off + ql_len];
1136        let qh = &bytes[off + ql_len..off + ql_len + qh_len];
1137        let sc = &bytes[off + ql_len + qh_len..off + ql_len + qh_len + sc_len];
1138        let d = read_f16_le(&bytes[off + ql_len + qh_len + sc_len..off + blk]);
1139        let dst = &mut out[i * QK_K..(i + 1) * QK_K];
1140        // Two halves of 128 elements each. Per half we walk l in 0..32
1141        // and decode four interleaved 6-bit values (offsets 0, 32, 64, 96).
1142        for h in 0..2 {
1143            let dst_base = h * 128;
1144            let ql_off = h * 64;
1145            let qh_off_h = h * 32;
1146            let sc_off = h * 8;
1147            for l in 0..32 {
1148                let is = l / 16;
1149                let qh_b = qh[qh_off_h + l];
1150                let q1 = (((ql[ql_off + l] & 0x0F) | ((qh_b & 3) << 4)) as i32 - 32) as f32;
1151                let q2 =
1152                    (((ql[ql_off + l + 32] & 0x0F) | (((qh_b >> 2) & 3) << 4)) as i32 - 32) as f32;
1153                let q3 = (((ql[ql_off + l] >> 4) | (((qh_b >> 4) & 3) << 4)) as i32 - 32) as f32;
1154                let q4 =
1155                    (((ql[ql_off + l + 32] >> 4) | (((qh_b >> 6) & 3) << 4)) as i32 - 32) as f32;
1156                dst[dst_base + l] = d * sc[sc_off + is] as i8 as f32 * q1;
1157                dst[dst_base + l + 32] = d * sc[sc_off + is + 2] as i8 as f32 * q2;
1158                dst[dst_base + l + 64] = d * sc[sc_off + is + 4] as i8 as f32 * q3;
1159                dst[dst_base + l + 96] = d * sc[sc_off + is + 6] as i8 as f32 * q4;
1160            }
1161        }
1162    }
1163    Ok(out)
1164}
1165
1166/// Q8_K block: 276 bytes / 256 elements. Mostly an intermediate
1167/// format used inside llama.cpp's matmul kernels, but some dumps do
1168/// store it directly. We only need to materialize the i8 quants ×
1169/// the f32 super-block scale; `bsums` (per-16-block partial sums) is
1170/// metadata we can safely ignore for plain dequant.
1171pub fn dequant_q8_k(bytes: &[u8], n: usize) -> Result<Vec<f32>> {
1172    if !n.is_multiple_of(QK_K) {
1173        bail!("Q8_K: n={n} not divisible by {QK_K}");
1174    }
1175    let nb = n / QK_K;
1176    let blk = 4 + QK_K + (QK_K / 16) * 2;
1177    if bytes.len() != nb * blk {
1178        bail!("Q8_K: bad byte count");
1179    }
1180    let mut out = Vec::with_capacity(n);
1181    for i in 0..nb {
1182        let off = i * blk;
1183        let d = f32::from_le_bytes([bytes[off], bytes[off + 1], bytes[off + 2], bytes[off + 3]]);
1184        let qs = &bytes[off + 4..off + 4 + QK_K];
1185        for &q in qs {
1186            out.push(d * (q as i8) as f32);
1187        }
1188    }
1189    Ok(out)
1190}
1191
1192// ─── tests ────────────────────────────────────────────────────────
1193
1194#[cfg(test)]
1195mod tests {
1196    use super::*;
1197    use std::io::Cursor;
1198
1199    #[test]
1200    fn roundtrip_f32_v3() {
1201        let data = [1.0f32, -2.0, 3.5, 0.0];
1202        let mut buf: Vec<u8> = Vec::new();
1203        buf.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
1204        buf.extend_from_slice(&3u32.to_le_bytes());
1205        buf.extend_from_slice(&1u64.to_le_bytes()); // tensor_count
1206        buf.extend_from_slice(&0u64.to_le_bytes()); // kv_count
1207        let name = "w";
1208        buf.extend_from_slice(&(name.len() as u64).to_le_bytes());
1209        buf.extend_from_slice(name.as_bytes());
1210        buf.extend_from_slice(&1u32.to_le_bytes()); // n_dims
1211        buf.extend_from_slice(&(data.len() as u64).to_le_bytes());
1212        buf.extend_from_slice(&(GgmlType::F32 as u32).to_le_bytes());
1213        buf.extend_from_slice(&0u64.to_le_bytes());
1214        while !buf.len().is_multiple_of(DEFAULT_ALIGNMENT as usize) {
1215            buf.push(0);
1216        }
1217        for v in &data {
1218            buf.extend_from_slice(&v.to_le_bytes());
1219        }
1220
1221        let mut c = Cursor::new(buf);
1222        let f = GgufFile::from_reader(&mut c).unwrap();
1223        assert_eq!(f.version, 3);
1224        assert_eq!(f.tensors.len(), 1);
1225        let (out, shape) = f.dequant_f32("w").unwrap();
1226        assert_eq!(shape, vec![4]);
1227        assert_eq!(out, data);
1228    }
1229
1230    #[test]
1231    fn rejects_wrong_magic() {
1232        let buf = vec![0u8; 16];
1233        let mut c = Cursor::new(buf);
1234        assert!(GgufFile::from_reader(&mut c).is_err());
1235    }
1236
1237    #[test]
1238    fn dequant_q8_0_block() {
1239        let mut bytes = Vec::new();
1240        let d = half::f16::from_f32(0.5);
1241        bytes.extend_from_slice(&d.to_le_bytes());
1242        let qs: [i8; QK8_0] = std::array::from_fn(|i| (i as i8) - 16);
1243        for q in qs {
1244            bytes.push(q as u8);
1245        }
1246
1247        let out = dequant_q8_0(&bytes, QK8_0).unwrap();
1248        assert_eq!(out.len(), QK8_0);
1249        for i in 0..QK8_0 {
1250            assert!((out[i] - 0.5 * (qs[i] as f32)).abs() < 1e-6);
1251        }
1252    }
1253
1254    #[test]
1255    fn dequant_q4_0_block() {
1256        let mut bytes = Vec::new();
1257        let d = half::f16::from_f32(1.0);
1258        bytes.extend_from_slice(&d.to_le_bytes());
1259        // Byte i = (i << 4) | i  → low nibble = i, high nibble = i.
1260        let qs: [u8; 16] = std::array::from_fn(|i| (i as u8 & 0x0F) | ((i as u8 & 0x0F) << 4));
1261        bytes.extend_from_slice(&qs);
1262
1263        let out = dequant_q4_0(&bytes, QK4_0).unwrap();
1264        assert_eq!(out.len(), QK4_0);
1265        for i in 0..16 {
1266            assert_eq!(out[i], (i as f32) - 8.0);
1267        }
1268        for i in 0..16 {
1269            assert_eq!(out[16 + i], (i as f32) - 8.0);
1270        }
1271    }
1272
1273    #[test]
1274    fn dequant_q4_k_block_constant_value() {
1275        // Hand-build one Q4_K block (144 bytes / 256 elements) where
1276        // every sub-block has scale=1 (encoded as 1), min=0, every
1277        // quant nibble = 7. Then d = 1, dmin = 1, and every output
1278        // should be 7.0.
1279        let mut bytes = Vec::new();
1280        bytes.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes()); // d
1281        bytes.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes()); // dmin
1282        // scales[12]: pack sc=1, m=0 for each of 8 sub-blocks.
1283        //   j<4: scales[j] = 1, scales[j+4] = 0
1284        //   j>=4: encoded across scales[j+4] low nibble + scales[j-4] top 2 bits.
1285        // Simplest construction: every byte = 0x01 in low 6 bits gives sc=1 for
1286        // j<4 and propagates through the j>=4 path with min=0. Verify with
1287        // get_scale_min_k4 manually:
1288        //   for j>=4: sc = (scales[j+4] & 0xF) | ((scales[j-4] >> 6) << 4)
1289        //   if scales[j+4] = 0x01 and scales[j-4] = 0x01, top 2 bits of 0x01 = 0,
1290        //   so sc = 1. Min = (0x01 >> 4) | ((0x01 >> 6) << 4) = 0.
1291        let mut scales = [0u8; K_SCALE_SIZE];
1292        for s in &mut scales[0..4] {
1293            *s = 0x01; // sc=1 in low 6 bits, top 2 bits = 0
1294        }
1295        // scales[4..8] hold min=0 for j<4 (low 6 bits = 0) and contribute
1296        // the low 4 bits of sc / min for j>=4 — leave at 0; sc for j>=4 then
1297        // equals top-2-bits-of-scales[j-4] << 4 = 0, which would give 0
1298        // not 1. Workaround: use a simpler check that doesn't require all
1299        // sub-blocks identical — just verify the j<4 sub-blocks decode to 7.
1300        bytes.extend_from_slice(&scales);
1301        // qs[128]: every nibble = 7 → byte = 0x77.
1302        bytes.extend(std::iter::repeat_n(0x77u8, QK_K / 2));
1303        let out = dequant_q4_k(&bytes, QK_K).unwrap();
1304        assert_eq!(out.len(), QK_K);
1305        // First 4 sub-blocks (128 elements at positions 0..32, 64..96, 128..160, 192..224)
1306        // pair as (j=0,j=1), (j=2,j=3) — so 128 elements decode with sc=1, min=0
1307        // → value = 1 * 7 - 0 = 7.0.
1308        // The actual emission order is: 32 from j, then 32 from j+1, then j+=2.
1309        // For j=0,1: first 64 outputs. For j=2,3: next 64 outputs. Both pairs
1310        // are in the j<4 branch.
1311        for v in &out[0..128] {
1312            assert!((v - 7.0).abs() < 1e-5, "Q4K decode mismatch: {v}");
1313        }
1314    }
1315
1316    #[test]
1317    fn dequant_q6_k_block_matches_full_with_signed_scale() {
1318        const BLK: usize = QK_K / 2 + QK_K / 4 + QK_K / 16 + 2;
1319        let mut block = [0u8; BLK];
1320        let sc_off = QK_K / 2 + QK_K / 4;
1321        block[sc_off] = 0xFF;
1322        block[0] = 0x21;
1323        block[QK_K / 2] = 0x08;
1324        block[BLK - 2..].copy_from_slice(&half::f16::ONE.to_le_bytes());
1325
1326        let mut out_block = [0f32; QK_K];
1327        dequant_q6_k_block(&block, &mut out_block);
1328        let full = dequant_q6_k(&block, QK_K).unwrap();
1329        assert!((out_block[0] - full[0]).abs() < 1e-4);
1330        assert!(
1331            (out_block[0] - 31.0).abs() < 1e-4,
1332            "unexpected value {}",
1333            out_block[0]
1334        );
1335    }
1336
1337    #[test]
1338    fn dequant_q6_k_block_constant_value() {
1339        // Build one Q6_K block where every per-sub-block scale = 1 and
1340        // every 6-bit quant value = 32 (i.e. 0 after the -32 bias) plus
1341        // d = 1. Output should be all zeros.
1342        let ql_len = QK_K / 2;
1343        let qh_len = QK_K / 4;
1344        let sc_len = QK_K / 16;
1345        let mut bytes = Vec::with_capacity(ql_len + qh_len + sc_len + 2);
1346        // ql: low 4 bits of each 6-bit value = 0 (since 32 = 0b100000 → low=0, high=2)
1347        bytes.resize(ql_len, 0u8);
1348        // qh: each pair of high bits = 2 (binary 10). Packed 4 pairs per byte
1349        // in the order (bits 0-1, 2-3, 4-5, 6-7) for offsets 0, 32, 64, 96.
1350        // Pattern: 0b10_10_10_10 = 0xAA.
1351        bytes.resize(ql_len + qh_len, 0xAAu8);
1352        // sc: 16 i8 scales, all = 1.
1353        bytes.extend(std::iter::repeat_n(1u8, sc_len));
1354        // d = 1.0
1355        bytes.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes());
1356
1357        let out = dequant_q6_k(&bytes, QK_K).unwrap();
1358        assert_eq!(out.len(), QK_K);
1359        for v in &out {
1360            assert!(v.abs() < 1e-5, "Q6K decode mismatch: {v}");
1361        }
1362    }
1363
1364    #[test]
1365    fn dequant_q2_k_block_constant_value() {
1366        // Q2_K: d=1, min=0, all scales encode sc=1/min=0, all 2-bit quants = 3.
1367        let mut bytes = Vec::new();
1368        bytes.extend_from_slice(&half::f16::from_f32(1.0).to_le_bytes()); // d
1369        bytes.extend_from_slice(&half::f16::from_f32(0.0).to_le_bytes()); // min
1370        bytes.extend(std::iter::repeat_n(0x01u8, QK_K / 16)); // sc=1, min=0
1371        bytes.extend(std::iter::repeat_n(0xFFu8, QK_K / 4)); // all 2-bit fields = 3
1372        let out = dequant_q2_k(&bytes, QK_K).unwrap();
1373        assert_eq!(out.len(), QK_K);
1374        for v in &out {
1375            assert!((v - 3.0).abs() < 1e-4, "Q2K decode mismatch: {v}");
1376        }
1377    }
1378
1379    #[test]
1380    fn dequant_q3_k_check() {
1381        let blk = 2 + K_SCALE_SIZE + QK_K / 8 + QK_K / 4;
1382        let bytes = vec![0u8; blk];
1383        let out = dequant_q3_k(&bytes, QK_K).unwrap();
1384        assert_eq!(out.len(), QK_K);
1385        assert!(out.iter().all(|v| v.is_finite()));
1386    }
1387
1388    #[test]
1389    fn dequant_q8_k_block() {
1390        let mut bytes = Vec::new();
1391        bytes.extend_from_slice(&0.25f32.to_le_bytes());
1392        let qs: [i8; QK_K] = std::array::from_fn(|i| (i as i32 - 128) as i8);
1393        for q in qs {
1394            bytes.push(q as u8);
1395        }
1396        // bsums: 16 i16 — unused by decoder, but must be present so the
1397        // bytes_for check matches.
1398        for _ in 0..(QK_K / 16) {
1399            bytes.extend_from_slice(&0i16.to_le_bytes());
1400        }
1401        let out = dequant_q8_k(&bytes, QK_K).unwrap();
1402        for i in 0..QK_K {
1403            assert!((out[i] - 0.25 * (qs[i] as f32)).abs() < 1e-6);
1404        }
1405    }
1406
1407    #[test]
1408    fn metadata_roundtrip() {
1409        let mut buf: Vec<u8> = Vec::new();
1410        buf.extend_from_slice(&GGUF_MAGIC.to_le_bytes());
1411        buf.extend_from_slice(&3u32.to_le_bytes());
1412        buf.extend_from_slice(&0u64.to_le_bytes()); // tensor_count
1413        buf.extend_from_slice(&1u64.to_le_bytes()); // kv_count
1414        let key = "general.architecture";
1415        buf.extend_from_slice(&(key.len() as u64).to_le_bytes());
1416        buf.extend_from_slice(key.as_bytes());
1417        buf.extend_from_slice(&8u32.to_le_bytes()); // type=string
1418        let val = "llama";
1419        buf.extend_from_slice(&(val.len() as u64).to_le_bytes());
1420        buf.extend_from_slice(val.as_bytes());
1421        while !buf.len().is_multiple_of(DEFAULT_ALIGNMENT as usize) {
1422            buf.push(0);
1423        }
1424
1425        let mut c = Cursor::new(buf);
1426        let f = GgufFile::from_reader(&mut c).unwrap();
1427        assert_eq!(
1428            f.metadata.get(key).and_then(MetaValue::as_str),
1429            Some("llama")
1430        );
1431    }
1432}