Skip to main content

ai_hwaccel/
model_format.rs

1//! Model file format detection.
2//!
3//! Parses `.safetensors`, `.gguf`, and `.onnx` file headers to extract model
4//! metadata (parameter count, data type, tensor names) without loading the
5//! full model into memory. Only the first few kilobytes are read.
6//!
7//! # Examples
8//!
9//! ```rust,no_run
10//! use ai_hwaccel::model_format::{detect_format, ModelFormat};
11//!
12//! let metadata = detect_format(std::path::Path::new("model.safetensors")).unwrap();
13//! println!("Format: {}", metadata.format);
14//! println!("Parameters: {}", metadata.param_count.unwrap_or(0));
15//! ```
16
17use std::fmt;
18use std::path::Path;
19
20use serde::{Deserialize, Serialize};
21
22/// Maximum bytes to read from a file header.
23const MAX_HEADER_BYTES: usize = 16 * 1024; // 16 KB
24
25/// GGUF magic number: "GGUF" in little-endian.
26const GGUF_MAGIC: u32 = 0x4655_4747; // "GGUF"
27
28/// ONNX protobuf file starts with field tag 0x08 (varint field 1).
29const ONNX_IR_VERSION_TAG: u8 = 0x08;
30
31/// Detected model file format.
32#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33#[non_exhaustive]
34pub enum ModelFormat {
35    /// HuggingFace SafeTensors format.
36    SafeTensors,
37    /// GGML/GGUF format (llama.cpp).
38    GGUF,
39    /// ONNX (Open Neural Network Exchange).
40    ONNX,
41    /// PyTorch serialized format (pickle-based).
42    PyTorch,
43}
44
45impl fmt::Display for ModelFormat {
46    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47        match self {
48            Self::SafeTensors => write!(f, "SafeTensors"),
49            Self::GGUF => write!(f, "GGUF"),
50            Self::ONNX => write!(f, "ONNX"),
51            Self::PyTorch => write!(f, "PyTorch"),
52        }
53    }
54}
55
56/// Metadata extracted from a model file header.
57#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ModelMetadata {
59    /// Detected format.
60    pub format: ModelFormat,
61    /// Estimated total parameter count (if extractable from header).
62    pub param_count: Option<u64>,
63    /// Weight data type (e.g. "F16", "BF16", "F32", "Q4_0").
64    pub dtype: Option<String>,
65    /// Number of tensors found in header.
66    pub tensor_count: Option<u32>,
67    /// GGUF version (if applicable).
68    pub format_version: Option<u32>,
69}
70
71/// Detect the model format from a file path.
72///
73/// Reads only the first few kilobytes to identify the format and extract
74/// metadata. Returns `None` if the format is unrecognised.
75#[must_use]
76pub fn detect_format(path: &Path) -> Option<ModelMetadata> {
77    use std::io::Read;
78    let mut file = std::fs::File::open(path).ok()?;
79    let mut buf = vec![0u8; MAX_HEADER_BYTES];
80    let n = file.read(&mut buf).ok()?;
81    buf.truncate(n);
82    detect_format_from_bytes(&buf)
83}
84
85/// Detect the model format from raw bytes (typically the first 16 KB).
86///
87/// This is the WASM-compatible entry point — no file I/O required.
88#[must_use]
89pub fn detect_format_from_bytes(bytes: &[u8]) -> Option<ModelMetadata> {
90    // Try each format in order of specificity.
91    if let Some(meta) = parse_safetensors_header(bytes) {
92        return Some(meta);
93    }
94    if let Some(meta) = parse_gguf_header(bytes) {
95        return Some(meta);
96    }
97    if let Some(meta) = parse_onnx_header(bytes) {
98        return Some(meta);
99    }
100    if is_pytorch_format(bytes) {
101        return Some(ModelMetadata {
102            format: ModelFormat::PyTorch,
103            param_count: None,
104            dtype: None,
105            tensor_count: None,
106            format_version: None,
107        });
108    }
109    None
110}
111
112// ---------------------------------------------------------------------------
113// SafeTensors parser
114// ---------------------------------------------------------------------------
115
116/// Parse SafeTensors file header.
117///
118/// Format: 8-byte LE header_size, then JSON metadata of that size.
119/// JSON contains tensor names as keys with `{dtype, shape, data_offsets}`.
120fn parse_safetensors_header(bytes: &[u8]) -> Option<ModelMetadata> {
121    if bytes.len() < 8 {
122        return None;
123    }
124
125    let header_size = u64::from_le_bytes(bytes[..8].try_into().ok()?) as usize;
126
127    // Sanity: header should be reasonable (< 100 MB) and start with '{'.
128    if header_size == 0 || header_size > 100 * 1024 * 1024 {
129        return None;
130    }
131
132    // We may not have the full header in our buffer, but we can still
133    // identify the format and parse what we have.
134    let json_end = (8 + header_size).min(bytes.len());
135    let json_bytes = &bytes[8..json_end];
136
137    // Must start with '{' to be valid JSON object.
138    let first_non_ws = json_bytes.iter().find(|b| !b.is_ascii_whitespace())?;
139    if *first_non_ws != b'{' {
140        return None;
141    }
142
143    // Try to parse as complete JSON if we have enough bytes.
144    let json_str = std::str::from_utf8(json_bytes).ok()?;
145
146    // If we have the complete header, parse it fully.
147    if json_end - 8 >= header_size
148        && let Ok(header) = serde_json::from_str::<serde_json::Value>(json_str)
149    {
150        return Some(extract_safetensors_metadata(&header));
151    }
152
153    // Partial header — we can still identify format.
154    Some(ModelMetadata {
155        format: ModelFormat::SafeTensors,
156        param_count: None,
157        dtype: None,
158        tensor_count: None,
159        format_version: None,
160    })
161}
162
163/// Extract metadata from a parsed SafeTensors JSON header.
164fn extract_safetensors_metadata(header: &serde_json::Value) -> ModelMetadata {
165    let obj = match header.as_object() {
166        Some(o) => o,
167        None => {
168            return ModelMetadata {
169                format: ModelFormat::SafeTensors,
170                param_count: None,
171                dtype: None,
172                tensor_count: None,
173                format_version: None,
174            };
175        }
176    };
177
178    let mut total_params: u64 = 0;
179    let mut tensor_count: u32 = 0;
180    let mut dtype = None;
181
182    for (key, value) in obj {
183        // Skip metadata key "__metadata__".
184        if key == "__metadata__" {
185            continue;
186        }
187
188        tensor_count = tensor_count.saturating_add(1);
189
190        if let Some(tensor_obj) = value.as_object() {
191            // Extract dtype from first tensor.
192            if dtype.is_none()
193                && let Some(dt) = tensor_obj.get("dtype").and_then(|v| v.as_str())
194            {
195                dtype = Some(dt.to_string());
196            }
197
198            // Count parameters from shape (skip empty shapes — scalars).
199            if let Some(shape) = tensor_obj.get("shape").and_then(|v| v.as_array())
200                && !shape.is_empty()
201            {
202                let params: u64 = shape.iter().filter_map(|d| d.as_u64()).product();
203                total_params = total_params.saturating_add(params);
204            }
205        }
206    }
207
208    ModelMetadata {
209        format: ModelFormat::SafeTensors,
210        param_count: if total_params > 0 {
211            Some(total_params)
212        } else {
213            None
214        },
215        dtype,
216        tensor_count: Some(tensor_count),
217        format_version: None,
218    }
219}
220
221// ---------------------------------------------------------------------------
222// GGUF parser
223// ---------------------------------------------------------------------------
224
225/// Parse GGUF file header.
226///
227/// Format: 4-byte magic "GGUF", 4-byte version, 8-byte tensor count,
228/// 8-byte metadata KV count, then metadata pairs.
229fn parse_gguf_header(bytes: &[u8]) -> Option<ModelMetadata> {
230    if bytes.len() < 20 {
231        return None;
232    }
233
234    let magic = u32::from_le_bytes(bytes[..4].try_into().ok()?);
235    if magic != GGUF_MAGIC {
236        return None;
237    }
238
239    let version = u32::from_le_bytes(bytes[4..8].try_into().ok()?);
240    let tensor_count = u64::from_le_bytes(bytes[8..16].try_into().ok()?);
241    let _kv_count = u64::from_le_bytes(bytes[16..24].try_into().ok()?);
242
243    // Try to extract dtype from metadata KV pairs.
244    // GGUF metadata is complex to parse fully; extract what we can.
245    let dtype = extract_gguf_dtype(bytes, 24);
246
247    Some(ModelMetadata {
248        format: ModelFormat::GGUF,
249        param_count: None, // GGUF doesn't store param count directly.
250        dtype,
251        tensor_count: if tensor_count <= u32::MAX as u64 {
252            Some(tensor_count as u32)
253        } else {
254            None
255        },
256        format_version: Some(version),
257    })
258}
259
260/// Try to extract the general.file_type from GGUF metadata.
261///
262/// This is a best-effort parser — GGUF KV format requires walking
263/// variable-length keys and values. We scan for known patterns.
264fn extract_gguf_dtype(bytes: &[u8], offset: usize) -> Option<String> {
265    // Scan for "general.file_type" key followed by a u32 value.
266    let needle = b"general.file_type";
267    let pos = bytes
268        .get(offset..)?
269        .windows(needle.len())
270        .position(|w| w == needle)?;
271
272    // The value type tag and value follow the key.
273    // Skip: key_len(8) + key + value_type(4) + value(4)
274    let value_offset = offset + pos + needle.len();
275    if value_offset + 8 > bytes.len() {
276        return None;
277    }
278
279    // Value type 4 = UINT32 in GGUF spec.
280    let value_type = u32::from_le_bytes(bytes[value_offset..value_offset + 4].try_into().ok()?);
281    if value_type != 4 {
282        return None;
283    }
284
285    let file_type = u32::from_le_bytes(bytes[value_offset + 4..value_offset + 8].try_into().ok()?);
286
287    // Map GGUF file types to human-readable names.
288    let name = match file_type {
289        0 => "F32",
290        1 => "F16",
291        2 => "Q4_0",
292        3 => "Q4_1",
293        7 => "Q8_0",
294        8 => "Q5_0",
295        9 => "Q5_1",
296        10 => "Q2_K",
297        11 => "Q3_K_S",
298        12 => "Q3_K_M",
299        13 => "Q3_K_L",
300        14 => "Q4_K_S",
301        15 => "Q4_K_M",
302        16 => "Q5_K_S",
303        17 => "Q5_K_M",
304        18 => "Q6_K",
305        19 => "IQ2_XXS",
306        20 => "IQ2_XS",
307        _ => return Some(format!("GGUF_TYPE_{file_type}")),
308    };
309    Some(name.to_string())
310}
311
312// ---------------------------------------------------------------------------
313// ONNX parser
314// ---------------------------------------------------------------------------
315
316/// Parse ONNX file header.
317///
318/// ONNX uses Protocol Buffers. The file starts with the ModelProto message.
319/// Field 1 (ir_version) is a varint, field 5 (model_version) is a varint.
320fn parse_onnx_header(bytes: &[u8]) -> Option<ModelMetadata> {
321    if bytes.len() < 4 {
322        return None;
323    }
324
325    // ONNX starts with protobuf field 1 (ir_version), tag = 0x08.
326    if bytes[0] != ONNX_IR_VERSION_TAG {
327        return None;
328    }
329
330    // Parse ir_version as varint.
331    let (ir_version, consumed) = parse_varint(&bytes[1..])?;
332
333    // Sanity: ir_version should be 1-10 (current range).
334    if ir_version == 0 || ir_version > 20 {
335        return None;
336    }
337
338    // Strengthen detection: require a valid second protobuf field after ir_version.
339    // ONNX ModelProto field 2 (opset_import) has tag 0x3A (field 7, wire type 2)
340    // or field 8 (metadata_props) 0x42, or producer_name (field 2) 0x12.
341    // We accept any valid protobuf field tag (wire type 0-2, field 1-15).
342    let next_offset = 1 + consumed;
343    if next_offset < bytes.len() {
344        let next_tag = bytes[next_offset];
345        let wire_type = next_tag & 0x07;
346        let field_num = next_tag >> 3;
347        // Valid protobuf: wire type 0 (varint), 1 (64-bit), 2 (length-delimited)
348        // and field number 1-15 (single-byte tags).
349        if wire_type > 2 || field_num == 0 {
350            return None;
351        }
352    } else {
353        // Only ir_version and nothing else — too short to be a real ONNX model.
354        return None;
355    }
356
357    Some(ModelMetadata {
358        format: ModelFormat::ONNX,
359        param_count: None,
360        dtype: None,
361        tensor_count: None,
362        format_version: Some(ir_version as u32),
363    })
364}
365
366/// Parse a protobuf varint. Returns (value, bytes_consumed).
367fn parse_varint(bytes: &[u8]) -> Option<(u64, usize)> {
368    let mut result: u64 = 0;
369    let mut shift = 0u32;
370    for (i, &byte) in bytes.iter().enumerate() {
371        if shift >= 64 {
372            return None;
373        }
374        result |= ((byte & 0x7F) as u64) << shift;
375        if byte & 0x80 == 0 {
376            return Some((result, i + 1));
377        }
378        shift += 7;
379    }
380    None
381}
382
383// ---------------------------------------------------------------------------
384// PyTorch detector
385// ---------------------------------------------------------------------------
386
387/// Check if bytes look like a PyTorch serialized file.
388///
389/// PyTorch files are ZIP archives containing pickle data.
390/// ZIP magic: PK\x03\x04 (0x04034b50 LE).
391fn is_pytorch_format(bytes: &[u8]) -> bool {
392    bytes.len() >= 4 && bytes[..4] == [0x50, 0x4B, 0x03, 0x04]
393}
394
395#[cfg(test)]
396mod tests {
397    use super::*;
398
399    // SafeTensors tests
400    #[test]
401    fn safetensors_valid_header() {
402        let json = r#"{"weight":{"dtype":"F16","shape":[768,768],"data_offsets":[0,1179648]}}"#;
403        let header_size = json.len() as u64;
404        let mut bytes = header_size.to_le_bytes().to_vec();
405        bytes.extend_from_slice(json.as_bytes());
406
407        let meta = detect_format_from_bytes(&bytes).unwrap();
408        assert_eq!(meta.format, ModelFormat::SafeTensors);
409        assert_eq!(meta.param_count, Some(768 * 768));
410        assert_eq!(meta.dtype.as_deref(), Some("F16"));
411        assert_eq!(meta.tensor_count, Some(1));
412    }
413
414    #[test]
415    fn safetensors_multi_tensor() {
416        let json = r#"{"w1":{"dtype":"BF16","shape":[1024,512],"data_offsets":[0,1]},"w2":{"dtype":"BF16","shape":[512,256],"data_offsets":[1,2]},"__metadata__":{"format":"pt"}}"#;
417        let header_size = json.len() as u64;
418        let mut bytes = header_size.to_le_bytes().to_vec();
419        bytes.extend_from_slice(json.as_bytes());
420
421        let meta = detect_format_from_bytes(&bytes).unwrap();
422        assert_eq!(meta.format, ModelFormat::SafeTensors);
423        assert_eq!(meta.param_count, Some(1024 * 512 + 512 * 256));
424        assert_eq!(meta.dtype.as_deref(), Some("BF16"));
425        assert_eq!(meta.tensor_count, Some(2)); // __metadata__ excluded
426    }
427
428    #[test]
429    fn safetensors_too_small() {
430        assert!(detect_format_from_bytes(&[0u8; 4]).is_none());
431    }
432
433    #[test]
434    fn safetensors_bad_header_size() {
435        // Header size says 1 GB — too large to be real.
436        let bytes = (1_000_000_000u64).to_le_bytes();
437        assert!(parse_safetensors_header(&bytes).is_none());
438    }
439
440    // GGUF tests
441    #[test]
442    fn gguf_valid_header() {
443        let mut bytes = Vec::new();
444        bytes.extend_from_slice(&GGUF_MAGIC.to_le_bytes()); // magic
445        bytes.extend_from_slice(&3u32.to_le_bytes()); // version 3
446        bytes.extend_from_slice(&42u64.to_le_bytes()); // 42 tensors
447        bytes.extend_from_slice(&5u64.to_le_bytes()); // 5 KV pairs
448
449        let meta = detect_format_from_bytes(&bytes).unwrap();
450        assert_eq!(meta.format, ModelFormat::GGUF);
451        assert_eq!(meta.tensor_count, Some(42));
452        assert_eq!(meta.format_version, Some(3));
453    }
454
455    #[test]
456    fn gguf_wrong_magic() {
457        let bytes = [0u8; 24];
458        assert!(parse_gguf_header(&bytes).is_none());
459    }
460
461    #[test]
462    fn gguf_too_small() {
463        assert!(parse_gguf_header(&[0u8; 10]).is_none());
464    }
465
466    // ONNX tests
467    #[test]
468    fn onnx_valid_header() {
469        // ir_version = 9 encoded as varint: 0x08, 0x09
470        let bytes = [0x08, 0x09, 0x12, 0x00];
471        let meta = detect_format_from_bytes(&bytes).unwrap();
472        assert_eq!(meta.format, ModelFormat::ONNX);
473        assert_eq!(meta.format_version, Some(9));
474    }
475
476    #[test]
477    fn onnx_bad_ir_version() {
478        // ir_version = 0 — invalid.
479        let bytes = [0x08, 0x00];
480        assert!(parse_onnx_header(&bytes).is_none());
481    }
482
483    // PyTorch tests
484    #[test]
485    fn pytorch_zip_magic() {
486        let bytes = [0x50, 0x4B, 0x03, 0x04, 0x00, 0x00];
487        let meta = detect_format_from_bytes(&bytes).unwrap();
488        assert_eq!(meta.format, ModelFormat::PyTorch);
489    }
490
491    #[test]
492    fn pytorch_not_zip() {
493        let bytes = [0x00, 0x00, 0x00, 0x00];
494        assert!(!is_pytorch_format(&bytes));
495    }
496
497    // Format detection priority
498    #[test]
499    fn unknown_format_returns_none() {
500        let bytes = [0xFF, 0xFE, 0xFD, 0xFC, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
501        assert!(detect_format_from_bytes(&bytes).is_none());
502    }
503
504    // Display
505    #[test]
506    fn format_display() {
507        assert_eq!(ModelFormat::SafeTensors.to_string(), "SafeTensors");
508        assert_eq!(ModelFormat::GGUF.to_string(), "GGUF");
509        assert_eq!(ModelFormat::ONNX.to_string(), "ONNX");
510        assert_eq!(ModelFormat::PyTorch.to_string(), "PyTorch");
511    }
512
513    // Serde roundtrip
514    #[test]
515    fn format_serde_roundtrip() {
516        for fmt in [
517            ModelFormat::SafeTensors,
518            ModelFormat::GGUF,
519            ModelFormat::ONNX,
520            ModelFormat::PyTorch,
521        ] {
522            let json = serde_json::to_string(&fmt).unwrap();
523            let back: ModelFormat = serde_json::from_str(&json).unwrap();
524            assert_eq!(fmt, back);
525        }
526    }
527
528    // Varint parser
529    #[test]
530    fn varint_single_byte() {
531        assert_eq!(parse_varint(&[0x09]), Some((9, 1)));
532    }
533
534    #[test]
535    fn varint_multi_byte() {
536        // 300 = 0b100101100 → [0xAC, 0x02]
537        assert_eq!(parse_varint(&[0xAC, 0x02]), Some((300, 2)));
538    }
539
540    #[test]
541    fn varint_empty() {
542        assert_eq!(parse_varint(&[]), None);
543    }
544
545    #[test]
546    fn varint_unterminated() {
547        // All continuation bits set, never terminates.
548        assert_eq!(parse_varint(&[0x80, 0x80, 0x80]), None);
549    }
550
551    // Audit edge cases
552    #[test]
553    fn safetensors_empty_shape_not_counted() {
554        // Scalar tensor with empty shape should not inflate param count.
555        let json = r#"{"bias":{"dtype":"F32","shape":[],"data_offsets":[0,4]}}"#;
556        let header_size = json.len() as u64;
557        let mut bytes = header_size.to_le_bytes().to_vec();
558        bytes.extend_from_slice(json.as_bytes());
559
560        let meta = detect_format_from_bytes(&bytes).unwrap();
561        assert_eq!(meta.format, ModelFormat::SafeTensors);
562        assert_eq!(meta.param_count, None); // Empty shape → no params counted.
563        assert_eq!(meta.tensor_count, Some(1));
564    }
565
566    #[test]
567    fn onnx_too_short_after_ir_version() {
568        // Only ir_version, no second field — should not match.
569        let bytes = [0x08, 0x09];
570        assert!(parse_onnx_header(&bytes).is_none());
571    }
572
573    #[test]
574    fn onnx_invalid_second_field() {
575        // ir_version=9, then wire type 7 (invalid) at field 0.
576        let bytes = [0x08, 0x09, 0x07];
577        assert!(parse_onnx_header(&bytes).is_none());
578    }
579
580    #[test]
581    fn onnx_valid_second_field() {
582        // ir_version=9, then field 2 wire type 2 (producer_name, length-delimited).
583        let bytes = [0x08, 0x09, 0x12, 0x05, b'o', b'n', b'n', b'x', b'!'];
584        let meta = detect_format_from_bytes(&bytes).unwrap();
585        assert_eq!(meta.format, ModelFormat::ONNX);
586        assert_eq!(meta.format_version, Some(9));
587    }
588
589    #[test]
590    fn random_0x08_not_onnx() {
591        // Arbitrary binary starting with 0x08 should not be detected as ONNX.
592        let bytes = [0x08, 0x05, 0xFF, 0xFF]; // wire type 7 = invalid
593        assert!(parse_onnx_header(&bytes).is_none());
594    }
595}