use crate::Tensor;
use super::helpers::rand_simple;
use super::types::{CalibrationMethod, CalibrationResult};
#[derive(Clone, Debug)]
pub struct Calibrator {
method: CalibrationMethod,
symmetric: bool,
bits: usize,
running_min: Option<f32>,
running_max: Option<f32>,
samples: Vec<f32>,
max_samples: usize,
num_batches: usize,
}
impl Calibrator {
pub fn min_max(bits: usize, symmetric: bool) -> Self {
Self {
method: CalibrationMethod::MinMax,
symmetric,
bits,
running_min: None,
running_max: None,
samples: Vec::new(),
max_samples: 0,
num_batches: 0,
}
}
pub fn percentile(
bits: usize,
symmetric: bool,
lower: f32,
upper: f32,
max_samples: usize,
) -> Self {
Self {
method: CalibrationMethod::Percentile { lower, upper },
symmetric,
bits,
running_min: None,
running_max: None,
samples: Vec::with_capacity(max_samples.min(10000)),
max_samples,
num_batches: 0,
}
}
pub fn moving_average(bits: usize, symmetric: bool, momentum: f32) -> Self {
Self {
method: CalibrationMethod::MovingAverage { momentum },
symmetric,
bits,
running_min: None,
running_max: None,
samples: Vec::new(),
max_samples: 0,
num_batches: 0,
}
}
pub fn observe(&mut self, data: &[f32]) {
if data.is_empty() {
return;
}
match &self.method {
CalibrationMethod::MinMax => {
self.observe_min_max(data);
}
CalibrationMethod::Percentile { .. } => {
self.observe_percentile(data);
}
CalibrationMethod::MovingAverage { momentum } => {
let momentum = *momentum;
self.observe_moving_average(data, momentum);
}
}
self.num_batches += 1;
}
pub fn observe_tensor(&mut self, tensor: &Tensor) {
if let Some(slice) = tensor.data().as_slice() {
self.observe(slice);
}
}
pub fn observe_tensors(&mut self, tensors: &[&Tensor]) {
for tensor in tensors {
self.observe_tensor(tensor);
}
}
pub fn compute(&self) -> CalibrationResult {
let (observed_min, observed_max) = match &self.method {
CalibrationMethod::MinMax | CalibrationMethod::MovingAverage { .. } => {
(self.running_min.unwrap_or(0.0), self.running_max.unwrap_or(0.0))
}
CalibrationMethod::Percentile { lower, upper } => {
self.compute_percentile_bounds(*lower, *upper)
}
};
let (scale, zero_point) = self.compute_scale_zero_point(observed_min, observed_max);
CalibrationResult {
scale,
zero_point,
observed_min,
observed_max,
method: self.method.clone(),
}
}
pub fn num_batches(&self) -> usize {
self.num_batches
}
pub fn method(&self) -> &CalibrationMethod {
&self.method
}
pub fn has_data(&self) -> bool {
self.num_batches > 0
}
pub fn reset(&mut self) {
self.running_min = None;
self.running_max = None;
self.samples.clear();
self.num_batches = 0;
}
fn observe_min_max(&mut self, data: &[f32]) {
let batch_min = data.iter().copied().fold(f32::INFINITY, f32::min);
let batch_max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
self.running_min = Some(self.running_min.map_or(batch_min, |m| m.min(batch_min)));
self.running_max = Some(self.running_max.map_or(batch_max, |m| m.max(batch_max)));
}
fn observe_percentile(&mut self, data: &[f32]) {
if self.samples.len() < self.max_samples {
let remaining = self.max_samples - self.samples.len();
self.samples.extend(data.iter().take(remaining).copied());
} else {
let total_seen = self.num_batches * data.len() + data.len();
for (i, &val) in data.iter().enumerate() {
let j = rand_simple(total_seen + i);
if j < self.max_samples {
self.samples[j] = val;
}
}
}
self.observe_min_max(data);
}
fn observe_moving_average(&mut self, data: &[f32], momentum: f32) {
let batch_min = data.iter().copied().fold(f32::INFINITY, f32::min);
let batch_max = data.iter().copied().fold(f32::NEG_INFINITY, f32::max);
self.running_min = Some(
self.running_min.map_or(batch_min, |m| m * (1.0 - momentum) + batch_min * momentum),
);
self.running_max = Some(
self.running_max.map_or(batch_max, |m| m * (1.0 - momentum) + batch_max * momentum),
);
}
fn compute_percentile_bounds(&self, lower: f32, upper: f32) -> (f32, f32) {
if self.samples.is_empty() {
return (self.running_min.unwrap_or(0.0), self.running_max.unwrap_or(0.0));
}
let mut sorted = self.samples.clone();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let n = sorted.len();
let lower_idx = ((lower / 100.0) * n as f32) as usize;
let upper_idx = ((upper / 100.0) * n as f32).min((n - 1) as f32) as usize;
(sorted[lower_idx], sorted[upper_idx])
}
fn compute_scale_zero_point(&self, min_val: f32, max_val: f32) -> (f32, i32) {
let qmax = (1 << (self.bits - 1)) - 1;
let qmin = if self.symmetric { -qmax } else { 0 };
let qmax_full = if self.symmetric { qmax } else { (1 << self.bits) - 1 };
if self.symmetric {
let max_abs = min_val.abs().max(max_val.abs());
let scale = if max_abs < 1e-10 { 1e-10 } else { max_abs / qmax as f32 };
(scale, 0)
} else {
let range = max_val - min_val;
let scale = if range < 1e-10 { 1e-10 } else { range / (qmax_full - qmin) as f32 };
let zero_point = (qmin as f32 - min_val / scale).round() as i32;
let zero_point = zero_point.clamp(qmin, qmax_full);
(scale, zero_point)
}
}
}