burn_core/module/base.rs
1use super::{Param, ParamId, Quantizer};
2use crate::{
3 record::Record,
4 tensor::backend::{AutodiffBackend, Backend},
5};
6use alloc::{string::String, vec::Vec};
7pub use burn_derive::Module;
8use burn_tensor::{Bool, Int, Tensor, ops::Device};
9
10/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
11/// the `alloc` crate.
12pub type Devices<B> = Vec<Device<B>>;
13
14// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
15// We may consider making it public in the future.
16macro_rules! module {
17 (map=$module:ident, ops=$item:expr) => {{
18 struct Mapper;
19 impl<B: Backend> ModuleMapper<B> for Mapper {
20 fn map_float<const D: usize>(
21 &mut self,
22 param: Param<Tensor<B, D>>,
23 ) -> Param<Tensor<B, D>> {
24 let (id, tensor, mapper) = param.consume();
25 let func = $item;
26 let tensor = func(tensor);
27 Param::from_mapped_value(id, tensor, mapper)
28 }
29 }
30 let mut mapper = Mapper;
31 $module.map(&mut mapper)
32 }};
33 (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
34 struct Visitor<'a, B: Backend> {
35 state: &'a mut $state_ty,
36 backend: core::marker::PhantomData<B>,
37 }
38 impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
39 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
40 let func = $item;
41 func(¶m.val(), &mut self.state)
42 }
43 }
44 #[allow(clippy::redundant_closure_call)]
45 let mut state = $init();
46 let mut visitor = Visitor {
47 state: &mut state,
48 backend: core::marker::PhantomData,
49 };
50 $module.visit(&mut visitor);
51 state
52 }};
53}
54
55/// Trait for all neural network modules.
56///
57/// Modules should be created using the [derive](burn_derive::Module) attribute.
58/// This will make your module trainable, savable and loadable via
59/// `state` and `load`.
60///
61/// # Example
62///
63/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
64/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
65/// necessary to optimize and train the module on any backend.
66///
67/// ```rust, ignore
68/// // Not necessary when using the burn crate directly.
69/// use burn_core as burn;
70///
71/// use burn::{
72/// module::Module,
73/// nn::Linear,
74/// tensor::Tensor,
75/// tensor::backend::Backend,
76/// };
77///
78/// #[derive(Module, Debug)]
79/// struct MyModule<B: Backend> {
80/// my_param: Linear<B>,
81/// my_other_field: usize,
82/// }
83/// ```
84pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
85 /// Type to save and load the module.
86 type Record: Record<B>;
87
88 /// Return all the devices found in the underneath module tree added to the given vector
89 /// without duplicates.
90 fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
91
92 /// Return all the devices found in the underneath module tree without duplicates.
93 fn devices(&self) -> Devices<B> {
94 self.collect_devices(Devices::<B>::new())
95 }
96
97 /// Fork the module and all of its sub-modules to the given device.
98 ///
99 /// # Notes
100 ///
101 /// This is similar to [to_device](Module::to_device), but it ensures the output module on the
102 /// new device will have its own autodiff graph.
103 fn fork(self, device: &B::Device) -> Self;
104
105 /// Move the module and all of its sub-modules to the given device.
106 ///
107 /// # Warnings
108 ///
109 /// The operation supports autodiff and it will be registered when activated. However, this may
110 /// not be what you want. The output model will be an intermediary model, meaning that you
111 /// can't optimize it with gradient descent. If you want to optimize the output network on the
112 /// target device, use [fork](Module::fork) instead.
113 fn to_device(self, device: &B::Device) -> Self;
114
115 /// Each tensor in the module tree will not require grad.
116 ///
117 /// # Warnings
118 ///
119 /// This should not be used for inference, use [valid](AutodiffModule::valid) when using
120 /// AD modules. This is mostly useful when performing partial finetuning, which is updating only
121 /// a small fraction of the parameters instead of finetuning all of them.
122 fn no_grad(self) -> Self {
123 module!(
124 map = self,
125 ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
126 )
127 }
128
129 /// Get the number of parameters the module has, including all of its sub-modules.
130 fn num_params(&self) -> usize {
131 module!(
132 visit_float = self,
133 ops = |tensor: &Tensor<B, D>, state: &mut usize| {
134 *state += tensor.shape().num_elements();
135 },
136 state = usize,
137 init = || 0
138 )
139 }
140 /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
141 fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
142
143 /// Map each tensor parameter in the module with a [mapper](ModuleMapper).
144 fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
145
146 /// Load the module state from a record.
147 fn load_record(self, record: Self::Record) -> Self;
148
149 /// Convert the module into a record containing the state.
150 fn into_record(self) -> Self::Record;
151
152 #[cfg(feature = "std")]
153 /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
154 ///
155 /// List of supported file recorders:
156 ///
157 /// * [default](crate::record::DefaultFileRecorder)
158 /// * [bincode](crate::record::BinFileRecorder)
159 /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
160 /// * [json pretty](crate::record::PrettyJsonFileRecorder)
161 /// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
162 /// * [named mpk](crate::record::NamedMpkFileRecorder)
163 /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
164 ///
165 /// ## Notes
166 ///
167 /// The file extension is automatically added depending on the file recorder provided, you
168 /// don't have to specify it.
169 fn save_file<FR, PB>(
170 self,
171 file_path: PB,
172 recorder: &FR,
173 ) -> Result<(), crate::record::RecorderError>
174 where
175 FR: crate::record::FileRecorder<B>,
176 PB: Into<std::path::PathBuf>,
177 {
178 let record = Self::into_record(self);
179 recorder.record(record, file_path.into())
180 }
181
182 #[cfg(feature = "std")]
183 /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
184 ///
185 /// The recorder should be the same as the one used to save the module, see
186 /// [save_file](Self::save_file).
187 ///
188 /// ## Notes
189 ///
190 /// The file extension is automatically added depending on the file recorder provided, you
191 /// don't have to specify it.
192 fn load_file<FR, PB>(
193 self,
194 file_path: PB,
195 recorder: &FR,
196 device: &B::Device,
197 ) -> Result<Self, crate::record::RecorderError>
198 where
199 FR: crate::record::FileRecorder<B>,
200 PB: Into<std::path::PathBuf>,
201 {
202 let record = recorder.load(file_path.into(), device)?;
203
204 Ok(self.load_record(record))
205 }
206
207 /// Quantize the weights of the module.
208 fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
209 self.map(quantizer)
210 }
211}
212
213/// Module visitor trait for traversing and inspecting module parameters.
214pub trait ModuleVisitor<B: Backend> {
215 /// Visit a float parameter in the module.
216 ///
217 /// # Parameters
218 /// - `param`: The float parameter to visit
219 #[allow(unused_variables)]
220 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}
221
222 /// Visit an int parameter in the module.
223 ///
224 /// # Parameters
225 /// - `param`: The integer parameter to visit
226 #[allow(unused_variables)]
227 fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}
228
229 /// Visit a bool parameter in the module.
230 ///
231 /// # Parameters
232 /// - `param`: The boolean parameter to visit
233 #[allow(unused_variables)]
234 fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}
235
236 /// Called when entering a submodule.
237 ///
238 /// # Parameters
239 /// - `name`: The name of the submodule being entered
240 /// - `container_type`: The type of the container with format:
241 /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
242 /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
243 /// - For Vec containers: "Vec" (name is the index)
244 /// - For Tuple containers: "Tuple" (name is the index)
245 /// - For Array containers: "Array" (name is the index)
246 ///
247 /// Note: Option containers do not call enter_module/exit_module to preserve
248 /// the field name in the path (e.g., "bias" instead of "bias.Some")
249 #[allow(unused_variables)]
250 fn enter_module(&mut self, name: &str, container_type: &str) {}
251
252 /// Called when exiting a submodule.
253 ///
254 /// # Parameters
255 /// - `name`: The name of the submodule being exited
256 /// - `container_type`: The type of the container with format:
257 /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
258 /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
259 /// - For Vec containers: "Vec" (name is the index)
260 /// - For Tuple containers: "Tuple" (name is the index)
261 /// - For Array containers: "Array" (name is the index)
262 ///
263 /// Note: Option containers do not call enter_module/exit_module to preserve
264 /// the field name in the path (e.g., "bias" instead of "bias.Some")
265 #[allow(unused_variables)]
266 fn exit_module(&mut self, name: &str, container_type: &str) {}
267
268 /// Visit a float tensor with its full module path.
269 ///
270 /// # Parameters
271 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
272 /// Each element represents a module name in the hierarchy, with the final element
273 /// being the parameter name. This allows efficient reuse of the path stack.
274 /// - `id`: The unique identifier of the parameter
275 /// - `tensor`: The float tensor to visit
276 #[allow(unused_variables)]
277 fn visit_float_with_path<const D: usize>(
278 &mut self,
279 path: &[String],
280 id: ParamId,
281 tensor: &Tensor<B, D>,
282 ) {
283 }
284
285 /// Visit an int tensor with its full module path.
286 ///
287 /// # Parameters
288 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
289 /// Each element represents a module name in the hierarchy, with the final element
290 /// being the parameter name. This allows efficient reuse of the path stack.
291 /// - `id`: The unique identifier of the parameter
292 /// - `tensor`: The integer tensor to visit
293 #[allow(unused_variables)]
294 fn visit_int_with_path<const D: usize>(
295 &mut self,
296 path: &[String],
297 id: ParamId,
298 tensor: &Tensor<B, D, Int>,
299 ) {
300 }
301
302 /// Visit a bool tensor with its full module path.
303 ///
304 /// # Parameters
305 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
306 /// Each element represents a module name in the hierarchy, with the final element
307 /// being the parameter name. This allows efficient reuse of the path stack.
308 /// - `id`: The unique identifier of the parameter
309 /// - `tensor`: The boolean tensor to visit
310 #[allow(unused_variables)]
311 fn visit_bool_with_path<const D: usize>(
312 &mut self,
313 path: &[String],
314 id: ParamId,
315 tensor: &Tensor<B, D, Bool>,
316 ) {
317 }
318}
319
320/// Module mapper trait for transforming module parameters.
321pub trait ModuleMapper<B: Backend> {
322 /// Called when entering a submodule.
323 ///
324 /// # Parameters
325 /// - `name`: The name of the submodule being entered
326 /// - `container_type`: The type of the container with format:
327 /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
328 /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
329 /// - For Vec containers: "Vec" (name is the index)
330 /// - For Tuple containers: "Tuple" (name is the index)
331 /// - For Array containers: "Array" (name is the index)
332 ///
333 /// Note: Option containers do not call enter_module/exit_module to preserve
334 /// the field name in the path (e.g., "bias" instead of "bias.Some")
335 #[allow(unused_variables)]
336 fn enter_module(&mut self, name: &str, container_type: &str) {}
337
338 /// Called when exiting a submodule.
339 ///
340 /// # Parameters
341 /// - `name`: The name of the submodule being exited
342 /// - `container_type`: The type of the container with format:
343 /// - For user-defined structs: "Struct:TypeName" (e.g., "Struct:Linear")
344 /// - For user-defined enums: "Enum:TypeName" (e.g., "Enum:MyEnum")
345 /// - For Vec containers: "Vec" (name is the index)
346 /// - For Tuple containers: "Tuple" (name is the index)
347 /// - For Array containers: "Array" (name is the index)
348 ///
349 /// Note: Option containers do not call enter_module/exit_module to preserve
350 /// the field name in the path (e.g., "bias" instead of "bias.Some")
351 #[allow(unused_variables)]
352 fn exit_module(&mut self, name: &str, container_type: &str) {}
353
354 /// Map a float parameter in the module.
355 ///
356 /// # Parameters
357 /// - `param`: The float parameter to transform
358 ///
359 /// # Returns
360 /// The transformed parameter
361 #[allow(unused_variables)]
362 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
363 let (id, tensor, mapper) = param.consume();
364 Param::from_mapped_value(id, tensor, mapper)
365 }
366
367 /// Map an int parameter in the module.
368 ///
369 /// # Parameters
370 /// - `param`: The integer parameter to transform
371 ///
372 /// # Returns
373 /// The transformed parameter
374 #[allow(unused_variables)]
375 fn map_int<const D: usize>(
376 &mut self,
377 param: Param<Tensor<B, D, Int>>,
378 ) -> Param<Tensor<B, D, Int>> {
379 let (id, tensor, mapper) = param.consume();
380 Param::from_mapped_value(id, tensor, mapper)
381 }
382
383 /// Map a bool parameter in the module.
384 ///
385 /// # Parameters
386 /// - `param`: The boolean parameter to transform
387 ///
388 /// # Returns
389 /// The transformed parameter
390 #[allow(unused_variables)]
391 fn map_bool<const D: usize>(
392 &mut self,
393 param: Param<Tensor<B, D, Bool>>,
394 ) -> Param<Tensor<B, D, Bool>> {
395 let (id, tensor, mapper) = param.consume();
396 Param::from_mapped_value(id, tensor, mapper)
397 }
398}
399
400/// Module with auto-differentiation backend.
401pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
402 /// Inner module without auto-differentiation.
403 type InnerModule: Module<B::InnerBackend>;
404
405 /// Get the same module, but on the inner backend without auto-differentiation.
406 fn valid(&self) -> Self::InnerModule;
407}