batuta/oracle/rag/quantization/
calibration.rs1use super::error::{validate_embedding, QuantizationError};
7use super::params::QuantizationParams;
8
9#[derive(Debug, Clone)]
14pub struct CalibrationStats {
15 pub absmax: f32,
17 pub mean: Vec<f32>,
19 pub(crate) m2: Vec<f32>,
21 pub n_samples: usize,
23 pub dims: usize,
25}
26
27impl CalibrationStats {
28 pub fn new(dims: usize) -> Self {
30 Self { absmax: 0.0, mean: vec![0.0; dims], m2: vec![0.0; dims], n_samples: 0, dims }
31 }
32
33 pub fn update(&mut self, embedding: &[f32]) -> Result<(), QuantizationError> {
37 validate_embedding(embedding, self.dims)?;
39
40 self.n_samples += 1;
41 let n = self.n_samples as f32;
42
43 for (i, &v) in embedding.iter().enumerate() {
44 self.absmax = self.absmax.max(v.abs());
46
47 let delta = v - self.mean[i];
49 self.mean[i] += delta / n;
50 let delta2 = v - self.mean[i];
51 self.m2[i] += delta * delta2;
52 }
53
54 Ok(())
55 }
56
57 pub fn update_batch(&mut self, embeddings: &[Vec<f32>]) -> Result<(), QuantizationError> {
59 for embedding in embeddings {
60 self.update(embedding)?;
61 }
62 Ok(())
63 }
64
65 pub fn variance(&self, i: usize) -> f32 {
67 if self.n_samples < 2 || i >= self.dims {
68 return 0.0;
69 }
70 self.m2[i] / (self.n_samples - 1) as f32
71 }
72
73 pub fn std_dev(&self, i: usize) -> f32 {
75 self.variance(i).sqrt()
76 }
77
78 pub fn to_quant_params(&self) -> Result<QuantizationParams, QuantizationError> {
80 if self.n_samples == 0 {
81 return Err(QuantizationError::CalibrationNotInitialized);
82 }
83 let absmax = if self.absmax == 0.0 { 1.0 } else { self.absmax };
84 QuantizationParams::from_absmax(absmax, self.dims)
85 }
86
87 pub fn is_sufficient(&self, min_samples: usize) -> bool {
89 self.n_samples >= min_samples
90 }
91}