use super::ParamId;
use alloc::boxed::Box;
use alloc::format;
use burn_common::stub::RwLock;
use core::cell::OnceCell;
use core::ops::Deref;
pub struct Param<T: Parameter> {
pub id: ParamId,
state: OnceCell<T>,
initialization: Option<RwLock<Option<Uninitialized<T>>>>,
}
impl<T: Parameter> core::fmt::Display for Param<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("Param: {}", self.id).as_str())
}
}
impl<T: Parameter> core::fmt::Debug for Param<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("Param: {}", self.id).as_str())
}
}
pub trait Parameter: Clone + core::fmt::Debug + Send {
type Device: Clone;
fn device(&self) -> Self::Device;
fn is_require_grad(&self) -> bool;
fn set_require_grad(self, require_grad: bool) -> Self;
}
#[allow(clippy::type_complexity)]
struct Uninitialized<P: Parameter> {
init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
device: P::Device,
is_require_grad: bool,
}
impl<P: Parameter> Uninitialized<P> {
fn initialize(self) -> P {
let init = self.init;
init(&self.device, self.is_require_grad)
}
}
impl<T: Parameter> Param<T> {
pub fn initialized(id: ParamId, value: T) -> Self {
Self {
id,
state: OnceCell::from(value),
initialization: None,
}
}
pub fn uninitialized<F>(id: ParamId, init: F, device: T::Device, is_require_grad: bool) -> Self
where
F: FnOnce(&T::Device, bool) -> T + Send + 'static,
{
Self {
id,
state: OnceCell::new(),
initialization: Some(RwLock::new(Some(Uninitialized {
init: Box::new(init),
device,
is_require_grad,
}))),
}
}
pub fn val(&self) -> T {
self.state
.get_or_init(|| {
let mut result = self
.initialization
.as_ref()
.expect("Should have an initialization when no state provided.")
.write()
.unwrap();
let state = result.take().expect("Should exist when not initialized");
state.initialize()
})
.clone()
}
pub fn into_value(self) -> T {
self.consume().1
}
pub fn consume(self) -> (ParamId, T) {
let tensor = self.val();
core::mem::drop(self.state);
(self.id, tensor)
}
pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
let (id, tensor) = self.consume();
let tensor = func(tensor);
Self {
id,
state: OnceCell::from(tensor),
initialization: None,
}
}
pub(crate) fn lazy_device(&self) -> T::Device {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.device(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.device.clone(),
None => self.device(),
}
}
pub(crate) fn lazy_is_require_grad(&self) -> bool {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.is_require_grad(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.is_require_grad,
None => self.is_require_grad(),
}
}
pub fn set_require_grad(self, require_grad: bool) -> Self {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
};
let mut init = initialization.write().unwrap();
let mut is_lazy = false;
if let Some(value) = init.as_mut() {
is_lazy = true;
value.is_require_grad = require_grad;
};
core::mem::drop(init);
if is_lazy {
return self;
}
self.map(|tensor| tensor.set_require_grad(require_grad))
}
}
impl<T: Parameter> Clone for Param<T> {
fn clone(&self) -> Self {
Param::initialized(self.id, self.val())
}
}
impl<T: Parameter> Deref for Param<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.state.get_or_init(|| {
let mut result = self
.initialization
.as_ref()
.expect("Should have an initialization when no state provided.")
.write()
.unwrap();
let state = result.take().expect("Should exist when not initialized");
state.initialize()
})
}
}