axonml_serialize/
format.rs1use std::path::Path;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Format {
14 Axonml,
16 Json,
18 SafeTensors,
20}
21
22impl Format {
23 #[must_use]
25 pub fn extension(&self) -> &'static str {
26 match self {
27 Format::Axonml => "axonml",
28 Format::Json => "json",
29 Format::SafeTensors => "safetensors",
30 }
31 }
32
33 #[must_use]
35 pub fn name(&self) -> &'static str {
36 match self {
37 Format::Axonml => "Axonml Native",
38 Format::Json => "JSON",
39 Format::SafeTensors => "SafeTensors",
40 }
41 }
42
43 #[must_use]
45 pub fn is_binary(&self) -> bool {
46 match self {
47 Format::Axonml => true,
48 Format::Json => false,
49 Format::SafeTensors => true,
50 }
51 }
52
53 #[must_use]
55 pub fn supports_streaming(&self) -> bool {
56 match self {
57 Format::Axonml => true,
58 Format::Json => false,
59 Format::SafeTensors => true,
60 }
61 }
62
63 #[must_use]
65 pub fn all() -> &'static [Format] {
66 &[Format::Axonml, Format::Json, Format::SafeTensors]
67 }
68}
69
70impl std::fmt::Display for Format {
71 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72 write!(f, "{}", self.name())
73 }
74}
75
76pub fn detect_format<P: AsRef<Path>>(path: P) -> Format {
82 let path = path.as_ref();
83
84 match path.extension().and_then(|e| e.to_str()) {
85 Some("axonml") => Format::Axonml,
86 Some("bin") => Format::Axonml,
87 Some("json") => Format::Json,
88 Some("safetensors") => Format::SafeTensors,
89 Some("st") => Format::SafeTensors,
90 _ => Format::Axonml, }
92}
93
94#[must_use]
96pub fn detect_format_from_bytes(bytes: &[u8]) -> Option<Format> {
97 if bytes.len() < 8 {
98 return None;
99 }
100
101 if bytes[0] == b'{' || bytes[0] == b'[' {
103 return Some(Format::Json);
104 }
105
106 if bytes.len() >= 16 {
110 let header_size = u64::from_le_bytes(bytes[0..8].try_into().ok()?);
111 if header_size < 10_000_000 && bytes.get(8) == Some(&b'{') {
112 return Some(Format::SafeTensors);
113 }
114 }
115
116 Some(Format::Axonml)
118}
119
120#[cfg(test)]
125mod tests {
126 use super::*;
127
128 #[test]
129 fn test_detect_format_from_extension() {
130 assert_eq!(detect_format("model.axonml"), Format::Axonml);
131 assert_eq!(detect_format("model.bin"), Format::Axonml);
132 assert_eq!(detect_format("model.json"), Format::Json);
133 assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
134 assert_eq!(detect_format("model.st"), Format::SafeTensors);
135 assert_eq!(detect_format("model.unknown"), Format::Axonml);
136 }
137
138 #[test]
139 fn test_format_properties() {
140 assert!(Format::Axonml.is_binary());
141 assert!(!Format::Json.is_binary());
142 assert!(Format::SafeTensors.is_binary());
143
144 assert_eq!(Format::Axonml.extension(), "axonml");
145 assert_eq!(Format::Json.extension(), "json");
146 }
147
148 #[test]
149 fn test_detect_format_from_bytes() {
150 assert_eq!(
152 detect_format_from_bytes(b"{\"key\": \"value\"}"),
153 Some(Format::Json)
154 );
155
156 assert_eq!(
158 detect_format_from_bytes(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]),
159 Some(Format::Axonml)
160 );
161 }
162
163 #[test]
164 fn test_format_display() {
165 assert_eq!(format!("{}", Format::Axonml), "Axonml Native");
166 assert_eq!(format!("{}", Format::Json), "JSON");
167 }
168}