pub trait Module<Input>: ResetParams + CanUpdateWithGradients {
type Output;
fn forward(&self, input: Input) -> Self::Output;
fn forward_mut(&mut self, input: Input) -> Self::Output { ... }
}
Expand description
A unit of a neural network. Acts on the generic Input
and produces Module::Output
.
- Immutable version: Module::forward()
- Mutable version: Module::forward_mut()
Generic Input
means you can implement module for multiple
input types on the same struct. For example super::Linear implements
Module for 1d inputs and 2d inputs.
Additionally, modules can specify different behavior based on whether it is a mutable forward (Module::forward_mut()) or non-mutable forward (Module::forward()). For example, a Dropout layer, which uses an rng under the hood, may not modify the input tensor in Module::forward(), since it cannot modify it’s underlying rng.
Required Associated Types
Required Methods
Pass an Input
through the unit and produce Self::Output.
Can be implemented for multiple Input
types.
This should never change self
.
See Module::forward_mut() for version that can mutate self
.
Example Usage
let model: Linear<7, 2> = Default::default();
let y1: Tensor1D<2> = model.forward(Tensor1D::zeros());
let y2: Tensor2D<10, 2> = model.forward(Tensor2D::zeros());
Example Implementation
struct MyMulLayer {
scale: Tensor1D<5, NoneTape>,
}
impl Module<Tensor1D<5>> for MyMulLayer {
type Output = Tensor1D<5>;
fn forward(&self, input: Tensor1D<5>) -> Self::Output {
mul(input, &self.scale)
}
}
Provided Methods
fn forward_mut(&mut self, input: Input) -> Self::Output
fn forward_mut(&mut self, input: Input) -> Self::Output
Pass an Input
through the unit and produce Self::Output.
Can be implemented for multiple Input
types.
This can change self
. See Module::forward() for immutable version
Example Usage
let mut model: Linear<7, 2> = Default::default();
let y1: Tensor1D<2> = model.forward_mut(Tensor1D::zeros());
let y2: Tensor2D<10, 2> = model.forward_mut(Tensor2D::zeros());