Skip to main content

axonml_quant/
calibration.rs

1//! Calibration for Quantization
2//!
3//! # File
4//! `crates/axonml-quant/src/calibration.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr. — AutomataNexus LLC
8//! ORCID: 0009-0005-2158-7060
9//!
10//! # Updated
11//! April 14, 2026 11:15 PM EST
12//!
13//! # Disclaimer
14//! Use at own risk. This software is provided "as is", without warranty of any
15//! kind, express or implied. The author and AutomataNexus shall not be held
16//! liable for any damages arising from the use of this software.
17
18use axonml_tensor::Tensor;
19
20use crate::error::{QuantError, QuantResult};
21use crate::types::QuantType;
22
23// =============================================================================
24// Calibration Data
25// =============================================================================
26
27/// Calibration data collected from sample inputs.
28#[derive(Debug, Clone)]
29pub struct CalibrationData {
30    /// Minimum value seen.
31    pub min: f32,
32    /// Maximum value seen.
33    pub max: f32,
34    /// Mean value.
35    pub mean: f32,
36    /// Standard deviation.
37    pub std_dev: f32,
38    /// Number of samples.
39    pub num_samples: usize,
40    /// Histogram buckets (for percentile calibration).
41    histogram: Vec<usize>,
42    /// Histogram bin edges.
43    bin_edges: Vec<f32>,
44}
45
46impl CalibrationData {
47    /// Creates new calibration data from initial tensor.
48    pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
49        let data = tensor.to_vec();
50        let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
51        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
52        let mean = data.iter().sum::<f32>() / data.len() as f32;
53
54        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
55        let std_dev = variance.sqrt();
56
57        // Initialize histogram
58        let bin_width = (max - min) / num_bins as f32;
59        let mut histogram = vec![0usize; num_bins];
60        let bin_edges: Vec<f32> = (0..=num_bins).map(|i| min + i as f32 * bin_width).collect();
61
62        for &val in &data {
63            let bin = ((val - min) / bin_width) as usize;
64            let bin = bin.min(num_bins - 1);
65            histogram[bin] += 1;
66        }
67
68        Self {
69            min,
70            max,
71            mean,
72            std_dev,
73            num_samples: data.len(),
74            histogram,
75            bin_edges,
76        }
77    }
78
79    /// Updates calibration data with more samples (streaming Welford algorithm).
80    pub fn update(&mut self, tensor: &Tensor<f32>) {
81        let data = tensor.to_vec();
82        let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
83        let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
84
85        // Update min/max
86        self.min = self.min.min(new_min);
87        self.max = self.max.max(new_max);
88
89        // Welford's online algorithm for mean and variance
90        let old_mean = self.mean;
91        let old_count = self.num_samples;
92        for &val in &data {
93            self.num_samples += 1;
94            let delta = val - self.mean;
95            self.mean += delta / self.num_samples as f32;
96        }
97
98        // Update std_dev from running variance
99        // Use combined variance formula: Var(A∪B) from count, mean, var of each
100        if old_count > 0 && !data.is_empty() {
101            let new_mean_batch: f32 = data.iter().sum::<f32>() / data.len() as f32;
102            let new_var_batch: f32 = data
103                .iter()
104                .map(|&v| (v - new_mean_batch).powi(2))
105                .sum::<f32>()
106                / data.len() as f32;
107            let old_var = self.std_dev * self.std_dev;
108            let n1 = old_count as f32;
109            let n2 = data.len() as f32;
110            let combined_var = (n1 * old_var
111                + n2 * new_var_batch
112                + n1 * n2 / (n1 + n2) * (old_mean - new_mean_batch).powi(2))
113                / (n1 + n2);
114            self.std_dev = combined_var.sqrt();
115        } else if !data.is_empty() {
116            let m: f32 = data.iter().sum::<f32>() / data.len() as f32;
117            self.std_dev =
118                (data.iter().map(|&v| (v - m).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
119            self.num_samples = data.len();
120        }
121
122        // Update histogram bins with new data
123        if !self.histogram.is_empty() && self.max > self.min {
124            let n_bins = self.histogram.len();
125            let bin_width = (self.max - self.min) / n_bins as f32;
126            for &val in &data {
127                let bin = ((val - self.min) / bin_width).floor() as usize;
128                let bin = bin.min(n_bins - 1);
129                self.histogram[bin] += 1;
130            }
131        }
132    }
133
134    /// Returns the dynamic range.
135    pub fn dynamic_range(&self) -> f32 {
136        self.max - self.min
137    }
138
139    /// Computes the optimal scale for symmetric quantization.
140    pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
141        let max_abs = self.min.abs().max(self.max.abs());
142        let max_int = match quant_type {
143            QuantType::Q8_0 => 127.0,
144            QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
145            QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
146            QuantType::F16 | QuantType::F32 => 1.0,
147        };
148        max_abs / max_int
149    }
150
151    /// Computes the optimal scale for asymmetric quantization.
152    pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
153        let max_int = match quant_type {
154            QuantType::Q8_0 => 255.0,
155            QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
156            QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
157            QuantType::F16 | QuantType::F32 => 1.0,
158        };
159
160        let scale = (self.max - self.min) / max_int;
161        let zero_point = -self.min / scale;
162
163        (scale, zero_point)
164    }
165
166    /// Returns the percentile value from the histogram.
167    pub fn percentile(&self, p: f32) -> f32 {
168        if p <= 0.0 {
169            return self.min;
170        }
171        if p >= 100.0 {
172            return self.max;
173        }
174
175        let target = (p / 100.0 * self.num_samples as f32) as usize;
176        let mut cumsum = 0usize;
177
178        for (i, &count) in self.histogram.iter().enumerate() {
179            cumsum += count;
180            if cumsum >= target {
181                return self.bin_edges[i];
182            }
183        }
184
185        self.max
186    }
187}
188
189// =============================================================================
190// Calibration Methods
191// =============================================================================
192
193/// Calibration method enumeration.
194#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195pub enum CalibrationMethod {
196    /// Use min/max values directly.
197    MinMax,
198    /// Use percentiles (e.g., 99.9th) to reduce outlier impact.
199    Percentile(u32), // percentile * 10 (e.g., 999 = 99.9%)
200    /// Use entropy-based calibration (KL divergence).
201    Entropy,
202    /// Use mean + k*std_dev for range.
203    MeanStd(u32), // k * 10 (e.g., 30 = 3.0 sigma)
204}
205
206/// Calibrates a tensor for quantization.
207///
208/// # Arguments
209/// * `tensor` - The tensor to calibrate
210/// * `method` - The calibration method to use
211///
212/// # Returns
213/// Calibration data for the tensor
214pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
215    let mut data = CalibrationData::new(tensor, 2048);
216
217    match method {
218        CalibrationMethod::MinMax => {
219            // Already computed in new()
220        }
221        CalibrationMethod::Percentile(p) => {
222            let percentile = p as f32 / 10.0;
223            let lower = data.percentile(100.0 - percentile);
224            let upper = data.percentile(percentile);
225            data.min = lower;
226            data.max = upper;
227        }
228        CalibrationMethod::MeanStd(k) => {
229            let k_factor = k as f32 / 10.0;
230            data.min = data.mean - k_factor * data.std_dev;
231            data.max = data.mean + k_factor * data.std_dev;
232        }
233        CalibrationMethod::Entropy => {
234            // KL-divergence calibration (TensorRT-style):
235            // Find the clipping threshold that minimizes KL divergence
236            // between the reference distribution and quantized distribution.
237            let n_bins = data.histogram.len();
238            if n_bins < 4 {
239                // Fallback if histogram is too small
240                data.min = data.percentile(0.01);
241                data.max = data.percentile(99.99);
242            } else {
243                let total: usize = data.histogram.iter().sum();
244                if total == 0 {
245                    data.min = data.percentile(0.01);
246                    data.max = data.percentile(99.99);
247                } else {
248                    // Normalize histogram to probability distribution
249                    let ref_dist: Vec<f64> = data
250                        .histogram
251                        .iter()
252                        .map(|&c| c as f64 / total as f64 + 1e-12)
253                        .collect();
254
255                    let quant_bins = 128usize; // INT8 bins
256                    let mut best_kl = f64::MAX;
257                    let mut best_threshold = n_bins;
258
259                    // Try different thresholds (from half to full range)
260                    for threshold in (n_bins / 2)..n_bins {
261                        // Clip reference distribution at threshold
262                        let mut clipped = ref_dist[..threshold].to_vec();
263                        // Add outlier mass to last bin
264                        let outlier_mass: f64 = ref_dist[threshold..].iter().sum();
265                        if let Some(last) = clipped.last_mut() {
266                            *last += outlier_mass;
267                        }
268
269                        // Quantize: merge bins
270                        let bins_per_quant = threshold.div_ceil(quant_bins);
271                        let mut quant_dist = vec![0.0f64; quant_bins.min(threshold)];
272                        for (i, &p) in clipped.iter().enumerate() {
273                            let qi = (i / bins_per_quant).min(quant_dist.len() - 1);
274                            quant_dist[qi] += p;
275                        }
276
277                        // Expand back to original bins
278                        let mut expanded = vec![0.0f64; threshold];
279                        for (qi, &qval) in quant_dist.iter().enumerate() {
280                            let start = qi * bins_per_quant;
281                            let end = ((qi + 1) * bins_per_quant).min(threshold);
282                            let count = (end - start) as f64;
283                            if count > 0.0 {
284                                let val = qval / count;
285                                for slot in expanded.iter_mut().take(end).skip(start) {
286                                    *slot = val + 1e-12;
287                                }
288                            }
289                        }
290
291                        // KL divergence: sum(P * log(P/Q))
292                        let kl: f64 = clipped
293                            .iter()
294                            .zip(expanded.iter())
295                            .map(|(&p, &q)| if p > 1e-12 { p * (p / q).ln() } else { 0.0 })
296                            .sum();
297
298                        if kl < best_kl {
299                            best_kl = kl;
300                            best_threshold = threshold;
301                        }
302                    }
303
304                    // Convert threshold back to value range
305                    let bin_width = (data.max - data.min) / n_bins as f32;
306                    let clip_max = data.min + best_threshold as f32 * bin_width;
307                    data.max = clip_max;
308                    // Symmetric: clip min symmetrically if data spans zero
309                    if data.min < 0.0 && data.max > 0.0 {
310                        let abs_max = data.max.abs().max(data.min.abs());
311                        data.min = -abs_max;
312                        data.max = abs_max;
313                    }
314                }
315            }
316        }
317    }
318
319    Ok(data)
320}
321
322/// Calibrates multiple tensors and returns combined calibration data.
323pub fn calibrate_batch(
324    tensors: &[&Tensor<f32>],
325    method: CalibrationMethod,
326) -> QuantResult<CalibrationData> {
327    if tensors.is_empty() {
328        return Err(QuantError::CalibrationError(
329            "No tensors provided".to_string(),
330        ));
331    }
332
333    let mut combined = CalibrationData::new(tensors[0], 2048);
334
335    for tensor in tensors.iter().skip(1) {
336        combined.update(tensor);
337    }
338
339    // Apply method-specific adjustments
340    match method {
341        CalibrationMethod::Percentile(p) => {
342            let percentile = p as f32 / 10.0;
343            combined.min = combined.percentile(100.0 - percentile);
344            combined.max = combined.percentile(percentile);
345        }
346        CalibrationMethod::MeanStd(k) => {
347            let k_factor = k as f32 / 10.0;
348            combined.min = combined.mean - k_factor * combined.std_dev;
349            combined.max = combined.mean + k_factor * combined.std_dev;
350        }
351        _ => {}
352    }
353
354    Ok(combined)
355}
356
357// =============================================================================
358// Tests
359// =============================================================================
360
361#[cfg(test)]
362mod tests {
363    use super::*;
364
365    #[test]
366    fn test_calibration_data() {
367        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
368        let tensor = Tensor::from_vec(data, &[5]).unwrap();
369
370        let calib = CalibrationData::new(&tensor, 10);
371
372        assert_eq!(calib.min, 1.0);
373        assert_eq!(calib.max, 5.0);
374        assert_eq!(calib.mean, 3.0);
375        assert_eq!(calib.num_samples, 5);
376    }
377
378    #[test]
379    fn test_symmetric_scale() {
380        let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
381        let tensor = Tensor::from_vec(data, &[5]).unwrap();
382
383        let calib = CalibrationData::new(&tensor, 10);
384        let scale = calib.symmetric_scale(QuantType::Q8_0);
385
386        // max_abs = 4.0, max_int = 127, scale = 4/127
387        assert!((scale - 4.0 / 127.0).abs() < 0.001);
388    }
389
390    #[test]
391    fn test_calibration_methods() {
392        let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
393        let tensor = Tensor::from_vec(data, &[1000]).unwrap();
394
395        // Min/Max calibration
396        let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
397        assert!((minmax.min - 0.0).abs() < 0.01);
398        assert!((minmax.max - 9.99).abs() < 0.01);
399
400        // Percentile calibration (99.9%)
401        let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
402        assert!(percentile.min >= 0.0);
403        assert!(percentile.max <= 9.99);
404    }
405
406    #[test]
407    fn test_dynamic_range() {
408        let data = vec![-5.0, 10.0];
409        let tensor = Tensor::from_vec(data, &[2]).unwrap();
410
411        let calib = CalibrationData::new(&tensor, 10);
412        assert_eq!(calib.dynamic_range(), 15.0);
413    }
414}