1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
7pub enum ModelFormat {
8 Json,
10
11 Yaml,
13
14 SafeTensors,
16
17 Apr,
19
20 #[cfg(feature = "gguf")]
22 Gguf,
23}
24
25impl ModelFormat {
26 pub fn extension(&self) -> &str {
28 match self {
29 ModelFormat::Json => "json",
30 ModelFormat::Yaml => "yaml",
31 ModelFormat::SafeTensors => "safetensors",
32 ModelFormat::Apr => "apr",
33 #[cfg(feature = "gguf")]
34 ModelFormat::Gguf => "gguf",
35 }
36 }
37
38 pub fn from_extension(ext: &str) -> Option<Self> {
40 match ext.to_lowercase().as_str() {
41 "json" => Some(ModelFormat::Json),
42 "yaml" | "yml" => Some(ModelFormat::Yaml),
43 "safetensors" => Some(ModelFormat::SafeTensors),
44 "apr" => Some(ModelFormat::Apr),
45 #[cfg(feature = "gguf")]
46 "gguf" => Some(ModelFormat::Gguf),
47 _ => None,
48 }
49 }
50}
51
52#[derive(Debug, Clone)]
54pub struct SaveConfig {
55 pub format: ModelFormat,
57
58 pub pretty: bool,
60
61 pub compress: bool,
63}
64
65impl SaveConfig {
66 pub fn new(format: ModelFormat) -> Self {
68 Self { format, pretty: true, compress: false }
69 }
70
71 pub fn with_pretty(mut self, pretty: bool) -> Self {
73 self.pretty = pretty;
74 self
75 }
76
77 pub fn with_compress(mut self, compress: bool) -> Self {
79 self.compress = compress;
80 self
81 }
82}
83
84impl Default for SaveConfig {
85 fn default() -> Self {
86 Self::new(ModelFormat::Json).with_pretty(true)
87 }
88}
89
90#[cfg(test)]
91mod tests {
92 use super::*;
93
94 #[test]
95 fn test_format_extension() {
96 assert_eq!(ModelFormat::Json.extension(), "json");
97 assert_eq!(ModelFormat::Yaml.extension(), "yaml");
98 assert_eq!(ModelFormat::SafeTensors.extension(), "safetensors");
99 assert_eq!(ModelFormat::Apr.extension(), "apr");
100 }
101
102 #[test]
103 fn test_format_from_extension() {
104 assert_eq!(ModelFormat::from_extension("json"), Some(ModelFormat::Json));
105 assert_eq!(ModelFormat::from_extension("JSON"), Some(ModelFormat::Json));
106 assert_eq!(ModelFormat::from_extension("yaml"), Some(ModelFormat::Yaml));
107 assert_eq!(ModelFormat::from_extension("yml"), Some(ModelFormat::Yaml));
108 assert_eq!(ModelFormat::from_extension("safetensors"), Some(ModelFormat::SafeTensors));
109 assert_eq!(ModelFormat::from_extension("SAFETENSORS"), Some(ModelFormat::SafeTensors));
110 assert_eq!(ModelFormat::from_extension("apr"), Some(ModelFormat::Apr));
111 assert_eq!(ModelFormat::from_extension("unknown"), None);
112 }
113
114 #[test]
115 fn test_safetensors_format_serde() {
116 let format = ModelFormat::SafeTensors;
117 let serialized = serde_json::to_string(&format).expect("JSON serialization should succeed");
118 let deserialized: ModelFormat =
119 serde_json::from_str(&serialized).expect("JSON deserialization should succeed");
120 assert_eq!(format, deserialized);
121 }
122
123 #[test]
124 fn test_save_config_safetensors() {
125 let config = SaveConfig::new(ModelFormat::SafeTensors);
126 assert_eq!(config.format, ModelFormat::SafeTensors);
127 assert!(config.pretty);
129 }
130
131 #[test]
132 fn test_save_config_builder() {
133 let config = SaveConfig::new(ModelFormat::Json).with_pretty(false).with_compress(true);
134
135 assert_eq!(config.format, ModelFormat::Json);
136 assert!(!config.pretty);
137 assert!(config.compress);
138 }
139
140 #[test]
141 fn test_save_config_default() {
142 let config = SaveConfig::default();
143 assert_eq!(config.format, ModelFormat::Json);
144 assert!(config.pretty);
145 assert!(!config.compress);
146 }
147
148 #[test]
149 fn test_model_format_serde() {
150 let format = ModelFormat::Json;
152 let serialized = serde_json::to_string(&format).expect("JSON serialization should succeed");
153 let deserialized: ModelFormat =
154 serde_json::from_str(&serialized).expect("JSON deserialization should succeed");
155 assert_eq!(format, deserialized);
156
157 let format_yaml = ModelFormat::Yaml;
158 let serialized =
159 serde_json::to_string(&format_yaml).expect("JSON serialization should succeed");
160 let deserialized: ModelFormat =
161 serde_json::from_str(&serialized).expect("JSON deserialization should succeed");
162 assert_eq!(format_yaml, deserialized);
163 }
164
165 #[test]
166 fn test_save_config_clone() {
167 let config = SaveConfig::new(ModelFormat::Yaml).with_compress(true);
168 let cloned = config.clone();
169 assert_eq!(config.format, cloned.format);
170 assert_eq!(config.compress, cloned.compress);
171 }
172}