pub trait Module: Send + Sync + Debug + Display {
type Backend: Backend;
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>;
fn to_device(&mut self, device: <Self::Backend as Backend>::Device);
fn load(
&mut self,
state: &State<<Self::Backend as Backend>::Elem>
) -> Result<(), LoadingError>;
fn state(&self) -> State<<Self::Backend as Backend>::Elem>;
fn num_params(&self) -> usize;
fn update_params<O: Optimizer<Backend = Self::Backend>>(
&mut self,
grads: &Gradients,
optim: &mut O
)
where
Self::Backend: ADBackend;
fn load_optim_state<O: Optimizer<Backend = Self::Backend>>(
&self,
optim: &mut O,
state_optim: &StateNamed<<Self::Backend as Backend>::Elem>
)
where
Self::Backend: ADBackend;
fn register_optim_state<O: Optimizer<Backend = Self::Backend>>(
&self,
optim: &O,
state_optim: &mut StateNamed<<Self::Backend as Backend>::Elem>
)
where
Self::Backend: ADBackend;
}
Expand description
Trait for all neural network modules.
Modules should be created using the derive attribute. This will make your module trainable, savable and loadable via update_params, state and load.
Module concrete types should define their parameters via the Param struct.
Example
A module should have a backend defined as a generic parameter B. This will be used by the derive attribute to generate the code necessary to optimize and train the module on any backend.
Also, to define to forward pass of your module, you should implement Forward.
use burn::nn;
use burn::module::{Param, Module};
use burn::module::Forward;
use burn::tensor::Tensor;
use burn::tensor::backend::Backend;
#[derive(Module, Debug)]
struct MyModule<B: Backend> {
my_param: Param<nn::Linear<B>>,
repeat: usize,
}
impl<B: Backend> Forward<Tensor<B, 2>, Tensor<B, 2>> for MyModule<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let mut x = input;
for _ in 0..self.repeat {
x = self.my_param.forward(x);
}
x
}
}
Required Associated Types
Required Methods
sourcefn devices(&self) -> Vec<<Self::Backend as Backend>::Device>
fn devices(&self) -> Vec<<Self::Backend as Backend>::Device>
Get the device list of the module and all of its sub-modules.
sourcefn to_device(&mut self, device: <Self::Backend as Backend>::Device)
fn to_device(&mut self, device: <Self::Backend as Backend>::Device)
Move the module and all of its sub-modules to the given device.
sourcefn load(
&mut self,
state: &State<<Self::Backend as Backend>::Elem>
) -> Result<(), LoadingError>
fn load(
&mut self,
state: &State<<Self::Backend as Backend>::Elem>
) -> Result<(), LoadingError>
Load the module state.
sourcefn num_params(&self) -> usize
fn num_params(&self) -> usize
Get the number of parameters the module has, including all of its sub-modules.