ApplyGradientExt

Trait ApplyGradientExt 

Source
pub trait ApplyGradientExt<Delta, T>: ApplyGradient<Delta, T> {
    type Velocity;

    // Required methods
    fn apply_gradient_with_momentum(
        &mut self,
        grad: &Delta,
        lr: T,
        momentum: T,
        velocity: &mut Self::Velocity,
    ) -> Result<Self::Output, Error>;
    fn apply_gradient_with_decay_and_momentum(
        &mut self,
        grad: &Delta,
        lr: T,
        decay: T,
        momentum: T,
        velocity: &mut Self::Velocity,
    ) -> Result<Self::Output, Error>;
}
Expand description

This trait extends the ApplyGradient trait by allowing for momentum-based optimization

Required Associated Types§

Required Methods§

Source

fn apply_gradient_with_momentum( &mut self, grad: &Delta, lr: T, momentum: T, velocity: &mut Self::Velocity, ) -> Result<Self::Output, Error>

Source

fn apply_gradient_with_decay_and_momentum( &mut self, grad: &Delta, lr: T, decay: T, momentum: T, velocity: &mut Self::Velocity, ) -> Result<Self::Output, Error>

Implementations on Foreign Types§

Source§

impl<A, S, T, D> ApplyGradientExt<ArrayBase<T, D>, A> for ArrayBase<S, D>
where A: Float + FromPrimitive + ScalarOperand, S: DataMut<Elem = A>, T: Data<Elem = A>, D: Dimension,

Source§

type Velocity = ArrayBase<OwnedRepr<A>, D>

Source§

fn apply_gradient_with_momentum( &mut self, grad: &ArrayBase<T, D>, lr: A, momentum: A, velocity: &mut <ArrayBase<S, D> as ApplyGradientExt<ArrayBase<T, D>, A>>::Velocity, ) -> Result<<ArrayBase<S, D> as ApplyGradient<ArrayBase<T, D>, A>>::Output, Error>

Source§

fn apply_gradient_with_decay_and_momentum( &mut self, grad: &ArrayBase<T, D>, lr: A, decay: A, momentum: A, velocity: &mut <ArrayBase<S, D> as ApplyGradientExt<ArrayBase<T, D>, A>>::Velocity, ) -> Result<<ArrayBase<S, D> as ApplyGradient<ArrayBase<T, D>, A>>::Output, Error>

Implementors§

Source§

impl<A, S, T, D> ApplyGradientExt<ParamsBase<T, D>, A> for ParamsBase<S, D>
where A: Float + FromPrimitive + ScalarOperand, S: DataMut<Elem = A>, T: Data<Elem = A>, D: Dimension,