Skip to main content

entrenar/prune/calibrate/
stats.rs

1//! Per-layer activation statistics for calibration.
2
3/// Per-layer activation statistics collected during calibration.
4///
5/// Uses Welford's algorithm for numerically stable online computation.
6#[derive(Debug, Clone, Default)]
7pub struct LayerActivationStats {
8    /// Running sum of L2 norms per input channel.
9    input_norm_sum: Vec<f32>,
10    /// Running sum of squared activations per channel.
11    squared_sum: Vec<f32>,
12    /// Number of samples processed.
13    count: usize,
14    /// Input feature dimension.
15    input_dim: usize,
16}
17
18impl LayerActivationStats {
19    /// Create new statistics tracker for a given input dimension.
20    pub fn new(input_dim: usize) -> Self {
21        Self {
22            input_norm_sum: vec![0.0; input_dim],
23            squared_sum: vec![0.0; input_dim],
24            count: 0,
25            input_dim,
26        }
27    }
28
29    /// Update statistics with a new batch of activations.
30    ///
31    /// # Arguments
32    ///
33    /// * `activations` - Batch of activations [batch_size, input_dim]
34    ///
35    /// # Panics
36    ///
37    /// Panics if activation dimensions don't match.
38    pub fn update(&mut self, activations: &[Vec<f32>]) {
39        contract_pre_update!();
40        if activations.is_empty() {
41            return;
42        }
43
44        for sample in activations {
45            assert_eq!(
46                sample.len(),
47                self.input_dim,
48                "Activation dimension mismatch: expected {}, got {}",
49                self.input_dim,
50                sample.len()
51            );
52
53            for (i, &val) in sample.iter().enumerate() {
54                // Accumulate squared values for L2 norm computation
55                self.squared_sum[i] += val * val;
56                self.input_norm_sum[i] += val.abs();
57            }
58            self.count += 1;
59        }
60    }
61
62    /// Get the mean L2 norm for each input channel.
63    ///
64    /// Returns sqrt(mean(x^2)) for each channel.
65    pub fn input_norms(&self) -> Vec<f32> {
66        if self.count == 0 {
67            return vec![0.0; self.input_dim];
68        }
69
70        self.squared_sum.iter().map(|&sum| (sum / self.count as f32).sqrt()).collect()
71    }
72
73    /// Get the mean absolute value for each input channel.
74    pub fn mean_abs(&self) -> Vec<f32> {
75        if self.count == 0 {
76            return vec![0.0; self.input_dim];
77        }
78
79        self.input_norm_sum.iter().map(|&sum| sum / self.count as f32).collect()
80    }
81
82    /// Get the number of samples processed.
83    pub fn count(&self) -> usize {
84        self.count
85    }
86
87    /// Get the input dimension.
88    pub fn input_dim(&self) -> usize {
89        self.input_dim
90    }
91
92    /// Check if any statistics have been collected.
93    pub fn is_empty(&self) -> bool {
94        self.count == 0
95    }
96
97    /// Reset all statistics.
98    pub fn reset(&mut self) {
99        self.input_norm_sum.fill(0.0);
100        self.squared_sum.fill(0.0);
101        self.count = 0;
102    }
103}