use super::tensor_collection::*;
use crate::{shapes::*, tensor::*, tensor_ops::Device};
struct ModelEMAOp {
decay: f64,
}
impl<E: Dtype, D: Device<E>> TensorVisitor<E, D> for ModelEMAOp {
type Viewer = (ViewTensorMut, ViewTensorRef);
type Err = D::Err;
type E2 = E;
type D2 = D;
fn visit<S: Shape>(
&mut self,
opts: TensorOptions<S, E, D>,
(dst, src): (&mut Tensor<S, E, D>, &Tensor<S, E, D>),
) -> Result<Option<Tensor<S, E, D>>, Self::Err> {
if opts.do_gradient_update {
dst.try_axpy(self.decay, src, 1.0 - self.decay)?;
}
Ok(None)
}
}
pub trait ModelEMA<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
fn ema(&mut self, other: &Self, decay: impl Into<f64>) {
self.try_ema(other, decay).unwrap();
}
fn try_ema(&mut self, other: &Self, decay: impl Into<f64>) -> Result<(), D::Err> {
let decay = decay.into();
let mut op = ModelEMAOp { decay };
Self::iter_tensors(&mut RecursiveWalker {
m: (self, other),
f: &mut op,
})?;
Ok(())
}
}
impl<E: Dtype, D: Device<E>, M: TensorCollection<E, D>> ModelEMA<E, D> for M {}
#[cfg(test)]
mod tests {
use super::*;
use crate::{nn::builders::*, tensor_ops::axpy, tests::*};
#[test]
fn test_model_ema() {
let dev: TestDevice = Default::default();
let distr = rand_distr::Standard;
type Model = (Linear<3, 5>, (Linear<5, 10>, BatchNorm2D<3>));
let model = dev.build_module::<Model, TestDtype>();
let mut ema1 = dev.build_module::<Model, TestDtype>();
ema1.1 .1.running_mean.fill_with_distr(distr);
ema1.1 .1.running_var.fill_with_distr(distr);
let ema0 = ema1.clone();
let decay = 0.5;
ema1.ema(&model, decay);
{
assert_eq!(
ema1.1 .1.running_mean.array(),
ema0.1 .1.running_mean.array()
);
assert_eq!(ema1.1 .1.running_var.array(), ema0.1 .1.running_var.array());
}
{
assert_eq!(
axpy(&ema0.0.weight, decay, &model.0.weight, 1.0 - decay).array(),
ema1.0.weight.array()
);
assert_eq!(
axpy(&ema0.0.bias, decay, &model.0.bias, 1.0 - decay).array(),
ema1.0.bias.array()
);
}
{
assert_eq!(
axpy(&ema0.1 .0.weight, decay, &model.1 .0.weight, 1.0 - decay).array(),
ema1.1 .0.weight.array()
);
assert_eq!(
axpy(&ema0.1 .0.bias, decay, &model.1 .0.bias, 1.0 - decay).array(),
ema1.1 .0.bias.array()
);
}
{
assert_eq!(
axpy(&ema0.1 .1.scale, decay, &model.1 .1.scale, 1.0 - decay).array(),
ema1.1 .1.scale.array()
);
assert_eq!(
axpy(&ema0.1 .1.bias, decay, &model.1 .1.bias, 1.0 - decay).array(),
ema1.1 .1.bias.array()
);
}
}
}