entrenar/prune/calibrate/stats.rs
1//! Per-layer activation statistics for calibration.
2
3/// Per-layer activation statistics collected during calibration.
4///
5/// Uses Welford's algorithm for numerically stable online computation.
6#[derive(Debug, Clone, Default)]
7pub struct LayerActivationStats {
8 /// Running sum of L2 norms per input channel.
9 input_norm_sum: Vec<f32>,
10 /// Running sum of squared activations per channel.
11 squared_sum: Vec<f32>,
12 /// Number of samples processed.
13 count: usize,
14 /// Input feature dimension.
15 input_dim: usize,
16}
17
18impl LayerActivationStats {
19 /// Create new statistics tracker for a given input dimension.
20 pub fn new(input_dim: usize) -> Self {
21 Self {
22 input_norm_sum: vec![0.0; input_dim],
23 squared_sum: vec![0.0; input_dim],
24 count: 0,
25 input_dim,
26 }
27 }
28
29 /// Update statistics with a new batch of activations.
30 ///
31 /// # Arguments
32 ///
33 /// * `activations` - Batch of activations [batch_size, input_dim]
34 ///
35 /// # Panics
36 ///
37 /// Panics if activation dimensions don't match.
38 pub fn update(&mut self, activations: &[Vec<f32>]) {
39 contract_pre_update!();
40 if activations.is_empty() {
41 return;
42 }
43
44 for sample in activations {
45 assert_eq!(
46 sample.len(),
47 self.input_dim,
48 "Activation dimension mismatch: expected {}, got {}",
49 self.input_dim,
50 sample.len()
51 );
52
53 for (i, &val) in sample.iter().enumerate() {
54 // Accumulate squared values for L2 norm computation
55 self.squared_sum[i] += val * val;
56 self.input_norm_sum[i] += val.abs();
57 }
58 self.count += 1;
59 }
60 }
61
62 /// Get the mean L2 norm for each input channel.
63 ///
64 /// Returns sqrt(mean(x^2)) for each channel.
65 pub fn input_norms(&self) -> Vec<f32> {
66 if self.count == 0 {
67 return vec![0.0; self.input_dim];
68 }
69
70 self.squared_sum.iter().map(|&sum| (sum / self.count as f32).sqrt()).collect()
71 }
72
73 /// Get the mean absolute value for each input channel.
74 pub fn mean_abs(&self) -> Vec<f32> {
75 if self.count == 0 {
76 return vec![0.0; self.input_dim];
77 }
78
79 self.input_norm_sum.iter().map(|&sum| sum / self.count as f32).collect()
80 }
81
82 /// Get the number of samples processed.
83 pub fn count(&self) -> usize {
84 self.count
85 }
86
87 /// Get the input dimension.
88 pub fn input_dim(&self) -> usize {
89 self.input_dim
90 }
91
92 /// Check if any statistics have been collected.
93 pub fn is_empty(&self) -> bool {
94 self.count == 0
95 }
96
97 /// Reset all statistics.
98 pub fn reset(&mut self) {
99 self.input_norm_sum.fill(0.0);
100 self.squared_sum.fill(0.0);
101 self.count = 0;
102 }
103}