Skip to main content

batuta/oracle/rag/quantization/
calibration.rs

1//! Calibration statistics for quantization
2//!
3//! Following Kaizen: continuously improved from query distribution.
4//! Uses Welford's algorithm for numerical stability (Higham, 2002).
5
6use super::error::{validate_embedding, QuantizationError};
7use super::params::QuantizationParams;
8
9/// Calibration statistics for quantization
10///
11/// Following Kaizen: continuously improved from query distribution.
12/// Uses Welford's algorithm for numerical stability (Higham, 2002).
13#[derive(Debug, Clone)]
14pub struct CalibrationStats {
15    /// Maximum absolute value across calibration set
16    pub absmax: f32,
17    /// Running mean for each dimension
18    pub mean: Vec<f32>,
19    /// Running M2 for variance calculation (Welford's)
20    pub(crate) m2: Vec<f32>,
21    /// Number of samples seen
22    pub n_samples: usize,
23    /// Embedding dimensions
24    pub dims: usize,
25}
26
27impl CalibrationStats {
28    /// Create new calibration stats for given dimensions
29    pub fn new(dims: usize) -> Self {
30        Self { absmax: 0.0, mean: vec![0.0; dims], m2: vec![0.0; dims], n_samples: 0, dims }
31    }
32
33    /// Update calibration with new embedding (Kaizen loop)
34    ///
35    /// Uses Welford's online algorithm for numerical stability.
36    pub fn update(&mut self, embedding: &[f32]) -> Result<(), QuantizationError> {
37        // Validate dimensions and finite values upfront
38        validate_embedding(embedding, self.dims)?;
39
40        self.n_samples += 1;
41        let n = self.n_samples as f32;
42
43        for (i, &v) in embedding.iter().enumerate() {
44            // Update absmax
45            self.absmax = self.absmax.max(v.abs());
46
47            // Welford's algorithm for mean and variance
48            let delta = v - self.mean[i];
49            self.mean[i] += delta / n;
50            let delta2 = v - self.mean[i];
51            self.m2[i] += delta * delta2;
52        }
53
54        Ok(())
55    }
56
57    /// Update with batch of embeddings (Heijunka batching)
58    pub fn update_batch(&mut self, embeddings: &[Vec<f32>]) -> Result<(), QuantizationError> {
59        for embedding in embeddings {
60            self.update(embedding)?;
61        }
62        Ok(())
63    }
64
65    /// Get variance for dimension i
66    pub fn variance(&self, i: usize) -> f32 {
67        if self.n_samples < 2 || i >= self.dims {
68            return 0.0;
69        }
70        self.m2[i] / (self.n_samples - 1) as f32
71    }
72
73    /// Get standard deviation for dimension i
74    pub fn std_dev(&self, i: usize) -> f32 {
75        self.variance(i).sqrt()
76    }
77
78    /// Convert calibration to quantization parameters
79    pub fn to_quant_params(&self) -> Result<QuantizationParams, QuantizationError> {
80        if self.n_samples == 0 {
81            return Err(QuantizationError::CalibrationNotInitialized);
82        }
83        let absmax = if self.absmax == 0.0 { 1.0 } else { self.absmax };
84        QuantizationParams::from_absmax(absmax, self.dims)
85    }
86
87    /// Check if calibration has sufficient samples
88    pub fn is_sufficient(&self, min_samples: usize) -> bool {
89        self.n_samples >= min_samples
90    }
91}