Skip to main content

quantize_rs/calibration/
stats.rs

1//! Incremental activation statistics (min, max, mean, std, histogram).
2//!
3//! [`ActivationStats`] can be built from a single batch with [`from_data`](ActivationStats::from_data)
4//! and then incrementally extended with [`update`](ActivationStats::update).
5
6use crate::calibration::methods::CalibrationMethod;
7
8const NUM_BINS: usize = 256;
9
10/// Incremental activation statistics for a single layer.
11///
12/// Tracks min, max, mean, standard deviation, and a 256-bin histogram.
13/// Supports incremental updates via Chan's parallel algorithm.
14#[derive(Debug, Clone)]
15pub struct ActivationStats {
16    min: f32,
17    max: f32,
18    mean: f32,
19    std: f32,
20    count: usize,
21
22    /// Running sum of squared deviations (Welford's M2) for incremental std.
23    m2: f64,
24
25    histogram_bins: Vec<usize>,
26    hist_min: f32,
27    hist_max: f32,
28}
29
30impl ActivationStats {
31    /// Minimum observed value.
32    pub fn min(&self) -> f32 {
33        self.min
34    }
35    /// Maximum observed value.
36    pub fn max(&self) -> f32 {
37        self.max
38    }
39    /// Running mean.
40    pub fn mean(&self) -> f32 {
41        self.mean
42    }
43    /// Running standard deviation.
44    pub fn std(&self) -> f32 {
45        self.std
46    }
47    /// Number of observations.
48    pub fn count(&self) -> usize {
49        self.count
50    }
51}
52
53impl ActivationStats {
54    /// Create stats from a single batch of observations.
55    pub fn from_data(data: &[f32]) -> Self {
56        if data.is_empty() {
57            return Self::default();
58        }
59
60        let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
61        if finite.is_empty() {
62            return Self::default();
63        }
64
65        let min = finite.iter().copied().fold(f32::INFINITY, f32::min);
66        let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
67
68        let sum: f32 = finite.iter().sum();
69        let mean = sum / finite.len() as f32;
70
71        let m2: f64 = finite.iter().map(|&x| ((x - mean) as f64).powi(2)).sum();
72        let std = (m2 / finite.len() as f64).sqrt() as f32;
73
74        let histogram_bins = build_histogram(data, min, max);
75
76        Self {
77            min,
78            max,
79            mean,
80            std,
81            count: finite.len(),
82            m2,
83            histogram_bins,
84            hist_min: min,
85            hist_max: max,
86        }
87    }
88
89    /// Incrementally merge a new batch of observations into the stats.
90    pub fn update(&mut self, data: &[f32]) {
91        if data.is_empty() {
92            return;
93        }
94
95        // Only consider finite values — skip batches that are entirely NaN/Inf
96        let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
97        if finite.is_empty() {
98            return;
99        }
100
101        let data_min = finite.iter().copied().fold(f32::INFINITY, f32::min);
102        let data_max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
103
104        let new_min = self.min.min(data_min);
105        let new_max = self.max.max(data_max);
106
107        // Parallel/batch variant of Welford's online algorithm:
108        // Merge two populations (existing stats + new batch) into combined stats.
109        let old_count = self.count as f64;
110        let new_count = finite.len() as f64;
111        let combined_count = old_count + new_count;
112
113        let data_sum: f64 = finite.iter().map(|&x| x as f64).sum();
114        let data_mean = data_sum / new_count;
115
116        let data_m2: f64 = finite
117            .iter()
118            .map(|&x| ((x as f64) - data_mean).powi(2))
119            .sum();
120
121        // Chan's parallel algorithm for combining M2 values
122        let delta = data_mean - self.mean as f64;
123        self.m2 = self.m2 + data_m2 + delta * delta * old_count * new_count / combined_count;
124
125        self.mean = ((self.mean as f64) * old_count + data_sum) as f32 / combined_count as f32;
126        self.count = combined_count as usize;
127        self.std = (self.m2 / combined_count).sqrt() as f32;
128
129        // If range expanded, re-bin existing data into the new range
130        if new_min < self.hist_min || new_max > self.hist_max {
131            let mut rebinned = vec![0usize; NUM_BINS];
132            rebin(
133                &self.histogram_bins,
134                self.hist_min,
135                self.hist_max,
136                &mut rebinned,
137                new_min,
138                new_max,
139            );
140            self.histogram_bins = rebinned;
141            self.hist_min = new_min;
142            self.hist_max = new_max;
143        }
144
145        // Add new data into bins (build_histogram already filters NaN/Inf internally)
146        let new_hist = build_histogram(&finite, self.hist_min, self.hist_max);
147        for (i, &c) in new_hist.iter().enumerate() {
148            self.histogram_bins[i] += c;
149        }
150
151        self.min = new_min;
152        self.max = new_max;
153    }
154
155    /// Estimate the value at percentile `p` (0--100) from the histogram.
156    pub fn percentile(&self, p: f32) -> f32 {
157        if self.histogram_bins.is_empty() {
158            return self.min;
159        }
160
161        let total: usize = self.histogram_bins.iter().sum();
162        if total == 0 {
163            return self.min;
164        }
165
166        // ceil, not truncation: for 5 elements at p=50, target rank must be 3
167        // (the actual median), not 2 (which would return the element below it).
168        let target_count = (total as f32 * p / 100.0).ceil() as usize;
169        let mut cumulative = 0;
170
171        let bin_size = if (self.hist_max - self.hist_min).abs() < 1e-8 {
172            0.0
173        } else {
174            (self.hist_max - self.hist_min) / NUM_BINS as f32
175        };
176
177        for (i, &count) in self.histogram_bins.iter().enumerate() {
178            cumulative += count;
179            if cumulative >= target_count {
180                return self.hist_min + (i as f32 + 0.5) * bin_size;
181            }
182        }
183
184        self.max
185    }
186
187    /// Return histogram data as (bin_center, count) pairs.
188    pub fn histogram_data(&self) -> Vec<(f32, usize)> {
189        if (self.hist_max - self.hist_min).abs() < 1e-8 {
190            let total: usize = self.histogram_bins.iter().sum();
191            if total > 0 {
192                return vec![(self.hist_min, total)];
193            }
194            return Vec::new();
195        }
196        let bin_size = (self.hist_max - self.hist_min) / NUM_BINS as f32;
197        self.histogram_bins
198            .iter()
199            .enumerate()
200            .filter(|(_, &count)| count > 0)
201            .map(|(i, &count)| {
202                let value = self.hist_min + (i as f32 + 0.5) * bin_size;
203                (value, count)
204            })
205            .collect()
206    }
207}
208
209impl Default for ActivationStats {
210    fn default() -> Self {
211        Self {
212            min: f32::INFINITY,
213            max: f32::NEG_INFINITY,
214            mean: 0.0,
215            std: 0.0,
216            count: 0,
217            m2: 0.0,
218            histogram_bins: Vec::new(),
219            hist_min: 0.0,
220            hist_max: 0.0,
221        }
222    }
223}
224
225fn build_histogram(data: &[f32], min: f32, max: f32) -> Vec<usize> {
226    let mut bins = vec![0usize; NUM_BINS];
227
228    if (max - min).abs() < 1e-8 {
229        // All values map to a single bin
230        let finite_count = data.iter().filter(|v| v.is_finite()).count();
231        if !bins.is_empty() {
232            bins[0] = finite_count;
233        }
234        return bins;
235    }
236
237    let bin_size = (max - min) / NUM_BINS as f32;
238
239    for &value in data {
240        if !value.is_finite() {
241            continue;
242        }
243        let bin_idx = ((value - min) / bin_size).floor() as usize;
244        let bin_idx = bin_idx.min(NUM_BINS - 1);
245        bins[bin_idx] += 1;
246    }
247
248    bins
249}
250
251/// Re-bin histogram data from one range to another.
252fn rebin(
253    old_bins: &[usize],
254    old_min: f32,
255    old_max: f32,
256    new_bins: &mut [usize],
257    new_min: f32,
258    new_max: f32,
259) {
260    if old_bins.is_empty() || new_bins.is_empty() {
261        return;
262    }
263    let old_range = old_max - old_min;
264    let new_range = new_max - new_min;
265    if old_range.abs() < 1e-8 || new_range.abs() < 1e-8 {
266        // Everything goes into the closest bin in the new range
267        let total: usize = old_bins.iter().sum();
268        if total > 0 {
269            let center = (old_min + old_max) * 0.5;
270            let idx = ((center - new_min) / new_range * new_bins.len() as f32).floor() as usize;
271            let idx = idx.min(new_bins.len() - 1);
272            new_bins[idx] += total;
273        }
274        return;
275    }
276    let old_bin_size = old_range / old_bins.len() as f32;
277    let new_bin_count = new_bins.len();
278    for (i, &count) in old_bins.iter().enumerate() {
279        if count == 0 {
280            continue;
281        }
282        let center = old_min + (i as f32 + 0.5) * old_bin_size;
283        let new_idx = ((center - new_min) / new_range * new_bin_count as f32).floor() as usize;
284        let new_idx = new_idx.min(new_bin_count - 1);
285        new_bins[new_idx] += count;
286    }
287}
288
289#[cfg(test)]
290mod tests {
291    use super::*;
292
293    #[test]
294    fn test_activation_stats() {
295        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
296        let stats = ActivationStats::from_data(&data);
297
298        assert_eq!(stats.min(), -1.0);
299        assert_eq!(stats.max(), 1.0);
300        assert!((stats.mean() - 0.0).abs() < 0.01);
301
302        let p50 = stats.percentile(50.0);
303        assert!((p50 - 0.0).abs() < 0.3);
304    }
305}
306
307/// Compute the optimal quantization range for `data` using the given method.
308pub fn calculate_optimal_range(data: &[f32], method: CalibrationMethod) -> (f32, f32) {
309    if data.is_empty() {
310        return (0.0, 0.0);
311    }
312
313    match method {
314        CalibrationMethod::MinMax => {
315            let min = data
316                .iter()
317                .copied()
318                .filter(|v| v.is_finite())
319                .fold(f32::INFINITY, f32::min);
320            let max = data
321                .iter()
322                .copied()
323                .filter(|v| v.is_finite())
324                .fold(f32::NEG_INFINITY, f32::max);
325            (min, max)
326        }
327
328        CalibrationMethod::Percentile(p) => {
329            let stats = ActivationStats::from_data(data);
330            let lower = stats.percentile(100.0 - p);
331            let upper = stats.percentile(p);
332            (lower, upper)
333        }
334
335        CalibrationMethod::Entropy => optimize_kl_divergence(data),
336
337        CalibrationMethod::MSE => optimize_mse(data),
338    }
339}
340
341/// Optimize range using KL divergence (entropy method)
342fn optimize_kl_divergence(data: &[f32]) -> (f32, f32) {
343    let stats = ActivationStats::from_data(data);
344
345    // Try different percentile thresholds and find the one with minimum KL divergence
346    let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
347    let mut best_range = (stats.min, stats.max);
348    let mut best_kl = f32::INFINITY;
349
350    for &percentile in &candidates {
351        let lower = stats.percentile(100.0 - percentile);
352        let upper = stats.percentile(percentile);
353
354        let kl = calculate_kl_divergence(data, lower, upper);
355
356        if kl < best_kl {
357            best_kl = kl;
358            best_range = (lower, upper);
359        }
360    }
361
362    best_range
363}
364
365/// Optimize range using MSE minimization
366fn optimize_mse(data: &[f32]) -> (f32, f32) {
367    let stats = ActivationStats::from_data(data);
368
369    // Try different percentile thresholds and find the one with minimum MSE
370    let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
371    let mut best_range = (stats.min, stats.max);
372    let mut best_mse = f32::INFINITY;
373
374    for &percentile in &candidates {
375        let lower = stats.percentile(100.0 - percentile);
376        let upper = stats.percentile(percentile);
377
378        let mse = calculate_quantization_mse(data, lower, upper);
379
380        if mse < best_mse {
381            best_mse = mse;
382            best_range = (lower, upper);
383        }
384    }
385
386    best_range
387}
388
389/// Calculate KL divergence between original and quantized distribution.
390///
391/// Uses dense, aligned bins so every bin index in the original histogram
392/// maps to the same value range in the quantized histogram.
393fn calculate_kl_divergence(data: &[f32], min: f32, max: f32) -> f32 {
394    if (max - min).abs() < 1e-8 {
395        return 0.0;
396    }
397
398    let num_bins = 128;
399    let bin_size = (max - min) / num_bins as f32;
400    let scale = (max - min) / 255.0;
401
402    let mut orig_bins = vec![0usize; num_bins];
403    let mut quant_bins = vec![0usize; num_bins];
404
405    for &v in data {
406        let clipped = v.clamp(min, max);
407
408        // Original bin
409        let bin = ((clipped - min) / bin_size).floor() as usize;
410        let bin = bin.min(num_bins - 1);
411        orig_bins[bin] += 1;
412
413        // Simulated INT8 quantize -> dequantize, then bin
414        let q = ((clipped - min) / scale).round();
415        let dequant = min + q * scale;
416        let qbin = ((dequant.clamp(min, max) - min) / bin_size).floor() as usize;
417        let qbin = qbin.min(num_bins - 1);
418        quant_bins[qbin] += 1;
419    }
420
421    let n = data.len() as f32;
422    let epsilon = 1e-10_f32;
423    let mut kl = 0.0_f32;
424
425    for i in 0..num_bins {
426        let p = (orig_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
427        let q = (quant_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
428        kl += p * (p / q).ln();
429    }
430
431    kl
432}
433
434fn calculate_quantization_mse(data: &[f32], min: f32, max: f32) -> f32 {
435    if (max - min).abs() < 1e-8 {
436        return 0.0;
437    }
438
439    let scale = (max - min) / 255.0;
440
441    let mse: f32 = data
442        .iter()
443        .map(|&v| {
444            let clipped = v.clamp(min, max);
445            let q = ((clipped - min) / scale).round().clamp(0.0, 255.0);
446            let dequantized = min + q * scale;
447            (v - dequantized).powi(2)
448        })
449        .sum::<f32>()
450        / data.len() as f32;
451
452    mse
453}