1use super::ParamId;
2use super::sync_once_cell::SyncOnceCell;
3use alloc::format;
4
5#[cfg(not(target_has_atomic = "ptr"))]
6use alloc::boxed::Box;
7use burn_std::stub::RwLock;
8use burn_tensor::Shape;
9use core::ops::Deref;
10
11#[cfg(target_has_atomic = "ptr")]
12use alloc::sync::Arc;
13
14#[cfg(not(target_has_atomic = "ptr"))]
15use portable_atomic_util::Arc;
16
17#[cfg(target_has_atomic = "ptr")]
18type Mapper<T> = Arc<dyn Fn(T) -> T + Send + Sync>;
19
20#[cfg(not(target_has_atomic = "ptr"))]
21type Mapper<T> = Arc<Box<dyn Fn(T) -> T + Send + Sync>>;
22
23#[cfg(target_has_atomic = "ptr")]
24fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
25 Arc::new(func)
26}
27
28#[cfg(not(target_has_atomic = "ptr"))]
29fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
30 Arc::new(Box::new(func))
31}
32
33#[cfg(target_has_atomic = "ptr")]
37type InitFn<P> = Arc<dyn Fn(&<P as Parameter>::Device, bool) -> P + Send + Sync>;
38
39#[cfg(not(target_has_atomic = "ptr"))]
40type InitFn<P> = Arc<Box<dyn Fn(&<P as Parameter>::Device, bool) -> P + Send + Sync>>;
41
42#[cfg(target_has_atomic = "ptr")]
43fn new_init_fn<P: Parameter, F: Fn(&P::Device, bool) -> P + Send + Sync + 'static>(
44 func: F,
45) -> InitFn<P> {
46 Arc::new(func)
47}
48
49#[cfg(not(target_has_atomic = "ptr"))]
50fn new_init_fn<P: Parameter, F: Fn(&P::Device, bool) -> P + Send + Sync + 'static>(
51 func: F,
52) -> InitFn<P> {
53 Arc::new(Box::new(func))
54}
55
56pub struct Param<T: Parameter> {
72 pub id: ParamId,
74 pub(crate) state: SyncOnceCell<T>,
77 pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,
84 pub(crate) param_mapper: ParamMapper<T>,
85 pub(crate) require_grad: bool,
87}
88
89#[derive(Clone)]
90pub struct ParamMapper<T: Parameter> {
103 load: Option<Mapper<T>>,
104 save: Option<Mapper<T>>,
105}
106
107impl<T: Parameter> core::fmt::Debug for ParamMapper<T> {
108 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
109 f.write_fmt(format_args!(
110 "ParamMapper {{ load: {}, save: {} }}",
111 self.load.is_some(),
112 self.save.is_some()
113 ))
114 }
115}
116
117impl<T: Parameter> ParamMapper<T> {
118 pub fn on_load(&self, param: T) -> T {
120 match &self.load {
121 Some(mapper) => mapper(param),
122 None => param,
123 }
124 }
125 pub fn on_save(&self, param: T) -> T {
127 match &self.save {
128 Some(mapper) => mapper(param),
129 None => param,
130 }
131 }
132}
133
134impl<T: Parameter> Default for ParamMapper<T> {
135 fn default() -> Self {
136 Self {
137 load: None,
138 save: None,
139 }
140 }
141}
142
143impl<T: Parameter> core::fmt::Display for Param<T> {
144 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
145 f.write_str(format!("Param: {}", self.id).as_str())
146 }
147}
148
149impl<T: Parameter> core::fmt::Debug for Param<T> {
150 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
151 f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str())
152 }
153}
154
155pub trait Parameter: Clone + core::fmt::Debug + Send {
157 type Device: Clone;
159
160 fn device(&self) -> Self::Device;
162
163 fn is_require_grad(&self) -> bool;
165
166 fn set_require_grad(self, require_grad: bool) -> Self;
168}
169
170#[allow(clippy::type_complexity)]
172pub(crate) struct Uninitialized<P: Parameter> {
173 init: InitFn<P>,
178 pub(crate) device: P::Device,
181 pub(crate) is_require_grad: bool,
184 pub(crate) shape: Shape,
187}
188
189impl<P: Parameter> Clone for Uninitialized<P> {
190 fn clone(&self) -> Self {
191 Self {
192 init: self.init.clone(),
193 device: self.device.clone(),
194 is_require_grad: self.is_require_grad,
195 shape: self.shape.clone(),
196 }
197 }
198}
199
200impl<P: Parameter> Uninitialized<P> {
201 fn initialize(&self) -> P {
210 (self.init)(&self.device, self.is_require_grad)
211 }
212}
213
214impl<T: Parameter> Param<T> {
215 pub fn initialized(id: ParamId, value: T) -> Self {
217 let require_grad = value.is_require_grad();
218 Self {
219 id,
220 state: SyncOnceCell::initialized(value),
221 initialization: None,
222 param_mapper: Default::default(),
223 require_grad,
224 }
225 }
226
227 pub fn uninitialized<F>(
229 id: ParamId,
230 init: F,
231 device: T::Device,
232 is_require_grad: bool,
233 shape: Shape,
234 ) -> Self
235 where
236 F: Fn(&T::Device, bool) -> T + Send + Sync + 'static,
237 {
238 Self {
239 id,
240 state: SyncOnceCell::new(),
241 initialization: Some(RwLock::new(Some(Uninitialized {
242 init: new_init_fn(init),
243 device,
244 is_require_grad,
245 shape,
246 }))),
247 param_mapper: Default::default(),
248 require_grad: is_require_grad,
249 }
250 }
251
252 pub fn val(&self) -> T {
257 self.state
258 .get_or_init(|| {
259 let mut result = self
260 .initialization
261 .as_ref()
262 .expect("Should have an initialization when no state provided.")
263 .write()
264 .unwrap();
265 let state = result.take().expect("Should exist when not initialized");
266 state.initialize()
267 })
268 .clone()
269 }
270
271 pub fn is_initialized(&self) -> bool {
276 self.state.get().is_some()
277 }
278
279 pub fn into_value(self) -> T {
281 self.consume().1
282 }
283
284 pub fn consume(self) -> (ParamId, T, ParamMapper<T>) {
286 let tensor = self.val();
287
288 core::mem::drop(self.state);
289
290 (self.id, tensor, self.param_mapper)
291 }
292
293 pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
295 let (id, tensor, param_mapper) = self.consume();
296 let tensor = func(tensor);
297 let require_grad = tensor.is_require_grad();
298
299 Self {
300 id,
301 state: SyncOnceCell::initialized(tensor),
302 initialization: None,
303 param_mapper,
304 require_grad,
305 }
306 }
307
308 pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {
313 let require_grad = value.is_require_grad();
314 Self {
315 id,
316 state: SyncOnceCell::initialized(value),
317 initialization: None,
318 param_mapper,
319 require_grad,
320 }
321 }
322
323 pub fn load_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
325 self.param_mapper.load = Some(new_mapper(func));
326
327 self
328 }
329
330 pub fn save_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
332 self.param_mapper.save = Some(new_mapper(func));
333
334 self
335 }
336
337 pub fn init_mapper<F: Fn(T) -> T + Send + Sync + 'static>(self, func: F) -> Self
339 where
340 T: 'static,
341 {
342 let initialization = match &self.initialization {
343 Some(init) => init,
344 None => return self.map(func),
345 };
346
347 let mut init = initialization.write().unwrap();
348
349 match init.as_mut() {
350 Some(value) => {
351 let prev = value.init.clone();
352
353 value.init = new_init_fn(move |a, b| {
354 let tensor = prev(a, b);
355 func(tensor)
356 });
357 core::mem::drop(init);
358 self
359 }
360 None => {
361 core::mem::drop(init);
362 self.map(func)
363 }
364 }
365 }
366
367 pub fn lazy_device(&self) -> T::Device {
376 let initialization = match &self.initialization {
377 Some(init) => init,
378 None => return self.device(),
379 };
380
381 let init = initialization.read().unwrap();
382
383 match init.as_ref() {
384 Some(value) => value.device.clone(),
385 None => self.device(),
386 }
387 }
388
389 pub(crate) fn lazy_is_require_grad(&self) -> bool {
401 let initialization = match &self.initialization {
402 Some(init) => init,
403 None => return self.is_require_grad(),
404 };
405
406 let init = initialization.read().unwrap();
407
408 match init.as_ref() {
409 Some(value) => value.is_require_grad,
410 None => self.is_require_grad(),
411 }
412 }
413
414 pub fn set_require_grad(self, require_grad: bool) -> Self {
416 let initialization = match &self.initialization {
417 Some(init) => init,
418 None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
419 };
420
421 let mut init = initialization.write().unwrap();
422 let mut is_lazy = false;
423
424 if let Some(value) = init.as_mut() {
425 is_lazy = true;
426 value.is_require_grad = require_grad;
427 };
428
429 core::mem::drop(init);
430
431 if is_lazy {
432 return self;
433 }
434
435 self.map(|tensor| tensor.set_require_grad(require_grad))
436 }
437}
438
439impl<T: Parameter> Clone for Param<T> {
440 fn clone(&self) -> Self {
441 if let Some(init_lock) = &self.initialization {
447 let init_guard = init_lock.read().unwrap();
448 if let Some(uninit) = init_guard.as_ref() {
449 return Self {
450 id: self.id,
451 state: SyncOnceCell::new(),
452 initialization: Some(RwLock::new(Some(uninit.clone()))),
453 param_mapper: self.param_mapper.clone(),
454 require_grad: self.require_grad,
455 };
456 }
457 }
458
459 let mut param = Param::initialized(self.id, self.val());
461 param.param_mapper = self.param_mapper.clone();
462 param
463 }
464}
465
466impl<T: Parameter> Deref for Param<T> {
467 type Target = T;
468
469 fn deref(&self) -> &Self::Target {
470 self.state.get_or_init(|| {
471 let mut result = self
472 .initialization
473 .as_ref()
474 .expect("Should have an initialization when no state provided.")
475 .write()
476 .unwrap();
477
478 let state = result.take().expect("Should exist when not initialized");
479 state.initialize()
480 })
481 }
482}
483
484#[cfg(test)]
485mod tests {
486 use super::*;
487 use burn_tensor::{Tensor, backend::Backend};
488
489 fn _assert_sync<T: Sync>() {}
492
493 #[test]
494 fn param_is_sync() {
495 fn check<B: Backend>() {
496 _assert_sync::<Param<Tensor<B, 2>>>();
497 }
498 check::<burn_flex::Flex>();
499 }
500
501 #[cfg(feature = "std")]
507 #[test]
508 fn param_concurrent_lazy_init() {
509 use alloc::vec::Vec;
510
511 type B = burn_flex::Flex;
512 let device = Default::default();
513
514 let param: Param<Tensor<B, 2>> = Param::uninitialized(
515 ParamId::new(),
516 |device, _require_grad| Tensor::zeros([2, 3], device),
517 device,
518 false,
519 [2, 3].into(),
520 );
521
522 std::thread::scope(|s| {
524 let handles: Vec<_> = (0..4).map(|_| s.spawn(|| param.val())).collect();
525
526 let results: Vec<_> = handles.into_iter().map(|h| h.join().unwrap()).collect();
527
528 let expected = results[0].to_data();
530 for result in &results[1..] {
531 assert_eq!(result.to_data(), expected);
532 }
533 });
534 }
535}