axonml_quant/
calibration.rs1use axonml_tensor::Tensor;
18
19use crate::error::{QuantError, QuantResult};
20use crate::types::QuantType;
21
22#[derive(Debug, Clone)]
28pub struct CalibrationData {
29 pub min: f32,
31 pub max: f32,
33 pub mean: f32,
35 pub std_dev: f32,
37 pub num_samples: usize,
39 histogram: Vec<usize>,
41 bin_edges: Vec<f32>,
43}
44
45impl CalibrationData {
46 pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
48 let data = tensor.to_vec();
49 let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
50 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
51 let mean = data.iter().sum::<f32>() / data.len() as f32;
52
53 let variance = data.iter().map(|x| (x - mean).powi(2)).sum::<f32>() / data.len() as f32;
54 let std_dev = variance.sqrt();
55
56 let bin_width = (max - min) / num_bins as f32;
58 let mut histogram = vec![0usize; num_bins];
59 let bin_edges: Vec<f32> = (0..=num_bins).map(|i| min + i as f32 * bin_width).collect();
60
61 for &val in &data {
62 let bin = ((val - min) / bin_width) as usize;
63 let bin = bin.min(num_bins - 1);
64 histogram[bin] += 1;
65 }
66
67 Self {
68 min,
69 max,
70 mean,
71 std_dev,
72 num_samples: data.len(),
73 histogram,
74 bin_edges,
75 }
76 }
77
78 pub fn update(&mut self, tensor: &Tensor<f32>) {
80 let data = tensor.to_vec();
81 let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
82 let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
83
84 self.min = self.min.min(new_min);
86 self.max = self.max.max(new_max);
87
88 let old_count = self.num_samples as f32;
90 let new_count = data.len() as f32;
91 let new_mean = data.iter().sum::<f32>() / new_count;
92 self.mean = (self.mean * old_count + new_mean * new_count) / (old_count + new_count);
93
94 self.num_samples += data.len();
96 }
98
99 pub fn dynamic_range(&self) -> f32 {
101 self.max - self.min
102 }
103
104 pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
106 let max_abs = self.min.abs().max(self.max.abs());
107 let max_int = match quant_type {
108 QuantType::Q8_0 => 127.0,
109 QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
110 QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
111 QuantType::F16 | QuantType::F32 => 1.0,
112 };
113 max_abs / max_int
114 }
115
116 pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
118 let max_int = match quant_type {
119 QuantType::Q8_0 => 255.0,
120 QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
121 QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
122 QuantType::F16 | QuantType::F32 => 1.0,
123 };
124
125 let scale = (self.max - self.min) / max_int;
126 let zero_point = -self.min / scale;
127
128 (scale, zero_point)
129 }
130
131 pub fn percentile(&self, p: f32) -> f32 {
133 if p <= 0.0 {
134 return self.min;
135 }
136 if p >= 100.0 {
137 return self.max;
138 }
139
140 let target = (p / 100.0 * self.num_samples as f32) as usize;
141 let mut cumsum = 0usize;
142
143 for (i, &count) in self.histogram.iter().enumerate() {
144 cumsum += count;
145 if cumsum >= target {
146 return self.bin_edges[i];
147 }
148 }
149
150 self.max
151 }
152}
153
154#[derive(Debug, Clone, Copy, PartialEq, Eq)]
160pub enum CalibrationMethod {
161 MinMax,
163 Percentile(u32), Entropy,
167 MeanStd(u32), }
170
171pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
180 let mut data = CalibrationData::new(tensor, 2048);
181
182 match method {
183 CalibrationMethod::MinMax => {
184 }
186 CalibrationMethod::Percentile(p) => {
187 let percentile = p as f32 / 10.0;
188 let lower = data.percentile(100.0 - percentile);
189 let upper = data.percentile(percentile);
190 data.min = lower;
191 data.max = upper;
192 }
193 CalibrationMethod::MeanStd(k) => {
194 let k_factor = k as f32 / 10.0;
195 data.min = data.mean - k_factor * data.std_dev;
196 data.max = data.mean + k_factor * data.std_dev;
197 }
198 CalibrationMethod::Entropy => {
199 data.min = data.percentile(0.01);
201 data.max = data.percentile(99.99);
202 }
203 }
204
205 Ok(data)
206}
207
208pub fn calibrate_batch(
210 tensors: &[&Tensor<f32>],
211 method: CalibrationMethod,
212) -> QuantResult<CalibrationData> {
213 if tensors.is_empty() {
214 return Err(QuantError::CalibrationError(
215 "No tensors provided".to_string(),
216 ));
217 }
218
219 let mut combined = CalibrationData::new(tensors[0], 2048);
220
221 for tensor in tensors.iter().skip(1) {
222 combined.update(tensor);
223 }
224
225 match method {
227 CalibrationMethod::Percentile(p) => {
228 let percentile = p as f32 / 10.0;
229 combined.min = combined.percentile(100.0 - percentile);
230 combined.max = combined.percentile(percentile);
231 }
232 CalibrationMethod::MeanStd(k) => {
233 let k_factor = k as f32 / 10.0;
234 combined.min = combined.mean - k_factor * combined.std_dev;
235 combined.max = combined.mean + k_factor * combined.std_dev;
236 }
237 _ => {}
238 }
239
240 Ok(combined)
241}
242
243#[cfg(test)]
248mod tests {
249 use super::*;
250
251 #[test]
252 fn test_calibration_data() {
253 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
254 let tensor = Tensor::from_vec(data, &[5]).unwrap();
255
256 let calib = CalibrationData::new(&tensor, 10);
257
258 assert_eq!(calib.min, 1.0);
259 assert_eq!(calib.max, 5.0);
260 assert_eq!(calib.mean, 3.0);
261 assert_eq!(calib.num_samples, 5);
262 }
263
264 #[test]
265 fn test_symmetric_scale() {
266 let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
267 let tensor = Tensor::from_vec(data, &[5]).unwrap();
268
269 let calib = CalibrationData::new(&tensor, 10);
270 let scale = calib.symmetric_scale(QuantType::Q8_0);
271
272 assert!((scale - 4.0 / 127.0).abs() < 0.001);
274 }
275
276 #[test]
277 fn test_calibration_methods() {
278 let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
279 let tensor = Tensor::from_vec(data, &[1000]).unwrap();
280
281 let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
283 assert!((minmax.min - 0.0).abs() < 0.01);
284 assert!((minmax.max - 9.99).abs() < 0.01);
285
286 let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
288 assert!(percentile.min >= 0.0);
289 assert!(percentile.max <= 9.99);
290 }
291
292 #[test]
293 fn test_dynamic_range() {
294 let data = vec![-5.0, 10.0];
295 let tensor = Tensor::from_vec(data, &[2]).unwrap();
296
297 let calib = CalibrationData::new(&tensor, 10);
298 assert_eq!(calib.dynamic_range(), 15.0);
299 }
300}