burn_core/module/base.rs
1use super::{Param, ParamId, Quantizer};
2use crate::{
3 record::Record,
4 tensor::backend::{AutodiffBackend, Backend},
5};
6use alloc::{string::String, vec::Vec};
7pub use burn_derive::Module;
8use burn_tensor::{Bool, Int, Tensor, ops::Device};
9
10/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
11/// the `alloc` crate.
12pub type Devices<B> = Vec<Device<B>>;
13
14// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
15// We may consider making it public in the future.
16macro_rules! module {
17 (map=$module:ident, ops=$item:expr) => {{
18 struct Mapper;
19 impl<B: Backend> ModuleMapper<B> for Mapper {
20 fn map_float<const D: usize>(
21 &mut self,
22 param: Param<Tensor<B, D>>,
23 ) -> Param<Tensor<B, D>> {
24 let (id, tensor, mapper) = param.consume();
25 let func = $item;
26 let tensor = func(tensor);
27 Param::from_mapped_value(id, tensor, mapper)
28 }
29 }
30 let mut mapper = Mapper;
31 $module.map(&mut mapper)
32 }};
33 (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
34 struct Visitor<'a, B: Backend> {
35 state: &'a mut $state_ty,
36 backend: core::marker::PhantomData<B>,
37 }
38 impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
39 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
40 let func = $item;
41 func(¶m.val(), &mut self.state)
42 }
43 }
44 #[allow(clippy::redundant_closure_call)]
45 let mut state = $init();
46 let mut visitor = Visitor {
47 state: &mut state,
48 backend: core::marker::PhantomData,
49 };
50 $module.visit(&mut visitor);
51 state
52 }};
53}
54
55/// Trait for all neural network modules.
56///
57/// Modules should be created using the [derive](burn_derive::Module) attribute.
58/// This will make your module trainable, savable and loadable via
59/// `state` and `load`.
60///
61/// # Example
62///
63/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
64/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
65/// necessary to optimize and train the module on any backend.
66///
67/// ```rust, ignore
68/// // Not necessary when using the burn crate directly.
69/// use burn_core as burn;
70///
71/// use burn::{
72/// module::Module,
73/// nn::Linear,
74/// tensor::Tensor,
75/// tensor::backend::Backend,
76/// };
77///
78/// #[derive(Module, Debug)]
79/// struct MyModule<B: Backend> {
80/// my_param: Linear<B>,
81/// my_other_field: usize,
82/// }
83/// ```
84pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
85 /// Type to save and load the module.
86 type Record: Record<B>;
87
88 /// Return all the devices found in the underneath module tree added to the given vector
89 /// without duplicates.
90 fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
91
92 /// Return all the devices found in the underneath module tree without duplicates.
93 fn devices(&self) -> Devices<B> {
94 self.collect_devices(Devices::<B>::new())
95 }
96
97 /// Fork the module and all of its sub-modules to the given device.
98 ///
99 /// # Notes
100 ///
101 /// This is similar to [to_device](Module::to_device), but it ensures the output module on the
102 /// new device will have its own autodiff graph.
103 fn fork(self, device: &B::Device) -> Self;
104
105 /// Move the module and all of its sub-modules to the given device.
106 ///
107 /// # Warnings
108 ///
109 /// The operation supports autodiff and it will be registered when activated. However, this may
110 /// not be what you want. The output model will be an intermediary model, meaning that you
111 /// can't optimize it with gradient descent. If you want to optimize the output network on the
112 /// target device, use [fork](Module::fork) instead.
113 fn to_device(self, device: &B::Device) -> Self;
114
115 /// Each tensor in the module tree will not require grad.
116 ///
117 /// # Warnings
118 ///
119 /// This should not be used for inference, use [valid](AutodiffModule::valid) when using
120 /// AD modules. This is mostly useful when performing partial finetuning, which is updating only
121 /// a small fraction of the parameters instead of finetuning all of them.
122 fn no_grad(self) -> Self {
123 module!(
124 map = self,
125 ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
126 )
127 }
128
129 /// Move the module and all of its sub-modules to the autodiff backend.
130 ///
131 /// # Notes
132 ///
133 /// * Only plain modules (not already on an autodiff backend) can be moved.
134 /// * Calling `train()` on a module that is already on an autodiff backend
135 /// will result in a type error, because the module's inner backend does not match.
136 fn train<AB>(self) -> <Self as HasAutodiffModule<AB>>::TrainModule
137 where
138 AB: AutodiffBackend<InnerBackend = B>,
139 Self: HasAutodiffModule<AB>,
140 {
141 <Self as HasAutodiffModule<AB>>::TrainModule::from_inner(self)
142 }
143
144 /// Get the number of parameters the module has, including all of its sub-modules.
145 fn num_params(&self) -> usize {
146 module!(
147 visit_float = self,
148 ops = |tensor: &Tensor<B, D>, state: &mut usize| {
149 *state += tensor.shape().num_elements();
150 },
151 state = usize,
152 init = || 0
153 )
154 }
155 /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
156 fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
157
158 /// Map each tensor parameter in the module with a [mapper](ModuleMapper).
159 fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
160
161 /// Load the module state from a record.
162 fn load_record(self, record: Self::Record) -> Self;
163
164 /// Convert the module into a record containing the state.
165 fn into_record(self) -> Self::Record;
166
167 #[cfg(feature = "std")]
168 /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
169 ///
170 /// List of supported file recorders:
171 ///
172 /// * [default](crate::record::DefaultFileRecorder)
173 /// * [bincode](crate::record::BinFileRecorder)
174 /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
175 /// * [json pretty](crate::record::PrettyJsonFileRecorder)
176 /// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
177 /// * [named mpk](crate::record::NamedMpkFileRecorder)
178 /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
179 ///
180 /// ## Notes
181 ///
182 /// The file extension is automatically added depending on the file recorder provided, you
183 /// don't have to specify it.
184 fn save_file<FR, PB>(
185 self,
186 file_path: PB,
187 recorder: &FR,
188 ) -> Result<(), crate::record::RecorderError>
189 where
190 FR: crate::record::FileRecorder<B>,
191 PB: Into<std::path::PathBuf>,
192 {
193 let record = Self::into_record(self);
194 recorder.record(record, file_path.into())
195 }
196
197 #[cfg(feature = "std")]
198 /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
199 ///
200 /// The recorder should be the same as the one used to save the module, see
201 /// [save_file](Self::save_file).
202 ///
203 /// ## Notes
204 ///
205 /// The file extension is automatically added depending on the file recorder provided, you
206 /// don't have to specify it.
207 fn load_file<FR, PB>(
208 self,
209 file_path: PB,
210 recorder: &FR,
211 device: &B::Device,
212 ) -> Result<Self, crate::record::RecorderError>
213 where
214 FR: crate::record::FileRecorder<B>,
215 PB: Into<std::path::PathBuf>,
216 {
217 let record = recorder.load(file_path.into(), device)?;
218
219 Ok(self.load_record(record))
220 }
221
222 /// Quantize the weights of the module.
223 fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
224 self.map(quantizer)
225 }
226}
227
228/// Module visitor trait for traversing and inspecting module parameters.
229pub trait ModuleVisitor<B: Backend> {
230 /// Visit a float parameter in the module.
231 ///
232 /// # Parameters
233 /// - `param`: The float parameter to visit
234 #[allow(unused_variables)]
235 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}
236
237 /// Visit an int parameter in the module.
238 ///
239 /// # Parameters
240 /// - `param`: The integer parameter to visit
241 #[allow(unused_variables)]
242 fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}
243
244 /// Visit a bool parameter in the module.
245 ///
246 /// # Parameters
247 /// - `param`: The boolean parameter to visit
248 #[allow(unused_variables)]
249 fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}
250
251 /// Called when entering a submodule.
252 ///
253 /// # Parameters
254 /// - `name`: The name of the submodule being entered
255 /// - `container_type`: The type of the container with format:
256 /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
257 /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
258 /// - For Vec containers: "Vec" (name is the index)
259 /// - For Tuple containers: "Tuple" (name is the index)
260 /// - For Array containers: "Array" (name is the index)
261 ///
262 /// Note: Option containers do not call enter_module/exit_module to preserve
263 /// the field name in the path (e.g., "bias" instead of "bias.Some")
264 #[allow(unused_variables)]
265 fn enter_module(&mut self, name: &str, container_type: &str) {}
266
267 /// Called when exiting a submodule.
268 ///
269 /// # Parameters
270 /// - `name`: The name of the submodule being exited
271 /// - `container_type`: The type of the container with format:
272 /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
273 /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
274 /// - For Vec containers: "Vec" (name is the index)
275 /// - For Tuple containers: "Tuple" (name is the index)
276 /// - For Array containers: "Array" (name is the index)
277 ///
278 /// Note: Option containers do not call enter_module/exit_module to preserve
279 /// the field name in the path (e.g., "bias" instead of "bias.Some")
280 #[allow(unused_variables)]
281 fn exit_module(&mut self, name: &str, container_type: &str) {}
282
283 /// Visit a float tensor with its full module path.
284 ///
285 /// # Parameters
286 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
287 /// Each element represents a module name in the hierarchy, with the final element
288 /// being the parameter name. This allows efficient reuse of the path stack.
289 /// - `id`: The unique identifier of the parameter
290 /// - `tensor`: The float tensor to visit
291 #[allow(unused_variables)]
292 fn visit_float_with_path<const D: usize>(
293 &mut self,
294 path: &[String],
295 id: ParamId,
296 tensor: &Tensor<B, D>,
297 ) {
298 }
299
300 /// Visit an int tensor with its full module path.
301 ///
302 /// # Parameters
303 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
304 /// Each element represents a module name in the hierarchy, with the final element
305 /// being the parameter name. This allows efficient reuse of the path stack.
306 /// - `id`: The unique identifier of the parameter
307 /// - `tensor`: The integer tensor to visit
308 #[allow(unused_variables)]
309 fn visit_int_with_path<const D: usize>(
310 &mut self,
311 path: &[String],
312 id: ParamId,
313 tensor: &Tensor<B, D, Int>,
314 ) {
315 }
316
317 /// Visit a bool tensor with its full module path.
318 ///
319 /// # Parameters
320 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
321 /// Each element represents a module name in the hierarchy, with the final element
322 /// being the parameter name. This allows efficient reuse of the path stack.
323 /// - `id`: The unique identifier of the parameter
324 /// - `tensor`: The boolean tensor to visit
325 #[allow(unused_variables)]
326 fn visit_bool_with_path<const D: usize>(
327 &mut self,
328 path: &[String],
329 id: ParamId,
330 tensor: &Tensor<B, D, Bool>,
331 ) {
332 }
333}
334
335/// Module mapper trait for transforming module parameters.
336pub trait ModuleMapper<B: Backend> {
337 /// Called when entering a submodule.
338 ///
339 /// # Parameters
340 /// - `name`: The name of the submodule being entered
341 /// - `container_type`: The type of the container with format:
342 /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
343 /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
344 /// - For Vec containers: "Vec" (name is the index)
345 /// - For Tuple containers: "Tuple" (name is the index)
346 /// - For Array containers: "Array" (name is the index)
347 ///
348 /// Note: Option containers do not call enter_module/exit_module to preserve
349 /// the field name in the path (e.g., "bias" instead of "bias.Some")
350 #[allow(unused_variables)]
351 fn enter_module(&mut self, name: &str, container_type: &str) {}
352
353 /// Called when exiting a submodule.
354 ///
355 /// # Parameters
356 /// - `name`: The name of the submodule being exited
357 /// - `container_type`: The type of the container with format:
358 /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
359 /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
360 /// - For Vec containers: "Vec" (name is the index)
361 /// - For Tuple containers: "Tuple" (name is the index)
362 /// - For Array containers: "Array" (name is the index)
363 ///
364 /// Note: Option containers do not call enter_module/exit_module to preserve
365 /// the field name in the path (e.g., "bias" instead of "bias.Some")
366 #[allow(unused_variables)]
367 fn exit_module(&mut self, name: &str, container_type: &str) {}
368
369 /// Map a float parameter in the module.
370 ///
371 /// # Parameters
372 /// - `param`: The float parameter to transform
373 ///
374 /// # Returns
375 /// The transformed parameter
376 #[allow(unused_variables)]
377 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
378 let (id, tensor, mapper) = param.consume();
379 Param::from_mapped_value(id, tensor, mapper)
380 }
381
382 /// Map an int parameter in the module.
383 ///
384 /// # Parameters
385 /// - `param`: The integer parameter to transform
386 ///
387 /// # Returns
388 /// The transformed parameter
389 #[allow(unused_variables)]
390 fn map_int<const D: usize>(
391 &mut self,
392 param: Param<Tensor<B, D, Int>>,
393 ) -> Param<Tensor<B, D, Int>> {
394 let (id, tensor, mapper) = param.consume();
395 Param::from_mapped_value(id, tensor, mapper)
396 }
397
398 /// Map a bool parameter in the module.
399 ///
400 /// # Parameters
401 /// - `param`: The boolean parameter to transform
402 ///
403 /// # Returns
404 /// The transformed parameter
405 #[allow(unused_variables)]
406 fn map_bool<const D: usize>(
407 &mut self,
408 param: Param<Tensor<B, D, Bool>>,
409 ) -> Param<Tensor<B, D, Bool>> {
410 let (id, tensor, mapper) = param.consume();
411 Param::from_mapped_value(id, tensor, mapper)
412 }
413}
414
415/// Module with auto-differentiation backend.
416pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
417 /// Inner module without auto-differentiation.
418 type InnerModule: Module<B::InnerBackend>;
419
420 /// Returns the same module, but on the inner backend without auto-differentiation.
421 fn valid(&self) -> Self::InnerModule;
422
423 /// Wraps an inner module back into an auto-diff module.
424 fn from_inner(module: Self::InnerModule) -> Self;
425}
426
427/// Helper trait to associate a module with its autodiff version.
428pub trait HasAutodiffModule<B: AutodiffBackend> {
429 /// The module with auto-differentiation.
430 type TrainModule: AutodiffModule<B, InnerModule = Self>;
431}
432
433#[cfg(test)]
434mod tests {
435 use super::*;
436
437 use crate::TestAutodiffBackend;
438 use crate::test_utils::SimpleLinear;
439
440 #[test]
441 fn test_module_val_train_stateful() {
442 let device = Default::default();
443 let module = SimpleLinear::<TestAutodiffBackend>::new(4, 4, &device);
444
445 assert!(module.weight.is_require_grad());
446 assert!(module.weight.require_grad);
447
448 let module = module.valid();
449 assert!(!module.weight.is_require_grad());
450 assert!(module.weight.require_grad); // stateful
451
452 // Without `HasAutodiffModule`, we would need to specify the module type as well, which would be annoying
453 // let module: SimpleLinear<TestAutodiffBackend> = module.train();
454 let module = module.train::<TestAutodiffBackend>();
455 assert!(module.weight.is_require_grad());
456 assert!(module.weight.require_grad); // stateful
457
458 let module = module.no_grad();
459 assert!(!module.weight.is_require_grad());
460 assert!(!module.weight.require_grad); // stateful
461
462 let module = module.valid();
463 assert!(!module.weight.is_require_grad()); // always
464 assert!(!module.weight.require_grad); // stateful
465
466 let module = module.train::<TestAutodiffBackend>();
467 assert!(!module.weight.is_require_grad());
468 assert!(!module.weight.require_grad); // stateful
469 }
470}