Trait ApplyGradientExt

Source
pub trait ApplyGradientExt<Delta, T>: ApplyGradient<Delta, T> {
    type Velocity;

    // Required methods
    fn apply_gradient_with_momentum(
        &mut self,
        grad: &Delta,
        lr: T,
        momentum: T,
        velocity: &mut Self::Velocity,
    ) -> Result<Self::Output>;
    fn apply_gradient_with_decay_and_momentum(
        &mut self,
        grad: &Delta,
        lr: T,
        decay: T,
        momentum: T,
        velocity: &mut Self::Velocity,
    ) -> Result<Self::Output>;
}
Expand description

This trait extends the ApplyGradient trait by allowing for momentum-based optimization

Required Associated Types§

Required Methods§

Source

fn apply_gradient_with_momentum( &mut self, grad: &Delta, lr: T, momentum: T, velocity: &mut Self::Velocity, ) -> Result<Self::Output>

Source

fn apply_gradient_with_decay_and_momentum( &mut self, grad: &Delta, lr: T, decay: T, momentum: T, velocity: &mut Self::Velocity, ) -> Result<Self::Output>

Implementations on Foreign Types§

Source§

impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
where A: Float + FromPrimitive + ScalarOperand, S: DataMut<Elem = A>, T: Data<Elem = A>, D: Dimension,

Source§

type Velocity = ArrayBase<OwnedRepr<A>, D>

Source§

fn apply_gradient_with_momentum( &mut self, grad: &ArrayBase<T, D>, lr: A, momentum: A, velocity: &mut Self::Velocity, ) -> Result<Self::Output>

Source§

fn apply_gradient_with_decay_and_momentum( &mut self, grad: &ArrayBase<T, D>, lr: A, decay: A, momentum: A, velocity: &mut Self::Velocity, ) -> Result<Self::Output>

Implementors§

Source§

impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>
where A: Float + FromPrimitive + ScalarOperand, S: DataMut<Elem = A>, T: Data<Elem = A>, D: Dimension,