use axonml_tensor::Tensor;
use crate::error::{QuantError, QuantResult};
use crate::types::QuantType;
#[derive(Debug, Clone)]
pub struct CalibrationData {
pub min: f32,
pub max: f32,
pub mean: f32,
pub std_dev: f32,
pub num_samples: usize,
histogram: Vec<usize>,
bin_edges: Vec<f32>,
}
impl CalibrationData {
pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
let data = tensor.to_vec();
let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let mean = data.iter().sum::<f32>() / data.len() as f32;
let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
let std_dev = variance.sqrt();
let bin_width = (max - min) / num_bins as f32;
let mut histogram = vec![0usize; num_bins];
let bin_edges: Vec<f32> = (0..=num_bins).map(|i| min + i as f32 * bin_width).collect();
for &val in &data {
let bin = ((val - min) / bin_width) as usize;
let bin = bin.min(num_bins - 1);
histogram[bin] += 1;
}
Self {
min,
max,
mean,
std_dev,
num_samples: data.len(),
histogram,
bin_edges,
}
}
pub fn update(&mut self, tensor: &Tensor<f32>) {
let data = tensor.to_vec();
let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
self.min = self.min.min(new_min);
self.max = self.max.max(new_max);
let old_mean = self.mean;
let old_count = self.num_samples;
for &val in &data {
self.num_samples += 1;
let delta = val - self.mean;
self.mean += delta / self.num_samples as f32;
}
if old_count > 0 && !data.is_empty() {
let new_mean_batch: f32 = data.iter().sum::<f32>() / data.len() as f32;
let new_var_batch: f32 = data
.iter()
.map(|&v| (v - new_mean_batch).powi(2))
.sum::<f32>()
/ data.len() as f32;
let old_var = self.std_dev * self.std_dev;
let n1 = old_count as f32;
let n2 = data.len() as f32;
let combined_var = (n1 * old_var
+ n2 * new_var_batch
+ n1 * n2 / (n1 + n2) * (old_mean - new_mean_batch).powi(2))
/ (n1 + n2);
self.std_dev = combined_var.sqrt();
} else if !data.is_empty() {
let m: f32 = data.iter().sum::<f32>() / data.len() as f32;
self.std_dev =
(data.iter().map(|&v| (v - m).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
self.num_samples = data.len();
}
if !self.histogram.is_empty() && self.max > self.min {
let n_bins = self.histogram.len();
let bin_width = (self.max - self.min) / n_bins as f32;
for &val in &data {
let bin = ((val - self.min) / bin_width).floor() as usize;
let bin = bin.min(n_bins - 1);
self.histogram[bin] += 1;
}
}
}
pub fn dynamic_range(&self) -> f32 {
self.max - self.min
}
pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
let max_abs = self.min.abs().max(self.max.abs());
let max_int = match quant_type {
QuantType::Q8_0 => 127.0,
QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
QuantType::F16 | QuantType::F32 => 1.0,
};
max_abs / max_int
}
pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
let max_int = match quant_type {
QuantType::Q8_0 => 255.0,
QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
QuantType::F16 | QuantType::F32 => 1.0,
};
let scale = (self.max - self.min) / max_int;
let zero_point = -self.min / scale;
(scale, zero_point)
}
pub fn percentile(&self, p: f32) -> f32 {
if p <= 0.0 {
return self.min;
}
if p >= 100.0 {
return self.max;
}
let target = (p / 100.0 * self.num_samples as f32) as usize;
let mut cumsum = 0usize;
for (i, &count) in self.histogram.iter().enumerate() {
cumsum += count;
if cumsum >= target {
return self.bin_edges[i];
}
}
self.max
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CalibrationMethod {
MinMax,
Percentile(u32), Entropy,
MeanStd(u32), }
pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
let mut data = CalibrationData::new(tensor, 2048);
match method {
CalibrationMethod::MinMax => {
}
CalibrationMethod::Percentile(p) => {
let percentile = p as f32 / 10.0;
let lower = data.percentile(100.0 - percentile);
let upper = data.percentile(percentile);
data.min = lower;
data.max = upper;
}
CalibrationMethod::MeanStd(k) => {
let k_factor = k as f32 / 10.0;
data.min = data.mean - k_factor * data.std_dev;
data.max = data.mean + k_factor * data.std_dev;
}
CalibrationMethod::Entropy => {
let n_bins = data.histogram.len();
if n_bins < 4 {
data.min = data.percentile(0.01);
data.max = data.percentile(99.99);
} else {
let total: usize = data.histogram.iter().sum();
if total == 0 {
data.min = data.percentile(0.01);
data.max = data.percentile(99.99);
} else {
let ref_dist: Vec<f64> = data
.histogram
.iter()
.map(|&c| c as f64 / total as f64 + 1e-12)
.collect();
let quant_bins = 128usize; let mut best_kl = f64::MAX;
let mut best_threshold = n_bins;
for threshold in (n_bins / 2)..n_bins {
let mut clipped = ref_dist[..threshold].to_vec();
let outlier_mass: f64 = ref_dist[threshold..].iter().sum();
if let Some(last) = clipped.last_mut() {
*last += outlier_mass;
}
let bins_per_quant = threshold.div_ceil(quant_bins);
let mut quant_dist = vec![0.0f64; quant_bins.min(threshold)];
for (i, &p) in clipped.iter().enumerate() {
let qi = (i / bins_per_quant).min(quant_dist.len() - 1);
quant_dist[qi] += p;
}
let mut expanded = vec![0.0f64; threshold];
for (qi, &qval) in quant_dist.iter().enumerate() {
let start = qi * bins_per_quant;
let end = ((qi + 1) * bins_per_quant).min(threshold);
let count = (end - start) as f64;
if count > 0.0 {
let val = qval / count;
for slot in expanded.iter_mut().take(end).skip(start) {
*slot = val + 1e-12;
}
}
}
let kl: f64 = clipped
.iter()
.zip(expanded.iter())
.map(|(&p, &q)| if p > 1e-12 { p * (p / q).ln() } else { 0.0 })
.sum();
if kl < best_kl {
best_kl = kl;
best_threshold = threshold;
}
}
let bin_width = (data.max - data.min) / n_bins as f32;
let clip_max = data.min + best_threshold as f32 * bin_width;
data.max = clip_max;
if data.min < 0.0 && data.max > 0.0 {
let abs_max = data.max.abs().max(data.min.abs());
data.min = -abs_max;
data.max = abs_max;
}
}
}
}
}
Ok(data)
}
pub fn calibrate_batch(
tensors: &[&Tensor<f32>],
method: CalibrationMethod,
) -> QuantResult<CalibrationData> {
if tensors.is_empty() {
return Err(QuantError::CalibrationError(
"No tensors provided".to_string(),
));
}
let mut combined = CalibrationData::new(tensors[0], 2048);
for tensor in tensors.iter().skip(1) {
combined.update(tensor);
}
match method {
CalibrationMethod::Percentile(p) => {
let percentile = p as f32 / 10.0;
combined.min = combined.percentile(100.0 - percentile);
combined.max = combined.percentile(percentile);
}
CalibrationMethod::MeanStd(k) => {
let k_factor = k as f32 / 10.0;
combined.min = combined.mean - k_factor * combined.std_dev;
combined.max = combined.mean + k_factor * combined.std_dev;
}
_ => {}
}
Ok(combined)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calibration_data() {
let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let tensor = Tensor::from_vec(data, &[5]).unwrap();
let calib = CalibrationData::new(&tensor, 10);
assert_eq!(calib.min, 1.0);
assert_eq!(calib.max, 5.0);
assert_eq!(calib.mean, 3.0);
assert_eq!(calib.num_samples, 5);
}
#[test]
fn test_symmetric_scale() {
let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
let tensor = Tensor::from_vec(data, &[5]).unwrap();
let calib = CalibrationData::new(&tensor, 10);
let scale = calib.symmetric_scale(QuantType::Q8_0);
assert!((scale - 4.0 / 127.0).abs() < 0.001);
}
#[test]
fn test_calibration_methods() {
let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
let tensor = Tensor::from_vec(data, &[1000]).unwrap();
let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
assert!((minmax.min - 0.0).abs() < 0.01);
assert!((minmax.max - 9.99).abs() < 0.01);
let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
assert!(percentile.min >= 0.0);
assert!(percentile.max <= 9.99);
}
#[test]
fn test_dynamic_range() {
let data = vec![-5.0, 10.0];
let tensor = Tensor::from_vec(data, &[2]).unwrap();
let calib = CalibrationData::new(&tensor, 10);
assert_eq!(calib.dynamic_range(), 15.0);
}
}