use std::fmt;
use std::path::Path;
use serde::{Deserialize, Serialize};
const MAX_HEADER_BYTES: usize = 16 * 1024;
const GGUF_MAGIC: u32 = 0x4655_4747;
const ONNX_IR_VERSION_TAG: u8 = 0x08;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[non_exhaustive]
pub enum ModelFormat {
SafeTensors,
GGUF,
ONNX,
PyTorch,
}
impl fmt::Display for ModelFormat {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::SafeTensors => write!(f, "SafeTensors"),
Self::GGUF => write!(f, "GGUF"),
Self::ONNX => write!(f, "ONNX"),
Self::PyTorch => write!(f, "PyTorch"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadata {
pub format: ModelFormat,
pub param_count: Option<u64>,
pub dtype: Option<String>,
pub tensor_count: Option<u32>,
pub format_version: Option<u32>,
}
#[must_use]
pub fn detect_format(path: &Path) -> Option<ModelMetadata> {
use std::io::Read;
let mut file = std::fs::File::open(path).ok()?;
let mut buf = vec![0u8; MAX_HEADER_BYTES];
let n = file.read(&mut buf).ok()?;
buf.truncate(n);
detect_format_from_bytes(&buf)
}
#[must_use]
pub fn detect_format_from_bytes(bytes: &[u8]) -> Option<ModelMetadata> {
if let Some(meta) = parse_safetensors_header(bytes) {
return Some(meta);
}
if let Some(meta) = parse_gguf_header(bytes) {
return Some(meta);
}
if let Some(meta) = parse_onnx_header(bytes) {
return Some(meta);
}
if is_pytorch_format(bytes) {
return Some(ModelMetadata {
format: ModelFormat::PyTorch,
param_count: None,
dtype: None,
tensor_count: None,
format_version: None,
});
}
None
}
fn parse_safetensors_header(bytes: &[u8]) -> Option<ModelMetadata> {
if bytes.len() < 8 {
return None;
}
let header_size = u64::from_le_bytes(bytes[..8].try_into().ok()?) as usize;
if header_size == 0 || header_size > 100 * 1024 * 1024 {
return None;
}
let json_end = (8 + header_size).min(bytes.len());
let json_bytes = &bytes[8..json_end];
let first_non_ws = json_bytes.iter().find(|b| !b.is_ascii_whitespace())?;
if *first_non_ws != b'{' {
return None;
}
let json_str = std::str::from_utf8(json_bytes).ok()?;
if json_end - 8 >= header_size
&& let Ok(header) = serde_json::from_str::<serde_json::Value>(json_str)
{
return Some(extract_safetensors_metadata(&header));
}
Some(ModelMetadata {
format: ModelFormat::SafeTensors,
param_count: None,
dtype: None,
tensor_count: None,
format_version: None,
})
}
fn extract_safetensors_metadata(header: &serde_json::Value) -> ModelMetadata {
let obj = match header.as_object() {
Some(o) => o,
None => {
return ModelMetadata {
format: ModelFormat::SafeTensors,
param_count: None,
dtype: None,
tensor_count: None,
format_version: None,
};
}
};
let mut total_params: u64 = 0;
let mut tensor_count: u32 = 0;
let mut dtype = None;
for (key, value) in obj {
if key == "__metadata__" {
continue;
}
tensor_count = tensor_count.saturating_add(1);
if let Some(tensor_obj) = value.as_object() {
if dtype.is_none()
&& let Some(dt) = tensor_obj.get("dtype").and_then(|v| v.as_str())
{
dtype = Some(dt.to_string());
}
if let Some(shape) = tensor_obj.get("shape").and_then(|v| v.as_array())
&& !shape.is_empty()
{
let params: u64 = shape.iter().filter_map(|d| d.as_u64()).product();
total_params = total_params.saturating_add(params);
}
}
}
ModelMetadata {
format: ModelFormat::SafeTensors,
param_count: if total_params > 0 {
Some(total_params)
} else {
None
},
dtype,
tensor_count: Some(tensor_count),
format_version: None,
}
}
fn parse_gguf_header(bytes: &[u8]) -> Option<ModelMetadata> {
if bytes.len() < 20 {
return None;
}
let magic = u32::from_le_bytes(bytes[..4].try_into().ok()?);
if magic != GGUF_MAGIC {
return None;
}
let version = u32::from_le_bytes(bytes[4..8].try_into().ok()?);
let tensor_count = u64::from_le_bytes(bytes[8..16].try_into().ok()?);
let _kv_count = u64::from_le_bytes(bytes[16..24].try_into().ok()?);
let dtype = extract_gguf_dtype(bytes, 24);
Some(ModelMetadata {
format: ModelFormat::GGUF,
param_count: None, dtype,
tensor_count: if tensor_count <= u32::MAX as u64 {
Some(tensor_count as u32)
} else {
None
},
format_version: Some(version),
})
}
fn extract_gguf_dtype(bytes: &[u8], offset: usize) -> Option<String> {
let needle = b"general.file_type";
let pos = bytes
.get(offset..)?
.windows(needle.len())
.position(|w| w == needle)?;
let value_offset = offset + pos + needle.len();
if value_offset + 8 > bytes.len() {
return None;
}
let value_type = u32::from_le_bytes(bytes[value_offset..value_offset + 4].try_into().ok()?);
if value_type != 4 {
return None;
}
let file_type = u32::from_le_bytes(bytes[value_offset + 4..value_offset + 8].try_into().ok()?);
let name = match file_type {
0 => "F32",
1 => "F16",
2 => "Q4_0",
3 => "Q4_1",
7 => "Q8_0",
8 => "Q5_0",
9 => "Q5_1",
10 => "Q2_K",
11 => "Q3_K_S",
12 => "Q3_K_M",
13 => "Q3_K_L",
14 => "Q4_K_S",
15 => "Q4_K_M",
16 => "Q5_K_S",
17 => "Q5_K_M",
18 => "Q6_K",
19 => "IQ2_XXS",
20 => "IQ2_XS",
_ => return Some(format!("GGUF_TYPE_{file_type}")),
};
Some(name.to_string())
}
fn parse_onnx_header(bytes: &[u8]) -> Option<ModelMetadata> {
if bytes.len() < 4 {
return None;
}
if bytes[0] != ONNX_IR_VERSION_TAG {
return None;
}
let (ir_version, consumed) = parse_varint(&bytes[1..])?;
if ir_version == 0 || ir_version > 20 {
return None;
}
let next_offset = 1 + consumed;
if next_offset < bytes.len() {
let next_tag = bytes[next_offset];
let wire_type = next_tag & 0x07;
let field_num = next_tag >> 3;
if wire_type > 2 || field_num == 0 {
return None;
}
} else {
return None;
}
Some(ModelMetadata {
format: ModelFormat::ONNX,
param_count: None,
dtype: None,
tensor_count: None,
format_version: Some(ir_version as u32),
})
}
fn parse_varint(bytes: &[u8]) -> Option<(u64, usize)> {
let mut result: u64 = 0;
let mut shift = 0u32;
for (i, &byte) in bytes.iter().enumerate() {
if shift >= 64 {
return None;
}
result |= ((byte & 0x7F) as u64) << shift;
if byte & 0x80 == 0 {
return Some((result, i + 1));
}
shift += 7;
}
None
}
fn is_pytorch_format(bytes: &[u8]) -> bool {
bytes.len() >= 4 && bytes[..4] == [0x50, 0x4B, 0x03, 0x04]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn safetensors_valid_header() {
let json = r#"{"weight":{"dtype":"F16","shape":[768,768],"data_offsets":[0,1179648]}}"#;
let header_size = json.len() as u64;
let mut bytes = header_size.to_le_bytes().to_vec();
bytes.extend_from_slice(json.as_bytes());
let meta = detect_format_from_bytes(&bytes).unwrap();
assert_eq!(meta.format, ModelFormat::SafeTensors);
assert_eq!(meta.param_count, Some(768 * 768));
assert_eq!(meta.dtype.as_deref(), Some("F16"));
assert_eq!(meta.tensor_count, Some(1));
}
#[test]
fn safetensors_multi_tensor() {
let json = r#"{"w1":{"dtype":"BF16","shape":[1024,512],"data_offsets":[0,1]},"w2":{"dtype":"BF16","shape":[512,256],"data_offsets":[1,2]},"__metadata__":{"format":"pt"}}"#;
let header_size = json.len() as u64;
let mut bytes = header_size.to_le_bytes().to_vec();
bytes.extend_from_slice(json.as_bytes());
let meta = detect_format_from_bytes(&bytes).unwrap();
assert_eq!(meta.format, ModelFormat::SafeTensors);
assert_eq!(meta.param_count, Some(1024 * 512 + 512 * 256));
assert_eq!(meta.dtype.as_deref(), Some("BF16"));
assert_eq!(meta.tensor_count, Some(2)); }
#[test]
fn safetensors_too_small() {
assert!(detect_format_from_bytes(&[0u8; 4]).is_none());
}
#[test]
fn safetensors_bad_header_size() {
let bytes = (1_000_000_000u64).to_le_bytes();
assert!(parse_safetensors_header(&bytes).is_none());
}
#[test]
fn gguf_valid_header() {
let mut bytes = Vec::new();
bytes.extend_from_slice(&GGUF_MAGIC.to_le_bytes()); bytes.extend_from_slice(&3u32.to_le_bytes()); bytes.extend_from_slice(&42u64.to_le_bytes()); bytes.extend_from_slice(&5u64.to_le_bytes());
let meta = detect_format_from_bytes(&bytes).unwrap();
assert_eq!(meta.format, ModelFormat::GGUF);
assert_eq!(meta.tensor_count, Some(42));
assert_eq!(meta.format_version, Some(3));
}
#[test]
fn gguf_wrong_magic() {
let bytes = [0u8; 24];
assert!(parse_gguf_header(&bytes).is_none());
}
#[test]
fn gguf_too_small() {
assert!(parse_gguf_header(&[0u8; 10]).is_none());
}
#[test]
fn onnx_valid_header() {
let bytes = [0x08, 0x09, 0x12, 0x00];
let meta = detect_format_from_bytes(&bytes).unwrap();
assert_eq!(meta.format, ModelFormat::ONNX);
assert_eq!(meta.format_version, Some(9));
}
#[test]
fn onnx_bad_ir_version() {
let bytes = [0x08, 0x00];
assert!(parse_onnx_header(&bytes).is_none());
}
#[test]
fn pytorch_zip_magic() {
let bytes = [0x50, 0x4B, 0x03, 0x04, 0x00, 0x00];
let meta = detect_format_from_bytes(&bytes).unwrap();
assert_eq!(meta.format, ModelFormat::PyTorch);
}
#[test]
fn pytorch_not_zip() {
let bytes = [0x00, 0x00, 0x00, 0x00];
assert!(!is_pytorch_format(&bytes));
}
#[test]
fn unknown_format_returns_none() {
let bytes = [0xFF, 0xFE, 0xFD, 0xFC, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
assert!(detect_format_from_bytes(&bytes).is_none());
}
#[test]
fn format_display() {
assert_eq!(ModelFormat::SafeTensors.to_string(), "SafeTensors");
assert_eq!(ModelFormat::GGUF.to_string(), "GGUF");
assert_eq!(ModelFormat::ONNX.to_string(), "ONNX");
assert_eq!(ModelFormat::PyTorch.to_string(), "PyTorch");
}
#[test]
fn format_serde_roundtrip() {
for fmt in [
ModelFormat::SafeTensors,
ModelFormat::GGUF,
ModelFormat::ONNX,
ModelFormat::PyTorch,
] {
let json = serde_json::to_string(&fmt).unwrap();
let back: ModelFormat = serde_json::from_str(&json).unwrap();
assert_eq!(fmt, back);
}
}
#[test]
fn varint_single_byte() {
assert_eq!(parse_varint(&[0x09]), Some((9, 1)));
}
#[test]
fn varint_multi_byte() {
assert_eq!(parse_varint(&[0xAC, 0x02]), Some((300, 2)));
}
#[test]
fn varint_empty() {
assert_eq!(parse_varint(&[]), None);
}
#[test]
fn varint_unterminated() {
assert_eq!(parse_varint(&[0x80, 0x80, 0x80]), None);
}
#[test]
fn safetensors_empty_shape_not_counted() {
let json = r#"{"bias":{"dtype":"F32","shape":[],"data_offsets":[0,4]}}"#;
let header_size = json.len() as u64;
let mut bytes = header_size.to_le_bytes().to_vec();
bytes.extend_from_slice(json.as_bytes());
let meta = detect_format_from_bytes(&bytes).unwrap();
assert_eq!(meta.format, ModelFormat::SafeTensors);
assert_eq!(meta.param_count, None); assert_eq!(meta.tensor_count, Some(1));
}
#[test]
fn onnx_too_short_after_ir_version() {
let bytes = [0x08, 0x09];
assert!(parse_onnx_header(&bytes).is_none());
}
#[test]
fn onnx_invalid_second_field() {
let bytes = [0x08, 0x09, 0x07];
assert!(parse_onnx_header(&bytes).is_none());
}
#[test]
fn onnx_valid_second_field() {
let bytes = [0x08, 0x09, 0x12, 0x05, b'o', b'n', b'n', b'x', b'!'];
let meta = detect_format_from_bytes(&bytes).unwrap();
assert_eq!(meta.format, ModelFormat::ONNX);
assert_eq!(meta.format_version, Some(9));
}
#[test]
fn random_0x08_not_onnx() {
let bytes = [0x08, 0x05, 0xFF, 0xFF]; assert!(parse_onnx_header(&bytes).is_none());
}
}