Skip to main content

quantize_rs/
config.rs

1//! YAML and TOML configuration file support.
2//!
3//! A configuration file can specify global quantization settings
4//! (`bits`, `per_channel`), per-model overrides, and batch processing
5//! parameters.
6
7use crate::errors::{QuantizeError, Result};
8use serde::{Deserialize, Serialize};
9use std::path::Path;
10
11/// Top-level quantization configuration.
12///
13/// Can be loaded from a YAML or TOML file with [`Config::from_file`].
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Config {
16    /// Default bit width (4 or 8). Defaults to 8.
17    #[serde(default = "default_bits")]
18    pub bits: u8,
19
20    /// Default per-channel setting. Defaults to `false`.
21    #[serde(default)]
22    pub per_channel: bool,
23
24    /// Layer names to exclude from quantization globally.
25    #[serde(default)]
26    pub excluded_layers: Vec<String>,
27
28    /// Minimum number of elements a tensor must have to be quantized.
29    /// Tensors smaller than this are kept in FP32. Defaults to 0 (no minimum).
30    #[serde(default)]
31    pub min_elements: usize,
32
33    /// Per-model configuration overrides.
34    #[serde(default)]
35    pub models: Vec<ModelConfig>,
36
37    /// Batch processing configuration.
38    #[serde(default)]
39    pub batch: Option<BatchConfig>,
40}
41
42/// Per-model quantization overrides.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct ModelConfig {
45    /// Path to the input ONNX model.
46    pub input: String,
47
48    /// Path for the quantized output model.
49    pub output: String,
50
51    /// Override bit width for this model.
52    #[serde(default)]
53    pub bits: Option<u8>,
54
55    /// Override per-channel setting for this model.
56    #[serde(default)]
57    pub per_channel: Option<bool>,
58
59    /// Skip this model if the output file already exists.
60    #[serde(default)]
61    pub skip_existing: bool,
62
63    /// Layer names to exclude from quantization for this model.
64    /// Merged with (but does not replace) the global `excluded_layers`.
65    #[serde(default)]
66    pub excluded_layers: Vec<String>,
67
68    /// Per-layer bit-width overrides for this model.
69    /// Key = initializer name, value = 4 or 8.
70    #[serde(default)]
71    pub layer_bits: std::collections::HashMap<String, u8>,
72
73    /// Override the global `min_elements` threshold for this model.
74    #[serde(default)]
75    pub min_elements: Option<usize>,
76}
77
78/// Batch processing configuration for quantizing multiple models.
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct BatchConfig {
81    /// Glob pattern or directory for input models.
82    pub input_dir: String,
83
84    /// Output directory for quantized models.
85    pub output_dir: String,
86
87    /// Skip models whose output already exists.
88    #[serde(default)]
89    pub skip_existing: bool,
90
91    /// Continue processing remaining models after a failure.
92    #[serde(default)]
93    pub continue_on_error: bool,
94}
95
96fn default_bits() -> u8 {
97    8
98}
99
100impl Config {
101    /// Load a config from a YAML or TOML file (auto-detected by extension).
102    ///
103    /// # Errors
104    ///
105    /// Returns [`QuantizeError::Config`] on I/O, parse, or unsupported format errors.
106    pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
107        let path = path.as_ref();
108        let extension = path.extension()
109            .and_then(|s| s.to_str())
110            .ok_or_else(|| QuantizeError::Config { reason: "Config file has no extension".into() })?;
111
112        let content = std::fs::read_to_string(path)
113            .map_err(|e| QuantizeError::Config { reason: format!("Failed to read config file '{}': {e}", path.display()) })?;
114
115        match extension {
116            "yaml" | "yml" => Self::from_yaml(&content),
117            "toml" => Self::from_toml(&content),
118            _ => Err(QuantizeError::Config { reason: format!("Unsupported config format: {}", extension) }),
119        }
120    }
121
122    /// Parse configuration from a YAML string.
123    pub fn from_yaml(content: &str) -> Result<Self> {
124        serde_yaml::from_str(content)
125            .map_err(|e| QuantizeError::Config { reason: format!("Failed to parse YAML config: {e}") })
126    }
127
128    /// Parse configuration from a TOML string.
129    pub fn from_toml(content: &str) -> Result<Self> {
130        toml::from_str(content)
131            .map_err(|e| QuantizeError::Config { reason: format!("Failed to parse TOML config: {e}") })
132    }
133
134    /// Validate the configuration (bits values, non-empty paths).
135    ///
136    /// # Errors
137    ///
138    /// Returns [`QuantizeError::Config`] if any field is invalid.
139    pub fn validate(&self) -> Result<()> {
140        if self.bits != 4 && self.bits != 8 {
141            return Err(QuantizeError::Config { reason: format!("Invalid bits value: {}. Must be 4 or 8", self.bits) });
142        }
143
144        for (idx, model) in self.models.iter().enumerate() {
145            if model.input.is_empty() {
146                return Err(QuantizeError::Config { reason: format!("Model {}: input path is empty", idx) });
147            }
148            if model.output.is_empty() {
149                return Err(QuantizeError::Config { reason: format!("Model {}: output path is empty", idx) });
150            }
151            if let Some(bits) = model.bits {
152                if bits != 4 && bits != 8 {
153                    return Err(QuantizeError::Config { reason: format!("Model {}: invalid bits value: {}", idx, bits) });
154                }
155            }
156            for (layer, &bits) in &model.layer_bits {
157                if bits != 4 && bits != 8 {
158                    return Err(QuantizeError::Config { reason: format!("Model {}: invalid bits {} for layer '{}'", idx, bits, layer) });
159                }
160            }
161        }
162
163        if let Some(batch) = &self.batch {
164            if batch.input_dir.is_empty() {
165                return Err(QuantizeError::Config { reason: "Batch input_dir is empty".into() });
166            }
167            if batch.output_dir.is_empty() {
168                return Err(QuantizeError::Config { reason: "Batch output_dir is empty".into() });
169            }
170        }
171
172        Ok(())
173    }
174
175    /// Effective bit width for a model (model override or global default).
176    pub fn get_bits(&self, model: &ModelConfig) -> u8 {
177        model.bits.unwrap_or(self.bits)
178    }
179
180    /// Effective per-channel setting for a model (model override or global default).
181    pub fn get_per_channel(&self, model: &ModelConfig) -> bool {
182        model.per_channel.unwrap_or(self.per_channel)
183    }
184
185    /// Effective excluded-layers list: global list merged with model-level list.
186    pub fn get_excluded_layers(&self, model: &ModelConfig) -> Vec<String> {
187        let mut layers = self.excluded_layers.clone();
188        for l in &model.excluded_layers {
189            if !layers.contains(l) {
190                layers.push(l.clone());
191            }
192        }
193        layers
194    }
195
196    /// Effective min-elements threshold for a model.
197    pub fn get_min_elements(&self, model: &ModelConfig) -> usize {
198        model.min_elements.unwrap_or(self.min_elements)
199    }
200
201    /// Effective per-layer bit-width overrides for a model.
202    ///
203    /// Layer names are model-specific so there is no global map to merge;
204    /// this simply returns the model's own `layer_bits` map.
205    pub fn get_layer_bits(&self, model: &ModelConfig) -> std::collections::HashMap<String, u8> {
206        model.layer_bits.clone()
207    }
208}
209
210#[cfg(test)]
211mod tests {
212    use super::*;
213
214    #[test]
215    fn test_yaml_config() {
216        let yaml = r#"
217bits: 8
218per_channel: true
219
220models:
221  - input: model1.onnx
222    output: model1_int8.onnx
223  
224  - input: model2.onnx
225    output: model2_int8.onnx
226    per_channel: false
227
228batch:
229  input_dir: "models/*.onnx"
230  output_dir: quantized/
231  skip_existing: true
232"#;
233
234        let config = Config::from_yaml(yaml).unwrap();
235        assert_eq!(config.bits, 8);
236        assert!(config.per_channel);
237        assert_eq!(config.models.len(), 2);
238        assert!(config.batch.is_some());
239    }
240
241    #[test]
242    fn test_toml_config() {
243        let toml = r#"
244bits = 8
245per_channel = true
246
247[[models]]
248input = "model1.onnx"
249output = "model1_int8.onnx"
250
251[[models]]
252input = "model2.onnx"
253output = "model2_int8.onnx"
254per_channel = false
255
256[batch]
257input_dir = "models/*.onnx"
258output_dir = "quantized/"
259skip_existing = true
260"#;
261
262        let config = Config::from_toml(toml).unwrap();
263        assert_eq!(config.bits, 8);
264        assert!(config.per_channel);
265        assert_eq!(config.models.len(), 2);
266        assert!(config.batch.is_some());
267    }
268}