axonml_serialize/
format.rs1use std::path::Path;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Format {
14 Axonml,
16 Json,
18 SafeTensors,
20}
21
22impl Format {
23 #[must_use] pub fn extension(&self) -> &'static str {
25 match self {
26 Format::Axonml => "axonml",
27 Format::Json => "json",
28 Format::SafeTensors => "safetensors",
29 }
30 }
31
32 #[must_use] pub fn name(&self) -> &'static str {
34 match self {
35 Format::Axonml => "Axonml Native",
36 Format::Json => "JSON",
37 Format::SafeTensors => "SafeTensors",
38 }
39 }
40
41 #[must_use] pub fn is_binary(&self) -> bool {
43 match self {
44 Format::Axonml => true,
45 Format::Json => false,
46 Format::SafeTensors => true,
47 }
48 }
49
50 #[must_use] pub fn supports_streaming(&self) -> bool {
52 match self {
53 Format::Axonml => true,
54 Format::Json => false,
55 Format::SafeTensors => true,
56 }
57 }
58
59 #[must_use] pub fn all() -> &'static [Format] {
61 &[Format::Axonml, Format::Json, Format::SafeTensors]
62 }
63}
64
65impl std::fmt::Display for Format {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 write!(f, "{}", self.name())
68 }
69}
70
71pub fn detect_format<P: AsRef<Path>>(path: P) -> Format {
77 let path = path.as_ref();
78
79 match path.extension().and_then(|e| e.to_str()) {
80 Some("axonml") => Format::Axonml,
81 Some("bin") => Format::Axonml,
82 Some("json") => Format::Json,
83 Some("safetensors") => Format::SafeTensors,
84 Some("st") => Format::SafeTensors,
85 _ => Format::Axonml, }
87}
88
89#[must_use] pub fn detect_format_from_bytes(bytes: &[u8]) -> Option<Format> {
91 if bytes.len() < 8 {
92 return None;
93 }
94
95 if bytes[0] == b'{' || bytes[0] == b'[' {
97 return Some(Format::Json);
98 }
99
100 if bytes.len() >= 16 {
104 let header_size = u64::from_le_bytes(bytes[0..8].try_into().ok()?);
105 if header_size < 10_000_000 && bytes.get(8) == Some(&b'{') {
106 return Some(Format::SafeTensors);
107 }
108 }
109
110 Some(Format::Axonml)
112}
113
114#[cfg(test)]
119mod tests {
120 use super::*;
121
122 #[test]
123 fn test_detect_format_from_extension() {
124 assert_eq!(detect_format("model.axonml"), Format::Axonml);
125 assert_eq!(detect_format("model.bin"), Format::Axonml);
126 assert_eq!(detect_format("model.json"), Format::Json);
127 assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
128 assert_eq!(detect_format("model.st"), Format::SafeTensors);
129 assert_eq!(detect_format("model.unknown"), Format::Axonml);
130 }
131
132 #[test]
133 fn test_format_properties() {
134 assert!(Format::Axonml.is_binary());
135 assert!(!Format::Json.is_binary());
136 assert!(Format::SafeTensors.is_binary());
137
138 assert_eq!(Format::Axonml.extension(), "axonml");
139 assert_eq!(Format::Json.extension(), "json");
140 }
141
142 #[test]
143 fn test_detect_format_from_bytes() {
144 assert_eq!(
146 detect_format_from_bytes(b"{\"key\": \"value\"}"),
147 Some(Format::Json)
148 );
149
150 assert_eq!(
152 detect_format_from_bytes(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]),
153 Some(Format::Axonml)
154 );
155 }
156
157 #[test]
158 fn test_format_display() {
159 assert_eq!(format!("{}", Format::Axonml), "Axonml Native");
160 assert_eq!(format!("{}", Format::Json), "JSON");
161 }
162}