Module

Trait Module 

Source
pub trait Module: Send + Sync {
    // Required method
    fn forward(&self, input: &Tensor) -> Tensor;

    // Provided methods
    fn parameters(&self) -> Vec<&Tensor> { ... }
    fn parameters_mut(&mut self) -> Vec<&mut Tensor> { ... }
    fn train(&mut self) { ... }
    fn eval(&mut self) { ... }
    fn training(&self) -> bool { ... }
    fn zero_grad(&mut self) { ... }
    fn num_parameters(&self) -> usize { ... }
}
Expand description

Base trait for all neural network modules.

Every layer, activation function, and container implements this trait, providing a uniform interface for:

  • Forward computation
  • Parameter access (for optimizers)
  • Training/evaluation mode switching

§Example

use aprender::nn::{Module, Linear};
use aprender::autograd::Tensor;

let layer = Linear::new(10, 5);
let x = Tensor::randn(&[32, 10]);
let output = layer.forward(&x);  // [32, 5]

// Access parameters for gradient descent
for param in layer.parameters() {
    println!("Shape: {:?}", param.shape());
}

Required Methods§

Source

fn forward(&self, input: &Tensor) -> Tensor

Perform forward computation.

This is the main computation method. Given an input tensor, it returns the output tensor. The computation graph is automatically recorded for backpropagation.

Provided Methods§

Source

fn parameters(&self) -> Vec<&Tensor>

Get references to all learnable parameters.

Used by optimizers to iterate over parameters for gradient updates. Parameters are returned in a deterministic order.

Source

fn parameters_mut(&mut self) -> Vec<&mut Tensor>

Get mutable references to all learnable parameters.

Used by optimizers to update parameters in-place.

Source

fn train(&mut self)

Set the module to training mode.

This affects layers like Dropout (active during training) and BatchNorm (uses batch statistics during training).

Source

fn eval(&mut self)

Set the module to evaluation mode.

This affects layers like Dropout (disabled during eval) and BatchNorm (uses running statistics during eval).

Source

fn training(&self) -> bool

Check if the module is in training mode.

Source

fn zero_grad(&mut self)

Zero out gradients for all parameters.

Should be called before each training iteration.

Source

fn num_parameters(&self) -> usize

Get the number of learnable parameters.

Implementors§

Source§

impl Module for EdgeConv

Source§

impl Module for GATConv

Source§

impl Module for GCNConv

Source§

impl Module for GINConv

Source§

impl Module for GraphSAGEConv

Source§

impl Module for FakeQuantize

Source§

impl Module for AlphaDropout

Source§

impl Module for AvgPool2d

Source§

impl Module for BatchNorm1d

Source§

impl Module for Bidirectional

Source§

impl Module for Conv1d

Source§

impl Module for Conv2d

Source§

impl Module for DropBlock

Source§

impl Module for DropConnect

Source§

impl Module for Dropout2d

Source§

impl Module for Dropout

Source§

impl Module for Flatten

Source§

impl Module for GELU

Source§

impl Module for GRU

Source§

impl Module for GlobalAvgPool2d

Source§

impl Module for GroupNorm

Source§

impl Module for GroupedQueryAttention

Source§

impl Module for InstanceNorm

Source§

impl Module for LSTM

Source§

impl Module for LayerNorm

Source§

impl Module for LeakyReLU

Source§

impl Module for Linear

Source§

impl Module for LinearAttention

Source§

impl Module for MaxPool1d

Source§

impl Module for MaxPool2d

Source§

impl Module for MultiHeadAttention

Source§

impl Module for PositionalEncoding

Source§

impl Module for RMSNorm

Source§

impl Module for ReLU

Source§

impl Module for Sequential

Source§

impl Module for Sigmoid

Source§

impl Module for Softmax

Source§

impl Module for Tanh

Source§

impl Module for TransformerDecoderLayer

Source§

impl Module for TransformerEncoderLayer

Source§

impl Module for VAE

Source§

impl<E: TransferEncoder> Module for DomainAdapter<E>

Source§

impl<E: TransferEncoder> Module for MultiTaskHead<E>

Source§

impl<M: Module> Module for TransferableEncoder<M>