use candle_core::{Result, Tensor};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct EMAConfig {
pub decay: f64,
}
impl Default for EMAConfig {
fn default() -> Self {
Self {
decay: 0.9999, }
}
}
pub struct EMA {
config: EMAConfig,
shadow_params: HashMap<usize, Tensor>,
}
impl EMA {
pub fn new(config: EMAConfig) -> Self {
Self {
config,
shadow_params: HashMap::new(),
}
}
pub fn update(&mut self, params: &[Tensor]) -> Result<()> {
for (i, param) in params.iter().enumerate() {
let shadow = self.shadow_params.entry(i).or_insert_with(|| {
param.clone()
});
*shadow = ((shadow.clone() * self.config.decay)?
+ (param * (1.0 - self.config.decay))?)?;
}
Ok(())
}
pub fn get_params(&self) -> Vec<Tensor> {
let mut params = Vec::new();
for i in 0..self.shadow_params.len() {
if let Some(shadow) = self.shadow_params.get(&i) {
params.push(shadow.clone());
}
}
params
}
pub fn copy_to(&self, params: &mut [Tensor]) -> Result<()> {
for (i, param) in params.iter_mut().enumerate() {
if let Some(shadow) = self.shadow_params.get(&i) {
*param = shadow.clone();
}
}
Ok(())
}
pub fn copy_from(&mut self, params: &[Tensor]) {
for (i, param) in params.iter().enumerate() {
self.shadow_params.insert(i, param.clone());
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use candle_core::Device;
#[test]
fn test_ema_creation() {
let config = EMAConfig::default();
let ema = EMA::new(config);
assert_eq!(ema.shadow_params.len(), 0);
}
#[test]
fn test_ema_update() -> Result<()> {
let device = Device::Cpu;
let param = Tensor::ones((10, 10), candle_core::DType::F32, &device)?;
let config = EMAConfig { decay: 0.9 };
let mut ema = EMA::new(config);
ema.update(&[param.clone()])?;
let shadow = &ema.shadow_params[&0];
let diff = (shadow.clone() - param.clone())?.abs()?.sum_all()?.to_scalar::<f32>()?;
assert!(diff < 1e-6);
Ok(())
}
#[test]
fn test_ema_smoothing() -> Result<()> {
let device = Device::Cpu;
let config = EMAConfig { decay: 0.9 };
let mut ema = EMA::new(config);
let param1 = Tensor::ones((5, 5), candle_core::DType::F32, &device)?;
ema.update(&[param1.clone()])?;
let param2 = Tensor::zeros((5, 5), candle_core::DType::F32, &device)?;
ema.update(&[param2.clone()])?;
let shadow = &ema.shadow_params[&0];
let mean_val = shadow.mean_all()?.to_scalar::<f32>()?;
assert!((mean_val - 0.9).abs() < 1e-6);
Ok(())
}
#[test]
fn test_copy_to() -> Result<()> {
let device = Device::Cpu;
let config = EMAConfig { decay: 0.95 };
let mut ema = EMA::new(config);
let param = Tensor::ones((5, 5), candle_core::DType::F32, &device)?;
ema.update(&[param.clone()])?;
let param2 = Tensor::zeros((5, 5), candle_core::DType::F32, &device)?;
ema.update(&[param2.clone()])?;
let mut params = vec![Tensor::ones((5, 5), candle_core::DType::F32, &device)?];
ema.copy_to(&mut params)?;
let expected = 0.95; let actual = params[0].mean_all()?.to_scalar::<f32>()?;
assert!((actual - expected).abs() < 1e-6);
Ok(())
}
#[test]
fn test_copy_from() -> Result<()> {
let device = Device::Cpu;
let config = EMAConfig::default();
let mut ema = EMA::new(config);
let param = Tensor::full(2.0f32, (5, 5), &device)?;
ema.copy_from(&[param.clone()]);
let shadow = &ema.shadow_params[&0];
let diff = (shadow.clone() - param)?.abs()?.sum_all()?.to_scalar::<f32>()?;
assert!(diff < 1e-6);
Ok(())
}
}