pub trait ApplyGradientExt<Delta>: ApplyGradient<Delta> {
type Velocity;
// Required methods
fn apply_gradient_with_momentum(
&mut self,
grad: &Delta,
lr: Self::Elem,
momentum: Self::Elem,
velocity: &mut Self::Velocity,
) -> Result<Self::Output>;
fn apply_gradient_with_decay_and_momentum(
&mut self,
grad: &Delta,
lr: Self::Elem,
decay: Self::Elem,
momentum: Self::Elem,
velocity: &mut Self::Velocity,
) -> Result<Self::Output>;
}
Expand description
This trait extends the ApplyGradient trait by allowing for momentum-based optimization