use axonml_nn::Parameter;
use axonml_tensor::Tensor;
pub struct ModelEMA {
shadow: Vec<Vec<f32>>,
decay: f32,
num_updates: usize,
}
impl ModelEMA {
pub fn new(params: &[Parameter], decay: f32) -> Self {
let shadow: Vec<Vec<f32>> = params
.iter()
.map(|p| p.variable().data().to_vec())
.collect();
Self {
shadow,
decay,
num_updates: 0,
}
}
pub fn with_warmup(params: &[Parameter]) -> Self {
Self::new(params, 0.9999)
}
pub fn update(&mut self, params: &[Parameter]) {
self.num_updates += 1;
let d = self.effective_decay();
for (shadow, param) in self.shadow.iter_mut().zip(params.iter()) {
let param_data = param.variable().data().to_vec();
for (s, &p) in shadow.iter_mut().zip(param_data.iter()) {
*s = d * *s + (1.0 - d) * p;
}
}
}
pub fn apply_to(&self, params: &[Parameter]) {
for (shadow, param) in self.shadow.iter().zip(params.iter()) {
let tensor = Tensor::from_vec(shadow.clone(), param.data().shape()).unwrap();
param.update_data(tensor);
}
}
pub fn apply_and_restore<F, R>(&self, params: &[Parameter], f: F) -> R
where
F: FnOnce() -> R,
{
let originals: Vec<Vec<f32>> = params.iter().map(|p| p.data().to_vec()).collect();
self.apply_to(params);
let result = f();
for (orig, param) in originals.iter().zip(params.iter()) {
let tensor = Tensor::from_vec(orig.clone(), param.data().shape()).unwrap();
param.update_data(tensor);
}
result
}
pub fn effective_decay(&self) -> f32 {
let tau = 2000.0f32;
self.decay * (1.0 - (-(self.num_updates as f32) / tau).exp())
}
pub fn num_updates(&self) -> usize {
self.num_updates
}
pub fn shadow_params(&self) -> &[Vec<f32>] {
&self.shadow
}
}
#[cfg(test)]
mod tests {
use super::*;
use axonml_tensor::Tensor;
fn make_params() -> Vec<Parameter> {
vec![
Parameter::new(
Tensor::from_vec(vec![1.0, 2.0, 3.0, 4.0], &[2, 2]).unwrap(),
true,
),
Parameter::new(Tensor::from_vec(vec![0.5, 0.5], &[2]).unwrap(), true),
]
}
#[test]
fn test_ema_creation() {
let params = make_params();
let ema = ModelEMA::new(¶ms, 0.999);
assert_eq!(ema.num_updates(), 0);
assert_eq!(ema.shadow.len(), 2);
assert_eq!(ema.shadow[0], vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_ema_update() {
let params = make_params();
let mut ema = ModelEMA::new(¶ms, 0.9);
params[0].update_data(Tensor::from_vec(vec![10.0, 20.0, 30.0, 40.0], &[2, 2]).unwrap());
ema.update(¶ms);
assert_eq!(ema.num_updates(), 1);
for &v in &ema.shadow[0] {
assert!(v > 1.0, "Shadow should move toward new values, got {v}");
}
}
#[test]
fn test_ema_apply_and_restore() {
let params = make_params();
let original_data = params[0].variable().data().to_vec();
let mut ema = ModelEMA::new(¶ms, 0.5);
for i in 0..100 {
params[0].update_data(Tensor::from_vec(vec![10.0; 4], &[2, 2]).unwrap());
ema.update(¶ms);
}
let result = ema.apply_and_restore(¶ms, || {
let data = params[0].variable().data().to_vec();
assert!(data[0] > 5.0, "EMA values should be closer to 10.0");
42
});
assert_eq!(result, 42);
let restored = params[0].variable().data().to_vec();
assert_eq!(restored, vec![10.0; 4]); }
#[test]
fn test_effective_decay_warmup() {
let params = make_params();
let mut ema = ModelEMA::new(¶ms, 0.9999);
assert!(ema.effective_decay() < 0.01);
ema.num_updates = 10000;
let d = ema.effective_decay();
assert!(
d > 0.99,
"After 10K steps, decay should be ~0.9999, got {d}"
);
}
}