axonml_quant/
calibration.rs1use axonml_tensor::Tensor;
9
10use crate::error::{QuantError, QuantResult};
11use crate::types::QuantType;
12
13#[derive(Debug, Clone)]
19pub struct CalibrationData {
20 pub min: f32,
22 pub max: f32,
24 pub mean: f32,
26 pub std_dev: f32,
28 pub num_samples: usize,
30 histogram: Vec<usize>,
32 bin_edges: Vec<f32>,
34}
35
36impl CalibrationData {
37 pub fn new(tensor: &Tensor<f32>, num_bins: usize) -> Self {
39 let data = tensor.to_vec();
40 let min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
41 let max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
42 let mean = data.iter().sum::<f32>() / data.len() as f32;
43
44 let variance = data
45 .iter()
46 .map(|x| (x - mean).powi(2))
47 .sum::<f32>()
48 / data.len() as f32;
49 let std_dev = variance.sqrt();
50
51 let bin_width = (max - min) / num_bins as f32;
53 let mut histogram = vec![0usize; num_bins];
54 let bin_edges: Vec<f32> = (0..=num_bins)
55 .map(|i| min + i as f32 * bin_width)
56 .collect();
57
58 for &val in &data {
59 let bin = ((val - min) / bin_width) as usize;
60 let bin = bin.min(num_bins - 1);
61 histogram[bin] += 1;
62 }
63
64 Self {
65 min,
66 max,
67 mean,
68 std_dev,
69 num_samples: data.len(),
70 histogram,
71 bin_edges,
72 }
73 }
74
75 pub fn update(&mut self, tensor: &Tensor<f32>) {
77 let data = tensor.to_vec();
78 let new_min = data.iter().fold(f32::INFINITY, |a, &b| a.min(b));
79 let new_max = data.iter().fold(f32::NEG_INFINITY, |a, &b| a.max(b));
80
81 self.min = self.min.min(new_min);
83 self.max = self.max.max(new_max);
84
85 let old_count = self.num_samples as f32;
87 let new_count = data.len() as f32;
88 let new_mean = data.iter().sum::<f32>() / new_count;
89 self.mean = (self.mean * old_count + new_mean * new_count) / (old_count + new_count);
90
91 self.num_samples += data.len();
93 }
95
96 pub fn dynamic_range(&self) -> f32 {
98 self.max - self.min
99 }
100
101 pub fn symmetric_scale(&self, quant_type: QuantType) -> f32 {
103 let max_abs = self.min.abs().max(self.max.abs());
104 let max_int = match quant_type {
105 QuantType::Q8_0 => 127.0,
106 QuantType::Q4_0 | QuantType::Q4_1 => 7.0,
107 QuantType::Q5_0 | QuantType::Q5_1 => 15.0,
108 QuantType::F16 | QuantType::F32 => 1.0,
109 };
110 max_abs / max_int
111 }
112
113 pub fn asymmetric_scale(&self, quant_type: QuantType) -> (f32, f32) {
115 let max_int = match quant_type {
116 QuantType::Q8_0 => 255.0,
117 QuantType::Q4_0 | QuantType::Q4_1 => 15.0,
118 QuantType::Q5_0 | QuantType::Q5_1 => 31.0,
119 QuantType::F16 | QuantType::F32 => 1.0,
120 };
121
122 let scale = (self.max - self.min) / max_int;
123 let zero_point = -self.min / scale;
124
125 (scale, zero_point)
126 }
127
128 pub fn percentile(&self, p: f32) -> f32 {
130 if p <= 0.0 {
131 return self.min;
132 }
133 if p >= 100.0 {
134 return self.max;
135 }
136
137 let target = (p / 100.0 * self.num_samples as f32) as usize;
138 let mut cumsum = 0usize;
139
140 for (i, &count) in self.histogram.iter().enumerate() {
141 cumsum += count;
142 if cumsum >= target {
143 return self.bin_edges[i];
144 }
145 }
146
147 self.max
148 }
149}
150
151#[derive(Debug, Clone, Copy, PartialEq, Eq)]
157pub enum CalibrationMethod {
158 MinMax,
160 Percentile(u32), Entropy,
164 MeanStd(u32), }
167
168pub fn calibrate(tensor: &Tensor<f32>, method: CalibrationMethod) -> QuantResult<CalibrationData> {
177 let mut data = CalibrationData::new(tensor, 2048);
178
179 match method {
180 CalibrationMethod::MinMax => {
181 }
183 CalibrationMethod::Percentile(p) => {
184 let percentile = p as f32 / 10.0;
185 let lower = data.percentile(100.0 - percentile);
186 let upper = data.percentile(percentile);
187 data.min = lower;
188 data.max = upper;
189 }
190 CalibrationMethod::MeanStd(k) => {
191 let k_factor = k as f32 / 10.0;
192 data.min = data.mean - k_factor * data.std_dev;
193 data.max = data.mean + k_factor * data.std_dev;
194 }
195 CalibrationMethod::Entropy => {
196 data.min = data.percentile(0.01);
198 data.max = data.percentile(99.99);
199 }
200 }
201
202 Ok(data)
203}
204
205pub fn calibrate_batch(
207 tensors: &[&Tensor<f32>],
208 method: CalibrationMethod,
209) -> QuantResult<CalibrationData> {
210 if tensors.is_empty() {
211 return Err(QuantError::CalibrationError("No tensors provided".to_string()));
212 }
213
214 let mut combined = CalibrationData::new(tensors[0], 2048);
215
216 for tensor in tensors.iter().skip(1) {
217 combined.update(tensor);
218 }
219
220 match method {
222 CalibrationMethod::Percentile(p) => {
223 let percentile = p as f32 / 10.0;
224 combined.min = combined.percentile(100.0 - percentile);
225 combined.max = combined.percentile(percentile);
226 }
227 CalibrationMethod::MeanStd(k) => {
228 let k_factor = k as f32 / 10.0;
229 combined.min = combined.mean - k_factor * combined.std_dev;
230 combined.max = combined.mean + k_factor * combined.std_dev;
231 }
232 _ => {}
233 }
234
235 Ok(combined)
236}
237
238#[cfg(test)]
243mod tests {
244 use super::*;
245
246 #[test]
247 fn test_calibration_data() {
248 let data = vec![1.0, 2.0, 3.0, 4.0, 5.0];
249 let tensor = Tensor::from_vec(data, &[5]).unwrap();
250
251 let calib = CalibrationData::new(&tensor, 10);
252
253 assert_eq!(calib.min, 1.0);
254 assert_eq!(calib.max, 5.0);
255 assert_eq!(calib.mean, 3.0);
256 assert_eq!(calib.num_samples, 5);
257 }
258
259 #[test]
260 fn test_symmetric_scale() {
261 let data = vec![-4.0, -2.0, 0.0, 2.0, 4.0];
262 let tensor = Tensor::from_vec(data, &[5]).unwrap();
263
264 let calib = CalibrationData::new(&tensor, 10);
265 let scale = calib.symmetric_scale(QuantType::Q8_0);
266
267 assert!((scale - 4.0 / 127.0).abs() < 0.001);
269 }
270
271 #[test]
272 fn test_calibration_methods() {
273 let data: Vec<f32> = (0..1000).map(|x| x as f32 / 100.0).collect();
274 let tensor = Tensor::from_vec(data, &[1000]).unwrap();
275
276 let minmax = calibrate(&tensor, CalibrationMethod::MinMax).unwrap();
278 assert!((minmax.min - 0.0).abs() < 0.01);
279 assert!((minmax.max - 9.99).abs() < 0.01);
280
281 let percentile = calibrate(&tensor, CalibrationMethod::Percentile(999)).unwrap();
283 assert!(percentile.min >= 0.0);
284 assert!(percentile.max <= 9.99);
285 }
286
287 #[test]
288 fn test_dynamic_range() {
289 let data = vec![-5.0, 10.0];
290 let tensor = Tensor::from_vec(data, &[2]).unwrap();
291
292 let calib = CalibrationData::new(&tensor, 10);
293 assert_eq!(calib.dynamic_range(), 15.0);
294 }
295}