Skip to main content

entrenar/hf_pipeline/export/gguf_verify/
parsing.rs

1//! GGUF binary parsing helpers
2
3use super::types::GgufTensorInfo;
4use crate::hf_pipeline::error::FetchError;
5
6/// Read a little-endian u32 from a byte slice at the given offset.
7/// Caller must ensure `pos + 4 <= data.len()`.
8pub(super) fn read_u32_le(data: &[u8], pos: usize) -> Result<u32, FetchError> {
9    let bytes: [u8; 4] = data
10        .get(pos..pos + 4)
11        .and_then(|s| s.try_into().ok())
12        .ok_or_else(|| truncation_error(pos))?;
13    Ok(u32::from_le_bytes(bytes))
14}
15
16/// Read a little-endian u64 from a byte slice at the given offset.
17/// Caller must ensure `pos + 8 <= data.len()`.
18pub(super) fn read_u64_le(data: &[u8], pos: usize) -> Result<u64, FetchError> {
19    let bytes: [u8; 8] = data
20        .get(pos..pos + 8)
21        .and_then(|s| s.try_into().ok())
22        .ok_or_else(|| truncation_error(pos))?;
23    Ok(u64::from_le_bytes(bytes))
24}
25
26/// Parsed GGUF header (magic, version, counts)
27pub(super) struct GgufHeader {
28    pub(super) version: u32,
29    pub(super) tensor_count: u64,
30    pub(super) metadata_count: u64,
31}
32
33/// Parse and validate the 24-byte GGUF header
34pub(super) fn parse_header(data: &[u8]) -> Result<GgufHeader, FetchError> {
35    if data.len() < 24 {
36        return Err(FetchError::ConfigParseError {
37            message: "GGUF file too small: less than 24 bytes".to_string(),
38        });
39    }
40
41    let magic = data.get(0..4).unwrap_or_default();
42    if magic != b"GGUF" {
43        return Err(FetchError::ConfigParseError {
44            message: format!(
45                "Invalid GGUF magic: expected 'GGUF', got '{}'",
46                String::from_utf8_lossy(magic)
47            ),
48        });
49    }
50
51    let version = read_u32_le(data, 4)?;
52    if version != 3 {
53        return Err(FetchError::ConfigParseError {
54            message: format!("Unsupported GGUF version: {version} (expected 3)"),
55        });
56    }
57
58    let tensor_count = read_u64_le(data, 8)?;
59    let metadata_count = read_u64_le(data, 16)?;
60
61    Ok(GgufHeader { version, tensor_count, metadata_count })
62}
63
64/// Skip over all metadata key-value pairs, returning the position after them
65pub(super) fn skip_all_metadata(
66    data: &[u8],
67    start: usize,
68    count: u64,
69) -> Result<usize, FetchError> {
70    let mut pos = start;
71    for _ in 0..count {
72        pos = skip_gguf_string(data, pos)?;
73        let value_type = read_u32_le(data, pos)?;
74        pos += 4;
75        pos = skip_gguf_value(data, pos, value_type)?;
76    }
77    Ok(pos)
78}
79
80/// Parse a single tensor info entry, returning (info, new_position)
81pub(super) fn parse_tensor_info(
82    data: &[u8],
83    pos: usize,
84) -> Result<(GgufTensorInfo, usize), FetchError> {
85    let (name, mut pos) = read_gguf_string(data, pos)?;
86
87    // n_dimensions
88    let n_dims = read_u32_le(data, pos)?;
89    pos += 4;
90
91    // Dimensions
92    let mut shape = Vec::with_capacity(n_dims as usize);
93    for _ in 0..n_dims {
94        shape.push(read_u64_le(data, pos)?);
95        pos += 8;
96    }
97
98    // dtype
99    let dtype = read_u32_le(data, pos)?;
100    pos += 4;
101
102    // offset
103    let offset = read_u64_le(data, pos)?;
104    pos += 8;
105
106    Ok((GgufTensorInfo { name, shape, dtype, offset }, pos))
107}
108
109/// Parse all tensor info entries
110pub(super) fn parse_all_tensor_info(
111    data: &[u8],
112    start: usize,
113    count: u64,
114) -> Result<Vec<GgufTensorInfo>, FetchError> {
115    let mut tensors = Vec::with_capacity(count as usize);
116    let mut pos = start;
117    for _ in 0..count {
118        let (info, new_pos) = parse_tensor_info(data, pos)?;
119        tensors.push(info);
120        pos = new_pos;
121    }
122    Ok(tensors)
123}
124
125/// Read a GGUF string at the given position, return (string, new_position)
126pub(super) fn read_gguf_string(data: &[u8], pos: usize) -> Result<(String, usize), FetchError> {
127    let len = read_u64_le(data, pos)? as usize;
128    let start = pos + 8;
129    let end = start + len;
130    if end > data.len() {
131        return Err(truncation_error(start));
132    }
133    let s = String::from_utf8_lossy(&data[start..end]).to_string();
134    Ok((s, end))
135}
136
137/// Skip over a GGUF string, returning the new position
138pub(super) fn skip_gguf_string(data: &[u8], pos: usize) -> Result<usize, FetchError> {
139    let (_, new_pos) = read_gguf_string(data, pos)?;
140    Ok(new_pos)
141}
142
143/// Skip a GGUF value based on its type, returning the new position
144pub(super) fn skip_gguf_value(
145    data: &[u8],
146    pos: usize,
147    value_type: u32,
148) -> Result<usize, FetchError> {
149    match value_type {
150        0 | 1 | 7 => Ok(pos + 1),         // UINT8, INT8, BOOL
151        2 | 3 => Ok(pos + 2),             // UINT16, INT16
152        4..=6 => Ok(pos + 4),             // UINT32, INT32, FLOAT32
153        8 => skip_gguf_string(data, pos), // STRING
154        10..=12 => Ok(pos + 8),           // UINT64, INT64, FLOAT64
155        9 => skip_gguf_array(data, pos),  // ARRAY
156        _ => Err(FetchError::ConfigParseError {
157            message: format!("Unknown GGUF metadata type: {value_type}"),
158        }),
159    }
160}
161
162/// Skip a GGUF array value: type(4) + count(8) + values
163pub(super) fn skip_gguf_array(data: &[u8], pos: usize) -> Result<usize, FetchError> {
164    let elem_type = read_u32_le(data, pos)?;
165    let count = read_u64_le(data, pos + 4)?;
166    let mut p = pos + 12;
167    for _ in 0..count {
168        p = skip_gguf_value(data, p, elem_type)?;
169    }
170    Ok(p)
171}
172
173/// Create a truncation error
174pub(super) fn truncation_error(pos: usize) -> FetchError {
175    FetchError::ConfigParseError { message: format!("GGUF file truncated at byte offset {pos}") }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181
182    // skip_gguf_value match arm coverage for all GGUF metadata type tags
183
184    #[test]
185    fn test_skip_gguf_value_variant_0_1_7() {
186        let data = [0u8; 16];
187        // UINT8 (0)
188        assert_eq!(skip_gguf_value(&data, 0, 0).expect("operation should succeed"), 1);
189        // INT8 (1)
190        assert_eq!(skip_gguf_value(&data, 0, 1).expect("operation should succeed"), 1);
191        // BOOL (7)
192        assert_eq!(skip_gguf_value(&data, 0, 7).expect("operation should succeed"), 1);
193    }
194
195    #[test]
196    fn test_skip_gguf_value_variant_2_3() {
197        let data = [0u8; 16];
198        // UINT16 (2)
199        assert_eq!(skip_gguf_value(&data, 0, 2).expect("operation should succeed"), 2);
200        // INT16 (3)
201        assert_eq!(skip_gguf_value(&data, 0, 3).expect("operation should succeed"), 2);
202    }
203
204    #[test]
205    fn test_skip_gguf_value_variant_4_to_6() {
206        let data = [0u8; 16];
207        // UINT32 (4)
208        assert_eq!(skip_gguf_value(&data, 0, 4).expect("operation should succeed"), 4);
209        // INT32 (5)
210        assert_eq!(skip_gguf_value(&data, 0, 5).expect("operation should succeed"), 4);
211        // FLOAT32 (6)
212        assert_eq!(skip_gguf_value(&data, 0, 6).expect("operation should succeed"), 4);
213    }
214
215    #[test]
216    fn test_skip_gguf_value_variant_8() {
217        // STRING: 8 bytes length (u64 LE) + string bytes
218        let mut data = vec![0u8; 16];
219        // length = 3 (u64 LE)
220        data[0] = 3;
221        // 3 bytes of string data
222        data[8] = b'a';
223        data[9] = b'b';
224        data[10] = b'c';
225        assert_eq!(skip_gguf_value(&data, 0, 8).expect("operation should succeed"), 11);
226    }
227
228    #[test]
229    fn test_skip_gguf_value_variant_10_to_12() {
230        let data = [0u8; 16];
231        // UINT64 (10)
232        assert_eq!(skip_gguf_value(&data, 0, 10).expect("operation should succeed"), 8);
233        // INT64 (11)
234        assert_eq!(skip_gguf_value(&data, 0, 11).expect("operation should succeed"), 8);
235        // FLOAT64 (12)
236        assert_eq!(skip_gguf_value(&data, 0, 12).expect("operation should succeed"), 8);
237    }
238
239    #[test]
240    fn test_skip_gguf_value_variant_9() {
241        // ARRAY: type(4) + count(8) + values
242        // Array of 2 UINT8 values
243        let mut data = vec![0u8; 32];
244        // elem_type = 0 (UINT8)
245        data[0] = 0;
246        // count = 2 (u64 LE)
247        data[4] = 2;
248        assert_eq!(skip_gguf_value(&data, 0, 9).expect("operation should succeed"), 14);
249    }
250
251    #[test]
252    fn test_skip_gguf_value_unknown_type() {
253        let data = [0u8; 16];
254        assert!(skip_gguf_value(&data, 0, 99).is_err());
255    }
256}