Trait ApplyGradient

Source
pub trait ApplyGradient<Delta, T> {
    type Output;

    // Required methods
    fn apply_gradient(&mut self, grad: &Delta, lr: T) -> Result<Self::Output>;
    fn apply_gradient_with_decay(
        &mut self,
        grad: &Delta,
        lr: T,
        decay: T,
    ) -> Result<Self::Output>;
}
Expand description

A trait declaring basic gradient-related routines for a neural network

Required Associated Types§

Required Methods§

Source

fn apply_gradient(&mut self, grad: &Delta, lr: T) -> Result<Self::Output>

Source

fn apply_gradient_with_decay( &mut self, grad: &Delta, lr: T, decay: T, ) -> Result<Self::Output>

Implementations on Foreign Types§

Source§

impl<A, S, T, D> ApplyGradient<ArrayBase<T, D>, A> for ArrayBase<S, D>
where A: Float + FromPrimitive + ScalarOperand, S: DataMut<Elem = A>, T: Data<Elem = A>, D: Dimension,

Source§

type Output = ()

Source§

fn apply_gradient( &mut self, grad: &ArrayBase<T, D>, lr: A, ) -> Result<Self::Output>

Source§

fn apply_gradient_with_decay( &mut self, grad: &ArrayBase<T, D>, lr: A, decay: A, ) -> Result<Self::Output>

Implementors§

Source§

impl<A, S, T, D> ApplyGradient<ParamsBase<T, D>, A> for ParamsBase<S, D>
where A: Float + FromPrimitive + ScalarOperand, S: DataMut<Elem = A>, T: Data<Elem = A>, D: Dimension,