burn_core/module/param/
base.rs

1use super::ParamId;
2use alloc::{boxed::Box, format};
3use burn_common::stub::RwLock;
4use burn_tensor::Shape;
5use core::cell::OnceCell;
6use core::ops::Deref;
7
8#[cfg(target_has_atomic = "ptr")]
9use alloc::sync::Arc;
10
11#[cfg(not(target_has_atomic = "ptr"))]
12use portable_atomic_util::Arc;
13
14#[cfg(target_has_atomic = "ptr")]
15type Mapper<T> = Arc<dyn Fn(T) -> T + Send + Sync>;
16
17#[cfg(not(target_has_atomic = "ptr"))]
18type Mapper<T> = Arc<Box<dyn Fn(T) -> T + Send + Sync>>;
19
20#[cfg(target_has_atomic = "ptr")]
21fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
22    Arc::new(func)
23}
24
25#[cfg(not(target_has_atomic = "ptr"))]
26fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
27    Arc::new(Box::new(func))
28}
29
30/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they
31/// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during
32/// training, and loaded during inference. If you don't want to save the tensors
33/// and/or don't want to update it during training, you don't need this type to wrap your tensor.
34///
35/// # Core Lazy Initialization Architecture
36///
37/// `Param<T>` has a dual-state design using `OnceCell<T>`:
38///
39/// ## State Management
40///
41/// **Two possible states:**
42///
43/// 1. **Initialized**: `state: OnceCell<T>` contains value, `initialization: None`
44/// 2. **Uninitialized (Lazy)**: `state` is empty, `initialization: Some(RwLock<Option<Uninitialized<T>>>)`
45pub struct Param<T: Parameter> {
46    /// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter.
47    pub id: ParamId,
48    /// The OnceCell holding the initialized parameter value.
49    /// Empty for uninitialized parameters, populated after first access or explicit initialization.
50    pub(crate) state: OnceCell<T>,
51    /// The deferred initialization state for lazy parameters.
52    ///
53    /// **State Transitions:**
54    /// - Initialized params: `None`
55    /// - Uninitialized params: `Some(RwLock<Some(Uninitialized<T>)>)`
56    /// - After lazy init triggers: `Some(RwLock<None>)` (inner Option is taken)
57    pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,
58    pub(crate) param_mapper: ParamMapper<T>,
59}
60
61#[derive(Clone)]
62/// Applies transformations when loading and saving parameters.
63///
64/// # Mapper System
65///
66/// `ParamMapper<T>` allows applying transformations during serialization and deserialization:
67/// - `load: Option<Mapper<T>>` - transformation during deserialization (applied in `transform_for_load()`)
68/// - `save: Option<Mapper<T>>` - transformation during serialization (applied in `transform_for_save()`)
69///
70/// These are commonly used for:
71/// - Quantization/dequantization
72/// - Precision conversion (e.g., FP32 ↔ FP16)
73/// - Custom parameter transformations
74pub struct ParamMapper<T: Parameter> {
75    load: Option<Mapper<T>>,
76    save: Option<Mapper<T>>,
77}
78
79impl<T: Parameter> core::fmt::Debug for ParamMapper<T> {
80    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
81        f.write_fmt(format_args!(
82            "ParamMapper {{ load: {}, save: {} }}",
83            self.load.is_some(),
84            self.save.is_some()
85        ))
86    }
87}
88
89impl<T: Parameter> ParamMapper<T> {
90    /// Applies the transformation when loading the given parameter.
91    pub fn on_load(&self, param: T) -> T {
92        match &self.load {
93            Some(mapper) => mapper(param),
94            None => param,
95        }
96    }
97    /// Applies the transformation when saving the given parameter.
98    pub fn on_save(&self, param: T) -> T {
99        match &self.save {
100            Some(mapper) => mapper(param),
101            None => param,
102        }
103    }
104}
105
106impl<T: Parameter> Default for ParamMapper<T> {
107    fn default() -> Self {
108        Self {
109            load: None,
110            save: None,
111        }
112    }
113}
114
115impl<T: Parameter> core::fmt::Display for Param<T> {
116    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
117        f.write_str(format!("Param: {}", self.id).as_str())
118    }
119}
120
121impl<T: Parameter> core::fmt::Debug for Param<T> {
122    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
123        f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str())
124    }
125}
126
127/// Trait that defines what is necessary for a type to be a parameter.
128pub trait Parameter: Clone + core::fmt::Debug + Send {
129    /// The device type to be used.
130    type Device: Clone;
131
132    /// Fetch the device.
133    fn device(&self) -> Self::Device;
134
135    /// Fetch the gradient requirement.
136    fn is_require_grad(&self) -> bool;
137
138    /// Set the gradient requirement.
139    fn set_require_grad(self, require_grad: bool) -> Self;
140}
141
142/// The deferred initialization state for lazy parameters.
143#[allow(clippy::type_complexity)]
144pub(crate) struct Uninitialized<P: Parameter> {
145    /// The initialization function. Called with `(device, is_require_grad) -> Parameter`.
146    /// This function is consumed during initialization via `FnOnce`.
147    init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
148    /// The target device on which the parameter should be initialized.
149    /// Used by `lazy_device()` to provide device information without triggering initialization.
150    pub(crate) device: P::Device,
151    /// The gradient requirement for the parameter.
152    /// Used by `lazy_is_require_grad()` to provide gradient settings without triggering initialization.
153    pub(crate) is_require_grad: bool,
154    /// The shape of the tensor parameter.
155    /// Used by `lazy_shape()` to provide shape information without triggering initialization.
156    pub(crate) shape: Shape,
157}
158
159impl<P: Parameter> Uninitialized<P> {
160    /// Consumes the uninitialized state and runs the initialization function.
161    ///
162    /// This is called by [Param::val] when accessing an uninitialized parameter for the first time.
163    /// The function is given the stored device and gradient requirement, and returns the initialized parameter.
164    fn initialize(self) -> P {
165        let init = self.init;
166        init(&self.device, self.is_require_grad)
167    }
168}
169
170impl<T: Parameter> Param<T> {
171    /// Create a new parameter that is already initialized.
172    pub fn initialized(id: ParamId, value: T) -> Self {
173        Self {
174            id,
175            state: OnceCell::from(value),
176            initialization: None,
177            param_mapper: Default::default(),
178        }
179    }
180
181    /// Create a new parameter that is not already initialized.
182    pub fn uninitialized<F>(
183        id: ParamId,
184        init: F,
185        device: T::Device,
186        is_require_grad: bool,
187        shape: Shape,
188    ) -> Self
189    where
190        F: FnOnce(&T::Device, bool) -> T + Send + 'static,
191    {
192        Self {
193            id,
194            state: OnceCell::new(),
195            initialization: Some(RwLock::new(Some(Uninitialized {
196                init: Box::new(init),
197                device,
198                is_require_grad,
199                shape,
200            }))),
201            param_mapper: Default::default(),
202        }
203    }
204
205    /// Gets the parameter value, initializing it lazily if needed.
206    ///
207    /// For initialized parameters, this returns a clone of the cached value.
208    /// For uninitialized parameters, this triggers initialization:
209    pub fn val(&self) -> T {
210        self.state
211            .get_or_init(|| {
212                let mut result = self
213                    .initialization
214                    .as_ref()
215                    .expect("Should have an initialization when no state provided.")
216                    .write()
217                    .unwrap();
218                let state = result.take().expect("Should exist when not initialized");
219                state.initialize()
220            })
221            .clone()
222    }
223
224    /// Check if the parameter has been initialized.
225    ///
226    /// Returns `true` if the parameter's value has been computed and cached,
227    /// `false` if it's still lazy and will be initialized on first access.
228    pub fn is_initialized(&self) -> bool {
229        self.state.get().is_some()
230    }
231
232    /// Gets the parameter's value while consuming the parameter.
233    pub fn into_value(self) -> T {
234        self.consume().1
235    }
236
237    /// Gets the parameter id and value while consuming the parameter.
238    pub fn consume(self) -> (ParamId, T, ParamMapper<T>) {
239        let tensor = self.val();
240
241        core::mem::drop(self.state);
242
243        (self.id, tensor, self.param_mapper)
244    }
245
246    /// Execute the given function on the inner value.
247    pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
248        let (id, tensor, param_mapper) = self.consume();
249        let tensor = func(tensor);
250
251        Self {
252            id,
253            state: OnceCell::from(tensor),
254            initialization: None,
255            param_mapper,
256        }
257    }
258
259    /// Create an initialized parameter with the given id, value, and param mapper.
260    ///
261    /// This is a helper method for creating parameters while preserving the param mapper,
262    /// typically used in ModuleMapper implementations.
263    pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {
264        Self {
265            id,
266            state: OnceCell::from(value),
267            initialization: None,
268            param_mapper,
269        }
270    }
271
272    /// Runs a transformation on the parameter when loading.
273    pub fn load_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
274        self.param_mapper.load = Some(new_mapper(func));
275
276        self
277    }
278
279    /// Runs a transformation on the parameter when saving.
280    pub fn save_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
281        self.param_mapper.save = Some(new_mapper(func));
282
283        self
284    }
285
286    /// Execute the given function on the inner value.
287    pub fn init_mapper<F: FnOnce(T) -> T + Send + 'static>(self, func: F) -> Self
288    where
289        T: 'static,
290    {
291        let initialization = match &self.initialization {
292            Some(init) => init,
293            None => return self.map(func),
294        };
295
296        let mut init = initialization.write().unwrap();
297
298        match init.as_mut() {
299            Some(value) => {
300                #[allow(clippy::type_complexity)]
301                let mut prev: Box<dyn FnOnce(&T::Device, bool) -> T + Send> =
302                    Box::new(|_, _| panic!("Fake func to not have null ref."));
303                core::mem::swap(&mut prev, &mut value.init);
304
305                value.init = Box::new(|a, b| {
306                    let tensor = prev(a, b);
307                    func(tensor)
308                });
309                core::mem::drop(init);
310                self
311            }
312            None => {
313                core::mem::drop(init);
314                self.map(func)
315            }
316        }
317    }
318
319    /// The device on which the parameter is or will be initialized, **without triggering initialization**.
320    ///
321    /// This is critical for the load optimization: when loading tensors into an uninitialized parameter,
322    /// we need to know the target device to move the loaded tensor appropriately, but we don't want to
323    /// trigger the initialization function (which would allocate an unnecessary tensor).
324    ///
325    /// Use this instead of [crate::tensor::Tensor::device] when you need the device but want to
326    /// preserve lazy initialization.
327    pub fn lazy_device(&self) -> T::Device {
328        let initialization = match &self.initialization {
329            Some(init) => init,
330            None => return self.device(),
331        };
332
333        let init = initialization.read().unwrap();
334
335        match init.as_ref() {
336            Some(value) => value.device.clone(),
337            None => self.device(),
338        }
339    }
340
341    /// The gradient requirement on which the parameter is or will be initialized, **without triggering initialization**.
342    ///
343    /// Similar to [lazy_device](Self::lazy_device), this is critical for the load optimization.
344    /// When loading tensors into an uninitialized parameter, we need to apply the correct gradient
345    /// setting to the loaded tensor without triggering the initialization function.
346    ///
347    /// # Notes
348    ///
349    /// This is a crate-private function, since users are not expected to use `is_require_grad` of an
350    /// uninitialized module to then override its value. All low-level functions should be provided
351    /// by `burn` and should handle those details.
352    pub(crate) fn lazy_is_require_grad(&self) -> bool {
353        let initialization = match &self.initialization {
354            Some(init) => init,
355            None => return self.is_require_grad(),
356        };
357
358        let init = initialization.read().unwrap();
359
360        match init.as_ref() {
361            Some(value) => value.is_require_grad,
362            None => self.is_require_grad(),
363        }
364    }
365
366    /// Override the gradient requirement for the current parameter.
367    pub fn set_require_grad(self, require_grad: bool) -> Self {
368        let initialization = match &self.initialization {
369            Some(init) => init,
370            None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
371        };
372
373        let mut init = initialization.write().unwrap();
374        let mut is_lazy = false;
375
376        if let Some(value) = init.as_mut() {
377            is_lazy = true;
378            value.is_require_grad = require_grad;
379        };
380
381        core::mem::drop(init);
382
383        if is_lazy {
384            return self;
385        }
386
387        self.map(|tensor| tensor.set_require_grad(require_grad))
388    }
389}
390
391impl<T: Parameter> Clone for Param<T> {
392    fn clone(&self) -> Self {
393        let mut param = Param::initialized(self.id, self.val());
394        param.param_mapper = self.param_mapper.clone();
395        param
396    }
397}
398
399impl<T: Parameter> Deref for Param<T> {
400    type Target = T;
401
402    fn deref(&self) -> &Self::Target {
403        self.state.get_or_init(|| {
404            let mut result = self
405                .initialization
406                .as_ref()
407                .expect("Should have an initialization when no state provided.")
408                .write()
409                .unwrap();
410
411            let state = result.take().expect("Should exist when not initialized");
412            state.initialize()
413        })
414    }
415}