Skip to main content

axonml_serialize/
format.rs

1//! Format Detection and Management
2//!
3//! Handles different serialization formats for model storage.
4
5use std::path::Path;
6
7// =============================================================================
8// Format Enum
9// =============================================================================
10
11/// Supported serialization formats.
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum Format {
14    /// Axonml native binary format (.axonml, .bin)
15    Axonml,
16    /// JSON format for debugging (.json)
17    Json,
18    /// `SafeTensors` format (.safetensors)
19    SafeTensors,
20}
21
22impl Format {
23    /// Get the file extension for this format.
24    #[must_use]
25    pub fn extension(&self) -> &'static str {
26        match self {
27            Format::Axonml => "axonml",
28            Format::Json => "json",
29            Format::SafeTensors => "safetensors",
30        }
31    }
32
33    /// Get a human-readable name for this format.
34    #[must_use]
35    pub fn name(&self) -> &'static str {
36        match self {
37            Format::Axonml => "Axonml Native",
38            Format::Json => "JSON",
39            Format::SafeTensors => "SafeTensors",
40        }
41    }
42
43    /// Check if this format is binary.
44    #[must_use]
45    pub fn is_binary(&self) -> bool {
46        match self {
47            Format::Axonml => true,
48            Format::Json => false,
49            Format::SafeTensors => true,
50        }
51    }
52
53    /// Check if this format supports streaming.
54    #[must_use]
55    pub fn supports_streaming(&self) -> bool {
56        match self {
57            Format::Axonml => true,
58            Format::Json => false,
59            Format::SafeTensors => true,
60        }
61    }
62
63    /// Get all supported formats.
64    #[must_use]
65    pub fn all() -> &'static [Format] {
66        &[Format::Axonml, Format::Json, Format::SafeTensors]
67    }
68}
69
70impl std::fmt::Display for Format {
71    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
72        write!(f, "{}", self.name())
73    }
74}
75
76// =============================================================================
77// Format Detection
78// =============================================================================
79
80/// Detect the format from a file path based on extension.
81pub fn detect_format<P: AsRef<Path>>(path: P) -> Format {
82    let path = path.as_ref();
83
84    match path.extension().and_then(|e| e.to_str()) {
85        Some("axonml") => Format::Axonml,
86        Some("bin") => Format::Axonml,
87        Some("json") => Format::Json,
88        Some("safetensors") => Format::SafeTensors,
89        Some("st") => Format::SafeTensors,
90        _ => Format::Axonml, // default
91    }
92}
93
94/// Detect format from file contents (magic bytes).
95#[must_use]
96pub fn detect_format_from_bytes(bytes: &[u8]) -> Option<Format> {
97    if bytes.len() < 8 {
98        return None;
99    }
100
101    // Check for JSON (starts with '{' or '[')
102    if bytes[0] == b'{' || bytes[0] == b'[' {
103        return Some(Format::Json);
104    }
105
106    // SafeTensors has a specific header format
107    // First 8 bytes are the header size as u64 little-endian
108    // Then the header is JSON
109    if bytes.len() >= 16 {
110        let header_size = u64::from_le_bytes(bytes[0..8].try_into().ok()?);
111        if header_size < 10_000_000 && bytes.get(8) == Some(&b'{') {
112            return Some(Format::SafeTensors);
113        }
114    }
115
116    // Default to Axonml binary format
117    Some(Format::Axonml)
118}
119
120// =============================================================================
121// Tests
122// =============================================================================
123
124#[cfg(test)]
125mod tests {
126    use super::*;
127
128    #[test]
129    fn test_detect_format_from_extension() {
130        assert_eq!(detect_format("model.axonml"), Format::Axonml);
131        assert_eq!(detect_format("model.bin"), Format::Axonml);
132        assert_eq!(detect_format("model.json"), Format::Json);
133        assert_eq!(detect_format("model.safetensors"), Format::SafeTensors);
134        assert_eq!(detect_format("model.st"), Format::SafeTensors);
135        assert_eq!(detect_format("model.unknown"), Format::Axonml);
136    }
137
138    #[test]
139    fn test_format_properties() {
140        assert!(Format::Axonml.is_binary());
141        assert!(!Format::Json.is_binary());
142        assert!(Format::SafeTensors.is_binary());
143
144        assert_eq!(Format::Axonml.extension(), "axonml");
145        assert_eq!(Format::Json.extension(), "json");
146    }
147
148    #[test]
149    fn test_detect_format_from_bytes() {
150        // JSON
151        assert_eq!(
152            detect_format_from_bytes(b"{\"key\": \"value\"}"),
153            Some(Format::Json)
154        );
155
156        // Binary (default to Axonml)
157        assert_eq!(
158            detect_format_from_bytes(&[0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07]),
159            Some(Format::Axonml)
160        );
161    }
162
163    #[test]
164    fn test_format_display() {
165        assert_eq!(format!("{}", Format::Axonml), "Axonml Native");
166        assert_eq!(format!("{}", Format::Json), "JSON");
167    }
168}