use crate::error::{SslError, SslResult};
#[derive(Debug, Clone)]
pub struct EmaUpdater {
pub last_momentum: f32,
}
impl Default for EmaUpdater {
fn default() -> Self {
Self {
last_momentum: 0.99,
}
}
}
impl EmaUpdater {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, target: &mut [f32], online: &[f32], momentum: f32) -> SslResult<()> {
if !(momentum.is_finite() && (0.0..=1.0).contains(&momentum)) {
return Err(SslError::InvalidMomentum { momentum });
}
if target.len() != online.len() {
return Err(SslError::DimensionMismatch {
expected: target.len(),
got: online.len(),
});
}
let one_minus_m = 1.0 - momentum;
for (t, &o) in target.iter_mut().zip(online.iter()) {
*t = momentum * *t + one_minus_m * o;
}
self.last_momentum = momentum;
Ok(())
}
}
pub fn cosine_momentum(step: usize, max_steps: usize, m_base: f32, m_end: f32) -> SslResult<f32> {
if !(m_base.is_finite() && (0.0..=1.0).contains(&m_base)) {
return Err(SslError::InvalidMomentum { momentum: m_base });
}
if !(m_end.is_finite() && (0.0..=1.0).contains(&m_end)) {
return Err(SslError::InvalidMomentum { momentum: m_end });
}
if max_steps == 0 {
return Err(SslError::Internal("cosine_momentum: max_steps == 0".into()));
}
let t = step.min(max_steps) as f32 / max_steps as f32;
let theta = std::f32::consts::PI * t;
Ok(m_end - (m_end - m_base) * (theta.cos() + 1.0) / 2.0)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn ema_zero_momentum_replaces_target() {
let mut updater = EmaUpdater::new();
let mut target = vec![0.0_f32, 1.0, 2.0];
let online = vec![10.0_f32, 20.0, 30.0];
updater
.update(&mut target, &online, 0.0)
.expect("update should succeed");
assert_eq!(target, online);
assert!((updater.last_momentum - 0.0).abs() < 1e-7);
}
#[test]
fn ema_full_momentum_keeps_target() {
let mut updater = EmaUpdater::new();
let mut target = vec![0.0_f32, 1.0, 2.0];
let online = vec![10.0_f32, 20.0, 30.0];
updater
.update(&mut target, &online, 1.0)
.expect("update should succeed");
assert_eq!(target, vec![0.0_f32, 1.0, 2.0]);
}
#[test]
fn ema_half_momentum_averages() {
let mut updater = EmaUpdater::new();
let mut target = vec![0.0_f32, 0.0, 0.0];
let online = vec![2.0_f32, 4.0, 6.0];
updater
.update(&mut target, &online, 0.5)
.expect("update should succeed");
assert_eq!(target, vec![1.0_f32, 2.0, 3.0]);
}
#[test]
fn ema_rejects_invalid_momentum() {
let mut updater = EmaUpdater::new();
let mut target = vec![0.0_f32; 3];
let online = vec![1.0_f32; 3];
assert!(updater.update(&mut target, &online, -0.1).is_err());
assert!(updater.update(&mut target, &online, 1.5).is_err());
assert!(updater.update(&mut target, &online, f32::NAN).is_err());
}
#[test]
fn ema_rejects_dim_mismatch() {
let mut updater = EmaUpdater::new();
let mut target = vec![0.0_f32; 3];
let online = vec![1.0_f32; 5];
assert!(updater.update(&mut target, &online, 0.5).is_err());
}
#[test]
fn cosine_momentum_at_step_zero_equals_base() {
let m = cosine_momentum(0, 100, 0.5, 1.0).expect("cosine_momentum should succeed");
assert!((m - 0.5).abs() < 1e-6);
}
#[test]
fn cosine_momentum_at_max_step_equals_end() {
let m = cosine_momentum(100, 100, 0.5, 1.0).expect("cosine_momentum should succeed");
assert!((m - 1.0).abs() < 1e-5);
}
#[test]
fn cosine_momentum_monotone() {
let m1 = cosine_momentum(10, 100, 0.5, 1.0).expect("cosine_momentum should succeed");
let m2 = cosine_momentum(50, 100, 0.5, 1.0).expect("cosine_momentum should succeed");
let m3 = cosine_momentum(90, 100, 0.5, 1.0).expect("cosine_momentum should succeed");
assert!(m1 < m2);
assert!(m2 < m3);
}
#[test]
fn cosine_momentum_rejects_invalid_bounds() {
assert!(cosine_momentum(0, 100, -0.1, 1.0).is_err());
assert!(cosine_momentum(0, 100, 0.5, 1.5).is_err());
}
#[test]
fn cosine_momentum_rejects_zero_steps() {
assert!(cosine_momentum(0, 0, 0.5, 1.0).is_err());
}
}