Skip to main content

mlx_native/gguf/
mod.rs

1//! GGUF v3 file format parser.
2//!
3//! Parses GGUF headers, metadata, and tensor info on open.  Tensor data is
4//! loaded lazily on demand into [`MlxBuffer`]s — either as raw GGML blocks
5//! (for GPU quantized matmul) or dequantized to F32 (for norm weights etc.).
6//!
7//! # Example
8//!
9//! ```ignore
10//! use mlx_native::gguf::GgufFile;
11//! use std::path::Path;
12//!
13//! let gguf = GgufFile::open(Path::new("model.gguf"))?;
14//! let names = gguf.tensor_names();
15//! let buf = gguf.load_tensor("blk.0.attn_q.weight", &device)?;
16//! let norm = gguf.load_tensor_f32("blk.0.attn_norm.weight", &device)?;
17//! ```
18
19use std::collections::HashMap;
20use std::io::{BufReader, Read, Seek, SeekFrom};
21use std::path::Path;
22use std::sync::Mutex;
23
24use half::f16;
25
26use crate::ops::quantized_matmul_ggml::GgmlType;
27use crate::{DType, MlxBuffer, MlxDevice, MlxError, Result};
28
29// ---------------------------------------------------------------------------
30// GGUF constants
31// ---------------------------------------------------------------------------
32
33/// GGUF magic number: "GGUF" as little-endian u32 (bytes: 0x47 0x47 0x55 0x46).
34const GGUF_MAGIC: u32 = 0x4655_4747;
35
36/// GGUF version we support.
37const GGUF_VERSION: u32 = 3;
38
39/// Default alignment for the tensor data section.
40const GGUF_DEFAULT_ALIGNMENT: u64 = 32;
41
42/// Metadata key that overrides the default alignment.
43const GGUF_ALIGNMENT_KEY: &str = "general.alignment";
44
45// ---------------------------------------------------------------------------
46// GGUF metadata value type IDs
47// ---------------------------------------------------------------------------
48
49const GGUF_TYPE_UINT8: u32 = 0;
50const GGUF_TYPE_INT8: u32 = 1;
51const GGUF_TYPE_UINT16: u32 = 2;
52const GGUF_TYPE_INT16: u32 = 3;
53const GGUF_TYPE_UINT32: u32 = 4;
54const GGUF_TYPE_INT32: u32 = 5;
55const GGUF_TYPE_FLOAT32: u32 = 6;
56const GGUF_TYPE_BOOL: u32 = 7;
57const GGUF_TYPE_STRING: u32 = 8;
58const GGUF_TYPE_ARRAY: u32 = 9;
59const GGUF_TYPE_UINT64: u32 = 10;
60const GGUF_TYPE_INT64: u32 = 11;
61const GGUF_TYPE_FLOAT64: u32 = 12;
62
63// ---------------------------------------------------------------------------
64// GGML type IDs (from ggml.h)
65// ---------------------------------------------------------------------------
66
67const GGML_TYPE_F32: u32 = 0;
68const GGML_TYPE_F16: u32 = 1;
69const GGML_TYPE_Q4_0: u32 = 2;
70const GGML_TYPE_Q8_0: u32 = 8;
71const GGML_TYPE_Q4_K: u32 = 12;
72const GGML_TYPE_Q5_K: u32 = 13;
73const GGML_TYPE_Q6_K: u32 = 14;
74const GGML_TYPE_I16: u32 = 17;
75
76// ---------------------------------------------------------------------------
77// Public types
78// ---------------------------------------------------------------------------
79
80/// GGUF metadata value types.
81#[derive(Debug, Clone)]
82pub enum MetadataValue {
83    Uint8(u8),
84    Int8(i8),
85    Uint16(u16),
86    Int16(i16),
87    Uint32(u32),
88    Int32(i32),
89    Float32(f32),
90    Bool(bool),
91    String(String),
92    Array(Vec<MetadataValue>),
93    Uint64(u64),
94    Int64(i64),
95    Float64(f64),
96}
97
98impl MetadataValue {
99    /// Try to interpret this value as a string reference.
100    pub fn as_str(&self) -> Option<&str> {
101        match self {
102            MetadataValue::String(s) => Some(s.as_str()),
103            _ => None,
104        }
105    }
106
107    /// Try to interpret this value as a u32.
108    pub fn as_u32(&self) -> Option<u32> {
109        match self {
110            MetadataValue::Uint32(v) => Some(*v),
111            MetadataValue::Uint8(v) => Some(*v as u32),
112            MetadataValue::Uint16(v) => Some(*v as u32),
113            MetadataValue::Int32(v) if *v >= 0 => Some(*v as u32),
114            _ => None,
115        }
116    }
117
118    /// Try to interpret this value as an f32.
119    pub fn as_f32(&self) -> Option<f32> {
120        match self {
121            MetadataValue::Float32(v) => Some(*v),
122            MetadataValue::Float64(v) => Some(*v as f32),
123            _ => None,
124        }
125    }
126}
127
128/// Information about a single tensor in the GGUF file.
129#[derive(Debug, Clone)]
130pub struct TensorInfo {
131    /// Tensor name (e.g. "blk.0.attn_q.weight").
132    pub name: String,
133    /// Tensor shape, innermost dimension first (as stored in GGUF).
134    pub shape: Vec<usize>,
135    /// GGML quantization type.
136    pub ggml_type: GgmlType,
137    /// Byte offset relative to the start of the tensor data section.
138    pub offset: u64,
139    /// Total byte length of this tensor's data.
140    pub byte_len: usize,
141}
142
143/// A parsed GGUF file, ready for lazy tensor loading.
144///
145/// The file is kept open so that tensor data can be read on demand via
146/// [`load_tensor`](GgufFile::load_tensor) and
147/// [`load_tensor_f32`](GgufFile::load_tensor_f32).
148pub struct GgufFile {
149    metadata: HashMap<String, MetadataValue>,
150    tensors: HashMap<String, TensorInfo>,
151    /// Absolute byte offset in the file where tensor data begins.
152    tensor_data_offset: u64,
153    reader: Mutex<BufReader<std::fs::File>>,
154}
155
156// ---------------------------------------------------------------------------
157// Low-level read helpers
158// ---------------------------------------------------------------------------
159
160/// Read a little-endian u8.
161fn read_u8<R: Read>(r: &mut R) -> Result<u8> {
162    let mut buf = [0u8; 1];
163    r.read_exact(&mut buf)
164        .map_err(|e| MlxError::GgufParseError(format!("read u8: {e}")))?;
165    Ok(buf[0])
166}
167
168/// Read a little-endian i8.
169fn read_i8<R: Read>(r: &mut R) -> Result<i8> {
170    Ok(read_u8(r)? as i8)
171}
172
173/// Read a little-endian u16.
174fn read_u16<R: Read>(r: &mut R) -> Result<u16> {
175    let mut buf = [0u8; 2];
176    r.read_exact(&mut buf)
177        .map_err(|e| MlxError::GgufParseError(format!("read u16: {e}")))?;
178    Ok(u16::from_le_bytes(buf))
179}
180
181/// Read a little-endian i16.
182fn read_i16<R: Read>(r: &mut R) -> Result<i16> {
183    let mut buf = [0u8; 2];
184    r.read_exact(&mut buf)
185        .map_err(|e| MlxError::GgufParseError(format!("read i16: {e}")))?;
186    Ok(i16::from_le_bytes(buf))
187}
188
189/// Read a little-endian u32.
190fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
191    let mut buf = [0u8; 4];
192    r.read_exact(&mut buf)
193        .map_err(|e| MlxError::GgufParseError(format!("read u32: {e}")))?;
194    Ok(u32::from_le_bytes(buf))
195}
196
197/// Read a little-endian i32.
198fn read_i32<R: Read>(r: &mut R) -> Result<i32> {
199    let mut buf = [0u8; 4];
200    r.read_exact(&mut buf)
201        .map_err(|e| MlxError::GgufParseError(format!("read i32: {e}")))?;
202    Ok(i32::from_le_bytes(buf))
203}
204
205/// Read a little-endian u64.
206fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
207    let mut buf = [0u8; 8];
208    r.read_exact(&mut buf)
209        .map_err(|e| MlxError::GgufParseError(format!("read u64: {e}")))?;
210    Ok(u64::from_le_bytes(buf))
211}
212
213/// Read a little-endian i64.
214fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
215    let mut buf = [0u8; 8];
216    r.read_exact(&mut buf)
217        .map_err(|e| MlxError::GgufParseError(format!("read i64: {e}")))?;
218    Ok(i64::from_le_bytes(buf))
219}
220
221/// Read a little-endian f32.
222fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
223    let mut buf = [0u8; 4];
224    r.read_exact(&mut buf)
225        .map_err(|e| MlxError::GgufParseError(format!("read f32: {e}")))?;
226    Ok(f32::from_le_bytes(buf))
227}
228
229/// Read a little-endian f64.
230fn read_f64<R: Read>(r: &mut R) -> Result<f64> {
231    let mut buf = [0u8; 8];
232    r.read_exact(&mut buf)
233        .map_err(|e| MlxError::GgufParseError(format!("read f64: {e}")))?;
234    Ok(f64::from_le_bytes(buf))
235}
236
237/// Read a GGUF-format string: u64 length followed by UTF-8 bytes (not
238/// null-terminated).
239fn read_gguf_string<R: Read>(r: &mut R) -> Result<String> {
240    let len = read_u64(r)? as usize;
241    if len > 256 * 1024 * 1024 {
242        return Err(MlxError::GgufParseError(format!(
243            "string length {len} exceeds 256 MiB safety limit"
244        )));
245    }
246    let mut buf = vec![0u8; len];
247    r.read_exact(&mut buf)
248        .map_err(|e| MlxError::GgufParseError(format!("read string bytes: {e}")))?;
249    String::from_utf8(buf)
250        .map_err(|e| MlxError::GgufParseError(format!("invalid UTF-8 in string: {e}")))
251}
252
253// ---------------------------------------------------------------------------
254// Metadata value parsing
255// ---------------------------------------------------------------------------
256
257/// Read a single metadata value of the given type.
258fn read_metadata_value<R: Read>(r: &mut R, value_type: u32) -> Result<MetadataValue> {
259    match value_type {
260        GGUF_TYPE_UINT8 => Ok(MetadataValue::Uint8(read_u8(r)?)),
261        GGUF_TYPE_INT8 => Ok(MetadataValue::Int8(read_i8(r)?)),
262        GGUF_TYPE_UINT16 => Ok(MetadataValue::Uint16(read_u16(r)?)),
263        GGUF_TYPE_INT16 => Ok(MetadataValue::Int16(read_i16(r)?)),
264        GGUF_TYPE_UINT32 => Ok(MetadataValue::Uint32(read_u32(r)?)),
265        GGUF_TYPE_INT32 => Ok(MetadataValue::Int32(read_i32(r)?)),
266        GGUF_TYPE_FLOAT32 => Ok(MetadataValue::Float32(read_f32(r)?)),
267        GGUF_TYPE_BOOL => {
268            let byte = read_u8(r)?;
269            Ok(MetadataValue::Bool(byte != 0))
270        }
271        GGUF_TYPE_STRING => Ok(MetadataValue::String(read_gguf_string(r)?)),
272        GGUF_TYPE_ARRAY => {
273            let elem_type = read_u32(r)?;
274            let count = read_u64(r)? as usize;
275            if count > 64 * 1024 * 1024 {
276                return Err(MlxError::GgufParseError(format!(
277                    "array count {count} exceeds 64M element safety limit"
278                )));
279            }
280            let mut elems = Vec::with_capacity(count);
281            for _ in 0..count {
282                elems.push(read_metadata_value(r, elem_type)?);
283            }
284            Ok(MetadataValue::Array(elems))
285        }
286        GGUF_TYPE_UINT64 => Ok(MetadataValue::Uint64(read_u64(r)?)),
287        GGUF_TYPE_INT64 => Ok(MetadataValue::Int64(read_i64(r)?)),
288        GGUF_TYPE_FLOAT64 => Ok(MetadataValue::Float64(read_f64(r)?)),
289        other => Err(MlxError::GgufParseError(format!(
290            "unknown metadata value type {other}"
291        ))),
292    }
293}
294
295// ---------------------------------------------------------------------------
296// GGML type mapping
297// ---------------------------------------------------------------------------
298
299/// Map a GGML type ID (u32 from the GGUF file) to our `GgmlType` enum.
300fn ggml_type_from_u32(id: u32) -> Result<GgmlType> {
301    match id {
302        GGML_TYPE_F32 => Ok(GgmlType::F32),
303        GGML_TYPE_F16 => Ok(GgmlType::F16),
304        GGML_TYPE_Q4_0 => Ok(GgmlType::Q4_0),
305        GGML_TYPE_Q8_0 => Ok(GgmlType::Q8_0),
306        GGML_TYPE_Q4_K => Ok(GgmlType::Q4_K),
307        GGML_TYPE_Q5_K => Ok(GgmlType::Q5_K),
308        GGML_TYPE_Q6_K => Ok(GgmlType::Q6_K),
309        GGML_TYPE_I16 => Ok(GgmlType::I16),
310        other => Err(MlxError::GgufParseError(format!(
311            "unsupported GGML type ID {other}"
312        ))),
313    }
314}
315
316/// Compute the byte length of a tensor from its shape and GGML type.
317///
318/// For quantized types, the innermost dimension (shape[0] in GGUF's row-major
319/// convention) must be divisible by the block's element count.
320fn compute_byte_len(shape: &[usize], ggml_type: GgmlType) -> Result<usize> {
321    let total_elements: usize = shape.iter().product();
322    if total_elements == 0 {
323        return Ok(0);
324    }
325
326    let elems_per_block = ggml_type.block_values() as usize;
327    let bytes_per_block = ggml_type.block_bytes() as usize;
328
329    if total_elements % elems_per_block != 0 {
330        return Err(MlxError::GgufParseError(format!(
331            "total elements {total_elements} not divisible by block size {elems_per_block} \
332             for type {:?}",
333            ggml_type
334        )));
335    }
336
337    Ok((total_elements / elems_per_block) * bytes_per_block)
338}
339
340// ---------------------------------------------------------------------------
341// Dequantization
342// ---------------------------------------------------------------------------
343
344/// Convert a raw little-endian f16 (2 bytes) to f32.
345#[inline]
346fn f16_from_le_bytes(bytes: [u8; 2]) -> f32 {
347    f16::from_le_bytes(bytes).to_f32()
348}
349
350/// Dequantize Q4_0 blocks to f32.
351///
352/// Block layout (18 bytes, 32 elements):
353///   f16 d          — scale
354///   u8  qs[16]     — packed 4-bit values (low nibble = first 16, high nibble = last 16)
355fn dequantize_q4_0(data: &[u8], output: &mut [f32]) -> Result<()> {
356    const BLOCK_BYTES: usize = 18;
357    const BLOCK_ELEMS: usize = 32;
358
359    if data.len() % BLOCK_BYTES != 0 {
360        return Err(MlxError::GgufParseError(format!(
361            "Q4_0 data length {} not divisible by block size {BLOCK_BYTES}",
362            data.len()
363        )));
364    }
365
366    let num_blocks = data.len() / BLOCK_BYTES;
367    if output.len() < num_blocks * BLOCK_ELEMS {
368        return Err(MlxError::GgufParseError(
369            "Q4_0 output buffer too small".into(),
370        ));
371    }
372
373    for i in 0..num_blocks {
374        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
375        let d = f16_from_le_bytes([block[0], block[1]]);
376        let qs = &block[2..18]; // 16 bytes
377
378        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
379
380        for j in 0..16 {
381            let x0 = (qs[j] & 0x0F) as i16 - 8;
382            let x1 = (qs[j] >> 4) as i16 - 8;
383            out[j] = x0 as f32 * d;
384            out[j + 16] = x1 as f32 * d;
385        }
386    }
387    Ok(())
388}
389
390/// Dequantize Q8_0 blocks to f32.
391///
392/// Block layout (34 bytes, 32 elements):
393///   f16 d         — scale
394///   i8  qs[32]    — signed 8-bit quantized values
395fn dequantize_q8_0(data: &[u8], output: &mut [f32]) -> Result<()> {
396    const BLOCK_BYTES: usize = 34;
397    const BLOCK_ELEMS: usize = 32;
398
399    if data.len() % BLOCK_BYTES != 0 {
400        return Err(MlxError::GgufParseError(format!(
401            "Q8_0 data length {} not divisible by block size {BLOCK_BYTES}",
402            data.len()
403        )));
404    }
405
406    let num_blocks = data.len() / BLOCK_BYTES;
407    if output.len() < num_blocks * BLOCK_ELEMS {
408        return Err(MlxError::GgufParseError(
409            "Q8_0 output buffer too small".into(),
410        ));
411    }
412
413    for i in 0..num_blocks {
414        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
415        let d = f16_from_le_bytes([block[0], block[1]]);
416        let qs = &block[2..34]; // 32 bytes of i8
417
418        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
419
420        for j in 0..32 {
421            out[j] = (qs[j] as i8) as f32 * d;
422        }
423    }
424    Ok(())
425}
426
427/// Extract a (scale, min) pair for sub-block `j` from the 12-byte scales
428/// array used by Q4_K and Q5_K.
429///
430/// This matches `get_scale_min_k4` from candle / llama.cpp exactly:
431///
432/// For j < 4:
433///   scale = scales[j] & 63
434///   min   = scales[j + 4] & 63
435///
436/// For j >= 4:
437///   scale = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4)
438///   min   = (scales[j + 4] >> 4)  | ((scales[j]     >> 6) << 4)
439#[inline]
440fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
441    if j < 4 {
442        let sc = scales[j] & 63;
443        let m = scales[j + 4] & 63;
444        (sc, m)
445    } else {
446        let sc = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4);
447        let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
448        (sc, m)
449    }
450}
451
452/// Dequantize Q5_K blocks to f32.
453///
454/// Block layout (176 bytes, 256 elements):
455///   f16 d           — super-block scale      (offset 0,  2 bytes)
456///   f16 dmin        — super-block minimum     (offset 2,  2 bytes)
457///   u8  scales[12]  — packed 6-bit scales/mins (offset 4,  12 bytes; shared with Q4_K)
458///   u8  qh[32]      — high bits of quants      (offset 16, 32 bytes = QK_K/8)
459///   u8  qs[128]     — low 4 bits of quants     (offset 48, 128 bytes = QK_K/2)
460///
461/// 8 sub-blocks of 32 elements each. Dequantization walks pairs of
462/// sub-blocks (is, is+1), each pair consumes 32 bytes of qs (low nibble
463/// for is, high nibble for is+1). The qh array is SHARED across all 4
464/// pairs — the high bit per element is masked out of qh using shifting
465/// selector values `u1 = 1 << (2*pair_idx)` / `u2 = 2 << (2*pair_idx)`.
466///
467/// Spec source: derived from `ggml/src/ggml-quants.c::dequantize_row_q5_K`.
468/// No code copied — formula reproduced from the mathematical definition.
469fn dequantize_q5_k(data: &[u8], output: &mut [f32]) -> Result<()> {
470    const BLOCK_BYTES: usize = 176;
471    const BLOCK_ELEMS: usize = 256;
472
473    if data.len() % BLOCK_BYTES != 0 {
474        return Err(MlxError::GgufParseError(format!(
475            "Q5_K data length {} not divisible by block size {BLOCK_BYTES}",
476            data.len()
477        )));
478    }
479
480    let num_blocks = data.len() / BLOCK_BYTES;
481    if output.len() < num_blocks * BLOCK_ELEMS {
482        return Err(MlxError::GgufParseError(
483            "Q5_K output buffer too small".into(),
484        ));
485    }
486
487    for i in 0..num_blocks {
488        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
489
490        let d = f16_from_le_bytes([block[0], block[1]]);
491        let dmin = f16_from_le_bytes([block[2], block[3]]);
492        let scales = &block[4..16]; // 12 bytes
493        let qh = &block[16..48]; // 32 bytes — high bit of quants
494        let qs = &block[48..176]; // 128 bytes — low 4 bits
495
496        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
497
498        // Process 4 pairs of sub-blocks (256 values total).
499        // u1 / u2 are the high-bit selector masks: they shift left by 2 each
500        // iteration so the 4 pairs pick bits 0/1, 2/3, 4/5, 6/7 of each qh byte.
501        let mut is = 0usize;
502        let mut u1: u8 = 1;
503        let mut u2: u8 = 2;
504        let mut ys_index = 0usize;
505        let mut ql_off = 0usize;
506
507        while ql_off < 128 {
508            let ql = &qs[ql_off..ql_off + 32];
509
510            let (sc1, m1) = get_scale_min_k4(is, scales);
511            let d1 = d * sc1 as f32;
512            let m1 = dmin * m1 as f32;
513            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
514            let d2 = d * sc2 as f32;
515            let m2 = dmin * m2 as f32;
516
517            // Sub-block `is` (low nibble + high bit from qh masked by u1).
518            for l in 0..32 {
519                let low = (ql[l] & 0x0F) as u32;
520                let high = if (qh[l] & u1) != 0 { 16 } else { 0 };
521                let q = low + high;
522                out[ys_index] = d1 * q as f32 - m1;
523                ys_index += 1;
524            }
525            // Sub-block `is + 1` (high nibble + high bit from qh masked by u2).
526            for l in 0..32 {
527                let low = (ql[l] >> 4) as u32;
528                let high = if (qh[l] & u2) != 0 { 16 } else { 0 };
529                let q = low + high;
530                out[ys_index] = d2 * q as f32 - m2;
531                ys_index += 1;
532            }
533
534            is += 2;
535            ql_off += 32;
536            u1 <<= 2;
537            u2 <<= 2;
538        }
539    }
540    Ok(())
541}
542
543/// Dequantize I16 tensors to f32.
544///
545/// Simple bitcast: `f32_val = i16_val as f32`. No scale metadata is used
546/// (apex GGUF convention — raw int16 values are meaningful as-is).
547///
548/// ADR-013 Decision 12 originally anticipated a per-tensor scale factor,
549/// but the apex GGUF does not emit one; values are stored as raw ints.
550/// If future GGUFs emit a scale, extend this with a scale parameter.
551fn dequantize_i16(data: &[u8], output: &mut [f32]) -> Result<()> {
552    if data.len() % 2 != 0 {
553        return Err(MlxError::GgufParseError(format!(
554            "I16 data length {} not even",
555            data.len()
556        )));
557    }
558    let num_elements = data.len() / 2;
559    if output.len() < num_elements {
560        return Err(MlxError::GgufParseError(
561            "I16 output buffer too small".into(),
562        ));
563    }
564    for i in 0..num_elements {
565        let v = i16::from_le_bytes([data[2 * i], data[2 * i + 1]]);
566        output[i] = v as f32;
567    }
568    Ok(())
569}
570
571/// Dequantize Q4_K blocks to f32.
572///
573/// Block layout (144 bytes, 256 elements):
574///   f16 d          — super-block scale          (offset 0,  2 bytes)
575///   f16 dmin       — super-block minimum         (offset 2,  2 bytes)
576///   u8  scales[12] — packed sub-block scales/mins (offset 4, 12 bytes)
577///   u8  qs[128]    — packed 4-bit quantized values (offset 16, 128 bytes)
578///
579/// 8 sub-blocks of 32 elements each.  Each pair of sub-blocks (64 elements)
580/// shares 32 bytes of qs — the low nibble gives the first sub-block, the
581/// high nibble gives the second.
582fn dequantize_q4_k(data: &[u8], output: &mut [f32]) -> Result<()> {
583    const BLOCK_BYTES: usize = 144;
584    const BLOCK_ELEMS: usize = 256;
585
586    if data.len() % BLOCK_BYTES != 0 {
587        return Err(MlxError::GgufParseError(format!(
588            "Q4_K data length {} not divisible by block size {BLOCK_BYTES}",
589            data.len()
590        )));
591    }
592
593    let num_blocks = data.len() / BLOCK_BYTES;
594    if output.len() < num_blocks * BLOCK_ELEMS {
595        return Err(MlxError::GgufParseError(
596            "Q4_K output buffer too small".into(),
597        ));
598    }
599
600    for i in 0..num_blocks {
601        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
602
603        let d = f16_from_le_bytes([block[0], block[1]]);
604        let dmin = f16_from_le_bytes([block[2], block[3]]);
605        let scales = &block[4..16];   // 12 bytes
606        let qs = &block[16..144];     // 128 bytes
607
608        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
609
610        // Process 4 pairs of sub-blocks (8 sub-blocks total, 256 elements).
611        // Each iteration handles 64 elements: sub-block `is` (low nibbles)
612        // and sub-block `is+1` (high nibbles) from 32 bytes of qs.
613        let mut is = 0usize;
614        let mut ys_index = 0usize;
615
616        // Step through the 256-element super-block in chunks of 64.
617        // j tracks the byte offset within qs.
618        let mut j = 0usize;
619        while j < 128 {
620            let q = &qs[j..j + 32];
621            let (sc1, m1) = get_scale_min_k4(is, scales);
622            let d1 = d * sc1 as f32;
623            let min1 = dmin * m1 as f32;
624            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
625            let d2 = d * sc2 as f32;
626            let min2 = dmin * m2 as f32;
627
628            // Low nibbles: sub-block `is` (32 elements)
629            for byte in q.iter() {
630                out[ys_index] = d1 * (*byte & 0xF) as f32 - min1;
631                ys_index += 1;
632            }
633            // High nibbles: sub-block `is + 1` (32 elements)
634            for byte in q.iter() {
635                out[ys_index] = d2 * (*byte >> 4) as f32 - min2;
636                ys_index += 1;
637            }
638
639            is += 2;
640            j += 32;
641        }
642    }
643    Ok(())
644}
645
646/// Dequantize Q6_K blocks to f32.
647///
648/// Block layout (210 bytes, 256 elements):
649///   u8   ql[128]   — low 4 bits of quantized values  (offset 0, 128 bytes)
650///   u8   qh[64]    — high 2 bits of quantized values  (offset 128, 64 bytes)
651///   i8   scales[16] — sub-block scales                (offset 192, 16 bytes)
652///   f16  d          — super-block scale               (offset 208, 2 bytes)
653///
654/// 256 elements organized as 2 groups of 128.  Each group of 128 has its own
655/// ql[64], qh[32] region and produces 4 interleaved sub-groups of 32.
656fn dequantize_q6_k(data: &[u8], output: &mut [f32]) -> Result<()> {
657    const BLOCK_BYTES: usize = 210;
658    const BLOCK_ELEMS: usize = 256;
659
660    if data.len() % BLOCK_BYTES != 0 {
661        return Err(MlxError::GgufParseError(format!(
662            "Q6_K data length {} not divisible by block size {BLOCK_BYTES}",
663            data.len()
664        )));
665    }
666
667    let num_blocks = data.len() / BLOCK_BYTES;
668    if output.len() < num_blocks * BLOCK_ELEMS {
669        return Err(MlxError::GgufParseError(
670            "Q6_K output buffer too small".into(),
671        ));
672    }
673
674    for i in 0..num_blocks {
675        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
676
677        let ql = &block[0..128];
678        let qh = &block[128..192];
679        let sc = &block[192..208]; // i8 scales[16]
680        let d = f16_from_le_bytes([block[208], block[209]]);
681
682        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
683
684        // Process in two groups of 128 (idx = 0 and idx = 1).
685        for idx in 0..2 {
686            let ql_base = &ql[64 * idx..];
687            let qh_base = &qh[32 * idx..];
688            let sc_base = &sc[8 * idx..];
689            let out_base = &mut out[128 * idx..];
690
691            for l in 0..32 {
692                let is = l / 16; // 0 for l in 0..16, 1 for l in 16..32
693
694                let q1 = ((ql_base[l] & 0xF) | ((qh_base[l] & 3) << 4)) as i8 - 32_i8;
695                let q2 = ((ql_base[l + 32] & 0xF) | (((qh_base[l] >> 2) & 3) << 4)) as i8
696                    - 32_i8;
697                let q3 = ((ql_base[l] >> 4) | (((qh_base[l] >> 4) & 3) << 4)) as i8 - 32_i8;
698                let q4 = ((ql_base[l + 32] >> 4) | (((qh_base[l] >> 6) & 3) << 4)) as i8
699                    - 32_i8;
700
701                out_base[l] = d * sc_base[is] as i8 as f32 * q1 as f32;
702                out_base[l + 32] = d * sc_base[is + 2] as i8 as f32 * q2 as f32;
703                out_base[l + 64] = d * sc_base[is + 4] as i8 as f32 * q3 as f32;
704                out_base[l + 96] = d * sc_base[is + 6] as i8 as f32 * q4 as f32;
705            }
706        }
707    }
708    Ok(())
709}
710
711/// Dequantize F16 data to F32.
712fn dequantize_f16(data: &[u8], output: &mut [f32]) -> Result<()> {
713    if data.len() % 2 != 0 {
714        return Err(MlxError::GgufParseError(
715            "F16 data length not even".into(),
716        ));
717    }
718    let count = data.len() / 2;
719    if output.len() < count {
720        return Err(MlxError::GgufParseError(
721            "F16 output buffer too small".into(),
722        ));
723    }
724    for i in 0..count {
725        output[i] = f16_from_le_bytes([data[2 * i], data[2 * i + 1]]);
726    }
727    Ok(())
728}
729
730/// Reinterpret F32 little-endian bytes into the output slice.
731fn copy_f32(data: &[u8], output: &mut [f32]) -> Result<()> {
732    if data.len() % 4 != 0 {
733        return Err(MlxError::GgufParseError(
734            "F32 data length not multiple of 4".into(),
735        ));
736    }
737    let count = data.len() / 4;
738    if output.len() < count {
739        return Err(MlxError::GgufParseError(
740            "F32 output buffer too small".into(),
741        ));
742    }
743    for i in 0..count {
744        output[i] = f32::from_le_bytes([
745            data[4 * i],
746            data[4 * i + 1],
747            data[4 * i + 2],
748            data[4 * i + 3],
749        ]);
750    }
751    Ok(())
752}
753
754/// Dequantize raw GGML block data to f32.
755fn dequantize_to_f32(data: &[u8], ggml_type: GgmlType, output: &mut [f32]) -> Result<()> {
756    match ggml_type {
757        GgmlType::F32 => copy_f32(data, output),
758        GgmlType::F16 => dequantize_f16(data, output),
759        GgmlType::Q4_0 => dequantize_q4_0(data, output),
760        GgmlType::Q8_0 => dequantize_q8_0(data, output),
761        GgmlType::Q4_K => dequantize_q4_k(data, output),
762        GgmlType::Q6_K => dequantize_q6_k(data, output),
763        GgmlType::Q5_K => dequantize_q5_k(data, output),
764        GgmlType::I16 => dequantize_i16(data, output),
765    }
766}
767
768// ---------------------------------------------------------------------------
769// GgufFile implementation
770// ---------------------------------------------------------------------------
771
772impl GgufFile {
773    /// Open and parse a GGUF v3 file.
774    ///
775    /// This reads the full header (magic, version, tensor count, metadata KV
776    /// pairs, tensor info entries) but does **not** read any tensor data.
777    /// Tensor data is loaded lazily via [`load_tensor`](Self::load_tensor) or
778    /// [`load_tensor_f32`](Self::load_tensor_f32).
779    ///
780    /// # Errors
781    ///
782    /// Returns `MlxError::IoError` if the file cannot be opened.
783    /// Returns `MlxError::GgufParseError` if the file is not valid GGUF v3.
784    pub fn open(path: &Path) -> Result<Self> {
785        let file = std::fs::File::open(path).map_err(|e| {
786            MlxError::IoError(format!("cannot open GGUF file '{}': {e}", path.display()))
787        })?;
788        let mut reader = BufReader::new(file);
789
790        // --- Header ---
791        let magic = read_u32(&mut reader)?;
792        if magic != GGUF_MAGIC {
793            return Err(MlxError::GgufParseError(format!(
794                "bad magic: expected 0x{GGUF_MAGIC:08X}, got 0x{magic:08X}"
795            )));
796        }
797
798        let version = read_u32(&mut reader)?;
799        if version != GGUF_VERSION {
800            return Err(MlxError::GgufParseError(format!(
801                "unsupported GGUF version {version} (only v3 is supported)"
802            )));
803        }
804
805        let tensor_count = read_u64(&mut reader)? as usize;
806        let metadata_kv_count = read_u64(&mut reader)? as usize;
807
808        // Sanity limits to prevent OOM on corrupted files.
809        if tensor_count > 100_000 {
810            return Err(MlxError::GgufParseError(format!(
811                "tensor_count {tensor_count} exceeds 100k safety limit"
812            )));
813        }
814        if metadata_kv_count > 1_000_000 {
815            return Err(MlxError::GgufParseError(format!(
816                "metadata_kv_count {metadata_kv_count} exceeds 1M safety limit"
817            )));
818        }
819
820        // --- Metadata KV pairs ---
821        let mut metadata = HashMap::with_capacity(metadata_kv_count);
822        for _ in 0..metadata_kv_count {
823            let key = read_gguf_string(&mut reader)?;
824            let value_type = read_u32(&mut reader)?;
825            let value = read_metadata_value(&mut reader, value_type)?;
826            metadata.insert(key, value);
827        }
828
829        // --- Determine alignment ---
830        let alignment = metadata
831            .get(GGUF_ALIGNMENT_KEY)
832            .and_then(|v| v.as_u32())
833            .map(|v| v as u64)
834            .unwrap_or(GGUF_DEFAULT_ALIGNMENT);
835
836        if alignment == 0 || (alignment & (alignment - 1)) != 0 {
837            return Err(MlxError::GgufParseError(format!(
838                "alignment {alignment} is not a power of two"
839            )));
840        }
841
842        // --- Tensor info entries ---
843        let mut tensors = HashMap::with_capacity(tensor_count);
844        for _ in 0..tensor_count {
845            let name = read_gguf_string(&mut reader)?;
846            let n_dims = read_u32(&mut reader)? as usize;
847
848            if n_dims > 8 {
849                return Err(MlxError::GgufParseError(format!(
850                    "tensor '{name}' has {n_dims} dimensions (max 8)"
851                )));
852            }
853
854            let mut shape = Vec::with_capacity(n_dims);
855            for _ in 0..n_dims {
856                shape.push(read_u64(&mut reader)? as usize);
857            }
858            // GGUF stores dimensions innermost-first (column-major order).
859            // Reverse to match the [rows, cols] convention used by candle
860            // and by the rest of hf2q's weight loading code.
861            shape.reverse();
862
863            let ggml_type_id = read_u32(&mut reader)?;
864            let ggml_type = ggml_type_from_u32(ggml_type_id).map_err(|e| {
865                MlxError::GgufParseError(format!("tensor '{name}': {e}"))
866            })?;
867
868            let offset = read_u64(&mut reader)?;
869            let byte_len = compute_byte_len(&shape, ggml_type).map_err(|e| {
870                MlxError::GgufParseError(format!("tensor '{name}': {e}"))
871            })?;
872
873            tensors.insert(
874                name.clone(),
875                TensorInfo {
876                    name,
877                    shape,
878                    ggml_type,
879                    offset,
880                    byte_len,
881                },
882            );
883        }
884
885        // --- Compute tensor_data_offset ---
886        // The current file position is just past all tensor info entries.
887        // Tensor data starts at the next alignment boundary.
888        let pos = reader
889            .stream_position()
890            .map_err(|e| MlxError::GgufParseError(format!("stream_position: {e}")))?;
891        let tensor_data_offset = align_offset(pos, alignment);
892
893        Ok(GgufFile {
894            metadata,
895            tensors,
896            tensor_data_offset,
897            reader: Mutex::new(reader),
898        })
899    }
900
901    // -----------------------------------------------------------------------
902    // Metadata accessors
903    // -----------------------------------------------------------------------
904
905    /// Look up a metadata value by key.
906    pub fn metadata(&self, key: &str) -> Option<&MetadataValue> {
907        self.metadata.get(key)
908    }
909
910    /// Look up a metadata string value by key.
911    pub fn metadata_string(&self, key: &str) -> Option<&str> {
912        self.metadata.get(key).and_then(|v| v.as_str())
913    }
914
915    /// Look up a metadata u32 value by key.
916    pub fn metadata_u32(&self, key: &str) -> Option<u32> {
917        self.metadata.get(key).and_then(|v| v.as_u32())
918    }
919
920    /// Look up a metadata f32 value by key.
921    pub fn metadata_f32(&self, key: &str) -> Option<f32> {
922        self.metadata.get(key).and_then(|v| v.as_f32())
923    }
924
925    // -----------------------------------------------------------------------
926    // Tensor info accessors
927    // -----------------------------------------------------------------------
928
929    /// Return the names of all tensors in the file.
930    pub fn tensor_names(&self) -> Vec<&str> {
931        self.tensors.keys().map(|s| s.as_str()).collect()
932    }
933
934    /// Look up info for a specific tensor by name.
935    pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
936        self.tensors.get(name)
937    }
938
939    /// Number of tensors in the file.
940    pub fn tensor_count(&self) -> usize {
941        self.tensors.len()
942    }
943
944    /// Number of metadata key-value pairs.
945    pub fn metadata_count(&self) -> usize {
946        self.metadata.len()
947    }
948
949    // -----------------------------------------------------------------------
950    // Tensor loading
951    // -----------------------------------------------------------------------
952
953    /// Read raw tensor bytes from the file.
954    ///
955    /// This is a private helper that seeks to the tensor's location and reads
956    /// `byte_len` bytes.
957    fn read_tensor_bytes(&self, info: &TensorInfo) -> Result<Vec<u8>> {
958        let abs_offset = self.tensor_data_offset + info.offset;
959        let mut reader = self
960            .reader
961            .lock()
962            .map_err(|_| MlxError::GgufParseError("reader mutex poisoned".into()))?;
963
964        reader
965            .seek(SeekFrom::Start(abs_offset))
966            .map_err(|e| MlxError::IoError(format!("seek to tensor '{}': {e}", info.name)))?;
967
968        let mut buf = vec![0u8; info.byte_len];
969        reader.read_exact(&mut buf).map_err(|e| {
970            MlxError::IoError(format!(
971                "read tensor '{}' ({} bytes at offset {}): {e}",
972                info.name, info.byte_len, abs_offset
973            ))
974        })?;
975
976        Ok(buf)
977    }
978
979    /// Load a tensor as a raw buffer on the Metal device.
980    ///
981    /// For quantized types (Q4_0, Q8_0, Q4_K, Q6_K) the buffer contains raw
982    /// GGML blocks with dtype `U8` — these are consumed directly by
983    /// `quantized_matmul_ggml` kernels.
984    ///
985    /// For F32 and F16 tensors the buffer has the corresponding typed dtype.
986    ///
987    /// # Errors
988    ///
989    /// Returns an error if the tensor name is not found, or if reading fails.
990    pub fn load_tensor(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
991        let info = self.tensors.get(name).ok_or_else(|| {
992            MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
993        })?;
994
995        let data = self.read_tensor_bytes(info)?;
996
997        match info.ggml_type {
998            GgmlType::F32 => {
999                let mut buf =
1000                    device.alloc_buffer(info.byte_len, DType::F32, info.shape.clone())?;
1001                {
1002                    let slice: &mut [u8] = buf.as_mut_slice()?;
1003                    slice.copy_from_slice(&data);
1004                }
1005                Ok(buf)
1006            }
1007            GgmlType::F16 => {
1008                let mut buf =
1009                    device.alloc_buffer(info.byte_len, DType::F16, info.shape.clone())?;
1010                {
1011                    let slice: &mut [u8] = buf.as_mut_slice()?;
1012                    slice.copy_from_slice(&data);
1013                }
1014                Ok(buf)
1015            }
1016            GgmlType::Q4_0
1017            | GgmlType::Q8_0
1018            | GgmlType::Q4_K
1019            | GgmlType::Q5_K
1020            | GgmlType::Q6_K
1021            | GgmlType::I16 => {
1022                // Store raw GGML blocks as U8 buffer. Q5_K and I16 are held
1023                // opaque until dequant kernels land (ADR-013 Decision 12).
1024                let mut buf =
1025                    device.alloc_buffer(info.byte_len, DType::U8, info.shape.clone())?;
1026                {
1027                    let slice: &mut [u8] = buf.as_mut_slice()?;
1028                    slice.copy_from_slice(&data);
1029                }
1030                Ok(buf)
1031            }
1032        }
1033    }
1034
1035    /// Load a tensor, dequantizing to F32 on the CPU, then upload to the
1036    /// Metal device.
1037    ///
1038    /// This is used for norm weights, embedding tables, and other tensors
1039    /// where the inference kernels operate on F32 directly.
1040    ///
1041    /// # Errors
1042    ///
1043    /// Returns an error if the tensor name is not found, reading fails, or
1044    /// dequantization encounters malformed data.
1045    pub fn load_tensor_f32(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
1046        let info = self.tensors.get(name).ok_or_else(|| {
1047            MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
1048        })?;
1049
1050        let data = self.read_tensor_bytes(info)?;
1051        let total_elements: usize = info.shape.iter().product();
1052
1053        if total_elements == 0 {
1054            return Err(MlxError::GgufParseError(format!(
1055                "tensor '{name}' has zero elements"
1056            )));
1057        }
1058
1059        let f32_byte_len = total_elements * 4;
1060        let mut buf =
1061            device.alloc_buffer(f32_byte_len, DType::F32, info.shape.clone())?;
1062
1063        {
1064            let out_slice: &mut [f32] = buf.as_mut_slice()?;
1065            dequantize_to_f32(&data, info.ggml_type, out_slice)?;
1066        }
1067
1068        Ok(buf)
1069    }
1070}
1071
1072// ---------------------------------------------------------------------------
1073// Utility
1074// ---------------------------------------------------------------------------
1075
1076/// Round `offset` up to the next multiple of `alignment`.
1077fn align_offset(offset: u64, alignment: u64) -> u64 {
1078    let mask = alignment - 1;
1079    (offset + mask) & !mask
1080}