use alloc::rc::Rc;
use alloc::string::String;
use alloc::string::ToString;
use alloc::vec::Vec;
use burn_core::module::ParamId;
use burn_tensor::quantization::{QPARAM_ALIGN, QuantParam, params_shape};
use burn_tensor::{Bool, DType, Int, Shape, Tensor, TensorData, backend::Backend};
use half::f16;
const fn quant_param_size(param: QuantParam) -> usize {
match param {
QuantParam::F32 => core::mem::size_of::<f32>(),
QuantParam::F16 | QuantParam::BF16 => core::mem::size_of::<f16>(),
QuantParam::UE8M0 | QuantParam::UE4M3 => core::mem::size_of::<u8>(),
}
}
#[derive(Debug, Clone)]
pub enum TensorSnapshotError {
IoError(String),
DataError(String),
PanicError(String),
}
impl core::fmt::Display for TensorSnapshotError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::IoError(e) => write!(f, "I/O error: {}", e),
Self::DataError(e) => write!(f, "Data error: {}", e),
Self::PanicError(e) => write!(f, "Panic error: {}", e),
}
}
}
impl core::error::Error for TensorSnapshotError {}
pub struct TensorSnapshot {
data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
pub dtype: burn_tensor::DType,
pub shape: Shape,
pub path_stack: Option<Vec<String>>,
pub container_stack: Option<Vec<String>>,
pub tensor_id: Option<ParamId>,
}
impl TensorSnapshot {
pub fn from_float<B: Backend, const D: usize>(
tensor: &Tensor<B, D>,
path_stack: Vec<String>,
container_stack: Vec<String>,
tensor_id: ParamId,
) -> Self {
let dtype = tensor.dtype();
let shape = tensor.shape();
let tensor = tensor.clone(); Self {
data_fn: Rc::new(move || Ok(tensor.to_data())),
dtype,
shape,
path_stack: Some(path_stack),
container_stack: Some(container_stack),
tensor_id: Some(tensor_id),
}
}
pub fn from_int<B: Backend, const D: usize>(
tensor: &Tensor<B, D, Int>,
path_stack: Vec<String>,
container_stack: Vec<String>,
tensor_id: ParamId,
) -> Self {
let dtype = tensor.dtype();
let shape = tensor.shape();
let tensor = tensor.clone(); Self {
data_fn: Rc::new(move || Ok(tensor.to_data())),
dtype,
shape,
path_stack: Some(path_stack),
container_stack: Some(container_stack),
tensor_id: Some(tensor_id),
}
}
pub fn from_bool<B: Backend, const D: usize>(
tensor: &Tensor<B, D, Bool>,
path_stack: Vec<String>,
container_stack: Vec<String>,
tensor_id: ParamId,
) -> Self {
let dtype = tensor.dtype();
let shape = tensor.shape();
let tensor = tensor.clone(); Self {
data_fn: Rc::new(move || Ok(tensor.to_data())),
dtype,
shape,
path_stack: Some(path_stack),
container_stack: Some(container_stack),
tensor_id: Some(tensor_id),
}
}
#[cfg(feature = "std")]
pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| (self.data_fn)())).unwrap_or_else(
|_| {
Err(TensorSnapshotError::PanicError(
"Panic occurred while loading tensor data".to_string(),
))
},
)
}
#[cfg(not(feature = "std"))]
pub fn to_data(&self) -> Result<TensorData, TensorSnapshotError> {
(self.data_fn)() }
pub fn full_path(&self) -> String {
self.path_stack
.as_ref()
.map(|stack| stack.join("."))
.unwrap_or_default()
}
pub fn container_path(&self) -> String {
self.container_stack
.as_ref()
.map(|stack| stack.join("."))
.unwrap_or_default()
}
pub fn module_type(&self) -> Option<String> {
self.container_stack.as_ref().and_then(|stack| {
stack
.iter()
.rev()
.find(|ct| ct.starts_with("Struct:") || ct.starts_with("Enum:"))
.cloned()
})
}
pub fn container_type(&self) -> String {
self.container_stack
.as_ref()
.and_then(|stack| stack.last())
.cloned()
.unwrap_or_else(|| "Unknown".to_string())
}
pub fn from_closure(
data_fn: Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>>,
dtype: burn_tensor::DType,
shape: Shape,
path_stack: Vec<String>,
container_stack: Vec<String>,
tensor_id: ParamId,
) -> Self {
Self {
data_fn,
dtype,
shape,
path_stack: Some(path_stack),
container_stack: Some(container_stack),
tensor_id: Some(tensor_id),
}
}
pub fn from_data(
data: TensorData,
path_stack: Vec<String>,
container_stack: Vec<String>,
tensor_id: ParamId,
) -> Self {
let dtype = data.dtype;
let shape = data.shape.clone();
Self {
data_fn: Rc::new(move || Ok(data.clone())),
dtype,
shape,
path_stack: Some(path_stack),
container_stack: Some(container_stack),
tensor_id: Some(tensor_id),
}
}
pub fn data_len(&self) -> usize {
const BITS_PER_BYTE: usize = 8;
let num_elements: usize = self.shape.iter().product();
match self.dtype {
DType::QFloat(scheme) => {
let num_storage_elements = num_elements.div_ceil(scheme.num_quants());
let value_bytes =
num_storage_elements * (scheme.size_bits_stored() / BITS_PER_BYTE);
let num_params = params_shape(&self.shape, scheme.level).num_elements();
let aligned_value_bytes = value_bytes.div_ceil(QPARAM_ALIGN) * QPARAM_ALIGN;
let scale_bytes = num_params * quant_param_size(scheme.param);
aligned_value_bytes + scale_bytes
}
_ => num_elements * self.dtype.size(),
}
}
pub fn clone_data_fn(&self) -> Rc<dyn Fn() -> Result<TensorData, TensorSnapshotError>> {
self.data_fn.clone()
}
}
impl Clone for TensorSnapshot {
fn clone(&self) -> Self {
Self {
data_fn: self.data_fn.clone(),
dtype: self.dtype,
shape: self.shape.clone(),
path_stack: self.path_stack.clone(),
container_stack: self.container_stack.clone(),
tensor_id: self.tensor_id,
}
}
}
impl core::fmt::Debug for TensorSnapshot {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
f.debug_struct("TensorSnapshot")
.field("dtype", &self.dtype)
.field("shape", &self.shape)
.field("path_stack", &self.path_stack)
.field("container_stack", &self.container_stack)
.field("tensor_id", &self.tensor_id)
.finish()
}
}
#[cfg(all(test, feature = "std"))]
mod tests {
use super::*;
type TestBackend = burn_flex::Flex;
use alloc::string::ToString;
use burn_tensor::{BoolStore, DType, shape};
#[test]
fn tensor_view_float() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let snapshot = TensorSnapshot::from_float(
&tensor,
vec!["test".to_string(), "weight".to_string()],
vec!["TestModule".to_string(), "Param".to_string()],
ParamId::new(),
);
assert_eq!(snapshot.dtype, DType::F32);
assert_eq!(snapshot.shape, shape![2, 2]);
assert_eq!(snapshot.full_path(), "test.weight");
assert_eq!(snapshot.container_path(), "TestModule.Param");
let data = snapshot.to_data().unwrap();
assert_eq!(data.shape, shape![2, 2]);
assert_eq!(data.dtype, DType::F32);
}
#[test]
fn tensor_view_int() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 2, Int>::from_data([[1, 2], [3, 4]], &device);
let snapshot = TensorSnapshot::from_int(
&tensor,
vec!["test".to_string(), "int".to_string()],
vec!["TestModule".to_string(), "Param".to_string()],
ParamId::new(),
);
assert_eq!(snapshot.dtype, DType::I32);
assert_eq!(snapshot.shape, shape![2, 2]);
let data = snapshot.to_data().unwrap();
assert_eq!(data.shape, shape![2, 2]);
assert_eq!(data.dtype, DType::I32);
}
#[test]
fn tensor_view_bool() {
let device = Default::default();
let tensor =
Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
let snapshot = TensorSnapshot::from_bool(
&tensor,
vec!["test".to_string(), "bool".to_string()],
vec!["TestModule".to_string(), "Param".to_string()],
ParamId::new(),
);
assert_eq!(snapshot.dtype, DType::Bool(BoolStore::Native));
assert_eq!(snapshot.shape, shape![2, 2]);
let data = snapshot.to_data().unwrap();
assert_eq!(data.shape, shape![2, 2]);
assert_eq!(data.dtype, DType::Bool(BoolStore::Native));
}
#[test]
fn data_len() {
let device = Default::default();
let tensor_f32 = Tensor::<TestBackend, 2>::from_data([[1.0, 2.0], [3.0, 4.0]], &device);
let view_f32 = TensorSnapshot::from_float(
&tensor_f32,
vec!["test".to_string()],
vec!["Module".to_string()],
ParamId::new(),
);
assert_eq!(view_f32.data_len(), 16);
let tensor_int =
Tensor::<TestBackend, 3, Int>::from_data([[[1, 2], [3, 4]], [[5, 6], [7, 8]]], &device);
let view_int = TensorSnapshot::from_int(
&tensor_int,
vec!["test".to_string()],
vec!["Module".to_string()],
ParamId::new(),
);
assert_eq!(view_int.data_len(), 32);
let tensor_bool =
Tensor::<TestBackend, 2, Bool>::from_data([[true, false], [false, true]], &device);
let view_bool = TensorSnapshot::from_bool(
&tensor_bool,
vec!["test".to_string()],
vec!["Module".to_string()],
ParamId::new(),
);
assert_eq!(view_bool.data_len(), 4); }
#[test]
fn from_closure() {
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0]);
let dtype = data.dtype;
let shape = data.shape.clone();
let snapshot = TensorSnapshot::from_closure(
Rc::new(move || Ok(data.clone())),
dtype,
shape.clone(),
vec!["model".to_string(), "layer".to_string()],
vec!["Model".to_string(), "Layer".to_string()],
ParamId::new(),
);
assert_eq!(snapshot.dtype, DType::F32);
assert_eq!(snapshot.shape, shape![4]);
assert_eq!(snapshot.full_path(), "model.layer");
assert_eq!(snapshot.data_len(), 16);
let materialized = snapshot.to_data().unwrap();
assert_eq!(materialized.shape, shape![4]);
}
#[test]
fn from_data() {
let data = TensorData::from([1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
let original_dtype = data.dtype;
let original_shape = data.shape.clone();
let snapshot = TensorSnapshot::from_data(
data,
vec!["encoder".to_string(), "weight".to_string()],
vec!["Struct:Encoder".to_string(), "Struct:Dense".to_string()],
ParamId::new(),
);
assert_eq!(snapshot.dtype, original_dtype);
assert_eq!(snapshot.shape, original_shape);
assert_eq!(snapshot.full_path(), "encoder.weight");
assert_eq!(snapshot.container_type(), "Struct:Dense");
assert_eq!(snapshot.data_len(), 24);
let materialized = snapshot.to_data().unwrap();
assert_eq!(materialized.shape, original_shape);
}
#[test]
#[cfg(feature = "std")]
fn panic_catching_in_to_data() {
use alloc::rc::Rc;
let snapshot = TensorSnapshot {
data_fn: Rc::new(|| panic!("Test panic in data_fn")),
dtype: DType::F32,
shape: shape![2, 2],
path_stack: Some(vec!["test".to_string()]),
container_stack: Some(vec!["Test".to_string()]),
tensor_id: Some(ParamId::new()),
};
let result = snapshot.to_data();
assert!(result.is_err());
match result {
Err(TensorSnapshotError::PanicError(msg)) => {
assert!(msg.contains("Panic occurred"));
}
_ => panic!("Expected PanicError with panic message"),
}
}
#[test]
fn error_propagation_in_closure() {
use alloc::rc::Rc;
let snapshot = TensorSnapshot::from_closure(
Rc::new(|| Err(TensorSnapshotError::IoError("Simulated IO error".into()))),
DType::F32,
shape![2, 2],
vec!["error_test".into()],
vec![],
ParamId::new(),
);
let result = snapshot.to_data();
assert!(result.is_err());
match result {
Err(TensorSnapshotError::IoError(msg)) => {
assert!(msg.contains("Simulated IO error"));
}
_ => panic!("Expected IoError"),
}
}
#[test]
fn container_type_extraction() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
let snapshot = TensorSnapshot::from_float(
&tensor,
vec![
"model".to_string(),
"layer1".to_string(),
"weight".to_string(),
],
vec![
"Struct:Model".to_string(),
"Struct:Conv2d".to_string(),
"Struct:Param".to_string(),
],
ParamId::new(),
);
assert_eq!(snapshot.container_type(), "Struct:Param");
assert_eq!(snapshot.module_type(), Some("Struct:Param".to_string()));
assert_eq!(
snapshot.container_path(),
"Struct:Model.Struct:Conv2d.Struct:Param"
);
assert_eq!(snapshot.full_path(), "model.layer1.weight");
}
#[test]
fn container_type_vs_module_type() {
let device = Default::default();
let tensor = Tensor::<TestBackend, 1>::from_data([1.0, 2.0, 3.0], &device);
let snapshot = TensorSnapshot::from_float(
&tensor,
vec![
"model".to_string(),
"layers".to_string(),
"0".to_string(),
"weight".to_string(),
],
vec![
"Struct:Model".to_string(),
"Vec".to_string(),
"Struct:Linear".to_string(),
],
ParamId::new(),
);
assert_eq!(snapshot.container_type(), "Struct:Linear");
assert_eq!(snapshot.module_type(), Some("Struct:Linear".to_string()));
let snapshot2 = TensorSnapshot::from_float(
&tensor,
vec!["data".to_string(), "0".to_string()],
vec!["Vec".to_string()],
ParamId::new(),
);
assert_eq!(snapshot2.container_type(), "Vec");
assert_eq!(snapshot2.module_type(), None);
let snapshot3 = TensorSnapshot::from_float(
&tensor,
vec![
"model".to_string(),
"layers".to_string(),
"0".to_string(),
"sublayers".to_string(),
"1".to_string(),
"weight".to_string(),
],
vec![
"Struct:Model".to_string(),
"Vec".to_string(),
"Array".to_string(),
"Struct:Linear".to_string(),
],
ParamId::new(),
);
assert_eq!(snapshot3.container_type(), "Struct:Linear");
assert_eq!(snapshot3.module_type(), Some("Struct:Linear".to_string()));
}
}