use burn::tensor::backend::Backend;
use burn::tensor::Tensor;
#[derive(Debug, Clone)]
pub struct Ema {
pub momentum: f64,
pub schedule: Option<MomentumSchedule>,
}
impl Ema {
pub fn new(momentum: f64) -> Self {
Self {
momentum,
schedule: None,
}
}
pub fn with_cosine_schedule(base_momentum: f64, total_steps: usize) -> Self {
Self {
momentum: base_momentum,
schedule: Some(MomentumSchedule::Cosine(CosineMomentumSchedule {
base_momentum,
final_momentum: 1.0,
total_steps,
})),
}
}
pub fn get_momentum(&self, step: usize) -> f64 {
match &self.schedule {
Some(schedule) => schedule.get_momentum(step),
None => self.momentum,
}
}
pub fn step(&self, target: f64, online: f64, step: usize) -> f64 {
let m = self.get_momentum(step);
m * target + (1.0 - m) * online
}
pub fn update_tensor<B: Backend, const D: usize>(
&self,
target: Tensor<B, D>,
online: &Tensor<B, D>,
step: usize,
) -> Tensor<B, D> {
let m = self.get_momentum(step);
target * m + online.clone() * (1.0 - m)
}
pub fn update_tensor_pairs<B: Backend, const D: usize>(
&self,
pairs: Vec<(Tensor<B, D>, Tensor<B, D>)>,
step: usize,
) -> Vec<Tensor<B, D>> {
let m = self.get_momentum(step);
pairs
.into_iter()
.map(|(target, online)| target * m + online * (1.0 - m))
.collect()
}
}
#[derive(Debug, Clone)]
pub enum MomentumSchedule {
Cosine(CosineMomentumSchedule),
}
impl MomentumSchedule {
pub fn get_momentum(&self, step: usize) -> f64 {
match self {
MomentumSchedule::Cosine(s) => s.get_momentum(step),
}
}
}
#[derive(Debug, Clone)]
pub struct CosineMomentumSchedule {
pub base_momentum: f64,
pub final_momentum: f64,
pub total_steps: usize,
}
impl CosineMomentumSchedule {
pub fn get_momentum(&self, step: usize) -> f64 {
if self.total_steps == 0 {
return self.final_momentum;
}
let t = step.min(self.total_steps - 1) as f64;
let total = self.total_steps as f64;
let progress = t / total;
self.final_momentum
- (self.final_momentum - self.base_momentum)
* (1.0 + (progress * std::f64::consts::PI).cos())
/ 2.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn_ndarray::NdArray;
use proptest::prelude::*;
type TestBackend = NdArray<f32>;
fn device() -> burn_ndarray::NdArrayDevice {
burn_ndarray::NdArrayDevice::Cpu
}
#[test]
fn test_ema_momentum_1_keeps_target_unchanged() {
let ema = Ema::new(1.0);
let result = ema.step(5.0, 10.0, 0);
assert!((result - 5.0).abs() < 1e-10);
}
#[test]
fn test_ema_momentum_0_copies_online() {
let ema = Ema::new(0.0);
let result = ema.step(5.0, 10.0, 0);
assert!((result - 10.0).abs() < 1e-10);
}
#[test]
fn test_ema_typical_momentum() {
let ema = Ema::new(0.996);
let result = ema.step(0.0, 1.0, 0);
assert!((result - 0.004).abs() < 1e-10);
}
#[test]
fn test_ema_converges_to_online() {
let ema = Ema::new(0.99);
let online = 1.0;
let mut target = 0.0;
for step in 0..1000 {
target = ema.step(target, online, step);
}
assert!(
(target - 1.0).abs() < 0.01,
"expected convergence to 1.0, got {target}"
);
}
#[test]
fn test_cosine_schedule_at_start() {
let schedule = CosineMomentumSchedule {
base_momentum: 0.996,
final_momentum: 1.0,
total_steps: 10000,
};
let m = schedule.get_momentum(0);
assert!(
(m - 0.996).abs() < 1e-6,
"expected 0.996 at step 0, got {m}"
);
}
#[test]
fn test_cosine_schedule_at_end() {
let schedule = CosineMomentumSchedule {
base_momentum: 0.996,
final_momentum: 1.0,
total_steps: 10000,
};
let m = schedule.get_momentum(9999);
assert!(
(m - 1.0).abs() < 1e-3,
"expected ~1.0 at final step, got {m}"
);
}
#[test]
fn test_cosine_schedule_midpoint() {
let schedule = CosineMomentumSchedule {
base_momentum: 0.996,
final_momentum: 1.0,
total_steps: 10000,
};
let m = schedule.get_momentum(5000);
assert!(
m > 0.997 && m < 0.999,
"expected ~0.998 at midpoint, got {m}"
);
}
#[test]
fn test_cosine_schedule_is_monotonically_increasing() {
let schedule = CosineMomentumSchedule {
base_momentum: 0.996,
final_momentum: 1.0,
total_steps: 1000,
};
let mut prev = schedule.get_momentum(0);
for step in 1..1000 {
let curr = schedule.get_momentum(step);
assert!(
curr >= prev - 1e-10,
"schedule not monotonic at step {step}: {prev} -> {curr}"
);
prev = curr;
}
}
#[test]
fn test_ema_with_schedule() {
let ema = Ema::with_cosine_schedule(0.996, 10000);
let m0 = ema.get_momentum(0);
let m_end = ema.get_momentum(9999);
assert!((m0 - 0.996).abs() < 1e-6);
assert!((m_end - 1.0).abs() < 1e-3);
}
#[test]
fn test_tensor_ema_momentum_1_keeps_target() {
let ema = Ema::new(1.0);
let target: Tensor<TestBackend, 2> =
Tensor::from_floats([[1.0, 2.0], [3.0, 4.0]], &device());
let online: Tensor<TestBackend, 2> =
Tensor::from_floats([[10.0, 20.0], [30.0, 40.0]], &device());
let result = ema.update_tensor(target, &online, 0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 1.0).abs() < 1e-6);
assert!((data[3] - 4.0).abs() < 1e-6);
}
#[test]
fn test_tensor_ema_momentum_0_copies_online() {
let ema = Ema::new(0.0);
let target: Tensor<TestBackend, 2> =
Tensor::from_floats([[1.0, 2.0], [3.0, 4.0]], &device());
let online: Tensor<TestBackend, 2> =
Tensor::from_floats([[10.0, 20.0], [30.0, 40.0]], &device());
let result = ema.update_tensor(target, &online, 0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
assert!((data[0] - 10.0).abs() < 1e-6);
assert!((data[3] - 40.0).abs() < 1e-6);
}
#[test]
fn test_tensor_ema_typical_momentum() {
let ema = Ema::new(0.996);
let target: Tensor<TestBackend, 1> = Tensor::zeros([4], &device());
let online: Tensor<TestBackend, 1> = Tensor::ones([4], &device());
let result = ema.update_tensor(target, &online, 0);
let data: Vec<f32> = result.into_data().to_vec().unwrap();
for &v in &data {
assert!((v - 0.004).abs() < 1e-6);
}
}
#[test]
fn test_tensor_ema_convergence() {
let ema = Ema::new(0.99);
let online: Tensor<TestBackend, 1> = Tensor::ones([8], &device());
let mut target: Tensor<TestBackend, 1> = Tensor::zeros([8], &device());
for step in 0..1000 {
target = ema.update_tensor(target, &online, step);
}
let data: Vec<f32> = target.into_data().to_vec().unwrap();
for &v in &data {
assert!(
(v - 1.0).abs() < 0.01,
"expected convergence to 1.0, got {v}"
);
}
}
#[test]
fn test_tensor_ema_with_schedule() {
let ema = Ema::with_cosine_schedule(0.996, 100);
let target: Tensor<TestBackend, 1> = Tensor::zeros([4], &device());
let online: Tensor<TestBackend, 1> = Tensor::ones([4], &device());
let result_early = ema.update_tensor(target.clone(), &online, 0);
let early: Vec<f32> = result_early.into_data().to_vec().unwrap();
let result_late = ema.update_tensor(target, &online, 99);
let late: Vec<f32> = result_late.into_data().to_vec().unwrap();
assert!(
early[0] > late[0],
"early step ({}) should move more than late step ({})",
early[0],
late[0]
);
}
#[test]
fn test_tensor_pair_update() {
let ema = Ema::new(0.5);
let pairs = vec![
(
Tensor::<TestBackend, 1>::zeros([4], &device()),
Tensor::<TestBackend, 1>::ones([4], &device()),
),
(
Tensor::<TestBackend, 1>::ones([4], &device()),
Tensor::<TestBackend, 1>::zeros([4], &device()),
),
];
let results = ema.update_tensor_pairs(pairs, 0);
assert_eq!(results.len(), 2);
let d0: Vec<f32> = results[0].clone().into_data().to_vec().unwrap();
assert!((d0[0] - 0.5).abs() < 1e-6);
let d1: Vec<f32> = results[1].clone().into_data().to_vec().unwrap();
assert!((d1[0] - 0.5).abs() < 1e-6);
}
#[test]
fn test_tensor_ema_3d_shape_preserved() {
let ema = Ema::new(0.99);
let target: Tensor<TestBackend, 3> = Tensor::zeros([2, 4, 8], &device());
let online: Tensor<TestBackend, 3> = Tensor::ones([2, 4, 8], &device());
let result = ema.update_tensor(target, &online, 0);
assert_eq!(result.dims(), [2, 4, 8]);
}
#[test]
fn test_cosine_schedule_zero_total_steps() {
let schedule = CosineMomentumSchedule {
base_momentum: 0.996,
final_momentum: 1.0,
total_steps: 0,
};
let m = schedule.get_momentum(0);
assert!((m - 1.0).abs() < 1e-10);
}
#[test]
fn test_cosine_schedule_beyond_total_steps() {
let schedule = CosineMomentumSchedule {
base_momentum: 0.996,
final_momentum: 1.0,
total_steps: 100,
};
let m = schedule.get_momentum(200);
assert!(
(m - 1.0).abs() < 1e-3,
"beyond-total-steps momentum should be near final: got {m}"
);
}
proptest! {
#[test]
fn prop_ema_converges_to_online(
momentum in 0.9f64..0.995,
steps in 1000usize..10000,
) {
let ema = Ema::new(momentum);
let online = 1.0f64;
let mut target = 0.0f64;
for s in 0..steps {
target = ema.step(target, online, s);
}
prop_assert!(
(target - online).abs() < 0.1,
"did not converge: momentum={momentum}, steps={steps}, target={target}"
);
}
#[test]
fn prop_ema_momentum_bounds(
momentum in 0.0f64..=1.0f64,
target_val in -100.0f64..100.0,
online_val in -100.0f64..100.0,
) {
let ema = Ema::new(momentum);
let result = ema.step(target_val, online_val, 0);
let lo = target_val.min(online_val);
let hi = target_val.max(online_val);
prop_assert!(
result >= lo - 1e-10 && result <= hi + 1e-10,
"result {result} out of bounds [{lo}, {hi}] with momentum {momentum}"
);
}
#[test]
fn prop_tensor_ema_matches_scalar(
momentum in 0.5f64..0.999,
) {
let ema = Ema::new(momentum);
let target_val = 3.0f32;
let online_val = 7.0f32;
let scalar_result = ema.step(target_val as f64, online_val as f64, 0) as f32;
let target: Tensor<TestBackend, 1> = Tensor::from_floats([target_val], &device());
let online: Tensor<TestBackend, 1> = Tensor::from_floats([online_val], &device());
let tensor_result: Vec<f32> = ema.update_tensor(target, &online, 0)
.into_data().to_vec().unwrap();
prop_assert!(
(tensor_result[0] - scalar_result).abs() < 1e-4,
"scalar={scalar_result}, tensor={}", tensor_result[0]
);
}
}
}