Skip to main content

quantize_rs/
config.rs

1//! YAML and TOML configuration file support.
2//!
3//! A configuration file can specify global quantization settings
4//! (`bits`, `per_channel`), per-model overrides, and batch processing
5//! parameters.
6
7use crate::errors::{QuantizeError, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11/// Top-level quantization configuration.
12///
13/// Can be loaded from a YAML or TOML file with [`Config::from_file`].
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Config {
16    /// Default bit width (4 or 8). Defaults to 8.
17    #[serde(default = "default_bits")]
18    pub bits: u8,
19
20    /// Default per-channel setting. Defaults to `false`.
21    #[serde(default)]
22    pub per_channel: bool,
23
24    /// Layer names to exclude from quantization globally.
25    #[serde(default)]
26    pub excluded_layers: Vec<String>,
27
28    /// Minimum number of elements a tensor must have to be quantized.
29    /// Tensors smaller than this are kept in FP32. Defaults to 0 (no minimum).
30    #[serde(default)]
31    pub min_elements: usize,
32
33    /// Per-model configuration overrides.
34    #[serde(default)]
35    pub models: Vec<ModelConfig>,
36
37    /// Batch processing configuration.
38    #[serde(default)]
39    pub batch: Option<BatchConfig>,
40}
41
42/// Per-model quantization overrides.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelConfig {
45    /// Path to the input ONNX model.
46    pub input: String,
47
48    /// Path for the quantized output model.
49    pub output: String,
50
51    /// Override bit width for this model.
52    #[serde(default)]
53    pub bits: Option<u8>,
54
55    /// Override per-channel setting for this model.
56    #[serde(default)]
57    pub per_channel: Option<bool>,
58
59    /// Skip this model if the output file already exists.
60    #[serde(default)]
61    pub skip_existing: bool,
62
63    /// Layer names to exclude from quantization for this model.
64    /// Merged with (but does not replace) the global `excluded_layers`.
65    #[serde(default)]
66    pub excluded_layers: Vec<String>,
67
68    /// Per-layer bit-width overrides for this model.
69    /// Key = initializer name, value = 4 or 8.
70    #[serde(default)]
71    pub layer_bits: std::collections::HashMap<String, u8>,
72
73    /// Override the global `min_elements` threshold for this model.
74    #[serde(default)]
75    pub min_elements: Option<usize>,
76}
77
78/// Batch processing configuration for quantizing multiple models.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct BatchConfig {
81    /// Glob pattern or directory for input models.
82    pub input_dir: String,
83
84    /// Output directory for quantized models.
85    pub output_dir: String,
86
87    /// Skip models whose output already exists.
88    #[serde(default)]
89    pub skip_existing: bool,
90
91    /// Continue processing remaining models after a failure.
92    #[serde(default)]
93    pub continue_on_error: bool,
94}
95
96fn default_bits() -> u8 {
97    8
98}
99
100impl Config {
101    /// Load a config from a YAML or TOML file (auto-detected by extension).
102    ///
103    /// # Errors
104    ///
105    /// Returns [`QuantizeError::Config`] on I/O, parse, or unsupported format errors.
106    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
107        let path = path.as_ref();
108        let extension =
109            path.extension()
110                .and_then(|s| s.to_str())
111                .ok_or_else(|| QuantizeError::Config {
112                    reason: "Config file has no extension".into(),
113                })?;
114
115        let content = std::fs::read_to_string(path).map_err(|e| QuantizeError::Config {
116            reason: format!("Failed to read config file '{}': {e}", path.display()),
117        })?;
118
119        match extension {
120            "yaml" | "yml" => Self::from_yaml(&content),
121            "toml" => Self::from_toml(&content),
122            _ => Err(QuantizeError::Config {
123                reason: format!("Unsupported config format: {}", extension),
124            }),
125        }
126    }
127
128    /// Parse configuration from a YAML string.
129    pub fn from_yaml(content: &str) -> Result<Self> {
130        serde_yaml::from_str(content).map_err(|e| QuantizeError::Config {
131            reason: format!("Failed to parse YAML config: {e}"),
132        })
133    }
134
135    /// Parse configuration from a TOML string.
136    pub fn from_toml(content: &str) -> Result<Self> {
137        toml::from_str(content).map_err(|e| QuantizeError::Config {
138            reason: format!("Failed to parse TOML config: {e}"),
139        })
140    }
141
142    /// Validate the configuration (bits values, non-empty paths).
143    ///
144    /// # Errors
145    ///
146    /// Returns [`QuantizeError::Config`] if any field is invalid.
147    pub fn validate(&self) -> Result<()> {
148        if self.bits != 4 && self.bits != 8 {
149            return Err(QuantizeError::Config {
150                reason: format!("Invalid bits value: {}. Must be 4 or 8", self.bits),
151            });
152        }
153
154        for (idx, model) in self.models.iter().enumerate() {
155            if model.input.is_empty() {
156                return Err(QuantizeError::Config {
157                    reason: format!("Model {}: input path is empty", idx),
158                });
159            }
160            if model.output.is_empty() {
161                return Err(QuantizeError::Config {
162                    reason: format!("Model {}: output path is empty", idx),
163                });
164            }
165            if let Some(bits) = model.bits {
166                if bits != 4 && bits != 8 {
167                    return Err(QuantizeError::Config {
168                        reason: format!("Model {}: invalid bits value: {}", idx, bits),
169                    });
170                }
171            }
172            for (layer, &bits) in &model.layer_bits {
173                if layer.is_empty() {
174                    return Err(QuantizeError::Config {
175                        reason: format!("Model {}: layer_bits contains an empty layer name", idx),
176                    });
177                }
178                if bits != 4 && bits != 8 {
179                    return Err(QuantizeError::Config {
180                        reason: format!(
181                            "Model {}: invalid bits {} for layer '{}'",
182                            idx, bits, layer
183                        ),
184                    });
185                }
186            }
187        }
188
189        if let Some(batch) = &self.batch {
190            if batch.input_dir.is_empty() {
191                return Err(QuantizeError::Config {
192                    reason: "Batch input_dir is empty".into(),
193                });
194            }
195            if batch.output_dir.is_empty() {
196                return Err(QuantizeError::Config {
197                    reason: "Batch output_dir is empty".into(),
198                });
199            }
200        }
201
202        Ok(())
203    }
204
205    /// Effective bit width for a model (model override or global default).
206    pub fn get_bits(&self, model: &ModelConfig) -> u8 {
207        model.bits.unwrap_or(self.bits)
208    }
209
210    /// Effective per-channel setting for a model (model override or global default).
211    pub fn get_per_channel(&self, model: &ModelConfig) -> bool {
212        model.per_channel.unwrap_or(self.per_channel)
213    }
214
215    /// Effective excluded-layers list: global list merged with model-level list.
216    pub fn get_excluded_layers(&self, model: &ModelConfig) -> Vec<String> {
217        let mut layers = self.excluded_layers.clone();
218        for l in &model.excluded_layers {
219            if !layers.contains(l) {
220                layers.push(l.clone());
221            }
222        }
223        layers
224    }
225
226    /// Effective min-elements threshold for a model.
227    pub fn get_min_elements(&self, model: &ModelConfig) -> usize {
228        model.min_elements.unwrap_or(self.min_elements)
229    }
230
231    /// Effective per-layer bit-width overrides for a model.
232    ///
233    /// Layer names are model-specific so there is no global map to merge;
234    /// this simply returns the model's own `layer_bits` map.
235    pub fn get_layer_bits(&self, model: &ModelConfig) -> std::collections::HashMap<String, u8> {
236        model.layer_bits.clone()
237    }
238}
239
240#[cfg(test)]
241mod tests {
242    use super::*;
243
244    #[test]
245    fn test_yaml_config() {
246        let yaml = r#"
247bits: 8
248per_channel: true
249
250models:
251  - input: model1.onnx
252    output: model1_int8.onnx
253  
254  - input: model2.onnx
255    output: model2_int8.onnx
256    per_channel: false
257
258batch:
259  input_dir: "models/*.onnx"
260  output_dir: quantized/
261  skip_existing: true
262"#;
263
264        let config = Config::from_yaml(yaml).unwrap();
265        assert_eq!(config.bits, 8);
266        assert!(config.per_channel);
267        assert_eq!(config.models.len(), 2);
268        assert!(config.batch.is_some());
269    }
270
271    #[test]
272    fn test_empty_layer_bits_key_rejected() {
273        let yaml = r#"
274bits: 8
275models:
276  - input: model.onnx
277    output: out.onnx
278    layer_bits:
279      "": 4
280"#;
281        let config = Config::from_yaml(yaml).unwrap();
282        let err = config.validate().unwrap_err();
283        assert!(matches!(err, crate::errors::QuantizeError::Config { .. }));
284        assert!(err.to_string().contains("empty layer name"));
285    }
286
287    #[test]
288    fn test_toml_config() {
289        let toml = r#"
290bits = 8
291per_channel = true
292
293[[models]]
294input = "model1.onnx"
295output = "model1_int8.onnx"
296
297[[models]]
298input = "model2.onnx"
299output = "model2_int8.onnx"
300per_channel = false
301
302[batch]
303input_dir = "models/*.onnx"
304output_dir = "quantized/"
305skip_existing = true
306"#;
307
308        let config = Config::from_toml(toml).unwrap();
309        assert_eq!(config.bits, 8);
310        assert!(config.per_channel);
311        assert_eq!(config.models.len(), 2);
312        assert!(config.batch.is_some());
313    }
314}