use super::ParamId;
use alloc::{boxed::Box, format};
use burn_std::stub::RwLock;
use burn_tensor::Shape;
use core::cell::OnceCell;
use core::ops::Deref;
#[cfg(target_has_atomic = "ptr")]
use alloc::sync::Arc;
#[cfg(not(target_has_atomic = "ptr"))]
use portable_atomic_util::Arc;
#[cfg(target_has_atomic = "ptr")]
type Mapper<T> = Arc<dyn Fn(T) -> T + Send + Sync>;
#[cfg(not(target_has_atomic = "ptr"))]
type Mapper<T> = Arc<Box<dyn Fn(T) -> T + Send + Sync>>;
#[cfg(target_has_atomic = "ptr")]
fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
Arc::new(func)
}
#[cfg(not(target_has_atomic = "ptr"))]
fn new_mapper<T, F: Fn(T) -> T + Send + Sync + 'static>(func: F) -> Mapper<T> {
Arc::new(Box::new(func))
}
pub struct Param<T: Parameter> {
pub id: ParamId,
pub(crate) state: OnceCell<T>,
pub(crate) initialization: Option<RwLock<Option<Uninitialized<T>>>>,
pub(crate) param_mapper: ParamMapper<T>,
}
#[derive(Clone)]
pub struct ParamMapper<T: Parameter> {
load: Option<Mapper<T>>,
save: Option<Mapper<T>>,
}
impl<T: Parameter> core::fmt::Debug for ParamMapper<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_fmt(format_args!(
"ParamMapper {{ load: {}, save: {} }}",
self.load.is_some(),
self.save.is_some()
))
}
}
impl<T: Parameter> ParamMapper<T> {
pub fn on_load(&self, param: T) -> T {
match &self.load {
Some(mapper) => mapper(param),
None => param,
}
}
pub fn on_save(&self, param: T) -> T {
match &self.save {
Some(mapper) => mapper(param),
None => param,
}
}
}
impl<T: Parameter> Default for ParamMapper<T> {
fn default() -> Self {
Self {
load: None,
save: None,
}
}
}
impl<T: Parameter> core::fmt::Display for Param<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("Param: {}", self.id).as_str())
}
}
impl<T: Parameter> core::fmt::Debug for Param<T> {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.write_str(format!("Param: {} - {:?}", self.id, self.param_mapper).as_str())
}
}
pub trait Parameter: Clone + core::fmt::Debug + Send {
type Device: Clone;
fn device(&self) -> Self::Device;
fn is_require_grad(&self) -> bool;
fn set_require_grad(self, require_grad: bool) -> Self;
}
#[allow(clippy::type_complexity)]
pub(crate) struct Uninitialized<P: Parameter> {
init: Box<dyn FnOnce(&P::Device, bool) -> P + Send>,
pub(crate) device: P::Device,
pub(crate) is_require_grad: bool,
pub(crate) shape: Shape,
}
impl<P: Parameter> Uninitialized<P> {
fn initialize(self) -> P {
let init = self.init;
init(&self.device, self.is_require_grad)
}
}
impl<T: Parameter> Param<T> {
pub fn initialized(id: ParamId, value: T) -> Self {
Self {
id,
state: OnceCell::from(value),
initialization: None,
param_mapper: Default::default(),
}
}
pub fn uninitialized<F>(
id: ParamId,
init: F,
device: T::Device,
is_require_grad: bool,
shape: Shape,
) -> Self
where
F: FnOnce(&T::Device, bool) -> T + Send + 'static,
{
Self {
id,
state: OnceCell::new(),
initialization: Some(RwLock::new(Some(Uninitialized {
init: Box::new(init),
device,
is_require_grad,
shape,
}))),
param_mapper: Default::default(),
}
}
pub fn val(&self) -> T {
self.state
.get_or_init(|| {
let mut result = self
.initialization
.as_ref()
.expect("Should have an initialization when no state provided.")
.write()
.unwrap();
let state = result.take().expect("Should exist when not initialized");
state.initialize()
})
.clone()
}
pub fn is_initialized(&self) -> bool {
self.state.get().is_some()
}
pub fn into_value(self) -> T {
self.consume().1
}
pub fn consume(self) -> (ParamId, T, ParamMapper<T>) {
let tensor = self.val();
core::mem::drop(self.state);
(self.id, tensor, self.param_mapper)
}
pub fn map<F: FnOnce(T) -> T>(self, func: F) -> Self {
let (id, tensor, param_mapper) = self.consume();
let tensor = func(tensor);
Self {
id,
state: OnceCell::from(tensor),
initialization: None,
param_mapper,
}
}
pub fn from_mapped_value(id: ParamId, value: T, param_mapper: ParamMapper<T>) -> Self {
Self {
id,
state: OnceCell::from(value),
initialization: None,
param_mapper,
}
}
pub fn load_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
self.param_mapper.load = Some(new_mapper(func));
self
}
pub fn save_mapper<F: Fn(T) -> T + Send + Sync + 'static>(mut self, func: F) -> Self {
self.param_mapper.save = Some(new_mapper(func));
self
}
pub fn init_mapper<F: FnOnce(T) -> T + Send + 'static>(self, func: F) -> Self
where
T: 'static,
{
let initialization = match &self.initialization {
Some(init) => init,
None => return self.map(func),
};
let mut init = initialization.write().unwrap();
match init.as_mut() {
Some(value) => {
#[allow(clippy::type_complexity)]
let mut prev: Box<dyn FnOnce(&T::Device, bool) -> T + Send> =
Box::new(|_, _| panic!("Fake func to not have null ref."));
core::mem::swap(&mut prev, &mut value.init);
value.init = Box::new(|a, b| {
let tensor = prev(a, b);
func(tensor)
});
core::mem::drop(init);
self
}
None => {
core::mem::drop(init);
self.map(func)
}
}
}
pub fn lazy_device(&self) -> T::Device {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.device(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.device.clone(),
None => self.device(),
}
}
pub(crate) fn lazy_is_require_grad(&self) -> bool {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.is_require_grad(),
};
let init = initialization.read().unwrap();
match init.as_ref() {
Some(value) => value.is_require_grad,
None => self.is_require_grad(),
}
}
pub fn set_require_grad(self, require_grad: bool) -> Self {
let initialization = match &self.initialization {
Some(init) => init,
None => return self.map(|tensor| tensor.set_require_grad(require_grad)),
};
let mut init = initialization.write().unwrap();
let mut is_lazy = false;
if let Some(value) = init.as_mut() {
is_lazy = true;
value.is_require_grad = require_grad;
};
core::mem::drop(init);
if is_lazy {
return self;
}
self.map(|tensor| tensor.set_require_grad(require_grad))
}
}
impl<T: Parameter> Clone for Param<T> {
fn clone(&self) -> Self {
let mut param = Param::initialized(self.id, self.val());
param.param_mapper = self.param_mapper.clone();
param
}
}
impl<T: Parameter> Deref for Param<T> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.state.get_or_init(|| {
let mut result = self
.initialization
.as_ref()
.expect("Should have an initialization when no state provided.")
.write()
.unwrap();
let state = result.take().expect("Should exist when not initialized");
state.initialize()
})
}
}