use pyo3::exceptions::{PyRuntimeError, PyValueError};
use pyo3::prelude::*;
#[cfg(feature = "calibration")]
use crate::calibration::{methods::CalibrationMethod, ActivationEstimator, CalibrationDataset};
use crate::onnx_utils::graph_builder::QdqWeightInput;
use crate::onnx_utils::{OnnxModel, SaveOptions};
use crate::quantization::{QuantConfig, Quantizer};
#[pyclass]
#[derive(Clone)]
struct ModelInfo {
#[pyo3(get)]
name: String,
#[pyo3(get)]
version: i64,
#[pyo3(get)]
num_nodes: usize,
#[pyo3(get)]
inputs: Vec<String>,
#[pyo3(get)]
outputs: Vec<String>,
}
#[pyfunction]
#[pyo3(signature = (input_path, output_path, bits=8, per_channel=false, excluded_layers=None, min_elements=0, layer_bits=None, native_int4=false, symmetric=false))]
#[allow(clippy::too_many_arguments)]
fn quantize(
input_path: &str,
output_path: &str,
bits: u8,
per_channel: bool,
excluded_layers: Option<Vec<String>>,
min_elements: usize,
layer_bits: Option<std::collections::HashMap<String, u8>>,
native_int4: bool,
symmetric: bool,
) -> PyResult<()> {
if bits != 4 && bits != 8 {
return Err(PyValueError::new_err(format!(
"bits must be 4 or 8, got {}",
bits
)));
}
let mut model = OnnxModel::load(input_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load model: {}", e)))?;
if model.extract_weights().is_empty() {
return Err(PyRuntimeError::new_err("No weights found to quantize"));
}
let config = QuantConfig {
bits,
per_channel,
symmetric,
calibration_method: None,
excluded_layers: excluded_layers.unwrap_or_default(),
min_elements,
layer_bits: layer_bits.unwrap_or_default(),
};
let outputs = Quantizer::new(config)
.quantize_model(&model)
.map_err(|e| PyRuntimeError::new_err(format!("Quantization failed: {}", e)))?;
let quantized_data: Vec<QdqWeightInput> = outputs.into_iter().map(|o| o.qdq).collect();
let save_options = SaveOptions::default().with_native_int4(native_int4);
model
.save_quantized_with_options(&quantized_data, output_path, save_options)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to save model: {}", e)))?;
Ok(())
}
#[cfg(feature = "calibration")]
#[pyfunction]
#[pyo3(signature = (
input_path,
output_path,
calibration_data=None,
bits=8,
per_channel=false,
method="minmax",
num_samples=100,
sample_shape=None,
native_int4=false,
symmetric=false,
))]
#[allow(clippy::too_many_arguments)]
fn quantize_with_calibration(
input_path: &str,
output_path: &str,
calibration_data: Option<&str>,
bits: u8,
per_channel: bool,
method: &str,
num_samples: usize,
sample_shape: Option<Vec<usize>>,
native_int4: bool,
symmetric: bool,
) -> PyResult<()> {
if bits != 4 && bits != 8 {
return Err(PyValueError::new_err(format!(
"bits must be 4 or 8, got {}",
bits
)));
}
let calib_method: CalibrationMethod = method
.parse()
.map_err(|e| PyRuntimeError::new_err(format!("{}", e)))?;
let model = OnnxModel::load(input_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load model: {}", e)))?;
let dataset = if let Some(path) = calibration_data {
CalibrationDataset::from_numpy(path).map_err(|e| {
PyRuntimeError::new_err(format!("Failed to load calibration data: {}", e))
})?
} else {
let shape = if let Some(s) = sample_shape {
s
} else {
model
.input_shapes()
.into_iter()
.next()
.and_then(|dims| {
let shape: Vec<usize> = dims
.into_iter()
.filter_map(|d| if d > 0 { Some(d as usize) } else { None })
.collect();
if shape.len() >= 2 {
Some(shape[1..].to_vec())
} else if !shape.is_empty() {
Some(shape)
} else {
None
}
})
.unwrap_or_else(|| vec![3, 224, 224])
};
CalibrationDataset::random(shape, num_samples, (0.0, 1.0)).map_err(|e| {
PyRuntimeError::new_err(format!("Failed to create random dataset: {}", e))
})?
};
let mut estimator = ActivationEstimator::new(model, input_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to create estimator: {}", e)))?;
estimator
.calibrate(&dataset)
.map_err(|e| PyRuntimeError::new_err(format!("Calibration failed: {}", e)))?;
let activation_stats: std::collections::HashMap<String, _> = estimator
.get_layer_stats()
.into_iter()
.map(|(k, v)| (k, v.clone()))
.collect();
let mut model = estimator.into_model();
let config = QuantConfig {
bits,
per_channel,
symmetric,
calibration_method: Some(calib_method),
..Default::default()
};
let outputs = Quantizer::with_calibration(config, activation_stats)
.quantize_model(&model)
.map_err(|e| PyRuntimeError::new_err(format!("Quantization failed: {}", e)))?;
let quantized_data: Vec<QdqWeightInput> = outputs.into_iter().map(|o| o.qdq).collect();
let save_options = SaveOptions::default().with_native_int4(native_int4);
model
.save_quantized_with_options(&quantized_data, output_path, save_options)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to save model: {}", e)))?;
Ok(())
}
#[pyfunction]
fn model_info(input_path: &str) -> PyResult<ModelInfo> {
let model = OnnxModel::load(input_path)
.map_err(|e| PyRuntimeError::new_err(format!("Failed to load model: {}", e)))?;
let info = model.info();
Ok(ModelInfo {
name: info.name,
version: info.version,
num_nodes: info.num_nodes,
inputs: info.inputs,
outputs: info.outputs,
})
}
#[pymodule]
fn quantize_rs(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(quantize, m)?)?;
#[cfg(feature = "calibration")]
m.add_function(wrap_pyfunction!(quantize_with_calibration, m)?)?;
m.add_function(wrap_pyfunction!(model_info, m)?)?;
m.add_class::<ModelInfo>()?;
m.add("__version__", env!("CARGO_PKG_VERSION"))?;
Ok(())
}