Skip to main content

burn_core/module/param/
base.rs

1use super::ParamId;
2use super::sync_once_cell::SyncOnceCell;
3use alloc::format;
4
5#[cfg(not(target_has_atomic = "ptr"))]
6use alloc::boxed::Box;
7use burn_std::stub::RwLock;
8use burn_tensor::Shape;
9use core::ops::Deref;
10
11#[cfg(target_has_atomic = "ptr")]
12use alloc::sync::Arc;
13
14#[cfg(not(target_has_atomic = "ptr"))]
15use portable_atomic_util::Arc;
16
17#[cfg(target_has_atomic = "ptr")]
18type Mapper<T> = Arc<dyn Fn(T) -> T + Send + Sync>;
19
20#[cfg(not(target_has_atomic = "ptr"))]
21type Mapper<T> = Arc<Box<dyn Fn(T) -> T + Send + Sync>>;
22
23#[cfg(target_has_atomic = "ptr")]
24fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
25    Arc::new(func)
26}
27
28#[cfg(not(target_has_atomic = "ptr"))]
29fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
30    Arc::new(Box::new(func))
31}
32
33/// Type alias for the init function stored in `Uninitialized`.
34/// On targets without atomics, `portable_atomic_util::Arc` needs `Box` indirection
35/// for unsized types, mirroring the `Mapper` pattern above.
36#[cfg(target_has_atomic = "ptr")]
37type InitFn<P> = Arc<dyn Fn(&<P as Parameter>::Device, bool) -> P + Send + Sync>;
38
39#[cfg(not(target_has_atomic = "ptr"))]
40type InitFn<P> = Arc<Box<dyn Fn(&<P as Parameter>::Device, bool) -> P + Send + Sync>>;
41
42#[cfg(target_has_atomic = "ptr")]
43fn new_init_fn<P: Parameter, F: Fn(&P::Device, bool) -> P + Send + Sync + 'static>(
44    func: F,
45) -> InitFn<P> {
46    Arc::new(func)
47}
48
49#[cfg(not(target_has_atomic = "ptr"))]
50fn new_init_fn<P: Parameter, F: Fn(&P::Device, bool) -> P + Send + Sync + 'static>(
51    func: F,
52) -> InitFn<P> {
53    Arc::new(Box::new(func))
54}
55
56/// Parameters are the fundamental building blocks of [modules](crate::module::Module) where they
57/// serve as containers for [tensors](crate::tensor::Tensor) that can be updated during
58/// training, and loaded during inference. If you don't want to save the tensors
59/// and/or don't want to update it during training, you don't need this type to wrap your tensor.
60///
61/// # Core Lazy Initialization Architecture
62///
63/// `Param<T>` has a dual-state design using `SyncOnceCell<T>`:
64///
65/// ## State Management
66///
67/// **Two possible states:**
68///
69/// 1. **Initialized**: `state: SyncOnceCell<T>` contains value, `initialization: None`
70/// 2. **Uninitialized (Lazy)**: `state` is empty, `initialization: Some(RwLock<Option<Uninitialized<T>>>)`
71pub struct Param<T: Parameter> {
72    /// The unique ID of this parameter. This is used by eg. optimizers to associate a gradient with a specific parameter.
73    pub id: ParamId,
74    /// The SyncOnceCell holding the initialized parameter value.
75    /// Empty for uninitialized parameters, populated after first access or explicit initialization.
76    pub(crate) state: SyncOnceCell<T>,
77    /// The deferred initialization state for lazy parameters.
78    ///
79    /// **State Transitions:**
80    /// - Initialized params: `None`
81    /// - Uninitialized params: `Some(RwLock<Some(Uninitialized<T>)>)`
82    /// - After lazy init triggers: `Some(RwLock<None>)` (inner Option is taken)
83    pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,
84    pub(crate) param_mapper: ParamMapper<T>,
85    // For stateful `module.valid()` <> `module.train()`
86    pub(crate) require_grad: bool,
87}
88
89#[derive(Clone)]
90/// Applies transformations when loading and saving parameters.
91///
92/// # Mapper System
93///
94/// `ParamMapper<T>` allows applying transformations during serialization and deserialization:
95/// - `load: Option<Mapper<T>>` - transformation during deserialization (applied in `transform_for_load()`)
96/// - `save: Option<Mapper<T>>` - transformation during serialization (applied in `transform_for_save()`)
97///
98/// These are commonly used for:
99/// - Quantization/dequantization
100/// - Precision conversion (e.g., FP32 ↔ FP16)
101/// - Custom parameter transformations
102pub struct ParamMapper<T: Parameter> {
103    load: Option<Mapper<T>>,
104    save: Option<Mapper<T>>,
105}
106
107impl<T: Parameter> core::fmt::Debug for ParamMapper<T> {
108    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
109        f.write_fmt(format_args!(
110            "ParamMapper {{ load: {}, save: {} }}",
111            self.load.is_some(),
112            self.save.is_some()
113        ))
114    }
115}
116
117impl<T: Parameter> ParamMapper<T> {
118    /// Applies the transformation when loading the given parameter.
119    pub fn on_load(&self, param: T) -> T {
120        match &self.load {
121            Some(mapper) => mapper(param),
122            None => param,
123        }
124    }
125    /// Applies the transformation when saving the given parameter.
126    pub fn on_save(&self, param: T) -> T {
127        match &self.save {
128            Some(mapper) => mapper(param),
129            None => param,
130        }
131    }
132}
133
134impl<T: Parameter> Default for ParamMapper<T> {
135    fn default() -> Self {
136        Self {
137            load: None,
138            save: None,
139        }
140    }
141}
142
143impl<T: Parameter> core::fmt::Display for Param<T> {
144    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145        f.write_str(format!("Param: {}", self.id).as_str())
146    }
147}
148
149impl<T: Parameter> core::fmt::Debug for Param<T> {
150    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
151        f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str())
152    }
153}
154
155/// Trait that defines what is necessary for a type to be a parameter.
156pub trait Parameter: Clone + core::fmt::Debug + Send {
157    /// The device type to be used.
158    type Device: Clone;
159
160    /// Fetch the device.
161    fn device(&self) -> Self::Device;
162
163    /// Fetch the gradient requirement.
164    fn is_require_grad(&self) -> bool;
165
166    /// Set the gradient requirement.
167    fn set_require_grad(self, require_grad: bool) -> Self;
168}
169
170/// The deferred initialization state for lazy parameters.
171#[allow(clippy::type_complexity)]
172pub(crate) struct Uninitialized<P: Parameter> {
173    /// The initialization function. Called with `(device, is_require_grad) -> Parameter`.
174    /// Wrapped in `Arc` so that cloning a `Param` preserves the lazy state without
175    /// triggering initialization. Each clone holds its own `Uninitialized` state and
176    /// will run the init function separately on first access (producing independent values).
177    init: InitFn<P>,
178    /// The target device on which the parameter should be initialized.
179    /// Used by `lazy_device()` to provide device information without triggering initialization.
180    pub(crate) device: P::Device,
181    /// The gradient requirement for the parameter.
182    /// Used by `lazy_is_require_grad()` to provide gradient settings without triggering initialization.
183    pub(crate) is_require_grad: bool,
184    /// The shape of the tensor parameter.
185    /// Used by `lazy_shape()` to provide shape information without triggering initialization.
186    pub(crate) shape: Shape,
187}
188
189impl<P: Parameter> Clone for Uninitialized<P> {
190    fn clone(&self) -> Self {
191        Self {
192            init: self.init.clone(),
193            device: self.device.clone(),
194            is_require_grad: self.is_require_grad,
195            shape: self.shape.clone(),
196        }
197    }
198}
199
200impl<P: Parameter> Uninitialized<P> {
201    /// Runs the initialization function.
202    ///
203    /// This is called by [Param::val] when accessing an uninitialized parameter for the first time.
204    /// The function is given the stored device and gradient requirement, and returns the initialized parameter.
205    ///
206    /// Although this takes `&self` (the `Arc<dyn Fn>` is callable multiple times), callers
207    /// are expected to invoke this only once per `Param` instance. The caller (`val()`) takes
208    /// the `Uninitialized` out of its `Option` via `take()` to enforce single-initialization.
209    fn initialize(&self) -> P {
210        (self.init)(&self.device, self.is_require_grad)
211    }
212}
213
214impl<T: Parameter> Param<T> {
215    /// Create a new parameter that is already initialized.
216    pub fn initialized(id: ParamId, value: T) -> Self {
217        let require_grad = value.is_require_grad();
218        Self {
219            id,
220            state: SyncOnceCell::initialized(value),
221            initialization: None,
222            param_mapper: Default::default(),
223            require_grad,
224        }
225    }
226
227    /// Create a new parameter that is not already initialized.
228    pub fn uninitialized<F>(
229        id: ParamId,
230        init: F,
231        device: T::Device,
232        is_require_grad: bool,
233        shape: Shape,
234    ) -> Self
235    where
236        F: Fn(&T::Device, bool) -> T + Send + Sync + 'static,
237    {
238        Self {
239            id,
240            state: SyncOnceCell::new(),
241            initialization: Some(RwLock::new(Some(Uninitialized {
242                init: new_init_fn(init),
243                device,
244                is_require_grad,
245                shape,
246            }))),
247            param_mapper: Default::default(),
248            require_grad: is_require_grad,
249        }
250    }
251
252    /// Gets the parameter value, initializing it lazily if needed.
253    ///
254    /// For initialized parameters, this returns a clone of the cached value.
255    /// For uninitialized parameters, this triggers initialization:
256    pub fn val(&self) -> T {
257        self.state
258            .get_or_init(|| {
259                let mut result = self
260                    .initialization
261                    .as_ref()
262                    .expect("Should have an initialization when no state provided.")
263                    .write()
264                    .unwrap();
265                let state = result.take().expect("Should exist when not initialized");
266                state.initialize()
267            })
268            .clone()
269    }
270
271    /// Check if the parameter has been initialized.
272    ///
273    /// Returns `true` if the parameter's value has been computed and cached,
274    /// `false` if it's still lazy and will be initialized on first access.
275    pub fn is_initialized(&self) -> bool {
276        self.state.get().is_some()
277    }
278
279    /// Gets the parameter's value while consuming the parameter.
280    pub fn into_value(self) -> T {
281        self.consume().1
282    }
283
284    /// Gets the parameter id and value while consuming the parameter.
285    pub fn consume(self) -> (ParamId, T, ParamMapper<T>) {
286        let tensor = self.val();
287
288        core::mem::drop(self.state);
289
290        (self.id, tensor, self.param_mapper)
291    }
292
293    /// Execute the given function on the inner value.
294    pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
295        let (id, tensor, param_mapper) = self.consume();
296        let tensor = func(tensor);
297        let require_grad = tensor.is_require_grad();
298
299        Self {
300            id,
301            state: SyncOnceCell::initialized(tensor),
302            initialization: None,
303            param_mapper,
304            require_grad,
305        }
306    }
307
308    /// Create an initialized parameter with the given id, value, and param mapper.
309    ///
310    /// This is a helper method for creating parameters while preserving the param mapper,
311    /// typically used in ModuleMapper implementations.
312    pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {
313        let require_grad = value.is_require_grad();
314        Self {
315            id,
316            state: SyncOnceCell::initialized(value),
317            initialization: None,
318            param_mapper,
319            require_grad,
320        }
321    }
322
323    /// Runs a transformation on the parameter when loading.
324    pub fn load_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
325        self.param_mapper.load = Some(new_mapper(func));
326
327        self
328    }
329
330    /// Runs a transformation on the parameter when saving.
331    pub fn save_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
332        self.param_mapper.save = Some(new_mapper(func));
333
334        self
335    }
336
337    /// Execute the given function on the inner value.
338    pub fn init_mapper<F: Fn(T) -> T + Send + Sync + 'static>(self, func: F) -> Self
339    where
340        T: 'static,
341    {
342        let initialization = match &self.initialization {
343            Some(init) => init,
344            None => return self.map(func),
345        };
346
347        let mut init = initialization.write().unwrap();
348
349        match init.as_mut() {
350            Some(value) => {
351                let prev = value.init.clone();
352
353                value.init = new_init_fn(move |a, b| {
354                    let tensor = prev(a, b);
355                    func(tensor)
356                });
357                core::mem::drop(init);
358                self
359            }
360            None => {
361                core::mem::drop(init);
362                self.map(func)
363            }
364        }
365    }
366
367    /// The device on which the parameter is or will be initialized, **without triggering initialization**.
368    ///
369    /// This is critical for the load optimization: when loading tensors into an uninitialized parameter,
370    /// we need to know the target device to move the loaded tensor appropriately, but we don't want to
371    /// trigger the initialization function (which would allocate an unnecessary tensor).
372    ///
373    /// Use this instead of [crate::tensor::Tensor::device] when you need the device but want to
374    /// preserve lazy initialization.
375    pub fn lazy_device(&self) -> T::Device {
376        let initialization = match &self.initialization {
377            Some(init) => init,
378            None => return self.device(),
379        };
380
381        let init = initialization.read().unwrap();
382
383        match init.as_ref() {
384            Some(value) => value.device.clone(),
385            None => self.device(),
386        }
387    }
388
389    /// The gradient requirement on which the parameter is or will be initialized, **without triggering initialization**.
390    ///
391    /// Similar to [lazy_device](Self::lazy_device), this is critical for the load optimization.
392    /// When loading tensors into an uninitialized parameter, we need to apply the correct gradient
393    /// setting to the loaded tensor without triggering the initialization function.
394    ///
395    /// # Notes
396    ///
397    /// This is a crate-private function, since users are not expected to use `is_require_grad` of an
398    /// uninitialized module to then override its value. All low-level functions should be provided
399    /// by `burn` and should handle those details.
400    pub(crate) fn lazy_is_require_grad(&self) -> bool {
401        let initialization = match &self.initialization {
402            Some(init) => init,
403            None => return self.is_require_grad(),
404        };
405
406        let init = initialization.read().unwrap();
407
408        match init.as_ref() {
409            Some(value) => value.is_require_grad,
410            None => self.is_require_grad(),
411        }
412    }
413
414    /// Override the gradient requirement for the current parameter.
415    pub fn set_require_grad(self, require_grad: bool) -> Self {
416        let initialization = match &self.initialization {
417            Some(init) => init,
418            None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
419        };
420
421        let mut init = initialization.write().unwrap();
422        let mut is_lazy = false;
423
424        if let Some(value) = init.as_mut() {
425            is_lazy = true;
426            value.is_require_grad = require_grad;
427        };
428
429        core::mem::drop(init);
430
431        if is_lazy {
432            return self;
433        }
434
435        self.map(|tensor| tensor.set_require_grad(require_grad))
436    }
437}
438
439impl<T: Parameter> Clone for Param<T> {
440    fn clone(&self) -> Self {
441        // If uninitialized, clone the lazy state without triggering initialization.
442        // This avoids allocating tensor memory for params that may never be used
443        // (e.g., when cloning a module just to load weights into it).
444        // The clone gets its own SyncOnceCell and RwLock, so initializing one
445        // does not affect the other.
446        if let Some(init_lock) = &self.initialization {
447            let init_guard = init_lock.read().unwrap();
448            if let Some(uninit) = init_guard.as_ref() {
449                return Self {
450                    id: self.id,
451                    state: SyncOnceCell::new(),
452                    initialization: Some(RwLock::new(Some(uninit.clone()))),
453                    param_mapper: self.param_mapper.clone(),
454                    require_grad: self.require_grad,
455                };
456            }
457        }
458
459        // Already initialized (or init was already consumed): clone the value.
460        let mut param = Param::initialized(self.id, self.val());
461        param.param_mapper = self.param_mapper.clone();
462        param
463    }
464}
465
466impl<T: Parameter> Deref for Param<T> {
467    type Target = T;
468
469    fn deref(&self) -> &Self::Target {
470        self.state.get_or_init(|| {
471            let mut result = self
472                .initialization
473                .as_ref()
474                .expect("Should have an initialization when no state provided.")
475                .write()
476                .unwrap();
477
478            let state = result.take().expect("Should exist when not initialized");
479            state.initialize()
480        })
481    }
482}
483
484#[cfg(test)]
485mod tests {
486    use super::*;
487    use burn_tensor::{Tensor, backend::Backend};
488
489    // Param<T> should be Sync so that models can be shared across threads
490    // (e.g. parallel inference with rayon).
491    fn _assert_sync<T: Sync>() {}
492
493    #[test]
494    fn param_is_sync() {
495        fn check<B: Backend>() {
496            _assert_sync::<Param<Tensor<B, 2>>>();
497        }
498        check::<burn_flex::Flex>();
499    }
500
501    /// Concurrent lazy initialization must not panic.
502    ///
503    /// Multiple threads call `val()` on an uninitialized `Param` simultaneously.
504    /// `SyncOnceCell::get_or_init` guarantees only one thread runs the initializer;
505    /// the others block and receive the same value.
506    #[cfg(feature = "std")]
507    #[test]
508    fn param_concurrent_lazy_init() {
509        use alloc::vec::Vec;
510
511        type B = burn_flex::Flex;
512        let device = Default::default();
513
514        let param: Param<Tensor<B, 2>> = Param::uninitialized(
515            ParamId::new(),
516            |device, _require_grad| Tensor::zeros([2, 3], device),
517            device,
518            false,
519            [2, 3].into(),
520        );
521
522        // Share across threads via &param (requires Sync).
523        std::thread::scope(|s| {
524            let handles: Vec<_> = (0..4).map(|_| s.spawn(|| param.val())).collect();
525
526            let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
527
528            // All threads must get the same value.
529            let expected = results[0].to_data();
530            for result in &results[1..] {
531                assert_eq!(result.to_data(), expected);
532            }
533        });
534    }
535}