axonml_serialize/
format.rs1use std::path::Path;
18
19#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum Format {
26 Axonml,
28 Json,
30 SafeTensors,
32}
33
34impl Format {
35 #[must_use]
37 pub fn extension(&self) -> &'static str {
38 match self {
39 Format::Axonml => "axonml",
40 Format::Json => "json",
41 Format::SafeTensors => "safetensors",
42 }
43 }
44
45 #[must_use]
47 pub fn name(&self) -> &'static str {
48 match self {
49 Format::Axonml => "Axonml Native",
50 Format::Json => "JSON",
51 Format::SafeTensors => "SafeTensors",
52 }
53 }
54
55 #[must_use]
57 pub fn is_binary(&self) -> bool {
58 match self {
59 Format::Axonml => true,
60 Format::Json => false,
61 Format::SafeTensors => true,
62 }
63 }
64
65 #[must_use]
67 pub fn supports_streaming(&self) -> bool {
68 match self {
69 Format::Axonml => true,
70 Format::Json => false,
71 Format::SafeTensors => true,
72 }
73 }
74
75 #[must_use]
77 pub fn all() -> &'static [Format] {
78 &[Format::Axonml, Format::Json, Format::SafeTensors]
79 }
80}
81
82impl std::fmt::Display for Format {
83 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
84 write!(f, "{}", self.name())
85 }
86}
87
88pub fn detect_format<P: AsRef<Path>>(path: P) -> Format {
94 let path = path.as_ref();
95
96 match path.extension().and_then(|e| e.to_str()) {
97 Some("axonml") => Format::Axonml,
98 Some("bin") => Format::Axonml,
99 Some("json") => Format::Json,
100 Some("safetensors") => Format::SafeTensors,
101 Some("st") => Format::SafeTensors,
102 _ => Format::Axonml, }
104}
105
106#[must_use]
108pub fn detect_format_from_bytes(bytes: &[u8]) -> Option<Format> {
109 if bytes.len() < 8 {
110 return None;
111 }
112
113 if bytes[0] == b'{' || bytes[0] == b'[' {
115 return Some(Format::Json);
116 }
117
118 if bytes.len() >= 16 {
122 let header_size = u64::from_le_bytes(bytes[0..8].try_into().ok()?);
123 if header_size < 10_000_000 && bytes.get(8) == Some(&b'{') {
124 return Some(Format::SafeTensors);
125 }
126 }
127
128 Some(Format::Axonml)
130}
131
132#[cfg(test)]
137mod tests {
138 use super::*;
139
140 #[test]
141 fn test_detect_format_from_extension() {
142 assert_eq!(detect_format("model.axonml"), Format::Axonml);
143 assert_eq!(detect_format("model.bin"), Format::Axonml);
144 assert_eq!(detect_format("model.json"), Format::Json);
145 assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
146 assert_eq!(detect_format("model.st"), Format::SafeTensors);
147 assert_eq!(detect_format("model.unknown"), Format::Axonml);
148 }
149
150 #[test]
151 fn test_format_properties() {
152 assert!(Format::Axonml.is_binary());
153 assert!(!Format::Json.is_binary());
154 assert!(Format::SafeTensors.is_binary());
155
156 assert_eq!(Format::Axonml.extension(), "axonml");
157 assert_eq!(Format::Json.extension(), "json");
158 }
159
160 #[test]
161 fn test_detect_format_from_bytes() {
162 assert_eq!(
164 detect_format_from_bytes(b"{\"key\": \"value\"}"),
165 Some(Format::Json)
166 );
167
168 assert_eq!(
170 detect_format_from_bytes(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]),
171 Some(Format::Axonml)
172 );
173 }
174
175 #[test]
176 fn test_format_display() {
177 assert_eq!(format!("{}", Format::Axonml), "Axonml Native");
178 assert_eq!(format!("{}", Format::Json), "JSON");
179 }
180}