Skip to main content

axonml_quant/
calibration.rs

1//! Calibration for Quantization
2//!
3//! Calibration methods for determining optimal quantization parameters.
4//!
5//! @version 0.1.0
6//! @author AutomataNexus Development Team
7
8use axonml_tensor::Tensor;
9
10use crate::error::{QuantError, QuantResult};
11use crate::types::QuantType;
12
13// =============================================================================
14// Calibration Data
15// =============================================================================
16
17/// Calibration data collected from sample inputs.
18#[derive(Debug, Clone)]
19pub struct CalibrationData {
20    /// Minimum value seen.
21    pub min: f32,
22    /// Maximum value seen.
23    pub max: f32,
24    /// Mean value.
25    pub mean: f32,
26    /// Standard deviation.
27    pub std_dev: f32,
28    /// Number of samples.
29    pub num_samples: usize,
30    /// Histogram buckets (for percentile calibration).
31    histogram: Vec<usize>,
32    /// Histogram bin edges.
33    bin_edges: Vec<f32>,
34}
35
36impl CalibrationData {
37    /// Creates new calibration data from initial tensor.
38    pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
39        let data = tensor.to_vec();
40        let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
41        let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
42        let mean = data.iter().sum::<f32>() / data.len() as f32;
43
44        let variance = data
45            .iter()
46            .map(|x| (x - mean).powi(2))
47            .sum::<f32>()
48            / data.len() as f32;
49        let std_dev = variance.sqrt();
50
51        // Initialize histogram
52        let bin_width = (max - min) / num_bins as f32;
53        let mut histogram = vec![0usize; num_bins];
54        let bin_edges: Vec<f32> = (0..=num_bins)
55            .map(|i| min + i as f32 * bin_width)
56            .collect();
57
58        for &val in &data {
59            let bin = ((val - min) / bin_width) as usize;
60            let bin = bin.min(num_bins - 1);
61            histogram[bin] += 1;
62        }
63
64        Self {
65            min,
66            max,
67            mean,
68            std_dev,
69            num_samples: data.len(),
70            histogram,
71            bin_edges,
72        }
73    }
74
75    /// Updates calibration data with more samples.
76    pub fn update(&mut self, tensor: &Tensor<f32>) {
77        let data = tensor.to_vec();
78        let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
79        let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
80
81        // Update min/max
82        self.min = self.min.min(new_min);
83        self.max = self.max.max(new_max);
84
85        // Update running mean
86        let old_count = self.num_samples as f32;
87        let new_count = data.len() as f32;
88        let new_mean = data.iter().sum::<f32>() / new_count;
89        self.mean = (self.mean * old_count + new_mean * new_count) / (old_count + new_count);
90
91        // Update histogram (rebuild with new range)
92        self.num_samples += data.len();
93        // Note: For proper histogram update, we'd need to keep all data or use streaming algorithms
94    }
95
96    /// Returns the dynamic range.
97    pub fn dynamic_range(&self) -> f32 {
98        self.max - self.min
99    }
100
101    /// Computes the optimal scale for symmetric quantization.
102    pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
103        let max_abs = self.min.abs().max(self.max.abs());
104        let max_int = match quant_type {
105            QuantType::Q8_0 => 127.0,
106            QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
107            QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
108            QuantType::F16 | QuantType::F32 => 1.0,
109        };
110        max_abs / max_int
111    }
112
113    /// Computes the optimal scale for asymmetric quantization.
114    pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
115        let max_int = match quant_type {
116            QuantType::Q8_0 => 255.0,
117            QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
118            QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
119            QuantType::F16 | QuantType::F32 => 1.0,
120        };
121
122        let scale = (self.max - self.min) / max_int;
123        let zero_point = -self.min / scale;
124
125        (scale, zero_point)
126    }
127
128    /// Returns the percentile value from the histogram.
129    pub fn percentile(&self, p: f32) -> f32 {
130        if p <= 0.0 {
131            return self.min;
132        }
133        if p >= 100.0 {
134            return self.max;
135        }
136
137        let target = (p / 100.0 * self.num_samples as f32) as usize;
138        let mut cumsum = 0usize;
139
140        for (i, &count) in self.histogram.iter().enumerate() {
141            cumsum += count;
142            if cumsum >= target {
143                return self.bin_edges[i];
144            }
145        }
146
147        self.max
148    }
149}
150
151// =============================================================================
152// Calibration Methods
153// =============================================================================
154
155/// Calibration method enumeration.
156#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum CalibrationMethod {
158    /// Use min/max values directly.
159    MinMax,
160    /// Use percentiles (e.g., 99.9th) to reduce outlier impact.
161    Percentile(u32), // percentile * 10 (e.g., 999 = 99.9%)
162    /// Use entropy-based calibration (KL divergence).
163    Entropy,
164    /// Use mean + k*std_dev for range.
165    MeanStd(u32), // k * 10 (e.g., 30 = 3.0 sigma)
166}
167
168/// Calibrates a tensor for quantization.
169///
170/// # Arguments
171/// * `tensor` - The tensor to calibrate
172/// * `method` - The calibration method to use
173///
174/// # Returns
175/// Calibration data for the tensor
176pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
177    let mut data = CalibrationData::new(tensor, 2048);
178
179    match method {
180        CalibrationMethod::MinMax => {
181            // Already computed in new()
182        }
183        CalibrationMethod::Percentile(p) => {
184            let percentile = p as f32 / 10.0;
185            let lower = data.percentile(100.0 - percentile);
186            let upper = data.percentile(percentile);
187            data.min = lower;
188            data.max = upper;
189        }
190        CalibrationMethod::MeanStd(k) => {
191            let k_factor = k as f32 / 10.0;
192            data.min = data.mean - k_factor * data.std_dev;
193            data.max = data.mean + k_factor * data.std_dev;
194        }
195        CalibrationMethod::Entropy => {
196            // Simplified entropy calibration - use 99.99th percentile
197            data.min = data.percentile(0.01);
198            data.max = data.percentile(99.99);
199        }
200    }
201
202    Ok(data)
203}
204
205/// Calibrates multiple tensors and returns combined calibration data.
206pub fn calibrate_batch(
207    tensors: &[&Tensor<f32>],
208    method: CalibrationMethod,
209) -> QuantResult<CalibrationData> {
210    if tensors.is_empty() {
211        return Err(QuantError::CalibrationError("No tensors provided".to_string()));
212    }
213
214    let mut combined = CalibrationData::new(tensors[0], 2048);
215
216    for tensor in tensors.iter().skip(1) {
217        combined.update(tensor);
218    }
219
220    // Apply method-specific adjustments
221    match method {
222        CalibrationMethod::Percentile(p) => {
223            let percentile = p as f32 / 10.0;
224            combined.min = combined.percentile(100.0 - percentile);
225            combined.max = combined.percentile(percentile);
226        }
227        CalibrationMethod::MeanStd(k) => {
228            let k_factor = k as f32 / 10.0;
229            combined.min = combined.mean - k_factor * combined.std_dev;
230            combined.max = combined.mean + k_factor * combined.std_dev;
231        }
232        _ => {}
233    }
234
235    Ok(combined)
236}
237
238// =============================================================================
239// Tests
240// =============================================================================
241
242#[cfg(test)]
243mod tests {
244    use super::*;
245
246    #[test]
247    fn test_calibration_data() {
248        let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
249        let tensor = Tensor::from_vec(data, &[5]).unwrap();
250
251        let calib = CalibrationData::new(&tensor, 10);
252
253        assert_eq!(calib.min, 1.0);
254        assert_eq!(calib.max, 5.0);
255        assert_eq!(calib.mean, 3.0);
256        assert_eq!(calib.num_samples, 5);
257    }
258
259    #[test]
260    fn test_symmetric_scale() {
261        let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
262        let tensor = Tensor::from_vec(data, &[5]).unwrap();
263
264        let calib = CalibrationData::new(&tensor, 10);
265        let scale = calib.symmetric_scale(QuantType::Q8_0);
266
267        // max_abs = 4.0, max_int = 127, scale = 4/127
268        assert!((scale - 4.0 / 127.0).abs() < 0.001);
269    }
270
271    #[test]
272    fn test_calibration_methods() {
273        let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
274        let tensor = Tensor::from_vec(data, &[1000]).unwrap();
275
276        // Min/Max calibration
277        let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
278        assert!((minmax.min - 0.0).abs() < 0.01);
279        assert!((minmax.max - 9.99).abs() < 0.01);
280
281        // Percentile calibration (99.9%)
282        let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
283        assert!(percentile.min >= 0.0);
284        assert!(percentile.max <= 9.99);
285    }
286
287    #[test]
288    fn test_dynamic_range() {
289        let data = vec![-5.0, 10.0];
290        let tensor = Tensor::from_vec(data, &[2]).unwrap();
291
292        let calib = CalibrationData::new(&tensor, 10);
293        assert_eq!(calib.dynamic_range(), 15.0);
294    }
295}