pub trait Module<B>: Clone + Send + Debugwhere
B: Backend,{
type Record: Record<B>;
// Required methods
fn collect_devices(
&self,
devices: Vec<<B as Backend>::Device>
) -> Vec<<B as Backend>::Device>;
fn fork(self, device: &<B as Backend>::Device) -> Self;
fn to_device(self, device: &<B as Backend>::Device) -> Self;
fn visit<Visitor>(&self, visitor: &mut Visitor)
where Visitor: ModuleVisitor<B>;
fn map<Mapper>(self, mapper: &mut Mapper) -> Self
where Mapper: ModuleMapper<B>;
fn load_record(self, record: Self::Record) -> Self;
fn into_record(self) -> Self::Record;
// Provided methods
fn devices(&self) -> Vec<<B as Backend>::Device> { ... }
fn no_grad(self) -> Self { ... }
fn num_params(&self) -> usize { ... }
fn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR
) -> Result<(), RecorderError>
where FR: FileRecorder<B>,
PB: Into<PathBuf> { ... }
fn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as Backend>::Device
) -> Result<Self, RecorderError>
where FR: FileRecorder<B>,
PB: Into<PathBuf> { ... }
}
Expand description
Trait for all neural network modules.
Modules should be created using the derive attribute.
This will make your module trainable, savable and loadable via
state
and load
.
§Example
A module should have a backend defined as a generic parameter B. This will be used by the derive attribute to generate the code necessary to optimize and train the module on any backend.
// Not necessary when using the burn crate directly.
use burn_core as burn;
use burn::{
nn,
module::Module,
tensor::Tensor,
tensor::backend::Backend,
};
#[derive(Module, Debug)]
struct MyModule<B: Backend> {
my_param: nn::Linear<B>,
my_other_field: usize,
}
Required Associated Types§
Required Methods§
sourcefn collect_devices(
&self,
devices: Vec<<B as Backend>::Device>
) -> Vec<<B as Backend>::Device>
fn collect_devices( &self, devices: Vec<<B as Backend>::Device> ) -> Vec<<B as Backend>::Device>
Return all the devices found in the underneath module tree added to the given vector without duplicates.
sourcefn to_device(self, device: &<B as Backend>::Device) -> Self
fn to_device(self, device: &<B as Backend>::Device) -> Self
Move the module and all of its sub-modules to the given device.
§Warnings
The device operations will be registered in the autodiff graph. Therefore, be sure to call backward only one time even if you have the same module on multiple devices. If you want to call backward multiple times, look into using fork instead.
sourcefn visit<Visitor>(&self, visitor: &mut Visitor)where
Visitor: ModuleVisitor<B>,
fn visit<Visitor>(&self, visitor: &mut Visitor)where
Visitor: ModuleVisitor<B>,
Visit each tensor parameter in the module with a visitor.
sourcefn map<Mapper>(self, mapper: &mut Mapper) -> Selfwhere
Mapper: ModuleMapper<B>,
fn map<Mapper>(self, mapper: &mut Mapper) -> Selfwhere
Mapper: ModuleMapper<B>,
Map each tensor parameter in the module with a mapper.
sourcefn load_record(self, record: Self::Record) -> Self
fn load_record(self, record: Self::Record) -> Self
Load the module state from a record.
sourcefn into_record(self) -> Self::Record
fn into_record(self) -> Self::Record
Convert the module into a record containing the state.
Provided Methods§
sourcefn devices(&self) -> Vec<<B as Backend>::Device>
fn devices(&self) -> Vec<<B as Backend>::Device>
Return all the devices found in the underneath module tree without duplicates.
sourcefn num_params(&self) -> usize
fn num_params(&self) -> usize
Get the number of parameters the module has, including all of its sub-modules.
sourcefn save_file<FR, PB>(
self,
file_path: PB,
recorder: &FR
) -> Result<(), RecorderError>
fn save_file<FR, PB>( self, file_path: PB, recorder: &FR ) -> Result<(), RecorderError>
Save the module to a file using the provided file recorder.
List of supported file recorders:
- default
- bincode
- bincode compressed with gzip
- json pretty
- json compressed with gzip
- named mpk
- named mpk compressed with gzip
§Notes
The file extension is automatically added depending on the file recorder provided, you don’t have to specify it.
sourcefn load_file<FR, PB>(
self,
file_path: PB,
recorder: &FR,
device: &<B as Backend>::Device
) -> Result<Self, RecorderError>
fn load_file<FR, PB>( self, file_path: PB, recorder: &FR, device: &<B as Backend>::Device ) -> Result<Self, RecorderError>
Load the module from a file using the provided file recorder.
The recorder should be the same as the one used to save the module, see save_file.
§Notes
The file extension is automatically added depending on the file recorder provided, you don’t have to specify it.