Skip to main content

quantize_rs/
config.rs

1//! YAML and TOML configuration file support.
2//!
3//! A configuration file can specify global quantization settings
4//! (`bits`, `per_channel`), per-model overrides, and batch processing
5//! parameters.
6
7use crate::errors::{QuantizeError, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11/// Top-level quantization configuration.
12///
13/// Can be loaded from a YAML or TOML file with [`Config::from_file`].
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Config {
16    /// Default bit width (4 or 8). Defaults to 8.
17    #[serde(default = "default_bits")]
18    pub bits: u8,
19
20    /// Default per-channel setting. Defaults to `false`.
21    #[serde(default)]
22    pub per_channel: bool,
23
24    /// Layer names to exclude from quantization globally.
25    #[serde(default)]
26    pub excluded_layers: Vec<String>,
27
28    /// Minimum number of elements a tensor must have to be quantized.
29    /// Tensors smaller than this are kept in FP32. Defaults to 0 (no minimum).
30    #[serde(default)]
31    pub min_elements: usize,
32
33    /// Store INT4 weights as native ONNX `DataType::Int4` (opset 21).
34    /// Only affects models quantized with `bits=4` or per-layer INT4 overrides.
35    /// Defaults to `false` (widen INT4 to INT8, compatible with opset 10+).
36    #[serde(default)]
37    pub native_int4: bool,
38
39    /// Use symmetric quantization (zero_point == 0).  Required by most
40    /// ONNX Runtime / TensorRT INT8 matmul kernels for per-channel weight
41    /// quantization.  Defaults to `false` (asymmetric).
42    #[serde(default)]
43    pub symmetric: bool,
44
45    /// Per-model configuration overrides.
46    #[serde(default)]
47    pub models: Vec<ModelConfig>,
48
49    /// Batch processing configuration.
50    #[serde(default)]
51    pub batch: Option<BatchConfig>,
52}
53
54/// Per-model quantization overrides.
55#[derive(Debug, Clone, Serialize, Deserialize)]
56pub struct ModelConfig {
57    /// Path to the input ONNX model.
58    pub input: String,
59
60    /// Path for the quantized output model.
61    pub output: String,
62
63    /// Override bit width for this model.
64    #[serde(default)]
65    pub bits: Option<u8>,
66
67    /// Override per-channel setting for this model.
68    #[serde(default)]
69    pub per_channel: Option<bool>,
70
71    /// Skip this model if the output file already exists.
72    #[serde(default)]
73    pub skip_existing: bool,
74
75    /// Layer names to exclude from quantization for this model.
76    /// Merged with (but does not replace) the global `excluded_layers`.
77    #[serde(default)]
78    pub excluded_layers: Vec<String>,
79
80    /// Per-layer bit-width overrides for this model.
81    /// Key = initializer name, value = 4 or 8.
82    #[serde(default)]
83    pub layer_bits: std::collections::HashMap<String, u8>,
84
85    /// Override the global `min_elements` threshold for this model.
86    #[serde(default)]
87    pub min_elements: Option<usize>,
88
89    /// Override the global `native_int4` flag for this model.
90    #[serde(default)]
91    pub native_int4: Option<bool>,
92
93    /// Override the global `symmetric` flag for this model.
94    #[serde(default)]
95    pub symmetric: Option<bool>,
96}
97
98/// Batch processing configuration for quantizing multiple models.
99#[derive(Debug, Clone, Serialize, Deserialize)]
100pub struct BatchConfig {
101    /// Glob pattern or directory for input models.
102    pub input_dir: String,
103
104    /// Output directory for quantized models.
105    pub output_dir: String,
106
107    /// Skip models whose output already exists.
108    #[serde(default)]
109    pub skip_existing: bool,
110
111    /// Continue processing remaining models after a failure.
112    #[serde(default)]
113    pub continue_on_error: bool,
114}
115
116fn default_bits() -> u8 {
117    8
118}
119
120impl Config {
121    /// Load a config from a YAML or TOML file (auto-detected by extension).
122    ///
123    /// # Errors
124    ///
125    /// Returns [`QuantizeError::Config`] on I/O, parse, or unsupported format errors.
126    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
127        let path = path.as_ref();
128        let extension =
129            path.extension()
130                .and_then(|s| s.to_str())
131                .ok_or_else(|| QuantizeError::Config {
132                    reason: "Config file has no extension".into(),
133                })?;
134
135        let content = std::fs::read_to_string(path).map_err(|e| QuantizeError::Config {
136            reason: format!("Failed to read config file '{}': {e}", path.display()),
137        })?;
138
139        match extension {
140            "yaml" | "yml" => Self::from_yaml(&content),
141            "toml" => Self::from_toml(&content),
142            _ => Err(QuantizeError::Config {
143                reason: format!("Unsupported config format: {}", extension),
144            }),
145        }
146    }
147
148    /// Parse configuration from a YAML string.
149    pub fn from_yaml(content: &str) -> Result<Self> {
150        serde_yaml::from_str(content).map_err(|e| QuantizeError::Config {
151            reason: format!("Failed to parse YAML config: {e}"),
152        })
153    }
154
155    /// Parse configuration from a TOML string.
156    pub fn from_toml(content: &str) -> Result<Self> {
157        toml::from_str(content).map_err(|e| QuantizeError::Config {
158            reason: format!("Failed to parse TOML config: {e}"),
159        })
160    }
161
162    /// Validate the configuration (bits values, non-empty paths).
163    ///
164    /// # Errors
165    ///
166    /// Returns [`QuantizeError::Config`] if any field is invalid.
167    pub fn validate(&self) -> Result<()> {
168        if self.bits != 4 && self.bits != 8 {
169            return Err(QuantizeError::Config {
170                reason: format!("Invalid bits value: {}. Must be 4 or 8", self.bits),
171            });
172        }
173
174        for (idx, model) in self.models.iter().enumerate() {
175            if model.input.is_empty() {
176                return Err(QuantizeError::Config {
177                    reason: format!("Model {}: input path is empty", idx),
178                });
179            }
180            if model.output.is_empty() {
181                return Err(QuantizeError::Config {
182                    reason: format!("Model {}: output path is empty", idx),
183                });
184            }
185            if let Some(bits) = model.bits {
186                if bits != 4 && bits != 8 {
187                    return Err(QuantizeError::Config {
188                        reason: format!("Model {}: invalid bits value: {}", idx, bits),
189                    });
190                }
191            }
192            for (layer, &bits) in &model.layer_bits {
193                if layer.is_empty() {
194                    return Err(QuantizeError::Config {
195                        reason: format!("Model {}: layer_bits contains an empty layer name", idx),
196                    });
197                }
198                if bits != 4 && bits != 8 {
199                    return Err(QuantizeError::Config {
200                        reason: format!(
201                            "Model {}: invalid bits {} for layer '{}'",
202                            idx, bits, layer
203                        ),
204                    });
205                }
206            }
207        }
208
209        if let Some(batch) = &self.batch {
210            if batch.input_dir.is_empty() {
211                return Err(QuantizeError::Config {
212                    reason: "Batch input_dir is empty".into(),
213                });
214            }
215            if batch.output_dir.is_empty() {
216                return Err(QuantizeError::Config {
217                    reason: "Batch output_dir is empty".into(),
218                });
219            }
220        }
221
222        Ok(())
223    }
224
225    /// Effective bit width for a model (model override or global default).
226    pub fn get_bits(&self, model: &ModelConfig) -> u8 {
227        model.bits.unwrap_or(self.bits)
228    }
229
230    /// Effective per-channel setting for a model (model override or global default).
231    pub fn get_per_channel(&self, model: &ModelConfig) -> bool {
232        model.per_channel.unwrap_or(self.per_channel)
233    }
234
235    /// Effective excluded-layers list: global list merged with model-level list.
236    pub fn get_excluded_layers(&self, model: &ModelConfig) -> Vec<String> {
237        let mut layers = self.excluded_layers.clone();
238        for l in &model.excluded_layers {
239            if !layers.contains(l) {
240                layers.push(l.clone());
241            }
242        }
243        layers
244    }
245
246    /// Effective min-elements threshold for a model.
247    pub fn get_min_elements(&self, model: &ModelConfig) -> usize {
248        model.min_elements.unwrap_or(self.min_elements)
249    }
250
251    /// Effective native-INT4 flag for a model (override or global default).
252    pub fn get_native_int4(&self, model: &ModelConfig) -> bool {
253        model.native_int4.unwrap_or(self.native_int4)
254    }
255
256    /// Effective symmetric flag for a model (override or global default).
257    pub fn get_symmetric(&self, model: &ModelConfig) -> bool {
258        model.symmetric.unwrap_or(self.symmetric)
259    }
260
261    /// Effective per-layer bit-width overrides for a model.
262    ///
263    /// Layer names are model-specific so there is no global map to merge;
264    /// this simply returns the model's own `layer_bits` map.
265    pub fn get_layer_bits(&self, model: &ModelConfig) -> std::collections::HashMap<String, u8> {
266        model.layer_bits.clone()
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_yaml_config() {
276        let yaml = r#"
277bits: 8
278per_channel: true
279
280models:
281  - input: model1.onnx
282    output: model1_int8.onnx
283  
284  - input: model2.onnx
285    output: model2_int8.onnx
286    per_channel: false
287
288batch:
289  input_dir: "models/*.onnx"
290  output_dir: quantized/
291  skip_existing: true
292"#;
293
294        let config = Config::from_yaml(yaml).unwrap();
295        assert_eq!(config.bits, 8);
296        assert!(config.per_channel);
297        assert_eq!(config.models.len(), 2);
298        assert!(config.batch.is_some());
299    }
300
301    #[test]
302    fn test_empty_layer_bits_key_rejected() {
303        let yaml = r#"
304bits: 8
305models:
306  - input: model.onnx
307    output: out.onnx
308    layer_bits:
309      "": 4
310"#;
311        let config = Config::from_yaml(yaml).unwrap();
312        let err = config.validate().unwrap_err();
313        assert!(matches!(err, crate::errors::QuantizeError::Config { .. }));
314        assert!(err.to_string().contains("empty layer name"));
315    }
316
317    #[test]
318    fn test_toml_config() {
319        let toml = r#"
320bits = 8
321per_channel = true
322
323[[models]]
324input = "model1.onnx"
325output = "model1_int8.onnx"
326
327[[models]]
328input = "model2.onnx"
329output = "model2_int8.onnx"
330per_channel = false
331
332[batch]
333input_dir = "models/*.onnx"
334output_dir = "quantized/"
335skip_existing = true
336"#;
337
338        let config = Config::from_toml(toml).unwrap();
339        assert_eq!(config.bits, 8);
340        assert!(config.per_channel);
341        assert_eq!(config.models.len(), 2);
342        assert!(config.batch.is_some());
343    }
344}