1use crate::calibration::methods::CalibrationMethod;
7
8const NUM_BINS: usize = 256;
9
10#[derive(Debug, Clone)]
15pub struct ActivationStats {
16 min: f32,
17 max: f32,
18 mean: f32,
19 std: f32,
20 count: usize,
21
22 m2: f64,
24
25 histogram_bins: Vec<usize>,
26 hist_min: f32,
27 hist_max: f32,
28}
29
30impl ActivationStats {
31 pub fn min(&self) -> f32 {
33 self.min
34 }
35 pub fn max(&self) -> f32 {
37 self.max
38 }
39 pub fn mean(&self) -> f32 {
41 self.mean
42 }
43 pub fn std(&self) -> f32 {
45 self.std
46 }
47 pub fn count(&self) -> usize {
49 self.count
50 }
51}
52
53impl ActivationStats {
54 pub fn from_data(data: &[f32]) -> Self {
56 if data.is_empty() {
57 return Self::default();
58 }
59
60 let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
61 if finite.is_empty() {
62 return Self::default();
63 }
64
65 let min = finite.iter().copied().fold(f32::INFINITY, f32::min);
66 let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
67
68 let sum: f32 = finite.iter().sum();
69 let mean = sum / finite.len() as f32;
70
71 let m2: f64 = finite.iter().map(|&x| ((x - mean) as f64).powi(2)).sum();
72 let std = (m2 / finite.len() as f64).sqrt() as f32;
73
74 let histogram_bins = build_histogram(data, min, max);
75
76 Self {
77 min,
78 max,
79 mean,
80 std,
81 count: finite.len(),
82 m2,
83 histogram_bins,
84 hist_min: min,
85 hist_max: max,
86 }
87 }
88
89 pub fn update(&mut self, data: &[f32]) {
91 if data.is_empty() {
92 return;
93 }
94
95 let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
97 if finite.is_empty() {
98 return;
99 }
100
101 let data_min = finite.iter().copied().fold(f32::INFINITY, f32::min);
102 let data_max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
103
104 let new_min = self.min.min(data_min);
105 let new_max = self.max.max(data_max);
106
107 let old_count = self.count as f64;
110 let new_count = finite.len() as f64;
111 let combined_count = old_count + new_count;
112
113 let data_sum: f64 = finite.iter().map(|&x| x as f64).sum();
114 let data_mean = data_sum / new_count;
115
116 let data_m2: f64 = finite
117 .iter()
118 .map(|&x| ((x as f64) - data_mean).powi(2))
119 .sum();
120
121 let delta = data_mean - self.mean as f64;
123 self.m2 = self.m2 + data_m2 + delta * delta * old_count * new_count / combined_count;
124
125 self.mean = ((self.mean as f64) * old_count + data_sum) as f32 / combined_count as f32;
126 self.count = combined_count as usize;
127 self.std = (self.m2 / combined_count).sqrt() as f32;
128
129 if new_min < self.hist_min || new_max > self.hist_max {
131 let mut rebinned = vec![0usize; NUM_BINS];
132 rebin(
133 &self.histogram_bins,
134 self.hist_min,
135 self.hist_max,
136 &mut rebinned,
137 new_min,
138 new_max,
139 );
140 self.histogram_bins = rebinned;
141 self.hist_min = new_min;
142 self.hist_max = new_max;
143 }
144
145 let new_hist = build_histogram(&finite, self.hist_min, self.hist_max);
147 for (i, &c) in new_hist.iter().enumerate() {
148 self.histogram_bins[i] += c;
149 }
150
151 self.min = new_min;
152 self.max = new_max;
153 }
154
155 pub fn percentile(&self, p: f32) -> f32 {
157 if self.histogram_bins.is_empty() {
158 return self.min;
159 }
160
161 let total: usize = self.histogram_bins.iter().sum();
162 if total == 0 {
163 return self.min;
164 }
165
166 let target_count = (total as f32 * p / 100.0).ceil() as usize;
169 let mut cumulative = 0;
170
171 let bin_size = if (self.hist_max - self.hist_min).abs() < 1e-8 {
172 0.0
173 } else {
174 (self.hist_max - self.hist_min) / NUM_BINS as f32
175 };
176
177 for (i, &count) in self.histogram_bins.iter().enumerate() {
178 cumulative += count;
179 if cumulative >= target_count {
180 return self.hist_min + (i as f32 + 0.5) * bin_size;
181 }
182 }
183
184 self.max
185 }
186
187 pub fn histogram_data(&self) -> Vec<(f32, usize)> {
189 if (self.hist_max - self.hist_min).abs() < 1e-8 {
190 let total: usize = self.histogram_bins.iter().sum();
191 if total > 0 {
192 return vec![(self.hist_min, total)];
193 }
194 return Vec::new();
195 }
196 let bin_size = (self.hist_max - self.hist_min) / NUM_BINS as f32;
197 self.histogram_bins
198 .iter()
199 .enumerate()
200 .filter(|(_, &count)| count > 0)
201 .map(|(i, &count)| {
202 let value = self.hist_min + (i as f32 + 0.5) * bin_size;
203 (value, count)
204 })
205 .collect()
206 }
207}
208
209impl Default for ActivationStats {
210 fn default() -> Self {
211 Self {
212 min: f32::INFINITY,
213 max: f32::NEG_INFINITY,
214 mean: 0.0,
215 std: 0.0,
216 count: 0,
217 m2: 0.0,
218 histogram_bins: Vec::new(),
219 hist_min: 0.0,
220 hist_max: 0.0,
221 }
222 }
223}
224
225fn build_histogram(data: &[f32], min: f32, max: f32) -> Vec<usize> {
226 let mut bins = vec![0usize; NUM_BINS];
227
228 if (max - min).abs() < 1e-8 {
229 let finite_count = data.iter().filter(|v| v.is_finite()).count();
231 if !bins.is_empty() {
232 bins[0] = finite_count;
233 }
234 return bins;
235 }
236
237 let bin_size = (max - min) / NUM_BINS as f32;
238
239 for &value in data {
240 if !value.is_finite() {
241 continue;
242 }
243 let bin_idx = ((value - min) / bin_size).floor() as usize;
244 let bin_idx = bin_idx.min(NUM_BINS - 1);
245 bins[bin_idx] += 1;
246 }
247
248 bins
249}
250
251fn rebin(
253 old_bins: &[usize],
254 old_min: f32,
255 old_max: f32,
256 new_bins: &mut [usize],
257 new_min: f32,
258 new_max: f32,
259) {
260 if old_bins.is_empty() || new_bins.is_empty() {
261 return;
262 }
263 let old_range = old_max - old_min;
264 let new_range = new_max - new_min;
265 if old_range.abs() < 1e-8 || new_range.abs() < 1e-8 {
266 let total: usize = old_bins.iter().sum();
268 if total > 0 {
269 let center = (old_min + old_max) * 0.5;
270 let idx = ((center - new_min) / new_range * new_bins.len() as f32).floor() as usize;
271 let idx = idx.min(new_bins.len() - 1);
272 new_bins[idx] += total;
273 }
274 return;
275 }
276 let old_bin_size = old_range / old_bins.len() as f32;
277 let new_bin_count = new_bins.len();
278 for (i, &count) in old_bins.iter().enumerate() {
279 if count == 0 {
280 continue;
281 }
282 let center = old_min + (i as f32 + 0.5) * old_bin_size;
283 let new_idx = ((center - new_min) / new_range * new_bin_count as f32).floor() as usize;
284 let new_idx = new_idx.min(new_bin_count - 1);
285 new_bins[new_idx] += count;
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_activation_stats() {
295 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
296 let stats = ActivationStats::from_data(&data);
297
298 assert_eq!(stats.min(), -1.0);
299 assert_eq!(stats.max(), 1.0);
300 assert!((stats.mean() - 0.0).abs() < 0.01);
301
302 let p50 = stats.percentile(50.0);
303 assert!((p50 - 0.0).abs() < 0.3);
304 }
305}
306
307pub fn calculate_optimal_range(data: &[f32], method: CalibrationMethod) -> (f32, f32) {
309 if data.is_empty() {
310 return (0.0, 0.0);
311 }
312
313 match method {
314 CalibrationMethod::MinMax => {
315 let min = data
316 .iter()
317 .copied()
318 .filter(|v| v.is_finite())
319 .fold(f32::INFINITY, f32::min);
320 let max = data
321 .iter()
322 .copied()
323 .filter(|v| v.is_finite())
324 .fold(f32::NEG_INFINITY, f32::max);
325 (min, max)
326 }
327
328 CalibrationMethod::Percentile(p) => {
329 let stats = ActivationStats::from_data(data);
330 let lower = stats.percentile(100.0 - p);
331 let upper = stats.percentile(p);
332 (lower, upper)
333 }
334
335 CalibrationMethod::Entropy => optimize_kl_divergence(data),
336
337 CalibrationMethod::MSE => optimize_mse(data),
338 }
339}
340
341fn optimize_kl_divergence(data: &[f32]) -> (f32, f32) {
343 let stats = ActivationStats::from_data(data);
344
345 let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
347 let mut best_range = (stats.min, stats.max);
348 let mut best_kl = f32::INFINITY;
349
350 for &percentile in &candidates {
351 let lower = stats.percentile(100.0 - percentile);
352 let upper = stats.percentile(percentile);
353
354 let kl = calculate_kl_divergence(data, lower, upper);
355
356 if kl < best_kl {
357 best_kl = kl;
358 best_range = (lower, upper);
359 }
360 }
361
362 best_range
363}
364
365fn optimize_mse(data: &[f32]) -> (f32, f32) {
367 let stats = ActivationStats::from_data(data);
368
369 let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
371 let mut best_range = (stats.min, stats.max);
372 let mut best_mse = f32::INFINITY;
373
374 for &percentile in &candidates {
375 let lower = stats.percentile(100.0 - percentile);
376 let upper = stats.percentile(percentile);
377
378 let mse = calculate_quantization_mse(data, lower, upper);
379
380 if mse < best_mse {
381 best_mse = mse;
382 best_range = (lower, upper);
383 }
384 }
385
386 best_range
387}
388
389fn calculate_kl_divergence(data: &[f32], min: f32, max: f32) -> f32 {
394 if (max - min).abs() < 1e-8 {
395 return 0.0;
396 }
397
398 let num_bins = 128;
399 let bin_size = (max - min) / num_bins as f32;
400 let scale = (max - min) / 255.0;
401
402 let mut orig_bins = vec![0usize; num_bins];
403 let mut quant_bins = vec![0usize; num_bins];
404
405 for &v in data {
406 let clipped = v.clamp(min, max);
407
408 let bin = ((clipped - min) / bin_size).floor() as usize;
410 let bin = bin.min(num_bins - 1);
411 orig_bins[bin] += 1;
412
413 let q = ((clipped - min) / scale).round();
415 let dequant = min + q * scale;
416 let qbin = ((dequant.clamp(min, max) - min) / bin_size).floor() as usize;
417 let qbin = qbin.min(num_bins - 1);
418 quant_bins[qbin] += 1;
419 }
420
421 let n = data.len() as f32;
422 let epsilon = 1e-10_f32;
423 let mut kl = 0.0_f32;
424
425 for i in 0..num_bins {
426 let p = (orig_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
427 let q = (quant_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
428 kl += p * (p / q).ln();
429 }
430
431 kl
432}
433
434fn calculate_quantization_mse(data: &[f32], min: f32, max: f32) -> f32 {
435 if (max - min).abs() < 1e-8 {
436 return 0.0;
437 }
438
439 let scale = (max - min) / 255.0;
440
441 let mse: f32 = data
442 .iter()
443 .map(|&v| {
444 let clipped = v.clamp(min, max);
445 let q = ((clipped - min) / scale).round().clamp(0.0, 255.0);
446 let dequantized = min + q * scale;
447 (v - dequantized).powi(2)
448 })
449 .sum::<f32>()
450 / data.len() as f32;
451
452 mse
453}