burn_core/module/param/base.rs
1use super::ParamId;
2use alloc::boxed::Box;
3use alloc::format;
4use burn_common::stub::RwLock;
5use core::cell::OnceCell;
6use core::ops::Deref;
7
8/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they
9/// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during
10/// training, and loaded during inference. If you don't want to save the tensors with a record
11/// and/or don't want to update it during training, you don't need this type to wrap your tensor.
12///
13/// # Laziness
14///
15/// The initialization of parameters can be lazy when created using
16/// [uninitialized](Self::uninitialized), which can be done using an [initializer](crate::nn::Initializer).
17///
18/// This reduces the amount of allocations done when loading a model for inference without having
19/// to create a custom initialization function only for inference.
20///
21/// ## Example
22///
23/// ```rust, ignore
24/// let device = Device::default();
25/// let config = ModuleConfig::default();
26/// let record = Recorder::new().load("/path/to/module", &device);
27///
28/// // No tensor allocation
29/// let module = config.init(device);
30/// // Will use the tensor allocated for the record if the same device is used.
31/// let module = module.load_record(record);
32/// ```
33pub struct Param<T: Parameter> {
34    /// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter.
35    pub id: ParamId,
36    state: OnceCell<T>,
37    /// The locking is only required because of `lazy_device` and `lazy_is_require_grad`.
38    ///
39    /// Because of once cell, we have a guarantee that the initialization will only be called once,
40    /// but it may be called at the same time as `lazy_device` and `lazy_is_require_grad`, which is
41    /// when the lock is actually useful, waiting for the initialization to be completed before
42    /// returning the value.
43    initialization: Option<RwLock<Option<Uninitialized<T>>>>,
44}
45
46impl<T: Parameter> core::fmt::Display for Param<T> {
47    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
48        f.write_str(format!("Param: {}", self.id).as_str())
49    }
50}
51
52impl<T: Parameter> core::fmt::Debug for Param<T> {
53    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
54        f.write_str(format!("Param: {}", self.id).as_str())
55    }
56}
57
58/// Trait that defines what is necessary for a type to be a parameter.
59pub trait Parameter: Clone + core::fmt::Debug + Send {
60    /// The device type to be used.
61    type Device: Clone;
62
63    /// Fetch the device.
64    fn device(&self) -> Self::Device;
65
66    /// Fetch the gradient requirement.
67    fn is_require_grad(&self) -> bool;
68
69    /// Set the gradient requirement.
70    fn set_require_grad(self, require_grad: bool) -> Self;
71}
72
73#[allow(clippy::type_complexity)]
74struct Uninitialized<P: Parameter> {
75    init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
76    device: P::Device,
77    is_require_grad: bool,
78}
79
80impl<P: Parameter> Uninitialized<P> {
81    fn initialize(self) -> P {
82        let init = self.init;
83        init(&self.device, self.is_require_grad)
84    }
85}
86
87impl<T: Parameter> Param<T> {
88    /// Create a new parameter that is already initialized.
89    pub fn initialized(id: ParamId, value: T) -> Self {
90        Self {
91            id,
92            state: OnceCell::from(value),
93            initialization: None,
94        }
95    }
96
97    /// Create a new parameter that is not already initialized.
98    pub fn uninitialized<F>(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self
99    where
100        F: FnOnce(&T::Device, bool) -> T + Send + 'static,
101    {
102        Self {
103            id,
104            state: OnceCell::new(),
105            initialization: Some(RwLock::new(Some(Uninitialized {
106                init: Box::new(init),
107                device,
108                is_require_grad,
109            }))),
110        }
111    }
112
113    /// Gets the parameter value.
114    pub fn val(&self) -> T {
115        self.state
116            .get_or_init(|| {
117                let mut result = self
118                    .initialization
119                    .as_ref()
120                    .expect("Should have an initialization when no state provided.")
121                    .write()
122                    .unwrap();
123                let state = result.take().expect("Should exist when not initialized");
124                state.initialize()
125            })
126            .clone()
127    }
128
129    /// Gets the parameter's value while consuming the parameter.
130    pub fn into_value(self) -> T {
131        self.consume().1
132    }
133
134    /// Gets the parameter id and value while consuming the parameter.
135    pub fn consume(self) -> (ParamId, T) {
136        let tensor = self.val();
137
138        core::mem::drop(self.state);
139
140        (self.id, tensor)
141    }
142
143    /// Execute the given function on the inner value.
144    pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
145        let (id, tensor) = self.consume();
146        let tensor = func(tensor);
147
148        Self {
149            id,
150            state: OnceCell::from(tensor),
151            initialization: None,
152        }
153    }
154
155    /// The device on which the parameter is or will be initialized.
156    ///
157    /// This should be used instead of [crate::tensor::Tensor::device], since using the tensor
158    /// function requires a dereference, which triggers the initialization. This is only useful
159    /// when the device is used for updating the tensor value, which has potentially not been
160    /// initialized yet, like loading a record.
161    ///
162    /// # Notes
163    ///
164    /// This is a crate-private function, since users are not expected to use the device of an
165    /// uninitialized module to then override its value. All low-level functions should be provided
166    /// by `burn` and should handle those details.
167    pub(crate) fn lazy_device(&self) -> T::Device {
168        let initialization = match &self.initialization {
169            Some(init) => init,
170            None => return self.device(),
171        };
172
173        let init = initialization.read().unwrap();
174
175        match init.as_ref() {
176            Some(value) => value.device.clone(),
177            None => self.device(),
178        }
179    }
180
181    /// The gradient requirement on which the parameter is or will be initialized.
182    ///
183    /// This should be used instead of [crate::tensor::Tensor::is_require_grad], since using the tensor
184    /// function requires a dereference, which triggers the initialization. This is only useful
185    /// when the boolean is used for updating the tensor value, which has potentially not been
186    /// initialized yet, like loading a record.
187    ///
188    /// # Notes
189    ///
190    /// This is a crate-private function, since users are not expected to use `is_require_grad` of an
191    /// uninitialized module to then override its value. All low-level functions should be provided
192    /// by `burn` and should handle those details.
193    pub(crate) fn lazy_is_require_grad(&self) -> bool {
194        let initialization = match &self.initialization {
195            Some(init) => init,
196            None => return self.is_require_grad(),
197        };
198
199        let init = initialization.read().unwrap();
200
201        match init.as_ref() {
202            Some(value) => value.is_require_grad,
203            None => self.is_require_grad(),
204        }
205    }
206
207    /// Override the gradient requirement for the current parameter.
208    pub fn set_require_grad(self, require_grad: bool) -> Self {
209        let initialization = match &self.initialization {
210            Some(init) => init,
211            None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
212        };
213
214        let mut init = initialization.write().unwrap();
215        let mut is_lazy = false;
216
217        if let Some(value) = init.as_mut() {
218            is_lazy = true;
219            value.is_require_grad = require_grad;
220        };
221
222        core::mem::drop(init);
223
224        if is_lazy {
225            return self;
226        }
227
228        self.map(|tensor| tensor.set_require_grad(require_grad))
229    }
230}
231
232impl<T: Parameter> Clone for Param<T> {
233    fn clone(&self) -> Self {
234        Param::initialized(self.id, self.val())
235    }
236}
237
238impl<T: Parameter> Deref for Param<T> {
239    type Target = T;
240
241    fn deref(&self) -> &Self::Target {
242        self.state.get_or_init(|| {
243            let mut result = self
244                .initialization
245                .as_ref()
246                .expect("Should have an initialization when no state provided.")
247                .write()
248                .unwrap();
249
250            let state = result.take().expect("Should exist when not initialized");
251            state.initialize()
252        })
253    }
254}