burn_core/module/base.rs
1use super::{Param, ParamId, Quantizer};
2use crate::{
3 record::Record,
4 tensor::backend::{AutodiffBackend, Backend},
5};
6use alloc::{string::String, vec::Vec};
7pub use burn_derive::Module;
8use burn_tensor::{Bool, Int, Tensor, ops::Device};
9
10/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
11/// the `alloc` crate.
12pub type Devices<B> = Vec<Device<B>>;
13
14// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
15// We may consider making it public in the future.
16macro_rules! module {
17 (map=$module:ident, ops=$item:expr) => {{
18 struct Mapper;
19 impl<B: Backend> ModuleMapper<B> for Mapper {
20 fn map_float<const D: usize>(
21 &mut self,
22 param: Param<Tensor<B, D>>,
23 ) -> Param<Tensor<B, D>> {
24 let (id, tensor, mapper) = param.consume();
25 let func = $item;
26 let tensor = func(tensor);
27 Param::from_mapped_value(id, tensor, mapper)
28 }
29 }
30 let mut mapper = Mapper;
31 $module.map(&mut mapper)
32 }};
33 (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
34 struct Visitor<'a, B: Backend> {
35 state: &'a mut $state_ty,
36 backend: core::marker::PhantomData<B>,
37 }
38 impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
39 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {
40 let func = $item;
41 func(¶m.val(), &mut self.state)
42 }
43 }
44 #[allow(clippy::redundant_closure_call)]
45 let mut state = $init();
46 let mut visitor = Visitor {
47 state: &mut state,
48 backend: core::marker::PhantomData,
49 };
50 $module.visit(&mut visitor);
51 state
52 }};
53}
54
55/// Trait for all neural network modules.
56///
57/// Modules should be created using the [derive](burn_derive::Module) attribute.
58/// This will make your module trainable, savable and loadable via
59/// `state` and `load`.
60///
61/// # Example
62///
63/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
64/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
65/// necessary to optimize and train the module on any backend.
66///
67/// ```rust, ignore
68/// // Not necessary when using the burn crate directly.
69/// use burn_core as burn;
70///
71/// use burn::{
72/// module::Module,
73/// nn::Linear,
74/// tensor::Tensor,
75/// tensor::backend::Backend,
76/// };
77///
78/// #[derive(Module, Debug)]
79/// struct MyModule<B: Backend> {
80/// my_param: Linear<B>,
81/// my_other_field: usize,
82/// }
83/// ```
84pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
85 /// Type to save and load the module.
86 type Record: Record<B>;
87
88 /// Return all the devices found in the underneath module tree added to the given vector
89 /// without duplicates.
90 fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
91
92 /// Return all the devices found in the underneath module tree without duplicates.
93 fn devices(&self) -> Devices<B> {
94 self.collect_devices(Devices::<B>::new())
95 }
96
97 /// Fork the module and all of its sub-modules to the given device.
98 ///
99 /// # Notes
100 ///
101 /// This is similar to [to_device](Module::to_device), but it ensures the output module on the
102 /// new device will have its own autodiff graph.
103 fn fork(self, device: &B::Device) -> Self;
104
105 /// Move the module and all of its sub-modules to the given device.
106 ///
107 /// # Warnings
108 ///
109 /// The operation supports autodiff and it will be registered when activated. However, this may
110 /// not be what you want. The output model will be an intermediary model, meaning that you
111 /// can't optimize it with gradient descent. If you want to optimize the output network on the
112 /// target device, use [fork](Module::fork) instead.
113 fn to_device(self, device: &B::Device) -> Self;
114
115 /// Each tensor in the module tree will not require grad.
116 ///
117 /// # Warnings
118 ///
119 /// This should not be used for inference, use [valid](AutodiffModule::valid) when using
120 /// AD modules. This is mostly useful when performing partial finetuning, which is updating only
121 /// a small fraction of the parameters instead of finetuning all of them.
122 fn no_grad(self) -> Self {
123 module!(
124 map = self,
125 ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
126 )
127 }
128
129 /// Get the number of parameters the module has, including all of its sub-modules.
130 fn num_params(&self) -> usize {
131 module!(
132 visit_float = self,
133 ops = |tensor: &Tensor<B, D>, state: &mut usize| {
134 *state += tensor.shape().num_elements();
135 },
136 state = usize,
137 init = || 0
138 )
139 }
140 /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
141 fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
142
143 /// Map each tensor parameter in the module with a [mapper](ModuleMapper).
144 fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
145
146 /// Load the module state from a record.
147 fn load_record(self, record: Self::Record) -> Self;
148
149 /// Convert the module into a record containing the state.
150 fn into_record(self) -> Self::Record;
151
152 #[cfg(feature = "std")]
153 /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
154 ///
155 /// List of supported file recorders:
156 ///
157 /// * [default](crate::record::DefaultFileRecorder)
158 /// * [bincode](crate::record::BinFileRecorder)
159 /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
160 /// * [json pretty](crate::record::PrettyJsonFileRecorder)
161 /// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
162 /// * [named mpk](crate::record::NamedMpkFileRecorder)
163 /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
164 ///
165 /// ## Notes
166 ///
167 /// The file extension is automatically added depending on the file recorder provided, you
168 /// don't have to specify it.
169 fn save_file<FR, PB>(
170 self,
171 file_path: PB,
172 recorder: &FR,
173 ) -> Result<(), crate::record::RecorderError>
174 where
175 FR: crate::record::FileRecorder<B>,
176 PB: Into<std::path::PathBuf>,
177 {
178 let record = Self::into_record(self);
179 recorder.record(record, file_path.into())
180 }
181
182 #[cfg(feature = "std")]
183 /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
184 ///
185 /// The recorder should be the same as the one used to save the module, see
186 /// [save_file](Self::save_file).
187 ///
188 /// ## Notes
189 ///
190 /// The file extension is automatically added depending on the file recorder provided, you
191 /// don't have to specify it.
192 fn load_file<FR, PB>(
193 self,
194 file_path: PB,
195 recorder: &FR,
196 device: &B::Device,
197 ) -> Result<Self, crate::record::RecorderError>
198 where
199 FR: crate::record::FileRecorder<B>,
200 PB: Into<std::path::PathBuf>,
201 {
202 let record = recorder.load(file_path.into(), device)?;
203
204 Ok(self.load_record(record))
205 }
206
207 /// Quantize the weights of the module.
208 fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
209 self.map(quantizer)
210 }
211}
212
213/// Module visitor trait for traversing and inspecting module parameters.
214pub trait ModuleVisitor<B: Backend> {
215 /// Visit a float parameter in the module.
216 ///
217 /// # Parameters
218 /// - `param`: The float parameter to visit
219 #[allow(unused_variables)]
220 fn visit_float<const D: usize>(&mut self, param: &Param<Tensor<B, D>>) {}
221
222 /// Visit an int parameter in the module.
223 ///
224 /// # Parameters
225 /// - `param`: The integer parameter to visit
226 #[allow(unused_variables)]
227 fn visit_int<const D: usize>(&mut self, param: &Param<Tensor<B, D, Int>>) {}
228
229 /// Visit a bool parameter in the module.
230 ///
231 /// # Parameters
232 /// - `param`: The boolean parameter to visit
233 #[allow(unused_variables)]
234 fn visit_bool<const D: usize>(&mut self, param: &Param<Tensor<B, D, Bool>>) {}
235
236 /// Called when entering a submodule.
237 ///
238 /// # Parameters
239 /// - `name`: The name of the submodule being entered
240 /// - `container_type`: The type of the container (e.g., "Module", "Vec", etc.)
241 #[allow(unused_variables)]
242 fn enter_module(&mut self, name: &str, container_type: &str) {}
243
244 /// Called when exiting a submodule.
245 ///
246 /// # Parameters
247 /// - `name`: The name of the submodule being exited
248 /// - `container_type`: The type of the container (e.g., "Module", "Vec", etc.)
249 #[allow(unused_variables)]
250 fn exit_module(&mut self, name: &str, container_type: &str) {}
251
252 /// Visit a float tensor with its full module path.
253 ///
254 /// # Parameters
255 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
256 /// Each element represents a module name in the hierarchy, with the final element
257 /// being the parameter name. This allows efficient reuse of the path stack.
258 /// - `id`: The unique identifier of the parameter
259 /// - `tensor`: The float tensor to visit
260 #[allow(unused_variables)]
261 fn visit_float_with_path<const D: usize>(
262 &mut self,
263 path: &[String],
264 id: ParamId,
265 tensor: &Tensor<B, D>,
266 ) {
267 }
268
269 /// Visit an int tensor with its full module path.
270 ///
271 /// # Parameters
272 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
273 /// Each element represents a module name in the hierarchy, with the final element
274 /// being the parameter name. This allows efficient reuse of the path stack.
275 /// - `id`: The unique identifier of the parameter
276 /// - `tensor`: The integer tensor to visit
277 #[allow(unused_variables)]
278 fn visit_int_with_path<const D: usize>(
279 &mut self,
280 path: &[String],
281 id: ParamId,
282 tensor: &Tensor<B, D, Int>,
283 ) {
284 }
285
286 /// Visit a bool tensor with its full module path.
287 ///
288 /// # Parameters
289 /// - `path`: The path components to the tensor as a slice (e.g., &["encoder", "layer1", "weight"]).
290 /// Each element represents a module name in the hierarchy, with the final element
291 /// being the parameter name. This allows efficient reuse of the path stack.
292 /// - `id`: The unique identifier of the parameter
293 /// - `tensor`: The boolean tensor to visit
294 #[allow(unused_variables)]
295 fn visit_bool_with_path<const D: usize>(
296 &mut self,
297 path: &[String],
298 id: ParamId,
299 tensor: &Tensor<B, D, Bool>,
300 ) {
301 }
302}
303
304/// Module mapper trait for transforming module parameters.
305pub trait ModuleMapper<B: Backend> {
306 /// Called when entering a submodule.
307 ///
308 /// # Parameters
309 /// - `name`: The name of the submodule being entered
310 /// - `container_type`: The type of the container (e.g., "Module", "Vec", etc.)
311 #[allow(unused_variables)]
312 fn enter_module(&mut self, name: &str, container_type: &str) {}
313
314 /// Called when exiting a submodule.
315 ///
316 /// # Parameters
317 /// - `name`: The name of the submodule being exited
318 /// - `container_type`: The type of the container (e.g., "Module", "Vec", etc.)
319 #[allow(unused_variables)]
320 fn exit_module(&mut self, name: &str, container_type: &str) {}
321
322 /// Map a float parameter in the module.
323 ///
324 /// # Parameters
325 /// - `param`: The float parameter to transform
326 ///
327 /// # Returns
328 /// The transformed parameter
329 #[allow(unused_variables)]
330 fn map_float<const D: usize>(&mut self, param: Param<Tensor<B, D>>) -> Param<Tensor<B, D>> {
331 let (id, tensor, mapper) = param.consume();
332 Param::from_mapped_value(id, tensor, mapper)
333 }
334
335 /// Map an int parameter in the module.
336 ///
337 /// # Parameters
338 /// - `param`: The integer parameter to transform
339 ///
340 /// # Returns
341 /// The transformed parameter
342 #[allow(unused_variables)]
343 fn map_int<const D: usize>(
344 &mut self,
345 param: Param<Tensor<B, D, Int>>,
346 ) -> Param<Tensor<B, D, Int>> {
347 let (id, tensor, mapper) = param.consume();
348 Param::from_mapped_value(id, tensor, mapper)
349 }
350
351 /// Map a bool parameter in the module.
352 ///
353 /// # Parameters
354 /// - `param`: The boolean parameter to transform
355 ///
356 /// # Returns
357 /// The transformed parameter
358 #[allow(unused_variables)]
359 fn map_bool<const D: usize>(
360 &mut self,
361 param: Param<Tensor<B, D, Bool>>,
362 ) -> Param<Tensor<B, D, Bool>> {
363 let (id, tensor, mapper) = param.consume();
364 Param::from_mapped_value(id, tensor, mapper)
365 }
366}
367
368/// Module with auto-differentiation backend.
369pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
370 /// Inner module without auto-differentiation.
371 type InnerModule: Module<B::InnerBackend>;
372
373 /// Get the same module, but on the inner backend without auto-differentiation.
374 fn valid(&self) -> Self::InnerModule;
375}