Trait burn_core::module::Module

source ·
pub trait Module<B: Backend>: Clone + Send + Sync + Debug {
    type Record: Record;

    // Required methods
    fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V);
    fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self;
    fn load_record(self, record: Self::Record) -> Self;
    fn into_record(self) -> Self::Record;

    // Provided methods
    fn devices(&self) -> Vec<B::Device> { ... }
    fn fork(self, device: &B::Device) -> Self { ... }
    fn to_device(self, device: &B::Device) -> Self { ... }
    fn no_grad(self) -> Self { ... }
    fn num_params(&self) -> usize { ... }
}
Expand description

Trait for all neural network modules.

Modules should be created using the derive attribute. This will make your module trainable, savable and loadable via state and load.

Example

A module should have a backend defined as a generic parameter B. This will be used by the derive attribute to generate the code necessary to optimize and train the module on any backend.

// Not necessary when using the burn crate directly.
use burn_core as burn;

use burn::{
    nn,
    module::Module,
    tensor::Tensor,
    tensor::backend::Backend,
};

#[derive(Module, Debug)]
struct MyModule<B: Backend> {
  my_param: nn::Linear<B>,
  my_other_field: usize,
}

Required Associated Types§

source

type Record: Record

Type to save and load the module.

Required Methods§

source

fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V)

Visit each tensor in the module with a visitor.

source

fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self

Map each tensor in the module with a mapper.

source

fn load_record(self, record: Self::Record) -> Self

Load the module state from a record.

source

fn into_record(self) -> Self::Record

Convert the module into a record containing the state.

Provided Methods§

source

fn devices(&self) -> Vec<B::Device>

Get the device list of the module and all of its sub-modules.

source

fn fork(self, device: &B::Device) -> Self

Fork the module and all of its sub-modules to the given device.

Notes

This is similar to to_device, but it ensures the module will have its own autodiff graph.

source

fn to_device(self, device: &B::Device) -> Self

Move the module and all of its sub-modules to the given device.

Warnings

The device operations will be registered in the autodiff graph. Therefore, be sure to call backward only one time even if you have the same module on multiple devices. If you want to call backward multiple times, look into using fork instead.

source

fn no_grad(self) -> Self

Each tensor in the module tree will not require grad.

Warnings

This should not be used for inference, use valid when using AD modules. This is mostly useful when performing partial finetuning, which is updating only a small fraction of the parameters instead of finetuning all of them.

source

fn num_params(&self) -> usize

Get the number of parameters the module has, including all of its sub-modules.

Implementations on Foreign Types§

source§

impl<B: Backend> Module<B> for u32

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for f64

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for u64

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for usize

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for bool

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for u8

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for i32

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for u16

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for String

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for i64

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<T, B> Module<B> for Option<T>where T: Module<B> + Debug + Send + Sync + Clone, B: Backend,

§

type Record = Option<<T as Module<B>>::Record>

source§

fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self

source§

fn load_record(self, record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for i8

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<T, B> Module<B> for Vec<T>where T: Module<B> + Debug + Send + Sync + Clone, B: Backend,

§

type Record = Vec<<T as Module<B>>::Record, Global>

source§

fn num_params(&self) -> usize

source§

fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self

source§

fn into_record(self) -> Self::Record

source§

fn load_record(self, record: Self::Record) -> Self

source§

impl<B: Backend> Module<B> for i16

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<const N: usize, T, B> Module<B> for [T; N]where T: Module<B> + Debug + Send + Sync + Clone + Copy, T::Record: Debug, B: Backend,

§

type Record = [<T as Module<B>>::Record; N]

source§

fn devices(&self) -> Vec<<B as Backend>::Device>

source§

fn num_params(&self) -> usize

source§

fn visit<V: ModuleVisitor<B>>(&self, visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, mapper: &mut M) -> Self

source§

fn load_record(self, record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

source§

impl<B: Backend> Module<B> for f32

§

type Record = ()

source§

fn visit<V: ModuleVisitor<B>>(&self, _visitor: &mut V)

source§

fn map<M: ModuleMapper<B>>(self, _mapper: &mut M) -> Self

source§

fn load_record(self, _record: Self::Record) -> Self

source§

fn into_record(self) -> Self::Record

Implementors§

source§

impl<B: Backend> Module<B> for Conv1dPaddingConfig

§

type Record = ()

source§

impl<B: Backend> Module<B> for Conv2dPaddingConfig

§

type Record = ()

source§

impl<B: Backend> Module<B> for MultiHeadAttention<B>

source§

impl<B: Backend> Module<B> for Conv1d<B>

source§

impl<B: Backend> Module<B> for Conv2d<B>

source§

impl<B: Backend> Module<B> for AvgPool2d

§

type Record = ()

source§

impl<B: Backend> Module<B> for MaxPool2d

§

type Record = ()

source§

impl<B: Backend> Module<B> for Dropout

§

type Record = ()

source§

impl<B: Backend> Module<B> for Embedding<B>

source§

impl<B: Backend> Module<B> for GELU

§

type Record = ()

source§

impl<B: Backend> Module<B> for LayerNorm<B>

source§

impl<B: Backend> Module<B> for Linear<B>

source§

impl<B: Backend> Module<B> for ReLU

§

type Record = ()

source§

impl<B: Backend> Module<B> for PositionWiseFeedForward<B>

source§

impl<B: Backend> Module<B> for TransformerDecoder<B>

source§

impl<B: Backend> Module<B> for TransformerDecoderLayer<B>

source§

impl<B: Backend> Module<B> for TransformerEncoder<B>

source§

impl<B: Backend> Module<B> for TransformerEncoderLayer<B>

source§

impl<B: Backend> Module<B> for bf16

§

type Record = ()

source§

impl<B: Backend> Module<B> for f16

§

type Record = ()

source§

impl<B: Backend, const D: usize> Module<B> for BatchNorm<B, D>

source§

impl<const D: usize, B: Backend> Module<B> for Param<Tensor<B, D>>

§

type Record = Param<Tensor<B, D, Float>>

source§

impl<const D: usize, B: Backend> Module<B> for RunningState<Tensor<B, D>>

§

type Record = Param<Tensor<B, D, Float>>