1use super::ParamId;
2use alloc::{boxed::Box, format};
3use burn_common::stub::RwLock;
4use burn_tensor::Shape;
5use core::cell::OnceCell;
6use core::ops::Deref;
7
8#[cfg(target_has_atomic = "ptr")]
9use alloc::sync::Arc;
10
11#[cfg(not(target_has_atomic = "ptr"))]
12use portable_atomic_util::Arc;
13
14#[cfg(target_has_atomic = "ptr")]
15type Mapper<T> = Arc<dyn Fn(T) -> T + Send + Sync>;
16
17#[cfg(not(target_has_atomic = "ptr"))]
18type Mapper<T> = Arc<Box<dyn Fn(T) -> T + Send + Sync>>;
19
20#[cfg(target_has_atomic = "ptr")]
21fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
22 Arc::new(func)
23}
24
25#[cfg(not(target_has_atomic = "ptr"))]
26fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
27 Arc::new(Box::new(func))
28}
29
30pub struct Param<T: Parameter> {
46 pub id: ParamId,
48 pub(crate) state: OnceCell<T>,
51 pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,
58 pub(crate) param_mapper: ParamMapper<T>,
59}
60
61#[derive(Clone)]
62pub struct ParamMapper<T: Parameter> {
75 load: Option<Mapper<T>>,
76 save: Option<Mapper<T>>,
77}
78
79impl<T: Parameter> core::fmt::Debug for ParamMapper<T> {
80 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
81 f.write_fmt(format_args!(
82 "ParamMapper {{ load: {}, save: {} }}",
83 self.load.is_some(),
84 self.save.is_some()
85 ))
86 }
87}
88
89impl<T: Parameter> ParamMapper<T> {
90 pub fn on_load(&self, param: T) -> T {
92 match &self.load {
93 Some(mapper) => mapper(param),
94 None => param,
95 }
96 }
97 pub fn on_save(&self, param: T) -> T {
99 match &self.save {
100 Some(mapper) => mapper(param),
101 None => param,
102 }
103 }
104}
105
106impl<T: Parameter> Default for ParamMapper<T> {
107 fn default() -> Self {
108 Self {
109 load: None,
110 save: None,
111 }
112 }
113}
114
115impl<T: Parameter> core::fmt::Display for Param<T> {
116 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
117 f.write_str(format!("Param: {}", self.id).as_str())
118 }
119}
120
121impl<T: Parameter> core::fmt::Debug for Param<T> {
122 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
123 f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str())
124 }
125}
126
127pub trait Parameter: Clone + core::fmt::Debug + Send {
129 type Device: Clone;
131
132 fn device(&self) -> Self::Device;
134
135 fn is_require_grad(&self) -> bool;
137
138 fn set_require_grad(self, require_grad: bool) -> Self;
140}
141
142#[allow(clippy::type_complexity)]
144pub(crate) struct Uninitialized<P: Parameter> {
145 init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
148 pub(crate) device: P::Device,
151 pub(crate) is_require_grad: bool,
154 pub(crate) shape: Shape,
157}
158
159impl<P: Parameter> Uninitialized<P> {
160 fn initialize(self) -> P {
165 let init = self.init;
166 init(&self.device, self.is_require_grad)
167 }
168}
169
170impl<T: Parameter> Param<T> {
171 pub fn initialized(id: ParamId, value: T) -> Self {
173 Self {
174 id,
175 state: OnceCell::from(value),
176 initialization: None,
177 param_mapper: Default::default(),
178 }
179 }
180
181 pub fn uninitialized<F>(
183 id: ParamId,
184 init: F,
185 device: T::Device,
186 is_require_grad: bool,
187 shape: Shape,
188 ) -> Self
189 where
190 F: FnOnce(&T::Device, bool) -> T + Send + 'static,
191 {
192 Self {
193 id,
194 state: OnceCell::new(),
195 initialization: Some(RwLock::new(Some(Uninitialized {
196 init: Box::new(init),
197 device,
198 is_require_grad,
199 shape,
200 }))),
201 param_mapper: Default::default(),
202 }
203 }
204
205 pub fn val(&self) -> T {
210 self.state
211 .get_or_init(|| {
212 let mut result = self
213 .initialization
214 .as_ref()
215 .expect("Should have an initialization when no state provided.")
216 .write()
217 .unwrap();
218 let state = result.take().expect("Should exist when not initialized");
219 state.initialize()
220 })
221 .clone()
222 }
223
224 pub fn is_initialized(&self) -> bool {
229 self.state.get().is_some()
230 }
231
232 pub fn into_value(self) -> T {
234 self.consume().1
235 }
236
237 pub fn consume(self) -> (ParamId, T, ParamMapper<T>) {
239 let tensor = self.val();
240
241 core::mem::drop(self.state);
242
243 (self.id, tensor, self.param_mapper)
244 }
245
246 pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
248 let (id, tensor, param_mapper) = self.consume();
249 let tensor = func(tensor);
250
251 Self {
252 id,
253 state: OnceCell::from(tensor),
254 initialization: None,
255 param_mapper,
256 }
257 }
258
259 pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {
264 Self {
265 id,
266 state: OnceCell::from(value),
267 initialization: None,
268 param_mapper,
269 }
270 }
271
272 pub fn load_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
274 self.param_mapper.load = Some(new_mapper(func));
275
276 self
277 }
278
279 pub fn save_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
281 self.param_mapper.save = Some(new_mapper(func));
282
283 self
284 }
285
286 pub fn init_mapper<F: FnOnce(T) -> T + Send + 'static>(self, func: F) -> Self
288 where
289 T: 'static,
290 {
291 let initialization = match &self.initialization {
292 Some(init) => init,
293 None => return self.map(func),
294 };
295
296 let mut init = initialization.write().unwrap();
297
298 match init.as_mut() {
299 Some(value) => {
300 #[allow(clippy::type_complexity)]
301 let mut prev: Box<dyn FnOnce(&T::Device, bool) -> T + Send> =
302 Box::new(|_, _| panic!("Fake func to not have null ref."));
303 core::mem::swap(&mut prev, &mut value.init);
304
305 value.init = Box::new(|a, b| {
306 let tensor = prev(a, b);
307 func(tensor)
308 });
309 core::mem::drop(init);
310 self
311 }
312 None => {
313 core::mem::drop(init);
314 self.map(func)
315 }
316 }
317 }
318
319 pub fn lazy_device(&self) -> T::Device {
328 let initialization = match &self.initialization {
329 Some(init) => init,
330 None => return self.device(),
331 };
332
333 let init = initialization.read().unwrap();
334
335 match init.as_ref() {
336 Some(value) => value.device.clone(),
337 None => self.device(),
338 }
339 }
340
341 pub(crate) fn lazy_is_require_grad(&self) -> bool {
353 let initialization = match &self.initialization {
354 Some(init) => init,
355 None => return self.is_require_grad(),
356 };
357
358 let init = initialization.read().unwrap();
359
360 match init.as_ref() {
361 Some(value) => value.is_require_grad,
362 None => self.is_require_grad(),
363 }
364 }
365
366 pub fn set_require_grad(self, require_grad: bool) -> Self {
368 let initialization = match &self.initialization {
369 Some(init) => init,
370 None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
371 };
372
373 let mut init = initialization.write().unwrap();
374 let mut is_lazy = false;
375
376 if let Some(value) = init.as_mut() {
377 is_lazy = true;
378 value.is_require_grad = require_grad;
379 };
380
381 core::mem::drop(init);
382
383 if is_lazy {
384 return self;
385 }
386
387 self.map(|tensor| tensor.set_require_grad(require_grad))
388 }
389}
390
391impl<T: Parameter> Clone for Param<T> {
392 fn clone(&self) -> Self {
393 let mut param = Param::initialized(self.id, self.val());
394 param.param_mapper = self.param_mapper.clone();
395 param
396 }
397}
398
399impl<T: Parameter> Deref for Param<T> {
400 type Target = T;
401
402 fn deref(&self) -> &Self::Target {
403 self.state.get_or_init(|| {
404 let mut result = self
405 .initialization
406 .as_ref()
407 .expect("Should have an initialization when no state provided.")
408 .write()
409 .unwrap();
410
411 let state = result.take().expect("Should exist when not initialized");
412 state.initialize()
413 })
414 }
415}