pub trait Module<Input>: ResetParams + CanUpdateWithGradients {
    type Output;

    fn forward(&self, input: Input) -> Self::Output;

    fn forward_mut(&mut self, input: Input) -> Self::Output { ... }
}
Expand description

A unit of a neural network. Acts on the generic Input and produces Module::Output.

Generic Input means you can implement module for multiple input types on the same struct. For example super::Linear implements Module for 1d inputs and 2d inputs.

Additionally, modules can specify different behavior based on whether it is a mutable forward (Module::forward_mut()) or non-mutable forward (Module::forward()). For example, a Dropout layer, which uses an rng under the hood, may not modify the input tensor in Module::forward(), since it cannot modify it’s underlying rng.

Required Associated Types

The type that this unit produces given Input.

Required Methods

Pass an Input through the unit and produce Self::Output. Can be implemented for multiple Input types.

This should never change self. See Module::forward_mut() for version that can mutate self.

Example Usage
let model: Linear<7, 2> = Default::default();
let y1: Tensor1D<2> = model.forward(Tensor1D::zeros());
let y2: Tensor2D<10, 2> = model.forward(Tensor2D::zeros());
Example Implementation
struct MyMulLayer {
    scale: Tensor1D<5, NoneTape>,
}

impl Module<Tensor1D<5>> for MyMulLayer {
    type Output = Tensor1D<5>;
    fn forward(&self, input: Tensor1D<5>) -> Self::Output {
        mul(input, &self.scale)
    }
}

Provided Methods

Pass an Input through the unit and produce Self::Output. Can be implemented for multiple Input types.

This can change self. See Module::forward() for immutable version

Example Usage
let mut model: Linear<7, 2> = Default::default();
let y1: Tensor1D<2> = model.forward_mut(Tensor1D::zeros());
let y2: Tensor2D<10, 2> = model.forward_mut(Tensor2D::zeros());

Implementations on Foreign Types

Calls forward sequentially on each module in the tuple.

Calls forward sequentially on each module in the tuple.

Calls forward sequentially on each module in the tuple.

Calls forward sequentially on each module in the tuple.

Calls forward sequentially on each module in the tuple.

Implementors