Skip to main content

entrenar/io/
format.rs

1//! Serialization format definitions
2
3use serde::{Deserialize, Serialize};
4
5/// Supported model serialization formats
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum ModelFormat {
8    /// JSON format (human-readable, larger file size)
9    Json,
10
11    /// YAML format (human-readable, good for configs)
12    Yaml,
13
14    /// SafeTensors format (HuggingFace compatible, efficient binary)
15    SafeTensors,
16
17    /// APR format (sovereign stack universal format — atomic, binary, training+inference)
18    Apr,
19
20    /// Placeholder for future GGUF support
21    #[cfg(feature = "gguf")]
22    Gguf,
23}
24
25impl ModelFormat {
26    /// Get file extension for this format
27    pub fn extension(&self) -> &str {
28        match self {
29            ModelFormat::Json => "json",
30            ModelFormat::Yaml => "yaml",
31            ModelFormat::SafeTensors => "safetensors",
32            ModelFormat::Apr => "apr",
33            #[cfg(feature = "gguf")]
34            ModelFormat::Gguf => "gguf",
35        }
36    }
37
38    /// Detect format from file extension
39    pub fn from_extension(ext: &str) -> Option<Self> {
40        match ext.to_lowercase().as_str() {
41            "json" => Some(ModelFormat::Json),
42            "yaml" | "yml" => Some(ModelFormat::Yaml),
43            "safetensors" => Some(ModelFormat::SafeTensors),
44            "apr" => Some(ModelFormat::Apr),
45            #[cfg(feature = "gguf")]
46            "gguf" => Some(ModelFormat::Gguf),
47            _ => None,
48        }
49    }
50}
51
52/// Configuration for saving models
53#[derive(Debug, Clone)]
54pub struct SaveConfig {
55    /// Serialization format
56    pub format: ModelFormat,
57
58    /// Whether to pretty-print (for text formats)
59    pub pretty: bool,
60
61    /// Whether to compress the output
62    pub compress: bool,
63}
64
65impl SaveConfig {
66    /// Create new save config with format
67    pub fn new(format: ModelFormat) -> Self {
68        Self { format, pretty: true, compress: false }
69    }
70
71    /// Enable/disable pretty printing
72    pub fn with_pretty(mut self, pretty: bool) -> Self {
73        self.pretty = pretty;
74        self
75    }
76
77    /// Enable/disable compression
78    pub fn with_compress(mut self, compress: bool) -> Self {
79        self.compress = compress;
80        self
81    }
82}
83
84impl Default for SaveConfig {
85    fn default() -> Self {
86        Self::new(ModelFormat::Json).with_pretty(true)
87    }
88}
89
90#[cfg(test)]
91mod tests {
92    use super::*;
93
94    #[test]
95    fn test_format_extension() {
96        assert_eq!(ModelFormat::Json.extension(), "json");
97        assert_eq!(ModelFormat::Yaml.extension(), "yaml");
98        assert_eq!(ModelFormat::SafeTensors.extension(), "safetensors");
99        assert_eq!(ModelFormat::Apr.extension(), "apr");
100    }
101
102    #[test]
103    fn test_format_from_extension() {
104        assert_eq!(ModelFormat::from_extension("json"), Some(ModelFormat::Json));
105        assert_eq!(ModelFormat::from_extension("JSON"), Some(ModelFormat::Json));
106        assert_eq!(ModelFormat::from_extension("yaml"), Some(ModelFormat::Yaml));
107        assert_eq!(ModelFormat::from_extension("yml"), Some(ModelFormat::Yaml));
108        assert_eq!(ModelFormat::from_extension("safetensors"), Some(ModelFormat::SafeTensors));
109        assert_eq!(ModelFormat::from_extension("SAFETENSORS"), Some(ModelFormat::SafeTensors));
110        assert_eq!(ModelFormat::from_extension("apr"), Some(ModelFormat::Apr));
111        assert_eq!(ModelFormat::from_extension("unknown"), None);
112    }
113
114    #[test]
115    fn test_safetensors_format_serde() {
116        let format = ModelFormat::SafeTensors;
117        let serialized = serde_json::to_string(&format).expect("JSON serialization should succeed");
118        let deserialized: ModelFormat =
119            serde_json::from_str(&serialized).expect("JSON deserialization should succeed");
120        assert_eq!(format, deserialized);
121    }
122
123    #[test]
124    fn test_save_config_safetensors() {
125        let config = SaveConfig::new(ModelFormat::SafeTensors);
126        assert_eq!(config.format, ModelFormat::SafeTensors);
127        // pretty/compress don't apply to binary formats
128        assert!(config.pretty);
129    }
130
131    #[test]
132    fn test_save_config_builder() {
133        let config = SaveConfig::new(ModelFormat::Json).with_pretty(false).with_compress(true);
134
135        assert_eq!(config.format, ModelFormat::Json);
136        assert!(!config.pretty);
137        assert!(config.compress);
138    }
139
140    #[test]
141    fn test_save_config_default() {
142        let config = SaveConfig::default();
143        assert_eq!(config.format, ModelFormat::Json);
144        assert!(config.pretty);
145        assert!(!config.compress);
146    }
147
148    #[test]
149    fn test_model_format_serde() {
150        // Test serialization/deserialization
151        let format = ModelFormat::Json;
152        let serialized = serde_json::to_string(&format).expect("JSON serialization should succeed");
153        let deserialized: ModelFormat =
154            serde_json::from_str(&serialized).expect("JSON deserialization should succeed");
155        assert_eq!(format, deserialized);
156
157        let format_yaml = ModelFormat::Yaml;
158        let serialized =
159            serde_json::to_string(&format_yaml).expect("JSON serialization should succeed");
160        let deserialized: ModelFormat =
161            serde_json::from_str(&serialized).expect("JSON deserialization should succeed");
162        assert_eq!(format_yaml, deserialized);
163    }
164
165    #[test]
166    fn test_save_config_clone() {
167        let config = SaveConfig::new(ModelFormat::Yaml).with_compress(true);
168        let cloned = config.clone();
169        assert_eq!(config.format, cloned.format);
170        assert_eq!(config.compress, cloned.compress);
171    }
172}