Skip to main content

mlx_native/gguf/
mod.rs

1//! GGUF v3 file format parser.
2//!
3//! Parses GGUF headers, metadata, and tensor info on open.  Tensor data is
4//! loaded lazily on demand into [`MlxBuffer`]s — either as raw GGML blocks
5//! (for GPU quantized matmul) or dequantized to F32 (for norm weights etc.).
6//!
7//! # Example
8//!
9//! ```ignore
10//! use mlx_native::gguf::GgufFile;
11//! use std::path::Path;
12//!
13//! let gguf = GgufFile::open(Path::new("model.gguf"))?;
14//! let names = gguf.tensor_names();
15//! let buf = gguf.load_tensor("blk.0.attn_q.weight", &device)?;
16//! let norm = gguf.load_tensor_f32("blk.0.attn_norm.weight", &device)?;
17//! ```
18
19use std::collections::HashMap;
20use std::io::{BufReader, Read, Seek, SeekFrom};
21use std::path::Path;
22use std::sync::Mutex;
23
24use half::f16;
25
26use crate::ops::quantized_matmul_ggml::GgmlType;
27use crate::{DType, MlxBuffer, MlxDevice, MlxError, Result};
28
29// ---------------------------------------------------------------------------
30// GGUF constants
31// ---------------------------------------------------------------------------
32
33/// GGUF magic number: "GGUF" as little-endian u32 (bytes: 0x47 0x47 0x55 0x46).
34const GGUF_MAGIC: u32 = 0x4655_4747;
35
36/// GGUF version we support.
37const GGUF_VERSION: u32 = 3;
38
39/// Default alignment for the tensor data section.
40const GGUF_DEFAULT_ALIGNMENT: u64 = 32;
41
42/// Metadata key that overrides the default alignment.
43const GGUF_ALIGNMENT_KEY: &str = "general.alignment";
44
45// ---------------------------------------------------------------------------
46// GGUF metadata value type IDs
47// ---------------------------------------------------------------------------
48
49const GGUF_TYPE_UINT8: u32 = 0;
50const GGUF_TYPE_INT8: u32 = 1;
51const GGUF_TYPE_UINT16: u32 = 2;
52const GGUF_TYPE_INT16: u32 = 3;
53const GGUF_TYPE_UINT32: u32 = 4;
54const GGUF_TYPE_INT32: u32 = 5;
55const GGUF_TYPE_FLOAT32: u32 = 6;
56const GGUF_TYPE_BOOL: u32 = 7;
57const GGUF_TYPE_STRING: u32 = 8;
58const GGUF_TYPE_ARRAY: u32 = 9;
59const GGUF_TYPE_UINT64: u32 = 10;
60const GGUF_TYPE_INT64: u32 = 11;
61const GGUF_TYPE_FLOAT64: u32 = 12;
62
63// ---------------------------------------------------------------------------
64// GGML type IDs (from ggml.h)
65// ---------------------------------------------------------------------------
66
67const GGML_TYPE_F32: u32 = 0;
68const GGML_TYPE_F16: u32 = 1;
69const GGML_TYPE_Q4_0: u32 = 2;
70const GGML_TYPE_Q8_0: u32 = 8;
71const GGML_TYPE_Q4_K: u32 = 12;
72const GGML_TYPE_Q6_K: u32 = 14;
73
74// ---------------------------------------------------------------------------
75// Public types
76// ---------------------------------------------------------------------------
77
78/// GGUF metadata value types.
79#[derive(Debug, Clone)]
80pub enum MetadataValue {
81    Uint8(u8),
82    Int8(i8),
83    Uint16(u16),
84    Int16(i16),
85    Uint32(u32),
86    Int32(i32),
87    Float32(f32),
88    Bool(bool),
89    String(String),
90    Array(Vec<MetadataValue>),
91    Uint64(u64),
92    Int64(i64),
93    Float64(f64),
94}
95
96impl MetadataValue {
97    /// Try to interpret this value as a string reference.
98    pub fn as_str(&self) -> Option<&str> {
99        match self {
100            MetadataValue::String(s) => Some(s.as_str()),
101            _ => None,
102        }
103    }
104
105    /// Try to interpret this value as a u32.
106    pub fn as_u32(&self) -> Option<u32> {
107        match self {
108            MetadataValue::Uint32(v) => Some(*v),
109            MetadataValue::Uint8(v) => Some(*v as u32),
110            MetadataValue::Uint16(v) => Some(*v as u32),
111            MetadataValue::Int32(v) if *v >= 0 => Some(*v as u32),
112            _ => None,
113        }
114    }
115
116    /// Try to interpret this value as an f32.
117    pub fn as_f32(&self) -> Option<f32> {
118        match self {
119            MetadataValue::Float32(v) => Some(*v),
120            MetadataValue::Float64(v) => Some(*v as f32),
121            _ => None,
122        }
123    }
124}
125
126/// Information about a single tensor in the GGUF file.
127#[derive(Debug, Clone)]
128pub struct TensorInfo {
129    /// Tensor name (e.g. "blk.0.attn_q.weight").
130    pub name: String,
131    /// Tensor shape, innermost dimension first (as stored in GGUF).
132    pub shape: Vec<usize>,
133    /// GGML quantization type.
134    pub ggml_type: GgmlType,
135    /// Byte offset relative to the start of the tensor data section.
136    pub offset: u64,
137    /// Total byte length of this tensor's data.
138    pub byte_len: usize,
139}
140
141/// A parsed GGUF file, ready for lazy tensor loading.
142///
143/// The file is kept open so that tensor data can be read on demand via
144/// [`load_tensor`](GgufFile::load_tensor) and
145/// [`load_tensor_f32`](GgufFile::load_tensor_f32).
146pub struct GgufFile {
147    metadata: HashMap<String, MetadataValue>,
148    tensors: HashMap<String, TensorInfo>,
149    /// Absolute byte offset in the file where tensor data begins.
150    tensor_data_offset: u64,
151    reader: Mutex<BufReader<std::fs::File>>,
152}
153
154// ---------------------------------------------------------------------------
155// Low-level read helpers
156// ---------------------------------------------------------------------------
157
158/// Read a little-endian u8.
159fn read_u8<R: Read>(r: &mut R) -> Result<u8> {
160    let mut buf = [0u8; 1];
161    r.read_exact(&mut buf)
162        .map_err(|e| MlxError::GgufParseError(format!("read u8: {e}")))?;
163    Ok(buf[0])
164}
165
166/// Read a little-endian i8.
167fn read_i8<R: Read>(r: &mut R) -> Result<i8> {
168    Ok(read_u8(r)? as i8)
169}
170
171/// Read a little-endian u16.
172fn read_u16<R: Read>(r: &mut R) -> Result<u16> {
173    let mut buf = [0u8; 2];
174    r.read_exact(&mut buf)
175        .map_err(|e| MlxError::GgufParseError(format!("read u16: {e}")))?;
176    Ok(u16::from_le_bytes(buf))
177}
178
179/// Read a little-endian i16.
180fn read_i16<R: Read>(r: &mut R) -> Result<i16> {
181    let mut buf = [0u8; 2];
182    r.read_exact(&mut buf)
183        .map_err(|e| MlxError::GgufParseError(format!("read i16: {e}")))?;
184    Ok(i16::from_le_bytes(buf))
185}
186
187/// Read a little-endian u32.
188fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
189    let mut buf = [0u8; 4];
190    r.read_exact(&mut buf)
191        .map_err(|e| MlxError::GgufParseError(format!("read u32: {e}")))?;
192    Ok(u32::from_le_bytes(buf))
193}
194
195/// Read a little-endian i32.
196fn read_i32<R: Read>(r: &mut R) -> Result<i32> {
197    let mut buf = [0u8; 4];
198    r.read_exact(&mut buf)
199        .map_err(|e| MlxError::GgufParseError(format!("read i32: {e}")))?;
200    Ok(i32::from_le_bytes(buf))
201}
202
203/// Read a little-endian u64.
204fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
205    let mut buf = [0u8; 8];
206    r.read_exact(&mut buf)
207        .map_err(|e| MlxError::GgufParseError(format!("read u64: {e}")))?;
208    Ok(u64::from_le_bytes(buf))
209}
210
211/// Read a little-endian i64.
212fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
213    let mut buf = [0u8; 8];
214    r.read_exact(&mut buf)
215        .map_err(|e| MlxError::GgufParseError(format!("read i64: {e}")))?;
216    Ok(i64::from_le_bytes(buf))
217}
218
219/// Read a little-endian f32.
220fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
221    let mut buf = [0u8; 4];
222    r.read_exact(&mut buf)
223        .map_err(|e| MlxError::GgufParseError(format!("read f32: {e}")))?;
224    Ok(f32::from_le_bytes(buf))
225}
226
227/// Read a little-endian f64.
228fn read_f64<R: Read>(r: &mut R) -> Result<f64> {
229    let mut buf = [0u8; 8];
230    r.read_exact(&mut buf)
231        .map_err(|e| MlxError::GgufParseError(format!("read f64: {e}")))?;
232    Ok(f64::from_le_bytes(buf))
233}
234
235/// Read a GGUF-format string: u64 length followed by UTF-8 bytes (not
236/// null-terminated).
237fn read_gguf_string<R: Read>(r: &mut R) -> Result<String> {
238    let len = read_u64(r)? as usize;
239    if len > 256 * 1024 * 1024 {
240        return Err(MlxError::GgufParseError(format!(
241            "string length {len} exceeds 256 MiB safety limit"
242        )));
243    }
244    let mut buf = vec![0u8; len];
245    r.read_exact(&mut buf)
246        .map_err(|e| MlxError::GgufParseError(format!("read string bytes: {e}")))?;
247    String::from_utf8(buf)
248        .map_err(|e| MlxError::GgufParseError(format!("invalid UTF-8 in string: {e}")))
249}
250
251// ---------------------------------------------------------------------------
252// Metadata value parsing
253// ---------------------------------------------------------------------------
254
255/// Read a single metadata value of the given type.
256fn read_metadata_value<R: Read>(r: &mut R, value_type: u32) -> Result<MetadataValue> {
257    match value_type {
258        GGUF_TYPE_UINT8 => Ok(MetadataValue::Uint8(read_u8(r)?)),
259        GGUF_TYPE_INT8 => Ok(MetadataValue::Int8(read_i8(r)?)),
260        GGUF_TYPE_UINT16 => Ok(MetadataValue::Uint16(read_u16(r)?)),
261        GGUF_TYPE_INT16 => Ok(MetadataValue::Int16(read_i16(r)?)),
262        GGUF_TYPE_UINT32 => Ok(MetadataValue::Uint32(read_u32(r)?)),
263        GGUF_TYPE_INT32 => Ok(MetadataValue::Int32(read_i32(r)?)),
264        GGUF_TYPE_FLOAT32 => Ok(MetadataValue::Float32(read_f32(r)?)),
265        GGUF_TYPE_BOOL => {
266            let byte = read_u8(r)?;
267            Ok(MetadataValue::Bool(byte != 0))
268        }
269        GGUF_TYPE_STRING => Ok(MetadataValue::String(read_gguf_string(r)?)),
270        GGUF_TYPE_ARRAY => {
271            let elem_type = read_u32(r)?;
272            let count = read_u64(r)? as usize;
273            if count > 64 * 1024 * 1024 {
274                return Err(MlxError::GgufParseError(format!(
275                    "array count {count} exceeds 64M element safety limit"
276                )));
277            }
278            let mut elems = Vec::with_capacity(count);
279            for _ in 0..count {
280                elems.push(read_metadata_value(r, elem_type)?);
281            }
282            Ok(MetadataValue::Array(elems))
283        }
284        GGUF_TYPE_UINT64 => Ok(MetadataValue::Uint64(read_u64(r)?)),
285        GGUF_TYPE_INT64 => Ok(MetadataValue::Int64(read_i64(r)?)),
286        GGUF_TYPE_FLOAT64 => Ok(MetadataValue::Float64(read_f64(r)?)),
287        other => Err(MlxError::GgufParseError(format!(
288            "unknown metadata value type {other}"
289        ))),
290    }
291}
292
293// ---------------------------------------------------------------------------
294// GGML type mapping
295// ---------------------------------------------------------------------------
296
297/// Map a GGML type ID (u32 from the GGUF file) to our `GgmlType` enum.
298fn ggml_type_from_u32(id: u32) -> Result<GgmlType> {
299    match id {
300        GGML_TYPE_F32 => Ok(GgmlType::F32),
301        GGML_TYPE_F16 => Ok(GgmlType::F16),
302        GGML_TYPE_Q4_0 => Ok(GgmlType::Q4_0),
303        GGML_TYPE_Q8_0 => Ok(GgmlType::Q8_0),
304        GGML_TYPE_Q4_K => Ok(GgmlType::Q4_K),
305        GGML_TYPE_Q6_K => Ok(GgmlType::Q6_K),
306        other => Err(MlxError::GgufParseError(format!(
307            "unsupported GGML type ID {other}"
308        ))),
309    }
310}
311
312/// Compute the byte length of a tensor from its shape and GGML type.
313///
314/// For quantized types, the innermost dimension (shape[0] in GGUF's row-major
315/// convention) must be divisible by the block's element count.
316fn compute_byte_len(shape: &[usize], ggml_type: GgmlType) -> Result<usize> {
317    let total_elements: usize = shape.iter().product();
318    if total_elements == 0 {
319        return Ok(0);
320    }
321
322    let elems_per_block = ggml_type.block_values() as usize;
323    let bytes_per_block = ggml_type.block_bytes() as usize;
324
325    if total_elements % elems_per_block != 0 {
326        return Err(MlxError::GgufParseError(format!(
327            "total elements {total_elements} not divisible by block size {elems_per_block} \
328             for type {:?}",
329            ggml_type
330        )));
331    }
332
333    Ok((total_elements / elems_per_block) * bytes_per_block)
334}
335
336// ---------------------------------------------------------------------------
337// Dequantization
338// ---------------------------------------------------------------------------
339
340/// Convert a raw little-endian f16 (2 bytes) to f32.
341#[inline]
342fn f16_from_le_bytes(bytes: [u8; 2]) -> f32 {
343    f16::from_le_bytes(bytes).to_f32()
344}
345
346/// Dequantize Q4_0 blocks to f32.
347///
348/// Block layout (18 bytes, 32 elements):
349///   f16 d          — scale
350///   u8  qs[16]     — packed 4-bit values (low nibble = first 16, high nibble = last 16)
351fn dequantize_q4_0(data: &[u8], output: &mut [f32]) -> Result<()> {
352    const BLOCK_BYTES: usize = 18;
353    const BLOCK_ELEMS: usize = 32;
354
355    if data.len() % BLOCK_BYTES != 0 {
356        return Err(MlxError::GgufParseError(format!(
357            "Q4_0 data length {} not divisible by block size {BLOCK_BYTES}",
358            data.len()
359        )));
360    }
361
362    let num_blocks = data.len() / BLOCK_BYTES;
363    if output.len() < num_blocks * BLOCK_ELEMS {
364        return Err(MlxError::GgufParseError(
365            "Q4_0 output buffer too small".into(),
366        ));
367    }
368
369    for i in 0..num_blocks {
370        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
371        let d = f16_from_le_bytes([block[0], block[1]]);
372        let qs = &block[2..18]; // 16 bytes
373
374        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
375
376        for j in 0..16 {
377            let x0 = (qs[j] & 0x0F) as i16 - 8;
378            let x1 = (qs[j] >> 4) as i16 - 8;
379            out[j] = x0 as f32 * d;
380            out[j + 16] = x1 as f32 * d;
381        }
382    }
383    Ok(())
384}
385
386/// Dequantize Q8_0 blocks to f32.
387///
388/// Block layout (34 bytes, 32 elements):
389///   f16 d         — scale
390///   i8  qs[32]    — signed 8-bit quantized values
391fn dequantize_q8_0(data: &[u8], output: &mut [f32]) -> Result<()> {
392    const BLOCK_BYTES: usize = 34;
393    const BLOCK_ELEMS: usize = 32;
394
395    if data.len() % BLOCK_BYTES != 0 {
396        return Err(MlxError::GgufParseError(format!(
397            "Q8_0 data length {} not divisible by block size {BLOCK_BYTES}",
398            data.len()
399        )));
400    }
401
402    let num_blocks = data.len() / BLOCK_BYTES;
403    if output.len() < num_blocks * BLOCK_ELEMS {
404        return Err(MlxError::GgufParseError(
405            "Q8_0 output buffer too small".into(),
406        ));
407    }
408
409    for i in 0..num_blocks {
410        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
411        let d = f16_from_le_bytes([block[0], block[1]]);
412        let qs = &block[2..34]; // 32 bytes of i8
413
414        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
415
416        for j in 0..32 {
417            out[j] = (qs[j] as i8) as f32 * d;
418        }
419    }
420    Ok(())
421}
422
423/// Extract a (scale, min) pair for sub-block `j` from the 12-byte scales
424/// array used by Q4_K and Q5_K.
425///
426/// This matches `get_scale_min_k4` from candle / llama.cpp exactly:
427///
428/// For j < 4:
429///   scale = scales[j] & 63
430///   min   = scales[j + 4] & 63
431///
432/// For j >= 4:
433///   scale = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4)
434///   min   = (scales[j + 4] >> 4)  | ((scales[j]     >> 6) << 4)
435#[inline]
436fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
437    if j < 4 {
438        let sc = scales[j] & 63;
439        let m = scales[j + 4] & 63;
440        (sc, m)
441    } else {
442        let sc = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4);
443        let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
444        (sc, m)
445    }
446}
447
448/// Dequantize Q4_K blocks to f32.
449///
450/// Block layout (144 bytes, 256 elements):
451///   f16 d          — super-block scale          (offset 0,  2 bytes)
452///   f16 dmin       — super-block minimum         (offset 2,  2 bytes)
453///   u8  scales[12] — packed sub-block scales/mins (offset 4, 12 bytes)
454///   u8  qs[128]    — packed 4-bit quantized values (offset 16, 128 bytes)
455///
456/// 8 sub-blocks of 32 elements each.  Each pair of sub-blocks (64 elements)
457/// shares 32 bytes of qs — the low nibble gives the first sub-block, the
458/// high nibble gives the second.
459fn dequantize_q4_k(data: &[u8], output: &mut [f32]) -> Result<()> {
460    const BLOCK_BYTES: usize = 144;
461    const BLOCK_ELEMS: usize = 256;
462
463    if data.len() % BLOCK_BYTES != 0 {
464        return Err(MlxError::GgufParseError(format!(
465            "Q4_K data length {} not divisible by block size {BLOCK_BYTES}",
466            data.len()
467        )));
468    }
469
470    let num_blocks = data.len() / BLOCK_BYTES;
471    if output.len() < num_blocks * BLOCK_ELEMS {
472        return Err(MlxError::GgufParseError(
473            "Q4_K output buffer too small".into(),
474        ));
475    }
476
477    for i in 0..num_blocks {
478        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
479
480        let d = f16_from_le_bytes([block[0], block[1]]);
481        let dmin = f16_from_le_bytes([block[2], block[3]]);
482        let scales = &block[4..16];   // 12 bytes
483        let qs = &block[16..144];     // 128 bytes
484
485        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
486
487        // Process 4 pairs of sub-blocks (8 sub-blocks total, 256 elements).
488        // Each iteration handles 64 elements: sub-block `is` (low nibbles)
489        // and sub-block `is+1` (high nibbles) from 32 bytes of qs.
490        let mut is = 0usize;
491        let mut ys_index = 0usize;
492
493        // Step through the 256-element super-block in chunks of 64.
494        // j tracks the byte offset within qs.
495        let mut j = 0usize;
496        while j < 128 {
497            let q = &qs[j..j + 32];
498            let (sc1, m1) = get_scale_min_k4(is, scales);
499            let d1 = d * sc1 as f32;
500            let min1 = dmin * m1 as f32;
501            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
502            let d2 = d * sc2 as f32;
503            let min2 = dmin * m2 as f32;
504
505            // Low nibbles: sub-block `is` (32 elements)
506            for byte in q.iter() {
507                out[ys_index] = d1 * (*byte & 0xF) as f32 - min1;
508                ys_index += 1;
509            }
510            // High nibbles: sub-block `is + 1` (32 elements)
511            for byte in q.iter() {
512                out[ys_index] = d2 * (*byte >> 4) as f32 - min2;
513                ys_index += 1;
514            }
515
516            is += 2;
517            j += 32;
518        }
519    }
520    Ok(())
521}
522
523/// Dequantize Q6_K blocks to f32.
524///
525/// Block layout (210 bytes, 256 elements):
526///   u8   ql[128]   — low 4 bits of quantized values  (offset 0, 128 bytes)
527///   u8   qh[64]    — high 2 bits of quantized values  (offset 128, 64 bytes)
528///   i8   scales[16] — sub-block scales                (offset 192, 16 bytes)
529///   f16  d          — super-block scale               (offset 208, 2 bytes)
530///
531/// 256 elements organized as 2 groups of 128.  Each group of 128 has its own
532/// ql[64], qh[32] region and produces 4 interleaved sub-groups of 32.
533fn dequantize_q6_k(data: &[u8], output: &mut [f32]) -> Result<()> {
534    const BLOCK_BYTES: usize = 210;
535    const BLOCK_ELEMS: usize = 256;
536
537    if data.len() % BLOCK_BYTES != 0 {
538        return Err(MlxError::GgufParseError(format!(
539            "Q6_K data length {} not divisible by block size {BLOCK_BYTES}",
540            data.len()
541        )));
542    }
543
544    let num_blocks = data.len() / BLOCK_BYTES;
545    if output.len() < num_blocks * BLOCK_ELEMS {
546        return Err(MlxError::GgufParseError(
547            "Q6_K output buffer too small".into(),
548        ));
549    }
550
551    for i in 0..num_blocks {
552        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
553
554        let ql = &block[0..128];
555        let qh = &block[128..192];
556        let sc = &block[192..208]; // i8 scales[16]
557        let d = f16_from_le_bytes([block[208], block[209]]);
558
559        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
560
561        // Process in two groups of 128 (idx = 0 and idx = 1).
562        for idx in 0..2 {
563            let ql_base = &ql[64 * idx..];
564            let qh_base = &qh[32 * idx..];
565            let sc_base = &sc[8 * idx..];
566            let out_base = &mut out[128 * idx..];
567
568            for l in 0..32 {
569                let is = l / 16; // 0 for l in 0..16, 1 for l in 16..32
570
571                let q1 = ((ql_base[l] & 0xF) | ((qh_base[l] & 3) << 4)) as i8 - 32_i8;
572                let q2 = ((ql_base[l + 32] & 0xF) | (((qh_base[l] >> 2) & 3) << 4)) as i8
573                    - 32_i8;
574                let q3 = ((ql_base[l] >> 4) | (((qh_base[l] >> 4) & 3) << 4)) as i8 - 32_i8;
575                let q4 = ((ql_base[l + 32] >> 4) | (((qh_base[l] >> 6) & 3) << 4)) as i8
576                    - 32_i8;
577
578                out_base[l] = d * sc_base[is] as i8 as f32 * q1 as f32;
579                out_base[l + 32] = d * sc_base[is + 2] as i8 as f32 * q2 as f32;
580                out_base[l + 64] = d * sc_base[is + 4] as i8 as f32 * q3 as f32;
581                out_base[l + 96] = d * sc_base[is + 6] as i8 as f32 * q4 as f32;
582            }
583        }
584    }
585    Ok(())
586}
587
588/// Dequantize F16 data to F32.
589fn dequantize_f16(data: &[u8], output: &mut [f32]) -> Result<()> {
590    if data.len() % 2 != 0 {
591        return Err(MlxError::GgufParseError(
592            "F16 data length not even".into(),
593        ));
594    }
595    let count = data.len() / 2;
596    if output.len() < count {
597        return Err(MlxError::GgufParseError(
598            "F16 output buffer too small".into(),
599        ));
600    }
601    for i in 0..count {
602        output[i] = f16_from_le_bytes([data[2 * i], data[2 * i + 1]]);
603    }
604    Ok(())
605}
606
607/// Reinterpret F32 little-endian bytes into the output slice.
608fn copy_f32(data: &[u8], output: &mut [f32]) -> Result<()> {
609    if data.len() % 4 != 0 {
610        return Err(MlxError::GgufParseError(
611            "F32 data length not multiple of 4".into(),
612        ));
613    }
614    let count = data.len() / 4;
615    if output.len() < count {
616        return Err(MlxError::GgufParseError(
617            "F32 output buffer too small".into(),
618        ));
619    }
620    for i in 0..count {
621        output[i] = f32::from_le_bytes([
622            data[4 * i],
623            data[4 * i + 1],
624            data[4 * i + 2],
625            data[4 * i + 3],
626        ]);
627    }
628    Ok(())
629}
630
631/// Dequantize raw GGML block data to f32.
632fn dequantize_to_f32(data: &[u8], ggml_type: GgmlType, output: &mut [f32]) -> Result<()> {
633    match ggml_type {
634        GgmlType::F32 => copy_f32(data, output),
635        GgmlType::F16 => dequantize_f16(data, output),
636        GgmlType::Q4_0 => dequantize_q4_0(data, output),
637        GgmlType::Q8_0 => dequantize_q8_0(data, output),
638        GgmlType::Q4_K => dequantize_q4_k(data, output),
639        GgmlType::Q6_K => dequantize_q6_k(data, output),
640    }
641}
642
643// ---------------------------------------------------------------------------
644// GgufFile implementation
645// ---------------------------------------------------------------------------
646
647impl GgufFile {
648    /// Open and parse a GGUF v3 file.
649    ///
650    /// This reads the full header (magic, version, tensor count, metadata KV
651    /// pairs, tensor info entries) but does **not** read any tensor data.
652    /// Tensor data is loaded lazily via [`load_tensor`](Self::load_tensor) or
653    /// [`load_tensor_f32`](Self::load_tensor_f32).
654    ///
655    /// # Errors
656    ///
657    /// Returns `MlxError::IoError` if the file cannot be opened.
658    /// Returns `MlxError::GgufParseError` if the file is not valid GGUF v3.
659    pub fn open(path: &Path) -> Result<Self> {
660        let file = std::fs::File::open(path).map_err(|e| {
661            MlxError::IoError(format!("cannot open GGUF file '{}': {e}", path.display()))
662        })?;
663        let mut reader = BufReader::new(file);
664
665        // --- Header ---
666        let magic = read_u32(&mut reader)?;
667        if magic != GGUF_MAGIC {
668            return Err(MlxError::GgufParseError(format!(
669                "bad magic: expected 0x{GGUF_MAGIC:08X}, got 0x{magic:08X}"
670            )));
671        }
672
673        let version = read_u32(&mut reader)?;
674        if version != GGUF_VERSION {
675            return Err(MlxError::GgufParseError(format!(
676                "unsupported GGUF version {version} (only v3 is supported)"
677            )));
678        }
679
680        let tensor_count = read_u64(&mut reader)? as usize;
681        let metadata_kv_count = read_u64(&mut reader)? as usize;
682
683        // Sanity limits to prevent OOM on corrupted files.
684        if tensor_count > 100_000 {
685            return Err(MlxError::GgufParseError(format!(
686                "tensor_count {tensor_count} exceeds 100k safety limit"
687            )));
688        }
689        if metadata_kv_count > 1_000_000 {
690            return Err(MlxError::GgufParseError(format!(
691                "metadata_kv_count {metadata_kv_count} exceeds 1M safety limit"
692            )));
693        }
694
695        // --- Metadata KV pairs ---
696        let mut metadata = HashMap::with_capacity(metadata_kv_count);
697        for _ in 0..metadata_kv_count {
698            let key = read_gguf_string(&mut reader)?;
699            let value_type = read_u32(&mut reader)?;
700            let value = read_metadata_value(&mut reader, value_type)?;
701            metadata.insert(key, value);
702        }
703
704        // --- Determine alignment ---
705        let alignment = metadata
706            .get(GGUF_ALIGNMENT_KEY)
707            .and_then(|v| v.as_u32())
708            .map(|v| v as u64)
709            .unwrap_or(GGUF_DEFAULT_ALIGNMENT);
710
711        if alignment == 0 || (alignment & (alignment - 1)) != 0 {
712            return Err(MlxError::GgufParseError(format!(
713                "alignment {alignment} is not a power of two"
714            )));
715        }
716
717        // --- Tensor info entries ---
718        let mut tensors = HashMap::with_capacity(tensor_count);
719        for _ in 0..tensor_count {
720            let name = read_gguf_string(&mut reader)?;
721            let n_dims = read_u32(&mut reader)? as usize;
722
723            if n_dims > 8 {
724                return Err(MlxError::GgufParseError(format!(
725                    "tensor '{name}' has {n_dims} dimensions (max 8)"
726                )));
727            }
728
729            let mut shape = Vec::with_capacity(n_dims);
730            for _ in 0..n_dims {
731                shape.push(read_u64(&mut reader)? as usize);
732            }
733            // GGUF stores dimensions innermost-first (column-major order).
734            // Reverse to match the [rows, cols] convention used by candle
735            // and by the rest of hf2q's weight loading code.
736            shape.reverse();
737
738            let ggml_type_id = read_u32(&mut reader)?;
739            let ggml_type = ggml_type_from_u32(ggml_type_id).map_err(|e| {
740                MlxError::GgufParseError(format!("tensor '{name}': {e}"))
741            })?;
742
743            let offset = read_u64(&mut reader)?;
744            let byte_len = compute_byte_len(&shape, ggml_type).map_err(|e| {
745                MlxError::GgufParseError(format!("tensor '{name}': {e}"))
746            })?;
747
748            tensors.insert(
749                name.clone(),
750                TensorInfo {
751                    name,
752                    shape,
753                    ggml_type,
754                    offset,
755                    byte_len,
756                },
757            );
758        }
759
760        // --- Compute tensor_data_offset ---
761        // The current file position is just past all tensor info entries.
762        // Tensor data starts at the next alignment boundary.
763        let pos = reader
764            .stream_position()
765            .map_err(|e| MlxError::GgufParseError(format!("stream_position: {e}")))?;
766        let tensor_data_offset = align_offset(pos, alignment);
767
768        Ok(GgufFile {
769            metadata,
770            tensors,
771            tensor_data_offset,
772            reader: Mutex::new(reader),
773        })
774    }
775
776    // -----------------------------------------------------------------------
777    // Metadata accessors
778    // -----------------------------------------------------------------------
779
780    /// Look up a metadata value by key.
781    pub fn metadata(&self, key: &str) -> Option<&MetadataValue> {
782        self.metadata.get(key)
783    }
784
785    /// Look up a metadata string value by key.
786    pub fn metadata_string(&self, key: &str) -> Option<&str> {
787        self.metadata.get(key).and_then(|v| v.as_str())
788    }
789
790    /// Look up a metadata u32 value by key.
791    pub fn metadata_u32(&self, key: &str) -> Option<u32> {
792        self.metadata.get(key).and_then(|v| v.as_u32())
793    }
794
795    /// Look up a metadata f32 value by key.
796    pub fn metadata_f32(&self, key: &str) -> Option<f32> {
797        self.metadata.get(key).and_then(|v| v.as_f32())
798    }
799
800    // -----------------------------------------------------------------------
801    // Tensor info accessors
802    // -----------------------------------------------------------------------
803
804    /// Return the names of all tensors in the file.
805    pub fn tensor_names(&self) -> Vec<&str> {
806        self.tensors.keys().map(|s| s.as_str()).collect()
807    }
808
809    /// Look up info for a specific tensor by name.
810    pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
811        self.tensors.get(name)
812    }
813
814    /// Number of tensors in the file.
815    pub fn tensor_count(&self) -> usize {
816        self.tensors.len()
817    }
818
819    /// Number of metadata key-value pairs.
820    pub fn metadata_count(&self) -> usize {
821        self.metadata.len()
822    }
823
824    // -----------------------------------------------------------------------
825    // Tensor loading
826    // -----------------------------------------------------------------------
827
828    /// Read raw tensor bytes from the file.
829    ///
830    /// This is a private helper that seeks to the tensor's location and reads
831    /// `byte_len` bytes.
832    fn read_tensor_bytes(&self, info: &TensorInfo) -> Result<Vec<u8>> {
833        let abs_offset = self.tensor_data_offset + info.offset;
834        let mut reader = self
835            .reader
836            .lock()
837            .map_err(|_| MlxError::GgufParseError("reader mutex poisoned".into()))?;
838
839        reader
840            .seek(SeekFrom::Start(abs_offset))
841            .map_err(|e| MlxError::IoError(format!("seek to tensor '{}': {e}", info.name)))?;
842
843        let mut buf = vec![0u8; info.byte_len];
844        reader.read_exact(&mut buf).map_err(|e| {
845            MlxError::IoError(format!(
846                "read tensor '{}' ({} bytes at offset {}): {e}",
847                info.name, info.byte_len, abs_offset
848            ))
849        })?;
850
851        Ok(buf)
852    }
853
854    /// Load a tensor as a raw buffer on the Metal device.
855    ///
856    /// For quantized types (Q4_0, Q8_0, Q4_K, Q6_K) the buffer contains raw
857    /// GGML blocks with dtype `U8` — these are consumed directly by
858    /// `quantized_matmul_ggml` kernels.
859    ///
860    /// For F32 and F16 tensors the buffer has the corresponding typed dtype.
861    ///
862    /// # Errors
863    ///
864    /// Returns an error if the tensor name is not found, or if reading fails.
865    pub fn load_tensor(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
866        let info = self.tensors.get(name).ok_or_else(|| {
867            MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
868        })?;
869
870        let data = self.read_tensor_bytes(info)?;
871
872        match info.ggml_type {
873            GgmlType::F32 => {
874                let mut buf =
875                    device.alloc_buffer(info.byte_len, DType::F32, info.shape.clone())?;
876                {
877                    let slice: &mut [u8] = buf.as_mut_slice()?;
878                    slice.copy_from_slice(&data);
879                }
880                Ok(buf)
881            }
882            GgmlType::F16 => {
883                let mut buf =
884                    device.alloc_buffer(info.byte_len, DType::F16, info.shape.clone())?;
885                {
886                    let slice: &mut [u8] = buf.as_mut_slice()?;
887                    slice.copy_from_slice(&data);
888                }
889                Ok(buf)
890            }
891            GgmlType::Q4_0 | GgmlType::Q8_0 | GgmlType::Q4_K | GgmlType::Q6_K => {
892                // Store raw GGML blocks as U8 buffer.
893                let mut buf =
894                    device.alloc_buffer(info.byte_len, DType::U8, info.shape.clone())?;
895                {
896                    let slice: &mut [u8] = buf.as_mut_slice()?;
897                    slice.copy_from_slice(&data);
898                }
899                Ok(buf)
900            }
901        }
902    }
903
904    /// Load a tensor, dequantizing to F32 on the CPU, then upload to the
905    /// Metal device.
906    ///
907    /// This is used for norm weights, embedding tables, and other tensors
908    /// where the inference kernels operate on F32 directly.
909    ///
910    /// # Errors
911    ///
912    /// Returns an error if the tensor name is not found, reading fails, or
913    /// dequantization encounters malformed data.
914    pub fn load_tensor_f32(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
915        let info = self.tensors.get(name).ok_or_else(|| {
916            MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
917        })?;
918
919        let data = self.read_tensor_bytes(info)?;
920        let total_elements: usize = info.shape.iter().product();
921
922        if total_elements == 0 {
923            return Err(MlxError::GgufParseError(format!(
924                "tensor '{name}' has zero elements"
925            )));
926        }
927
928        let f32_byte_len = total_elements * 4;
929        let mut buf =
930            device.alloc_buffer(f32_byte_len, DType::F32, info.shape.clone())?;
931
932        {
933            let out_slice: &mut [f32] = buf.as_mut_slice()?;
934            dequantize_to_f32(&data, info.ggml_type, out_slice)?;
935        }
936
937        Ok(buf)
938    }
939}
940
941// ---------------------------------------------------------------------------
942// Utility
943// ---------------------------------------------------------------------------
944
945/// Round `offset` up to the next multiple of `alignment`.
946fn align_offset(offset: u64, alignment: u64) -> u64 {
947    let mask = alignment - 1;
948    (offset + mask) & !mask
949}