Trait dfdx::nn::ModelEMA

source ·
pub trait ModelEMA<E: Dtype, D: Device<E>>: TensorCollection<E, D> {
    // Provided methods
    fn ema(&mut self, other: &Self, decay: impl Into<f64>) { ... }
    fn try_ema(
        &mut self,
        other: &Self,
        decay: impl Into<f64>
    ) -> Result<(), D::Err> { ... }
}
Expand description

Performs model exponential moving average on two modules.

Only updates trainable parameters. For example, batch normalization running parameters are not updated.

type Model = Linear<2, 5>;
let model = dev.build_module::<Model, f32>();
let mut model_ema = model.clone();
model_ema.ema(&model, 0.001);

Provided Methods§

source

fn ema(&mut self, other: &Self, decay: impl Into<f64>)

Does `self = self * decay + other * (1 - decay), using crate::tensor_ops::axpy() on parameters.

Only updates trainable parameters. For example, batch normalization running parameters are not updated.

source

fn try_ema(&mut self, other: &Self, decay: impl Into<f64>) -> Result<(), D::Err>

Implementors§

source§

impl<E: Dtype, D: Device<E>, M: TensorCollection<E, D>> ModelEMA<E, D> for M