Skip to main content

burn_core/module/
base.rs

1use super::{Param, ParamId, Quantizer};
2use crate::{
3    record::Record,
4    tensor::backend::{AutodiffBackend, Backend},
5};
6use alloc::{string::String, vec::Vec};
7pub use burn_derive::Module;
8use burn_tensor::{Bool, Int, Tensor, ops::Device};
9
10/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
11/// the `alloc` crate.
12pub type Devices<B> = Vec<Device<B>>;
13
14// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
15// We may consider making it public in the future.
16macro_rules! module {
17    (map=$module:ident, ops=$item:expr) => {{
18        struct Mapper;
19        impl<B: Backend> ModuleMapper<B> for Mapper {
20            fn map_float<const D: usize>(
21                &mut self,
22                param: Param<Tensor<B, D>>,
23            ) -> Param<Tensor<B, D>> {
24                let (id, tensor, mapper) = param.consume();
25                let func = $item;
26                let tensor = func(tensor);
27                Param::from_mapped_value(id, tensor, mapper)
28            }
29        }
30        let mut mapper = Mapper;
31        $module.map(&mut mapper)
32    }};
33    (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
34        struct Visitor<'a, B: Backend> {
35            state: &'a mut $state_ty,
36            backend: core::marker::PhantomData<B>,
37        }
38        impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
39            fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
40                let func = $item;
41                func(&param.val(), &mut self.state)
42            }
43        }
44        #[allow(clippy::redundant_closure_call)]
45        let mut state = $init();
46        let mut visitor = Visitor {
47            state: &mut state,
48            backend: core::marker::PhantomData,
49        };
50        $module.visit(&mut visitor);
51        state
52    }};
53}
54
55/// Trait for all neural network modules.
56///
57/// Modules should be created using the [derive](burn_derive::Module) attribute.
58/// This will make your module trainable, savable and loadable via
59/// `state` and `load`.
60///
61/// # Example
62///
63/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
64/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
65/// necessary to optimize and train the module on any backend.
66///
67/// ```rust, ignore
68/// // Not necessary when using the burn crate directly.
69/// use burn_core as burn;
70///
71/// use burn::{
72///     module::Module,
73///     nn::Linear,
74///     tensor::Tensor,
75///     tensor::backend::Backend,
76/// };
77///
78/// #[derive(Module, Debug)]
79/// struct MyModule<B: Backend> {
80///   my_param: Linear<B>,
81///   my_other_field: usize,
82/// }
83/// ```
84pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
85    /// Type to save and load the module.
86    type Record: Record<B>;
87
88    /// Return all the devices found in the underneath module tree added to the given vector
89    /// without duplicates.
90    fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
91
92    /// Return all the devices found in the underneath module tree without duplicates.
93    fn devices(&self) -> Devices<B> {
94        self.collect_devices(Devices::<B>::new())
95    }
96
97    /// Fork the module and all of its sub-modules to the given device.
98    ///
99    /// # Notes
100    ///
101    /// This is similar to [to_device](Module::to_device), but it ensures the output module on the
102    /// new device will have its own autodiff graph.
103    fn fork(self, device: &B::Device) -> Self;
104
105    /// Move the module and all of its sub-modules to the given device.
106    ///
107    /// # Warnings
108    ///
109    /// The operation supports autodiff and it will be registered when activated. However, this may
110    /// not be what you want. The output model will be an intermediary model, meaning that you
111    /// can't optimize it with gradient descent. If you want to optimize the output network on the
112    /// target device, use [fork](Module::fork) instead.
113    fn to_device(self, device: &B::Device) -> Self;
114
115    /// Each tensor in the module tree will not require grad.
116    ///
117    /// # Warnings
118    ///
119    /// This should not be used for inference, use [valid](AutodiffModule::valid) when using
120    /// AD modules. This is mostly useful when performing partial finetuning, which is updating only
121    /// a small fraction of the parameters instead of finetuning all of them.
122    fn no_grad(self) -> Self {
123        module!(
124            map = self,
125            ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
126        )
127    }
128
129    /// Move the module and all of its sub-modules to the autodiff backend.
130    ///
131    /// # Notes
132    ///
133    /// * Only plain modules (not already on an autodiff backend) can be moved.
134    /// * Calling `train()` on a module that is already on an autodiff backend
135    ///   will result in a type error, because the module's inner backend does not match.
136    fn train<AB>(self) -> <Self as HasAutodiffModule<AB>>::TrainModule
137    where
138        AB: AutodiffBackend<InnerBackend = B>,
139        Self: HasAutodiffModule<AB>,
140    {
141        <Self as HasAutodiffModule<AB>>::TrainModule::from_inner(self)
142    }
143
144    /// Get the number of parameters the module has, including all of its sub-modules.
145    fn num_params(&self) -> usize {
146        module!(
147            visit_float = self,
148            ops = |tensor: &Tensor<B, D>, state: &mut usize| {
149                *state += tensor.shape().num_elements();
150            },
151            state = usize,
152            init = || 0
153        )
154    }
155    /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
156    fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
157
158    /// Map each tensor parameter in the module with a [mapper](ModuleMapper).
159    fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
160
161    /// Load the module state from a record.
162    fn load_record(self, record: Self::Record) -> Self;
163
164    /// Convert the module into a record containing the state.
165    fn into_record(self) -> Self::Record;
166
167    #[cfg(feature = "std")]
168    /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
169    ///
170    /// List of supported file recorders:
171    ///
172    /// * [default](crate::record::DefaultFileRecorder)
173    /// * [bincode](crate::record::BinFileRecorder)
174    /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
175    /// * [json pretty](crate::record::PrettyJsonFileRecorder)
176    /// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
177    /// * [named mpk](crate::record::NamedMpkFileRecorder)
178    /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
179    ///
180    /// ## Notes
181    ///
182    /// The file extension is automatically added depending on the file recorder provided, you
183    /// don't have to specify it.
184    fn save_file<FR, PB>(
185        self,
186        file_path: PB,
187        recorder: &FR,
188    ) -> Result<(), crate::record::RecorderError>
189    where
190        FR: crate::record::FileRecorder<B>,
191        PB: Into<std::path::PathBuf>,
192    {
193        let record = Self::into_record(self);
194        recorder.record(record, file_path.into())
195    }
196
197    #[cfg(feature = "std")]
198    /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
199    ///
200    /// The recorder should be the same as the one used to save the module, see
201    /// [save_file](Self::save_file).
202    ///
203    /// ## Notes
204    ///
205    /// The file extension is automatically added depending on the file recorder provided, you
206    /// don't have to specify it.
207    fn load_file<FR, PB>(
208        self,
209        file_path: PB,
210        recorder: &FR,
211        device: &B::Device,
212    ) -> Result<Self, crate::record::RecorderError>
213    where
214        FR: crate::record::FileRecorder<B>,
215        PB: Into<std::path::PathBuf>,
216    {
217        let record = recorder.load(file_path.into(), device)?;
218
219        Ok(self.load_record(record))
220    }
221
222    /// Quantize the weights of the module.
223    fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
224        self.map(quantizer)
225    }
226}
227
228/// Module visitor trait for traversing and inspecting module parameters.
229pub trait ModuleVisitor<B: Backend> {
230    /// Visit a float parameter in the module.
231    ///
232    /// # Parameters
233    /// - `param`: The float parameter to visit
234    #[allow(unused_variables)]
235    fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}
236
237    /// Visit an int parameter in the module.
238    ///
239    /// # Parameters
240    /// - `param`: The integer parameter to visit
241    #[allow(unused_variables)]
242    fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}
243
244    /// Visit a bool parameter in the module.
245    ///
246    /// # Parameters
247    /// - `param`: The boolean parameter to visit
248    #[allow(unused_variables)]
249    fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}
250
251    /// Called when entering a submodule.
252    ///
253    /// # Parameters
254    /// - `name`: The name of the submodule being entered
255    /// - `container_type`: The type of the container with format:
256    ///   - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
257    ///   - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
258    ///   - For Vec containers: "Vec" (name is the index)
259    ///   - For Tuple containers: "Tuple" (name is the index)
260    ///   - For Array containers: "Array" (name is the index)
261    ///
262    /// Note: Option containers do not call enter_module/exit_module to preserve
263    /// the field name in the path (e.g., "bias" instead of "bias.Some")
264    #[allow(unused_variables)]
265    fn enter_module(&mut self, name: &str, container_type: &str) {}
266
267    /// Called when exiting a submodule.
268    ///
269    /// # Parameters
270    /// - `name`: The name of the submodule being exited
271    /// - `container_type`: The type of the container with format:
272    ///   - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
273    ///   - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
274    ///   - For Vec containers: "Vec" (name is the index)
275    ///   - For Tuple containers: "Tuple" (name is the index)
276    ///   - For Array containers: "Array" (name is the index)
277    ///
278    /// Note: Option containers do not call enter_module/exit_module to preserve
279    /// the field name in the path (e.g., "bias" instead of "bias.Some")
280    #[allow(unused_variables)]
281    fn exit_module(&mut self, name: &str, container_type: &str) {}
282
283    /// Visit a float tensor with its full module path.
284    ///
285    /// # Parameters
286    /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
287    ///   Each element represents a module name in the hierarchy, with the final element
288    ///   being the parameter name. This allows efficient reuse of the path stack.
289    /// - `id`: The unique identifier of the parameter
290    /// - `tensor`: The float tensor to visit
291    #[allow(unused_variables)]
292    fn visit_float_with_path<const D: usize>(
293        &mut self,
294        path: &[String],
295        id: ParamId,
296        tensor: &Tensor<B, D>,
297    ) {
298    }
299
300    /// Visit an int tensor with its full module path.
301    ///
302    /// # Parameters
303    /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
304    ///   Each element represents a module name in the hierarchy, with the final element
305    ///   being the parameter name. This allows efficient reuse of the path stack.
306    /// - `id`: The unique identifier of the parameter
307    /// - `tensor`: The integer tensor to visit
308    #[allow(unused_variables)]
309    fn visit_int_with_path<const D: usize>(
310        &mut self,
311        path: &[String],
312        id: ParamId,
313        tensor: &Tensor<B, D, Int>,
314    ) {
315    }
316
317    /// Visit a bool tensor with its full module path.
318    ///
319    /// # Parameters
320    /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
321    ///   Each element represents a module name in the hierarchy, with the final element
322    ///   being the parameter name. This allows efficient reuse of the path stack.
323    /// - `id`: The unique identifier of the parameter
324    /// - `tensor`: The boolean tensor to visit
325    #[allow(unused_variables)]
326    fn visit_bool_with_path<const D: usize>(
327        &mut self,
328        path: &[String],
329        id: ParamId,
330        tensor: &Tensor<B, D, Bool>,
331    ) {
332    }
333}
334
335/// Module mapper trait for transforming module parameters.
336pub trait ModuleMapper<B: Backend> {
337    /// Called when entering a submodule.
338    ///
339    /// # Parameters
340    /// - `name`: The name of the submodule being entered
341    /// - `container_type`: The type of the container with format:
342    ///   - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
343    ///   - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
344    ///   - For Vec containers: "Vec" (name is the index)
345    ///   - For Tuple containers: "Tuple" (name is the index)
346    ///   - For Array containers: "Array" (name is the index)
347    ///
348    /// Note: Option containers do not call enter_module/exit_module to preserve
349    /// the field name in the path (e.g., "bias" instead of "bias.Some")
350    #[allow(unused_variables)]
351    fn enter_module(&mut self, name: &str, container_type: &str) {}
352
353    /// Called when exiting a submodule.
354    ///
355    /// # Parameters
356    /// - `name`: The name of the submodule being exited
357    /// - `container_type`: The type of the container with format:
358    ///   - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
359    ///   - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
360    ///   - For Vec containers: "Vec" (name is the index)
361    ///   - For Tuple containers: "Tuple" (name is the index)
362    ///   - For Array containers: "Array" (name is the index)
363    ///
364    /// Note: Option containers do not call enter_module/exit_module to preserve
365    /// the field name in the path (e.g., "bias" instead of "bias.Some")
366    #[allow(unused_variables)]
367    fn exit_module(&mut self, name: &str, container_type: &str) {}
368
369    /// Map a float parameter in the module.
370    ///
371    /// # Parameters
372    /// - `param`: The float parameter to transform
373    ///
374    /// # Returns
375    /// The transformed parameter
376    #[allow(unused_variables)]
377    fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
378        let (id, tensor, mapper) = param.consume();
379        Param::from_mapped_value(id, tensor, mapper)
380    }
381
382    /// Map an int parameter in the module.
383    ///
384    /// # Parameters
385    /// - `param`: The integer parameter to transform
386    ///
387    /// # Returns
388    /// The transformed parameter
389    #[allow(unused_variables)]
390    fn map_int<const D: usize>(
391        &mut self,
392        param: Param<Tensor<B, D, Int>>,
393    ) -> Param<Tensor<B, D, Int>> {
394        let (id, tensor, mapper) = param.consume();
395        Param::from_mapped_value(id, tensor, mapper)
396    }
397
398    /// Map a bool parameter in the module.
399    ///
400    /// # Parameters
401    /// - `param`: The boolean parameter to transform
402    ///
403    /// # Returns
404    /// The transformed parameter
405    #[allow(unused_variables)]
406    fn map_bool<const D: usize>(
407        &mut self,
408        param: Param<Tensor<B, D, Bool>>,
409    ) -> Param<Tensor<B, D, Bool>> {
410        let (id, tensor, mapper) = param.consume();
411        Param::from_mapped_value(id, tensor, mapper)
412    }
413}
414
415/// Module with auto-differentiation backend.
416pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
417    /// Inner module without auto-differentiation.
418    type InnerModule: Module<B::InnerBackend>;
419
420    /// Returns the same module, but on the inner backend without auto-differentiation.
421    fn valid(&self) -> Self::InnerModule;
422
423    /// Wraps an inner module back into an auto-diff module.
424    fn from_inner(module: Self::InnerModule) -> Self;
425}
426
427/// Helper trait to associate a module with its autodiff version.
428pub trait HasAutodiffModule<B: AutodiffBackend> {
429    /// The module with auto-differentiation.
430    type TrainModule: AutodiffModule<B, InnerModule = Self>;
431}
432
433#[cfg(test)]
434mod tests {
435    use super::*;
436
437    use crate::TestAutodiffBackend;
438    use crate::test_utils::SimpleLinear;
439
440    #[test]
441    fn test_module_val_train_stateful() {
442        let device = Default::default();
443        let module = SimpleLinear::<TestAutodiffBackend>::new(4, 4, &device);
444
445        assert!(module.weight.is_require_grad());
446        assert!(module.weight.require_grad);
447
448        let module = module.valid();
449        assert!(!module.weight.is_require_grad());
450        assert!(module.weight.require_grad); // stateful
451
452        // Without `HasAutodiffModule`, we would need to specify the module type as well, which would be annoying
453        // let module: SimpleLinear<TestAutodiffBackend> = module.train();
454        let module = module.train::<TestAutodiffBackend>();
455        assert!(module.weight.is_require_grad());
456        assert!(module.weight.require_grad); // stateful
457
458        let module = module.no_grad();
459        assert!(!module.weight.is_require_grad());
460        assert!(!module.weight.require_grad); // stateful
461
462        let module = module.valid();
463        assert!(!module.weight.is_require_grad()); // always
464        assert!(!module.weight.require_grad); // stateful
465
466        let module = module.train::<TestAutodiffBackend>();
467        assert!(!module.weight.is_require_grad());
468        assert!(!module.weight.require_grad); // stateful
469    }
470}