Skip to main content

quantize_rs/calibration/
stats.rs

1//! Incremental activation statistics (min, max, mean, std, histogram).
2//!
3//! [`ActivationStats`] can be built from a single batch with [`from_data`](ActivationStats::from_data)
4//! and then incrementally extended with [`update`](ActivationStats::update).
5
6use crate::calibration::methods::CalibrationMethod;
7
8const NUM_BINS: usize = 256;
9
10/// Incremental activation statistics for a single layer.
11///
12/// Tracks min, max, mean, standard deviation, and a 256-bin histogram.
13/// Supports incremental updates via Chan's parallel algorithm.
14#[derive(Debug, Clone)]
15pub struct ActivationStats {
16    min: f32,
17    max: f32,
18    mean: f32,
19    std: f32,
20    count: usize,
21
22    /// Running sum of squared deviations (Welford's M2) for incremental std.
23    m2: f64,
24
25    histogram_bins: Vec<usize>,
26    hist_min: f32,
27    hist_max: f32,
28}
29
30impl ActivationStats {
31    /// Minimum observed value.
32    pub fn min(&self) -> f32 {
33        self.min
34    }
35    /// Maximum observed value.
36    pub fn max(&self) -> f32 {
37        self.max
38    }
39    /// Running mean.
40    pub fn mean(&self) -> f32 {
41        self.mean
42    }
43    /// Running standard deviation.
44    pub fn std(&self) -> f32 {
45        self.std
46    }
47    /// Number of observations.
48    pub fn count(&self) -> usize {
49        self.count
50    }
51}
52
53impl ActivationStats {
54    /// Create stats from a single batch of observations.
55    pub fn from_data(data: &[f32]) -> Self {
56        if data.is_empty() {
57            return Self::default();
58        }
59
60        let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
61        if finite.is_empty() {
62            return Self::default();
63        }
64
65        let min = finite.iter().copied().fold(f32::INFINITY, f32::min);
66        let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
67
68        let sum: f32 = finite.iter().sum();
69        let mean = sum / finite.len() as f32;
70
71        let m2: f64 = finite.iter().map(|&x| ((x - mean) as f64).powi(2)).sum();
72        let std = (m2 / finite.len() as f64).sqrt() as f32;
73
74        let histogram_bins = build_histogram(data, min, max);
75
76        Self {
77            min,
78            max,
79            mean,
80            std,
81            count: finite.len(),
82            m2,
83            histogram_bins,
84            hist_min: min,
85            hist_max: max,
86        }
87    }
88
89    /// Incrementally merge a new batch of observations into the stats.
90    pub fn update(&mut self, data: &[f32]) {
91        if data.is_empty() {
92            return;
93        }
94
95        // Only consider finite values — skip batches that are entirely NaN/Inf
96        let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
97        if finite.is_empty() {
98            return;
99        }
100
101        let data_min = finite.iter().copied().fold(f32::INFINITY, f32::min);
102        let data_max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
103
104        let new_min = self.min.min(data_min);
105        let new_max = self.max.max(data_max);
106
107        // Parallel/batch variant of Welford's online algorithm:
108        // Merge two populations (existing stats + new batch) into combined stats.
109        let old_count = self.count as f64;
110        let new_count = finite.len() as f64;
111        let combined_count = old_count + new_count;
112
113        let data_sum: f64 = finite.iter().map(|&x| x as f64).sum();
114        let data_mean = data_sum / new_count;
115
116        let data_m2: f64 = finite
117            .iter()
118            .map(|&x| ((x as f64) - data_mean).powi(2))
119            .sum();
120
121        // Chan's parallel algorithm for combining M2 values
122        let delta = data_mean - self.mean as f64;
123        self.m2 = self.m2 + data_m2 + delta * delta * old_count * new_count / combined_count;
124
125        self.mean = ((self.mean as f64) * old_count + data_sum) as f32 / combined_count as f32;
126        self.count = combined_count as usize;
127        self.std = (self.m2 / combined_count).sqrt() as f32;
128
129        // If range expanded, re-bin existing data into the new range
130        if new_min < self.hist_min || new_max > self.hist_max {
131            let mut rebinned = vec![0usize; NUM_BINS];
132            rebin(
133                &self.histogram_bins,
134                self.hist_min,
135                self.hist_max,
136                &mut rebinned,
137                new_min,
138                new_max,
139            );
140            self.histogram_bins = rebinned;
141            self.hist_min = new_min;
142            self.hist_max = new_max;
143        }
144
145        // Add new data into bins (build_histogram already filters NaN/Inf internally)
146        let new_hist = build_histogram(&finite, self.hist_min, self.hist_max);
147        for (i, &c) in new_hist.iter().enumerate() {
148            self.histogram_bins[i] += c;
149        }
150
151        self.min = new_min;
152        self.max = new_max;
153    }
154
155    /// Estimate the value at percentile `p` (0--100) from the histogram.
156    pub fn percentile(&self, p: f32) -> f32 {
157        if self.histogram_bins.is_empty() {
158            return self.min;
159        }
160
161        let total: usize = self.histogram_bins.iter().sum();
162        if total == 0 {
163            return self.min;
164        }
165
166        // ceil, not truncation: for 5 elements at p=50, target rank must be 3
167        // (the actual median), not 2 (which would return the element below it).
168        let target_count = (total as f32 * p / 100.0).ceil() as usize;
169        let mut cumulative = 0;
170
171        let bin_size = if (self.hist_max - self.hist_min).abs() < 1e-8 {
172            0.0
173        } else {
174            (self.hist_max - self.hist_min) / NUM_BINS as f32
175        };
176
177        for (i, &count) in self.histogram_bins.iter().enumerate() {
178            cumulative += count;
179            if cumulative >= target_count {
180                return self.hist_min + (i as f32 + 0.5) * bin_size;
181            }
182        }
183
184        self.max
185    }
186
187    /// Return histogram data as (bin_center, count) pairs.
188    pub fn histogram_data(&self) -> Vec<(f32, usize)> {
189        if (self.hist_max - self.hist_min).abs() < 1e-8 {
190            let total: usize = self.histogram_bins.iter().sum();
191            if total > 0 {
192                return vec![(self.hist_min, total)];
193            }
194            return Vec::new();
195        }
196        let bin_size = (self.hist_max - self.hist_min) / NUM_BINS as f32;
197        self.histogram_bins
198            .iter()
199            .enumerate()
200            .filter(|(_, &count)| count > 0)
201            .map(|(i, &count)| {
202                let value = self.hist_min + (i as f32 + 0.5) * bin_size;
203                (value, count)
204            })
205            .collect()
206    }
207}
208
209impl Default for ActivationStats {
210    fn default() -> Self {
211        Self {
212            min: f32::INFINITY,
213            max: f32::NEG_INFINITY,
214            mean: 0.0,
215            std: 0.0,
216            count: 0,
217            m2: 0.0,
218            histogram_bins: Vec::new(),
219            hist_min: 0.0,
220            hist_max: 0.0,
221        }
222    }
223}
224
225fn build_histogram(data: &[f32], min: f32, max: f32) -> Vec<usize> {
226    let mut bins = vec![0usize; NUM_BINS];
227
228    if (max - min).abs() < 1e-8 {
229        // All values map to a single bin
230        let finite_count = data.iter().filter(|v| v.is_finite()).count();
231        if !bins.is_empty() {
232            bins[0] = finite_count;
233        }
234        return bins;
235    }
236
237    let bin_size = (max - min) / NUM_BINS as f32;
238
239    for &value in data {
240        if !value.is_finite() {
241            continue;
242        }
243        let bin_idx = ((value - min) / bin_size).floor() as usize;
244        let bin_idx = bin_idx.min(NUM_BINS - 1);
245        bins[bin_idx] += 1;
246    }
247
248    bins
249}
250
251/// Re-bin histogram data from one range to another.
252fn rebin(
253    old_bins: &[usize],
254    old_min: f32,
255    old_max: f32,
256    new_bins: &mut [usize],
257    new_min: f32,
258    new_max: f32,
259) {
260    if old_bins.is_empty() || new_bins.is_empty() {
261        return;
262    }
263    let old_range = old_max - old_min;
264    let new_range = new_max - new_min;
265    if old_range.abs() < 1e-8 || new_range.abs() < 1e-8 {
266        // Everything goes into the closest bin in the new range
267        let total: usize = old_bins.iter().sum();
268        if total > 0 {
269            let center = (old_min + old_max) * 0.5;
270            let idx = ((center - new_min) / new_range * new_bins.len() as f32).floor() as usize;
271            let idx = idx.min(new_bins.len() - 1);
272            new_bins[idx] += total;
273        }
274        return;
275    }
276    let old_bin_size = old_range / old_bins.len() as f32;
277    let new_bin_count = new_bins.len();
278    for (i, &count) in old_bins.iter().enumerate() {
279        if count == 0 {
280            continue;
281        }
282        let center = old_min + (i as f32 + 0.5) * old_bin_size;
283        let new_idx = ((center - new_min) / new_range * new_bin_count as f32).floor() as usize;
284        let new_idx = new_idx.min(new_bin_count - 1);
285        new_bins[new_idx] += count;
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_activation_stats() {
295        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
296        let stats = ActivationStats::from_data(&data);
297
298        assert_eq!(stats.min(), -1.0);
299        assert_eq!(stats.max(), 1.0);
300        assert!((stats.mean() - 0.0).abs() < 0.01);
301
302        let p50 = stats.percentile(50.0);
303        assert!((p50 - 0.0).abs() < 0.3);
304    }
305
306    // -----------------------------------------------------------------------
307    // Histogram-direct range optimization
308    // -----------------------------------------------------------------------
309
310    #[test]
311    fn test_minmax_from_stats_matches_raw_data() {
312        let data: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 500.0).collect();
313        let stats = ActivationStats::from_data(&data);
314
315        let from_stats = calculate_optimal_range_from_stats(&stats, CalibrationMethod::MinMax);
316        let from_raw = calculate_optimal_range(&data, CalibrationMethod::MinMax);
317
318        // MinMax path must be identical.
319        assert_eq!(from_stats.0, from_raw.0);
320        assert_eq!(from_stats.1, from_raw.1);
321    }
322
323    #[test]
324    fn test_percentile_from_stats_is_deterministic() {
325        // Same stats → same range, on every call.  The raw-data path used to
326        // regenerate samples with a thread-local RNG, making results unstable.
327        let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 100.0).collect();
328        let stats = ActivationStats::from_data(&data);
329
330        let r1 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Percentile(99.9));
331        let r2 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Percentile(99.9));
332        let r3 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Percentile(99.9));
333
334        assert_eq!(r1, r2);
335        assert_eq!(r2, r3);
336    }
337
338    #[test]
339    fn test_mse_from_stats_is_deterministic() {
340        let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 100.0).collect();
341        let stats = ActivationStats::from_data(&data);
342
343        let r1 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::MSE);
344        let r2 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::MSE);
345        assert_eq!(r1, r2);
346    }
347
348    #[test]
349    fn test_entropy_from_stats_is_deterministic() {
350        let data: Vec<f32> = (0..500).map(|i| (i as f32 - 250.0) / 100.0).collect();
351        let stats = ActivationStats::from_data(&data);
352
353        let r1 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Entropy);
354        let r2 = calculate_optimal_range_from_stats(&stats, CalibrationMethod::Entropy);
355        assert_eq!(r1, r2);
356    }
357
358    #[test]
359    fn test_all_methods_produce_finite_ranges() {
360        // Regression guard: the histogram-direct optimizers must never
361        // produce NaN/Inf for any reasonable input, including skewed data.
362        let data: Vec<f32> = (0..200).map(|i| (i as f32 / 50.0) - 1.0).collect();
363        let stats = ActivationStats::from_data(&data);
364
365        for method in [
366            CalibrationMethod::MinMax,
367            CalibrationMethod::Percentile(99.9),
368            CalibrationMethod::Entropy,
369            CalibrationMethod::MSE,
370        ] {
371            let (lo, hi) = calculate_optimal_range_from_stats(&stats, method);
372            assert!(lo.is_finite(), "{:?}: lower bound not finite", method);
373            assert!(hi.is_finite(), "{:?}: upper bound not finite", method);
374            assert!(lo <= hi, "{:?}: lo ({}) > hi ({})", method, lo, hi);
375        }
376    }
377
378    #[test]
379    fn test_stats_based_matches_raw_based_on_bulk_data() {
380        // For a well-populated histogram, the stats-based and raw-based
381        // percentile paths should agree closely (histogram has 256 bins → the
382        // result is within one bin width).
383        let data: Vec<f32> = (0..1000).map(|i| (i as f32 - 500.0) / 100.0).collect();
384        let stats = ActivationStats::from_data(&data);
385
386        let from_stats =
387            calculate_optimal_range_from_stats(&stats, CalibrationMethod::Percentile(99.0));
388        let from_raw = calculate_optimal_range(&data, CalibrationMethod::Percentile(99.0));
389
390        let width = stats.max() - stats.min();
391        let bin_width = width / 256.0;
392        let tolerance = 3.0 * bin_width + 1e-4;
393        assert!(
394            (from_stats.0 - from_raw.0).abs() <= tolerance,
395            "lower percentile drift: stats={} raw={} tol={}",
396            from_stats.0,
397            from_raw.0,
398            tolerance
399        );
400        assert!(
401            (from_stats.1 - from_raw.1).abs() <= tolerance,
402            "upper percentile drift: stats={} raw={} tol={}",
403            from_stats.1,
404            from_raw.1,
405            tolerance
406        );
407    }
408}
409
410/// Compute the optimal quantization range for `data` using the given method.
411pub fn calculate_optimal_range(data: &[f32], method: CalibrationMethod) -> (f32, f32) {
412    if data.is_empty() {
413        return (0.0, 0.0);
414    }
415
416    match method {
417        CalibrationMethod::MinMax => {
418            let min = data
419                .iter()
420                .copied()
421                .filter(|v| v.is_finite())
422                .fold(f32::INFINITY, f32::min);
423            let max = data
424                .iter()
425                .copied()
426                .filter(|v| v.is_finite())
427                .fold(f32::NEG_INFINITY, f32::max);
428            (min, max)
429        }
430
431        CalibrationMethod::Percentile(p) => {
432            let stats = ActivationStats::from_data(data);
433            let lower = stats.percentile(100.0 - p);
434            let upper = stats.percentile(p);
435            (lower, upper)
436        }
437
438        CalibrationMethod::Entropy => optimize_kl_divergence(data),
439
440        CalibrationMethod::MSE => optimize_mse(data),
441    }
442}
443
444/// Compute the optimal quantization range directly from pre-collected
445/// [`ActivationStats`], without regenerating samples from the histogram.
446///
447/// This is the preferred path inside `Quantizer::with_calibration`: the stats
448/// already carry the full empirical distribution (min/max + 256-bin histogram),
449/// so there is no benefit to re-sampling and re-binning.  It's also
450/// deterministic (no RNG) and O(num_bins) instead of O(num_samples).
451pub fn calculate_optimal_range_from_stats(
452    stats: &ActivationStats,
453    method: CalibrationMethod,
454) -> (f32, f32) {
455    match method {
456        CalibrationMethod::MinMax => (stats.min(), stats.max()),
457
458        CalibrationMethod::Percentile(p) => {
459            let lower = stats.percentile(100.0 - p);
460            let upper = stats.percentile(p);
461            (lower, upper)
462        }
463
464        CalibrationMethod::Entropy => optimize_kl_from_stats(stats),
465
466        CalibrationMethod::MSE => optimize_mse_from_stats(stats),
467    }
468}
469
470/// Optimize range using KL divergence (entropy method)
471fn optimize_kl_divergence(data: &[f32]) -> (f32, f32) {
472    let stats = ActivationStats::from_data(data);
473
474    // Try different percentile thresholds and find the one with minimum KL divergence
475    let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
476    let mut best_range = (stats.min, stats.max);
477    let mut best_kl = f32::INFINITY;
478
479    for &percentile in &candidates {
480        let lower = stats.percentile(100.0 - percentile);
481        let upper = stats.percentile(percentile);
482
483        let kl = calculate_kl_divergence(data, lower, upper);
484
485        if kl < best_kl {
486            best_kl = kl;
487            best_range = (lower, upper);
488        }
489    }
490
491    best_range
492}
493
494/// Optimize range using MSE minimization
495fn optimize_mse(data: &[f32]) -> (f32, f32) {
496    let stats = ActivationStats::from_data(data);
497
498    // Try different percentile thresholds and find the one with minimum MSE
499    let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
500    let mut best_range = (stats.min, stats.max);
501    let mut best_mse = f32::INFINITY;
502
503    for &percentile in &candidates {
504        let lower = stats.percentile(100.0 - percentile);
505        let upper = stats.percentile(percentile);
506
507        let mse = calculate_quantization_mse(data, lower, upper);
508
509        if mse < best_mse {
510            best_mse = mse;
511            best_range = (lower, upper);
512        }
513    }
514
515    best_range
516}
517
518/// Calculate KL divergence between original and quantized distribution.
519///
520/// Uses dense, aligned bins so every bin index in the original histogram
521/// maps to the same value range in the quantized histogram.
522fn calculate_kl_divergence(data: &[f32], min: f32, max: f32) -> f32 {
523    if (max - min).abs() < 1e-8 {
524        return 0.0;
525    }
526
527    let num_bins = 128;
528    let bin_size = (max - min) / num_bins as f32;
529    let scale = (max - min) / 255.0;
530
531    let mut orig_bins = vec![0usize; num_bins];
532    let mut quant_bins = vec![0usize; num_bins];
533
534    for &v in data {
535        let clipped = v.clamp(min, max);
536
537        // Original bin
538        let bin = ((clipped - min) / bin_size).floor() as usize;
539        let bin = bin.min(num_bins - 1);
540        orig_bins[bin] += 1;
541
542        // Simulated INT8 quantize -> dequantize, then bin
543        let q = ((clipped - min) / scale).round();
544        let dequant = min + q * scale;
545        let qbin = ((dequant.clamp(min, max) - min) / bin_size).floor() as usize;
546        let qbin = qbin.min(num_bins - 1);
547        quant_bins[qbin] += 1;
548    }
549
550    let n = data.len() as f32;
551    let epsilon = 1e-10_f32;
552    let mut kl = 0.0_f32;
553
554    for i in 0..num_bins {
555        let p = (orig_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
556        let q = (quant_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
557        kl += p * (p / q).ln();
558    }
559
560    kl
561}
562
563fn calculate_quantization_mse(data: &[f32], min: f32, max: f32) -> f32 {
564    if (max - min).abs() < 1e-8 {
565        return 0.0;
566    }
567
568    let scale = (max - min) / 255.0;
569
570    let mse: f32 = data
571        .iter()
572        .map(|&v| {
573            let clipped = v.clamp(min, max);
574            let q = ((clipped - min) / scale).round().clamp(0.0, 255.0);
575            let dequantized = min + q * scale;
576            (v - dequantized).powi(2)
577        })
578        .sum::<f32>()
579        / data.len() as f32;
580
581    mse
582}
583
584// ---------------------------------------------------------------------------
585// Histogram-direct range optimization
586//
587// The functions below walk the 256-bin histogram carried by `ActivationStats`
588// instead of reconstructing samples.  They are deterministic, RNG-free, and
589// O(candidates × num_bins) in work — independent of the original dataset size.
590// ---------------------------------------------------------------------------
591
592/// KL divergence between the empirical histogram and a simulated INT8
593/// quantize → dequantize of that histogram, restricted to `[min, max]`.
594fn histogram_kl_divergence(stats: &ActivationStats, min: f32, max: f32) -> f32 {
595    if (max - min).abs() < 1e-8 {
596        return 0.0;
597    }
598    let hist = stats.histogram_data();
599    if hist.is_empty() {
600        return 0.0;
601    }
602
603    const NUM_REBINS: usize = 128;
604    let rebin_size = (max - min) / NUM_REBINS as f32;
605    let scale = (max - min) / 255.0;
606
607    let mut orig = vec![0.0_f32; NUM_REBINS];
608    let mut quant = vec![0.0_f32; NUM_REBINS];
609
610    for &(center, count) in &hist {
611        let clipped = center.clamp(min, max);
612        let count_f = count as f32;
613
614        let bin = ((clipped - min) / rebin_size).floor() as usize;
615        let bin = bin.min(NUM_REBINS - 1);
616        orig[bin] += count_f;
617
618        let q = ((clipped - min) / scale).round();
619        let dq = min + q * scale;
620        let qbin = ((dq.clamp(min, max) - min) / rebin_size).floor() as usize;
621        let qbin = qbin.min(NUM_REBINS - 1);
622        quant[qbin] += count_f;
623    }
624
625    let total: f32 = orig.iter().sum();
626    if total == 0.0 {
627        return 0.0;
628    }
629
630    let epsilon = 1e-10_f32;
631    let denom = total + epsilon * NUM_REBINS as f32;
632    let mut kl = 0.0_f32;
633    for i in 0..NUM_REBINS {
634        let p = (orig[i] + epsilon) / denom;
635        let q = (quant[i] + epsilon) / denom;
636        kl += p * (p / q).ln();
637    }
638    kl
639}
640
641/// Quantization MSE computed directly on the histogram: sum of
642/// `(center - dequantize(quantize(center)))² × count` weighted by count.
643fn histogram_quantization_mse(stats: &ActivationStats, min: f32, max: f32) -> f32 {
644    if (max - min).abs() < 1e-8 {
645        return 0.0;
646    }
647
648    let scale = (max - min) / 255.0;
649    let mut weighted_sse = 0.0_f64;
650    let mut total_count = 0_u64;
651
652    for (center, count) in stats.histogram_data() {
653        let clipped = center.clamp(min, max);
654        let q = ((clipped - min) / scale).round().clamp(0.0, 255.0);
655        let dq = min + q * scale;
656        let err = (center - dq) as f64;
657        weighted_sse += err * err * count as f64;
658        total_count += count as u64;
659    }
660
661    if total_count == 0 {
662        0.0
663    } else {
664        (weighted_sse / total_count as f64) as f32
665    }
666}
667
668fn optimize_kl_from_stats(stats: &ActivationStats) -> (f32, f32) {
669    let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
670    let mut best_range = (stats.min(), stats.max());
671    let mut best_kl = f32::INFINITY;
672
673    for &percentile in &candidates {
674        let lower = stats.percentile(100.0 - percentile);
675        let upper = stats.percentile(percentile);
676        let kl = histogram_kl_divergence(stats, lower, upper);
677        if kl < best_kl {
678            best_kl = kl;
679            best_range = (lower, upper);
680        }
681    }
682    best_range
683}
684
685fn optimize_mse_from_stats(stats: &ActivationStats) -> (f32, f32) {
686    let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
687    let mut best_range = (stats.min(), stats.max());
688    let mut best_mse = f32::INFINITY;
689
690    for &percentile in &candidates {
691        let lower = stats.percentile(100.0 - percentile);
692        let upper = stats.percentile(percentile);
693        let mse = histogram_quantization_mse(stats, lower, upper);
694        if mse < best_mse {
695            best_mse = mse;
696            best_range = (lower, upper);
697        }
698    }
699    best_range
700}