use crate::preprocessing::incremental::{IncrementalScaler, WelfordState};
use crate::{Result, TreeBoostError};
pub trait Scaler {
fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()>;
fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()>;
fn fit_transform(&mut self, data: &mut [f32], num_features: usize) -> Result<()> {
self.fit(data, num_features)?;
self.transform(data, num_features)?;
Ok(())
}
fn is_fitted(&self) -> bool;
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct StandardScaler {
pub means: Vec<f32>,
pub stds: Vec<f32>,
fitted: bool,
#[serde(default)]
welford_states: Vec<WelfordState>,
#[serde(default)]
forget_factor: Option<f32>,
}
impl StandardScaler {
pub fn new() -> Self {
Self {
means: Vec::new(),
stds: Vec::new(),
fitted: false,
welford_states: Vec::new(),
forget_factor: None,
}
}
pub fn with_forget_factor(forget_factor: f32) -> Self {
Self {
means: Vec::new(),
stds: Vec::new(),
fitted: false,
welford_states: Vec::new(),
forget_factor: Some(forget_factor.clamp(0.0, 1.0)),
}
}
pub fn set_forget_factor(&mut self, factor: Option<f32>) {
self.forget_factor = factor.map(|f| f.clamp(0.0, 1.0));
}
pub fn forget_factor(&self) -> Option<f32> {
self.forget_factor
}
pub fn means(&self) -> &[f32] {
&self.means
}
pub fn stds(&self) -> &[f32] {
&self.stds
}
fn sync_from_welford(&mut self) {
let num_features = self.welford_states.len();
self.means.resize(num_features, 0.0);
self.stds.resize(num_features, 1.0);
for (i, state) in self.welford_states.iter().enumerate() {
self.means[i] = state.mean as f32;
let std = state.std() as f32;
self.stds[i] = if std < 1e-8 { 1.0 } else { std };
}
}
fn compute_batch_stats(data: &[f32], num_features: usize) -> Vec<(f64, f64)> {
let num_rows = data.len() / num_features;
let mut stats = vec![(0.0f64, 0.0f64); num_features];
if num_rows == 0 {
return stats;
}
for feat in 0..num_features {
let mut sum = 0.0f64;
for row in 0..num_rows {
sum += data[row * num_features + feat] as f64;
}
stats[feat].0 = sum / num_rows as f64;
}
for feat in 0..num_features {
let mean = stats[feat].0;
let mut variance = 0.0f64;
for row in 0..num_rows {
let x = data[row * num_features + feat] as f64;
variance += (x - mean).powi(2);
}
stats[feat].1 = variance / num_rows as f64;
}
stats
}
fn partial_fit_ema(&mut self, data: &[f32], num_features: usize, alpha: f32) -> Result<()> {
let num_rows = data.len() / num_features;
if num_rows == 0 {
return Ok(());
}
let batch_stats = Self::compute_batch_stats(data, num_features);
if self.means.is_empty() || !self.fitted {
self.means = vec![0.0; num_features];
self.stds = vec![1.0; num_features];
self.welford_states = vec![WelfordState::new(); num_features];
for feat in 0..num_features {
let (mean, var) = batch_stats[feat];
self.means[feat] = mean as f32;
let std = var.sqrt() as f32;
self.stds[feat] = if std < 1e-8 { 1.0 } else { std };
self.welford_states[feat].n = num_rows as u64;
self.welford_states[feat].mean = mean;
self.welford_states[feat].m2 = var * num_rows as f64;
}
self.fitted = true;
return Ok(());
}
if self.means.len() != num_features {
return Err(TreeBoostError::Data(format!(
"num_features mismatch: initialized with {}, partial_fit with {}",
self.means.len(),
num_features
)));
}
let alpha_64 = alpha as f64;
let decay = 1.0 - alpha_64;
for feat in 0..num_features {
let (batch_mean, batch_var) = batch_stats[feat];
let old_mean = self.means[feat] as f64;
let new_mean = decay * old_mean + alpha_64 * batch_mean;
self.means[feat] = new_mean as f32;
let old_var = (self.stds[feat] as f64).powi(2);
let new_var = decay * old_var + alpha_64 * batch_var;
let new_std = new_var.sqrt() as f32;
self.stds[feat] = if new_std < 1e-8 { 1.0 } else { new_std };
self.welford_states[feat].n += num_rows as u64;
}
Ok(())
}
}
impl Default for StandardScaler {
fn default() -> Self {
Self::new()
}
}
impl Scaler for StandardScaler {
fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
if num_features == 0 {
return Err(TreeBoostError::Data("num_features must be > 0".into()));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
if num_rows == 0 {
return Err(TreeBoostError::Data("No rows to fit".into()));
}
self.means = vec![0.0; num_features];
self.stds = vec![0.0; num_features];
for feat in 0..num_features {
let mut sum = 0.0;
for row in 0..num_rows {
sum += data[row * num_features + feat];
}
self.means[feat] = sum / num_rows as f32;
}
for feat in 0..num_features {
let mean = self.means[feat];
let mut variance = 0.0;
for row in 0..num_rows {
let x = data[row * num_features + feat];
variance += (x - mean).powi(2);
}
let std = (variance / num_rows as f32).sqrt();
self.stds[feat] = if std < 1e-8 { 1.0 } else { std };
}
self.fitted = true;
Ok(())
}
fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
if !self.fitted {
return Err(TreeBoostError::Data(
"StandardScaler not fitted. Call fit() first.".into(),
));
}
if num_features != self.means.len() {
return Err(TreeBoostError::Data(format!(
"num_features mismatch: fit with {}, transform with {}",
self.means.len(),
num_features
)));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
for feat in 0..num_features {
let mean = self.means[feat];
let std = self.stds[feat];
for row in 0..num_rows {
let idx = row * num_features + feat;
data[idx] = (data[idx] - mean) / std;
}
}
Ok(())
}
fn is_fitted(&self) -> bool {
self.fitted
}
}
impl IncrementalScaler for StandardScaler {
fn partial_fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
if num_features == 0 {
return Err(TreeBoostError::Data("num_features must be > 0".into()));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
if num_rows == 0 {
return Ok(()); }
if let Some(alpha) = self.forget_factor {
return self.partial_fit_ema(data, num_features, alpha);
}
if self.welford_states.is_empty() {
self.welford_states = vec![WelfordState::new(); num_features];
} else if self.welford_states.len() != num_features {
return Err(TreeBoostError::Data(format!(
"num_features mismatch: initialized with {}, partial_fit with {}",
self.welford_states.len(),
num_features
)));
}
for row in 0..num_rows {
for feat in 0..num_features {
let x = data[row * num_features + feat] as f64;
if x.is_finite() {
self.welford_states[feat].update(x);
}
}
}
self.sync_from_welford();
self.fitted = true;
Ok(())
}
fn n_samples(&self) -> u64 {
self.welford_states.first().map(|s| s.n).unwrap_or(0)
}
fn merge(&mut self, other: &Self) -> Result<()> {
if self.welford_states.is_empty() {
self.welford_states = other.welford_states.clone();
self.sync_from_welford();
self.fitted = other.fitted;
return Ok(());
}
if other.welford_states.is_empty() {
return Ok(()); }
if self.welford_states.len() != other.welford_states.len() {
return Err(TreeBoostError::Data(format!(
"Cannot merge scalers with different num_features: {} vs {}",
self.welford_states.len(),
other.welford_states.len()
)));
}
for (self_state, other_state) in self.welford_states.iter_mut().zip(&other.welford_states) {
self_state.merge(other_state);
}
self.sync_from_welford();
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct MinMaxScaler {
pub mins: Vec<f32>,
pub maxs: Vec<f32>,
pub feature_range: (f32, f32),
fitted: bool,
#[serde(default)]
n_samples: u64,
}
impl MinMaxScaler {
pub fn new() -> Self {
Self {
mins: Vec::new(),
maxs: Vec::new(),
feature_range: (0.0, 1.0),
fitted: false,
n_samples: 0,
}
}
pub fn with_range(mut self, min: f32, max: f32) -> Self {
self.feature_range = (min, max);
self
}
}
impl Default for MinMaxScaler {
fn default() -> Self {
Self::new()
}
}
impl Scaler for MinMaxScaler {
fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
if num_features == 0 {
return Err(TreeBoostError::Data("num_features must be > 0".into()));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
if num_rows == 0 {
return Err(TreeBoostError::Data("No rows to fit".into()));
}
self.mins = vec![f32::INFINITY; num_features];
self.maxs = vec![f32::NEG_INFINITY; num_features];
for feat in 0..num_features {
for row in 0..num_rows {
let val = data[row * num_features + feat];
self.mins[feat] = self.mins[feat].min(val);
self.maxs[feat] = self.maxs[feat].max(val);
}
if (self.maxs[feat] - self.mins[feat]).abs() < 1e-8 {
self.maxs[feat] = self.mins[feat] + 1.0;
}
}
self.fitted = true;
Ok(())
}
fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
if !self.fitted {
return Err(TreeBoostError::Data(
"MinMaxScaler not fitted. Call fit() first.".into(),
));
}
if num_features != self.mins.len() {
return Err(TreeBoostError::Data(format!(
"num_features mismatch: fit with {}, transform with {}",
self.mins.len(),
num_features
)));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
let (a, b) = self.feature_range;
for feat in 0..num_features {
let min = self.mins[feat];
let max = self.maxs[feat];
let scale = b - a;
for row in 0..num_rows {
let idx = row * num_features + feat;
data[idx] = (data[idx] - min) / (max - min) * scale + a;
data[idx] = data[idx].clamp(a, b);
}
}
Ok(())
}
fn is_fitted(&self) -> bool {
self.fitted
}
}
impl IncrementalScaler for MinMaxScaler {
fn partial_fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
if num_features == 0 {
return Err(TreeBoostError::Data("num_features must be > 0".into()));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
if num_rows == 0 {
return Ok(()); }
if self.mins.is_empty() {
self.mins = vec![f32::INFINITY; num_features];
self.maxs = vec![f32::NEG_INFINITY; num_features];
} else if self.mins.len() != num_features {
return Err(TreeBoostError::Data(format!(
"num_features mismatch: initialized with {}, partial_fit with {}",
self.mins.len(),
num_features
)));
}
for row in 0..num_rows {
for feat in 0..num_features {
let val = data[row * num_features + feat];
if val.is_finite() {
self.mins[feat] = self.mins[feat].min(val);
self.maxs[feat] = self.maxs[feat].max(val);
}
}
}
for feat in 0..num_features {
if (self.maxs[feat] - self.mins[feat]).abs() < 1e-8 {
self.maxs[feat] = self.mins[feat] + 1.0;
}
}
self.n_samples += num_rows as u64;
self.fitted = true;
Ok(())
}
fn n_samples(&self) -> u64 {
self.n_samples
}
fn merge(&mut self, other: &Self) -> Result<()> {
if self.mins.is_empty() {
self.mins = other.mins.clone();
self.maxs = other.maxs.clone();
self.n_samples = other.n_samples;
self.fitted = other.fitted;
return Ok(());
}
if other.mins.is_empty() {
return Ok(()); }
if self.mins.len() != other.mins.len() {
return Err(TreeBoostError::Data(format!(
"Cannot merge scalers with different num_features: {} vs {}",
self.mins.len(),
other.mins.len()
)));
}
for i in 0..self.mins.len() {
self.mins[i] = self.mins[i].min(other.mins[i]);
self.maxs[i] = self.maxs[i].max(other.maxs[i]);
}
self.n_samples += other.n_samples;
Ok(())
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct RobustScaler {
pub medians: Vec<f32>,
pub iqrs: Vec<f32>,
fitted: bool,
}
impl RobustScaler {
pub fn new() -> Self {
Self {
medians: Vec::new(),
iqrs: Vec::new(),
fitted: false,
}
}
}
impl Default for RobustScaler {
fn default() -> Self {
Self::new()
}
}
impl Scaler for RobustScaler {
fn fit(&mut self, data: &[f32], num_features: usize) -> Result<()> {
if num_features == 0 {
return Err(TreeBoostError::Data("num_features must be > 0".into()));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
if num_rows == 0 {
return Err(TreeBoostError::Data("No rows to fit".into()));
}
self.medians = vec![0.0; num_features];
self.iqrs = vec![0.0; num_features];
use tdigest::TDigest;
for feat in 0..num_features {
let mut digest = TDigest::new_with_size(100);
for row in 0..num_rows {
let value = data[row * num_features + feat] as f64;
if value.is_finite() {
digest = digest.merge_unsorted(vec![value]);
}
}
let q1 = digest.estimate_quantile(0.25) as f32;
let median = digest.estimate_quantile(0.50) as f32;
let q3 = digest.estimate_quantile(0.75) as f32;
self.medians[feat] = median;
let iqr = q3 - q1;
self.iqrs[feat] = if iqr < 1e-8 { 1.0 } else { iqr };
}
self.fitted = true;
Ok(())
}
fn transform(&self, data: &mut [f32], num_features: usize) -> Result<()> {
if !self.fitted {
return Err(TreeBoostError::Data(
"RobustScaler not fitted. Call fit() first.".into(),
));
}
if num_features != self.medians.len() {
return Err(TreeBoostError::Data(format!(
"num_features mismatch: fit with {}, transform with {}",
self.medians.len(),
num_features
)));
}
if !data.len().is_multiple_of(num_features) {
return Err(TreeBoostError::Data(format!(
"Data length {} not divisible by num_features {}",
data.len(),
num_features
)));
}
let num_rows = data.len() / num_features;
for feat in 0..num_features {
let median = self.medians[feat];
let iqr = self.iqrs[feat];
for row in 0..num_rows {
let idx = row * num_features + feat;
data[idx] = (data[idx] - median) / iqr;
}
}
Ok(())
}
fn is_fitted(&self) -> bool {
self.fitted
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_standard_scaler_basic() {
let mut data = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let num_features = 3;
let mut scaler = StandardScaler::new();
assert!(!scaler.is_fitted());
scaler.fit(&data, num_features).unwrap();
assert!(scaler.is_fitted());
assert_eq!(scaler.means(), &[2.5, 3.5, 4.5]);
scaler.transform(&mut data, num_features).unwrap();
}
#[test]
fn test_standard_scaler_zero_variance() {
let mut data = vec![5.0, 1.0, 2.0, 5.0, 3.0, 4.0];
let num_features = 3;
let mut scaler = StandardScaler::new();
scaler.fit(&data, num_features).unwrap();
assert_eq!(scaler.stds[0], 1.0);
assert_eq!(scaler.means[0], 5.0);
scaler.transform(&mut data, num_features).unwrap();
}
#[test]
fn test_minmax_scaler_basic() {
let mut data = vec![1.0, 10.0, 2.0, 20.0, 3.0, 30.0]; let num_features = 2;
let mut scaler = MinMaxScaler::new();
scaler.fit(&data, num_features).unwrap();
assert_eq!(scaler.mins, vec![1.0, 10.0]);
assert_eq!(scaler.maxs, vec![3.0, 30.0]);
scaler.transform(&mut data, num_features).unwrap();
assert!((data[0] - 0.0).abs() < 1e-6);
assert!((data[2] - 0.5).abs() < 1e-6);
assert!((data[4] - 1.0).abs() < 1e-6);
assert!((data[1] - 0.0).abs() < 1e-6);
assert!((data[3] - 0.5).abs() < 1e-6);
assert!((data[5] - 1.0).abs() < 1e-6);
}
#[test]
fn test_minmax_scaler_custom_range() {
let mut data = vec![1.0, 2.0, 3.0]; let num_features = 1;
let mut scaler = MinMaxScaler::new().with_range(-1.0, 1.0);
scaler.fit(&data, num_features).unwrap();
scaler.transform(&mut data, num_features).unwrap();
assert!((data[0] - (-1.0)).abs() < 1e-6);
assert!((data[1] - 0.0).abs() < 1e-6);
assert!((data[2] - 1.0).abs() < 1e-6);
}
#[test]
fn test_robust_scaler_basic() {
let mut data = vec![1.0, 2.0, 3.0, 100.0]; let num_features = 2;
let mut scaler = RobustScaler::new();
scaler.fit(&data, num_features).unwrap();
assert!((scaler.medians[0] - 2.0).abs() < 1e-6);
scaler.transform(&mut data, num_features).unwrap();
}
#[test]
fn test_scaler_not_fitted_error() {
let mut data = vec![1.0, 2.0, 3.0];
let scaler = StandardScaler::new();
let result = scaler.transform(&mut data, 1);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("not fitted"));
}
#[test]
fn test_scaler_feature_mismatch_error() {
let data = vec![1.0, 2.0, 3.0, 4.0];
let mut scaler = StandardScaler::new();
scaler.fit(&data, 2).unwrap();
let mut test_data = vec![5.0, 6.0, 7.0];
let result = scaler.transform(&mut test_data, 3);
assert!(result.is_err());
assert!(result.unwrap_err().to_string().contains("mismatch"));
}
#[test]
fn test_standard_scaler_incremental_equivalence() {
let all_data: Vec<f32> = (0..1000).map(|i| i as f32).collect();
let num_features = 1;
let mut scaler_a = StandardScaler::new();
scaler_a.fit(&all_data, num_features).unwrap();
let mut scaler_b = StandardScaler::new();
for chunk in all_data.chunks(100) {
scaler_b.partial_fit(chunk, num_features).unwrap();
}
assert!(
(scaler_a.means[0] - scaler_b.means[0]).abs() < 1e-3,
"Means differ: {} vs {}",
scaler_a.means[0],
scaler_b.means[0]
);
assert!(
(scaler_a.stds[0] - scaler_b.stds[0]).abs() < 1e-3,
"Stds differ: {} vs {}",
scaler_a.stds[0],
scaler_b.stds[0]
);
assert_eq!(scaler_b.n_samples(), 1000);
}
#[test]
fn test_standard_scaler_welford_stability() {
let offset = 1e8_f32;
let data: Vec<f32> = (0..100).map(|i| offset + i as f32).collect();
let num_features = 1;
let mut scaler = StandardScaler::new();
scaler.partial_fit(&data, num_features).unwrap();
let expected_mean = offset + 49.5;
assert!(
(scaler.means[0] - expected_mean).abs() < 1.0,
"Mean with large offset: got {}, expected {}",
scaler.means[0],
expected_mean
);
}
#[test]
fn test_standard_scaler_merge() {
let num_features = 2;
let mut scaler_a = StandardScaler::new();
scaler_a
.partial_fit(&[1.0, 10.0, 2.0, 20.0], num_features)
.unwrap();
let mut scaler_b = StandardScaler::new();
scaler_b
.partial_fit(&[3.0, 30.0, 4.0, 40.0], num_features)
.unwrap();
scaler_a.merge(&scaler_b).unwrap();
assert_eq!(scaler_a.n_samples(), 4);
assert!((scaler_a.means[0] - 2.5).abs() < 1e-5);
assert!((scaler_a.means[1] - 25.0).abs() < 1e-4);
}
#[test]
fn test_minmax_scaler_incremental() {
let num_features = 2;
let mut scaler = MinMaxScaler::new();
scaler
.partial_fit(&[0.0, 0.0, 50.0, 100.0], num_features)
.unwrap();
assert_eq!(scaler.mins, vec![0.0, 0.0]);
assert_eq!(scaler.maxs, vec![50.0, 100.0]);
scaler
.partial_fit(&[25.0, 50.0, 100.0, 200.0], num_features)
.unwrap();
assert_eq!(scaler.mins, vec![0.0, 0.0]); assert_eq!(scaler.maxs, vec![100.0, 200.0]);
assert_eq!(scaler.n_samples(), 4);
}
#[test]
fn test_minmax_scaler_merge() {
let num_features = 1;
let mut scaler_a = MinMaxScaler::new();
scaler_a.partial_fit(&[10.0, 20.0], num_features).unwrap();
let mut scaler_b = MinMaxScaler::new();
scaler_b.partial_fit(&[5.0, 30.0], num_features).unwrap();
scaler_a.merge(&scaler_b).unwrap();
assert_eq!(scaler_a.mins, vec![5.0]);
assert_eq!(scaler_a.maxs, vec![30.0]);
assert_eq!(scaler_a.n_samples(), 4);
}
#[test]
fn test_standard_scaler_forget_factor_creation() {
let scaler = StandardScaler::with_forget_factor(0.1);
assert_eq!(scaler.forget_factor(), Some(0.1));
let mut scaler2 = StandardScaler::new();
assert_eq!(scaler2.forget_factor(), None);
scaler2.set_forget_factor(Some(0.5));
assert_eq!(scaler2.forget_factor(), Some(0.5));
scaler2.set_forget_factor(None);
assert_eq!(scaler2.forget_factor(), None);
}
#[test]
fn test_standard_scaler_forget_factor_clamping() {
let scaler = StandardScaler::with_forget_factor(-0.5);
assert_eq!(scaler.forget_factor(), Some(0.0));
let scaler2 = StandardScaler::with_forget_factor(1.5);
assert_eq!(scaler2.forget_factor(), Some(1.0));
}
#[test]
fn test_standard_scaler_ema_single_batch() {
let num_features = 1;
let data = vec![10.0, 20.0, 30.0, 40.0];
let mut scaler = StandardScaler::with_forget_factor(0.1);
scaler.partial_fit(&data, num_features).unwrap();
assert!(scaler.is_fitted());
assert!((scaler.means()[0] - 25.0).abs() < 0.01);
}
#[test]
fn test_standard_scaler_ema_decay() {
let num_features = 1;
let batch1 = vec![8.0, 10.0, 12.0];
let batch2 = vec![98.0, 100.0, 102.0];
let mut scaler = StandardScaler::with_forget_factor(0.3);
scaler.partial_fit(&batch1, num_features).unwrap();
let mean_after_batch1 = scaler.means()[0];
assert!((mean_after_batch1 - 10.0).abs() < 0.01);
scaler.partial_fit(&batch2, num_features).unwrap();
let mean_after_batch2 = scaler.means()[0];
assert!(
(mean_after_batch2 - 37.0).abs() < 0.5,
"Expected ~37, got {}",
mean_after_batch2
);
}
#[test]
fn test_standard_scaler_ema_vs_cumulative() {
let num_features = 1;
let batch1 = vec![8.0, 10.0, 12.0];
let batch2 = vec![98.0, 100.0, 102.0];
let mut cumulative = StandardScaler::new();
cumulative.partial_fit(&batch1, num_features).unwrap();
cumulative.partial_fit(&batch2, num_features).unwrap();
let mut ema = StandardScaler::with_forget_factor(0.5);
ema.partial_fit(&batch1, num_features).unwrap();
ema.partial_fit(&batch2, num_features).unwrap();
let cumulative_mean = cumulative.means()[0];
let ema_mean = ema.means()[0];
assert!((cumulative_mean - 55.0).abs() < 1.0);
assert!((ema_mean - 55.0).abs() < 1.0);
}
#[test]
fn test_standard_scaler_ema_adapts_to_drift() {
let num_features = 1;
let batch1 = vec![8.0, 10.0, 12.0];
let batch2 = vec![28.0, 30.0, 32.0]; let batch3 = vec![48.0, 50.0, 52.0]; let batch4 = vec![68.0, 70.0, 72.0]; let batch5 = vec![88.0, 90.0, 92.0];
let mut scaler = StandardScaler::with_forget_factor(0.5);
scaler.partial_fit(&batch1, num_features).unwrap();
assert!((scaler.means()[0] - 10.0).abs() < 1.0);
scaler.partial_fit(&batch2, num_features).unwrap();
assert!(
(scaler.means()[0] - 20.0).abs() < 1.0,
"Expected ~20, got {}",
scaler.means()[0]
);
scaler.partial_fit(&batch3, num_features).unwrap();
assert!(
(scaler.means()[0] - 35.0).abs() < 1.0,
"Expected ~35, got {}",
scaler.means()[0]
);
scaler.partial_fit(&batch4, num_features).unwrap();
assert!(
(scaler.means()[0] - 52.5).abs() < 1.0,
"Expected ~52.5, got {}",
scaler.means()[0]
);
scaler.partial_fit(&batch5, num_features).unwrap();
assert!(
(scaler.means()[0] - 71.25).abs() < 1.5,
"Expected ~71.25, got {}",
scaler.means()[0]
);
}
#[test]
fn test_standard_scaler_ema_variance_decay() {
let num_features = 1;
let batch1 = vec![9.9, 10.0, 10.1];
let batch2 = vec![0.0, 10.0, 20.0];
let mut scaler = StandardScaler::with_forget_factor(0.3);
scaler.partial_fit(&batch1, num_features).unwrap();
let std_after_batch1 = scaler.stds()[0];
assert!(
std_after_batch1 < 1.0,
"Std should be small after low-variance batch"
);
scaler.partial_fit(&batch2, num_features).unwrap();
let std_after_batch2 = scaler.stds()[0];
assert!(
std_after_batch2 > std_after_batch1,
"Std should increase after high-variance batch"
);
}
}