use crate::error::{QuantError, QuantResult};
pub trait Observer {
fn observe(&mut self, data: &[f32]);
fn compute_params(&self) -> QuantResult<(f32, i32)>;
fn reset(&mut self);
fn is_calibrated(&self) -> bool;
}
fn sym_scale(abs_max: f32, bits: u32) -> f32 {
let q_max = (1i32 << (bits - 1)) as f32 - 1.0;
abs_max.max(1e-8) / q_max
}
fn asym_scale_zp(min_val: f32, max_val: f32, bits: u32) -> (f32, i32) {
let q_range = ((1u32 << bits) - 1) as f32;
let range = (max_val - min_val).max(1e-8);
let scale = range / q_range;
let zp = (-min_val / scale).round().clamp(0.0, q_range) as i32;
(scale, zp)
}
#[derive(Debug, Clone)]
pub struct MinMaxObserver {
pub min_val: f32,
pub max_val: f32,
pub bits: u32,
pub symmetric: bool,
}
impl MinMaxObserver {
#[must_use]
pub fn new(bits: u32, symmetric: bool) -> Self {
assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
Self {
min_val: f32::INFINITY,
max_val: f32::NEG_INFINITY,
bits,
symmetric,
}
}
}
impl Observer for MinMaxObserver {
fn observe(&mut self, data: &[f32]) {
for &v in data {
if v.is_finite() {
if v < self.min_val {
self.min_val = v;
}
if v > self.max_val {
self.max_val = v;
}
}
}
}
fn compute_params(&self) -> QuantResult<(f32, i32)> {
if !self.is_calibrated() {
return Err(QuantError::CalibrationRequired("MinMaxObserver"));
}
if self.symmetric {
let abs_max = self.min_val.abs().max(self.max_val.abs());
Ok((sym_scale(abs_max, self.bits), 0))
} else {
Ok(asym_scale_zp(self.min_val, self.max_val, self.bits))
}
}
fn reset(&mut self) {
self.min_val = f32::INFINITY;
self.max_val = f32::NEG_INFINITY;
}
fn is_calibrated(&self) -> bool {
self.min_val.is_finite() && self.max_val.is_finite()
}
}
#[derive(Debug, Clone)]
pub struct MovingAvgObserver {
pub min_val: f32,
pub max_val: f32,
pub momentum: f32,
pub bits: u32,
pub symmetric: bool,
initialized: bool,
}
impl MovingAvgObserver {
#[must_use]
pub fn new(bits: u32, symmetric: bool, momentum: f32) -> Self {
assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
assert!(
momentum > 0.0 && momentum < 1.0,
"momentum must be in (0, 1), got {momentum}"
);
Self {
min_val: 0.0,
max_val: 0.0,
momentum,
bits,
symmetric,
initialized: false,
}
}
}
impl Observer for MovingAvgObserver {
fn observe(&mut self, data: &[f32]) {
if data.is_empty() {
return;
}
let batch_min = data
.iter()
.copied()
.filter(|v| v.is_finite())
.fold(f32::INFINITY, f32::min);
let batch_max = data
.iter()
.copied()
.filter(|v| v.is_finite())
.fold(f32::NEG_INFINITY, f32::max);
if !batch_min.is_finite() || !batch_max.is_finite() {
return;
}
if !self.initialized {
self.min_val = batch_min;
self.max_val = batch_max;
self.initialized = true;
} else {
let m = self.momentum;
self.min_val = m * self.min_val + (1.0 - m) * batch_min;
self.max_val = m * self.max_val + (1.0 - m) * batch_max;
}
}
fn compute_params(&self) -> QuantResult<(f32, i32)> {
if !self.is_calibrated() {
return Err(QuantError::CalibrationRequired("MovingAvgObserver"));
}
if self.symmetric {
let abs_max = self.min_val.abs().max(self.max_val.abs());
Ok((sym_scale(abs_max, self.bits), 0))
} else {
Ok(asym_scale_zp(self.min_val, self.max_val, self.bits))
}
}
fn reset(&mut self) {
self.min_val = 0.0;
self.max_val = 0.0;
self.initialized = false;
}
fn is_calibrated(&self) -> bool {
self.initialized
}
}
#[derive(Debug, Clone)]
pub struct HistogramObserver {
bins: Vec<u64>,
range_min: f32,
range_max: f32,
n_bins: usize,
pub bits: u32,
pub symmetric: bool,
initialized: bool,
}
impl HistogramObserver {
#[must_use]
pub fn new(bits: u32, symmetric: bool, n_bins: usize) -> Self {
assert!(bits > 0 && bits <= 16, "bits must be in [1, 16]");
assert!(n_bins > 0, "n_bins must be > 0");
Self {
bins: vec![0_u64; n_bins],
range_min: 0.0,
range_max: 0.0,
n_bins,
bits,
symmetric,
initialized: false,
}
}
fn bin_width(&self) -> f32 {
(self.range_max - self.range_min) / self.n_bins as f32
}
fn estimate_mse(&self, lo: f32, hi: f32) -> f32 {
let bw = self.bin_width();
let total: u64 = self.bins.iter().sum();
if total == 0 || (hi - lo).abs() < 1e-12 {
return f32::INFINITY;
}
let n_levels = ((1u32 << self.bits) - 1) as f32;
let step = (hi - lo) / n_levels;
let mut mse = 0.0_f32;
for (b, &cnt) in self.bins.iter().enumerate() {
if cnt == 0 {
continue;
}
let center = self.range_min + (b as f32 + 0.5) * bw;
let quant_val = if center <= lo {
lo
} else if center >= hi {
hi
} else {
let idx = ((center - lo) / step).round();
lo + idx * step
};
let err = center - quant_val;
mse += cnt as f32 * err * err;
}
mse / total as f32
}
}
impl Observer for HistogramObserver {
fn observe(&mut self, data: &[f32]) {
let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
if finite.is_empty() {
return;
}
let d_min = finite.iter().copied().fold(f32::INFINITY, f32::min);
let d_max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
if !self.initialized {
self.range_min = d_min;
self.range_max = d_max;
self.initialized = true;
} else {
if d_min < self.range_min {
self.range_min = d_min;
}
if d_max > self.range_max {
self.range_max = d_max;
}
}
if (self.range_max - self.range_min).abs() < 1e-8 {
self.range_max = self.range_min + 1e-8;
}
let bw = self.bin_width();
for &v in &finite {
let idx = ((v - self.range_min) / bw) as usize;
let idx = idx.min(self.n_bins - 1);
self.bins[idx] += 1;
}
}
fn compute_params(&self) -> QuantResult<(f32, i32)> {
if !self.is_calibrated() {
return Err(QuantError::CalibrationRequired("HistogramObserver"));
}
let n_search = 20_usize;
let mut best_mse = f32::INFINITY;
let mut best_lo = self.range_min;
let mut best_hi = self.range_max;
let total: u64 = self.bins.iter().sum();
if total == 0 {
return Err(QuantError::CalibrationRequired("HistogramObserver"));
}
let percentiles: Vec<f32> = (1..=n_search).map(|i| i as f32 / n_search as f32).collect();
for &pct in &percentiles {
let threshold = (pct * total as f32) as u64;
let mut cum = 0_u64;
let mut cut_bin = self.n_bins - 1;
for (b, &cnt) in self.bins.iter().enumerate() {
cum += cnt;
if cum >= threshold {
cut_bin = b;
break;
}
}
let bw = self.bin_width();
let hi = self.range_min + (cut_bin as f32 + 1.0) * bw;
let lo = if self.symmetric { -hi } else { self.range_min };
let mse = self.estimate_mse(lo, hi);
if mse < best_mse {
best_mse = mse;
best_lo = lo;
best_hi = hi;
}
}
if self.symmetric {
let abs_max = best_lo.abs().max(best_hi.abs());
Ok((sym_scale(abs_max, self.bits), 0))
} else {
Ok(asym_scale_zp(best_lo, best_hi, self.bits))
}
}
fn reset(&mut self) {
self.bins.fill(0);
self.range_min = 0.0;
self.range_max = 0.0;
self.initialized = false;
}
fn is_calibrated(&self) -> bool {
self.initialized
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_abs_diff_eq;
#[test]
fn minmax_symmetric_scale() {
let mut obs = MinMaxObserver::new(8, true);
obs.observe(&[-2.0_f32, -1.0, 0.5, 2.0]);
let (scale, zp) = obs.compute_params().unwrap();
assert_abs_diff_eq!(scale, 2.0 / 127.0, epsilon = 1e-6);
assert_eq!(zp, 0);
}
#[test]
fn minmax_asymmetric_scale_zp() {
let mut obs = MinMaxObserver::new(8, false);
obs.observe(&[0.0_f32, 1.0, 2.0, 3.0]);
let (scale, zp) = obs.compute_params().unwrap();
assert_abs_diff_eq!(scale, 3.0 / 255.0, epsilon = 1e-5);
assert_eq!(zp, 0);
}
#[test]
fn minmax_calibration_required() {
let obs = MinMaxObserver::new(8, true);
assert!(matches!(
obs.compute_params(),
Err(QuantError::CalibrationRequired(_))
));
}
#[test]
fn minmax_reset() {
let mut obs = MinMaxObserver::new(8, true);
obs.observe(&[1.0_f32, 2.0]);
obs.reset();
assert!(!obs.is_calibrated());
}
#[test]
fn moving_avg_first_batch_exact() {
let mut obs = MovingAvgObserver::new(8, true, 0.9);
obs.observe(&[-1.0_f32, 1.0]);
let (scale, zp) = obs.compute_params().unwrap();
assert_abs_diff_eq!(scale, 1.0 / 127.0, epsilon = 1e-5);
assert_eq!(zp, 0);
}
#[test]
fn moving_avg_ema_update() {
let mut obs = MovingAvgObserver::new(8, true, 0.9);
obs.observe(&[2.0_f32, 2.0]); obs.observe(&[4.0_f32, 4.0]); assert_abs_diff_eq!(obs.max_val, 2.2, epsilon = 1e-5);
}
#[test]
fn moving_avg_calibration_required() {
let obs = MovingAvgObserver::new(8, true, 0.9);
assert!(matches!(
obs.compute_params(),
Err(QuantError::CalibrationRequired(_))
));
}
#[test]
fn histogram_observer_calibrates() {
let mut obs = HistogramObserver::new(8, true, 256);
let data: Vec<f32> = (0..1024).map(|i| (i as f32 / 512.0) - 1.0).collect();
obs.observe(&data);
assert!(obs.is_calibrated());
let (scale, zp) = obs.compute_params().unwrap();
assert!(scale > 0.0, "scale must be positive: {scale}");
assert_eq!(zp, 0, "symmetric: zp must be 0");
}
#[test]
fn histogram_observer_reset() {
let mut obs = HistogramObserver::new(8, true, 128);
obs.observe(&[1.0_f32, 2.0]);
obs.reset();
assert!(!obs.is_calibrated());
}
#[test]
fn histogram_observer_uncalibrated_error() {
let obs = HistogramObserver::new(8, true, 64);
assert!(matches!(
obs.compute_params(),
Err(QuantError::CalibrationRequired(_))
));
}
}