axonml_quant/
calibration.rs1use axonml_tensor::Tensor;
19
20use crate::error::{QuantError, QuantResult};
21use crate::types::QuantType;
22
23#[derive(Debug, Clone)]
29pub struct CalibrationData {
30 pub min: f32,
32 pub max: f32,
34 pub mean: f32,
36 pub std_dev: f32,
38 pub num_samples: usize,
40 histogram: Vec<usize>,
42 bin_edges: Vec<f32>,
44}
45
46impl CalibrationData {
47 pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
49 let data = tensor.to_vec();
50 let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
51 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
52 let mean = data.iter().sum::<f32>() / data.len() as f32;
53
54 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
55 let std_dev = variance.sqrt();
56
57 let bin_width = (max - min) / num_bins as f32;
59 let mut histogram = vec![0usize; num_bins];
60 let bin_edges: Vec<f32> = (0..=num_bins).map(|i| min + i as f32 * bin_width).collect();
61
62 for &val in &data {
63 let bin = ((val - min) / bin_width) as usize;
64 let bin = bin.min(num_bins - 1);
65 histogram[bin] += 1;
66 }
67
68 Self {
69 min,
70 max,
71 mean,
72 std_dev,
73 num_samples: data.len(),
74 histogram,
75 bin_edges,
76 }
77 }
78
79 pub fn update(&mut self, tensor: &Tensor<f32>) {
81 let data = tensor.to_vec();
82 let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
83 let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
84
85 self.min = self.min.min(new_min);
87 self.max = self.max.max(new_max);
88
89 let old_mean = self.mean;
91 let old_count = self.num_samples;
92 for &val in &data {
93 self.num_samples += 1;
94 let delta = val - self.mean;
95 self.mean += delta / self.num_samples as f32;
96 }
97
98 if old_count > 0 && !data.is_empty() {
101 let new_mean_batch: f32 = data.iter().sum::<f32>() / data.len() as f32;
102 let new_var_batch: f32 = data
103 .iter()
104 .map(|&v| (v - new_mean_batch).powi(2))
105 .sum::<f32>()
106 / data.len() as f32;
107 let old_var = self.std_dev * self.std_dev;
108 let n1 = old_count as f32;
109 let n2 = data.len() as f32;
110 let combined_var = (n1 * old_var
111 + n2 * new_var_batch
112 + n1 * n2 / (n1 + n2) * (old_mean - new_mean_batch).powi(2))
113 / (n1 + n2);
114 self.std_dev = combined_var.sqrt();
115 } else if !data.is_empty() {
116 let m: f32 = data.iter().sum::<f32>() / data.len() as f32;
117 self.std_dev =
118 (data.iter().map(|&v| (v - m).powi(2)).sum::<f32>() / data.len() as f32).sqrt();
119 self.num_samples = data.len();
120 }
121
122 if !self.histogram.is_empty() && self.max > self.min {
124 let n_bins = self.histogram.len();
125 let bin_width = (self.max - self.min) / n_bins as f32;
126 for &val in &data {
127 let bin = ((val - self.min) / bin_width).floor() as usize;
128 let bin = bin.min(n_bins - 1);
129 self.histogram[bin] += 1;
130 }
131 }
132 }
133
134 pub fn dynamic_range(&self) -> f32 {
136 self.max - self.min
137 }
138
139 pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
141 let max_abs = self.min.abs().max(self.max.abs());
142 let max_int = match quant_type {
143 QuantType::Q8_0 => 127.0,
144 QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
145 QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
146 QuantType::F16 | QuantType::F32 => 1.0,
147 };
148 max_abs / max_int
149 }
150
151 pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
153 let max_int = match quant_type {
154 QuantType::Q8_0 => 255.0,
155 QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
156 QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
157 QuantType::F16 | QuantType::F32 => 1.0,
158 };
159
160 let scale = (self.max - self.min) / max_int;
161 let zero_point = -self.min / scale;
162
163 (scale, zero_point)
164 }
165
166 pub fn percentile(&self, p: f32) -> f32 {
168 if p <= 0.0 {
169 return self.min;
170 }
171 if p >= 100.0 {
172 return self.max;
173 }
174
175 let target = (p / 100.0 * self.num_samples as f32) as usize;
176 let mut cumsum = 0usize;
177
178 for (i, &count) in self.histogram.iter().enumerate() {
179 cumsum += count;
180 if cumsum >= target {
181 return self.bin_edges[i];
182 }
183 }
184
185 self.max
186 }
187}
188
189#[derive(Debug, Clone, Copy, PartialEq, Eq)]
195pub enum CalibrationMethod {
196 MinMax,
198 Percentile(u32), Entropy,
202 MeanStd(u32), }
205
206pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
215 let mut data = CalibrationData::new(tensor, 2048);
216
217 match method {
218 CalibrationMethod::MinMax => {
219 }
221 CalibrationMethod::Percentile(p) => {
222 let percentile = p as f32 / 10.0;
223 let lower = data.percentile(100.0 - percentile);
224 let upper = data.percentile(percentile);
225 data.min = lower;
226 data.max = upper;
227 }
228 CalibrationMethod::MeanStd(k) => {
229 let k_factor = k as f32 / 10.0;
230 data.min = data.mean - k_factor * data.std_dev;
231 data.max = data.mean + k_factor * data.std_dev;
232 }
233 CalibrationMethod::Entropy => {
234 let n_bins = data.histogram.len();
238 if n_bins < 4 {
239 data.min = data.percentile(0.01);
241 data.max = data.percentile(99.99);
242 } else {
243 let total: usize = data.histogram.iter().sum();
244 if total == 0 {
245 data.min = data.percentile(0.01);
246 data.max = data.percentile(99.99);
247 } else {
248 let ref_dist: Vec<f64> = data
250 .histogram
251 .iter()
252 .map(|&c| c as f64 / total as f64 + 1e-12)
253 .collect();
254
255 let quant_bins = 128usize; let mut best_kl = f64::MAX;
257 let mut best_threshold = n_bins;
258
259 for threshold in (n_bins / 2)..n_bins {
261 let mut clipped = ref_dist[..threshold].to_vec();
263 let outlier_mass: f64 = ref_dist[threshold..].iter().sum();
265 if let Some(last) = clipped.last_mut() {
266 *last += outlier_mass;
267 }
268
269 let bins_per_quant = threshold.div_ceil(quant_bins);
271 let mut quant_dist = vec![0.0f64; quant_bins.min(threshold)];
272 for (i, &p) in clipped.iter().enumerate() {
273 let qi = (i / bins_per_quant).min(quant_dist.len() - 1);
274 quant_dist[qi] += p;
275 }
276
277 let mut expanded = vec![0.0f64; threshold];
279 for (qi, &qval) in quant_dist.iter().enumerate() {
280 let start = qi * bins_per_quant;
281 let end = ((qi + 1) * bins_per_quant).min(threshold);
282 let count = (end - start) as f64;
283 if count > 0.0 {
284 let val = qval / count;
285 for slot in expanded.iter_mut().take(end).skip(start) {
286 *slot = val + 1e-12;
287 }
288 }
289 }
290
291 let kl: f64 = clipped
293 .iter()
294 .zip(expanded.iter())
295 .map(|(&p, &q)| if p > 1e-12 { p * (p / q).ln() } else { 0.0 })
296 .sum();
297
298 if kl < best_kl {
299 best_kl = kl;
300 best_threshold = threshold;
301 }
302 }
303
304 let bin_width = (data.max - data.min) / n_bins as f32;
306 let clip_max = data.min + best_threshold as f32 * bin_width;
307 data.max = clip_max;
308 if data.min < 0.0 && data.max > 0.0 {
310 let abs_max = data.max.abs().max(data.min.abs());
311 data.min = -abs_max;
312 data.max = abs_max;
313 }
314 }
315 }
316 }
317 }
318
319 Ok(data)
320}
321
322pub fn calibrate_batch(
324 tensors: &[&Tensor<f32>],
325 method: CalibrationMethod,
326) -> QuantResult<CalibrationData> {
327 if tensors.is_empty() {
328 return Err(QuantError::CalibrationError(
329 "No tensors provided".to_string(),
330 ));
331 }
332
333 let mut combined = CalibrationData::new(tensors[0], 2048);
334
335 for tensor in tensors.iter().skip(1) {
336 combined.update(tensor);
337 }
338
339 match method {
341 CalibrationMethod::Percentile(p) => {
342 let percentile = p as f32 / 10.0;
343 combined.min = combined.percentile(100.0 - percentile);
344 combined.max = combined.percentile(percentile);
345 }
346 CalibrationMethod::MeanStd(k) => {
347 let k_factor = k as f32 / 10.0;
348 combined.min = combined.mean - k_factor * combined.std_dev;
349 combined.max = combined.mean + k_factor * combined.std_dev;
350 }
351 _ => {}
352 }
353
354 Ok(combined)
355}
356
357#[cfg(test)]
362mod tests {
363 use super::*;
364
365 #[test]
366 fn test_calibration_data() {
367 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
368 let tensor = Tensor::from_vec(data, &[5]).unwrap();
369
370 let calib = CalibrationData::new(&tensor, 10);
371
372 assert_eq!(calib.min, 1.0);
373 assert_eq!(calib.max, 5.0);
374 assert_eq!(calib.mean, 3.0);
375 assert_eq!(calib.num_samples, 5);
376 }
377
378 #[test]
379 fn test_symmetric_scale() {
380 let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
381 let tensor = Tensor::from_vec(data, &[5]).unwrap();
382
383 let calib = CalibrationData::new(&tensor, 10);
384 let scale = calib.symmetric_scale(QuantType::Q8_0);
385
386 assert!((scale - 4.0 / 127.0).abs() < 0.001);
388 }
389
390 #[test]
391 fn test_calibration_methods() {
392 let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
393 let tensor = Tensor::from_vec(data, &[1000]).unwrap();
394
395 let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
397 assert!((minmax.min - 0.0).abs() < 0.01);
398 assert!((minmax.max - 9.99).abs() < 0.01);
399
400 let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
402 assert!(percentile.min >= 0.0);
403 assert!(percentile.max <= 9.99);
404 }
405
406 #[test]
407 fn test_dynamic_range() {
408 let data = vec![-5.0, 10.0];
409 let tensor = Tensor::from_vec(data, &[2]).unwrap();
410
411 let calib = CalibrationData::new(&tensor, 10);
412 assert_eq!(calib.dynamic_range(), 15.0);
413 }
414}