Skip to main content

axonml_quant/
calibration.rs

1//! Calibration for Quantization
2//!
3//! # File
4//! `crates/axonml-quant/src/calibration.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18
19use crate::error::{QuantError, QuantResult};
20use crate::types::QuantType;
21
22// =============================================================================
23// Calibration Data
24// =============================================================================
25
26/// Calibration data collected from sample inputs.
27#[derive(Debug, Clone)]
28pub struct CalibrationData {
29    /// Minimum value seen.
30    pub min: f32,
31    /// Maximum value seen.
32    pub max: f32,
33    /// Mean value.
34    pub mean: f32,
35    /// Standard deviation.
36    pub std_dev: f32,
37    /// Number of samples.
38    pub num_samples: usize,
39    /// Histogram buckets (for percentile calibration).
40    histogram: Vec<usize>,
41    /// Histogram bin edges.
42    bin_edges: Vec<f32>,
43}
44
45impl CalibrationData {
46    /// Creates new calibration data from initial tensor.
47    pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
48        let data = tensor.to_vec();
49        let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
50        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
51        let mean = data.iter().sum::<f32>() / data.len() as f32;
52
53        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
54        let std_dev = variance.sqrt();
55
56        // Initialize histogram
57        let bin_width = (max - min) / num_bins as f32;
58        let mut histogram = vec![0usize; num_bins];
59        let bin_edges: Vec<f32> = (0..=num_bins).map(|i| min + i as f32 * bin_width).collect();
60
61        for &val in &data {
62            let bin = ((val - min) / bin_width) as usize;
63            let bin = bin.min(num_bins - 1);
64            histogram[bin] += 1;
65        }
66
67        Self {
68            min,
69            max,
70            mean,
71            std_dev,
72            num_samples: data.len(),
73            histogram,
74            bin_edges,
75        }
76    }
77
78    /// Updates calibration data with more samples (streaming Welford algorithm).
79    pub fn update(&mut self, tensor: &Tensor<f32>) {
80        let data = tensor.to_vec();
81        let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
82        let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
83
84        // Update min/max
85        self.min = self.min.min(new_min);
86        self.max = self.max.max(new_max);
87
88        // Welford's online algorithm for mean and variance
89        let old_mean = self.mean;
90        let old_count = self.num_samples;
91        for &val in &data {
92            self.num_samples += 1;
93            let delta = val - self.mean;
94            self.mean += delta / self.num_samples as f32;
95        }
96
97        // Update std_dev from running variance
98        // Use combined variance formula: Var(A∪B) from count, mean, var of each
99        if old_count > 0 && !data.is_empty() {
100            let new_mean_batch: f32 = data.iter().sum::<f32>() / data.len() as f32;
101            let new_var_batch: f32 = data
102                .iter()
103                .map(|&v| (v - new_mean_batch).powi(2))
104                .sum::<f32>()
105                / data.len() as f32;
106            let old_var = self.std_dev * self.std_dev;
107            let n1 = old_count as f32;
108            let n2 = data.len() as f32;
109            let combined_var = (n1 * old_var
110                + n2 * new_var_batch
111                + n1 * n2 / (n1 + n2) * (old_mean - new_mean_batch).powi(2))
112                / (n1 + n2);
113            self.std_dev = combined_var.sqrt();
114        } else if !data.is_empty() {
115            let m: f32 = data.iter().sum::<f32>() / data.len() as f32;
116            self.std_dev =
117                (data.iter().map(|&v| (v - m).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
118            self.num_samples = data.len();
119        }
120
121        // Update histogram bins with new data
122        if !self.histogram.is_empty() && self.max > self.min {
123            let n_bins = self.histogram.len();
124            let bin_width = (self.max - self.min) / n_bins as f32;
125            for &val in &data {
126                let bin = ((val - self.min) / bin_width).floor() as usize;
127                let bin = bin.min(n_bins - 1);
128                self.histogram[bin] += 1;
129            }
130        }
131    }
132
133    /// Returns the dynamic range.
134    pub fn dynamic_range(&self) -> f32 {
135        self.max - self.min
136    }
137
138    /// Computes the optimal scale for symmetric quantization.
139    pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
140        let max_abs = self.min.abs().max(self.max.abs());
141        let max_int = match quant_type {
142            QuantType::Q8_0 => 127.0,
143            QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
144            QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
145            QuantType::F16 | QuantType::F32 => 1.0,
146        };
147        max_abs / max_int
148    }
149
150    /// Computes the optimal scale for asymmetric quantization.
151    pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
152        let max_int = match quant_type {
153            QuantType::Q8_0 => 255.0,
154            QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
155            QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
156            QuantType::F16 | QuantType::F32 => 1.0,
157        };
158
159        let scale = (self.max - self.min) / max_int;
160        let zero_point = -self.min / scale;
161
162        (scale, zero_point)
163    }
164
165    /// Returns the percentile value from the histogram.
166    pub fn percentile(&self, p: f32) -> f32 {
167        if p <= 0.0 {
168            return self.min;
169        }
170        if p >= 100.0 {
171            return self.max;
172        }
173
174        let target = (p / 100.0 * self.num_samples as f32) as usize;
175        let mut cumsum = 0usize;
176
177        for (i, &count) in self.histogram.iter().enumerate() {
178            cumsum += count;
179            if cumsum >= target {
180                return self.bin_edges[i];
181            }
182        }
183
184        self.max
185    }
186}
187
188// =============================================================================
189// Calibration Methods
190// =============================================================================
191
192/// Calibration method enumeration.
193#[derive(Debug, Clone, Copy, PartialEq, Eq)]
194pub enum CalibrationMethod {
195    /// Use min/max values directly.
196    MinMax,
197    /// Use percentiles (e.g., 99.9th) to reduce outlier impact.
198    Percentile(u32), // percentile * 10 (e.g., 999 = 99.9%)
199    /// Use entropy-based calibration (KL divergence).
200    Entropy,
201    /// Use mean + k*std_dev for range.
202    MeanStd(u32), // k * 10 (e.g., 30 = 3.0 sigma)
203}
204
205/// Calibrates a tensor for quantization.
206///
207/// # Arguments
208/// * `tensor` - The tensor to calibrate
209/// * `method` - The calibration method to use
210///
211/// # Returns
212/// Calibration data for the tensor
213pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
214    let mut data = CalibrationData::new(tensor, 2048);
215
216    match method {
217        CalibrationMethod::MinMax => {
218            // Already computed in new()
219        }
220        CalibrationMethod::Percentile(p) => {
221            let percentile = p as f32 / 10.0;
222            let lower = data.percentile(100.0 - percentile);
223            let upper = data.percentile(percentile);
224            data.min = lower;
225            data.max = upper;
226        }
227        CalibrationMethod::MeanStd(k) => {
228            let k_factor = k as f32 / 10.0;
229            data.min = data.mean - k_factor * data.std_dev;
230            data.max = data.mean + k_factor * data.std_dev;
231        }
232        CalibrationMethod::Entropy => {
233            // KL-divergence calibration (TensorRT-style):
234            // Find the clipping threshold that minimizes KL divergence
235            // between the reference distribution and quantized distribution.
236            let n_bins = data.histogram.len();
237            if n_bins < 4 {
238                // Fallback if histogram is too small
239                data.min = data.percentile(0.01);
240                data.max = data.percentile(99.99);
241            } else {
242                let total: usize = data.histogram.iter().sum();
243                if total == 0 {
244                    data.min = data.percentile(0.01);
245                    data.max = data.percentile(99.99);
246                } else {
247                    // Normalize histogram to probability distribution
248                    let ref_dist: Vec<f64> = data
249                        .histogram
250                        .iter()
251                        .map(|&c| c as f64 / total as f64 + 1e-12)
252                        .collect();
253
254                    let quant_bins = 128usize; // INT8 bins
255                    let mut best_kl = f64::MAX;
256                    let mut best_threshold = n_bins;
257
258                    // Try different thresholds (from half to full range)
259                    for threshold in (n_bins / 2)..n_bins {
260                        // Clip reference distribution at threshold
261                        let mut clipped = ref_dist[..threshold].to_vec();
262                        // Add outlier mass to last bin
263                        let outlier_mass: f64 = ref_dist[threshold..].iter().sum();
264                        if let Some(last) = clipped.last_mut() {
265                            *last += outlier_mass;
266                        }
267
268                        // Quantize: merge bins
269                        let bins_per_quant = threshold.div_ceil(quant_bins);
270                        let mut quant_dist = vec![0.0f64; quant_bins.min(threshold)];
271                        for (i, &p) in clipped.iter().enumerate() {
272                            let qi = (i / bins_per_quant).min(quant_dist.len() - 1);
273                            quant_dist[qi] += p;
274                        }
275
276                        // Expand back to original bins
277                        let mut expanded = vec![0.0f64; threshold];
278                        for (qi, &qval) in quant_dist.iter().enumerate() {
279                            let start = qi * bins_per_quant;
280                            let end = ((qi + 1) * bins_per_quant).min(threshold);
281                            let count = (end - start) as f64;
282                            if count > 0.0 {
283                                let val = qval / count;
284                                for slot in expanded.iter_mut().take(end).skip(start) {
285                                    *slot = val + 1e-12;
286                                }
287                            }
288                        }
289
290                        // KL divergence: sum(P * log(P/Q))
291                        let kl: f64 = clipped
292                            .iter()
293                            .zip(expanded.iter())
294                            .map(|(&p, &q)| if p > 1e-12 { p * (p / q).ln() } else { 0.0 })
295                            .sum();
296
297                        if kl < best_kl {
298                            best_kl = kl;
299                            best_threshold = threshold;
300                        }
301                    }
302
303                    // Convert threshold back to value range
304                    let bin_width = (data.max - data.min) / n_bins as f32;
305                    let clip_max = data.min + best_threshold as f32 * bin_width;
306                    data.max = clip_max;
307                    // Symmetric: clip min symmetrically if data spans zero
308                    if data.min < 0.0 && data.max > 0.0 {
309                        let abs_max = data.max.abs().max(data.min.abs());
310                        data.min = -abs_max;
311                        data.max = abs_max;
312                    }
313                }
314            }
315        }
316    }
317
318    Ok(data)
319}
320
321/// Calibrates multiple tensors and returns combined calibration data.
322pub fn calibrate_batch(
323    tensors: &[&Tensor<f32>],
324    method: CalibrationMethod,
325) -> QuantResult<CalibrationData> {
326    if tensors.is_empty() {
327        return Err(QuantError::CalibrationError(
328            "No tensors provided".to_string(),
329        ));
330    }
331
332    let mut combined = CalibrationData::new(tensors[0], 2048);
333
334    for tensor in tensors.iter().skip(1) {
335        combined.update(tensor);
336    }
337
338    // Apply method-specific adjustments
339    match method {
340        CalibrationMethod::Percentile(p) => {
341            let percentile = p as f32 / 10.0;
342            combined.min = combined.percentile(100.0 - percentile);
343            combined.max = combined.percentile(percentile);
344        }
345        CalibrationMethod::MeanStd(k) => {
346            let k_factor = k as f32 / 10.0;
347            combined.min = combined.mean - k_factor * combined.std_dev;
348            combined.max = combined.mean + k_factor * combined.std_dev;
349        }
350        _ => {}
351    }
352
353    Ok(combined)
354}
355
356// =============================================================================
357// Tests
358// =============================================================================
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_calibration_data() {
366        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
367        let tensor = Tensor::from_vec(data, &[5]).unwrap();
368
369        let calib = CalibrationData::new(&tensor, 10);
370
371        assert_eq!(calib.min, 1.0);
372        assert_eq!(calib.max, 5.0);
373        assert_eq!(calib.mean, 3.0);
374        assert_eq!(calib.num_samples, 5);
375    }
376
377    #[test]
378    fn test_symmetric_scale() {
379        let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
380        let tensor = Tensor::from_vec(data, &[5]).unwrap();
381
382        let calib = CalibrationData::new(&tensor, 10);
383        let scale = calib.symmetric_scale(QuantType::Q8_0);
384
385        // max_abs = 4.0, max_int = 127, scale = 4/127
386        assert!((scale - 4.0 / 127.0).abs() < 0.001);
387    }
388
389    #[test]
390    fn test_calibration_methods() {
391        let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
392        let tensor = Tensor::from_vec(data, &[1000]).unwrap();
393
394        // Min/Max calibration
395        let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
396        assert!((minmax.min - 0.0).abs() < 0.01);
397        assert!((minmax.max - 9.99).abs() < 0.01);
398
399        // Percentile calibration (99.9%)
400        let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
401        assert!(percentile.min >= 0.0);
402        assert!(percentile.max <= 9.99);
403    }
404
405    #[test]
406    fn test_dynamic_range() {
407        let data = vec![-5.0, 10.0];
408        let tensor = Tensor::from_vec(data, &[2]).unwrap();
409
410        let calib = CalibrationData::new(&tensor, 10);
411        assert_eq!(calib.dynamic_range(), 15.0);
412    }
413}