#[derive(Clone, Debug)]
pub struct IncrementalNormalizer {
means: Vec<f64>,
m2s: Vec<f64>,
count: u64,
n_features: Option<usize>,
variance_floor: f64,
}
impl IncrementalNormalizer {
pub fn new() -> Self {
Self {
means: Vec::new(),
m2s: Vec::new(),
count: 0,
n_features: None,
variance_floor: 1e-12,
}
}
pub fn with_n_features(n: usize) -> Self {
Self {
means: vec![0.0; n],
m2s: vec![0.0; n],
count: 0,
n_features: Some(n),
variance_floor: 1e-12,
}
}
pub fn with_variance_floor(n: usize, floor: f64) -> Self {
Self {
means: vec![0.0; n],
m2s: vec![0.0; n],
count: 0,
n_features: Some(n),
variance_floor: floor,
}
}
pub fn update(&mut self, features: &[f64]) {
match self.n_features {
None => {
let n = features.len();
self.means = vec![0.0; n];
self.m2s = vec![0.0; n];
self.n_features = Some(n);
}
Some(n) => {
assert_eq!(
features.len(),
n,
"IncrementalNormalizer: expected {} features, got {}",
n,
features.len()
);
}
}
self.count += 1;
let count_f = self.count as f64;
for (j, &fj) in features.iter().enumerate() {
let delta = fj - self.means[j];
self.means[j] += delta / count_f;
let delta2 = fj - self.means[j];
self.m2s[j] += delta * delta2;
}
}
pub fn transform(&self, features: &[f64]) -> Vec<f64> {
let n = self.check_features_len(features.len());
assert!(
self.count > 0,
"IncrementalNormalizer: cannot transform before any updates"
);
let count_f = self.count as f64;
let mut out = Vec::with_capacity(n);
for ((&fj, &mean), &m2) in features.iter().zip(self.means.iter()).zip(self.m2s.iter()) {
let var = m2 / count_f;
out.push((fj - mean) / (var + self.variance_floor).sqrt());
}
out
}
pub fn transform_in_place(&self, features: &mut [f64]) {
let _n = self.check_features_len(features.len());
assert!(
self.count > 0,
"IncrementalNormalizer: cannot transform before any updates"
);
let count_f = self.count as f64;
for ((fj, &mean), &m2) in features
.iter_mut()
.zip(self.means.iter())
.zip(self.m2s.iter())
{
let var = m2 / count_f;
*fj = (*fj - mean) / (var + self.variance_floor).sqrt();
}
}
pub fn update_and_transform(&mut self, features: &[f64]) -> Vec<f64> {
self.update(features);
self.transform(features)
}
pub fn mean(&self, idx: usize) -> f64 {
self.means[idx]
}
pub fn variance(&self, idx: usize) -> f64 {
if self.count == 0 {
return 0.0;
}
self.m2s[idx] / self.count as f64
}
pub fn std_dev(&self, idx: usize) -> f64 {
self.variance(idx).sqrt()
}
pub fn n_features(&self) -> Option<usize> {
self.n_features
}
pub fn count(&self) -> u64 {
self.count
}
pub fn reset(&mut self) {
self.count = 0;
self.means.fill(0.0);
self.m2s.fill(0.0);
}
fn check_features_len(&self, len: usize) -> usize {
let n = self
.n_features
.expect("IncrementalNormalizer: n_features not set (call update first)");
assert_eq!(
len, n,
"IncrementalNormalizer: expected {} features, got {}",
n, len
);
n
}
}
impl Default for IncrementalNormalizer {
fn default() -> Self {
Self::new()
}
}
impl crate::pipeline::StreamingPreprocessor for IncrementalNormalizer {
fn update_and_transform(&mut self, features: &[f64]) -> Vec<f64> {
IncrementalNormalizer::update_and_transform(self, features)
}
fn transform(&self, features: &[f64]) -> Vec<f64> {
IncrementalNormalizer::transform(self, features)
}
fn output_dim(&self) -> Option<usize> {
self.n_features()
}
fn reset(&mut self) {
IncrementalNormalizer::reset(self);
}
}
#[cfg(test)]
mod tests {
use super::*;
const EPS: f64 = 1e-9;
#[test]
fn lazy_init_sets_n_features_from_first_update() {
let mut norm = IncrementalNormalizer::new();
assert_eq!(norm.n_features(), None);
norm.update(&[1.0, 2.0, 3.0]);
assert_eq!(norm.n_features(), Some(3));
assert_eq!(norm.count(), 1);
}
#[test]
fn explicit_init_pre_allocates() {
let norm = IncrementalNormalizer::with_n_features(4);
assert_eq!(norm.n_features(), Some(4));
assert_eq!(norm.count(), 0);
}
#[test]
fn single_feature_mean_is_correct() {
let mut norm = IncrementalNormalizer::new();
let samples = [2.0, 4.0, 6.0, 8.0, 10.0];
for &s in &samples {
norm.update(&[s]);
}
let expected_mean = samples.iter().sum::<f64>() / samples.len() as f64;
assert!((norm.mean(0) - expected_mean).abs() < EPS);
}
#[test]
fn single_feature_variance_is_correct() {
let mut norm = IncrementalNormalizer::new();
let samples = [2.0, 4.0, 6.0, 8.0, 10.0];
for &s in &samples {
norm.update(&[s]);
}
let mean = samples.iter().sum::<f64>() / samples.len() as f64;
let expected_var =
samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / samples.len() as f64;
assert!((norm.variance(0) - expected_var).abs() < EPS);
}
#[test]
fn transform_produces_standardized_output() {
let mut norm = IncrementalNormalizer::new();
let data: Vec<[f64; 2]> = (0..100)
.map(|i| [i as f64, (i as f64) * 2.0 + 50.0])
.collect();
for row in &data {
norm.update(row);
}
let mean_point = [norm.mean(0), norm.mean(1)];
let z = norm.transform(&mean_point);
assert!(z[0].abs() < EPS, "z[0] = {}", z[0]);
assert!(z[1].abs() < EPS, "z[1] = {}", z[1]);
let one_sd = [
norm.mean(0) + norm.std_dev(0),
norm.mean(1) + norm.std_dev(1),
];
let z1 = norm.transform(&one_sd);
assert!((z1[0] - 1.0).abs() < 0.01, "z1[0] = {}", z1[0]);
assert!((z1[1] - 1.0).abs() < 0.01, "z1[1] = {}", z1[1]);
}
#[test]
fn variance_floor_prevents_nan() {
let mut norm = IncrementalNormalizer::new();
norm.update(&[42.0]);
let z = norm.transform(&[42.0]);
assert!(z[0].is_finite(), "z[0] should be finite, got {}", z[0]);
}
#[test]
fn update_and_transform_matches_separate_calls() {
let mut norm_a = IncrementalNormalizer::new();
let mut norm_b = IncrementalNormalizer::new();
let samples = [[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]];
for row in &samples {
let z_combined = norm_a.update_and_transform(row);
norm_b.update(row);
let z_separate = norm_b.transform(row);
for (a, b) in z_combined.iter().zip(z_separate.iter()) {
assert!((a - b).abs() < EPS, "mismatch: {} vs {}", a, b);
}
}
}
#[test]
fn reset_clears_statistics() {
let mut norm = IncrementalNormalizer::with_n_features(2);
norm.update(&[10.0, 20.0]);
norm.update(&[30.0, 40.0]);
assert_eq!(norm.count(), 2);
norm.reset();
assert_eq!(norm.count(), 0);
assert!((norm.mean(0)).abs() < EPS);
assert!((norm.variance(0)).abs() < EPS);
assert_eq!(norm.n_features(), Some(2));
}
#[test]
fn transform_in_place_matches_transform() {
let mut norm = IncrementalNormalizer::new();
for row in &[[1.0, 5.0], [2.0, 6.0], [3.0, 7.0], [4.0, 8.0]] {
norm.update(row);
}
let input = [2.5, 6.5];
let z_alloc = norm.transform(&input);
let mut z_inplace = input;
norm.transform_in_place(&mut z_inplace);
for (a, b) in z_alloc.iter().zip(z_inplace.iter()) {
assert!((a - b).abs() < EPS, "mismatch: {} vs {}", a, b);
}
}
#[test]
#[should_panic(expected = "expected 2 features, got 3")]
fn panics_on_dimension_mismatch() {
let mut norm = IncrementalNormalizer::with_n_features(2);
norm.update(&[1.0, 2.0, 3.0]);
}
}