pub trait Module<Input>: ResetParams + CanUpdateWithGradients {
type Output;
fn forward(&self, input: Input) -> Self::Output;
}
Expand description
A unit of a neural network. Acts on the generic Input
and produces Module::Output
.
Generic Input
means you can implement module for multiple
input types on the same struct. For example super::Linear implements
Module for 1d inputs and 2d inputs.
Required Associated Types
Required Methods
Pass an Input
through the unit and produce Self::Output.
Can be implemented for multiple Input
types.
Example Usage
let model: Linear<7, 2> = Default::default();
let y1: Tensor1D<2> = model.forward(Tensor1D::zeros());
let y2: Tensor2D<10, 2> = model.forward(Tensor2D::zeros());
Example Implementation
struct MyMulLayer {
scale: Tensor1D<5, NoneTape>,
}
impl Module<Tensor1D<5>> for MyMulLayer {
type Output = Tensor1D<5>;
fn forward(&self, input: Tensor1D<5>) -> Self::Output {
mul(input, &self.scale)
}
}