Skip to main content

axonml_quant/
calibration.rs

1//! Calibration for Quantization
2//!
3//! # File
4//! `crates/axonml-quant/src/calibration.rs`
5//!
6//! # Author
7//! Andrew Jewell Sr - AutomataNexus
8//!
9//! # Updated
10//! March 8, 2026
11//!
12//! # Disclaimer
13//! Use at own risk. This software is provided "as is", without warranty of any
14//! kind, express or implied. The author and AutomataNexus shall not be held
15//! liable for any damages arising from the use of this software.
16
17use axonml_tensor::Tensor;
18
19use crate::error::{QuantError, QuantResult};
20use crate::types::QuantType;
21
22// =============================================================================
23// Calibration Data
24// =============================================================================
25
26/// Calibration data collected from sample inputs.
27#[derive(Debug, Clone)]
28pub struct CalibrationData {
29    /// Minimum value seen.
30    pub min: f32,
31    /// Maximum value seen.
32    pub max: f32,
33    /// Mean value.
34    pub mean: f32,
35    /// Standard deviation.
36    pub std_dev: f32,
37    /// Number of samples.
38    pub num_samples: usize,
39    /// Histogram buckets (for percentile calibration).
40    histogram: Vec<usize>,
41    /// Histogram bin edges.
42    bin_edges: Vec<f32>,
43}
44
45impl CalibrationData {
46    /// Creates new calibration data from initial tensor.
47    pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
48        let data = tensor.to_vec();
49        let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
50        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
51        let mean = data.iter().sum::<f32>() / data.len() as f32;
52
53        let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
54        let std_dev = variance.sqrt();
55
56        // Initialize histogram
57        let bin_width = (max - min) / num_bins as f32;
58        let mut histogram = vec![0usize; num_bins];
59        let bin_edges: Vec<f32> = (0..=num_bins).map(|i| min + i as f32 * bin_width).collect();
60
61        for &val in &data {
62            let bin = ((val - min) / bin_width) as usize;
63            let bin = bin.min(num_bins - 1);
64            histogram[bin] += 1;
65        }
66
67        Self {
68            min,
69            max,
70            mean,
71            std_dev,
72            num_samples: data.len(),
73            histogram,
74            bin_edges,
75        }
76    }
77
78    /// Updates calibration data with more samples.
79    pub fn update(&mut self, tensor: &Tensor<f32>) {
80        let data = tensor.to_vec();
81        let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
82        let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
83
84        // Update min/max
85        self.min = self.min.min(new_min);
86        self.max = self.max.max(new_max);
87
88        // Update running mean
89        let old_count = self.num_samples as f32;
90        let new_count = data.len() as f32;
91        let new_mean = data.iter().sum::<f32>() / new_count;
92        self.mean = (self.mean * old_count + new_mean * new_count) / (old_count + new_count);
93
94        // Update histogram (rebuild with new range)
95        self.num_samples += data.len();
96        // Note: For proper histogram update, we'd need to keep all data or use streaming algorithms
97    }
98
99    /// Returns the dynamic range.
100    pub fn dynamic_range(&self) -> f32 {
101        self.max - self.min
102    }
103
104    /// Computes the optimal scale for symmetric quantization.
105    pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
106        let max_abs = self.min.abs().max(self.max.abs());
107        let max_int = match quant_type {
108            QuantType::Q8_0 => 127.0,
109            QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
110            QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
111            QuantType::F16 | QuantType::F32 => 1.0,
112        };
113        max_abs / max_int
114    }
115
116    /// Computes the optimal scale for asymmetric quantization.
117    pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
118        let max_int = match quant_type {
119            QuantType::Q8_0 => 255.0,
120            QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
121            QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
122            QuantType::F16 | QuantType::F32 => 1.0,
123        };
124
125        let scale = (self.max - self.min) / max_int;
126        let zero_point = -self.min / scale;
127
128        (scale, zero_point)
129    }
130
131    /// Returns the percentile value from the histogram.
132    pub fn percentile(&self, p: f32) -> f32 {
133        if p <= 0.0 {
134            return self.min;
135        }
136        if p >= 100.0 {
137            return self.max;
138        }
139
140        let target = (p / 100.0 * self.num_samples as f32) as usize;
141        let mut cumsum = 0usize;
142
143        for (i, &count) in self.histogram.iter().enumerate() {
144            cumsum += count;
145            if cumsum >= target {
146                return self.bin_edges[i];
147            }
148        }
149
150        self.max
151    }
152}
153
154// =============================================================================
155// Calibration Methods
156// =============================================================================
157
158/// Calibration method enumeration.
159#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum CalibrationMethod {
161    /// Use min/max values directly.
162    MinMax,
163    /// Use percentiles (e.g., 99.9th) to reduce outlier impact.
164    Percentile(u32), // percentile * 10 (e.g., 999 = 99.9%)
165    /// Use entropy-based calibration (KL divergence).
166    Entropy,
167    /// Use mean + k*std_dev for range.
168    MeanStd(u32), // k * 10 (e.g., 30 = 3.0 sigma)
169}
170
171/// Calibrates a tensor for quantization.
172///
173/// # Arguments
174/// * `tensor` - The tensor to calibrate
175/// * `method` - The calibration method to use
176///
177/// # Returns
178/// Calibration data for the tensor
179pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
180    let mut data = CalibrationData::new(tensor, 2048);
181
182    match method {
183        CalibrationMethod::MinMax => {
184            // Already computed in new()
185        }
186        CalibrationMethod::Percentile(p) => {
187            let percentile = p as f32 / 10.0;
188            let lower = data.percentile(100.0 - percentile);
189            let upper = data.percentile(percentile);
190            data.min = lower;
191            data.max = upper;
192        }
193        CalibrationMethod::MeanStd(k) => {
194            let k_factor = k as f32 / 10.0;
195            data.min = data.mean - k_factor * data.std_dev;
196            data.max = data.mean + k_factor * data.std_dev;
197        }
198        CalibrationMethod::Entropy => {
199            // Simplified entropy calibration - use 99.99th percentile
200            data.min = data.percentile(0.01);
201            data.max = data.percentile(99.99);
202        }
203    }
204
205    Ok(data)
206}
207
208/// Calibrates multiple tensors and returns combined calibration data.
209pub fn calibrate_batch(
210    tensors: &[&Tensor<f32>],
211    method: CalibrationMethod,
212) -> QuantResult<CalibrationData> {
213    if tensors.is_empty() {
214        return Err(QuantError::CalibrationError(
215            "No tensors provided".to_string(),
216        ));
217    }
218
219    let mut combined = CalibrationData::new(tensors[0], 2048);
220
221    for tensor in tensors.iter().skip(1) {
222        combined.update(tensor);
223    }
224
225    // Apply method-specific adjustments
226    match method {
227        CalibrationMethod::Percentile(p) => {
228            let percentile = p as f32 / 10.0;
229            combined.min = combined.percentile(100.0 - percentile);
230            combined.max = combined.percentile(percentile);
231        }
232        CalibrationMethod::MeanStd(k) => {
233            let k_factor = k as f32 / 10.0;
234            combined.min = combined.mean - k_factor * combined.std_dev;
235            combined.max = combined.mean + k_factor * combined.std_dev;
236        }
237        _ => {}
238    }
239
240    Ok(combined)
241}
242
243// =============================================================================
244// Tests
245// =============================================================================
246
247#[cfg(test)]
248mod tests {
249    use super::*;
250
251    #[test]
252    fn test_calibration_data() {
253        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
254        let tensor = Tensor::from_vec(data, &[5]).unwrap();
255
256        let calib = CalibrationData::new(&tensor, 10);
257
258        assert_eq!(calib.min, 1.0);
259        assert_eq!(calib.max, 5.0);
260        assert_eq!(calib.mean, 3.0);
261        assert_eq!(calib.num_samples, 5);
262    }
263
264    #[test]
265    fn test_symmetric_scale() {
266        let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
267        let tensor = Tensor::from_vec(data, &[5]).unwrap();
268
269        let calib = CalibrationData::new(&tensor, 10);
270        let scale = calib.symmetric_scale(QuantType::Q8_0);
271
272        // max_abs = 4.0, max_int = 127, scale = 4/127
273        assert!((scale - 4.0 / 127.0).abs() < 0.001);
274    }
275
276    #[test]
277    fn test_calibration_methods() {
278        let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
279        let tensor = Tensor::from_vec(data, &[1000]).unwrap();
280
281        // Min/Max calibration
282        let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
283        assert!((minmax.min - 0.0).abs() < 0.01);
284        assert!((minmax.max - 9.99).abs() < 0.01);
285
286        // Percentile calibration (99.9%)
287        let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
288        assert!(percentile.min >= 0.0);
289        assert!(percentile.max <= 9.99);
290    }
291
292    #[test]
293    fn test_dynamic_range() {
294        let data = vec![-5.0, 10.0];
295        let tensor = Tensor::from_vec(data, &[2]).unwrap();
296
297        let calib = CalibrationData::new(&tensor, 10);
298        assert_eq!(calib.dynamic_range(), 15.0);
299    }
300}