use crate::errors::{QuantizeError, Result};
use serde::{Deserialize, Serialize};
use std::path::Path;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Config {
#[serde(default = "default_bits")]
pub bits: u8,
#[serde(default)]
pub per_channel: bool,
#[serde(default)]
pub excluded_layers: Vec<String>,
#[serde(default)]
pub min_elements: usize,
#[serde(default)]
pub native_int4: bool,
#[serde(default)]
pub symmetric: bool,
#[serde(default)]
pub models: Vec<ModelConfig>,
#[serde(default)]
pub batch: Option<BatchConfig>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelConfig {
pub input: String,
pub output: String,
#[serde(default)]
pub bits: Option<u8>,
#[serde(default)]
pub per_channel: Option<bool>,
#[serde(default)]
pub skip_existing: bool,
#[serde(default)]
pub excluded_layers: Vec<String>,
#[serde(default)]
pub layer_bits: std::collections::HashMap<String, u8>,
#[serde(default)]
pub min_elements: Option<usize>,
#[serde(default)]
pub native_int4: Option<bool>,
#[serde(default)]
pub symmetric: Option<bool>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BatchConfig {
pub input_dir: String,
pub output_dir: String,
#[serde(default)]
pub skip_existing: bool,
#[serde(default)]
pub continue_on_error: bool,
}
fn default_bits() -> u8 {
8
}
impl Config {
pub fn from_file(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref();
let extension =
path.extension()
.and_then(|s| s.to_str())
.ok_or_else(|| QuantizeError::Config {
reason: "Config file has no extension".into(),
})?;
let content = std::fs::read_to_string(path).map_err(|e| QuantizeError::Config {
reason: format!("Failed to read config file '{}': {e}", path.display()),
})?;
match extension {
"yaml" | "yml" => Self::from_yaml(&content),
"toml" => Self::from_toml(&content),
_ => Err(QuantizeError::Config {
reason: format!("Unsupported config format: {}", extension),
}),
}
}
pub fn from_yaml(content: &str) -> Result<Self> {
serde_yaml::from_str(content).map_err(|e| QuantizeError::Config {
reason: format!("Failed to parse YAML config: {e}"),
})
}
pub fn from_toml(content: &str) -> Result<Self> {
toml::from_str(content).map_err(|e| QuantizeError::Config {
reason: format!("Failed to parse TOML config: {e}"),
})
}
pub fn validate(&self) -> Result<()> {
if self.bits != 4 && self.bits != 8 {
return Err(QuantizeError::Config {
reason: format!("Invalid bits value: {}. Must be 4 or 8", self.bits),
});
}
for (idx, model) in self.models.iter().enumerate() {
if model.input.is_empty() {
return Err(QuantizeError::Config {
reason: format!("Model {}: input path is empty", idx),
});
}
if model.output.is_empty() {
return Err(QuantizeError::Config {
reason: format!("Model {}: output path is empty", idx),
});
}
if let Some(bits) = model.bits {
if bits != 4 && bits != 8 {
return Err(QuantizeError::Config {
reason: format!("Model {}: invalid bits value: {}", idx, bits),
});
}
}
for (layer, &bits) in &model.layer_bits {
if layer.is_empty() {
return Err(QuantizeError::Config {
reason: format!("Model {}: layer_bits contains an empty layer name", idx),
});
}
if bits != 4 && bits != 8 {
return Err(QuantizeError::Config {
reason: format!(
"Model {}: invalid bits {} for layer '{}'",
idx, bits, layer
),
});
}
}
}
if let Some(batch) = &self.batch {
if batch.input_dir.is_empty() {
return Err(QuantizeError::Config {
reason: "Batch input_dir is empty".into(),
});
}
if batch.output_dir.is_empty() {
return Err(QuantizeError::Config {
reason: "Batch output_dir is empty".into(),
});
}
}
Ok(())
}
pub fn get_bits(&self, model: &ModelConfig) -> u8 {
model.bits.unwrap_or(self.bits)
}
pub fn get_per_channel(&self, model: &ModelConfig) -> bool {
model.per_channel.unwrap_or(self.per_channel)
}
pub fn get_excluded_layers(&self, model: &ModelConfig) -> Vec<String> {
let mut layers = self.excluded_layers.clone();
for l in &model.excluded_layers {
if !layers.contains(l) {
layers.push(l.clone());
}
}
layers
}
pub fn get_min_elements(&self, model: &ModelConfig) -> usize {
model.min_elements.unwrap_or(self.min_elements)
}
pub fn get_native_int4(&self, model: &ModelConfig) -> bool {
model.native_int4.unwrap_or(self.native_int4)
}
pub fn get_symmetric(&self, model: &ModelConfig) -> bool {
model.symmetric.unwrap_or(self.symmetric)
}
pub fn get_layer_bits(&self, model: &ModelConfig) -> std::collections::HashMap<String, u8> {
model.layer_bits.clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_yaml_config() {
let yaml = r#"
bits: 8
per_channel: true
models:
- input: model1.onnx
output: model1_int8.onnx
- input: model2.onnx
output: model2_int8.onnx
per_channel: false
batch:
input_dir: "models/*.onnx"
output_dir: quantized/
skip_existing: true
"#;
let config = Config::from_yaml(yaml).unwrap();
assert_eq!(config.bits, 8);
assert!(config.per_channel);
assert_eq!(config.models.len(), 2);
assert!(config.batch.is_some());
}
#[test]
fn test_empty_layer_bits_key_rejected() {
let yaml = r#"
bits: 8
models:
- input: model.onnx
output: out.onnx
layer_bits:
"": 4
"#;
let config = Config::from_yaml(yaml).unwrap();
let err = config.validate().unwrap_err();
assert!(matches!(err, crate::errors::QuantizeError::Config { .. }));
assert!(err.to_string().contains("empty layer name"));
}
#[test]
fn test_toml_config() {
let toml = r#"
bits = 8
per_channel = true
[[models]]
input = "model1.onnx"
output = "model1_int8.onnx"
[[models]]
input = "model2.onnx"
output = "model2_int8.onnx"
per_channel = false
[batch]
input_dir = "models/*.onnx"
output_dir = "quantized/"
skip_existing = true
"#;
let config = Config::from_toml(toml).unwrap();
assert_eq!(config.bits, 8);
assert!(config.per_channel);
assert_eq!(config.models.len(), 2);
assert!(config.batch.is_some());
}
}