use crate::errors::{QuantizeError, Result};
use std::collections::HashMap;
use tract_onnx::prelude::*;
use crate::calibration::stats::ActivationStats;
use crate::calibration::CalibrationDataset;
use crate::onnx_utils::OnnxModel;
pub struct ActivationEstimator {
model: OnnxModel,
#[allow(clippy::type_complexity)]
tract_model: SimplePlan<TypedFact, Box<dyn TypedOp>, Graph<TypedFact, Box<dyn TypedOp>>>,
layer_stats: HashMap<String, ActivationStats>,
output_names: Vec<String>,
}
impl std::fmt::Debug for ActivationEstimator {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ActivationEstimator")
.field("model", &self.model)
.field("layer_stats_count", &self.layer_stats.len())
.field("output_names_count", &self.output_names.len())
.finish()
}
}
impl ActivationEstimator {
pub fn from_path(model: OnnxModel, onnx_path: &str) -> Result<Self> {
let mut tract_model = tract_onnx::onnx().model_for_path(onnx_path).map_err(|e| {
QuantizeError::Calibration {
reason: format!("tract failed to load ONNX model '{}': {e}", onnx_path),
}
})?;
let node_count = tract_model.nodes.len();
let original_outputs: Vec<OutletId> = tract_model.outputs.to_vec();
for node_id in 0..node_count {
let node = &tract_model.nodes[node_id];
if node.op_is::<tract_onnx::tract_core::ops::source::TypedSource>()
|| node.op_is::<tract_onnx::tract_core::ops::konst::Const>()
{
continue;
}
for output_idx in 0..node.outputs.len() {
let outlet = OutletId::new(node_id, output_idx);
if !original_outputs.contains(&outlet) {
tract_model.outputs.push(outlet);
}
}
}
let optimized_model =
tract_model
.into_optimized()
.map_err(|e| QuantizeError::Calibration {
reason: format!("tract optimization failed: {e}"),
})?;
let mut output_names = Vec::new();
for outlet in optimized_model.outputs.iter() {
let node = &optimized_model.nodes[outlet.node];
output_names.push(node.name.clone());
}
let tract_model =
optimized_model
.into_runnable()
.map_err(|e| QuantizeError::Calibration {
reason: format!("tract failed to create runnable plan: {e}"),
})?;
Ok(Self {
model,
tract_model,
layer_stats: HashMap::new(),
output_names,
})
}
pub fn new(model: OnnxModel, onnx_path: &str) -> Result<Self> {
Self::from_path(model, onnx_path)
}
pub fn calibrate(&mut self, dataset: &CalibrationDataset) -> Result<()> {
if dataset.is_empty() {
return Err(QuantizeError::Calibration {
reason: "Calibration dataset is empty".into(),
});
}
println!(
"Running activation-based calibration on {} samples...",
dataset.len()
);
let num_samples = dataset.len();
for (sample_idx, sample) in dataset.samples.iter().enumerate() {
self.process_sample(sample, &dataset.shape)?;
if (sample_idx + 1) % (num_samples / 10).max(1) == 0 || sample_idx == num_samples - 1 {
println!(" Processed {}/{} samples", sample_idx + 1, num_samples);
}
}
println!(
"✓ Calibration complete: {} layers tracked",
self.layer_stats.len()
);
Ok(())
}
fn process_sample(&mut self, sample: &[f32], shape: &[usize]) -> Result<()> {
let mut input_shape = vec![1]; input_shape.extend_from_slice(shape);
let input_tensor =
tract_core::prelude::Tensor::from_shape(&input_shape, sample).map_err(|e| {
QuantizeError::Calibration {
reason: format!("Failed to create input tensor from calibration sample: {e}"),
}
})?;
let outputs = self
.tract_model
.run(tvec!(input_tensor.into()))
.map_err(|e| QuantizeError::Calibration {
reason: format!("tract inference failed on calibration sample: {e}"),
})?;
for (output_idx, tvalue) in outputs.iter().enumerate() {
let layer_name = if output_idx < self.output_names.len() {
&self.output_names[output_idx]
} else {
continue;
};
let tensor = tvalue.clone().into_tensor();
let data = extract_f32_data(&tensor)?;
self.layer_stats
.entry(layer_name.clone())
.and_modify(|stats| stats.update(&data))
.or_insert_with(|| ActivationStats::from_data(&data));
}
Ok(())
}
pub fn get_layer_stats(&self) -> HashMap<String, &ActivationStats> {
self.layer_stats
.iter()
.map(|(name, stats)| (name.clone(), stats))
.collect()
}
pub fn into_layer_stats(self) -> HashMap<String, ActivationStats> {
self.layer_stats
}
pub fn get_layer_stats_mut(&mut self) -> &mut HashMap<String, ActivationStats> {
&mut self.layer_stats
}
pub fn into_model(self) -> OnnxModel {
self.model
}
pub fn model(&self) -> &OnnxModel {
&self.model
}
}
fn extract_f32_data(tensor: &Tensor) -> Result<Vec<f32>> {
match tensor.to_array_view::<f32>() {
Ok(view) => {
Ok(view.iter().copied().collect())
}
Err(_) => {
let tensor_f32 = tensor
.cast_to::<f32>()
.map_err(|e| QuantizeError::Calibration {
reason: format!("Failed to cast tensor to f32 for activation statistics: {e}"),
})?;
let view =
tensor_f32
.to_array_view::<f32>()
.map_err(|e| QuantizeError::Calibration {
reason: format!("Tensor cast succeeded but array view failed: {e}"),
})?;
Ok(view.iter().copied().collect())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[ignore] fn test_activation_estimator_real_inference() {
let model_paths = vec![
"mnist.onnx",
"test_models/mnist.onnx",
"resnet18-v1-7.onnx",
"test_models/resnet18-v1-7.onnx",
];
let mut found_path = None;
for path in model_paths {
if std::path::Path::new(path).exists() {
found_path = Some(path);
break;
}
}
let model_path = match found_path {
Some(p) => p,
None => {
println!(
"No test model found. Place mnist.onnx or resnet18-v1-7.onnx in project root."
);
return;
}
};
println!("Testing with model: {}", model_path);
let model = OnnxModel::load(model_path).expect("Failed to load model");
let info = model.info();
println!("Model: {}, {} nodes", info.name, info.num_nodes);
let input_shape = if model_path.contains("mnist") {
vec![1, 28, 28]
} else {
vec![3, 224, 224]
};
let dataset = CalibrationDataset::random(input_shape, 5, (0.0, 1.0)).unwrap();
let mut estimator = ActivationEstimator::new(model, model_path)
.expect("Failed to create ActivationEstimator");
estimator.calibrate(&dataset).expect("Calibration failed");
let stats = estimator.get_layer_stats();
assert!(!stats.is_empty(), "No activation statistics collected");
println!("\nCollected stats for {} layers:", stats.len());
for (name, stat) in stats.iter().take(5) {
println!(
" {}: min={:.4}, max={:.4}, mean={:.4}",
name,
stat.min(),
stat.max(),
stat.mean()
);
}
for (name, stat) in stats.iter() {
assert!(
(stat.max() - stat.min()).abs() > 1e-6,
"Layer {} has constant output (min={}, max={})",
name,
stat.min(),
stat.max()
);
}
}
#[test]
#[ignore]
fn test_calibration_dataset_integration() {
let model_path = "mnist.onnx";
if !std::path::Path::new(model_path).exists() {
println!("mnist.onnx not found, skipping integration test");
return;
}
let model = OnnxModel::load(model_path).unwrap();
let dataset = CalibrationDataset::random(vec![1, 28, 28], 10, (0.0, 1.0)).unwrap();
let mut estimator = ActivationEstimator::new(model, model_path).unwrap();
estimator.calibrate(&dataset).unwrap();
let stats = estimator.get_layer_stats();
assert!(stats.len() > 0);
for (_name, stat) in stats.iter() {
assert!(stat.count() > 0);
}
}
}