Skip to main content

quantize_rs/calibration/
stats.rs

1//! Incremental activation statistics (min, max, mean, std, histogram).
2//!
3//! [`ActivationStats`] can be built from a single batch with [`from_data`](ActivationStats::from_data)
4//! and then incrementally extended with [`update`](ActivationStats::update).
5
6use crate::calibration::methods::CalibrationMethod;
7
8const NUM_BINS: usize = 256;
9
10/// Incremental activation statistics for a single layer.
11///
12/// Tracks min, max, mean, standard deviation, and a 256-bin histogram.
13/// Supports incremental updates via Chan's parallel algorithm.
14#[derive(Debug, Clone)]
15pub struct ActivationStats {
16    min: f32,
17    max: f32,
18    mean: f32,
19    std: f32,
20    count: usize,
21
22    /// Running sum of squared deviations (Welford's M2) for incremental std.
23    m2: f64,
24
25    histogram_bins: Vec<usize>,
26    hist_min: f32,
27    hist_max: f32,
28}
29
30impl ActivationStats {
31    /// Minimum observed value.
32    pub fn min(&self) -> f32 { self.min }
33    /// Maximum observed value.
34    pub fn max(&self) -> f32 { self.max }
35    /// Running mean.
36    pub fn mean(&self) -> f32 { self.mean }
37    /// Running standard deviation.
38    pub fn std(&self) -> f32 { self.std }
39    /// Number of observations.
40    pub fn count(&self) -> usize { self.count }
41}
42
43impl ActivationStats {
44    /// Create stats from a single batch of observations.
45    pub fn from_data(data: &[f32]) -> Self {
46        if data.is_empty() {
47            return Self::default();
48        }
49
50        let finite: Vec<f32> = data.iter().copied().filter(|v| v.is_finite()).collect();
51        if finite.is_empty() {
52            return Self::default();
53        }
54
55        let min = finite.iter().copied().fold(f32::INFINITY, f32::min);
56        let max = finite.iter().copied().fold(f32::NEG_INFINITY, f32::max);
57
58        let sum: f32 = finite.iter().sum();
59        let mean = sum / finite.len() as f32;
60
61        let m2: f64 = finite.iter()
62            .map(|&x| ((x - mean) as f64).powi(2))
63            .sum();
64        let std = (m2 / finite.len() as f64).sqrt() as f32;
65
66        let histogram_bins = build_histogram(data, min, max);
67
68        Self {
69            min,
70            max,
71            mean,
72            std,
73            count: finite.len(),
74            m2,
75            histogram_bins,
76            hist_min: min,
77            hist_max: max,
78        }
79    }
80
81    /// Incrementally merge a new batch of observations into the stats.
82    pub fn update(&mut self, data: &[f32]) {
83        if data.is_empty() {
84            return;
85        }
86
87        let data_min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
88        let data_max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
89
90        let new_min = self.min.min(data_min);
91        let new_max = self.max.max(data_max);
92
93        // Parallel/batch variant of Welford's online algorithm:
94        // Merge two populations (existing stats + new batch) into combined stats.
95        let old_count = self.count as f64;
96        let new_count = data.len() as f64;
97        let combined_count = old_count + new_count;
98
99        let data_sum: f64 = data.iter().map(|&x| x as f64).sum();
100        let data_mean = data_sum / new_count;
101
102        let data_m2: f64 = data.iter()
103            .map(|&x| ((x as f64) - data_mean).powi(2))
104            .sum();
105
106        // Chan's parallel algorithm for combining M2 values
107        let delta = data_mean - self.mean as f64;
108        self.m2 = self.m2 + data_m2 + delta * delta * old_count * new_count / combined_count;
109
110        self.mean = ((self.mean as f64) * old_count + data_sum) as f32 / combined_count as f32;
111        self.count = combined_count as usize;
112        self.std = (self.m2 / combined_count).sqrt() as f32;
113
114        // If range expanded, re-bin existing data into the new range
115        if new_min < self.hist_min || new_max > self.hist_max {
116            let mut rebinned = vec![0usize; NUM_BINS];
117            rebin(&self.histogram_bins, self.hist_min, self.hist_max, &mut rebinned, new_min, new_max);
118            self.histogram_bins = rebinned;
119            self.hist_min = new_min;
120            self.hist_max = new_max;
121        }
122
123        // Add new data into bins
124        let new_hist = build_histogram(data, self.hist_min, self.hist_max);
125        for (i, &c) in new_hist.iter().enumerate() {
126            self.histogram_bins[i] += c;
127        }
128
129        self.min = new_min;
130        self.max = new_max;
131    }
132
133    /// Estimate the value at percentile `p` (0--100) from the histogram.
134    pub fn percentile(&self, p: f32) -> f32 {
135        if self.histogram_bins.is_empty() {
136            return self.min;
137        }
138
139        let total: usize = self.histogram_bins.iter().sum();
140        if total == 0 {
141            return self.min;
142        }
143
144        // ceil, not truncation: for 5 elements at p=50, target rank must be 3
145        // (the actual median), not 2 (which would return the element below it).
146        let target_count = (total as f32 * p / 100.0).ceil() as usize;
147        let mut cumulative = 0;
148
149        let bin_size = if (self.hist_max - self.hist_min).abs() < 1e-8 {
150            0.0
151        } else {
152            (self.hist_max - self.hist_min) / NUM_BINS as f32
153        };
154
155        for (i, &count) in self.histogram_bins.iter().enumerate() {
156            cumulative += count;
157            if cumulative >= target_count {
158                return self.hist_min + (i as f32 + 0.5) * bin_size;
159            }
160        }
161
162        self.max
163    }
164
165    /// Return histogram data as (bin_center, count) pairs.
166    pub fn histogram_data(&self) -> Vec<(f32, usize)> {
167        if (self.hist_max - self.hist_min).abs() < 1e-8 {
168            let total: usize = self.histogram_bins.iter().sum();
169            if total > 0 {
170                return vec![(self.hist_min, total)];
171            }
172            return Vec::new();
173        }
174        let bin_size = (self.hist_max - self.hist_min) / NUM_BINS as f32;
175        self.histogram_bins.iter()
176            .enumerate()
177            .filter(|(_, &count)| count > 0)
178            .map(|(i, &count)| {
179                let value = self.hist_min + (i as f32 + 0.5) * bin_size;
180                (value, count)
181            })
182            .collect()
183    }
184}
185
186impl Default for ActivationStats {
187    fn default() -> Self {
188        Self {
189            min: f32::INFINITY,
190            max: f32::NEG_INFINITY,
191            mean: 0.0,
192            std: 0.0,
193            count: 0,
194            m2: 0.0,
195            histogram_bins: Vec::new(),
196            hist_min: 0.0,
197            hist_max: 0.0,
198        }
199    }
200}
201
202fn build_histogram(data: &[f32], min: f32, max: f32) -> Vec<usize> {
203    let mut bins = vec![0usize; NUM_BINS];
204
205    if (max - min).abs() < 1e-8 {
206        // All values map to a single bin
207        let finite_count = data.iter().filter(|v| v.is_finite()).count();
208        if !bins.is_empty() {
209            bins[0] = finite_count;
210        }
211        return bins;
212    }
213
214    let bin_size = (max - min) / NUM_BINS as f32;
215
216    for &value in data {
217        if !value.is_finite() { continue; }
218        let bin_idx = ((value - min) / bin_size).floor() as usize;
219        let bin_idx = bin_idx.min(NUM_BINS - 1);
220        bins[bin_idx] += 1;
221    }
222
223    bins
224}
225
226/// Re-bin histogram data from one range to another.
227fn rebin(
228    old_bins: &[usize],
229    old_min: f32,
230    old_max: f32,
231    new_bins: &mut [usize],
232    new_min: f32,
233    new_max: f32,
234) {
235    if old_bins.is_empty() || new_bins.is_empty() {
236        return;
237    }
238    let old_range = old_max - old_min;
239    let new_range = new_max - new_min;
240    if old_range.abs() < 1e-8 || new_range.abs() < 1e-8 {
241        // Everything goes into the closest bin in the new range
242        let total: usize = old_bins.iter().sum();
243        if total > 0 {
244            let center = (old_min + old_max) * 0.5;
245            let idx = ((center - new_min) / new_range * new_bins.len() as f32).floor() as usize;
246            let idx = idx.min(new_bins.len() - 1);
247            new_bins[idx] += total;
248        }
249        return;
250    }
251    let old_bin_size = old_range / old_bins.len() as f32;
252    let new_bin_count = new_bins.len();
253    for (i, &count) in old_bins.iter().enumerate() {
254        if count == 0 { continue; }
255        let center = old_min + (i as f32 + 0.5) * old_bin_size;
256        let new_idx = ((center - new_min) / new_range * new_bin_count as f32).floor() as usize;
257        let new_idx = new_idx.min(new_bin_count - 1);
258        new_bins[new_idx] += count;
259    }
260}
261
262#[cfg(test)]
263mod tests {
264    use super::*;
265
266    #[test]
267    fn test_activation_stats() {
268        let data = vec![-1.0, -0.5, 0.0, 0.5, 1.0];
269        let stats = ActivationStats::from_data(&data);
270
271        assert_eq!(stats.min(), -1.0);
272        assert_eq!(stats.max(), 1.0);
273        assert!((stats.mean() - 0.0).abs() < 0.01);
274
275        let p50 = stats.percentile(50.0);
276        assert!((p50 - 0.0).abs() < 0.3);
277    }
278}
279
280/// Compute the optimal quantization range for `data` using the given method.
281pub fn calculate_optimal_range(
282    data: &[f32],
283    method: CalibrationMethod,
284) -> (f32, f32) {
285    if data.is_empty() {
286        return (0.0, 0.0);
287    }
288
289    match method {
290        CalibrationMethod::MinMax => {
291            let min = data.iter().copied().filter(|v| v.is_finite()).fold(f32::INFINITY, f32::min);
292            let max = data.iter().copied().filter(|v| v.is_finite()).fold(f32::NEG_INFINITY, f32::max);
293            (min, max)
294        }
295
296        CalibrationMethod::Percentile(p) => {
297            let stats = ActivationStats::from_data(data);
298            let lower = stats.percentile(100.0 - p);
299            let upper = stats.percentile(p);
300            (lower, upper)
301        }
302
303        CalibrationMethod::Entropy => {
304            optimize_kl_divergence(data)
305        }
306
307        CalibrationMethod::MSE => {
308            optimize_mse(data)
309        }
310    }
311}
312
313/// Optimize range using KL divergence (entropy method)
314fn optimize_kl_divergence(data: &[f32]) -> (f32, f32) {
315    let stats = ActivationStats::from_data(data);
316
317    // Try different percentile thresholds and find the one with minimum KL divergence
318    let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
319    let mut best_range = (stats.min, stats.max);
320    let mut best_kl = f32::INFINITY;
321
322    for &percentile in &candidates {
323        let lower = stats.percentile(100.0 - percentile);
324        let upper = stats.percentile(percentile);
325
326        let kl = calculate_kl_divergence(data, lower, upper);
327
328        if kl < best_kl {
329            best_kl = kl;
330            best_range = (lower, upper);
331        }
332    }
333
334    best_range
335}
336
337/// Optimize range using MSE minimization
338fn optimize_mse(data: &[f32]) -> (f32, f32) {
339    let stats = ActivationStats::from_data(data);
340
341    // Try different percentile thresholds and find the one with minimum MSE
342    let candidates = [99.0, 99.5, 99.9, 99.95, 99.99];
343    let mut best_range = (stats.min, stats.max);
344    let mut best_mse = f32::INFINITY;
345
346    for &percentile in &candidates {
347        let lower = stats.percentile(100.0 - percentile);
348        let upper = stats.percentile(percentile);
349
350        let mse = calculate_quantization_mse(data, lower, upper);
351
352        if mse < best_mse {
353            best_mse = mse;
354            best_range = (lower, upper);
355        }
356    }
357
358    best_range
359}
360
361/// Calculate KL divergence between original and quantized distribution.
362///
363/// Uses dense, aligned bins so every bin index in the original histogram
364/// maps to the same value range in the quantized histogram.
365fn calculate_kl_divergence(data: &[f32], min: f32, max: f32) -> f32 {
366    if (max - min).abs() < 1e-8 {
367        return 0.0;
368    }
369
370    let num_bins = 128;
371    let bin_size = (max - min) / num_bins as f32;
372    let scale = (max - min) / 255.0;
373
374    let mut orig_bins = vec![0usize; num_bins];
375    let mut quant_bins = vec![0usize; num_bins];
376
377    for &v in data {
378        let clipped = v.clamp(min, max);
379
380        // Original bin
381        let bin = ((clipped - min) / bin_size).floor() as usize;
382        let bin = bin.min(num_bins - 1);
383        orig_bins[bin] += 1;
384
385        // Simulated INT8 quantize -> dequantize, then bin
386        let q = ((clipped - min) / scale).round();
387        let dequant = min + q * scale;
388        let qbin = ((dequant.clamp(min, max) - min) / bin_size).floor() as usize;
389        let qbin = qbin.min(num_bins - 1);
390        quant_bins[qbin] += 1;
391    }
392
393    let n = data.len() as f32;
394    let epsilon = 1e-10_f32;
395    let mut kl = 0.0_f32;
396
397    for i in 0..num_bins {
398        let p = (orig_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
399        let q = (quant_bins[i] as f32 + epsilon) / (n + epsilon * num_bins as f32);
400        kl += p * (p / q).ln();
401    }
402
403    kl
404}
405
406fn calculate_quantization_mse(data: &[f32], min: f32, max: f32) -> f32 {
407    if (max - min).abs() < 1e-8 {
408        return 0.0;
409    }
410
411    let scale = (max - min) / 255.0;
412
413    let mse: f32 = data.iter()
414        .map(|&v| {
415            let clipped = v.clamp(min, max);
416            let q = ((clipped - min) / scale).round();
417            let dequantized = min + q * scale;
418            (v - dequantized).powi(2)
419        })
420        .sum::<f32>() / data.len() as f32;
421
422    mse
423}