Skip to main content

oxibonsai_model/
gguf_loader.rs

1//! Production-quality GGUF model loader with validation, streaming, and memory budgeting.
2//!
3//! This module provides high-level utilities for loading GGUF model files with:
4//! - Configurable memory budgets and validation strictness
5//! - Lazy tensor metadata loading (no weight data read upfront)
6//! - Streaming chunk iterators for progressive loading
7//! - Memory footprint estimation before committing to a full load
8
9use std::io::Read;
10use std::path::Path;
11use std::time::Instant;
12
13// ─────────────────────────────────────────────────────────────────────────────
14// LoadError
15// ─────────────────────────────────────────────────────────────────────────────
16
17/// Errors that can occur during GGUF model loading.
18#[derive(Debug, thiserror::Error)]
19pub enum LoadError {
20    /// An underlying I/O error (e.g., file not found, permission denied).
21    #[error("I/O error: {0}")]
22    Io(#[from] std::io::Error),
23
24    /// The GGUF file could not be parsed (malformed binary).
25    #[error("GGUF parse error: {0}")]
26    Parse(String),
27
28    /// Loading this file would exceed the configured memory budget.
29    #[error("memory budget exceeded: need {need} bytes, budget {budget} bytes")]
30    MemoryBudgetExceeded { need: u64, budget: u64 },
31
32    /// The GGUF version in the file header is not supported.
33    #[error("unsupported GGUF version: {0}")]
34    UnsupportedVersion(u32),
35
36    /// A required structural invariant was violated.
37    #[error("validation failed: {0}")]
38    ValidationFailed(String),
39}
40
41// ─────────────────────────────────────────────────────────────────────────────
42// LoadConfig
43// ─────────────────────────────────────────────────────────────────────────────
44
45/// Configuration governing how a GGUF model is loaded.
46#[derive(Debug, Clone)]
47pub struct LoadConfig {
48    /// Maximum memory (in bytes) the loader is allowed to consume; `None` = unlimited.
49    pub max_memory_bytes: Option<usize>,
50    /// Whether to validate file-level checksums (currently advisory).
51    pub validate_checksums: bool,
52    /// If `true`, tensors with unrecognised quantisation types are silently skipped
53    /// rather than returning an error.
54    pub allow_unknown_quant_types: bool,
55    /// Size of each streaming chunk in bytes when using [`TensorChunkIter`].
56    pub streaming_chunk_size: usize,
57    /// If `true`, reject GGUF files that declare an unsupported version.
58    pub strict_version: bool,
59}
60
61impl Default for LoadConfig {
62    fn default() -> Self {
63        Self {
64            max_memory_bytes: None,
65            validate_checksums: false,
66            allow_unknown_quant_types: true,
67            streaming_chunk_size: 4 * 1024 * 1024, // 4 MiB
68            strict_version: false,
69        }
70    }
71}
72
73// ─────────────────────────────────────────────────────────────────────────────
74// LoadStats
75// ─────────────────────────────────────────────────────────────────────────────
76
77/// Statistics gathered during a model loading operation.
78#[derive(Debug, Clone, Default)]
79pub struct LoadStats {
80    /// Number of tensors successfully loaded.
81    pub tensors_loaded: usize,
82    /// Total bytes of tensor weight data loaded.
83    pub bytes_loaded: u64,
84    /// Tensors skipped because their quantisation type was unrecognised.
85    pub skipped_tensors: usize,
86    /// Wall-clock time for the entire load operation in milliseconds.
87    pub load_time_ms: u64,
88    /// Approximate peak memory usage (bytes) during loading.
89    pub peak_memory_bytes: usize,
90    /// Non-fatal issues found during validation (empty = clean).
91    pub validation_warnings: Vec<String>,
92}
93
94// ─────────────────────────────────────────────────────────────────────────────
95// TensorEntry
96// ─────────────────────────────────────────────────────────────────────────────
97
98/// Known quantisation type IDs and their human-readable names.
99const KNOWN_QUANT_TYPES: &[(u32, &str)] = &[
100    (0, "F32"),
101    (1, "F16"),
102    (2, "Q4_0"),
103    (3, "Q4_1"),
104    (6, "Q5_0"),
105    (7, "Q5_1"),
106    (8, "Q8_0"),
107    (9, "Q8_1"),
108    (10, "Q2_K"),
109    (11, "Q3_K"),
110    (12, "Q4_K"),
111    (13, "Q5_K"),
112    (14, "Q6_K"),
113    (15, "Q8_K"),
114    (30, "BF16"),
115    (35, "TQ2_0"),
116    (41, "Q1_0_g128"),
117    (42, "TQ2_0_g128"),
118];
119
120/// A loaded tensor entry — contains metadata only; no weight bytes are held here.
121/// Use [`load_tensor_metadata`] to obtain a collection of these, then open the
122/// file and seek to `offset` to read the actual data.
123#[derive(Debug, Clone)]
124pub struct TensorEntry {
125    /// Tensor name as stored in the GGUF file (e.g. `"blk.0.attn_q.weight"`).
126    pub name: String,
127    /// Shape dimensions (e.g. `[4096, 4096]`).
128    pub shape: Vec<u64>,
129    /// Raw GGUF quantisation type ID.
130    pub quant_type_id: u32,
131    /// Byte offset of this tensor's data from the start of the tensor data section.
132    pub offset: u64,
133    /// Number of bytes occupied by this tensor in the data section.
134    pub size_bytes: u64,
135}
136
137impl TensorEntry {
138    /// Total number of elements across all dimensions.
139    pub fn element_count(&self) -> u64 {
140        self.shape.iter().product()
141    }
142
143    /// Human-readable name for the quantisation type, or `"UNKNOWN"`.
144    pub fn quant_name(&self) -> &'static str {
145        KNOWN_QUANT_TYPES
146            .iter()
147            .find(|(id, _)| *id == self.quant_type_id)
148            .map(|(_, name)| *name)
149            .unwrap_or("UNKNOWN")
150    }
151
152    /// Returns `true` when the quantisation type ID is one OxiBonsai recognises.
153    pub fn is_known_quant(&self) -> bool {
154        KNOWN_QUANT_TYPES
155            .iter()
156            .any(|(id, _)| *id == self.quant_type_id)
157    }
158}
159
160// ─────────────────────────────────────────────────────────────────────────────
161// GGUF raw-parsing helpers (pure Rust, no external deps beyond std)
162// ─────────────────────────────────────────────────────────────────────────────
163
164/// GGUF magic bytes in little-endian order: ASCII "GGUF" = bytes [0x47,0x47,0x55,0x46] → LE u32 0x46554747.
165const GGUF_MAGIC: u32 = 0x4655_4747;
166
167/// Supported GGUF versions.
168const SUPPORTED_VERSIONS: &[u32] = &[2, 3];
169
170/// Read a little-endian u32 from a cursor.
171fn read_u32_le(buf: &[u8], pos: &mut usize) -> Result<u32, LoadError> {
172    if *pos + 4 > buf.len() {
173        return Err(LoadError::Parse(format!(
174            "unexpected EOF at offset {} reading u32",
175            pos
176        )));
177    }
178    let v = u32::from_le_bytes(
179        buf[*pos..*pos + 4]
180            .try_into()
181            .map_err(|_| LoadError::Parse("slice conversion failed for u32".to_string()))?,
182    );
183    *pos += 4;
184    Ok(v)
185}
186
187/// Read a little-endian u64 from a cursor.
188fn read_u64_le(buf: &[u8], pos: &mut usize) -> Result<u64, LoadError> {
189    if *pos + 8 > buf.len() {
190        return Err(LoadError::Parse(format!(
191            "unexpected EOF at offset {} reading u64",
192            pos
193        )));
194    }
195    let v = u64::from_le_bytes(
196        buf[*pos..*pos + 8]
197            .try_into()
198            .map_err(|_| LoadError::Parse("slice conversion failed for u64".to_string()))?,
199    );
200    *pos += 8;
201    Ok(v)
202}
203
204/// Read a GGUF string: [u64 length][bytes].
205fn read_gguf_string(buf: &[u8], pos: &mut usize) -> Result<String, LoadError> {
206    let len = read_u64_le(buf, pos)? as usize;
207    if *pos + len > buf.len() {
208        return Err(LoadError::Parse(format!(
209            "string of length {len} extends beyond buffer"
210        )));
211    }
212    let s = std::str::from_utf8(&buf[*pos..*pos + len])
213        .map_err(|e| LoadError::Parse(format!("invalid UTF-8 in string: {e}")))?
214        .to_string();
215    *pos += len;
216    Ok(s)
217}
218
219/// Skip over a GGUF metadata value (we don't need the values for metadata-only loading).
220/// Returns the number of bytes consumed.
221fn skip_metadata_value(buf: &[u8], pos: &mut usize, value_type: u32) -> Result<(), LoadError> {
222    match value_type {
223        0 | 1 => {
224            // uint8, int8
225            if *pos + 1 > buf.len() {
226                return Err(LoadError::Parse("EOF in u8/i8 value".to_string()));
227            }
228            *pos += 1;
229        }
230        2 | 3 => {
231            // uint16, int16
232            if *pos + 2 > buf.len() {
233                return Err(LoadError::Parse("EOF in u16/i16 value".to_string()));
234            }
235            *pos += 2;
236        }
237        4..=7 => {
238            // uint32, int32, float32, bool
239            if *pos + 4 > buf.len() {
240                return Err(LoadError::Parse(
241                    "EOF in u32/i32/f32/bool value".to_string(),
242                ));
243            }
244            *pos += 4;
245        }
246        8 => {
247            // string
248            read_gguf_string(buf, pos)?;
249        }
250        9 => {
251            // array: [value_type: u32][count: u64][elements...]
252            let elem_type = read_u32_le(buf, pos)?;
253            let count = read_u64_le(buf, pos)?;
254            for _ in 0..count {
255                skip_metadata_value(buf, pos, elem_type)?;
256            }
257        }
258        10..=12 => {
259            // uint64, int64, float64
260            if *pos + 8 > buf.len() {
261                return Err(LoadError::Parse("EOF in u64/i64/f64 value".to_string()));
262            }
263            *pos += 8;
264        }
265        other => {
266            return Err(LoadError::Parse(format!(
267                "unknown metadata value type id: {other}"
268            )));
269        }
270    }
271    Ok(())
272}
273
274/// Low-level result of parsing a GGUF file's header + metadata + tensor info.
275struct ParsedGgufMeta {
276    version: u32,
277    tensor_entries: Vec<TensorEntry>,
278}
279
280/// Parse GGUF header, skip metadata KV, parse tensor info entries.
281/// Does NOT read any tensor weight bytes.
282fn parse_gguf_meta(buf: &[u8]) -> Result<ParsedGgufMeta, LoadError> {
283    let mut pos = 0usize;
284
285    // --- Header ---
286    let magic = read_u32_le(buf, &mut pos)?;
287    if magic != GGUF_MAGIC {
288        return Err(LoadError::Parse(format!(
289            "invalid GGUF magic: 0x{:08X} (expected 0x{:08X})",
290            magic, GGUF_MAGIC
291        )));
292    }
293
294    let version = read_u32_le(buf, &mut pos)?;
295
296    let tensor_count = read_u64_le(buf, &mut pos)?;
297    let metadata_kv_count = read_u64_le(buf, &mut pos)?;
298
299    // --- Metadata KV pairs (skip values, we only need structure) ---
300    for _ in 0..metadata_kv_count {
301        // key
302        read_gguf_string(buf, &mut pos)?;
303        // value type
304        let value_type = read_u32_le(buf, &mut pos)?;
305        // value
306        skip_metadata_value(buf, &mut pos, value_type)?;
307    }
308
309    // --- Tensor info entries ---
310    let mut tensor_entries = Vec::with_capacity(tensor_count as usize);
311    for _ in 0..tensor_count {
312        let name = read_gguf_string(buf, &mut pos)?;
313        let n_dims = read_u32_le(buf, &mut pos)?;
314        let mut shape = Vec::with_capacity(n_dims as usize);
315        for _ in 0..n_dims {
316            shape.push(read_u64_le(buf, &mut pos)?);
317        }
318        let quant_type_id = read_u32_le(buf, &mut pos)?;
319        let offset = read_u64_le(buf, &mut pos)?;
320
321        // Compute size_bytes using quantisation block math.
322        let size_bytes = compute_tensor_size_bytes(&shape, quant_type_id);
323
324        tensor_entries.push(TensorEntry {
325            name,
326            shape,
327            quant_type_id,
328            offset,
329            size_bytes,
330        });
331    }
332
333    Ok(ParsedGgufMeta {
334        version,
335        tensor_entries,
336    })
337}
338
339/// Compute the byte size of a tensor given its shape and quant type ID.
340fn compute_tensor_size_bytes(shape: &[u64], quant_type_id: u32) -> u64 {
341    let element_count: u64 = shape.iter().product();
342    let (block_size, block_bytes): (u64, u64) = match quant_type_id {
343        0 => (1, 4),      // F32
344        1 => (1, 2),      // F16
345        2 => (32, 18),    // Q4_0
346        3 => (32, 20),    // Q4_1
347        6 => (32, 22),    // Q5_0
348        7 => (32, 24),    // Q5_1
349        8 => (32, 34),    // Q8_0
350        9 => (32, 40),    // Q8_1
351        10 => (256, 84),  // Q2_K
352        11 => (256, 110), // Q3_K
353        12 => (256, 144), // Q4_K
354        13 => (256, 176), // Q5_K
355        14 => (256, 210), // Q6_K
356        15 => (256, 292), // Q8_K
357        30 => (1, 2),     // BF16
358        35 => (256, 66),  // TQ2_0 (llama.cpp ternary, 256-element groups)
359        41 => (128, 18),  // Q1_0_g128
360        42 => (128, 34),  // TQ2_0_g128 (PrismML ternary, 128-element groups)
361        // Unknown type: assume 1 byte per element as a conservative fallback
362        _ => (1, 1),
363    };
364    let num_blocks = element_count.div_ceil(block_size);
365    num_blocks * block_bytes
366}
367
368// ─────────────────────────────────────────────────────────────────────────────
369// Public API
370// ─────────────────────────────────────────────────────────────────────────────
371
372/// Validates a GGUF file at `path`, checking:
373/// - File exists and is readable
374/// - Magic bytes are correct
375/// - Version is in the supported set
376/// - Tensor count and metadata are self-consistent
377///
378/// Returns a (possibly empty) list of advisory warning strings.
379/// An empty list means the file passed all checks.
380pub fn validate_gguf_file(path: &Path) -> Result<Vec<String>, LoadError> {
381    let mut file = std::fs::File::open(path)?;
382    let mut buf = Vec::new();
383    file.read_to_end(&mut buf)?;
384
385    let mut warnings = Vec::new();
386    let start = Instant::now();
387
388    let meta = parse_gguf_meta(&buf)?;
389
390    if !SUPPORTED_VERSIONS.contains(&meta.version) {
391        warnings.push(format!(
392            "GGUF version {} is not in the officially supported set {:?}",
393            meta.version, SUPPORTED_VERSIONS
394        ));
395    }
396
397    if meta.tensor_entries.is_empty() {
398        warnings.push("file contains zero tensors".to_string());
399    }
400
401    for entry in &meta.tensor_entries {
402        if !entry.is_known_quant() {
403            warnings.push(format!(
404                "tensor '{}' has unknown quantisation type id {}",
405                entry.name, entry.quant_type_id
406            ));
407        }
408        if entry.shape.is_empty() {
409            warnings.push(format!(
410                "tensor '{}' has zero-dimensional shape",
411                entry.name
412            ));
413        }
414    }
415
416    let _elapsed = start.elapsed();
417    Ok(warnings)
418}
419
420/// Loads tensor metadata (names, shapes, types, offsets) from a GGUF file.
421///
422/// This is intentionally fast — no weight bytes are read.  Call this to build
423/// a directory of available tensors before deciding which to materialise.
424pub fn load_tensor_metadata(path: &Path) -> Result<Vec<TensorEntry>, LoadError> {
425    let _t0 = Instant::now();
426
427    let mut file = std::fs::File::open(path)?;
428    let mut buf = Vec::new();
429    file.read_to_end(&mut buf)?;
430
431    let meta = parse_gguf_meta(&buf)?;
432    Ok(meta.tensor_entries)
433}
434
435/// Computes the expected memory footprint (in bytes) for fully loading all
436/// tensor weight data from the given GGUF file.
437///
438/// This reads only the file header and tensor metadata — no weight bytes.
439pub fn estimate_memory_bytes(path: &Path) -> Result<u64, LoadError> {
440    let entries = load_tensor_metadata(path)?;
441    let total: u64 = entries.iter().map(|e| e.size_bytes).sum();
442    Ok(total)
443}
444
445/// Returns `true` when the GGUF file at `path` fits within `budget_bytes`.
446///
447/// Identical to calling [`estimate_memory_bytes`] and comparing.
448pub fn fits_in_budget(path: &Path, budget_bytes: u64) -> Result<bool, LoadError> {
449    let need = estimate_memory_bytes(path)?;
450    Ok(need <= budget_bytes)
451}
452
453// ─────────────────────────────────────────────────────────────────────────────
454// Streaming iterator
455// ─────────────────────────────────────────────────────────────────────────────
456
457/// An iterator that yields successive fixed-size byte chunks from a tensor's
458/// raw data buffer.  Use this for progressive / streaming loading.
459///
460/// ```
461/// # use oxibonsai_model::gguf_loader::TensorChunkIter;
462/// let data = vec![0u8; 100];
463/// let mut iter = TensorChunkIter::new(data, 32);
464/// assert_eq!(iter.total_chunks(), 4); // ceil(100/32)
465/// ```
466pub struct TensorChunkIter {
467    data: Vec<u8>,
468    chunk_size: usize,
469    pos: usize,
470}
471
472impl TensorChunkIter {
473    /// Create a new chunk iterator over `data` with the given `chunk_size`.
474    pub fn new(data: Vec<u8>, chunk_size: usize) -> Self {
475        assert!(chunk_size > 0, "chunk_size must be > 0");
476        Self {
477            data,
478            chunk_size,
479            pos: 0,
480        }
481    }
482
483    /// Total number of chunks (rounded up for any partial final chunk).
484    pub fn total_chunks(&self) -> usize {
485        if self.data.is_empty() {
486            return 0;
487        }
488        self.data.len().div_ceil(self.chunk_size)
489    }
490
491    /// Remaining bytes not yet yielded by the iterator.
492    pub fn bytes_remaining(&self) -> usize {
493        self.data.len().saturating_sub(self.pos)
494    }
495}
496
497impl Iterator for TensorChunkIter {
498    type Item = Vec<u8>;
499
500    fn next(&mut self) -> Option<Self::Item> {
501        if self.pos >= self.data.len() {
502            return None;
503        }
504        let end = (self.pos + self.chunk_size).min(self.data.len());
505        let chunk = self.data[self.pos..end].to_vec();
506        self.pos = end;
507        Some(chunk)
508    }
509}