1use serde::{Deserialize, Serialize};
27use std::collections::HashMap;
28
29#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
35pub enum ModelFormat {
36 Gguf(GgufInfo),
38 SafeTensors(SafeTensorsInfo),
40 Apr(AprInfo),
42 Onnx(OnnxInfo),
44 PyTorch,
46 Unknown,
48}
49
50impl ModelFormat {
51 #[must_use]
53 pub fn name(&self) -> &'static str {
54 match self {
55 Self::Gguf(_) => "GGUF",
56 Self::SafeTensors(_) => "SafeTensors",
57 Self::Apr(_) => "APR",
58 Self::Onnx(_) => "ONNX",
59 Self::PyTorch => "PyTorch",
60 Self::Unknown => "Unknown",
61 }
62 }
63
64 #[must_use]
66 pub fn extension(&self) -> &'static str {
67 match self {
68 Self::Gguf(_) => ".gguf",
69 Self::SafeTensors(_) => ".safetensors",
70 Self::Apr(_) => ".apr",
71 Self::Onnx(_) => ".onnx",
72 Self::PyTorch => ".pt",
73 Self::Unknown => "",
74 }
75 }
76
77 #[must_use]
79 pub fn is_quantized(&self) -> bool {
80 match self {
81 Self::Gguf(info) => info.quantization.is_some(),
82 Self::Apr(info) => info.quantization.is_some(),
83 _ => false,
84 }
85 }
86
87 #[must_use]
89 pub fn quantization(&self) -> Option<&str> {
90 match self {
91 Self::Gguf(info) => info.quantization.as_deref(),
92 Self::Apr(info) => info.quantization.as_deref(),
93 _ => None,
94 }
95 }
96
97 #[must_use]
99 pub fn parameters(&self) -> Option<u64> {
100 match self {
101 Self::Gguf(info) => info.parameters,
102 Self::SafeTensors(info) => info.parameters,
103 Self::Apr(info) => info.parameters,
104 Self::Onnx(info) => info.parameters,
105 Self::PyTorch | Self::Unknown => None,
106 }
107 }
108}
109
110#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
116pub struct GgufInfo {
117 pub version: u32,
119 pub tensor_count: u64,
121 pub metadata_count: u64,
123 pub architecture: Option<String>,
125 pub quantization: Option<String>,
127 pub context_length: Option<u32>,
129 pub embedding_dim: Option<u32>,
131 pub num_layers: Option<u32>,
133 pub num_heads: Option<u32>,
135 pub vocab_size: Option<u32>,
137 pub parameters: Option<u64>,
139 pub name: Option<String>,
141 pub author: Option<String>,
143 pub license: Option<String>,
145}
146
147impl Default for GgufInfo {
148 fn default() -> Self {
149 Self {
150 version: 0,
151 tensor_count: 0,
152 metadata_count: 0,
153 architecture: None,
154 quantization: None,
155 context_length: None,
156 embedding_dim: None,
157 num_layers: None,
158 num_heads: None,
159 vocab_size: None,
160 parameters: None,
161 name: None,
162 author: None,
163 license: None,
164 }
165 }
166}
167
168#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
170pub struct SafeTensorsInfo {
171 pub tensor_count: usize,
173 pub tensors: HashMap<String, TensorInfo>,
175 pub metadata: HashMap<String, String>,
177 pub parameters: Option<u64>,
179 pub dtype: Option<String>,
181}
182
183impl Default for SafeTensorsInfo {
184 fn default() -> Self {
185 Self {
186 tensor_count: 0,
187 tensors: HashMap::new(),
188 metadata: HashMap::new(),
189 parameters: None,
190 dtype: None,
191 }
192 }
193}
194
195#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
197pub struct TensorInfo {
198 pub shape: Vec<usize>,
200 pub dtype: String,
202 pub offset: usize,
204}
205
206#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
208pub struct AprInfo {
209 pub version: u32,
211 pub model_type: String,
213 pub quantization: Option<String>,
215 pub compressed: bool,
217 pub encrypted: bool,
219 pub signed: bool,
221 pub parameters: Option<u64>,
223 pub checksum: Option<u32>,
225}
226
227impl Default for AprInfo {
228 fn default() -> Self {
229 Self {
230 version: 0,
231 model_type: String::new(),
232 quantization: None,
233 compressed: false,
234 encrypted: false,
235 signed: false,
236 parameters: None,
237 checksum: None,
238 }
239 }
240}
241
242#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
244pub struct OnnxInfo {
245 pub ir_version: u64,
247 pub producer_name: Option<String>,
249 pub producer_version: Option<String>,
251 pub description: Option<String>,
253 pub node_count: usize,
255 pub parameters: Option<u64>,
257}
258
259impl Default for OnnxInfo {
260 fn default() -> Self {
261 Self {
262 ir_version: 0,
263 producer_name: None,
264 producer_version: None,
265 description: None,
266 node_count: 0,
267 parameters: None,
268 }
269 }
270}
271
272mod magic {
278 pub(super) const GGUF: [u8; 4] = [0x47, 0x47, 0x55, 0x46];
280 pub(super) const _SAFETENSORS_MIN_HEADER: u64 = 8;
282 pub(super) const APR: [u8; 4] = [0x41, 0x50, 0x52, 0x00];
284 pub(super) const ONNX: [u8; 2] = [0x08, 0x00]; pub(super) const PYTORCH_ZIP: [u8; 2] = [0x50, 0x4B];
288 pub(super) const PYTORCH_PICKLE: u8 = 0x80;
289}
290
291#[must_use]
301pub fn detect_format(data: &[u8]) -> ModelFormat {
302 if data.len() < 8 {
303 return ModelFormat::Unknown;
304 }
305
306 if data[..4] == magic::GGUF {
308 return parse_gguf_header(data);
309 }
310
311 if data[..4] == magic::APR {
313 return parse_apr_header(data);
314 }
315
316 if let Some(info) = try_parse_safetensors(data) {
321 return ModelFormat::SafeTensors(info);
322 }
323
324 if data[..2] == magic::PYTORCH_ZIP || data[0] == magic::PYTORCH_PICKLE {
326 return ModelFormat::PyTorch;
327 }
328
329 if data[0] == magic::ONNX[0] {
331 if let Some(info) = try_parse_onnx(data) {
332 return ModelFormat::Onnx(info);
333 }
334 }
335
336 ModelFormat::Unknown
337}
338
339const FORMAT_EXTENSIONS: &[(&str, &str)] = &[
341 (".gguf", "GGUF"),
342 (".safetensors", "SafeTensors"),
343 (".apr", "APR"),
344 (".onnx", "ONNX"),
345 (".pt", "PyTorch"),
346 (".pth", "PyTorch"),
347 (".bin", "Binary"),
348];
349
350#[must_use]
352pub fn detect_format_from_path(path: &str) -> Option<&'static str> {
353 let path_lower = path.to_lowercase();
354 FORMAT_EXTENSIONS.iter().find(|(ext, _)| path_lower.ends_with(ext)).map(|(_, name)| *name)
355}
356
357fn parse_gguf_header(data: &[u8]) -> ModelFormat {
359 if data.len() < 24 {
360 return ModelFormat::Gguf(GgufInfo::default());
361 }
362
363 let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
370 let tensor_count = u64::from_le_bytes([
371 data[8], data[9], data[10], data[11], data[12], data[13], data[14], data[15],
372 ]);
373 let metadata_count = u64::from_le_bytes([
374 data[16], data[17], data[18], data[19], data[20], data[21], data[22], data[23],
375 ]);
376
377 ModelFormat::Gguf(GgufInfo { version, tensor_count, metadata_count, ..Default::default() })
381}
382
383fn parse_apr_header(data: &[u8]) -> ModelFormat {
385 if data.len() < 16 {
386 return ModelFormat::Apr(AprInfo::default());
387 }
388
389 let version = u32::from_le_bytes([data[4], data[5], data[6], data[7]]);
396 let flags = u32::from_le_bytes([data[8], data[9], data[10], data[11]]);
397
398 let compressed = (flags & 0x01) != 0;
399 let encrypted = (flags & 0x02) != 0;
400 let signed = (flags & 0x04) != 0;
401
402 let model_type_len = u32::from_le_bytes([data[12], data[13], data[14], data[15]]) as usize;
404 let model_type = if data.len() >= 16 + model_type_len {
405 String::from_utf8_lossy(&data[16..16 + model_type_len]).to_string()
406 } else {
407 String::new()
408 };
409
410 ModelFormat::Apr(AprInfo {
411 version,
412 model_type,
413 compressed,
414 encrypted,
415 signed,
416 ..Default::default()
417 })
418}
419
420fn read_safetensors_header(data: &[u8]) -> Option<&[u8]> {
422 if data.len() < 8 {
423 return None;
424 }
425
426 let header_size = u64::from_le_bytes([
427 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
428 ]) as usize;
429
430 if header_size == 0 || header_size > 100_000_000 {
431 return None;
432 }
433
434 if data.len() < 8 + header_size {
435 return None;
436 }
437
438 let header_json = &data[8..8 + header_size];
439 if header_json.first() != Some(&b'{') {
440 return None;
441 }
442 Some(header_json)
443}
444
445fn extract_metadata(header: &HashMap<String, serde_json::Value>, info: &mut SafeTensorsInfo) {
447 let Some(meta) = header.get("__metadata__") else { return };
448 let Some(obj) = meta.as_object() else { return };
449 for (k, v) in obj {
450 if let Some(s) = v.as_str() {
451 info.metadata.insert(k.clone(), s.to_string());
452 }
453 }
454}
455
456fn try_parse_safetensors(data: &[u8]) -> Option<SafeTensorsInfo> {
458 if data.len() >= 8 {
460 let header_size = u64::from_le_bytes([
461 data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7],
462 ]) as usize;
463 if header_size > 0
464 && header_size <= 100_000_000
465 && data.len() < 8 + header_size
466 && data.get(8) == Some(&b'{')
467 {
468 return Some(SafeTensorsInfo { tensor_count: 0, ..Default::default() });
469 }
470 }
471
472 let header_json = read_safetensors_header(data)?;
473 let header: HashMap<String, serde_json::Value> = serde_json::from_slice(header_json).ok()?;
474
475 let mut info = SafeTensorsInfo::default();
476 info.tensor_count = header.keys().filter(|k| *k != "__metadata__").count();
477 extract_metadata(&header, &mut info);
478 info.parameters = Some(extract_tensor_info(&header, &mut info));
479 Some(info)
480}
481
482fn extract_tensor_info(
484 header: &HashMap<String, serde_json::Value>,
485 info: &mut SafeTensorsInfo,
486) -> u64 {
487 let mut total_params: u64 = 0;
488 for (name, value) in header {
489 if name == "__metadata__" {
490 continue;
491 }
492 let Some(obj) = value.as_object() else {
493 continue;
494 };
495 let (Some(dtype), Some(shape)) = (obj.get("dtype"), obj.get("shape")) else {
496 continue;
497 };
498 let dtype_str = dtype.as_str().unwrap_or("F32").to_string();
499 let shape_vec: Vec<usize> = shape
500 .as_array()
501 .map(|arr| arr.iter().filter_map(|v| v.as_u64().map(|n| n as usize)).collect())
502 .unwrap_or_default();
503
504 let elements: u64 = shape_vec.iter().map(|&s| s as u64).product();
505 total_params += elements;
506
507 info.tensors.insert(
508 name.clone(),
509 TensorInfo { shape: shape_vec, dtype: dtype_str.clone(), offset: 0 },
510 );
511
512 if info.dtype.is_none() {
513 info.dtype = Some(dtype_str);
514 }
515 }
516 total_params
517}
518
519fn try_parse_onnx(data: &[u8]) -> Option<OnnxInfo> {
521 if data.len() < 16 {
525 return None;
526 }
527
528 if data[0] != 0x08 {
531 return None;
532 }
533
534 let (ir_version, _) = read_varint(&data[1..])?;
536
537 Some(OnnxInfo { ir_version, ..Default::default() })
538}
539
540fn read_varint(data: &[u8]) -> Option<(u64, usize)> {
542 let mut result: u64 = 0;
543 let mut shift = 0;
544
545 for (i, &byte) in data.iter().enumerate() {
546 if i >= 10 {
547 return None; }
549
550 result |= ((byte & 0x7F) as u64) << shift;
551 shift += 7;
552
553 if byte & 0x80 == 0 {
554 return Some((result, i + 1));
555 }
556 }
557
558 None
559}
560
561#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
567#[allow(non_camel_case_types)]
568pub enum QuantType {
569 F32,
571 F16,
573 BF16,
575 Q8_0,
577 Q8_K,
579 Q6_K,
581 Q5_K_S,
583 Q5_K_M,
585 Q5_0,
587 Q5_1,
589 Q4_K_S,
591 Q4_K_M,
593 Q4_0,
595 Q4_1,
597 Q3_K_S,
599 Q3_K_M,
601 Q3_K_L,
603 Q2_K_S,
605 Q2_K,
607 IQ4_NL,
609 IQ4_XS,
611 IQ3_S,
613 IQ3_M,
615 IQ3_XS,
617 IQ3_XXS,
619 IQ2_S,
621 IQ2_XS,
623 IQ2_XXS,
625 IQ1_S,
627 IQ1_M,
629}
630
631impl QuantType {
632 #[must_use]
634 pub fn from_str(s: &str) -> Option<Self> {
635 let s = s.to_uppercase();
636 match s.as_str() {
637 "F32" | "FP32" => Some(Self::F32),
638 "F16" | "FP16" => Some(Self::F16),
639 "BF16" => Some(Self::BF16),
640 "Q8_0" => Some(Self::Q8_0),
641 "Q8_K" => Some(Self::Q8_K),
642 "Q6_K" => Some(Self::Q6_K),
643 "Q5_K_S" => Some(Self::Q5_K_S),
644 "Q5_K_M" => Some(Self::Q5_K_M),
645 "Q5_0" => Some(Self::Q5_0),
646 "Q5_1" => Some(Self::Q5_1),
647 "Q4_K_S" => Some(Self::Q4_K_S),
648 "Q4_K_M" => Some(Self::Q4_K_M),
649 "Q4_0" => Some(Self::Q4_0),
650 "Q4_1" => Some(Self::Q4_1),
651 "Q3_K_S" => Some(Self::Q3_K_S),
652 "Q3_K_M" => Some(Self::Q3_K_M),
653 "Q3_K_L" => Some(Self::Q3_K_L),
654 "Q2_K_S" => Some(Self::Q2_K_S),
655 "Q2_K" => Some(Self::Q2_K),
656 "IQ4_NL" => Some(Self::IQ4_NL),
657 "IQ4_XS" => Some(Self::IQ4_XS),
658 "IQ3_S" => Some(Self::IQ3_S),
659 "IQ3_M" => Some(Self::IQ3_M),
660 "IQ3_XS" => Some(Self::IQ3_XS),
661 "IQ3_XXS" => Some(Self::IQ3_XXS),
662 "IQ2_S" => Some(Self::IQ2_S),
663 "IQ2_XS" => Some(Self::IQ2_XS),
664 "IQ2_XXS" => Some(Self::IQ2_XXS),
665 "IQ1_S" => Some(Self::IQ1_S),
666 "IQ1_M" => Some(Self::IQ1_M),
667 _ => None,
668 }
669 }
670
671 #[must_use]
673 pub const fn bits_per_weight(&self) -> f32 {
674 match self {
675 Self::F32 => 32.0,
676 Self::F16 | Self::BF16 => 16.0,
677 Self::Q8_0 | Self::Q8_K => 8.0,
678 Self::Q6_K => 6.5,
679 Self::Q5_K_S | Self::Q5_K_M | Self::Q5_0 | Self::Q5_1 => 5.5,
680 Self::Q4_K_S | Self::Q4_K_M | Self::Q4_0 | Self::Q4_1 => 4.5,
681 Self::Q3_K_S | Self::Q3_K_M | Self::Q3_K_L => 3.5,
682 Self::Q2_K_S | Self::Q2_K => 2.5,
683 Self::IQ4_NL | Self::IQ4_XS => 4.25,
684 Self::IQ3_S | Self::IQ3_M | Self::IQ3_XS | Self::IQ3_XXS => 3.0,
685 Self::IQ2_S | Self::IQ2_XS | Self::IQ2_XXS => 2.0,
686 Self::IQ1_S | Self::IQ1_M => 1.5,
687 }
688 }
689
690 #[must_use]
692 pub fn estimate_size(&self, parameters: u64) -> u64 {
693 let bits = self.bits_per_weight() as f64;
694 let bytes = (parameters as f64 * bits) / 8.0;
695 (bytes * 1.1) as u64
697 }
698
699 #[must_use]
701 pub const fn quality_tier(&self) -> u8 {
702 match self {
703 Self::F32 | Self::F16 | Self::BF16 => 5,
704 Self::Q8_0 | Self::Q8_K => 5,
705 Self::Q6_K => 4,
706 Self::Q5_K_S | Self::Q5_K_M | Self::Q5_0 | Self::Q5_1 => 4,
707 Self::Q4_K_S | Self::Q4_K_M | Self::Q4_0 | Self::Q4_1 => 3,
708 Self::IQ4_NL | Self::IQ4_XS => 3,
709 Self::Q3_K_S | Self::Q3_K_M | Self::Q3_K_L => 2,
710 Self::IQ3_S | Self::IQ3_M | Self::IQ3_XS | Self::IQ3_XXS => 2,
711 Self::Q2_K_S | Self::Q2_K => 1,
712 Self::IQ2_S | Self::IQ2_XS | Self::IQ2_XXS => 1,
713 Self::IQ1_S | Self::IQ1_M => 1,
714 }
715 }
716
717 #[must_use]
719 pub fn vram_requirement(&self, parameters: u64) -> f64 {
720 let model_size = self.estimate_size(parameters) as f64;
722 let context_overhead = 2.0 * 1024.0 * 1024.0 * 1024.0;
724 (model_size + context_overhead) / (1024.0 * 1024.0 * 1024.0)
726 }
727}
728
729impl std::fmt::Display for QuantType {
730 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
731 let s = match self {
732 Self::F32 => "F32",
733 Self::F16 => "F16",
734 Self::BF16 => "BF16",
735 Self::Q8_0 => "Q8_0",
736 Self::Q8_K => "Q8_K",
737 Self::Q6_K => "Q6_K",
738 Self::Q5_K_S => "Q5_K_S",
739 Self::Q5_K_M => "Q5_K_M",
740 Self::Q5_0 => "Q5_0",
741 Self::Q5_1 => "Q5_1",
742 Self::Q4_K_S => "Q4_K_S",
743 Self::Q4_K_M => "Q4_K_M",
744 Self::Q4_0 => "Q4_0",
745 Self::Q4_1 => "Q4_1",
746 Self::Q3_K_S => "Q3_K_S",
747 Self::Q3_K_M => "Q3_K_M",
748 Self::Q3_K_L => "Q3_K_L",
749 Self::Q2_K_S => "Q2_K_S",
750 Self::Q2_K => "Q2_K",
751 Self::IQ4_NL => "IQ4_NL",
752 Self::IQ4_XS => "IQ4_XS",
753 Self::IQ3_S => "IQ3_S",
754 Self::IQ3_M => "IQ3_M",
755 Self::IQ3_XS => "IQ3_XS",
756 Self::IQ3_XXS => "IQ3_XXS",
757 Self::IQ2_S => "IQ2_S",
758 Self::IQ2_XS => "IQ2_XS",
759 Self::IQ2_XXS => "IQ2_XXS",
760 Self::IQ1_S => "IQ1_S",
761 Self::IQ1_M => "IQ1_M",
762 };
763 write!(f, "{s}")
764 }
765}
766
767#[cfg(test)]
772mod tests {
773 use super::*;
774
775 #[test]
780 fn test_detect_gguf_format() {
781 let mut data = vec![0u8; 100];
783 data[0..4].copy_from_slice(&magic::GGUF);
784 data[4..8].copy_from_slice(&3u32.to_le_bytes());
785 data[8..16].copy_from_slice(&100u64.to_le_bytes());
786 data[16..24].copy_from_slice(&50u64.to_le_bytes());
787
788 let format = detect_format(&data);
789 assert!(matches!(format, ModelFormat::Gguf(_)));
790
791 if let ModelFormat::Gguf(info) = format {
792 assert_eq!(info.version, 3);
793 assert_eq!(info.tensor_count, 100);
794 assert_eq!(info.metadata_count, 50);
795 }
796 }
797
798 #[test]
799 fn test_detect_apr_format() {
800 let mut data = vec![0u8; 100];
802 data[0..4].copy_from_slice(&magic::APR);
803 data[4..8].copy_from_slice(&1u32.to_le_bytes());
804 data[8..12].copy_from_slice(&0x05u32.to_le_bytes()); data[12..16].copy_from_slice(&4u32.to_le_bytes()); data[16..20].copy_from_slice(b"Test");
807
808 let format = detect_format(&data);
809 assert!(matches!(format, ModelFormat::Apr(_)));
810
811 if let ModelFormat::Apr(info) = format {
812 assert_eq!(info.version, 1);
813 assert!(info.compressed);
814 assert!(!info.encrypted);
815 assert!(info.signed);
816 assert_eq!(info.model_type, "Test");
817 }
818 }
819
820 #[test]
821 fn test_detect_pytorch_zip_format() {
822 let mut data = vec![0u8; 100];
823 data[0..2].copy_from_slice(&magic::PYTORCH_ZIP);
824
825 let format = detect_format(&data);
826 assert!(matches!(format, ModelFormat::PyTorch));
827 }
828
829 #[test]
830 fn test_detect_pytorch_pickle_format() {
831 let mut data = vec![0u8; 100];
832 data[0] = magic::PYTORCH_PICKLE;
833
834 let format = detect_format(&data);
835 assert!(matches!(format, ModelFormat::PyTorch));
836 }
837
838 #[test]
839 fn test_detect_safetensors_format() {
840 let header = r#"{"tensor1":{"dtype":"F32","shape":[100,100]}}"#;
842 let header_bytes = header.as_bytes();
843 let header_size = header_bytes.len() as u64;
844
845 let mut data = Vec::new();
846 data.extend_from_slice(&header_size.to_le_bytes());
847 data.extend_from_slice(header_bytes);
848
849 let format = detect_format(&data);
850 assert!(matches!(format, ModelFormat::SafeTensors(_)));
851
852 if let ModelFormat::SafeTensors(info) = format {
853 assert_eq!(info.tensor_count, 1);
854 assert!(info.tensors.contains_key("tensor1"));
855 assert_eq!(info.parameters, Some(10000)); }
857 }
858
859 #[test]
860 fn test_detect_unknown_format() {
861 let data = vec![0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07];
862 let format = detect_format(&data);
863 assert!(matches!(format, ModelFormat::Unknown));
864 }
865
866 #[test]
867 fn test_detect_empty_data() {
868 let data = vec![];
869 let format = detect_format(&data);
870 assert!(matches!(format, ModelFormat::Unknown));
871 }
872
873 #[test]
874 fn test_detect_short_data() {
875 let data = vec![0x47, 0x47, 0x55]; let format = detect_format(&data);
877 assert!(matches!(format, ModelFormat::Unknown));
878 }
879
880 #[test]
885 fn test_model_format_name() {
886 assert_eq!(ModelFormat::Gguf(GgufInfo::default()).name(), "GGUF");
887 assert_eq!(ModelFormat::SafeTensors(SafeTensorsInfo::default()).name(), "SafeTensors");
888 assert_eq!(ModelFormat::Apr(AprInfo::default()).name(), "APR");
889 assert_eq!(ModelFormat::Onnx(OnnxInfo::default()).name(), "ONNX");
890 assert_eq!(ModelFormat::PyTorch.name(), "PyTorch");
891 assert_eq!(ModelFormat::Unknown.name(), "Unknown");
892 }
893
894 #[test]
895 fn test_model_format_extension() {
896 assert_eq!(ModelFormat::Gguf(GgufInfo::default()).extension(), ".gguf");
897 assert_eq!(
898 ModelFormat::SafeTensors(SafeTensorsInfo::default()).extension(),
899 ".safetensors"
900 );
901 assert_eq!(ModelFormat::Apr(AprInfo::default()).extension(), ".apr");
902 }
903
904 #[test]
905 fn test_model_format_is_quantized() {
906 let gguf_quant = ModelFormat::Gguf(GgufInfo {
907 quantization: Some("Q4_K_M".to_string()),
908 ..Default::default()
909 });
910 assert!(gguf_quant.is_quantized());
911
912 let gguf_no_quant = ModelFormat::Gguf(GgufInfo::default());
913 assert!(!gguf_no_quant.is_quantized());
914 }
915
916 #[test]
917 fn test_model_format_quantization() {
918 let format = ModelFormat::Gguf(GgufInfo {
919 quantization: Some("Q8_0".to_string()),
920 ..Default::default()
921 });
922 assert_eq!(format.quantization(), Some("Q8_0"));
923 }
924
925 #[test]
926 fn test_model_format_parameters() {
927 let format =
928 ModelFormat::Gguf(GgufInfo { parameters: Some(7_000_000_000), ..Default::default() });
929 assert_eq!(format.parameters(), Some(7_000_000_000));
930
931 assert_eq!(ModelFormat::PyTorch.parameters(), None);
932 }
933
934 #[test]
939 fn test_detect_format_from_path() {
940 assert_eq!(detect_format_from_path("model.gguf"), Some("GGUF"));
941 assert_eq!(detect_format_from_path("model.GGUF"), Some("GGUF"));
942 assert_eq!(detect_format_from_path("model.safetensors"), Some("SafeTensors"));
943 assert_eq!(detect_format_from_path("model.apr"), Some("APR"));
944 assert_eq!(detect_format_from_path("model.onnx"), Some("ONNX"));
945 assert_eq!(detect_format_from_path("model.pt"), Some("PyTorch"));
946 assert_eq!(detect_format_from_path("model.pth"), Some("PyTorch"));
947 assert_eq!(detect_format_from_path("model.bin"), Some("Binary"));
948 assert_eq!(detect_format_from_path("model.txt"), None);
949 }
950
951 #[test]
956 fn test_quant_type_from_str() {
957 assert_eq!(QuantType::from_str("Q4_K_M"), Some(QuantType::Q4_K_M));
958 assert_eq!(QuantType::from_str("q4_k_m"), Some(QuantType::Q4_K_M));
959 assert_eq!(QuantType::from_str("F16"), Some(QuantType::F16));
960 assert_eq!(QuantType::from_str("fp16"), Some(QuantType::F16));
961 assert_eq!(QuantType::from_str("invalid"), None);
962 }
963
964 #[test]
965 fn test_quant_type_bits_per_weight() {
966 assert!((QuantType::F32.bits_per_weight() - 32.0).abs() < f32::EPSILON);
967 assert!((QuantType::F16.bits_per_weight() - 16.0).abs() < f32::EPSILON);
968 assert!((QuantType::Q8_0.bits_per_weight() - 8.0).abs() < f32::EPSILON);
969 assert!((QuantType::Q4_K_M.bits_per_weight() - 4.5).abs() < f32::EPSILON);
970 }
971
972 #[test]
973 fn test_quant_type_estimate_size() {
974 let params = 7_000_000_000u64; let f32_size = QuantType::F32.estimate_size(params);
978 assert!(f32_size > 28_000_000_000 && f32_size < 32_000_000_000);
979
980 let q4_size = QuantType::Q4_K_M.estimate_size(params);
982 assert!(q4_size > 4_000_000_000 && q4_size < 5_000_000_000);
983 }
984
985 #[test]
986 fn test_quant_type_quality_tier() {
987 assert_eq!(QuantType::F32.quality_tier(), 5);
988 assert_eq!(QuantType::Q8_0.quality_tier(), 5);
989 assert_eq!(QuantType::Q4_K_M.quality_tier(), 3);
990 assert_eq!(QuantType::Q2_K.quality_tier(), 1);
991 }
992
993 #[test]
994 fn test_quant_type_vram_requirement() {
995 let params = 7_000_000_000u64;
996 let vram_f32 = QuantType::F32.vram_requirement(params);
997 let vram_q4 = QuantType::Q4_K_M.vram_requirement(params);
998
999 assert!(vram_f32 > vram_q4);
1001 assert!(vram_f32 > 0.0);
1003 assert!(vram_q4 > 0.0);
1004 }
1005
1006 #[test]
1007 fn test_quant_type_display() {
1008 assert_eq!(format!("{}", QuantType::Q4_K_M), "Q4_K_M");
1009 assert_eq!(format!("{}", QuantType::F16), "F16");
1010 assert_eq!(format!("{}", QuantType::IQ3_XXS), "IQ3_XXS");
1011 }
1012
1013 #[test]
1018 fn test_gguf_info_serialization() {
1019 let info = GgufInfo {
1020 version: 3,
1021 tensor_count: 100,
1022 metadata_count: 50,
1023 architecture: Some("llama".to_string()),
1024 quantization: Some("Q4_K_M".to_string()),
1025 ..Default::default()
1026 };
1027
1028 let json = serde_json::to_string(&info).unwrap();
1029 assert!(json.contains("llama"));
1030 assert!(json.contains("Q4_K_M"));
1031
1032 let parsed: GgufInfo = serde_json::from_str(&json).unwrap();
1033 assert_eq!(parsed.version, 3);
1034 assert_eq!(parsed.architecture, Some("llama".to_string()));
1035 }
1036
1037 #[test]
1038 fn test_safetensors_info_serialization() {
1039 let mut tensors = HashMap::new();
1040 tensors.insert(
1041 "weight".to_string(),
1042 TensorInfo { shape: vec![100, 100], dtype: "F32".to_string(), offset: 0 },
1043 );
1044
1045 let info = SafeTensorsInfo {
1046 tensor_count: 1,
1047 tensors,
1048 parameters: Some(10000),
1049 ..Default::default()
1050 };
1051
1052 let json = serde_json::to_string(&info).unwrap();
1053 let parsed: SafeTensorsInfo = serde_json::from_str(&json).unwrap();
1054
1055 assert_eq!(parsed.tensor_count, 1);
1056 assert_eq!(parsed.parameters, Some(10000));
1057 }
1058
1059 #[test]
1060 fn test_model_format_serialization() {
1061 let format = ModelFormat::Gguf(GgufInfo { version: 3, ..Default::default() });
1062
1063 let json = serde_json::to_string(&format).unwrap();
1064 let parsed: ModelFormat = serde_json::from_str(&json).unwrap();
1065
1066 assert!(matches!(parsed, ModelFormat::Gguf(_)));
1067 }
1068
1069 #[test]
1070 fn test_quant_type_serialization() {
1071 let qt = QuantType::Q4_K_M;
1072 let json = serde_json::to_string(&qt).unwrap();
1073 let parsed: QuantType = serde_json::from_str(&json).unwrap();
1074 assert_eq!(parsed, QuantType::Q4_K_M);
1075 }
1076
1077 #[test]
1082 fn test_safetensors_with_metadata() {
1083 let header = r#"{"__metadata__":{"format":"pt"},"tensor1":{"dtype":"F16","shape":[512]}}"#;
1084 let header_bytes = header.as_bytes();
1085 let header_size = header_bytes.len() as u64;
1086
1087 let mut data = Vec::new();
1088 data.extend_from_slice(&header_size.to_le_bytes());
1089 data.extend_from_slice(header_bytes);
1090
1091 let format = detect_format(&data);
1092 if let ModelFormat::SafeTensors(info) = format {
1093 assert_eq!(info.tensor_count, 1); assert_eq!(info.metadata.get("format"), Some(&"pt".to_string()));
1095 assert_eq!(info.dtype, Some("F16".to_string()));
1096 } else {
1097 panic!("Expected SafeTensors format");
1098 }
1099 }
1100
1101 #[test]
1102 fn test_pacha4_safetensors_header_size_0x80_not_pytorch() {
1103 let header = r#"{"__metadata__":{"format":"pt"},"tensor1":{"dtype":"F32","shape":[32],"data_offsets":[0,128]}}"#;
1110 let header_bytes = header.as_bytes();
1111 let target_size = 128usize; assert!(header_bytes.len() <= target_size, "header too large for test");
1115 let padding = target_size - header_bytes.len();
1116
1117 let mut data = Vec::new();
1118 data.extend_from_slice(&(target_size as u64).to_le_bytes());
1119 data.extend_from_slice(header_bytes);
1120 let padded_header = format!(
1123 r#"{{"__metadata__":{{"format":"pt"}},"tensor1":{{"dtype":"F32","shape":[32],"data_offsets":[0,128]}}{}}}"#,
1124 " ".repeat(padding)
1125 );
1126 let padded_bytes = padded_header.as_bytes();
1127
1128 let mut data2 = Vec::new();
1129 data2.extend_from_slice(&(padded_bytes.len() as u64).to_le_bytes());
1130 data2.extend_from_slice(padded_bytes);
1131 data2.extend_from_slice(&[0u8; 128]);
1133
1134 assert_eq!(data2[0], 0x80, "Test setup: first byte must be 0x80");
1136
1137 let format = detect_format(&data2);
1138 match format {
1139 ModelFormat::SafeTensors(info) => {
1140 assert_eq!(info.tensor_count, 1);
1141 assert_eq!(info.metadata.get("format"), Some(&"pt".to_string()));
1142 }
1143 other => panic!("Expected SafeTensors but got {:?} — pacha#4 regression", other),
1144 }
1145 }
1146
1147 #[test]
1148 fn test_safetensors_invalid_header_size() {
1149 let header_size = 1_000_000u64;
1151 let mut data = Vec::new();
1152 data.extend_from_slice(&header_size.to_le_bytes());
1153 data.extend_from_slice(b"{}");
1154
1155 let format = detect_format(&data);
1156 assert!(matches!(format, ModelFormat::SafeTensors(_)));
1158 }
1159
1160 #[test]
1161 fn test_gguf_info_default() {
1162 let info = GgufInfo::default();
1163 assert_eq!(info.version, 0);
1164 assert!(info.architecture.is_none());
1165 }
1166
1167 #[test]
1168 fn test_apr_info_default() {
1169 let info = AprInfo::default();
1170 assert_eq!(info.version, 0);
1171 assert!(!info.compressed);
1172 assert!(!info.encrypted);
1173 }
1174}