#[derive(Debug, Clone, Default)]
pub struct LayerActivationStats {
input_norm_sum: Vec<f32>,
squared_sum: Vec<f32>,
count: usize,
input_dim: usize,
}
impl LayerActivationStats {
pub fn new(input_dim: usize) -> Self {
Self {
input_norm_sum: vec![0.0; input_dim],
squared_sum: vec![0.0; input_dim],
count: 0,
input_dim,
}
}
pub fn update(&mut self, activations: &[Vec<f32>]) {
if activations.is_empty() {
return;
}
for sample in activations {
assert_eq!(
sample.len(),
self.input_dim,
"Activation dimension mismatch: expected {}, got {}",
self.input_dim,
sample.len()
);
for (i, &val) in sample.iter().enumerate() {
self.squared_sum[i] += val * val;
self.input_norm_sum[i] += val.abs();
}
self.count += 1;
}
}
pub fn input_norms(&self) -> Vec<f32> {
if self.count == 0 {
return vec![0.0; self.input_dim];
}
self.squared_sum.iter().map(|&sum| (sum / self.count as f32).sqrt()).collect()
}
pub fn mean_abs(&self) -> Vec<f32> {
if self.count == 0 {
return vec![0.0; self.input_dim];
}
self.input_norm_sum.iter().map(|&sum| sum / self.count as f32).collect()
}
pub fn count(&self) -> usize {
self.count
}
pub fn input_dim(&self) -> usize {
self.input_dim
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn reset(&mut self) {
self.input_norm_sum.fill(0.0);
self.squared_sum.fill(0.0);
self.count = 0;
}
}