pub trait Module: Send + Sync + Debug + Display {
    type Backend: Backend;

    fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
    fn to_device(&mut self, device: <Self::Backend as Backend>::Device);
    fn load(
        &mut self,
        state: &State<<Self::Backend as Backend>::Elem>
    ) -> Result<(), LoadingError>; fn state(&self) -> State<<Self::Backend as Backend>::Elem>; fn num_params(&self) -> usize; fn update_params<O: Optimizer<Backend = Self::Backend>>(
        &mut self,
        grads: &Gradients,
        optim: &mut O
    )
    where
        Self::Backend: ADBackend
; fn load_optim_state<O: Optimizer<Backend = Self::Backend>>(
        &self,
        optim: &mut O,
        state_optim: &StateNamed<<Self::Backend as Backend>::Elem>
    )
    where
        Self::Backend: ADBackend
; fn register_optim_state<O: Optimizer<Backend = Self::Backend>>(
        &self,
        optim: &O,
        state_optim: &mut StateNamed<<Self::Backend as Backend>::Elem>
    )
    where
        Self::Backend: ADBackend
; }
Expand description

Trait for all neural network modules.

Modules should be created using the derive attribute. This will make your module trainable, savable and loadable via update_params, state and load.

Module concrete types should define their parameters via the Param struct.

Example

A module should have a backend defined as a generic parameter B. This will be used by the derive attribute to generate the code necessary to optimize and train the module on any backend.

Also, to define to forward pass of your module, you should implement Forward.

use burn::nn;
use burn::module::{Param, Module};
use burn::module::Forward;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;

#[derive(Module, Debug)]
struct MyModule<B: Backend> {
  my_param: Param<nn::Linear<B>>,
  repeat: usize,
}

impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for MyModule<B> {
   fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
       let mut x = input;

       for _ in 0..self.repeat {
           x = self.my_param.forward(x);
       }

       x
   }
}

Required Associated Types

Required Methods

Get the device list of the module and all of its sub-modules.

Move the module and all of its sub-modules to the given device.

Load the module state.

Get the module state.

Get the number of parameters the module has, including all of its sub-modules.

Update the module parameters with the given gradients and optimizer.

Load the optimizer state for the module, including all of its sub-modules.

Note

This method should only be called by generated code, see load to load the state of the optimizer.

Register the optimizer state for the module, including all of its sub-modules.

Note

This method should only be called by generated code, see state to get the state of the optimizer.

Implementors