1use crate::errors::{QuantizeError, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Config {
16 #[serde(default = "default_bits")]
18 pub bits: u8,
19
20 #[serde(default)]
22 pub per_channel: bool,
23
24 #[serde(default)]
26 pub excluded_layers: Vec<String>,
27
28 #[serde(default)]
31 pub min_elements: usize,
32
33 #[serde(default)]
35 pub models: Vec<ModelConfig>,
36
37 #[serde(default)]
39 pub batch: Option<BatchConfig>,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelConfig {
45 pub input: String,
47
48 pub output: String,
50
51 #[serde(default)]
53 pub bits: Option<u8>,
54
55 #[serde(default)]
57 pub per_channel: Option<bool>,
58
59 #[serde(default)]
61 pub skip_existing: bool,
62
63 #[serde(default)]
66 pub excluded_layers: Vec<String>,
67
68 #[serde(default)]
71 pub layer_bits: std::collections::HashMap<String, u8>,
72
73 #[serde(default)]
75 pub min_elements: Option<usize>,
76}
77
78#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct BatchConfig {
81 pub input_dir: String,
83
84 pub output_dir: String,
86
87 #[serde(default)]
89 pub skip_existing: bool,
90
91 #[serde(default)]
93 pub continue_on_error: bool,
94}
95
96fn default_bits() -> u8 {
97 8
98}
99
100impl Config {
101 pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
107 let path = path.as_ref();
108 let extension = path.extension()
109 .and_then(|s| s.to_str())
110 .ok_or_else(|| QuantizeError::Config { reason: "Config file has no extension".into() })?;
111
112 let content = std::fs::read_to_string(path)
113 .map_err(|e| QuantizeError::Config { reason: format!("Failed to read config file '{}': {e}", path.display()) })?;
114
115 match extension {
116 "yaml" | "yml" => Self::from_yaml(&content),
117 "toml" => Self::from_toml(&content),
118 _ => Err(QuantizeError::Config { reason: format!("Unsupported config format: {}", extension) }),
119 }
120 }
121
122 pub fn from_yaml(content: &str) -> Result<Self> {
124 serde_yaml::from_str(content)
125 .map_err(|e| QuantizeError::Config { reason: format!("Failed to parse YAML config: {e}") })
126 }
127
128 pub fn from_toml(content: &str) -> Result<Self> {
130 toml::from_str(content)
131 .map_err(|e| QuantizeError::Config { reason: format!("Failed to parse TOML config: {e}") })
132 }
133
134 pub fn validate(&self) -> Result<()> {
140 if self.bits != 4 && self.bits != 8 {
141 return Err(QuantizeError::Config { reason: format!("Invalid bits value: {}. Must be 4 or 8", self.bits) });
142 }
143
144 for (idx, model) in self.models.iter().enumerate() {
145 if model.input.is_empty() {
146 return Err(QuantizeError::Config { reason: format!("Model {}: input path is empty", idx) });
147 }
148 if model.output.is_empty() {
149 return Err(QuantizeError::Config { reason: format!("Model {}: output path is empty", idx) });
150 }
151 if let Some(bits) = model.bits {
152 if bits != 4 && bits != 8 {
153 return Err(QuantizeError::Config { reason: format!("Model {}: invalid bits value: {}", idx, bits) });
154 }
155 }
156 for (layer, &bits) in &model.layer_bits {
157 if bits != 4 && bits != 8 {
158 return Err(QuantizeError::Config { reason: format!("Model {}: invalid bits {} for layer '{}'", idx, bits, layer) });
159 }
160 }
161 }
162
163 if let Some(batch) = &self.batch {
164 if batch.input_dir.is_empty() {
165 return Err(QuantizeError::Config { reason: "Batch input_dir is empty".into() });
166 }
167 if batch.output_dir.is_empty() {
168 return Err(QuantizeError::Config { reason: "Batch output_dir is empty".into() });
169 }
170 }
171
172 Ok(())
173 }
174
175 pub fn get_bits(&self, model: &ModelConfig) -> u8 {
177 model.bits.unwrap_or(self.bits)
178 }
179
180 pub fn get_per_channel(&self, model: &ModelConfig) -> bool {
182 model.per_channel.unwrap_or(self.per_channel)
183 }
184
185 pub fn get_excluded_layers(&self, model: &ModelConfig) -> Vec<String> {
187 let mut layers = self.excluded_layers.clone();
188 for l in &model.excluded_layers {
189 if !layers.contains(l) {
190 layers.push(l.clone());
191 }
192 }
193 layers
194 }
195
196 pub fn get_min_elements(&self, model: &ModelConfig) -> usize {
198 model.min_elements.unwrap_or(self.min_elements)
199 }
200
201 pub fn get_layer_bits(&self, model: &ModelConfig) -> std::collections::HashMap<String, u8> {
206 model.layer_bits.clone()
207 }
208}
209
210#[cfg(test)]
211mod tests {
212 use super::*;
213
214 #[test]
215 fn test_yaml_config() {
216 let yaml = r#"
217bits: 8
218per_channel: true
219
220models:
221 - input: model1.onnx
222 output: model1_int8.onnx
223
224 - input: model2.onnx
225 output: model2_int8.onnx
226 per_channel: false
227
228batch:
229 input_dir: "models/*.onnx"
230 output_dir: quantized/
231 skip_existing: true
232"#;
233
234 let config = Config::from_yaml(yaml).unwrap();
235 assert_eq!(config.bits, 8);
236 assert!(config.per_channel);
237 assert_eq!(config.models.len(), 2);
238 assert!(config.batch.is_some());
239 }
240
241 #[test]
242 fn test_toml_config() {
243 let toml = r#"
244bits = 8
245per_channel = true
246
247[[models]]
248input = "model1.onnx"
249output = "model1_int8.onnx"
250
251[[models]]
252input = "model2.onnx"
253output = "model2_int8.onnx"
254per_channel = false
255
256[batch]
257input_dir = "models/*.onnx"
258output_dir = "quantized/"
259skip_existing = true
260"#;
261
262 let config = Config::from_toml(toml).unwrap();
263 assert_eq!(config.bits, 8);
264 assert!(config.per_channel);
265 assert_eq!(config.models.len(), 2);
266 assert!(config.batch.is_some());
267 }
268}