burn_core/module/
base.rs

1use super::{Param, ParamId, Quantizer};
2use crate::{
3    record::Record,
4    tensor::backend::{AutodiffBackend, Backend},
5};
6use alloc::{string::String, vec::Vec};
7pub use burn_derive::Module;
8use burn_tensor::{Bool, Int, Tensor, ops::Device};
9
10/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
11/// the `alloc` crate.
12pub type Devices<B> = Vec<Device<B>>;
13
14// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
15// We may consider making it public in the future.
16macro_rules! module {
17    (map=$module:ident, ops=$item:expr) => {{
18        struct Mapper;
19        impl<B: Backend> ModuleMapper<B> for Mapper {
20            fn map_float<const D: usize>(
21                &mut self,
22                param: Param<Tensor<B, D>>,
23            ) -> Param<Tensor<B, D>> {
24                let (id, tensor, mapper) = param.consume();
25                let func = $item;
26                let tensor = func(tensor);
27                Param::from_mapped_value(id, tensor, mapper)
28            }
29        }
30        let mut mapper = Mapper;
31        $module.map(&mut mapper)
32    }};
33    (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
34        struct Visitor<'a, B: Backend> {
35            state: &'a mut $state_ty,
36            backend: core::marker::PhantomData<B>,
37        }
38        impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
39            fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
40                let func = $item;
41                func(&param.val(), &mut self.state)
42            }
43        }
44        #[allow(clippy::redundant_closure_call)]
45        let mut state = $init();
46        let mut visitor = Visitor {
47            state: &mut state,
48            backend: core::marker::PhantomData,
49        };
50        $module.visit(&mut visitor);
51        state
52    }};
53}
54
55/// Trait for all neural network modules.
56///
57/// Modules should be created using the [derive](burn_derive::Module) attribute.
58/// This will make your module trainable, savable and loadable via
59/// `state` and `load`.
60///
61/// # Example
62///
63/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
64/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
65/// necessary to optimize and train the module on any backend.
66///
67/// ```rust, ignore
68/// // Not necessary when using the burn crate directly.
69/// use burn_core as burn;
70///
71/// use burn::{
72///     module::Module,
73///     nn::Linear,
74///     tensor::Tensor,
75///     tensor::backend::Backend,
76/// };
77///
78/// #[derive(Module, Debug)]
79/// struct MyModule<B: Backend> {
80///   my_param: Linear<B>,
81///   my_other_field: usize,
82/// }
83/// ```
84pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
85    /// Type to save and load the module.
86    type Record: Record<B>;
87
88    /// Return all the devices found in the underneath module tree added to the given vector
89    /// without duplicates.
90    fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
91
92    /// Return all the devices found in the underneath module tree without duplicates.
93    fn devices(&self) -> Devices<B> {
94        self.collect_devices(Devices::<B>::new())
95    }
96
97    /// Fork the module and all of its sub-modules to the given device.
98    ///
99    /// # Notes
100    ///
101    /// This is similar to [to_device](Module::to_device), but it ensures the output module on the
102    /// new device will have its own autodiff graph.
103    fn fork(self, device: &B::Device) -> Self;
104
105    /// Move the module and all of its sub-modules to the given device.
106    ///
107    /// # Warnings
108    ///
109    /// The operation supports autodiff and it will be registered when activated. However, this may
110    /// not be what you want. The output model will be an intermediary model, meaning that you
111    /// can't optimize it with gradient descent. If you want to optimize the output network on the
112    /// target device, use [fork](Module::fork) instead.
113    fn to_device(self, device: &B::Device) -> Self;
114
115    /// Each tensor in the module tree will not require grad.
116    ///
117    /// # Warnings
118    ///
119    /// This should not be used for inference, use [valid](AutodiffModule::valid) when using
120    /// AD modules. This is mostly useful when performing partial finetuning, which is updating only
121    /// a small fraction of the parameters instead of finetuning all of them.
122    fn no_grad(self) -> Self {
123        module!(
124            map = self,
125            ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
126        )
127    }
128
129    /// Get the number of parameters the module has, including all of its sub-modules.
130    fn num_params(&self) -> usize {
131        module!(
132            visit_float = self,
133            ops = |tensor: &Tensor<B, D>, state: &mut usize| {
134                *state += tensor.shape().num_elements();
135            },
136            state = usize,
137            init = || 0
138        )
139    }
140    /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
141    fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
142
143    /// Map each tensor parameter in the module with a [mapper](ModuleMapper).
144    fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
145
146    /// Load the module state from a record.
147    fn load_record(self, record: Self::Record) -> Self;
148
149    /// Convert the module into a record containing the state.
150    fn into_record(self) -> Self::Record;
151
152    #[cfg(feature = "std")]
153    /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
154    ///
155    /// List of supported file recorders:
156    ///
157    /// * [default](crate::record::DefaultFileRecorder)
158    /// * [bincode](crate::record::BinFileRecorder)
159    /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
160    /// * [json pretty](crate::record::PrettyJsonFileRecorder)
161    /// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
162    /// * [named mpk](crate::record::NamedMpkFileRecorder)
163    /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
164    ///
165    /// ## Notes
166    ///
167    /// The file extension is automatically added depending on the file recorder provided, you
168    /// don't have to specify it.
169    fn save_file<FR, PB>(
170        self,
171        file_path: PB,
172        recorder: &FR,
173    ) -> Result<(), crate::record::RecorderError>
174    where
175        FR: crate::record::FileRecorder<B>,
176        PB: Into<std::path::PathBuf>,
177    {
178        let record = Self::into_record(self);
179        recorder.record(record, file_path.into())
180    }
181
182    #[cfg(feature = "std")]
183    /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
184    ///
185    /// The recorder should be the same as the one used to save the module, see
186    /// [save_file](Self::save_file).
187    ///
188    /// ## Notes
189    ///
190    /// The file extension is automatically added depending on the file recorder provided, you
191    /// don't have to specify it.
192    fn load_file<FR, PB>(
193        self,
194        file_path: PB,
195        recorder: &FR,
196        device: &B::Device,
197    ) -> Result<Self, crate::record::RecorderError>
198    where
199        FR: crate::record::FileRecorder<B>,
200        PB: Into<std::path::PathBuf>,
201    {
202        let record = recorder.load(file_path.into(), device)?;
203
204        Ok(self.load_record(record))
205    }
206
207    /// Quantize the weights of the module.
208    fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
209        self.map(quantizer)
210    }
211}
212
213/// Module visitor trait for traversing and inspecting module parameters.
214pub trait ModuleVisitor<B: Backend> {
215    /// Visit a float parameter in the module.
216    ///
217    /// # Parameters
218    /// - `param`: The float parameter to visit
219    #[allow(unused_variables)]
220    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}
221
222    /// Visit an int parameter in the module.
223    ///
224    /// # Parameters
225    /// - `param`: The integer parameter to visit
226    #[allow(unused_variables)]
227    fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}
228
229    /// Visit a bool parameter in the module.
230    ///
231    /// # Parameters
232    /// - `param`: The boolean parameter to visit
233    #[allow(unused_variables)]
234    fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}
235
236    /// Called when entering a submodule.
237    ///
238    /// # Parameters
239    /// - `name`: The name of the submodule being entered
240    /// - `container_type`: The type of the container (e.g., "Module", "Vec", etc.)
241    #[allow(unused_variables)]
242    fn enter_module(&mut self, name: &str, container_type: &str) {}
243
244    /// Called when exiting a submodule.
245    ///
246    /// # Parameters
247    /// - `name`: The name of the submodule being exited
248    /// - `container_type`: The type of the container (e.g., "Module", "Vec", etc.)
249    #[allow(unused_variables)]
250    fn exit_module(&mut self, name: &str, container_type: &str) {}
251
252    /// Visit a float tensor with its full module path.
253    ///
254    /// # Parameters
255    /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
256    ///   Each element represents a module name in the hierarchy, with the final element
257    ///   being the parameter name. This allows efficient reuse of the path stack.
258    /// - `id`: The unique identifier of the parameter
259    /// - `tensor`: The float tensor to visit
260    #[allow(unused_variables)]
261    fn visit_float_with_path<const D: usize>(
262        &mut self,
263        path: &[String],
264        id: ParamId,
265        tensor: &Tensor<B, D>,
266    ) {
267    }
268
269    /// Visit an int tensor with its full module path.
270    ///
271    /// # Parameters
272    /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
273    ///   Each element represents a module name in the hierarchy, with the final element
274    ///   being the parameter name. This allows efficient reuse of the path stack.
275    /// - `id`: The unique identifier of the parameter
276    /// - `tensor`: The integer tensor to visit
277    #[allow(unused_variables)]
278    fn visit_int_with_path<const D: usize>(
279        &mut self,
280        path: &[String],
281        id: ParamId,
282        tensor: &Tensor<B, D, Int>,
283    ) {
284    }
285
286    /// Visit a bool tensor with its full module path.
287    ///
288    /// # Parameters
289    /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
290    ///   Each element represents a module name in the hierarchy, with the final element
291    ///   being the parameter name. This allows efficient reuse of the path stack.
292    /// - `id`: The unique identifier of the parameter
293    /// - `tensor`: The boolean tensor to visit
294    #[allow(unused_variables)]
295    fn visit_bool_with_path<const D: usize>(
296        &mut self,
297        path: &[String],
298        id: ParamId,
299        tensor: &Tensor<B, D, Bool>,
300    ) {
301    }
302}
303
304/// Module mapper trait for transforming module parameters.
305pub trait ModuleMapper<B: Backend> {
306    /// Called when entering a submodule.
307    ///
308    /// # Parameters
309    /// - `name`: The name of the submodule being entered
310    /// - `container_type`: The type of the container (e.g., "Module", "Vec", etc.)
311    #[allow(unused_variables)]
312    fn enter_module(&mut self, name: &str, container_type: &str) {}
313
314    /// Called when exiting a submodule.
315    ///
316    /// # Parameters
317    /// - `name`: The name of the submodule being exited
318    /// - `container_type`: The type of the container (e.g., "Module", "Vec", etc.)
319    #[allow(unused_variables)]
320    fn exit_module(&mut self, name: &str, container_type: &str) {}
321
322    /// Map a float parameter in the module.
323    ///
324    /// # Parameters
325    /// - `param`: The float parameter to transform
326    ///
327    /// # Returns
328    /// The transformed parameter
329    #[allow(unused_variables)]
330    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
331        let (id, tensor, mapper) = param.consume();
332        Param::from_mapped_value(id, tensor, mapper)
333    }
334
335    /// Map an int parameter in the module.
336    ///
337    /// # Parameters
338    /// - `param`: The integer parameter to transform
339    ///
340    /// # Returns
341    /// The transformed parameter
342    #[allow(unused_variables)]
343    fn map_int<const D: usize>(
344        &mut self,
345        param: Param<Tensor<B, D, Int>>,
346    ) -> Param<Tensor<B, D, Int>> {
347        let (id, tensor, mapper) = param.consume();
348        Param::from_mapped_value(id, tensor, mapper)
349    }
350
351    /// Map a bool parameter in the module.
352    ///
353    /// # Parameters
354    /// - `param`: The boolean parameter to transform
355    ///
356    /// # Returns
357    /// The transformed parameter
358    #[allow(unused_variables)]
359    fn map_bool<const D: usize>(
360        &mut self,
361        param: Param<Tensor<B, D, Bool>>,
362    ) -> Param<Tensor<B, D, Bool>> {
363        let (id, tensor, mapper) = param.consume();
364        Param::from_mapped_value(id, tensor, mapper)
365    }
366}
367
368/// Module with auto-differentiation backend.
369pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
370    /// Inner module without auto-differentiation.
371    type InnerModule: Module<B::InnerBackend>;
372
373    /// Get the same module, but on the inner backend without auto-differentiation.
374    fn valid(&self) -> Self::InnerModule;
375}