1use crate::errors::{QuantizeError, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Config {
16 #[serde(default = "default_bits")]
18 pub bits: u8,
19
20 #[serde(default)]
22 pub per_channel: bool,
23
24 #[serde(default)]
26 pub excluded_layers: Vec<String>,
27
28 #[serde(default)]
31 pub min_elements: usize,
32
33 #[serde(default)]
37 pub native_int4: bool,
38
39 #[serde(default)]
43 pub symmetric: bool,
44
45 #[serde(default)]
47 pub models: Vec<ModelConfig>,
48
49 #[serde(default)]
51 pub batch: Option<BatchConfig>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ModelConfig {
57 pub input: String,
59
60 pub output: String,
62
63 #[serde(default)]
65 pub bits: Option<u8>,
66
67 #[serde(default)]
69 pub per_channel: Option<bool>,
70
71 #[serde(default)]
73 pub skip_existing: bool,
74
75 #[serde(default)]
78 pub excluded_layers: Vec<String>,
79
80 #[serde(default)]
83 pub layer_bits: std::collections::HashMap<String, u8>,
84
85 #[serde(default)]
87 pub min_elements: Option<usize>,
88
89 #[serde(default)]
91 pub native_int4: Option<bool>,
92
93 #[serde(default)]
95 pub symmetric: Option<bool>,
96}
97
98#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct BatchConfig {
101 pub input_dir: String,
103
104 pub output_dir: String,
106
107 #[serde(default)]
109 pub skip_existing: bool,
110
111 #[serde(default)]
113 pub continue_on_error: bool,
114}
115
116fn default_bits() -> u8 {
117 8
118}
119
120impl Config {
121 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
127 let path = path.as_ref();
128 let extension =
129 path.extension()
130 .and_then(|s| s.to_str())
131 .ok_or_else(|| QuantizeError::Config {
132 reason: "Config file has no extension".into(),
133 })?;
134
135 let content = std::fs::read_to_string(path).map_err(|e| QuantizeError::Config {
136 reason: format!("Failed to read config file '{}': {e}", path.display()),
137 })?;
138
139 match extension {
140 "yaml" | "yml" => Self::from_yaml(&content),
141 "toml" => Self::from_toml(&content),
142 _ => Err(QuantizeError::Config {
143 reason: format!("Unsupported config format: {}", extension),
144 }),
145 }
146 }
147
148 pub fn from_yaml(content: &str) -> Result<Self> {
150 serde_yaml::from_str(content).map_err(|e| QuantizeError::Config {
151 reason: format!("Failed to parse YAML config: {e}"),
152 })
153 }
154
155 pub fn from_toml(content: &str) -> Result<Self> {
157 toml::from_str(content).map_err(|e| QuantizeError::Config {
158 reason: format!("Failed to parse TOML config: {e}"),
159 })
160 }
161
162 pub fn validate(&self) -> Result<()> {
168 if self.bits != 4 && self.bits != 8 {
169 return Err(QuantizeError::Config {
170 reason: format!("Invalid bits value: {}. Must be 4 or 8", self.bits),
171 });
172 }
173
174 for (idx, model) in self.models.iter().enumerate() {
175 if model.input.is_empty() {
176 return Err(QuantizeError::Config {
177 reason: format!("Model {}: input path is empty", idx),
178 });
179 }
180 if model.output.is_empty() {
181 return Err(QuantizeError::Config {
182 reason: format!("Model {}: output path is empty", idx),
183 });
184 }
185 if let Some(bits) = model.bits {
186 if bits != 4 && bits != 8 {
187 return Err(QuantizeError::Config {
188 reason: format!("Model {}: invalid bits value: {}", idx, bits),
189 });
190 }
191 }
192 for (layer, &bits) in &model.layer_bits {
193 if layer.is_empty() {
194 return Err(QuantizeError::Config {
195 reason: format!("Model {}: layer_bits contains an empty layer name", idx),
196 });
197 }
198 if bits != 4 && bits != 8 {
199 return Err(QuantizeError::Config {
200 reason: format!(
201 "Model {}: invalid bits {} for layer '{}'",
202 idx, bits, layer
203 ),
204 });
205 }
206 }
207 }
208
209 if let Some(batch) = &self.batch {
210 if batch.input_dir.is_empty() {
211 return Err(QuantizeError::Config {
212 reason: "Batch input_dir is empty".into(),
213 });
214 }
215 if batch.output_dir.is_empty() {
216 return Err(QuantizeError::Config {
217 reason: "Batch output_dir is empty".into(),
218 });
219 }
220 }
221
222 Ok(())
223 }
224
225 pub fn get_bits(&self, model: &ModelConfig) -> u8 {
227 model.bits.unwrap_or(self.bits)
228 }
229
230 pub fn get_per_channel(&self, model: &ModelConfig) -> bool {
232 model.per_channel.unwrap_or(self.per_channel)
233 }
234
235 pub fn get_excluded_layers(&self, model: &ModelConfig) -> Vec<String> {
237 let mut layers = self.excluded_layers.clone();
238 for l in &model.excluded_layers {
239 if !layers.contains(l) {
240 layers.push(l.clone());
241 }
242 }
243 layers
244 }
245
246 pub fn get_min_elements(&self, model: &ModelConfig) -> usize {
248 model.min_elements.unwrap_or(self.min_elements)
249 }
250
251 pub fn get_native_int4(&self, model: &ModelConfig) -> bool {
253 model.native_int4.unwrap_or(self.native_int4)
254 }
255
256 pub fn get_symmetric(&self, model: &ModelConfig) -> bool {
258 model.symmetric.unwrap_or(self.symmetric)
259 }
260
261 pub fn get_layer_bits(&self, model: &ModelConfig) -> std::collections::HashMap<String, u8> {
266 model.layer_bits.clone()
267 }
268}
269
270#[cfg(test)]
271mod tests {
272 use super::*;
273
274 #[test]
275 fn test_yaml_config() {
276 let yaml = r#"
277bits: 8
278per_channel: true
279
280models:
281 - input: model1.onnx
282 output: model1_int8.onnx
283
284 - input: model2.onnx
285 output: model2_int8.onnx
286 per_channel: false
287
288batch:
289 input_dir: "models/*.onnx"
290 output_dir: quantized/
291 skip_existing: true
292"#;
293
294 let config = Config::from_yaml(yaml).unwrap();
295 assert_eq!(config.bits, 8);
296 assert!(config.per_channel);
297 assert_eq!(config.models.len(), 2);
298 assert!(config.batch.is_some());
299 }
300
301 #[test]
302 fn test_empty_layer_bits_key_rejected() {
303 let yaml = r#"
304bits: 8
305models:
306 - input: model.onnx
307 output: out.onnx
308 layer_bits:
309 "": 4
310"#;
311 let config = Config::from_yaml(yaml).unwrap();
312 let err = config.validate().unwrap_err();
313 assert!(matches!(err, crate::errors::QuantizeError::Config { .. }));
314 assert!(err.to_string().contains("empty layer name"));
315 }
316
317 #[test]
318 fn test_toml_config() {
319 let toml = r#"
320bits = 8
321per_channel = true
322
323[[models]]
324input = "model1.onnx"
325output = "model1_int8.onnx"
326
327[[models]]
328input = "model2.onnx"
329output = "model2_int8.onnx"
330per_channel = false
331
332[batch]
333input_dir = "models/*.onnx"
334output_dir = "quantized/"
335skip_existing = true
336"#;
337
338 let config = Config::from_toml(toml).unwrap();
339 assert_eq!(config.bits, 8);
340 assert!(config.per_channel);
341 assert_eq!(config.models.len(), 2);
342 assert!(config.batch.is_some());
343 }
344}