pub trait ModelEMA<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
// Provided methods
fn ema(&mut self, other: &Self, decay: impl Into<f64>) { ... }
fn try_ema(
&mut self,
other: &Self,
decay: impl Into<f64>
) -> Result<(), D::Err> { ... }
}
Expand description
Performs model exponential moving average on two modules.
Only updates trainable parameters. For example, batch normalization running parameters are not updated.
type Model = Linear<2, 5>;
let model = dev.build_module::<Model, f32>();
let mut model_ema = model.clone();
model_ema.ema(&model, 0.001);
Provided Methods§
sourcefn ema(&mut self, other: &Self, decay: impl Into<f64>)
fn ema(&mut self, other: &Self, decay: impl Into<f64>)
Does `self = self * decay + other * (1 - decay), using crate::tensor_ops::axpy() on parameters.
Only updates trainable parameters. For example, batch normalization running parameters are not updated.