use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use serde::Deserialize;
use serde::Serialize;
use super::{PrecisionSettings, Record};
use crate::module::{Param, ParamId};
use burn_tensor::{DataSerialize, Element};
use hashbrown::HashMap;
impl Record for () {
type Item<S: PrecisionSettings> = ();
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {}
fn from_item<S: PrecisionSettings>(_item: Self::Item<S>) -> Self {}
}
impl<T: Record> Record for Vec<T> {
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.into_iter().map(Record::into_item).collect()
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.into_iter().map(Record::from_item).collect()
}
}
impl<T: Record> Record for Option<T> {
type Item<S: PrecisionSettings> = Option<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.map(Record::into_item)
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.map(Record::from_item)
}
}
impl<const N: usize, T: Record + core::fmt::Debug> Record for [T; N] {
type Item<S: PrecisionSettings> = Vec<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.map(Record::into_item).into_iter().collect()
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.into_iter()
.map(Record::from_item)
.collect::<Vec<_>>()
.try_into()
.unwrap_or_else(|_| panic!("An arrar of size {N}"))
}
}
impl<T: Record> Record for HashMap<ParamId, T> {
type Item<S: PrecisionSettings> = HashMap<String, T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
let mut items = HashMap::with_capacity(self.len());
self.into_iter().for_each(|(id, record)| {
items.insert(id.to_string(), record.into_item());
});
items
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
let mut record = HashMap::with_capacity(item.len());
item.into_iter().for_each(|(id, item)| {
record.insert(ParamId::from(id), T::from_item(item));
});
record
}
}
impl<E: Element> Record for DataSerialize<E> {
type Item<S: PrecisionSettings> = DataSerialize<S::FloatElem>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self.convert()
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item.convert()
}
}
#[derive(new, Debug, Clone, Serialize, Deserialize)]
pub struct ParamSerde<T> {
id: String,
param: T,
}
impl<T: Record> Record for Param<T> {
type Item<S: PrecisionSettings> = ParamSerde<T::Item<S>>;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
ParamSerde::new(self.id.into_string(), self.value.into_item())
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
Param::new(ParamId::from(item.id), T::from_item(item.param))
}
}
macro_rules! primitive {
($type:ty) => {
impl Record for $type {
type Item<S: PrecisionSettings> = $type;
fn into_item<S: PrecisionSettings>(self) -> Self::Item<S> {
self
}
fn from_item<S: PrecisionSettings>(item: Self::Item<S>) -> Self {
item
}
}
};
}
primitive!(alloc::string::String);
primitive!(bool);
primitive!(f64);
primitive!(f32);
#[cfg(feature = "std")]
primitive!(half::bf16);
#[cfg(feature = "std")]
primitive!(half::f16);
primitive!(usize);
primitive!(u64);
primitive!(u32);
primitive!(u16);
primitive!(u8);
primitive!(i64);
primitive!(i32);
primitive!(i16);
primitive!(i8);