use super::error::{validate_embedding, QuantizationError};
use super::params::QuantizationParams;
#[derive(Debug, Clone)]
pub struct CalibrationStats {
pub absmax: f32,
pub mean: Vec<f32>,
pub(crate) m2: Vec<f32>,
pub n_samples: usize,
pub dims: usize,
}
impl CalibrationStats {
pub fn new(dims: usize) -> Self {
Self { absmax: 0.0, mean: vec![0.0; dims], m2: vec![0.0; dims], n_samples: 0, dims }
}
pub fn update(&mut self, embedding: &[f32]) -> Result<(), QuantizationError> {
validate_embedding(embedding, self.dims)?;
self.n_samples += 1;
let n = self.n_samples as f32;
for (i, &v) in embedding.iter().enumerate() {
self.absmax = self.absmax.max(v.abs());
let delta = v - self.mean[i];
self.mean[i] += delta / n;
let delta2 = v - self.mean[i];
self.m2[i] += delta * delta2;
}
Ok(())
}
pub fn update_batch(&mut self, embeddings: &[Vec<f32>]) -> Result<(), QuantizationError> {
for embedding in embeddings {
self.update(embedding)?;
}
Ok(())
}
pub fn variance(&self, i: usize) -> f32 {
if self.n_samples < 2 || i >= self.dims {
return 0.0;
}
self.m2[i] / (self.n_samples - 1) as f32
}
pub fn std_dev(&self, i: usize) -> f32 {
self.variance(i).sqrt()
}
pub fn to_quant_params(&self) -> Result<QuantizationParams, QuantizationError> {
if self.n_samples == 0 {
return Err(QuantizationError::CalibrationNotInitialized);
}
let absmax = if self.absmax == 0.0 { 1.0 } else { self.absmax };
QuantizationParams::from_absmax(absmax, self.dims)
}
pub fn is_sufficient(&self, min_samples: usize) -> bool {
self.n_samples >= min_samples
}
}