burn_core/module/param/base.rs
1use super::ParamId;
2use alloc::boxed::Box;
3use alloc::format;
4use burn_common::stub::RwLock;
5use core::cell::OnceCell;
6use core::ops::Deref;
7
8/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they
9/// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during
10/// training, and loaded during inference. If you don't want to save the tensors with a record
11/// and/or don't want to update it during training, you don't need this type to wrap your tensor.
12///
13/// # Laziness
14///
15/// The initialization of parameters can be lazy when created using
16/// [uninitialized](Self::uninitialized), which can be done using an [initializer](crate::nn::Initializer).
17///
18/// This reduces the amount of allocations done when loading a model for inference without having
19/// to create a custom initialization function only for inference.
20///
21/// ## Example
22///
23/// ```rust, ignore
24/// let device = Device::default();
25/// let config = ModuleConfig::default();
26/// let record = Recorder::new().load("/path/to/module", &device);
27///
28/// // No tensor allocation
29/// let module = config.init(device);
30/// // Will use the tensor allocated for the record if the same device is used.
31/// let module = module.load_record(record);
32/// ```
33pub struct Param<T: Parameter> {
34 /// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter.
35 pub id: ParamId,
36 state: OnceCell<T>,
37 /// The locking is only required because of `lazy_device` and `lazy_is_require_grad`.
38 ///
39 /// Because of once cell, we have a guarantee that the initialization will only be called once,
40 /// but it may be called at the same time as `lazy_device` and `lazy_is_require_grad`, which is
41 /// when the lock is actually useful, waiting for the initialization to be completed before
42 /// returning the value.
43 initialization: Option<RwLock<Option<Uninitialized<T>>>>,
44}
45
46impl<T: Parameter> core::fmt::Display for Param<T> {
47 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48 f.write_str(format!("Param: {}", self.id).as_str())
49 }
50}
51
52impl<T: Parameter> core::fmt::Debug for Param<T> {
53 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
54 f.write_str(format!("Param: {}", self.id).as_str())
55 }
56}
57
58/// Trait that defines what is necessary for a type to be a parameter.
59pub trait Parameter: Clone + core::fmt::Debug + Send {
60 /// The device type to be used.
61 type Device: Clone;
62
63 /// Fetch the device.
64 fn device(&self) -> Self::Device;
65
66 /// Fetch the gradient requirement.
67 fn is_require_grad(&self) -> bool;
68
69 /// Set the gradient requirement.
70 fn set_require_grad(self, require_grad: bool) -> Self;
71}
72
73#[allow(clippy::type_complexity)]
74struct Uninitialized<P: Parameter> {
75 init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
76 device: P::Device,
77 is_require_grad: bool,
78}
79
80impl<P: Parameter> Uninitialized<P> {
81 fn initialize(self) -> P {
82 let init = self.init;
83 init(&self.device, self.is_require_grad)
84 }
85}
86
87impl<T: Parameter> Param<T> {
88 /// Create a new parameter that is already initialized.
89 pub fn initialized(id: ParamId, value: T) -> Self {
90 Self {
91 id,
92 state: OnceCell::from(value),
93 initialization: None,
94 }
95 }
96
97 /// Create a new parameter that is not already initialized.
98 pub fn uninitialized<F>(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self
99 where
100 F: FnOnce(&T::Device, bool) -> T + Send + 'static,
101 {
102 Self {
103 id,
104 state: OnceCell::new(),
105 initialization: Some(RwLock::new(Some(Uninitialized {
106 init: Box::new(init),
107 device,
108 is_require_grad,
109 }))),
110 }
111 }
112
113 /// Gets the parameter value.
114 pub fn val(&self) -> T {
115 self.state
116 .get_or_init(|| {
117 let mut result = self
118 .initialization
119 .as_ref()
120 .expect("Should have an initialization when no state provided.")
121 .write()
122 .unwrap();
123 let state = result.take().expect("Should exist when not initialized");
124 state.initialize()
125 })
126 .clone()
127 }
128
129 /// Gets the parameter's value while consuming the parameter.
130 pub fn into_value(self) -> T {
131 self.consume().1
132 }
133
134 /// Gets the parameter id and value while consuming the parameter.
135 pub fn consume(self) -> (ParamId, T) {
136 let tensor = self.val();
137
138 core::mem::drop(self.state);
139
140 (self.id, tensor)
141 }
142
143 /// Execute the given function on the inner value.
144 pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
145 let (id, tensor) = self.consume();
146 let tensor = func(tensor);
147
148 Self {
149 id,
150 state: OnceCell::from(tensor),
151 initialization: None,
152 }
153 }
154
155 /// The device on which the parameter is or will be initialized.
156 ///
157 /// This should be used instead of [crate::tensor::Tensor::device], since using the tensor
158 /// function requires a dereference, which triggers the initialization. This is only useful
159 /// when the device is used for updating the tensor value, which has potentially not been
160 /// initialized yet, like loading a record.
161 ///
162 /// # Notes
163 ///
164 /// This is a crate-private function, since users are not expected to use the device of an
165 /// uninitialized module to then override its value. All low-level functions should be provided
166 /// by `burn` and should handle those details.
167 pub(crate) fn lazy_device(&self) -> T::Device {
168 let initialization = match &self.initialization {
169 Some(init) => init,
170 None => return self.device(),
171 };
172
173 let init = initialization.read().unwrap();
174
175 match init.as_ref() {
176 Some(value) => value.device.clone(),
177 None => self.device(),
178 }
179 }
180
181 /// The gradient requirement on which the parameter is or will be initialized.
182 ///
183 /// This should be used instead of [crate::tensor::Tensor::is_require_grad], since using the tensor
184 /// function requires a dereference, which triggers the initialization. This is only useful
185 /// when the boolean is used for updating the tensor value, which has potentially not been
186 /// initialized yet, like loading a record.
187 ///
188 /// # Notes
189 ///
190 /// This is a crate-private function, since users are not expected to use `is_require_grad` of an
191 /// uninitialized module to then override its value. All low-level functions should be provided
192 /// by `burn` and should handle those details.
193 pub(crate) fn lazy_is_require_grad(&self) -> bool {
194 let initialization = match &self.initialization {
195 Some(init) => init,
196 None => return self.is_require_grad(),
197 };
198
199 let init = initialization.read().unwrap();
200
201 match init.as_ref() {
202 Some(value) => value.is_require_grad,
203 None => self.is_require_grad(),
204 }
205 }
206
207 /// Override the gradient requirement for the current parameter.
208 pub fn set_require_grad(self, require_grad: bool) -> Self {
209 let initialization = match &self.initialization {
210 Some(init) => init,
211 None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
212 };
213
214 let mut init = initialization.write().unwrap();
215 let mut is_lazy = false;
216
217 if let Some(value) = init.as_mut() {
218 is_lazy = true;
219 value.is_require_grad = require_grad;
220 };
221
222 core::mem::drop(init);
223
224 if is_lazy {
225 return self;
226 }
227
228 self.map(|tensor| tensor.set_require_grad(require_grad))
229 }
230}
231
232impl<T: Parameter> Clone for Param<T> {
233 fn clone(&self) -> Self {
234 Param::initialized(self.id, self.val())
235 }
236}
237
238impl<T: Parameter> Deref for Param<T> {
239 type Target = T;
240
241 fn deref(&self) -> &Self::Target {
242 self.state.get_or_init(|| {
243 let mut result = self
244 .initialization
245 .as_ref()
246 .expect("Should have an initialization when no state provided.")
247 .write()
248 .unwrap();
249
250 let state = result.take().expect("Should exist when not initialized");
251 state.initialize()
252 })
253 }
254}