axonml_quant/
calibration.rs1use axonml_tensor::Tensor;
18
19use crate::error::{QuantError, QuantResult};
20use crate::types::QuantType;
21
22#[derive(Debug, Clone)]
28pub struct CalibrationData {
29 pub min: f32,
31 pub max: f32,
33 pub mean: f32,
35 pub std_dev: f32,
37 pub num_samples: usize,
39 histogram: Vec<usize>,
41 bin_edges: Vec<f32>,
43}
44
45impl CalibrationData {
46 pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
48 let data = tensor.to_vec();
49 let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
50 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
51 let mean = data.iter().sum::<f32>() / data.len() as f32;
52
53 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
54 let std_dev = variance.sqrt();
55
56 let bin_width = (max - min) / num_bins as f32;
58 let mut histogram = vec![0usize; num_bins];
59 let bin_edges: Vec<f32> = (0..=num_bins).map(|i| min + i as f32 * bin_width).collect();
60
61 for &val in &data {
62 let bin = ((val - min) / bin_width) as usize;
63 let bin = bin.min(num_bins - 1);
64 histogram[bin] += 1;
65 }
66
67 Self {
68 min,
69 max,
70 mean,
71 std_dev,
72 num_samples: data.len(),
73 histogram,
74 bin_edges,
75 }
76 }
77
78 pub fn update(&mut self, tensor: &Tensor<f32>) {
80 let data = tensor.to_vec();
81 let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
82 let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
83
84 self.min = self.min.min(new_min);
86 self.max = self.max.max(new_max);
87
88 let old_mean = self.mean;
90 let old_count = self.num_samples;
91 for &val in &data {
92 self.num_samples += 1;
93 let delta = val - self.mean;
94 self.mean += delta / self.num_samples as f32;
95 }
96
97 if old_count > 0 && !data.is_empty() {
100 let new_mean_batch: f32 = data.iter().sum::<f32>() / data.len() as f32;
101 let new_var_batch: f32 = data.iter().map(|&v| (v - new_mean_batch).powi(2)).sum::<f32>()
102 / data.len() as f32;
103 let old_var = self.std_dev * self.std_dev;
104 let n1 = old_count as f32;
105 let n2 = data.len() as f32;
106 let combined_var = (n1 * old_var + n2 * new_var_batch
107 + n1 * n2 / (n1 + n2) * (old_mean - new_mean_batch).powi(2))
108 / (n1 + n2);
109 self.std_dev = combined_var.sqrt();
110 } else if !data.is_empty() {
111 let m: f32 = data.iter().sum::<f32>() / data.len() as f32;
112 self.std_dev = (data.iter().map(|&v| (v - m).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
113 self.num_samples = data.len();
114 }
115
116 if !self.histogram.is_empty() && self.max > self.min {
118 let n_bins = self.histogram.len();
119 let bin_width = (self.max - self.min) / n_bins as f32;
120 for &val in &data {
121 let bin = ((val - self.min) / bin_width).floor() as usize;
122 let bin = bin.min(n_bins - 1);
123 self.histogram[bin] += 1;
124 }
125 }
126 }
127
128 pub fn dynamic_range(&self) -> f32 {
130 self.max - self.min
131 }
132
133 pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
135 let max_abs = self.min.abs().max(self.max.abs());
136 let max_int = match quant_type {
137 QuantType::Q8_0 => 127.0,
138 QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
139 QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
140 QuantType::F16 | QuantType::F32 => 1.0,
141 };
142 max_abs / max_int
143 }
144
145 pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
147 let max_int = match quant_type {
148 QuantType::Q8_0 => 255.0,
149 QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
150 QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
151 QuantType::F16 | QuantType::F32 => 1.0,
152 };
153
154 let scale = (self.max - self.min) / max_int;
155 let zero_point = -self.min / scale;
156
157 (scale, zero_point)
158 }
159
160 pub fn percentile(&self, p: f32) -> f32 {
162 if p <= 0.0 {
163 return self.min;
164 }
165 if p >= 100.0 {
166 return self.max;
167 }
168
169 let target = (p / 100.0 * self.num_samples as f32) as usize;
170 let mut cumsum = 0usize;
171
172 for (i, &count) in self.histogram.iter().enumerate() {
173 cumsum += count;
174 if cumsum >= target {
175 return self.bin_edges[i];
176 }
177 }
178
179 self.max
180 }
181}
182
183#[derive(Debug, Clone, Copy, PartialEq, Eq)]
189pub enum CalibrationMethod {
190 MinMax,
192 Percentile(u32), Entropy,
196 MeanStd(u32), }
199
200pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
209 let mut data = CalibrationData::new(tensor, 2048);
210
211 match method {
212 CalibrationMethod::MinMax => {
213 }
215 CalibrationMethod::Percentile(p) => {
216 let percentile = p as f32 / 10.0;
217 let lower = data.percentile(100.0 - percentile);
218 let upper = data.percentile(percentile);
219 data.min = lower;
220 data.max = upper;
221 }
222 CalibrationMethod::MeanStd(k) => {
223 let k_factor = k as f32 / 10.0;
224 data.min = data.mean - k_factor * data.std_dev;
225 data.max = data.mean + k_factor * data.std_dev;
226 }
227 CalibrationMethod::Entropy => {
228 let n_bins = data.histogram.len();
232 if n_bins < 4 {
233 data.min = data.percentile(0.01);
235 data.max = data.percentile(99.99);
236 } else {
237 let total: usize = data.histogram.iter().sum();
238 if total == 0 {
239 data.min = data.percentile(0.01);
240 data.max = data.percentile(99.99);
241 } else {
242 let ref_dist: Vec<f64> = data.histogram.iter()
244 .map(|&c| c as f64 / total as f64 + 1e-12)
245 .collect();
246
247 let quant_bins = 128usize; let mut best_kl = f64::MAX;
249 let mut best_threshold = n_bins;
250
251 for threshold in (n_bins / 2)..n_bins {
253 let mut clipped = ref_dist[..threshold].to_vec();
255 let outlier_mass: f64 = ref_dist[threshold..].iter().sum();
257 if let Some(last) = clipped.last_mut() {
258 *last += outlier_mass;
259 }
260
261 let bins_per_quant = (threshold + quant_bins - 1) / quant_bins;
263 let mut quant_dist = vec![0.0f64; quant_bins.min(threshold)];
264 for (i, &p) in clipped.iter().enumerate() {
265 let qi = (i / bins_per_quant).min(quant_dist.len() - 1);
266 quant_dist[qi] += p;
267 }
268
269 let mut expanded = vec![0.0f64; threshold];
271 for qi in 0..quant_dist.len() {
272 let start = qi * bins_per_quant;
273 let end = ((qi + 1) * bins_per_quant).min(threshold);
274 let count = (end - start) as f64;
275 if count > 0.0 {
276 let val = quant_dist[qi] / count;
277 for j in start..end {
278 expanded[j] = val + 1e-12;
279 }
280 }
281 }
282
283 let kl: f64 = clipped.iter().zip(expanded.iter())
285 .map(|(&p, &q)| if p > 1e-12 { p * (p / q).ln() } else { 0.0 })
286 .sum();
287
288 if kl < best_kl {
289 best_kl = kl;
290 best_threshold = threshold;
291 }
292 }
293
294 let bin_width = (data.max - data.min) / n_bins as f32;
296 let clip_max = data.min + best_threshold as f32 * bin_width;
297 data.max = clip_max;
298 if data.min < 0.0 && data.max > 0.0 {
300 let abs_max = data.max.abs().max(data.min.abs());
301 data.min = -abs_max;
302 data.max = abs_max;
303 }
304 }
305 }
306 }
307 }
308
309 Ok(data)
310}
311
312pub fn calibrate_batch(
314 tensors: &[&Tensor<f32>],
315 method: CalibrationMethod,
316) -> QuantResult<CalibrationData> {
317 if tensors.is_empty() {
318 return Err(QuantError::CalibrationError(
319 "No tensors provided".to_string(),
320 ));
321 }
322
323 let mut combined = CalibrationData::new(tensors[0], 2048);
324
325 for tensor in tensors.iter().skip(1) {
326 combined.update(tensor);
327 }
328
329 match method {
331 CalibrationMethod::Percentile(p) => {
332 let percentile = p as f32 / 10.0;
333 combined.min = combined.percentile(100.0 - percentile);
334 combined.max = combined.percentile(percentile);
335 }
336 CalibrationMethod::MeanStd(k) => {
337 let k_factor = k as f32 / 10.0;
338 combined.min = combined.mean - k_factor * combined.std_dev;
339 combined.max = combined.mean + k_factor * combined.std_dev;
340 }
341 _ => {}
342 }
343
344 Ok(combined)
345}
346
347#[cfg(test)]
352mod tests {
353 use super::*;
354
355 #[test]
356 fn test_calibration_data() {
357 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
358 let tensor = Tensor::from_vec(data, &[5]).unwrap();
359
360 let calib = CalibrationData::new(&tensor, 10);
361
362 assert_eq!(calib.min, 1.0);
363 assert_eq!(calib.max, 5.0);
364 assert_eq!(calib.mean, 3.0);
365 assert_eq!(calib.num_samples, 5);
366 }
367
368 #[test]
369 fn test_symmetric_scale() {
370 let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
371 let tensor = Tensor::from_vec(data, &[5]).unwrap();
372
373 let calib = CalibrationData::new(&tensor, 10);
374 let scale = calib.symmetric_scale(QuantType::Q8_0);
375
376 assert!((scale - 4.0 / 127.0).abs() < 0.001);
378 }
379
380 #[test]
381 fn test_calibration_methods() {
382 let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
383 let tensor = Tensor::from_vec(data, &[1000]).unwrap();
384
385 let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
387 assert!((minmax.min - 0.0).abs() < 0.01);
388 assert!((minmax.max - 9.99).abs() < 0.01);
389
390 let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
392 assert!(percentile.min >= 0.0);
393 assert!(percentile.max <= 9.99);
394 }
395
396 #[test]
397 fn test_dynamic_range() {
398 let data = vec![-5.0, 10.0];
399 let tensor = Tensor::from_vec(data, &[2]).unwrap();
400
401 let calib = CalibrationData::new(&tensor, 10);
402 assert_eq!(calib.dynamic_range(), 15.0);
403 }
404}