use crate::{Result, TreeBoostError};
pub trait IncrementalScaler {
fn partial_fit(&mut self, data: &[f32], num_features: usize) -> Result<()>;
fn n_samples(&self) -> u64;
fn merge(&mut self, other: &Self) -> Result<()>;
}
pub trait IncrementalEncoder {
fn partial_fit(&mut self, categories: &[&str]) -> Result<()>;
fn n_samples(&self) -> u64;
}
pub fn not_supported_error(preprocessor_name: &str) -> TreeBoostError {
TreeBoostError::Config(format!(
"{} does not support incremental fitting. Schema is frozen after first fit. \
For incremental learning, use FrequencyEncoder or TargetEncoder instead of {}.",
preprocessor_name, preprocessor_name
))
}
#[derive(Debug, Clone, Default, serde::Serialize, serde::Deserialize)]
pub struct WelfordState {
pub n: u64,
pub mean: f64,
pub m2: f64,
}
impl WelfordState {
pub fn new() -> Self {
Self::default()
}
#[inline]
pub fn update(&mut self, x: f64) {
self.n += 1;
let delta = x - self.mean;
self.mean += delta / self.n as f64;
let delta2 = x - self.mean;
self.m2 += delta * delta2;
}
#[inline]
pub fn variance(&self) -> f64 {
if self.n == 0 {
0.0
} else {
self.m2 / self.n as f64
}
}
#[inline]
pub fn std(&self) -> f64 {
self.variance().sqrt()
}
pub fn merge(&mut self, other: &WelfordState) {
if other.n == 0 {
return;
}
if self.n == 0 {
*self = other.clone();
return;
}
let combined_n = self.n + other.n;
let delta = other.mean - self.mean;
let combined_mean = self.mean + delta * (other.n as f64 / combined_n as f64);
let combined_m2 = self.m2
+ other.m2
+ delta * delta * (self.n as f64 * other.n as f64 / combined_n as f64);
self.n = combined_n;
self.mean = combined_mean;
self.m2 = combined_m2;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_welford_basic() {
let mut state = WelfordState::new();
for x in 1..=5 {
state.update(x as f64);
}
assert_eq!(state.n, 5);
assert!((state.mean - 3.0).abs() < 1e-10);
assert!((state.variance() - 2.0).abs() < 1e-10);
}
#[test]
fn test_welford_merge() {
let mut state_a = WelfordState::new();
for x in 1..=3 {
state_a.update(x as f64);
}
let mut state_b = WelfordState::new();
for x in 4..=5 {
state_b.update(x as f64);
}
state_a.merge(&state_b);
assert_eq!(state_a.n, 5);
assert!((state_a.mean - 3.0).abs() < 1e-10);
assert!((state_a.variance() - 2.0).abs() < 1e-10);
}
#[test]
fn test_welford_numerical_stability() {
let mut state = WelfordState::new();
let offset = 1e8_f64;
for i in 0..3 {
state.update(offset + i as f64);
}
assert_eq!(state.n, 3);
assert!((state.mean - (offset + 1.0)).abs() < 1e-6);
let expected_var = 2.0 / 3.0;
assert!(
(state.variance() - expected_var).abs() < 1e-10,
"Welford should handle large offsets: got {} expected {}",
state.variance(),
expected_var
);
}
}