1use crate::calibration::methods::CalibrationMethod;
7
8const NUM_BINS: usize = 256;
9
10#[derive(Debug, Clone)]
15pub struct ActivationStats {
16 min: f32,
17 max: f32,
18 mean: f32,
19 std: f32,
20 count: usize,
21
22 m2: f64,
24
25 histogram_bins: Vec<usize>,
26 hist_min: f32,
27 hist_max: f32,
28}
29
30impl ActivationStats {
31 pub fn min(&self) -> f32 {
33 self.min
34 }
35 pub fn max(&self) -> f32 {
37 self.max
38 }
39 pub fn mean(&self) -> f32 {
41 self.mean
42 }
43 pub fn std(&self) -> f32 {
45 self.std
46 }
47 pub fn count(&self) -> usize {
49 self.count
50 }
51}
52
53impl ActivationStats {
54 pub fn from_data(data: &[f32]) -> Self {
56 if data.is_empty() {
57 return Self::default();
58 }
59
60 let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
61 if finite.is_empty() {
62 return Self::default();
63 }
64
65 let min = finite.iter().copied().fold(f32::INFINITY, f32::min);
66 let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
67
68 let sum: f32 = finite.iter().sum();
69 let mean = sum / finite.len() as f32;
70
71 let m2: f64 = finite.iter().map(|&x| ((x - mean) as f64).powi(2)).sum();
72 let std = (m2 / finite.len() as f64).sqrt() as f32;
73
74 let histogram_bins = build_histogram(data, min, max);
75
76 Self {
77 min,
78 max,
79 mean,
80 std,
81 count: finite.len(),
82 m2,
83 histogram_bins,
84 hist_min: min,
85 hist_max: max,
86 }
87 }
88
89 pub fn update(&mut self, data: &[f32]) {
91 if data.is_empty() {
92 return;
93 }
94
95 let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
97 if finite.is_empty() {
98 return;
99 }
100
101 let data_min = finite.iter().copied().fold(f32::INFINITY, f32::min);
102 let data_max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
103
104 let new_min = self.min.min(data_min);
105 let new_max = self.max.max(data_max);
106
107 let old_count = self.count as f64;
110 let new_count = finite.len() as f64;
111 let combined_count = old_count + new_count;
112
113 let data_sum: f64 = finite.iter().map(|&x| x as f64).sum();
114 let data_mean = data_sum / new_count;
115
116 let data_m2: f64 = finite
117 .iter()
118 .map(|&x| ((x as f64) - data_mean).powi(2))
119 .sum();
120
121 let delta = data_mean - self.mean as f64;
123 self.m2 = self.m2 + data_m2 + delta * delta * old_count * new_count / combined_count;
124
125 self.mean = ((self.mean as f64) * old_count + data_sum) as f32 / combined_count as f32;
126 self.count = combined_count as usize;
127 self.std = (self.m2 / combined_count).sqrt() as f32;
128
129 if new_min < self.hist_min || new_max > self.hist_max {
131 let mut rebinned = vec![0usize; NUM_BINS];
132 rebin(
133 &self.histogram_bins,
134 self.hist_min,
135 self.hist_max,
136 &mut rebinned,
137 new_min,
138 new_max,
139 );
140 self.histogram_bins = rebinned;
141 self.hist_min = new_min;
142 self.hist_max = new_max;
143 }
144
145 let new_hist = build_histogram(&finite, self.hist_min, self.hist_max);
147 for (i, &c) in new_hist.iter().enumerate() {
148 self.histogram_bins[i] += c;
149 }
150
151 self.min = new_min;
152 self.max = new_max;
153 }
154
155 pub fn percentile(&self, p: f32) -> f32 {
157 if self.histogram_bins.is_empty() {
158 return self.min;
159 }
160
161 let total: usize = self.histogram_bins.iter().sum();
162 if total == 0 {
163 return self.min;
164 }
165
166 let target_count = (total as f32 * p / 100.0).ceil() as usize;
169 let mut cumulative = 0;
170
171 let bin_size = if (self.hist_max - self.hist_min).abs() < 1e-8 {
172 0.0
173 } else {
174 (self.hist_max - self.hist_min) / NUM_BINS as f32
175 };
176
177 for (i, &count) in self.histogram_bins.iter().enumerate() {
178 cumulative += count;
179 if cumulative >= target_count {
180 return self.hist_min + (i as f32 + 0.5) * bin_size;
181 }
182 }
183
184 self.max
185 }
186
187 pub fn histogram_data(&self) -> Vec<(f32, usize)> {
189 if (self.hist_max - self.hist_min).abs() < 1e-8 {
190 let total: usize = self.histogram_bins.iter().sum();
191 if total > 0 {
192 return vec![(self.hist_min, total)];
193 }
194 return Vec::new();
195 }
196 let bin_size = (self.hist_max - self.hist_min) / NUM_BINS as f32;
197 self.histogram_bins
198 .iter()
199 .enumerate()
200 .filter(|(_, &count)| count > 0)
201 .map(|(i, &count)| {
202 let value = self.hist_min + (i as f32 + 0.5) * bin_size;
203 (value, count)
204 })
205 .collect()
206 }
207}
208
209impl Default for ActivationStats {
210 fn default() -> Self {
211 Self {
212 min: f32::INFINITY,
213 max: f32::NEG_INFINITY,
214 mean: 0.0,
215 std: 0.0,
216 count: 0,
217 m2: 0.0,
218 histogram_bins: Vec::new(),
219 hist_min: 0.0,
220 hist_max: 0.0,
221 }
222 }
223}
224
225fn build_histogram(data: &[f32], min: f32, max: f32) -> Vec<usize> {
226 let mut bins = vec![0usize; NUM_BINS];
227
228 if (max - min).abs() < 1e-8 {
229 let finite_count = data.iter().filter(|v| v.is_finite()).count();
231 if !bins.is_empty() {
232 bins[0] = finite_count;
233 }
234 return bins;
235 }
236
237 let bin_size = (max - min) / NUM_BINS as f32;
238
239 for &value in data {
240 if !value.is_finite() {
241 continue;
242 }
243 let bin_idx = ((value - min) / bin_size).floor() as usize;
244 let bin_idx = bin_idx.min(NUM_BINS - 1);
245 bins[bin_idx] += 1;
246 }
247
248 bins
249}
250
251fn rebin(
253 old_bins: &[usize],
254 old_min: f32,
255 old_max: f32,
256 new_bins: &mut [usize],
257 new_min: f32,
258 new_max: f32,
259) {
260 if old_bins.is_empty() || new_bins.is_empty() {
261 return;
262 }
263 let old_range = old_max - old_min;
264 let new_range = new_max - new_min;
265 if old_range.abs() < 1e-8 || new_range.abs() < 1e-8 {
266 let total: usize = old_bins.iter().sum();
268 if total > 0 {
269 let center = (old_min + old_max) * 0.5;
270 let idx = ((center - new_min) / new_range * new_bins.len() as f32).floor() as usize;
271 let idx = idx.min(new_bins.len() - 1);
272 new_bins[idx] += total;
273 }
274 return;
275 }
276 let old_bin_size = old_range / old_bins.len() as f32;
277 let new_bin_count = new_bins.len();
278 for (i, &count) in old_bins.iter().enumerate() {
279 if count == 0 {
280 continue;
281 }
282 let center = old_min + (i as f32 + 0.5) * old_bin_size;
283 let new_idx = ((center - new_min) / new_range * new_bin_count as f32).floor() as usize;
284 let new_idx = new_idx.min(new_bin_count - 1);
285 new_bins[new_idx] += count;
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::*;
292
293 #[test]
294 fn test_activation_stats() {
295 let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
296 let stats = ActivationStats::from_data(&data);
297
298 assert_eq!(stats.min(), -1.0);
299 assert_eq!(stats.max(), 1.0);
300 assert!((stats.mean() - 0.0).abs() < 0.01);
301
302 let p50 = stats.percentile(50.0);
303 assert!((p50 - 0.0).abs() < 0.3);
304 }
305
306 #[test]
311 fn test_minmax_from_stats_matches_raw_data() {
312 let data: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 500.0).collect();
313 let stats = ActivationStats::from_data(&data);
314
315 let from_stats = calculate_optimal_range_from_stats(&stats, CalibrationMethod::MinMax);
316 let from_raw = calculate_optimal_range(&data, CalibrationMethod::MinMax);
317
318 assert_eq!(from_stats.0, from_raw.0);
320 assert_eq!(from_stats.1, from_raw.1);
321 }
322
323 #[test]
324 fn test_percentile_from_stats_is_deterministic() {
325 let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 100.0).collect();
328 let stats = ActivationStats::from_data(&data);
329
330 let r1 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Percentile(99.9));
331 let r2 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Percentile(99.9));
332 let r3 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Percentile(99.9));
333
334 assert_eq!(r1, r2);
335 assert_eq!(r2, r3);
336 }
337
338 #[test]
339 fn test_mse_from_stats_is_deterministic() {
340 let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 100.0).collect();
341 let stats = ActivationStats::from_data(&data);
342
343 let r1 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::MSE);
344 let r2 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::MSE);
345 assert_eq!(r1, r2);
346 }
347
348 #[test]
349 fn test_entropy_from_stats_is_deterministic() {
350 let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 100.0).collect();
351 let stats = ActivationStats::from_data(&data);
352
353 let r1 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Entropy);
354 let r2 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Entropy);
355 assert_eq!(r1, r2);
356 }
357
358 #[test]
359 fn test_all_methods_produce_finite_ranges() {
360 let data: Vec<f32> = (0..200).map(|i| (i as f32 / 50.0) - 1.0).collect();
363 let stats = ActivationStats::from_data(&data);
364
365 for method in [
366 CalibrationMethod::MinMax,
367 CalibrationMethod::Percentile(99.9),
368 CalibrationMethod::Entropy,
369 CalibrationMethod::MSE,
370 ] {
371 let (lo, hi) = calculate_optimal_range_from_stats(&stats, method);
372 assert!(lo.is_finite(), "{:?}: lower bound not finite", method);
373 assert!(hi.is_finite(), "{:?}: upper bound not finite", method);
374 assert!(lo <= hi, "{:?}: lo ({}) > hi ({})", method, lo, hi);
375 }
376 }
377
378 #[test]
379 fn test_stats_based_matches_raw_based_on_bulk_data() {
380 let data: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 100.0).collect();
384 let stats = ActivationStats::from_data(&data);
385
386 let from_stats =
387 calculate_optimal_range_from_stats(&stats, CalibrationMethod::Percentile(99.0));
388 let from_raw = calculate_optimal_range(&data, CalibrationMethod::Percentile(99.0));
389
390 let width = stats.max() - stats.min();
391 let bin_width = width / 256.0;
392 let tolerance = 3.0 * bin_width + 1e-4;
393 assert!(
394 (from_stats.0 - from_raw.0).abs() <= tolerance,
395 "lower percentile drift: stats={} raw={} tol={}",
396 from_stats.0,
397 from_raw.0,
398 tolerance
399 );
400 assert!(
401 (from_stats.1 - from_raw.1).abs() <= tolerance,
402 "upper percentile drift: stats={} raw={} tol={}",
403 from_stats.1,
404 from_raw.1,
405 tolerance
406 );
407 }
408}
409
410pub fn calculate_optimal_range(data: &[f32], method: CalibrationMethod) -> (f32, f32) {
412 if data.is_empty() {
413 return (0.0, 0.0);
414 }
415
416 match method {
417 CalibrationMethod::MinMax => {
418 let min = data
419 .iter()
420 .copied()
421 .filter(|v| v.is_finite())
422 .fold(f32::INFINITY, f32::min);
423 let max = data
424 .iter()
425 .copied()
426 .filter(|v| v.is_finite())
427 .fold(f32::NEG_INFINITY, f32::max);
428 (min, max)
429 }
430
431 CalibrationMethod::Percentile(p) => {
432 let stats = ActivationStats::from_data(data);
433 let lower = stats.percentile(100.0 - p);
434 let upper = stats.percentile(p);
435 (lower, upper)
436 }
437
438 CalibrationMethod::Entropy => optimize_kl_divergence(data),
439
440 CalibrationMethod::MSE => optimize_mse(data),
441 }
442}
443
444pub fn calculate_optimal_range_from_stats(
452 stats: &ActivationStats,
453 method: CalibrationMethod,
454) -> (f32, f32) {
455 match method {
456 CalibrationMethod::MinMax => (stats.min(), stats.max()),
457
458 CalibrationMethod::Percentile(p) => {
459 let lower = stats.percentile(100.0 - p);
460 let upper = stats.percentile(p);
461 (lower, upper)
462 }
463
464 CalibrationMethod::Entropy => optimize_kl_from_stats(stats),
465
466 CalibrationMethod::MSE => optimize_mse_from_stats(stats),
467 }
468}
469
470fn optimize_kl_divergence(data: &[f32]) -> (f32, f32) {
472 let stats = ActivationStats::from_data(data);
473
474 let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
476 let mut best_range = (stats.min, stats.max);
477 let mut best_kl = f32::INFINITY;
478
479 for &percentile in &candidates {
480 let lower = stats.percentile(100.0 - percentile);
481 let upper = stats.percentile(percentile);
482
483 let kl = calculate_kl_divergence(data, lower, upper);
484
485 if kl < best_kl {
486 best_kl = kl;
487 best_range = (lower, upper);
488 }
489 }
490
491 best_range
492}
493
494fn optimize_mse(data: &[f32]) -> (f32, f32) {
496 let stats = ActivationStats::from_data(data);
497
498 let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
500 let mut best_range = (stats.min, stats.max);
501 let mut best_mse = f32::INFINITY;
502
503 for &percentile in &candidates {
504 let lower = stats.percentile(100.0 - percentile);
505 let upper = stats.percentile(percentile);
506
507 let mse = calculate_quantization_mse(data, lower, upper);
508
509 if mse < best_mse {
510 best_mse = mse;
511 best_range = (lower, upper);
512 }
513 }
514
515 best_range
516}
517
518fn calculate_kl_divergence(data: &[f32], min: f32, max: f32) -> f32 {
523 if (max - min).abs() < 1e-8 {
524 return 0.0;
525 }
526
527 let num_bins = 128;
528 let bin_size = (max - min) / num_bins as f32;
529 let scale = (max - min) / 255.0;
530
531 let mut orig_bins = vec![0usize; num_bins];
532 let mut quant_bins = vec![0usize; num_bins];
533
534 for &v in data {
535 let clipped = v.clamp(min, max);
536
537 let bin = ((clipped - min) / bin_size).floor() as usize;
539 let bin = bin.min(num_bins - 1);
540 orig_bins[bin] += 1;
541
542 let q = ((clipped - min) / scale).round();
544 let dequant = min + q * scale;
545 let qbin = ((dequant.clamp(min, max) - min) / bin_size).floor() as usize;
546 let qbin = qbin.min(num_bins - 1);
547 quant_bins[qbin] += 1;
548 }
549
550 let n = data.len() as f32;
551 let epsilon = 1e-10_f32;
552 let mut kl = 0.0_f32;
553
554 for i in 0..num_bins {
555 let p = (orig_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
556 let q = (quant_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
557 kl += p * (p / q).ln();
558 }
559
560 kl
561}
562
563fn calculate_quantization_mse(data: &[f32], min: f32, max: f32) -> f32 {
564 if (max - min).abs() < 1e-8 {
565 return 0.0;
566 }
567
568 let scale = (max - min) / 255.0;
569
570 let mse: f32 = data
571 .iter()
572 .map(|&v| {
573 let clipped = v.clamp(min, max);
574 let q = ((clipped - min) / scale).round().clamp(0.0, 255.0);
575 let dequantized = min + q * scale;
576 (v - dequantized).powi(2)
577 })
578 .sum::<f32>()
579 / data.len() as f32;
580
581 mse
582}
583
584fn histogram_kl_divergence(stats: &ActivationStats, min: f32, max: f32) -> f32 {
595 if (max - min).abs() < 1e-8 {
596 return 0.0;
597 }
598 let hist = stats.histogram_data();
599 if hist.is_empty() {
600 return 0.0;
601 }
602
603 const NUM_REBINS: usize = 128;
604 let rebin_size = (max - min) / NUM_REBINS as f32;
605 let scale = (max - min) / 255.0;
606
607 let mut orig = vec![0.0_f32; NUM_REBINS];
608 let mut quant = vec![0.0_f32; NUM_REBINS];
609
610 for &(center, count) in &hist {
611 let clipped = center.clamp(min, max);
612 let count_f = count as f32;
613
614 let bin = ((clipped - min) / rebin_size).floor() as usize;
615 let bin = bin.min(NUM_REBINS - 1);
616 orig[bin] += count_f;
617
618 let q = ((clipped - min) / scale).round();
619 let dq = min + q * scale;
620 let qbin = ((dq.clamp(min, max) - min) / rebin_size).floor() as usize;
621 let qbin = qbin.min(NUM_REBINS - 1);
622 quant[qbin] += count_f;
623 }
624
625 let total: f32 = orig.iter().sum();
626 if total == 0.0 {
627 return 0.0;
628 }
629
630 let epsilon = 1e-10_f32;
631 let denom = total + epsilon * NUM_REBINS as f32;
632 let mut kl = 0.0_f32;
633 for i in 0..NUM_REBINS {
634 let p = (orig[i] + epsilon) / denom;
635 let q = (quant[i] + epsilon) / denom;
636 kl += p * (p / q).ln();
637 }
638 kl
639}
640
641fn histogram_quantization_mse(stats: &ActivationStats, min: f32, max: f32) -> f32 {
644 if (max - min).abs() < 1e-8 {
645 return 0.0;
646 }
647
648 let scale = (max - min) / 255.0;
649 let mut weighted_sse = 0.0_f64;
650 let mut total_count = 0_u64;
651
652 for (center, count) in stats.histogram_data() {
653 let clipped = center.clamp(min, max);
654 let q = ((clipped - min) / scale).round().clamp(0.0, 255.0);
655 let dq = min + q * scale;
656 let err = (center - dq) as f64;
657 weighted_sse += err * err * count as f64;
658 total_count += count as u64;
659 }
660
661 if total_count == 0 {
662 0.0
663 } else {
664 (weighted_sse / total_count as f64) as f32
665 }
666}
667
668fn optimize_kl_from_stats(stats: &ActivationStats) -> (f32, f32) {
669 let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
670 let mut best_range = (stats.min(), stats.max());
671 let mut best_kl = f32::INFINITY;
672
673 for &percentile in &candidates {
674 let lower = stats.percentile(100.0 - percentile);
675 let upper = stats.percentile(percentile);
676 let kl = histogram_kl_divergence(stats, lower, upper);
677 if kl < best_kl {
678 best_kl = kl;
679 best_range = (lower, upper);
680 }
681 }
682 best_range
683}
684
685fn optimize_mse_from_stats(stats: &ActivationStats) -> (f32, f32) {
686 let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
687 let mut best_range = (stats.min(), stats.max());
688 let mut best_mse = f32::INFINITY;
689
690 for &percentile in &candidates {
691 let lower = stats.percentile(100.0 - percentile);
692 let upper = stats.percentile(percentile);
693 let mse = histogram_quantization_mse(stats, lower, upper);
694 if mse < best_mse {
695 best_mse = mse;
696 best_range = (lower, upper);
697 }
698 }
699 best_range
700}