Skip to main content

trueno/inference/
gguf.rs

1//! GGUF file reader — loads tensor data for inference.
2//!
3//! Reads GGUF v3 files (llama.cpp compatible). Parses header, metadata,
4//! tensor info, then memory-maps or reads tensor data bytes.
5//!
6//! # Format
7//!
8//! ```text
9//! [magic: u32] [version: u32] [tensor_count: u64] [metadata_kv_count: u64]
10//! [metadata KV pairs...]
11//! [tensor info entries...]
12//! [alignment padding]
13//! [tensor data (contiguous)]
14//! ```
15
16use std::collections::HashMap;
17use std::io::{self, Read, Seek};
18use std::path::Path;
19
20use crate::error::TruenoError;
21
22const GGUF_MAGIC: u32 = 0x4655_4747; // "GGUF" in little-endian
23
24/// GGML tensor type IDs (subset used for LLM inference).
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26#[repr(u32)]
27pub enum GgmlType {
28    F32 = 0,
29    F16 = 1,
30    Q4_0 = 2,
31    Q4_1 = 3,
32    Q5_0 = 6,
33    Q5_1 = 7,
34    Q8_0 = 8,
35    Q8_1 = 9,
36    Q2K = 10,
37    Q3K = 11,
38    Q4K = 12,
39    Q5K = 13,
40    Q6K = 14,
41    Q8K = 15,
42    Bf16 = 30,
43}
44
45impl GgmlType {
46    fn from_u32(v: u32) -> Option<Self> {
47        match v {
48            0 => Some(Self::F32),
49            1 => Some(Self::F16),
50            2 => Some(Self::Q4_0),
51            3 => Some(Self::Q4_1),
52            6 => Some(Self::Q5_0),
53            7 => Some(Self::Q5_1),
54            8 => Some(Self::Q8_0),
55            9 => Some(Self::Q8_1),
56            10 => Some(Self::Q2K),
57            11 => Some(Self::Q3K),
58            12 => Some(Self::Q4K),
59            13 => Some(Self::Q5K),
60            14 => Some(Self::Q6K),
61            15 => Some(Self::Q8K),
62            30 => Some(Self::Bf16),
63            _ => None,
64        }
65    }
66
67    /// Bytes per block for this quantization type.
68    pub fn block_bytes(&self) -> usize {
69        match self {
70            Self::F32 => 4,
71            Self::F16 | Self::Bf16 => 2,
72            Self::Q4_0 => 18, // 32 weights: 2 (scale) + 16 (4-bit)
73            Self::Q4_1 => 20, // 32 weights: 2+2 (scale+min) + 16
74            Self::Q5_0 => 22, // 32 weights
75            Self::Q5_1 => 24,
76            Self::Q8_0 => 34, // 32 weights: 2 (scale) + 32 (8-bit)
77            Self::Q8_1 => 36,
78            Self::Q2K => 84, // 256 weights
79            Self::Q3K => 110,
80            Self::Q4K => 144, // 256 weights
81            Self::Q5K => 176,
82            Self::Q6K => 210,
83            Self::Q8K => 292,
84        }
85    }
86
87    /// Weights per block.
88    pub fn block_size(&self) -> usize {
89        match self {
90            Self::F32 | Self::F16 | Self::Bf16 => 1,
91            Self::Q4_0 | Self::Q4_1 | Self::Q5_0 | Self::Q5_1 | Self::Q8_0 | Self::Q8_1 => 32,
92            Self::Q2K | Self::Q3K | Self::Q4K | Self::Q5K | Self::Q6K | Self::Q8K => 256,
93        }
94    }
95
96    /// Total bytes for `n_elements` weights.
97    pub fn tensor_bytes(&self, n_elements: usize) -> usize {
98        let bs = self.block_size();
99        let n_blocks = (n_elements + bs - 1) / bs;
100        n_blocks * self.block_bytes()
101    }
102}
103
104/// Info about a single tensor in the GGUF file.
105#[derive(Debug, Clone)]
106pub struct TensorInfo {
107    pub name: String,
108    pub dtype: GgmlType,
109    pub dims: Vec<u64>,
110    /// Offset from start of data section (NOT from file start).
111    pub offset: u64,
112}
113
114impl TensorInfo {
115    pub fn n_elements(&self) -> u64 {
116        self.dims.iter().product::<u64>().max(1)
117    }
118
119    pub fn byte_size(&self) -> usize {
120        self.dtype.tensor_bytes(self.n_elements() as usize)
121    }
122}
123
124/// Parsed GGUF file ready for tensor extraction.
125pub struct GgufFile {
126    pub tensor_count: u64,
127    pub metadata: HashMap<String, MetadataValue>,
128    pub tensors: Vec<TensorInfo>,
129    /// Offset in bytes from file start where tensor data begins.
130    pub data_offset: u64,
131    /// Raw file bytes (memory mapped or loaded).
132    data: Vec<u8>,
133}
134
135#[derive(Debug, Clone)]
136pub enum MetadataValue {
137    U8(u8),
138    I8(i8),
139    U16(u16),
140    I16(i16),
141    U32(u32),
142    I32(i32),
143    U64(u64),
144    I64(i64),
145    F32(f32),
146    F64(f64),
147    Bool(bool),
148    String(String),
149    Array(Vec<MetadataValue>),
150}
151
152impl MetadataValue {
153    pub fn as_u32(&self) -> Option<u32> {
154        match self {
155            Self::U32(v) => Some(*v),
156            Self::U64(v) => Some(*v as u32),
157            Self::I32(v) => Some(*v as u32),
158            _ => None,
159        }
160    }
161
162    pub fn as_f32(&self) -> Option<f32> {
163        match self {
164            Self::F32(v) => Some(*v),
165            Self::F64(v) => Some(*v as f32),
166            _ => None,
167        }
168    }
169
170    pub fn as_str(&self) -> Option<&str> {
171        match self {
172            Self::String(s) => Some(s),
173            _ => None,
174        }
175    }
176}
177
178impl GgufFile {
179    /// Load and parse a GGUF file.
180    pub fn load(path: &Path) -> Result<Self, TruenoError> {
181        let data = std::fs::read(path).map_err(|e| {
182            TruenoError::InvalidInput(format!("Failed to read GGUF file {}: {e}", path.display()))
183        })?;
184
185        Self::parse(data)
186    }
187
188    /// Parse GGUF from raw bytes.
189    pub fn parse(data: Vec<u8>) -> Result<Self, TruenoError> {
190        let mut cursor = io::Cursor::new(&data);
191
192        // Header
193        let magic = read_u32(&mut cursor)?;
194        if magic != GGUF_MAGIC {
195            return Err(TruenoError::InvalidInput(format!(
196                "Not a GGUF file: magic=0x{magic:08x}, expected 0x{GGUF_MAGIC:08x}"
197            )));
198        }
199        let version = read_u32(&mut cursor)?;
200        if !(2..=3).contains(&version) {
201            return Err(TruenoError::InvalidInput(format!(
202                "Unsupported GGUF version {version} (need 2 or 3)"
203            )));
204        }
205        let tensor_count = read_u64(&mut cursor)?;
206        let metadata_kv_count = read_u64(&mut cursor)?;
207
208        // Metadata
209        let mut metadata = HashMap::new();
210        for _ in 0..metadata_kv_count {
211            let key = read_gguf_string(&mut cursor)?;
212            let value = read_metadata_value(&mut cursor)?;
213            metadata.insert(key, value);
214        }
215
216        // Tensor info
217        let mut tensors = Vec::with_capacity(tensor_count as usize);
218        for _ in 0..tensor_count {
219            let name = read_gguf_string(&mut cursor)?;
220            let n_dims = read_u32(&mut cursor)? as usize;
221            let mut dims = Vec::with_capacity(n_dims);
222            for _ in 0..n_dims {
223                dims.push(read_u64(&mut cursor)?);
224            }
225            let dtype_u32 = read_u32(&mut cursor)?;
226            let dtype = GgmlType::from_u32(dtype_u32).ok_or_else(|| {
227                TruenoError::InvalidInput(format!(
228                    "Unknown GGML type {dtype_u32} for tensor '{name}'"
229                ))
230            })?;
231            let offset = read_u64(&mut cursor)?;
232            tensors.push(TensorInfo { name, dtype, dims, offset });
233        }
234
235        // Data section starts at next alignment boundary (default 32 bytes)
236        let alignment =
237            metadata.get("general.alignment").and_then(|v| v.as_u32()).unwrap_or(32) as u64;
238        let pos = cursor.position();
239        let data_offset = (pos + alignment - 1) / alignment * alignment;
240
241        Ok(Self { tensor_count, metadata, tensors, data_offset, data })
242    }
243
244    /// Get raw bytes for a tensor by name.
245    pub fn tensor_data(&self, name: &str) -> Option<&[u8]> {
246        let info = self.tensors.iter().find(|t| t.name == name)?;
247        let start = self.data_offset as usize + info.offset as usize;
248        let end = start + info.byte_size();
249        if end <= self.data.len() {
250            Some(&self.data[start..end])
251        } else {
252            None
253        }
254    }
255
256    /// Get tensor info by name.
257    pub fn tensor_info(&self, name: &str) -> Option<&TensorInfo> {
258        self.tensors.iter().find(|t| t.name == name)
259    }
260
261    /// Get a metadata string value.
262    pub fn meta_str(&self, key: &str) -> Option<&str> {
263        self.metadata.get(key)?.as_str()
264    }
265
266    /// Get a metadata u32 value.
267    pub fn meta_u32(&self, key: &str) -> Option<u32> {
268        self.metadata.get(key)?.as_u32()
269    }
270
271    /// Get a metadata f32 value.
272    pub fn meta_f32(&self, key: &str) -> Option<f32> {
273        self.metadata.get(key)?.as_f32()
274    }
275}
276
277// ── Binary readers ──
278
279fn read_u8<R: Read>(r: &mut R) -> Result<u8, TruenoError> {
280    let mut buf = [0u8; 1];
281    r.read_exact(&mut buf)
282        .map_err(|e| TruenoError::InvalidInput(format!("GGUF read error: {e}")))?;
283    Ok(buf[0])
284}
285
286fn read_u16<R: Read>(r: &mut R) -> Result<u16, TruenoError> {
287    let mut buf = [0u8; 2];
288    r.read_exact(&mut buf)
289        .map_err(|e| TruenoError::InvalidInput(format!("GGUF read error: {e}")))?;
290    Ok(u16::from_le_bytes(buf))
291}
292
293fn read_u32<R: Read>(r: &mut R) -> Result<u32, TruenoError> {
294    let mut buf = [0u8; 4];
295    r.read_exact(&mut buf)
296        .map_err(|e| TruenoError::InvalidInput(format!("GGUF read error: {e}")))?;
297    Ok(u32::from_le_bytes(buf))
298}
299
300fn read_i32<R: Read>(r: &mut R) -> Result<i32, TruenoError> {
301    let mut buf = [0u8; 4];
302    r.read_exact(&mut buf)
303        .map_err(|e| TruenoError::InvalidInput(format!("GGUF read error: {e}")))?;
304    Ok(i32::from_le_bytes(buf))
305}
306
307fn read_u64<R: Read>(r: &mut R) -> Result<u64, TruenoError> {
308    let mut buf = [0u8; 8];
309    r.read_exact(&mut buf)
310        .map_err(|e| TruenoError::InvalidInput(format!("GGUF read error: {e}")))?;
311    Ok(u64::from_le_bytes(buf))
312}
313
314fn read_i64<R: Read>(r: &mut R) -> Result<i64, TruenoError> {
315    let mut buf = [0u8; 8];
316    r.read_exact(&mut buf)
317        .map_err(|e| TruenoError::InvalidInput(format!("GGUF read error: {e}")))?;
318    Ok(i64::from_le_bytes(buf))
319}
320
321fn read_f32_val<R: Read>(r: &mut R) -> Result<f32, TruenoError> {
322    let mut buf = [0u8; 4];
323    r.read_exact(&mut buf)
324        .map_err(|e| TruenoError::InvalidInput(format!("GGUF read error: {e}")))?;
325    Ok(f32::from_le_bytes(buf))
326}
327
328fn read_f64_val<R: Read>(r: &mut R) -> Result<f64, TruenoError> {
329    let mut buf = [0u8; 8];
330    r.read_exact(&mut buf)
331        .map_err(|e| TruenoError::InvalidInput(format!("GGUF read error: {e}")))?;
332    Ok(f64::from_le_bytes(buf))
333}
334
335fn read_gguf_string<R: Read>(r: &mut R) -> Result<String, TruenoError> {
336    let len = read_u64(r)? as usize;
337    if len > 1_000_000 {
338        return Err(TruenoError::InvalidInput(format!("GGUF string too long: {len}")));
339    }
340    let mut buf = vec![0u8; len];
341    r.read_exact(&mut buf)
342        .map_err(|e| TruenoError::InvalidInput(format!("GGUF string read error: {e}")))?;
343    String::from_utf8(buf)
344        .map_err(|e| TruenoError::InvalidInput(format!("GGUF string not UTF-8: {e}")))
345}
346
347fn read_metadata_value<R: Read + Seek>(r: &mut R) -> Result<MetadataValue, TruenoError> {
348    let value_type = read_u32(r)?;
349    match value_type {
350        0 => Ok(MetadataValue::U8(read_u8(r)?)),
351        1 => Ok(MetadataValue::I8(read_u8(r)? as i8)),
352        2 => Ok(MetadataValue::U16(read_u16(r)?)),
353        3 => Ok(MetadataValue::I16(read_u16(r)? as i16)),
354        4 => Ok(MetadataValue::U32(read_u32(r)?)),
355        5 => Ok(MetadataValue::I32(read_i32(r)?)),
356        6 => Ok(MetadataValue::F32(read_f32_val(r)?)),
357        7 => Ok(MetadataValue::Bool(read_u8(r)? != 0)),
358        8 => Ok(MetadataValue::String(read_gguf_string(r)?)),
359        9 => {
360            // Array
361            let elem_type = read_u32(r)?;
362            let count = read_u64(r)? as usize;
363            if count > 10_000_000 {
364                return Err(TruenoError::InvalidInput(format!("GGUF array too large: {count}")));
365            }
366            let mut items = Vec::with_capacity(count.min(1024));
367            for _ in 0..count {
368                // Read elements of the declared type
369                let item = match elem_type {
370                    0 => MetadataValue::U8(read_u8(r)?),
371                    1 => MetadataValue::I8(read_u8(r)? as i8),
372                    4 => MetadataValue::U32(read_u32(r)?),
373                    5 => MetadataValue::I32(read_i32(r)?),
374                    6 => MetadataValue::F32(read_f32_val(r)?),
375                    8 => MetadataValue::String(read_gguf_string(r)?),
376                    10 => MetadataValue::U64(read_u64(r)?),
377                    11 => MetadataValue::I64(read_i64(r)?),
378                    12 => MetadataValue::F64(read_f64_val(r)?),
379                    _ => {
380                        return Err(TruenoError::InvalidInput(format!(
381                            "Unsupported GGUF array element type {elem_type}"
382                        )))
383                    }
384                };
385                items.push(item);
386            }
387            Ok(MetadataValue::Array(items))
388        }
389        10 => Ok(MetadataValue::U64(read_u64(r)?)),
390        11 => Ok(MetadataValue::I64(read_i64(r)?)),
391        12 => Ok(MetadataValue::F64(read_f64_val(r)?)),
392        _ => Err(TruenoError::InvalidInput(format!("Unknown GGUF metadata type {value_type}"))),
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_ggml_type_q4k_properties() {
402        let q4k = GgmlType::Q4K;
403        assert_eq!(q4k.block_size(), 256);
404        assert_eq!(q4k.block_bytes(), 144);
405        // 4096 weights = 16 blocks × 144 bytes = 2304
406        assert_eq!(q4k.tensor_bytes(4096), 2304);
407    }
408
409    #[test]
410    fn test_ggml_type_f32_properties() {
411        let f32t = GgmlType::F32;
412        assert_eq!(f32t.block_size(), 1);
413        assert_eq!(f32t.block_bytes(), 4);
414        assert_eq!(f32t.tensor_bytes(1024), 4096);
415    }
416
417    #[test]
418    fn test_gguf_magic_check() {
419        let bad_data = vec![0u8; 32];
420        let result = GgufFile::parse(bad_data);
421        assert!(result.is_err());
422    }
423
424    #[test]
425    fn test_minimal_gguf() {
426        // Build a minimal valid GGUF v3 with 0 tensors, 0 metadata
427        let mut data = Vec::new();
428        data.extend_from_slice(&GGUF_MAGIC.to_le_bytes()); // magic
429        data.extend_from_slice(&3u32.to_le_bytes()); // version
430        data.extend_from_slice(&0u64.to_le_bytes()); // tensor_count
431        data.extend_from_slice(&0u64.to_le_bytes()); // metadata_kv_count
432                                                     // Pad to 32-byte alignment
433        data.resize(32, 0);
434
435        let file = GgufFile::parse(data).expect("valid minimal GGUF");
436        assert_eq!(file.tensor_count, 0);
437        assert_eq!(file.tensors.len(), 0);
438    }
439}