Skip to main content

mlx_native/gguf/
mod.rs

1//! GGUF v3 file format parser.
2//!
3//! Parses GGUF headers, metadata, and tensor info on open.  Tensor data is
4//! loaded lazily on demand into [`MlxBuffer`]s — either as raw GGML blocks
5//! (for GPU quantized matmul) or dequantized to F32 (for norm weights etc.).
6//!
7//! # Example
8//!
9//! ```ignore
10//! use mlx_native::gguf::GgufFile;
11//! use std::path::Path;
12//!
13//! let gguf = GgufFile::open(Path::new("model.gguf"))?;
14//! let names = gguf.tensor_names();
15//! let buf = gguf.load_tensor("blk.0.attn_q.weight", &device)?;
16//! let norm = gguf.load_tensor_f32("blk.0.attn_norm.weight", &device)?;
17//! ```
18
19use std::collections::HashMap;
20use std::io::{BufReader, Read, Seek, SeekFrom};
21use std::path::Path;
22use std::sync::Mutex;
23
24use half::f16;
25
26use crate::ops::quantized_matmul_ggml::GgmlType;
27use crate::{DType, MlxBuffer, MlxBufferPool, MlxDevice, MlxError, Result};
28
29// ---------------------------------------------------------------------------
30// GGUF constants
31// ---------------------------------------------------------------------------
32
33/// GGUF magic number: "GGUF" as little-endian u32 (bytes: 0x47 0x47 0x55 0x46).
34const GGUF_MAGIC: u32 = 0x4655_4747;
35
36/// GGUF version we support.
37const GGUF_VERSION: u32 = 3;
38
39/// Default alignment for the tensor data section.
40const GGUF_DEFAULT_ALIGNMENT: u64 = 32;
41
42/// Metadata key that overrides the default alignment.
43const GGUF_ALIGNMENT_KEY: &str = "general.alignment";
44
45// ---------------------------------------------------------------------------
46// GGUF metadata value type IDs
47// ---------------------------------------------------------------------------
48
49const GGUF_TYPE_UINT8: u32 = 0;
50const GGUF_TYPE_INT8: u32 = 1;
51const GGUF_TYPE_UINT16: u32 = 2;
52const GGUF_TYPE_INT16: u32 = 3;
53const GGUF_TYPE_UINT32: u32 = 4;
54const GGUF_TYPE_INT32: u32 = 5;
55const GGUF_TYPE_FLOAT32: u32 = 6;
56const GGUF_TYPE_BOOL: u32 = 7;
57const GGUF_TYPE_STRING: u32 = 8;
58const GGUF_TYPE_ARRAY: u32 = 9;
59const GGUF_TYPE_UINT64: u32 = 10;
60const GGUF_TYPE_INT64: u32 = 11;
61const GGUF_TYPE_FLOAT64: u32 = 12;
62
63// ---------------------------------------------------------------------------
64// GGML type IDs (from ggml.h)
65// ---------------------------------------------------------------------------
66
67const GGML_TYPE_F32: u32 = 0;
68const GGML_TYPE_F16: u32 = 1;
69const GGML_TYPE_Q4_0: u32 = 2;
70const GGML_TYPE_Q5_1: u32 = 7;
71const GGML_TYPE_Q8_0: u32 = 8;
72const GGML_TYPE_Q4_K: u32 = 12;
73const GGML_TYPE_Q5_K: u32 = 13;
74const GGML_TYPE_Q6_K: u32 = 14;
75const GGML_TYPE_I16: u32 = 17;
76const GGML_TYPE_IQ4_NL: u32 = 20;
77
78/// IQ4_NL non-linear codebook constants. 16 signed entries selected by
79/// 4-bit indices in `block_iq4_nl::qs`. Verified byte-equal with
80/// `/opt/llama.cpp/ggml/src/ggml-common.h:1109-1112`. ADR-022 Phase 1.
81const KVALUES_IQ4_NL: [i8; 16] = [
82    -127, -104, -83, -65, -49, -35, -22, -10, 1, 13, 25, 38, 53, 69, 89, 113,
83];
84
85// ---------------------------------------------------------------------------
86// Public types
87// ---------------------------------------------------------------------------
88
89/// GGUF metadata value types.
90#[derive(Debug, Clone)]
91pub enum MetadataValue {
92    Uint8(u8),
93    Int8(i8),
94    Uint16(u16),
95    Int16(i16),
96    Uint32(u32),
97    Int32(i32),
98    Float32(f32),
99    Bool(bool),
100    String(String),
101    Array(Vec<MetadataValue>),
102    Uint64(u64),
103    Int64(i64),
104    Float64(f64),
105}
106
107impl MetadataValue {
108    /// Try to interpret this value as a string reference.
109    pub fn as_str(&self) -> Option<&str> {
110        match self {
111            MetadataValue::String(s) => Some(s.as_str()),
112            _ => None,
113        }
114    }
115
116    /// Try to interpret this value as a u32.
117    pub fn as_u32(&self) -> Option<u32> {
118        match self {
119            MetadataValue::Uint32(v) => Some(*v),
120            MetadataValue::Uint8(v) => Some(*v as u32),
121            MetadataValue::Uint16(v) => Some(*v as u32),
122            MetadataValue::Int32(v) if *v >= 0 => Some(*v as u32),
123            _ => None,
124        }
125    }
126
127    /// Try to interpret this value as an f32.
128    pub fn as_f32(&self) -> Option<f32> {
129        match self {
130            MetadataValue::Float32(v) => Some(*v),
131            MetadataValue::Float64(v) => Some(*v as f32),
132            _ => None,
133        }
134    }
135}
136
137/// Information about a single tensor in the GGUF file.
138#[derive(Debug, Clone)]
139pub struct TensorInfo {
140    /// Tensor name (e.g. "blk.0.attn_q.weight").
141    pub name: String,
142    /// Tensor shape, innermost dimension first (as stored in GGUF).
143    pub shape: Vec<usize>,
144    /// GGML quantization type.
145    pub ggml_type: GgmlType,
146    /// Byte offset relative to the start of the tensor data section.
147    pub offset: u64,
148    /// Total byte length of this tensor's data.
149    pub byte_len: usize,
150}
151
152/// A parsed GGUF file, ready for lazy tensor loading.
153///
154/// The file is kept open so that tensor data can be read on demand via
155/// [`load_tensor`](GgufFile::load_tensor) and
156/// [`load_tensor_f32`](GgufFile::load_tensor_f32).
157pub struct GgufFile {
158    metadata: HashMap<String, MetadataValue>,
159    tensors: HashMap<String, TensorInfo>,
160    /// Absolute byte offset in the file where tensor data begins.
161    tensor_data_offset: u64,
162    reader: Mutex<BufReader<std::fs::File>>,
163}
164
165// ---------------------------------------------------------------------------
166// Low-level read helpers
167// ---------------------------------------------------------------------------
168
169/// Read a little-endian u8.
170fn read_u8<R: Read>(r: &mut R) -> Result<u8> {
171    let mut buf = [0u8; 1];
172    r.read_exact(&mut buf)
173        .map_err(|e| MlxError::GgufParseError(format!("read u8: {e}")))?;
174    Ok(buf[0])
175}
176
177/// Read a little-endian i8.
178fn read_i8<R: Read>(r: &mut R) -> Result<i8> {
179    Ok(read_u8(r)? as i8)
180}
181
182/// Read a little-endian u16.
183fn read_u16<R: Read>(r: &mut R) -> Result<u16> {
184    let mut buf = [0u8; 2];
185    r.read_exact(&mut buf)
186        .map_err(|e| MlxError::GgufParseError(format!("read u16: {e}")))?;
187    Ok(u16::from_le_bytes(buf))
188}
189
190/// Read a little-endian i16.
191fn read_i16<R: Read>(r: &mut R) -> Result<i16> {
192    let mut buf = [0u8; 2];
193    r.read_exact(&mut buf)
194        .map_err(|e| MlxError::GgufParseError(format!("read i16: {e}")))?;
195    Ok(i16::from_le_bytes(buf))
196}
197
198/// Read a little-endian u32.
199fn read_u32<R: Read>(r: &mut R) -> Result<u32> {
200    let mut buf = [0u8; 4];
201    r.read_exact(&mut buf)
202        .map_err(|e| MlxError::GgufParseError(format!("read u32: {e}")))?;
203    Ok(u32::from_le_bytes(buf))
204}
205
206/// Read a little-endian i32.
207fn read_i32<R: Read>(r: &mut R) -> Result<i32> {
208    let mut buf = [0u8; 4];
209    r.read_exact(&mut buf)
210        .map_err(|e| MlxError::GgufParseError(format!("read i32: {e}")))?;
211    Ok(i32::from_le_bytes(buf))
212}
213
214/// Read a little-endian u64.
215fn read_u64<R: Read>(r: &mut R) -> Result<u64> {
216    let mut buf = [0u8; 8];
217    r.read_exact(&mut buf)
218        .map_err(|e| MlxError::GgufParseError(format!("read u64: {e}")))?;
219    Ok(u64::from_le_bytes(buf))
220}
221
222/// Read a little-endian i64.
223fn read_i64<R: Read>(r: &mut R) -> Result<i64> {
224    let mut buf = [0u8; 8];
225    r.read_exact(&mut buf)
226        .map_err(|e| MlxError::GgufParseError(format!("read i64: {e}")))?;
227    Ok(i64::from_le_bytes(buf))
228}
229
230/// Read a little-endian f32.
231fn read_f32<R: Read>(r: &mut R) -> Result<f32> {
232    let mut buf = [0u8; 4];
233    r.read_exact(&mut buf)
234        .map_err(|e| MlxError::GgufParseError(format!("read f32: {e}")))?;
235    Ok(f32::from_le_bytes(buf))
236}
237
238/// Read a little-endian f64.
239fn read_f64<R: Read>(r: &mut R) -> Result<f64> {
240    let mut buf = [0u8; 8];
241    r.read_exact(&mut buf)
242        .map_err(|e| MlxError::GgufParseError(format!("read f64: {e}")))?;
243    Ok(f64::from_le_bytes(buf))
244}
245
246/// Read a GGUF-format string: u64 length followed by UTF-8 bytes (not
247/// null-terminated).
248fn read_gguf_string<R: Read>(r: &mut R) -> Result<String> {
249    let len = read_u64(r)? as usize;
250    if len > 256 * 1024 * 1024 {
251        return Err(MlxError::GgufParseError(format!(
252            "string length {len} exceeds 256 MiB safety limit"
253        )));
254    }
255    let mut buf = vec![0u8; len];
256    r.read_exact(&mut buf)
257        .map_err(|e| MlxError::GgufParseError(format!("read string bytes: {e}")))?;
258    String::from_utf8(buf)
259        .map_err(|e| MlxError::GgufParseError(format!("invalid UTF-8 in string: {e}")))
260}
261
262// ---------------------------------------------------------------------------
263// Metadata value parsing
264// ---------------------------------------------------------------------------
265
266/// Read a single metadata value of the given type.
267fn read_metadata_value<R: Read>(r: &mut R, value_type: u32) -> Result<MetadataValue> {
268    match value_type {
269        GGUF_TYPE_UINT8 => Ok(MetadataValue::Uint8(read_u8(r)?)),
270        GGUF_TYPE_INT8 => Ok(MetadataValue::Int8(read_i8(r)?)),
271        GGUF_TYPE_UINT16 => Ok(MetadataValue::Uint16(read_u16(r)?)),
272        GGUF_TYPE_INT16 => Ok(MetadataValue::Int16(read_i16(r)?)),
273        GGUF_TYPE_UINT32 => Ok(MetadataValue::Uint32(read_u32(r)?)),
274        GGUF_TYPE_INT32 => Ok(MetadataValue::Int32(read_i32(r)?)),
275        GGUF_TYPE_FLOAT32 => Ok(MetadataValue::Float32(read_f32(r)?)),
276        GGUF_TYPE_BOOL => {
277            let byte = read_u8(r)?;
278            Ok(MetadataValue::Bool(byte != 0))
279        }
280        GGUF_TYPE_STRING => Ok(MetadataValue::String(read_gguf_string(r)?)),
281        GGUF_TYPE_ARRAY => {
282            let elem_type = read_u32(r)?;
283            let count = read_u64(r)? as usize;
284            if count > 64 * 1024 * 1024 {
285                return Err(MlxError::GgufParseError(format!(
286                    "array count {count} exceeds 64M element safety limit"
287                )));
288            }
289            let mut elems = Vec::with_capacity(count);
290            for _ in 0..count {
291                elems.push(read_metadata_value(r, elem_type)?);
292            }
293            Ok(MetadataValue::Array(elems))
294        }
295        GGUF_TYPE_UINT64 => Ok(MetadataValue::Uint64(read_u64(r)?)),
296        GGUF_TYPE_INT64 => Ok(MetadataValue::Int64(read_i64(r)?)),
297        GGUF_TYPE_FLOAT64 => Ok(MetadataValue::Float64(read_f64(r)?)),
298        other => Err(MlxError::GgufParseError(format!(
299            "unknown metadata value type {other}"
300        ))),
301    }
302}
303
304// ---------------------------------------------------------------------------
305// GGML type mapping
306// ---------------------------------------------------------------------------
307
308/// Map a GGML type ID (u32 from the GGUF file) to our `GgmlType` enum.
309fn ggml_type_from_u32(id: u32) -> Result<GgmlType> {
310    match id {
311        GGML_TYPE_F32 => Ok(GgmlType::F32),
312        GGML_TYPE_F16 => Ok(GgmlType::F16),
313        GGML_TYPE_Q4_0 => Ok(GgmlType::Q4_0),
314        GGML_TYPE_Q5_1 => Ok(GgmlType::Q5_1),
315        GGML_TYPE_Q8_0 => Ok(GgmlType::Q8_0),
316        GGML_TYPE_Q4_K => Ok(GgmlType::Q4_K),
317        GGML_TYPE_Q5_K => Ok(GgmlType::Q5_K),
318        GGML_TYPE_Q6_K => Ok(GgmlType::Q6_K),
319        GGML_TYPE_I16 => Ok(GgmlType::I16),
320        GGML_TYPE_IQ4_NL => Ok(GgmlType::IQ4_NL),
321        other => Err(MlxError::GgufParseError(format!(
322            "unsupported GGML type ID {other}"
323        ))),
324    }
325}
326
327/// Compute the byte length of a tensor from its shape and GGML type.
328///
329/// For quantized types, the innermost dimension (shape[0] in GGUF's row-major
330/// convention) must be divisible by the block's element count.
331fn compute_byte_len(shape: &[usize], ggml_type: GgmlType) -> Result<usize> {
332    let total_elements: usize = shape.iter().product();
333    if total_elements == 0 {
334        return Ok(0);
335    }
336
337    let elems_per_block = ggml_type.block_values() as usize;
338    let bytes_per_block = ggml_type.block_bytes() as usize;
339
340    if total_elements % elems_per_block != 0 {
341        return Err(MlxError::GgufParseError(format!(
342            "total elements {total_elements} not divisible by block size {elems_per_block} \
343             for type {:?}",
344            ggml_type
345        )));
346    }
347
348    Ok((total_elements / elems_per_block) * bytes_per_block)
349}
350
351// ---------------------------------------------------------------------------
352// Dequantization
353// ---------------------------------------------------------------------------
354
355/// Convert a raw little-endian f16 (2 bytes) to f32.
356#[inline]
357fn f16_from_le_bytes(bytes: [u8; 2]) -> f32 {
358    f16::from_le_bytes(bytes).to_f32()
359}
360
361/// Dequantize Q4_0 blocks to f32.
362///
363/// Block layout (18 bytes, 32 elements):
364///   f16 d          — scale
365///   u8  qs[16]     — packed 4-bit values (low nibble = first 16, high nibble = last 16)
366fn dequantize_q4_0(data: &[u8], output: &mut [f32]) -> Result<()> {
367    const BLOCK_BYTES: usize = 18;
368    const BLOCK_ELEMS: usize = 32;
369
370    if data.len() % BLOCK_BYTES != 0 {
371        return Err(MlxError::GgufParseError(format!(
372            "Q4_0 data length {} not divisible by block size {BLOCK_BYTES}",
373            data.len()
374        )));
375    }
376
377    let num_blocks = data.len() / BLOCK_BYTES;
378    if output.len() < num_blocks * BLOCK_ELEMS {
379        return Err(MlxError::GgufParseError(
380            "Q4_0 output buffer too small".into(),
381        ));
382    }
383
384    for i in 0..num_blocks {
385        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
386        let d = f16_from_le_bytes([block[0], block[1]]);
387        let qs = &block[2..18]; // 16 bytes
388
389        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
390
391        for j in 0..16 {
392            let x0 = (qs[j] & 0x0F) as i16 - 8;
393            let x1 = (qs[j] >> 4) as i16 - 8;
394            out[j] = x0 as f32 * d;
395            out[j + 16] = x1 as f32 * d;
396        }
397    }
398    Ok(())
399}
400
401/// Dequantize Q8_0 blocks to f32.
402///
403/// Block layout (34 bytes, 32 elements):
404///   f16 d         — scale
405///   i8  qs[32]    — signed 8-bit quantized values
406fn dequantize_q8_0(data: &[u8], output: &mut [f32]) -> Result<()> {
407    const BLOCK_BYTES: usize = 34;
408    const BLOCK_ELEMS: usize = 32;
409
410    if data.len() % BLOCK_BYTES != 0 {
411        return Err(MlxError::GgufParseError(format!(
412            "Q8_0 data length {} not divisible by block size {BLOCK_BYTES}",
413            data.len()
414        )));
415    }
416
417    let num_blocks = data.len() / BLOCK_BYTES;
418    if output.len() < num_blocks * BLOCK_ELEMS {
419        return Err(MlxError::GgufParseError(
420            "Q8_0 output buffer too small".into(),
421        ));
422    }
423
424    for i in 0..num_blocks {
425        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
426        let d = f16_from_le_bytes([block[0], block[1]]);
427        let qs = &block[2..34]; // 32 bytes of i8
428
429        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
430
431        for j in 0..32 {
432            out[j] = (qs[j] as i8) as f32 * d;
433        }
434    }
435    Ok(())
436}
437
438/// Extract a (scale, min) pair for sub-block `j` from the 12-byte scales
439/// array used by Q4_K and Q5_K.
440///
441/// This matches `get_scale_min_k4` from candle / llama.cpp exactly:
442///
443/// For j < 4:
444///   scale = scales[j] & 63
445///   min   = scales[j + 4] & 63
446///
447/// For j >= 4:
448///   scale = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4)
449///   min   = (scales[j + 4] >> 4)  | ((scales[j]     >> 6) << 4)
450#[inline]
451fn get_scale_min_k4(j: usize, scales: &[u8]) -> (u8, u8) {
452    if j < 4 {
453        let sc = scales[j] & 63;
454        let m = scales[j + 4] & 63;
455        (sc, m)
456    } else {
457        let sc = (scales[j + 4] & 0xF) | ((scales[j - 4] >> 6) << 4);
458        let m = (scales[j + 4] >> 4) | ((scales[j] >> 6) << 4);
459        (sc, m)
460    }
461}
462
463/// Dequantize Q5_K blocks to f32.
464///
465/// Block layout (176 bytes, 256 elements):
466///   f16 d           — super-block scale      (offset 0,  2 bytes)
467///   f16 dmin        — super-block minimum     (offset 2,  2 bytes)
468///   u8  scales[12]  — packed 6-bit scales/mins (offset 4,  12 bytes; shared with Q4_K)
469///   u8  qh[32]      — high bits of quants      (offset 16, 32 bytes = QK_K/8)
470///   u8  qs[128]     — low 4 bits of quants     (offset 48, 128 bytes = QK_K/2)
471///
472/// 8 sub-blocks of 32 elements each. Dequantization walks pairs of
473/// sub-blocks (is, is+1), each pair consumes 32 bytes of qs (low nibble
474/// for is, high nibble for is+1). The qh array is SHARED across all 4
475/// pairs — the high bit per element is masked out of qh using shifting
476/// selector values `u1 = 1 << (2*pair_idx)` / `u2 = 2 << (2*pair_idx)`.
477///
478/// Spec source: derived from `ggml/src/ggml-quants.c::dequantize_row_q5_K`.
479/// No code copied — formula reproduced from the mathematical definition.
480fn dequantize_q5_k(data: &[u8], output: &mut [f32]) -> Result<()> {
481    const BLOCK_BYTES: usize = 176;
482    const BLOCK_ELEMS: usize = 256;
483
484    if data.len() % BLOCK_BYTES != 0 {
485        return Err(MlxError::GgufParseError(format!(
486            "Q5_K data length {} not divisible by block size {BLOCK_BYTES}",
487            data.len()
488        )));
489    }
490
491    let num_blocks = data.len() / BLOCK_BYTES;
492    if output.len() < num_blocks * BLOCK_ELEMS {
493        return Err(MlxError::GgufParseError(
494            "Q5_K output buffer too small".into(),
495        ));
496    }
497
498    for i in 0..num_blocks {
499        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
500
501        let d = f16_from_le_bytes([block[0], block[1]]);
502        let dmin = f16_from_le_bytes([block[2], block[3]]);
503        let scales = &block[4..16]; // 12 bytes
504        let qh = &block[16..48]; // 32 bytes — high bit of quants
505        let qs = &block[48..176]; // 128 bytes — low 4 bits
506
507        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
508
509        // Process 4 pairs of sub-blocks (256 values total).
510        // u1 / u2 are the high-bit selector masks: they shift left by 2 each
511        // iteration so the 4 pairs pick bits 0/1, 2/3, 4/5, 6/7 of each qh byte.
512        let mut is = 0usize;
513        let mut u1: u8 = 1;
514        let mut u2: u8 = 2;
515        let mut ys_index = 0usize;
516        let mut ql_off = 0usize;
517
518        while ql_off < 128 {
519            let ql = &qs[ql_off..ql_off + 32];
520
521            let (sc1, m1) = get_scale_min_k4(is, scales);
522            let d1 = d * sc1 as f32;
523            let m1 = dmin * m1 as f32;
524            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
525            let d2 = d * sc2 as f32;
526            let m2 = dmin * m2 as f32;
527
528            // Sub-block `is` (low nibble + high bit from qh masked by u1).
529            for l in 0..32 {
530                let low = (ql[l] & 0x0F) as u32;
531                let high = if (qh[l] & u1) != 0 { 16 } else { 0 };
532                let q = low + high;
533                out[ys_index] = d1 * q as f32 - m1;
534                ys_index += 1;
535            }
536            // Sub-block `is + 1` (high nibble + high bit from qh masked by u2).
537            for l in 0..32 {
538                let low = (ql[l] >> 4) as u32;
539                let high = if (qh[l] & u2) != 0 { 16 } else { 0 };
540                let q = low + high;
541                out[ys_index] = d2 * q as f32 - m2;
542                ys_index += 1;
543            }
544
545            is += 2;
546            ql_off += 32;
547            u1 <<= 2;
548            u2 <<= 2;
549        }
550    }
551    Ok(())
552}
553
554/// Dequantize I16 tensors to f32.
555///
556/// Simple bitcast: `f32_val = i16_val as f32`. No scale metadata is used
557/// (apex GGUF convention — raw int16 values are meaningful as-is).
558///
559/// ADR-013 Decision 12 originally anticipated a per-tensor scale factor,
560/// but the apex GGUF does not emit one; values are stored as raw ints.
561/// If future GGUFs emit a scale, extend this with a scale parameter.
562fn dequantize_i16(data: &[u8], output: &mut [f32]) -> Result<()> {
563    if data.len() % 2 != 0 {
564        return Err(MlxError::GgufParseError(format!(
565            "I16 data length {} not even",
566            data.len()
567        )));
568    }
569    let num_elements = data.len() / 2;
570    if output.len() < num_elements {
571        return Err(MlxError::GgufParseError(
572            "I16 output buffer too small".into(),
573        ));
574    }
575    for i in 0..num_elements {
576        let v = i16::from_le_bytes([data[2 * i], data[2 * i + 1]]);
577        output[i] = v as f32;
578    }
579    Ok(())
580}
581
582/// Dequantize Q4_K blocks to f32.
583///
584/// Block layout (144 bytes, 256 elements):
585///   f16 d          — super-block scale          (offset 0,  2 bytes)
586///   f16 dmin       — super-block minimum         (offset 2,  2 bytes)
587///   u8  scales[12] — packed sub-block scales/mins (offset 4, 12 bytes)
588///   u8  qs[128]    — packed 4-bit quantized values (offset 16, 128 bytes)
589///
590/// 8 sub-blocks of 32 elements each.  Each pair of sub-blocks (64 elements)
591/// shares 32 bytes of qs — the low nibble gives the first sub-block, the
592/// high nibble gives the second.
593fn dequantize_q4_k(data: &[u8], output: &mut [f32]) -> Result<()> {
594    const BLOCK_BYTES: usize = 144;
595    const BLOCK_ELEMS: usize = 256;
596
597    if data.len() % BLOCK_BYTES != 0 {
598        return Err(MlxError::GgufParseError(format!(
599            "Q4_K data length {} not divisible by block size {BLOCK_BYTES}",
600            data.len()
601        )));
602    }
603
604    let num_blocks = data.len() / BLOCK_BYTES;
605    if output.len() < num_blocks * BLOCK_ELEMS {
606        return Err(MlxError::GgufParseError(
607            "Q4_K output buffer too small".into(),
608        ));
609    }
610
611    for i in 0..num_blocks {
612        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
613
614        let d = f16_from_le_bytes([block[0], block[1]]);
615        let dmin = f16_from_le_bytes([block[2], block[3]]);
616        let scales = &block[4..16];   // 12 bytes
617        let qs = &block[16..144];     // 128 bytes
618
619        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
620
621        // Process 4 pairs of sub-blocks (8 sub-blocks total, 256 elements).
622        // Each iteration handles 64 elements: sub-block `is` (low nibbles)
623        // and sub-block `is+1` (high nibbles) from 32 bytes of qs.
624        let mut is = 0usize;
625        let mut ys_index = 0usize;
626
627        // Step through the 256-element super-block in chunks of 64.
628        // j tracks the byte offset within qs.
629        let mut j = 0usize;
630        while j < 128 {
631            let q = &qs[j..j + 32];
632            let (sc1, m1) = get_scale_min_k4(is, scales);
633            let d1 = d * sc1 as f32;
634            let min1 = dmin * m1 as f32;
635            let (sc2, m2) = get_scale_min_k4(is + 1, scales);
636            let d2 = d * sc2 as f32;
637            let min2 = dmin * m2 as f32;
638
639            // Low nibbles: sub-block `is` (32 elements)
640            for byte in q.iter() {
641                out[ys_index] = d1 * (*byte & 0xF) as f32 - min1;
642                ys_index += 1;
643            }
644            // High nibbles: sub-block `is + 1` (32 elements)
645            for byte in q.iter() {
646                out[ys_index] = d2 * (*byte >> 4) as f32 - min2;
647                ys_index += 1;
648            }
649
650            is += 2;
651            j += 32;
652        }
653    }
654    Ok(())
655}
656
657/// Dequantize Q6_K blocks to f32.
658///
659/// Block layout (210 bytes, 256 elements):
660///   u8   ql[128]   — low 4 bits of quantized values  (offset 0, 128 bytes)
661///   u8   qh[64]    — high 2 bits of quantized values  (offset 128, 64 bytes)
662///   i8   scales[16] — sub-block scales                (offset 192, 16 bytes)
663///   f16  d          — super-block scale               (offset 208, 2 bytes)
664///
665/// 256 elements organized as 2 groups of 128.  Each group of 128 has its own
666/// ql[64], qh[32] region and produces 4 interleaved sub-groups of 32.
667fn dequantize_q6_k(data: &[u8], output: &mut [f32]) -> Result<()> {
668    const BLOCK_BYTES: usize = 210;
669    const BLOCK_ELEMS: usize = 256;
670
671    if data.len() % BLOCK_BYTES != 0 {
672        return Err(MlxError::GgufParseError(format!(
673            "Q6_K data length {} not divisible by block size {BLOCK_BYTES}",
674            data.len()
675        )));
676    }
677
678    let num_blocks = data.len() / BLOCK_BYTES;
679    if output.len() < num_blocks * BLOCK_ELEMS {
680        return Err(MlxError::GgufParseError(
681            "Q6_K output buffer too small".into(),
682        ));
683    }
684
685    for i in 0..num_blocks {
686        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
687
688        let ql = &block[0..128];
689        let qh = &block[128..192];
690        let sc = &block[192..208]; // i8 scales[16]
691        let d = f16_from_le_bytes([block[208], block[209]]);
692
693        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
694
695        // Process in two groups of 128 (idx = 0 and idx = 1).
696        for idx in 0..2 {
697            let ql_base = &ql[64 * idx..];
698            let qh_base = &qh[32 * idx..];
699            let sc_base = &sc[8 * idx..];
700            let out_base = &mut out[128 * idx..];
701
702            for l in 0..32 {
703                let is = l / 16; // 0 for l in 0..16, 1 for l in 16..32
704
705                let q1 = ((ql_base[l] & 0xF) | ((qh_base[l] & 3) << 4)) as i8 - 32_i8;
706                let q2 = ((ql_base[l + 32] & 0xF) | (((qh_base[l] >> 2) & 3) << 4)) as i8
707                    - 32_i8;
708                let q3 = ((ql_base[l] >> 4) | (((qh_base[l] >> 4) & 3) << 4)) as i8 - 32_i8;
709                let q4 = ((ql_base[l + 32] >> 4) | (((qh_base[l] >> 6) & 3) << 4)) as i8
710                    - 32_i8;
711
712                out_base[l] = d * sc_base[is] as i8 as f32 * q1 as f32;
713                out_base[l + 32] = d * sc_base[is + 2] as i8 as f32 * q2 as f32;
714                out_base[l + 64] = d * sc_base[is + 4] as i8 as f32 * q3 as f32;
715                out_base[l + 96] = d * sc_base[is + 6] as i8 as f32 * q4 as f32;
716            }
717        }
718    }
719    Ok(())
720}
721
722/// Dequantize F16 data to F32.
723fn dequantize_f16(data: &[u8], output: &mut [f32]) -> Result<()> {
724    if data.len() % 2 != 0 {
725        return Err(MlxError::GgufParseError(
726            "F16 data length not even".into(),
727        ));
728    }
729    let count = data.len() / 2;
730    if output.len() < count {
731        return Err(MlxError::GgufParseError(
732            "F16 output buffer too small".into(),
733        ));
734    }
735    for i in 0..count {
736        output[i] = f16_from_le_bytes([data[2 * i], data[2 * i + 1]]);
737    }
738    Ok(())
739}
740
741/// Reinterpret F32 little-endian bytes into the output slice.
742fn copy_f32(data: &[u8], output: &mut [f32]) -> Result<()> {
743    if data.len() % 4 != 0 {
744        return Err(MlxError::GgufParseError(
745            "F32 data length not multiple of 4".into(),
746        ));
747    }
748    let count = data.len() / 4;
749    if output.len() < count {
750        return Err(MlxError::GgufParseError(
751            "F32 output buffer too small".into(),
752        ));
753    }
754    for i in 0..count {
755        output[i] = f32::from_le_bytes([
756            data[4 * i],
757            data[4 * i + 1],
758            data[4 * i + 2],
759            data[4 * i + 3],
760        ]);
761    }
762    Ok(())
763}
764
765/// Dequantize Q5_1 blocks to f32.
766///
767/// Block layout (24 bytes, 32 elements):
768///   f16 d   — block scale            (offset 0,  2 bytes)
769///   f16 m   — block min term         (offset 2,  2 bytes)
770///   u32 qh  — high-bit pack          (offset 4,  4 bytes)
771///   u8  qs[16] — packed 4-bit lo nibbles (offset 8, 16 bytes)
772///
773/// Per-element: `out[j]      = d * x0 + m`, `out[j + 16] = d * x1 + m`,
774/// where `x0 = (qs[j] & 0x0F) | ((qh >> j) << 4) & 0x10`,
775///       `x1 = (qs[j] >> 4)  | ((qh >> (j + 12)) & 0x10)`.
776///
777/// Reference: `/opt/llama.cpp/ggml/src/ggml-quants.c:464` `dequantize_row_q5_1`.
778/// ADR-022 Phase 1.
779fn dequantize_q5_1(data: &[u8], output: &mut [f32]) -> Result<()> {
780    const BLOCK_BYTES: usize = 24;
781    const BLOCK_ELEMS: usize = 32;
782
783    if data.len() % BLOCK_BYTES != 0 {
784        return Err(MlxError::GgufParseError(format!(
785            "Q5_1 data length {} not divisible by block size {BLOCK_BYTES}",
786            data.len()
787        )));
788    }
789
790    let num_blocks = data.len() / BLOCK_BYTES;
791    if output.len() < num_blocks * BLOCK_ELEMS {
792        return Err(MlxError::GgufParseError(
793            "Q5_1 output buffer too small".into(),
794        ));
795    }
796
797    for i in 0..num_blocks {
798        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
799
800        let d = f16_from_le_bytes([block[0], block[1]]);
801        let m = f16_from_le_bytes([block[2], block[3]]);
802        let qh = u32::from_le_bytes([block[4], block[5], block[6], block[7]]);
803        let qs = &block[8..24]; // 16 bytes
804
805        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
806
807        for j in 0..(BLOCK_ELEMS / 2) {
808            // High-bit packed: bit j of qh contributes to position j;
809            // bit (j + 16) contributes to position j + 16. Mirrors
810            // `dequantize_row_q5_1` byte-for-byte.
811            let xh_0 = (((qh >> j) << 4) & 0x10) as u8;
812            let xh_1 = ((qh >> (j + 12)) & 0x10) as u8;
813            let x0 = ((qs[j] & 0x0F) | xh_0) as i32;
814            let x1 = ((qs[j] >> 4) | xh_1) as i32;
815            out[j] = (x0 as f32) * d + m;
816            out[j + BLOCK_ELEMS / 2] = (x1 as f32) * d + m;
817        }
818    }
819    Ok(())
820}
821
822/// Dequantize IQ4_NL blocks to f32.
823///
824/// Block layout (18 bytes, 32 elements):
825///   f16 d      — block scale                (offset 0,  2 bytes)
826///   u8  qs[16] — 16 × pair of 4-bit indices (offset 2, 16 bytes)
827///
828/// Per-element: `out[j]      = d * KVALUES_IQ4_NL[qs[j] & 0x0F]`,
829///              `out[j + 16] = d * KVALUES_IQ4_NL[qs[j] >> 4]`.
830///
831/// Reference: `/opt/llama.cpp/ggml/src/ggml-quants.c:2649` `dequantize_row_iq4_nl`.
832/// Codebook table verified against `ggml-common.h:1109-1112`. ADR-022 Phase 1.
833fn dequantize_iq4_nl(data: &[u8], output: &mut [f32]) -> Result<()> {
834    const BLOCK_BYTES: usize = 18;
835    const BLOCK_ELEMS: usize = 32;
836
837    if data.len() % BLOCK_BYTES != 0 {
838        return Err(MlxError::GgufParseError(format!(
839            "IQ4_NL data length {} not divisible by block size {BLOCK_BYTES}",
840            data.len()
841        )));
842    }
843
844    let num_blocks = data.len() / BLOCK_BYTES;
845    if output.len() < num_blocks * BLOCK_ELEMS {
846        return Err(MlxError::GgufParseError(
847            "IQ4_NL output buffer too small".into(),
848        ));
849    }
850
851    for i in 0..num_blocks {
852        let block = &data[i * BLOCK_BYTES..(i + 1) * BLOCK_BYTES];
853
854        let d = f16_from_le_bytes([block[0], block[1]]);
855        let qs = &block[2..18];
856
857        let out = &mut output[i * BLOCK_ELEMS..(i + 1) * BLOCK_ELEMS];
858
859        for j in 0..(BLOCK_ELEMS / 2) {
860            let lo = (qs[j] & 0x0F) as usize;
861            let hi = (qs[j] >> 4) as usize;
862            out[j] = d * KVALUES_IQ4_NL[lo] as f32;
863            out[j + BLOCK_ELEMS / 2] = d * KVALUES_IQ4_NL[hi] as f32;
864        }
865    }
866    Ok(())
867}
868
869/// Test-only export of `dequantize_q5_1` for ADR-022 parity tests in
870/// `/opt/mlx-native/tests/adr_022_phase1_dequant_parity.rs`. Hidden
871/// behind a doc(hidden) marker so it's not part of the public API but
872/// is accessible from integration tests via crate::gguf.
873#[doc(hidden)]
874pub fn test_only_dequantize_q5_1(data: &[u8], output: &mut [f32]) -> Result<()> {
875    dequantize_q5_1(data, output)
876}
877
878/// Test-only export of `dequantize_iq4_nl` for ADR-022 parity tests.
879#[doc(hidden)]
880pub fn test_only_dequantize_iq4_nl(data: &[u8], output: &mut [f32]) -> Result<()> {
881    dequantize_iq4_nl(data, output)
882}
883
884/// Test-only accessor for `KVALUES_IQ4_NL` so parity tests can pin the
885/// codebook bytes against the llama.cpp source of truth.
886#[doc(hidden)]
887pub fn test_only_kvalues_iq4_nl() -> [i8; 16] {
888    KVALUES_IQ4_NL
889}
890
891/// Test-only export of `dequantize_to_f32` for ADR-022 Phase-2 Q5_K
892/// dense parity tests. Routes through the same dispatch as the
893/// production load path. Hidden from rustdoc.
894#[doc(hidden)]
895pub fn test_only_dequantize(data: &[u8], ggml_type: GgmlType, output: &mut [f32]) -> Result<()> {
896    dequantize_to_f32(data, ggml_type, output)
897}
898
899/// Dequantize raw GGML block data to f32.
900fn dequantize_to_f32(data: &[u8], ggml_type: GgmlType, output: &mut [f32]) -> Result<()> {
901    match ggml_type {
902        GgmlType::F32 => copy_f32(data, output),
903        GgmlType::F16 => dequantize_f16(data, output),
904        GgmlType::Q4_0 => dequantize_q4_0(data, output),
905        GgmlType::Q8_0 => dequantize_q8_0(data, output),
906        GgmlType::Q4_K => dequantize_q4_k(data, output),
907        GgmlType::Q6_K => dequantize_q6_k(data, output),
908        GgmlType::Q5_K => dequantize_q5_k(data, output),
909        GgmlType::I16 => dequantize_i16(data, output),
910        GgmlType::Q5_1 => dequantize_q5_1(data, output),
911        GgmlType::IQ4_NL => dequantize_iq4_nl(data, output),
912    }
913}
914
915// ---------------------------------------------------------------------------
916// GgufFile implementation
917// ---------------------------------------------------------------------------
918
919impl GgufFile {
920    /// Open and parse a GGUF v3 file.
921    ///
922    /// This reads the full header (magic, version, tensor count, metadata KV
923    /// pairs, tensor info entries) but does **not** read any tensor data.
924    /// Tensor data is loaded lazily via [`load_tensor`](Self::load_tensor) or
925    /// [`load_tensor_f32`](Self::load_tensor_f32).
926    ///
927    /// # Errors
928    ///
929    /// Returns `MlxError::IoError` if the file cannot be opened.
930    /// Returns `MlxError::GgufParseError` if the file is not valid GGUF v3.
931    pub fn open(path: &Path) -> Result<Self> {
932        let file = std::fs::File::open(path).map_err(|e| {
933            MlxError::IoError(format!("cannot open GGUF file '{}': {e}", path.display()))
934        })?;
935        let mut reader = BufReader::new(file);
936
937        // --- Header ---
938        let magic = read_u32(&mut reader)?;
939        if magic != GGUF_MAGIC {
940            return Err(MlxError::GgufParseError(format!(
941                "bad magic: expected 0x{GGUF_MAGIC:08X}, got 0x{magic:08X}"
942            )));
943        }
944
945        let version = read_u32(&mut reader)?;
946        if version != GGUF_VERSION {
947            return Err(MlxError::GgufParseError(format!(
948                "unsupported GGUF version {version} (only v3 is supported)"
949            )));
950        }
951
952        let tensor_count = read_u64(&mut reader)? as usize;
953        let metadata_kv_count = read_u64(&mut reader)? as usize;
954
955        // Sanity limits to prevent OOM on corrupted files.
956        if tensor_count > 100_000 {
957            return Err(MlxError::GgufParseError(format!(
958                "tensor_count {tensor_count} exceeds 100k safety limit"
959            )));
960        }
961        if metadata_kv_count > 1_000_000 {
962            return Err(MlxError::GgufParseError(format!(
963                "metadata_kv_count {metadata_kv_count} exceeds 1M safety limit"
964            )));
965        }
966
967        // --- Metadata KV pairs ---
968        let mut metadata = HashMap::with_capacity(metadata_kv_count);
969        for _ in 0..metadata_kv_count {
970            let key = read_gguf_string(&mut reader)?;
971            let value_type = read_u32(&mut reader)?;
972            let value = read_metadata_value(&mut reader, value_type)?;
973            metadata.insert(key, value);
974        }
975
976        // --- Determine alignment ---
977        let alignment = metadata
978            .get(GGUF_ALIGNMENT_KEY)
979            .and_then(|v| v.as_u32())
980            .map(|v| v as u64)
981            .unwrap_or(GGUF_DEFAULT_ALIGNMENT);
982
983        if alignment == 0 || (alignment & (alignment - 1)) != 0 {
984            return Err(MlxError::GgufParseError(format!(
985                "alignment {alignment} is not a power of two"
986            )));
987        }
988
989        // --- Tensor info entries ---
990        let mut tensors = HashMap::with_capacity(tensor_count);
991        for _ in 0..tensor_count {
992            let name = read_gguf_string(&mut reader)?;
993            let n_dims = read_u32(&mut reader)? as usize;
994
995            if n_dims > 8 {
996                return Err(MlxError::GgufParseError(format!(
997                    "tensor '{name}' has {n_dims} dimensions (max 8)"
998                )));
999            }
1000
1001            let mut shape = Vec::with_capacity(n_dims);
1002            for _ in 0..n_dims {
1003                shape.push(read_u64(&mut reader)? as usize);
1004            }
1005            // GGUF stores dimensions innermost-first (column-major order).
1006            // Reverse to match the [rows, cols] convention used by candle
1007            // and by the rest of hf2q's weight loading code.
1008            shape.reverse();
1009
1010            let ggml_type_id = read_u32(&mut reader)?;
1011            let ggml_type = ggml_type_from_u32(ggml_type_id).map_err(|e| {
1012                MlxError::GgufParseError(format!("tensor '{name}': {e}"))
1013            })?;
1014
1015            let offset = read_u64(&mut reader)?;
1016            let byte_len = compute_byte_len(&shape, ggml_type).map_err(|e| {
1017                MlxError::GgufParseError(format!("tensor '{name}': {e}"))
1018            })?;
1019
1020            tensors.insert(
1021                name.clone(),
1022                TensorInfo {
1023                    name,
1024                    shape,
1025                    ggml_type,
1026                    offset,
1027                    byte_len,
1028                },
1029            );
1030        }
1031
1032        // --- Compute tensor_data_offset ---
1033        // The current file position is just past all tensor info entries.
1034        // Tensor data starts at the next alignment boundary.
1035        let pos = reader
1036            .stream_position()
1037            .map_err(|e| MlxError::GgufParseError(format!("stream_position: {e}")))?;
1038        let tensor_data_offset = align_offset(pos, alignment);
1039
1040        Ok(GgufFile {
1041            metadata,
1042            tensors,
1043            tensor_data_offset,
1044            reader: Mutex::new(reader),
1045        })
1046    }
1047
1048    // -----------------------------------------------------------------------
1049    // Metadata accessors
1050    // -----------------------------------------------------------------------
1051
1052    /// Look up a metadata value by key.
1053    pub fn metadata(&self, key: &str) -> Option<&MetadataValue> {
1054        self.metadata.get(key)
1055    }
1056
1057    /// Look up a metadata string value by key.
1058    pub fn metadata_string(&self, key: &str) -> Option<&str> {
1059        self.metadata.get(key).and_then(|v| v.as_str())
1060    }
1061
1062    /// Look up a metadata u32 value by key.
1063    pub fn metadata_u32(&self, key: &str) -> Option<u32> {
1064        self.metadata.get(key).and_then(|v| v.as_u32())
1065    }
1066
1067    /// Look up a metadata f32 value by key.
1068    pub fn metadata_f32(&self, key: &str) -> Option<f32> {
1069        self.metadata.get(key).and_then(|v| v.as_f32())
1070    }
1071
1072    // -----------------------------------------------------------------------
1073    // Tensor info accessors
1074    // -----------------------------------------------------------------------
1075
1076    /// Return the names of all tensors in the file.
1077    pub fn tensor_names(&self) -> Vec<&str> {
1078        self.tensors.keys().map(|s| s.as_str()).collect()
1079    }
1080
1081    /// Look up info for a specific tensor by name.
1082    pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
1083        self.tensors.get(name)
1084    }
1085
1086    /// Number of tensors in the file.
1087    pub fn tensor_count(&self) -> usize {
1088        self.tensors.len()
1089    }
1090
1091    /// Number of metadata key-value pairs.
1092    pub fn metadata_count(&self) -> usize {
1093        self.metadata.len()
1094    }
1095
1096    // -----------------------------------------------------------------------
1097    // Tensor loading
1098    // -----------------------------------------------------------------------
1099
1100    /// Read raw tensor bytes from the file.
1101    ///
1102    /// This is a private helper that seeks to the tensor's location and reads
1103    /// `byte_len` bytes.
1104    fn read_tensor_bytes(&self, info: &TensorInfo) -> Result<Vec<u8>> {
1105        let abs_offset = self.tensor_data_offset + info.offset;
1106        let mut reader = self
1107            .reader
1108            .lock()
1109            .map_err(|_| MlxError::GgufParseError("reader mutex poisoned".into()))?;
1110
1111        reader
1112            .seek(SeekFrom::Start(abs_offset))
1113            .map_err(|e| MlxError::IoError(format!("seek to tensor '{}': {e}", info.name)))?;
1114
1115        let mut buf = vec![0u8; info.byte_len];
1116        reader.read_exact(&mut buf).map_err(|e| {
1117            MlxError::IoError(format!(
1118                "read tensor '{}' ({} bytes at offset {}): {e}",
1119                info.name, info.byte_len, abs_offset
1120            ))
1121        })?;
1122
1123        Ok(buf)
1124    }
1125
1126    /// Load a tensor as a raw buffer on the Metal device.
1127    ///
1128    /// For quantized types (Q4_0, Q8_0, Q4_K, Q6_K) the buffer contains raw
1129    /// GGML blocks with dtype `U8` — these are consumed directly by
1130    /// `quantized_matmul_ggml` kernels.
1131    ///
1132    /// For F32 and F16 tensors the buffer has the corresponding typed dtype.
1133    ///
1134    /// # Errors
1135    ///
1136    /// Returns an error if the tensor name is not found, or if reading fails.
1137    pub fn load_tensor(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
1138        let info = self.tensors.get(name).ok_or_else(|| {
1139            MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
1140        })?;
1141
1142        let data = self.read_tensor_bytes(info)?;
1143
1144        match info.ggml_type {
1145            GgmlType::F32 => {
1146                let mut buf =
1147                    device.alloc_buffer(info.byte_len, DType::F32, info.shape.clone())?;
1148                {
1149                    let slice: &mut [u8] = buf.as_mut_slice()?;
1150                    slice.copy_from_slice(&data);
1151                }
1152                Ok(buf)
1153            }
1154            GgmlType::F16 => {
1155                let mut buf =
1156                    device.alloc_buffer(info.byte_len, DType::F16, info.shape.clone())?;
1157                {
1158                    let slice: &mut [u8] = buf.as_mut_slice()?;
1159                    slice.copy_from_slice(&data);
1160                }
1161                Ok(buf)
1162            }
1163            GgmlType::Q4_0
1164            | GgmlType::Q8_0
1165            | GgmlType::Q4_K
1166            | GgmlType::Q5_K
1167            | GgmlType::Q6_K
1168            | GgmlType::I16
1169            | GgmlType::Q5_1
1170            | GgmlType::IQ4_NL => {
1171                // Store raw GGML blocks as a U8 buffer. Where a Metal
1172                // quantized-matmul kernel exists for the type, it consumes
1173                // these blocks directly without an explicit dequant pass on
1174                // the GPU; otherwise the U8 view is opaque on-device storage
1175                // pending either a kernel port or a host-side dequant.
1176                //
1177                // Coverage status (ADR-022 in-flight; see ADR for the live
1178                // matrix). Per-type Metal kernel coverage is owned by
1179                // `quantized_matmul_ggml.rs` `kernel_name` / `mm_kernel_name`
1180                // / `mm_tensor_kernel_name` and the matmul-id counterparts;
1181                // host-side dequant for parity / no-kernel-yet paths is
1182                // wired into `dequantize_to_f32` directly above.
1183                let mut buf =
1184                    device.alloc_buffer(info.byte_len, DType::U8, info.shape.clone())?;
1185                {
1186                    let slice: &mut [u8] = buf.as_mut_slice()?;
1187                    slice.copy_from_slice(&data);
1188                }
1189                Ok(buf)
1190            }
1191        }
1192    }
1193
1194    /// Load a tensor, dequantizing to F32 on the CPU, then upload to the
1195    /// Metal device.
1196    ///
1197    /// This is used for norm weights, embedding tables, and other tensors
1198    /// where the inference kernels operate on F32 directly.
1199    ///
1200    /// # Errors
1201    ///
1202    /// Returns an error if the tensor name is not found, reading fails, or
1203    /// dequantization encounters malformed data.
1204    pub fn load_tensor_f32(&self, name: &str, device: &MlxDevice) -> Result<MlxBuffer> {
1205        let info = self.tensors.get(name).ok_or_else(|| {
1206            MlxError::GgufParseError(format!("tensor '{name}' not found in GGUF file"))
1207        })?;
1208
1209        let data = self.read_tensor_bytes(info)?;
1210        let total_elements: usize = info.shape.iter().product();
1211
1212        if total_elements == 0 {
1213            return Err(MlxError::GgufParseError(format!(
1214                "tensor '{name}' has zero elements"
1215            )));
1216        }
1217
1218        let f32_byte_len = total_elements * 4;
1219        let mut buf =
1220            device.alloc_buffer(f32_byte_len, DType::F32, info.shape.clone())?;
1221
1222        {
1223            let out_slice: &mut [f32] = buf.as_mut_slice()?;
1224            dequantize_to_f32(&data, info.ggml_type, out_slice)?;
1225        }
1226
1227        Ok(buf)
1228    }
1229
1230    /// Load a tensor and register its underlying Metal buffer with `pool`'s
1231    /// residency set, returning the [`MlxBuffer`] to the caller.
1232    ///
1233    /// This is functionally equivalent to:
1234    ///
1235    /// ```ignore
1236    /// let buf = gguf.load_tensor(name, device)?;
1237    /// pool.register_existing(device, &buf)?;
1238    /// ```
1239    ///
1240    /// but exists as a single call so callers don't need to reach for the
1241    /// underlying [`MlxBufferPool::register_existing`] API directly.  See
1242    /// that method's docs for the residency-set ownership contract.
1243    ///
1244    /// # Why a separate method instead of a `pool` parameter on `load_tensor`
1245    ///
1246    /// `load_tensor` has stable callers across the codebase that pass only
1247    /// `&MlxDevice`; making the pool registration optional via a new method
1248    /// keeps the existing signature wire-compatible.
1249    ///
1250    /// # Note on bucket-rounding
1251    ///
1252    /// The buffer is allocated at exactly `info.byte_len` via
1253    /// [`MlxDevice::alloc_buffer`](crate::MlxDevice::alloc_buffer) (no
1254    /// bucket-rounding) and added to the pool's residency set only —
1255    /// it is not placed in the recycling free list.  This is the path
1256    /// hf2q's static weight loader uses to gain MTLResidencySet hints
1257    /// without paying the 48% bucket-rounding tax that would have
1258    /// inflated 17 GB of weights to 25 GB.
1259    ///
1260    /// # Errors
1261    ///
1262    /// Same as [`load_tensor`](Self::load_tensor), plus any
1263    /// [`MlxError::InvalidArgument`] from
1264    /// [`MlxBufferPool::register_existing`].
1265    pub fn load_tensor_into_pool(
1266        &self,
1267        name: &str,
1268        device: &MlxDevice,
1269        pool: &mut MlxBufferPool,
1270    ) -> Result<MlxBuffer> {
1271        let buf = self.load_tensor(name, device)?;
1272        pool.register_existing(device, &buf)?;
1273        Ok(buf)
1274    }
1275}
1276
1277// ---------------------------------------------------------------------------
1278// Utility
1279// ---------------------------------------------------------------------------
1280
1281/// Round `offset` up to the next multiple of `alignment`.
1282fn align_offset(offset: u64, alignment: u64) -> u64 {
1283    let mask = alignment - 1;
1284    (offset + mask) & !mask
1285}