1use crate::errors::{QuantizeError, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Config {
16 #[serde(default = "default_bits")]
18 pub bits: u8,
19
20 #[serde(default)]
22 pub per_channel: bool,
23
24 #[serde(default)]
26 pub excluded_layers: Vec<String>,
27
28 #[serde(default)]
31 pub min_elements: usize,
32
33 #[serde(default)]
35 pub models: Vec<ModelConfig>,
36
37 #[serde(default)]
39 pub batch: Option<BatchConfig>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelConfig {
45 pub input: String,
47
48 pub output: String,
50
51 #[serde(default)]
53 pub bits: Option<u8>,
54
55 #[serde(default)]
57 pub per_channel: Option<bool>,
58
59 #[serde(default)]
61 pub skip_existing: bool,
62
63 #[serde(default)]
66 pub excluded_layers: Vec<String>,
67
68 #[serde(default)]
71 pub layer_bits: std::collections::HashMap<String, u8>,
72
73 #[serde(default)]
75 pub min_elements: Option<usize>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct BatchConfig {
81 pub input_dir: String,
83
84 pub output_dir: String,
86
87 #[serde(default)]
89 pub skip_existing: bool,
90
91 #[serde(default)]
93 pub continue_on_error: bool,
94}
95
96fn default_bits() -> u8 {
97 8
98}
99
100impl Config {
101 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
107 let path = path.as_ref();
108 let extension =
109 path.extension()
110 .and_then(|s| s.to_str())
111 .ok_or_else(|| QuantizeError::Config {
112 reason: "Config file has no extension".into(),
113 })?;
114
115 let content = std::fs::read_to_string(path).map_err(|e| QuantizeError::Config {
116 reason: format!("Failed to read config file '{}': {e}", path.display()),
117 })?;
118
119 match extension {
120 "yaml" | "yml" => Self::from_yaml(&content),
121 "toml" => Self::from_toml(&content),
122 _ => Err(QuantizeError::Config {
123 reason: format!("Unsupported config format: {}", extension),
124 }),
125 }
126 }
127
128 pub fn from_yaml(content: &str) -> Result<Self> {
130 serde_yaml::from_str(content).map_err(|e| QuantizeError::Config {
131 reason: format!("Failed to parse YAML config: {e}"),
132 })
133 }
134
135 pub fn from_toml(content: &str) -> Result<Self> {
137 toml::from_str(content).map_err(|e| QuantizeError::Config {
138 reason: format!("Failed to parse TOML config: {e}"),
139 })
140 }
141
142 pub fn validate(&self) -> Result<()> {
148 if self.bits != 4 && self.bits != 8 {
149 return Err(QuantizeError::Config {
150 reason: format!("Invalid bits value: {}. Must be 4 or 8", self.bits),
151 });
152 }
153
154 for (idx, model) in self.models.iter().enumerate() {
155 if model.input.is_empty() {
156 return Err(QuantizeError::Config {
157 reason: format!("Model {}: input path is empty", idx),
158 });
159 }
160 if model.output.is_empty() {
161 return Err(QuantizeError::Config {
162 reason: format!("Model {}: output path is empty", idx),
163 });
164 }
165 if let Some(bits) = model.bits {
166 if bits != 4 && bits != 8 {
167 return Err(QuantizeError::Config {
168 reason: format!("Model {}: invalid bits value: {}", idx, bits),
169 });
170 }
171 }
172 for (layer, &bits) in &model.layer_bits {
173 if layer.is_empty() {
174 return Err(QuantizeError::Config {
175 reason: format!("Model {}: layer_bits contains an empty layer name", idx),
176 });
177 }
178 if bits != 4 && bits != 8 {
179 return Err(QuantizeError::Config {
180 reason: format!(
181 "Model {}: invalid bits {} for layer '{}'",
182 idx, bits, layer
183 ),
184 });
185 }
186 }
187 }
188
189 if let Some(batch) = &self.batch {
190 if batch.input_dir.is_empty() {
191 return Err(QuantizeError::Config {
192 reason: "Batch input_dir is empty".into(),
193 });
194 }
195 if batch.output_dir.is_empty() {
196 return Err(QuantizeError::Config {
197 reason: "Batch output_dir is empty".into(),
198 });
199 }
200 }
201
202 Ok(())
203 }
204
205 pub fn get_bits(&self, model: &ModelConfig) -> u8 {
207 model.bits.unwrap_or(self.bits)
208 }
209
210 pub fn get_per_channel(&self, model: &ModelConfig) -> bool {
212 model.per_channel.unwrap_or(self.per_channel)
213 }
214
215 pub fn get_excluded_layers(&self, model: &ModelConfig) -> Vec<String> {
217 let mut layers = self.excluded_layers.clone();
218 for l in &model.excluded_layers {
219 if !layers.contains(l) {
220 layers.push(l.clone());
221 }
222 }
223 layers
224 }
225
226 pub fn get_min_elements(&self, model: &ModelConfig) -> usize {
228 model.min_elements.unwrap_or(self.min_elements)
229 }
230
231 pub fn get_layer_bits(&self, model: &ModelConfig) -> std::collections::HashMap<String, u8> {
236 model.layer_bits.clone()
237 }
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243
244 #[test]
245 fn test_yaml_config() {
246 let yaml = r#"
247bits: 8
248per_channel: true
249
250models:
251 - input: model1.onnx
252 output: model1_int8.onnx
253
254 - input: model2.onnx
255 output: model2_int8.onnx
256 per_channel: false
257
258batch:
259 input_dir: "models/*.onnx"
260 output_dir: quantized/
261 skip_existing: true
262"#;
263
264 let config = Config::from_yaml(yaml).unwrap();
265 assert_eq!(config.bits, 8);
266 assert!(config.per_channel);
267 assert_eq!(config.models.len(), 2);
268 assert!(config.batch.is_some());
269 }
270
271 #[test]
272 fn test_empty_layer_bits_key_rejected() {
273 let yaml = r#"
274bits: 8
275models:
276 - input: model.onnx
277 output: out.onnx
278 layer_bits:
279 "": 4
280"#;
281 let config = Config::from_yaml(yaml).unwrap();
282 let err = config.validate().unwrap_err();
283 assert!(matches!(err, crate::errors::QuantizeError::Config { .. }));
284 assert!(err.to_string().contains("empty layer name"));
285 }
286
287 #[test]
288 fn test_toml_config() {
289 let toml = r#"
290bits = 8
291per_channel = true
292
293[[models]]
294input = "model1.onnx"
295output = "model1_int8.onnx"
296
297[[models]]
298input = "model2.onnx"
299output = "model2_int8.onnx"
300per_channel = false
301
302[batch]
303input_dir = "models/*.onnx"
304output_dir = "quantized/"
305skip_existing = true
306"#;
307
308 let config = Config::from_toml(toml).unwrap();
309 assert_eq!(config.bits, 8);
310 assert!(config.per_channel);
311 assert_eq!(config.models.len(), 2);
312 assert!(config.batch.is_some());
313 }
314}