pub trait Module<B: Backend>: Clone + Send + Sync + Debug {
type Record: Record;
// Required methods
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
fn load_record(self, record: Self::Record) -> Self;
fn into_record(self) -> Self::Record;
// Provided methods
fn devices(&self) -> Vec<B::Device> { ... }
fn fork(self, device: &B::Device) -> Self { ... }
fn to_device(self, device: &B::Device) -> Self { ... }
fn no_grad(self) -> Self { ... }
fn num_params(&self) -> usize { ... }
}Expand description
Trait for all neural network modules.
Modules should be created using the derive attribute. This will make your module trainable, savable and loadable via state and load.
Example
A module should have a backend defined as a generic parameter B. This will be used by the derive attribute to generate the code necessary to optimize and train the module on any backend.
// Not necessary when using the burn crate directly.
use burn_core as burn;
use burn::{
nn,
module::Module,
tensor::Tensor,
tensor::backend::Backend,
};
#[derive(Module, Debug)]
struct MyModule<B: Backend> {
my_param: nn::Linear<B>,
my_other_field: usize,
}Required Associated Types§
Required Methods§
sourcefn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V)
fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V)
Visit each tensor in the module with a visitor.
sourcefn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self
fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self
Map each tensor in the module with a mapper.
sourcefn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Load the module state from a record.
sourcefn into_record(self) -> Self::Record
fn into_record(self) -> Self::Record
Convert the module into a record containing the state.
Provided Methods§
sourcefn devices(&self) -> Vec<B::Device>
fn devices(&self) -> Vec<B::Device>
Get the device list of the module and all of its sub-modules.
sourcefn to_device(self, device: &B::Device) -> Self
fn to_device(self, device: &B::Device) -> Self
Move the module and all of its sub-modules to the given device.
Warnings
The device operations will be registered in the autodiff graph. Therefore, be sure to call backward only one time even if you have the same module on multiple devices. If you want to call backward multiple times, look into using fork instead.
sourcefn num_params(&self) -> usize
fn num_params(&self) -> usize
Get the number of parameters the module has, including all of its sub-modules.