1use std::fmt;
18use std::path::Path;
19
20use serde::{Deserialize, Serialize};
21
22const MAX_HEADER_BYTES: usize = 16 * 1024; const GGUF_MAGIC: u32 = 0x4655_4747; const ONNX_IR_VERSION_TAG: u8 = 0x08;
30
31#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
33#[non_exhaustive]
34pub enum ModelFormat {
35 SafeTensors,
37 GGUF,
39 ONNX,
41 PyTorch,
43}
44
45impl fmt::Display for ModelFormat {
46 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
47 match self {
48 Self::SafeTensors => write!(f, "SafeTensors"),
49 Self::GGUF => write!(f, "GGUF"),
50 Self::ONNX => write!(f, "ONNX"),
51 Self::PyTorch => write!(f, "PyTorch"),
52 }
53 }
54}
55
56#[derive(Debug, Clone, Serialize, Deserialize)]
58pub struct ModelMetadata {
59 pub format: ModelFormat,
61 pub param_count: Option<u64>,
63 pub dtype: Option<String>,
65 pub tensor_count: Option<u32>,
67 pub format_version: Option<u32>,
69}
70
71#[must_use]
76pub fn detect_format(path: &Path) -> Option<ModelMetadata> {
77 use std::io::Read;
78 let mut file = std::fs::File::open(path).ok()?;
79 let mut buf = vec![0u8; MAX_HEADER_BYTES];
80 let n = file.read(&mut buf).ok()?;
81 buf.truncate(n);
82 detect_format_from_bytes(&buf)
83}
84
85#[must_use]
89pub fn detect_format_from_bytes(bytes: &[u8]) -> Option<ModelMetadata> {
90 if let Some(meta) = parse_safetensors_header(bytes) {
92 return Some(meta);
93 }
94 if let Some(meta) = parse_gguf_header(bytes) {
95 return Some(meta);
96 }
97 if let Some(meta) = parse_onnx_header(bytes) {
98 return Some(meta);
99 }
100 if is_pytorch_format(bytes) {
101 return Some(ModelMetadata {
102 format: ModelFormat::PyTorch,
103 param_count: None,
104 dtype: None,
105 tensor_count: None,
106 format_version: None,
107 });
108 }
109 None
110}
111
112fn parse_safetensors_header(bytes: &[u8]) -> Option<ModelMetadata> {
121 if bytes.len() < 8 {
122 return None;
123 }
124
125 let header_size = u64::from_le_bytes(bytes[..8].try_into().ok()?) as usize;
126
127 if header_size == 0 || header_size > 100 * 1024 * 1024 {
129 return None;
130 }
131
132 let json_end = (8 + header_size).min(bytes.len());
135 let json_bytes = &bytes[8..json_end];
136
137 let first_non_ws = json_bytes.iter().find(|b| !b.is_ascii_whitespace())?;
139 if *first_non_ws != b'{' {
140 return None;
141 }
142
143 let json_str = std::str::from_utf8(json_bytes).ok()?;
145
146 if json_end - 8 >= header_size
148 && let Ok(header) = serde_json::from_str::<serde_json::Value>(json_str)
149 {
150 return Some(extract_safetensors_metadata(&header));
151 }
152
153 Some(ModelMetadata {
155 format: ModelFormat::SafeTensors,
156 param_count: None,
157 dtype: None,
158 tensor_count: None,
159 format_version: None,
160 })
161}
162
163fn extract_safetensors_metadata(header: &serde_json::Value) -> ModelMetadata {
165 let obj = match header.as_object() {
166 Some(o) => o,
167 None => {
168 return ModelMetadata {
169 format: ModelFormat::SafeTensors,
170 param_count: None,
171 dtype: None,
172 tensor_count: None,
173 format_version: None,
174 };
175 }
176 };
177
178 let mut total_params: u64 = 0;
179 let mut tensor_count: u32 = 0;
180 let mut dtype = None;
181
182 for (key, value) in obj {
183 if key == "__metadata__" {
185 continue;
186 }
187
188 tensor_count = tensor_count.saturating_add(1);
189
190 if let Some(tensor_obj) = value.as_object() {
191 if dtype.is_none()
193 && let Some(dt) = tensor_obj.get("dtype").and_then(|v| v.as_str())
194 {
195 dtype = Some(dt.to_string());
196 }
197
198 if let Some(shape) = tensor_obj.get("shape").and_then(|v| v.as_array())
200 && !shape.is_empty()
201 {
202 let params: u64 = shape.iter().filter_map(|d| d.as_u64()).product();
203 total_params = total_params.saturating_add(params);
204 }
205 }
206 }
207
208 ModelMetadata {
209 format: ModelFormat::SafeTensors,
210 param_count: if total_params > 0 {
211 Some(total_params)
212 } else {
213 None
214 },
215 dtype,
216 tensor_count: Some(tensor_count),
217 format_version: None,
218 }
219}
220
221fn parse_gguf_header(bytes: &[u8]) -> Option<ModelMetadata> {
230 if bytes.len() < 20 {
231 return None;
232 }
233
234 let magic = u32::from_le_bytes(bytes[..4].try_into().ok()?);
235 if magic != GGUF_MAGIC {
236 return None;
237 }
238
239 let version = u32::from_le_bytes(bytes[4..8].try_into().ok()?);
240 let tensor_count = u64::from_le_bytes(bytes[8..16].try_into().ok()?);
241 let _kv_count = u64::from_le_bytes(bytes[16..24].try_into().ok()?);
242
243 let dtype = extract_gguf_dtype(bytes, 24);
246
247 Some(ModelMetadata {
248 format: ModelFormat::GGUF,
249 param_count: None, dtype,
251 tensor_count: if tensor_count <= u32::MAX as u64 {
252 Some(tensor_count as u32)
253 } else {
254 None
255 },
256 format_version: Some(version),
257 })
258}
259
260fn extract_gguf_dtype(bytes: &[u8], offset: usize) -> Option<String> {
265 let needle = b"general.file_type";
267 let pos = bytes
268 .get(offset..)?
269 .windows(needle.len())
270 .position(|w| w == needle)?;
271
272 let value_offset = offset + pos + needle.len();
275 if value_offset + 8 > bytes.len() {
276 return None;
277 }
278
279 let value_type = u32::from_le_bytes(bytes[value_offset..value_offset + 4].try_into().ok()?);
281 if value_type != 4 {
282 return None;
283 }
284
285 let file_type = u32::from_le_bytes(bytes[value_offset + 4..value_offset + 8].try_into().ok()?);
286
287 let name = match file_type {
289 0 => "F32",
290 1 => "F16",
291 2 => "Q4_0",
292 3 => "Q4_1",
293 7 => "Q8_0",
294 8 => "Q5_0",
295 9 => "Q5_1",
296 10 => "Q2_K",
297 11 => "Q3_K_S",
298 12 => "Q3_K_M",
299 13 => "Q3_K_L",
300 14 => "Q4_K_S",
301 15 => "Q4_K_M",
302 16 => "Q5_K_S",
303 17 => "Q5_K_M",
304 18 => "Q6_K",
305 19 => "IQ2_XXS",
306 20 => "IQ2_XS",
307 _ => return Some(format!("GGUF_TYPE_{file_type}")),
308 };
309 Some(name.to_string())
310}
311
312fn parse_onnx_header(bytes: &[u8]) -> Option<ModelMetadata> {
321 if bytes.len() < 4 {
322 return None;
323 }
324
325 if bytes[0] != ONNX_IR_VERSION_TAG {
327 return None;
328 }
329
330 let (ir_version, consumed) = parse_varint(&bytes[1..])?;
332
333 if ir_version == 0 || ir_version > 20 {
335 return None;
336 }
337
338 let next_offset = 1 + consumed;
343 if next_offset < bytes.len() {
344 let next_tag = bytes[next_offset];
345 let wire_type = next_tag & 0x07;
346 let field_num = next_tag >> 3;
347 if wire_type > 2 || field_num == 0 {
350 return None;
351 }
352 } else {
353 return None;
355 }
356
357 Some(ModelMetadata {
358 format: ModelFormat::ONNX,
359 param_count: None,
360 dtype: None,
361 tensor_count: None,
362 format_version: Some(ir_version as u32),
363 })
364}
365
366fn parse_varint(bytes: &[u8]) -> Option<(u64, usize)> {
368 let mut result: u64 = 0;
369 let mut shift = 0u32;
370 for (i, &byte) in bytes.iter().enumerate() {
371 if shift >= 64 {
372 return None;
373 }
374 result |= ((byte & 0x7F) as u64) << shift;
375 if byte & 0x80 == 0 {
376 return Some((result, i + 1));
377 }
378 shift += 7;
379 }
380 None
381}
382
383fn is_pytorch_format(bytes: &[u8]) -> bool {
392 bytes.len() >= 4 && bytes[..4] == [0x50, 0x4B, 0x03, 0x04]
393}
394
395#[cfg(test)]
396mod tests {
397 use super::*;
398
399 #[test]
401 fn safetensors_valid_header() {
402 let json = r#"{"weight":{"dtype":"F16","shape":[768,768],"data_offsets":[0,1179648]}}"#;
403 let header_size = json.len() as u64;
404 let mut bytes = header_size.to_le_bytes().to_vec();
405 bytes.extend_from_slice(json.as_bytes());
406
407 let meta = detect_format_from_bytes(&bytes).unwrap();
408 assert_eq!(meta.format, ModelFormat::SafeTensors);
409 assert_eq!(meta.param_count, Some(768 * 768));
410 assert_eq!(meta.dtype.as_deref(), Some("F16"));
411 assert_eq!(meta.tensor_count, Some(1));
412 }
413
414 #[test]
415 fn safetensors_multi_tensor() {
416 let json = r#"{"w1":{"dtype":"BF16","shape":[1024,512],"data_offsets":[0,1]},"w2":{"dtype":"BF16","shape":[512,256],"data_offsets":[1,2]},"__metadata__":{"format":"pt"}}"#;
417 let header_size = json.len() as u64;
418 let mut bytes = header_size.to_le_bytes().to_vec();
419 bytes.extend_from_slice(json.as_bytes());
420
421 let meta = detect_format_from_bytes(&bytes).unwrap();
422 assert_eq!(meta.format, ModelFormat::SafeTensors);
423 assert_eq!(meta.param_count, Some(1024 * 512 + 512 * 256));
424 assert_eq!(meta.dtype.as_deref(), Some("BF16"));
425 assert_eq!(meta.tensor_count, Some(2)); }
427
428 #[test]
429 fn safetensors_too_small() {
430 assert!(detect_format_from_bytes(&[0u8; 4]).is_none());
431 }
432
433 #[test]
434 fn safetensors_bad_header_size() {
435 let bytes = (1_000_000_000u64).to_le_bytes();
437 assert!(parse_safetensors_header(&bytes).is_none());
438 }
439
440 #[test]
442 fn gguf_valid_header() {
443 let mut bytes = Vec::new();
444 bytes.extend_from_slice(&GGUF_MAGIC.to_le_bytes()); bytes.extend_from_slice(&3u32.to_le_bytes()); bytes.extend_from_slice(&42u64.to_le_bytes()); bytes.extend_from_slice(&5u64.to_le_bytes()); let meta = detect_format_from_bytes(&bytes).unwrap();
450 assert_eq!(meta.format, ModelFormat::GGUF);
451 assert_eq!(meta.tensor_count, Some(42));
452 assert_eq!(meta.format_version, Some(3));
453 }
454
455 #[test]
456 fn gguf_wrong_magic() {
457 let bytes = [0u8; 24];
458 assert!(parse_gguf_header(&bytes).is_none());
459 }
460
461 #[test]
462 fn gguf_too_small() {
463 assert!(parse_gguf_header(&[0u8; 10]).is_none());
464 }
465
466 #[test]
468 fn onnx_valid_header() {
469 let bytes = [0x08, 0x09, 0x12, 0x00];
471 let meta = detect_format_from_bytes(&bytes).unwrap();
472 assert_eq!(meta.format, ModelFormat::ONNX);
473 assert_eq!(meta.format_version, Some(9));
474 }
475
476 #[test]
477 fn onnx_bad_ir_version() {
478 let bytes = [0x08, 0x00];
480 assert!(parse_onnx_header(&bytes).is_none());
481 }
482
483 #[test]
485 fn pytorch_zip_magic() {
486 let bytes = [0x50, 0x4B, 0x03, 0x04, 0x00, 0x00];
487 let meta = detect_format_from_bytes(&bytes).unwrap();
488 assert_eq!(meta.format, ModelFormat::PyTorch);
489 }
490
491 #[test]
492 fn pytorch_not_zip() {
493 let bytes = [0x00, 0x00, 0x00, 0x00];
494 assert!(!is_pytorch_format(&bytes));
495 }
496
497 #[test]
499 fn unknown_format_returns_none() {
500 let bytes = [0xFF, 0xFE, 0xFD, 0xFC, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00];
501 assert!(detect_format_from_bytes(&bytes).is_none());
502 }
503
504 #[test]
506 fn format_display() {
507 assert_eq!(ModelFormat::SafeTensors.to_string(), "SafeTensors");
508 assert_eq!(ModelFormat::GGUF.to_string(), "GGUF");
509 assert_eq!(ModelFormat::ONNX.to_string(), "ONNX");
510 assert_eq!(ModelFormat::PyTorch.to_string(), "PyTorch");
511 }
512
513 #[test]
515 fn format_serde_roundtrip() {
516 for fmt in [
517 ModelFormat::SafeTensors,
518 ModelFormat::GGUF,
519 ModelFormat::ONNX,
520 ModelFormat::PyTorch,
521 ] {
522 let json = serde_json::to_string(&fmt).unwrap();
523 let back: ModelFormat = serde_json::from_str(&json).unwrap();
524 assert_eq!(fmt, back);
525 }
526 }
527
528 #[test]
530 fn varint_single_byte() {
531 assert_eq!(parse_varint(&[0x09]), Some((9, 1)));
532 }
533
534 #[test]
535 fn varint_multi_byte() {
536 assert_eq!(parse_varint(&[0xAC, 0x02]), Some((300, 2)));
538 }
539
540 #[test]
541 fn varint_empty() {
542 assert_eq!(parse_varint(&[]), None);
543 }
544
545 #[test]
546 fn varint_unterminated() {
547 assert_eq!(parse_varint(&[0x80, 0x80, 0x80]), None);
549 }
550
551 #[test]
553 fn safetensors_empty_shape_not_counted() {
554 let json = r#"{"bias":{"dtype":"F32","shape":[],"data_offsets":[0,4]}}"#;
556 let header_size = json.len() as u64;
557 let mut bytes = header_size.to_le_bytes().to_vec();
558 bytes.extend_from_slice(json.as_bytes());
559
560 let meta = detect_format_from_bytes(&bytes).unwrap();
561 assert_eq!(meta.format, ModelFormat::SafeTensors);
562 assert_eq!(meta.param_count, None); assert_eq!(meta.tensor_count, Some(1));
564 }
565
566 #[test]
567 fn onnx_too_short_after_ir_version() {
568 let bytes = [0x08, 0x09];
570 assert!(parse_onnx_header(&bytes).is_none());
571 }
572
573 #[test]
574 fn onnx_invalid_second_field() {
575 let bytes = [0x08, 0x09, 0x07];
577 assert!(parse_onnx_header(&bytes).is_none());
578 }
579
580 #[test]
581 fn onnx_valid_second_field() {
582 let bytes = [0x08, 0x09, 0x12, 0x05, b'o', b'n', b'n', b'x', b'!'];
584 let meta = detect_format_from_bytes(&bytes).unwrap();
585 assert_eq!(meta.format, ModelFormat::ONNX);
586 assert_eq!(meta.format_version, Some(9));
587 }
588
589 #[test]
590 fn random_0x08_not_onnx() {
591 let bytes = [0x08, 0x05, 0xFF, 0xFF]; assert!(parse_onnx_header(&bytes).is_none());
594 }
595}