burn_core/module/
base.rs

1use super::{ParamId, Quantizer};
2use crate::{
3    record::Record,
4    tensor::backend::{AutodiffBackend, Backend},
5};
6use alloc::vec::Vec;
7pub use burn_derive::Module;
8use burn_tensor::{Bool, Int, Tensor, ops::Device};
9
10/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
11/// the `alloc` crate.
12pub type Devices<B> = Vec<Device<B>>;
13
14// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
15// We may consider making it public in the future.
16macro_rules! module {
17    (map=$module:ident, ops=$item:expr) => {{
18        struct Mapper;
19        impl<B: Backend> ModuleMapper<B> for Mapper {
20            fn map_float<const D: usize>(
21                &mut self,
22                _id: ParamId,
23                tensor: Tensor<B, D>,
24            ) -> Tensor<B, D> {
25                let func = $item;
26                func(tensor)
27            }
28        }
29        let mut mapper = Mapper;
30        $module.map(&mut mapper)
31    }};
32    (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
33        struct Visitor<'a, B: Backend> {
34            state: &'a mut $state_ty,
35            backend: core::marker::PhantomData<B>,
36        }
37        impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
38            fn visit_float<const D: usize>(&mut self, _id: ParamId, tensor: &Tensor<B, D>) {
39                let func = $item;
40                func(tensor, &mut self.state)
41            }
42        }
43        #[allow(clippy::redundant_closure_call)]
44        let mut state = $init();
45        let mut visitor = Visitor {
46            state: &mut state,
47            backend: core::marker::PhantomData,
48        };
49        $module.visit(&mut visitor);
50        state
51    }};
52}
53
54/// Trait for all neural network modules.
55///
56/// Modules should be created using the [derive](burn_derive::Module) attribute.
57/// This will make your module trainable, savable and loadable via
58/// `state` and `load`.
59///
60/// # Example
61///
62/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
63/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
64/// necessary to optimize and train the module on any backend.
65///
66/// ```no_run
67/// // Not necessary when using the burn crate directly.
68/// use burn_core as burn;
69///
70/// use burn::{
71///     nn,
72///     module::Module,
73///     tensor::Tensor,
74///     tensor::backend::Backend,
75/// };
76///
77/// #[derive(Module, Debug)]
78/// struct MyModule<B: Backend> {
79///   my_param: nn::Linear<B>,
80///   my_other_field: usize,
81/// }
82/// ```
83pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
84    /// Type to save and load the module.
85    type Record: Record<B>;
86
87    /// Return all the devices found in the underneath module tree added to the given vector
88    /// without duplicates.
89    fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
90
91    /// Return all the devices found in the underneath module tree without duplicates.
92    fn devices(&self) -> Devices<B> {
93        self.collect_devices(Devices::<B>::new())
94    }
95
96    /// Fork the module and all of its sub-modules to the given device.
97    ///
98    /// # Notes
99    ///
100    /// This is similar to [to_device](Module::to_device), but it ensures the output module on the
101    /// new device will have its own autodiff graph.
102    fn fork(self, device: &B::Device) -> Self;
103
104    /// Move the module and all of its sub-modules to the given device.
105    ///
106    /// # Warnings
107    ///
108    /// The operation supports autodiff and it will be registered when activated. However, this may
109    /// not be what you want. The output model will be an intermediary model, meaning that you
110    /// can't optimize it with gradient descent. If you want to optimize the output network on the
111    /// target device, use [fork](Module::fork) instead.
112    fn to_device(self, device: &B::Device) -> Self;
113
114    /// Each tensor in the module tree will not require grad.
115    ///
116    /// # Warnings
117    ///
118    /// This should not be used for inference, use [valid](AutodiffModule::valid) when using
119    /// AD modules. This is mostly useful when performing partial finetuning, which is updating only
120    /// a small fraction of the parameters instead of finetuning all of them.
121    fn no_grad(self) -> Self {
122        module!(
123            map = self,
124            ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
125        )
126    }
127
128    /// Get the number of parameters the module has, including all of its sub-modules.
129    fn num_params(&self) -> usize {
130        module!(
131            visit_float = self,
132            ops = |tensor: &Tensor<B, D>, state: &mut usize| {
133                *state += tensor.shape().num_elements();
134            },
135            state = usize,
136            init = || 0
137        )
138    }
139    /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
140    fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
141
142    /// Map each tensor parameter in the module with a [mapper](ModuleMapper).
143    fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
144
145    /// Load the module state from a record.
146    fn load_record(self, record: Self::Record) -> Self;
147
148    /// Convert the module into a record containing the state.
149    fn into_record(self) -> Self::Record;
150
151    #[cfg(feature = "std")]
152    /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
153    ///
154    /// List of supported file recorders:
155    ///
156    /// * [default](crate::record::DefaultFileRecorder)
157    /// * [bincode](crate::record::BinFileRecorder)
158    /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
159    /// * [json pretty](crate::record::PrettyJsonFileRecorder)
160    /// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
161    /// * [named mpk](crate::record::NamedMpkFileRecorder)
162    /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
163    ///
164    /// ## Notes
165    ///
166    /// The file extension is automatically added depending on the file recorder provided, you
167    /// don't have to specify it.
168    fn save_file<FR, PB>(
169        self,
170        file_path: PB,
171        recorder: &FR,
172    ) -> Result<(), crate::record::RecorderError>
173    where
174        FR: crate::record::FileRecorder<B>,
175        PB: Into<std::path::PathBuf>,
176    {
177        let record = Self::into_record(self);
178        recorder.record(record, file_path.into())
179    }
180
181    #[cfg(feature = "std")]
182    /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
183    ///
184    /// The recorder should be the same as the one used to save the module, see
185    /// [save_file](Self::save_file).
186    ///
187    /// ## Notes
188    ///
189    /// The file extension is automatically added depending on the file recorder provided, you
190    /// don't have to specify it.
191    fn load_file<FR, PB>(
192        self,
193        file_path: PB,
194        recorder: &FR,
195        device: &B::Device,
196    ) -> Result<Self, crate::record::RecorderError>
197    where
198        FR: crate::record::FileRecorder<B>,
199        PB: Into<std::path::PathBuf>,
200    {
201        let record = recorder.load(file_path.into(), device)?;
202
203        Ok(self.load_record(record))
204    }
205
206    /// Quantize the weights of the module.
207    fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
208        self.map(quantizer)
209    }
210}
211
212/// Module visitor trait.
213pub trait ModuleVisitor<B: Backend> {
214    /// Visit a float tensor in the module.
215    fn visit_float<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D>) {}
216    /// Visit an int tensor in the module.
217    fn visit_int<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D, Int>) {}
218    /// Visit a bool tensor in the module.
219    fn visit_bool<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D, Bool>) {}
220}
221
222/// Module mapper trait.
223pub trait ModuleMapper<B: Backend> {
224    /// Map a float tensor in the module.
225    fn map_float<const D: usize>(&mut self, _id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
226        tensor
227    }
228    /// Map an int tensor in the module.
229    fn map_int<const D: usize>(
230        &mut self,
231        _id: ParamId,
232        tensor: Tensor<B, D, Int>,
233    ) -> Tensor<B, D, Int> {
234        tensor
235    }
236    /// Map a bool tensor in the module.
237    fn map_bool<const D: usize>(
238        &mut self,
239        _id: ParamId,
240        tensor: Tensor<B, D, Bool>,
241    ) -> Tensor<B, D, Bool> {
242        tensor
243    }
244}
245
246/// Module with auto-differentiation backend.
247pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
248    /// Inner module without auto-differentiation.
249    type InnerModule: Module<B::InnerBackend>;
250
251    /// Get the same module, but on the inner backend without auto-differentiation.
252    fn valid(&self) -> Self::InnerModule;
253}