use std::collections::HashMap;
use super::adapter::BurnModuleAdapter;
use super::de::Deserializer;
use super::error::Error;
use super::ser::Serializer;
use crate::record::{PrecisionSettings, Record};
use crate::tensor::backend::Backend;
use alloc::fmt;
use burn_tensor::Bytes;
use num_traits::cast::ToPrimitive;
use regex::Regex;
use serde::Deserialize;
#[derive(Clone)]
pub enum NestedValue {
Default(Option<String>),
Bool(bool),
String(String),
F32(f32),
F64(f64),
I16(i16),
I32(i32),
I64(i64),
U8(u8),
U16(u16),
U64(u64),
Map(HashMap<String, NestedValue>),
Vec(Vec<NestedValue>),
U8s(Vec<u8>),
U16s(Vec<u16>),
F32s(Vec<f32>),
Bytes(Bytes),
}
impl NestedValue {
pub fn as_map(self) -> Option<HashMap<String, NestedValue>> {
match self {
NestedValue::Map(map) => Some(map),
_ => None,
}
}
pub fn as_bool(self) -> Option<bool> {
match self {
NestedValue::Bool(bool) => Some(bool),
_ => None,
}
}
pub fn as_string(self) -> Option<String> {
match self {
NestedValue::String(string) => Some(string),
_ => None,
}
}
pub fn as_f32(self) -> Option<f32> {
match self {
NestedValue::F32(f32) => Some(f32),
NestedValue::F64(f) => f.to_f32(),
_ => None,
}
}
pub fn as_f64(self) -> Option<f64> {
match self {
NestedValue::F64(f64) => Some(f64),
NestedValue::F32(f) => f.to_f64(),
_ => None,
}
}
pub fn as_i16(self) -> Option<i16> {
match self {
NestedValue::I16(i16) => Some(i16),
NestedValue::I32(i) => i.to_i16(),
NestedValue::I64(i) => i.to_i16(),
NestedValue::U16(u) => u.to_i16(),
NestedValue::U64(u) => u.to_i16(),
_ => None,
}
}
pub fn as_i32(self) -> Option<i32> {
match self {
NestedValue::I32(i32) => Some(i32),
NestedValue::I16(i) => i.to_i32(),
NestedValue::I64(i) => i.to_i32(),
NestedValue::U16(u) => u.to_i32(),
NestedValue::U64(u) => u.to_i32(),
_ => None,
}
}
pub fn as_i64(self) -> Option<i64> {
match self {
NestedValue::I64(i64) => Some(i64),
NestedValue::I16(i) => i.to_i64(),
NestedValue::I32(i) => i.to_i64(),
NestedValue::U16(u) => u.to_i64(),
NestedValue::U64(u) => u.to_i64(),
_ => None,
}
}
pub fn as_u8(self) -> Option<u8> {
match self {
NestedValue::U8(u8) => Some(u8),
NestedValue::I16(i) => i.to_u8(),
NestedValue::I32(i) => i.to_u8(),
NestedValue::I64(i) => i.to_u8(),
NestedValue::U16(u) => u.to_u8(),
NestedValue::U64(u) => u.to_u8(),
_ => None,
}
}
pub fn as_u16(self) -> Option<u16> {
match self {
NestedValue::U16(u16) => Some(u16),
NestedValue::I16(i) => i.to_u16(),
NestedValue::I32(i) => i.to_u16(),
NestedValue::I64(i) => i.to_u16(),
NestedValue::U64(u) => u.to_u16(),
_ => None,
}
}
pub fn as_u64(self) -> Option<u64> {
match self {
NestedValue::U64(u64) => Some(u64),
NestedValue::I16(i) => i.to_u64(),
NestedValue::I32(i) => i.to_u64(),
NestedValue::I64(i) => i.to_u64(),
NestedValue::U16(u) => u.to_u64(),
_ => None,
}
}
pub fn as_bytes(self) -> Option<Bytes> {
match self {
NestedValue::Bytes(u) => Some(u),
NestedValue::U8s(u) => Some(Bytes::from_elems(u)),
_ => None,
}
}
pub fn try_into_record<T, PS, A, B>(self, device: &B::Device) -> Result<T, Error>
where
B: Backend,
T: Record<B>,
PS: PrecisionSettings,
A: BurnModuleAdapter,
{
let deserializer = Deserializer::<A>::new(self, false);
let item = T::Item::deserialize(deserializer)?;
Ok(T::from_item::<PS>(item, device))
}
}
pub fn remap<T>(
mut tensors: HashMap<String, T>,
key_remap: Vec<(Regex, String)>,
) -> (HashMap<String, T>, Vec<(String, String)>) {
if key_remap.is_empty() {
let remapped_names = tensors
.keys()
.cloned()
.map(|s| (s.clone(), s)) .collect();
return (tensors, remapped_names);
}
let mut remapped = HashMap::new();
let mut remapped_names = Vec::new();
for (name, tensor) in tensors.drain() {
let mut new_name = name.clone();
for (pattern, replacement) in &key_remap {
if pattern.is_match(&new_name) {
new_name = pattern
.replace_all(&new_name, replacement.as_str())
.to_string();
}
}
remapped_names.push((new_name.clone(), name));
remapped.insert(new_name, tensor);
}
(remapped, remapped_names)
}
fn insert_nested_value(current: &mut NestedValue, keys: &[&str], value: NestedValue) {
if keys.is_empty() {
*current = value;
return;
}
match current {
NestedValue::Map(map) => {
if !map.contains_key(keys[0]) {
let next = if keys[1..]
.first()
.and_then(|k| k.parse::<usize>().ok())
.is_some()
{
NestedValue::Vec(Vec::new())
} else {
NestedValue::Map(HashMap::new())
};
map.insert(keys[0].to_string(), next);
}
insert_nested_value(map.get_mut(keys[0]).unwrap(), &keys[1..], value);
}
NestedValue::Vec(vec) => {
let index = keys[0].parse::<usize>().unwrap();
if index >= vec.len() {
vec.resize_with(index + 1, || NestedValue::Map(HashMap::new()));
}
insert_nested_value(&mut vec[index], &keys[1..], value);
}
_ => panic!("Invalid structure encountered"),
}
}
pub trait Serializable {
fn serialize<PS>(&self, serializer: Serializer) -> Result<NestedValue, Error>
where
PS: PrecisionSettings;
}
pub fn unflatten<PS, T>(input: HashMap<String, T>) -> Result<NestedValue, Error>
where
PS: PrecisionSettings,
T: Serializable,
{
let mut result = NestedValue::Map(HashMap::new());
for (key, value) in input {
let parts: Vec<&str> = key.split('.').collect();
let st = value.serialize::<PS>(Serializer::new())?;
insert_nested_value(&mut result, &parts, st);
}
cleanup_empty_maps(&mut result);
Ok(result)
}
fn cleanup_empty_maps(current: &mut NestedValue) {
match current {
NestedValue::Map(map) => {
map.values_mut().for_each(cleanup_empty_maps);
}
NestedValue::Vec(vec) => {
vec.iter_mut().for_each(cleanup_empty_maps);
vec.retain(|v| !matches!(v, NestedValue::Map(m) if m.is_empty()));
}
_ => {}
}
}
fn write_vec_truncated<T: core::fmt::Debug>(
vec: &[T],
f: &mut core::fmt::Formatter,
) -> fmt::Result {
write!(f, "Vec([")?;
for (i, v) in vec.iter().take(3).enumerate() {
if i > 0 {
write!(f, ", ")?;
}
write!(f, "{v:?}")?;
}
write!(f, ", ...] len={})", vec.len())
}
impl fmt::Debug for NestedValue {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
NestedValue::Vec(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::U8s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::U16s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::F32s(vec) if vec.len() > 3 => write_vec_truncated(vec, f),
NestedValue::Bytes(bytes) if bytes.len() > 3 => write_vec_truncated(bytes, f),
NestedValue::Default(origin) => f.debug_tuple("Default").field(origin).finish(),
NestedValue::Bool(b) => f.debug_tuple("Bool").field(b).finish(),
NestedValue::String(s) => f.debug_tuple("String").field(s).finish(),
NestedValue::F32(val) => f.debug_tuple("F32").field(val).finish(),
NestedValue::F64(val) => f.debug_tuple("F64").field(val).finish(),
NestedValue::I16(val) => f.debug_tuple("I16").field(val).finish(),
NestedValue::I32(val) => f.debug_tuple("I32").field(val).finish(),
NestedValue::I64(val) => f.debug_tuple("I64").field(val).finish(),
NestedValue::U8(val) => f.debug_tuple("U8").field(val).finish(),
NestedValue::U16(val) => f.debug_tuple("U16").field(val).finish(),
NestedValue::U64(val) => f.debug_tuple("U64").field(val).finish(),
NestedValue::Map(map) => f.debug_map().entries(map.iter()).finish(),
NestedValue::Vec(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::U8s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::U16s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::F32s(vec) => f.debug_list().entries(vec.iter()).finish(),
NestedValue::Bytes(bytes) => f.debug_list().entries(bytes.iter()).finish(),
}
}
}