burn_core/module/base.rs
1use super::{ParamId, Quantizer};
2use crate::{
3 record::Record,
4 tensor::backend::{AutodiffBackend, Backend},
5};
6use alloc::vec::Vec;
7pub use burn_derive::Module;
8use burn_tensor::{Bool, Int, Tensor, ops::Device};
9
10/// Type alias to `Vec<B::Device>` which supports `no_std` environments, but automatically using
11/// the `alloc` crate.
12pub type Devices<B> = Vec<Device<B>>;
13
14// At the moment, our plan is to continue experimenting with the macro internally and monitor its development.
15// We may consider making it public in the future.
16macro_rules! module {
17 (map=$module:ident, ops=$item:expr) => {{
18 struct Mapper;
19 impl<B: Backend> ModuleMapper<B> for Mapper {
20 fn map_float<const D: usize>(
21 &mut self,
22 _id: ParamId,
23 tensor: Tensor<B, D>,
24 ) -> Tensor<B, D> {
25 let func = $item;
26 func(tensor)
27 }
28 }
29 let mut mapper = Mapper;
30 $module.map(&mut mapper)
31 }};
32 (visit_float=$module:ident, ops=$item:expr, state=$state_ty:ty, init=$init:expr) => {{
33 struct Visitor<'a, B: Backend> {
34 state: &'a mut $state_ty,
35 backend: core::marker::PhantomData<B>,
36 }
37 impl<'a, B: Backend> ModuleVisitor<B> for Visitor<'a, B> {
38 fn visit_float<const D: usize>(&mut self, _id: ParamId, tensor: &Tensor<B, D>) {
39 let func = $item;
40 func(tensor, &mut self.state)
41 }
42 }
43 #[allow(clippy::redundant_closure_call)]
44 let mut state = $init();
45 let mut visitor = Visitor {
46 state: &mut state,
47 backend: core::marker::PhantomData,
48 };
49 $module.visit(&mut visitor);
50 state
51 }};
52}
53
54/// Trait for all neural network modules.
55///
56/// Modules should be created using the [derive](burn_derive::Module) attribute.
57/// This will make your module trainable, savable and loadable via
58/// `state` and `load`.
59///
60/// # Example
61///
62/// A module should have a [backend](crate::tensor::backend::Backend) defined as a generic
63/// parameter B. This will be used by the [derive](burn_derive::Module) attribute to generate the code
64/// necessary to optimize and train the module on any backend.
65///
66/// ```no_run
67/// // Not necessary when using the burn crate directly.
68/// use burn_core as burn;
69///
70/// use burn::{
71/// nn,
72/// module::Module,
73/// tensor::Tensor,
74/// tensor::backend::Backend,
75/// };
76///
77/// #[derive(Module, Debug)]
78/// struct MyModule<B: Backend> {
79/// my_param: nn::Linear<B>,
80/// my_other_field: usize,
81/// }
82/// ```
83pub trait Module<B: Backend>: Clone + Send + core::fmt::Debug {
84 /// Type to save and load the module.
85 type Record: Record<B>;
86
87 /// Return all the devices found in the underneath module tree added to the given vector
88 /// without duplicates.
89 fn collect_devices(&self, devices: Devices<B>) -> Devices<B>;
90
91 /// Return all the devices found in the underneath module tree without duplicates.
92 fn devices(&self) -> Devices<B> {
93 self.collect_devices(Devices::<B>::new())
94 }
95
96 /// Fork the module and all of its sub-modules to the given device.
97 ///
98 /// # Notes
99 ///
100 /// This is similar to [to_device](Module::to_device), but it ensures the output module on the
101 /// new device will have its own autodiff graph.
102 fn fork(self, device: &B::Device) -> Self;
103
104 /// Move the module and all of its sub-modules to the given device.
105 ///
106 /// # Warnings
107 ///
108 /// The operation supports autodiff and it will be registered when activated. However, this may
109 /// not be what you want. The output model will be an intermediary model, meaning that you
110 /// can't optimize it with gradient descent. If you want to optimize the output network on the
111 /// target device, use [fork](Module::fork) instead.
112 fn to_device(self, device: &B::Device) -> Self;
113
114 /// Each tensor in the module tree will not require grad.
115 ///
116 /// # Warnings
117 ///
118 /// This should not be used for inference, use [valid](AutodiffModule::valid) when using
119 /// AD modules. This is mostly useful when performing partial finetuning, which is updating only
120 /// a small fraction of the parameters instead of finetuning all of them.
121 fn no_grad(self) -> Self {
122 module!(
123 map = self,
124 ops = |tensor: Tensor<B, D>| tensor.set_require_grad(false)
125 )
126 }
127
128 /// Get the number of parameters the module has, including all of its sub-modules.
129 fn num_params(&self) -> usize {
130 module!(
131 visit_float = self,
132 ops = |tensor: &Tensor<B, D>, state: &mut usize| {
133 *state += tensor.shape().num_elements();
134 },
135 state = usize,
136 init = || 0
137 )
138 }
139 /// Visit each tensor parameter in the module with a [visitor](ModuleVisitor).
140 fn visit<Visitor: ModuleVisitor<B>>(&self, visitor: &mut Visitor);
141
142 /// Map each tensor parameter in the module with a [mapper](ModuleMapper).
143 fn map<Mapper: ModuleMapper<B>>(self, mapper: &mut Mapper) -> Self;
144
145 /// Load the module state from a record.
146 fn load_record(self, record: Self::Record) -> Self;
147
148 /// Convert the module into a record containing the state.
149 fn into_record(self) -> Self::Record;
150
151 #[cfg(feature = "std")]
152 /// Save the module to a file using the provided [file recorder](crate::record::FileRecorder).
153 ///
154 /// List of supported file recorders:
155 ///
156 /// * [default](crate::record::DefaultFileRecorder)
157 /// * [bincode](crate::record::BinFileRecorder)
158 /// * [bincode compressed with gzip](crate::record::BinGzFileRecorder)
159 /// * [json pretty](crate::record::PrettyJsonFileRecorder)
160 /// * [json compressed with gzip](crate::record::JsonGzFileRecorder)
161 /// * [named mpk](crate::record::NamedMpkFileRecorder)
162 /// * [named mpk compressed with gzip](crate::record::NamedMpkGzFileRecorder)
163 ///
164 /// ## Notes
165 ///
166 /// The file extension is automatically added depending on the file recorder provided, you
167 /// don't have to specify it.
168 fn save_file<FR, PB>(
169 self,
170 file_path: PB,
171 recorder: &FR,
172 ) -> Result<(), crate::record::RecorderError>
173 where
174 FR: crate::record::FileRecorder<B>,
175 PB: Into<std::path::PathBuf>,
176 {
177 let record = Self::into_record(self);
178 recorder.record(record, file_path.into())
179 }
180
181 #[cfg(feature = "std")]
182 /// Load the module from a file using the provided [file recorder](crate::record::FileRecorder).
183 ///
184 /// The recorder should be the same as the one used to save the module, see
185 /// [save_file](Self::save_file).
186 ///
187 /// ## Notes
188 ///
189 /// The file extension is automatically added depending on the file recorder provided, you
190 /// don't have to specify it.
191 fn load_file<FR, PB>(
192 self,
193 file_path: PB,
194 recorder: &FR,
195 device: &B::Device,
196 ) -> Result<Self, crate::record::RecorderError>
197 where
198 FR: crate::record::FileRecorder<B>,
199 PB: Into<std::path::PathBuf>,
200 {
201 let record = recorder.load(file_path.into(), device)?;
202
203 Ok(self.load_record(record))
204 }
205
206 /// Quantize the weights of the module.
207 fn quantize_weights(self, quantizer: &mut Quantizer) -> Self {
208 self.map(quantizer)
209 }
210}
211
212/// Module visitor trait.
213pub trait ModuleVisitor<B: Backend> {
214 /// Visit a float tensor in the module.
215 fn visit_float<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D>) {}
216 /// Visit an int tensor in the module.
217 fn visit_int<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D, Int>) {}
218 /// Visit a bool tensor in the module.
219 fn visit_bool<const D: usize>(&mut self, _id: ParamId, _tensor: &Tensor<B, D, Bool>) {}
220}
221
222/// Module mapper trait.
223pub trait ModuleMapper<B: Backend> {
224 /// Map a float tensor in the module.
225 fn map_float<const D: usize>(&mut self, _id: ParamId, tensor: Tensor<B, D>) -> Tensor<B, D> {
226 tensor
227 }
228 /// Map an int tensor in the module.
229 fn map_int<const D: usize>(
230 &mut self,
231 _id: ParamId,
232 tensor: Tensor<B, D, Int>,
233 ) -> Tensor<B, D, Int> {
234 tensor
235 }
236 /// Map a bool tensor in the module.
237 fn map_bool<const D: usize>(
238 &mut self,
239 _id: ParamId,
240 tensor: Tensor<B, D, Bool>,
241 ) -> Tensor<B, D, Bool> {
242 tensor
243 }
244}
245
246/// Module with auto-differentiation backend.
247pub trait AutodiffModule<B: AutodiffBackend>: Module<B> + Send + core::fmt::Debug {
248 /// Inner module without auto-differentiation.
249 type InnerModule: Module<B::InnerBackend>;
250
251 /// Get the same module, but on the inner backend without auto-differentiation.
252 fn valid(&self) -> Self::InnerModule;
253}