Trait RevModule

Source
pub trait RevModule<X>: Module<X> {
    type SelfGrads: Gradients;

    // Required methods
    fn reverse(
        &self,
        inputs: &X,
        grads_wrt_output: &<Self as Module<X>>::Output,
    ) -> (X, Self::SelfGrads);
    fn apply(
        &mut self,
        applyer: &mut impl GradApplyer,
        updates: Self::SelfGrads,
    ) -> Result<(), Error>;
}
Expand description

A unit of computation which can do backpropagation without knowledge of any additional state.

This trait is to be implemented by individual network layers, but not compositions of them.

Required Associated Types§

Source

type SelfGrads: Gradients

The type describing gradients with respect to the modules’ own parameters.

Required Methods§

Source

fn reverse( &self, inputs: &X, grads_wrt_output: &<Self as Module<X>>::Output, ) -> (X, Self::SelfGrads)

Returns the gradients with respect to the input, and the gradients with respect to any internal parameters.

Source

fn apply( &mut self, applyer: &mut impl GradApplyer, updates: Self::SelfGrads, ) -> Result<(), Error>

Applies a gradient update step: adding product of the provided gradients and the scalar to the parameters.

Dyn Compatibility§

This trait is not dyn compatible.

In older versions of Rust, dyn compatibility was called "object safety", so this trait is not object safe.

Implementors§

Source§

impl<E: Dtype + MatMulImpl, const I: usize, const O: usize> RevModule<[E; I]> for Dense<E, I, O>

Source§

impl<E: Dtype + MatMulImpl, const I: usize, const O: usize, C: Conv1dKernel<E, Const<I>, Const<O>>> RevModule<[E; I]> for Conv1d<E, I, O, C>

Source§

impl<E: Dtype, const I: usize> RevModule<[E; I]> for Bias1d<E, I>

Source§

impl<E: Dtype, const I: usize> RevModule<[E; I]> for Diag<E, I>

Source§

impl<E: Dtype, const I: usize> RevModule<[E; I]> for ScalarScale<E>

Source§

impl<E: Float + MatMulImpl, const I: usize> RevModule<[E; I]> for RMSDiv<E, I>

Source§

impl<E: Float, const I: usize> RevModule<[E; I]> for Activation<E>

Source§

impl<E: Float, const I: usize> RevModule<[E; I]> for Softmax

Source§

impl<E: Float, const I: usize> RevModule<[E; I]> for Swish<E, I>