Skip to main content

axonml_serialize/
format.rs

1//! Format Detection and Management
2//!
3//! Handles different serialization formats for model storage.
4
5use std::path::Path;
6
7// =============================================================================
8// Format Enum
9// =============================================================================
10
11/// Supported serialization formats.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Format {
14    /// Axonml native binary format (.axonml, .bin)
15    Axonml,
16    /// JSON format for debugging (.json)
17    Json,
18    /// `SafeTensors` format (.safetensors)
19    SafeTensors,
20}
21
22impl Format {
23    /// Get the file extension for this format.
24    #[must_use] pub fn extension(&self) -> &'static str {
25        match self {
26            Format::Axonml => "axonml",
27            Format::Json => "json",
28            Format::SafeTensors => "safetensors",
29        }
30    }
31
32    /// Get a human-readable name for this format.
33    #[must_use] pub fn name(&self) -> &'static str {
34        match self {
35            Format::Axonml => "Axonml Native",
36            Format::Json => "JSON",
37            Format::SafeTensors => "SafeTensors",
38        }
39    }
40
41    /// Check if this format is binary.
42    #[must_use] pub fn is_binary(&self) -> bool {
43        match self {
44            Format::Axonml => true,
45            Format::Json => false,
46            Format::SafeTensors => true,
47        }
48    }
49
50    /// Check if this format supports streaming.
51    #[must_use] pub fn supports_streaming(&self) -> bool {
52        match self {
53            Format::Axonml => true,
54            Format::Json => false,
55            Format::SafeTensors => true,
56        }
57    }
58
59    /// Get all supported formats.
60    #[must_use] pub fn all() -> &'static [Format] {
61        &[Format::Axonml, Format::Json, Format::SafeTensors]
62    }
63}
64
65impl std::fmt::Display for Format {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        write!(f, "{}", self.name())
68    }
69}
70
71// =============================================================================
72// Format Detection
73// =============================================================================
74
75/// Detect the format from a file path based on extension.
76pub fn detect_format<P: AsRef<Path>>(path: P) -> Format {
77    let path = path.as_ref();
78
79    match path.extension().and_then(|e| e.to_str()) {
80        Some("axonml") => Format::Axonml,
81        Some("bin") => Format::Axonml,
82        Some("json") => Format::Json,
83        Some("safetensors") => Format::SafeTensors,
84        Some("st") => Format::SafeTensors,
85        _ => Format::Axonml, // default
86    }
87}
88
89/// Detect format from file contents (magic bytes).
90#[must_use] pub fn detect_format_from_bytes(bytes: &[u8]) -> Option<Format> {
91    if bytes.len() < 8 {
92        return None;
93    }
94
95    // Check for JSON (starts with '{' or '[')
96    if bytes[0] == b'{' || bytes[0] == b'[' {
97        return Some(Format::Json);
98    }
99
100    // SafeTensors has a specific header format
101    // First 8 bytes are the header size as u64 little-endian
102    // Then the header is JSON
103    if bytes.len() >= 16 {
104        let header_size = u64::from_le_bytes(bytes[0..8].try_into().ok()?);
105        if header_size < 10_000_000 && bytes.get(8) == Some(&b'{') {
106            return Some(Format::SafeTensors);
107        }
108    }
109
110    // Default to Axonml binary format
111    Some(Format::Axonml)
112}
113
114// =============================================================================
115// Tests
116// =============================================================================
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121
122    #[test]
123    fn test_detect_format_from_extension() {
124        assert_eq!(detect_format("model.axonml"), Format::Axonml);
125        assert_eq!(detect_format("model.bin"), Format::Axonml);
126        assert_eq!(detect_format("model.json"), Format::Json);
127        assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
128        assert_eq!(detect_format("model.st"), Format::SafeTensors);
129        assert_eq!(detect_format("model.unknown"), Format::Axonml);
130    }
131
132    #[test]
133    fn test_format_properties() {
134        assert!(Format::Axonml.is_binary());
135        assert!(!Format::Json.is_binary());
136        assert!(Format::SafeTensors.is_binary());
137
138        assert_eq!(Format::Axonml.extension(), "axonml");
139        assert_eq!(Format::Json.extension(), "json");
140    }
141
142    #[test]
143    fn test_detect_format_from_bytes() {
144        // JSON
145        assert_eq!(
146            detect_format_from_bytes(b"{\"key\": \"value\"}"),
147            Some(Format::Json)
148        );
149
150        // Binary (default to Axonml)
151        assert_eq!(
152            detect_format_from_bytes(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]),
153            Some(Format::Axonml)
154        );
155    }
156
157    #[test]
158    fn test_format_display() {
159        assert_eq!(format!("{}", Format::Axonml), "Axonml Native");
160        assert_eq!(format!("{}", Format::Json), "JSON");
161    }
162}