use crate::ModuleStore;
use crate::burnpack::store::BurnpackStore;
use burn_core as burn;
use burn_core::module::{Module, Param};
use burn_tensor::{AllocationProperty, Bytes, Tensor, backend::Backend};
type TestBackend = burn_flex::Flex;
#[derive(Module, Debug)]
struct SimpleModule<B: Backend> {
weight: Param<Tensor<B, 2>>,
bias: Param<Tensor<B, 1>>,
}
impl<B: Backend> SimpleModule<B> {
fn new(device: &B::Device) -> Self {
Self {
weight: Param::from_data([[1.0f32, 2.0], [3.0, 4.0]], device),
bias: Param::from_data([0.5f32, 1.5], device),
}
}
fn new_zeros(device: &B::Device) -> Self {
Self {
weight: Param::from_tensor(Tensor::zeros([2, 2], device)),
bias: Param::from_tensor(Tensor::zeros([2], device)),
}
}
}
#[test]
fn test_from_static_enables_zero_copy() {
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let bytes_vec: Vec<u8> = bytes.to_vec();
let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice());
let mut load_store = BurnpackStore::from_static(static_bytes);
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
let loaded_weight = loaded_module.weight.val().to_data();
let loaded_bias = loaded_module.bias.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
assert_eq!(loaded_bias.to_vec::<f32>().unwrap(), vec![0.5, 1.5]);
}
#[test]
fn test_zero_copy_builder_method() {
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let shared = bytes::Bytes::from(bytes.to_vec());
let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other);
let mut load_store = BurnpackStore::from_bytes(Some(cubecl_bytes)).zero_copy(true);
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
let loaded_weight = loaded_module.weight.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn test_zero_copy_disabled_uses_copy() {
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let bytes_vec: Vec<u8> = bytes.to_vec();
let static_bytes: &'static [u8] = Box::leak(bytes_vec.into_boxed_slice());
let mut load_store = BurnpackStore::from_static(static_bytes).zero_copy(false);
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
let loaded_weight = loaded_module.weight.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn test_from_bytes_uses_copy_by_default() {
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
let loaded_weight = loaded_module.weight.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
}
#[test]
fn test_storage_backend_slice_bytes() {
use crate::burnpack::reader::BurnpackReader;
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let shared = bytes::Bytes::from(bytes.to_vec());
let cubecl_bytes = Bytes::from_shared(shared, AllocationProperty::Other);
let reader = BurnpackReader::from_bytes(cubecl_bytes).unwrap();
let snapshots = reader.get_snapshots_zero_copy(true).unwrap();
assert_eq!(snapshots.len(), 2);
for snapshot in &snapshots {
let data = snapshot.to_data().unwrap();
assert!(!data.bytes.is_empty());
}
}
#[test]
fn test_zero_copy_file_based_works() {
use tempfile::NamedTempFile;
let device = Default::default();
let module = SimpleModule::<TestBackend>::new(&device);
let temp_file = NamedTempFile::new().unwrap();
let path = temp_file.path();
let mut save_store = BurnpackStore::from_file(path).overwrite(true);
save_store.collect_from(&module).unwrap();
let mut load_store = BurnpackStore::from_file(path).zero_copy(true);
let mut loaded_module = SimpleModule::<TestBackend>::new_zeros(&device);
load_store.apply_to(&mut loaded_module).unwrap();
let loaded_weight = loaded_module.weight.val().to_data();
assert_eq!(
loaded_weight.to_vec::<f32>().unwrap(),
vec![1.0, 2.0, 3.0, 4.0]
);
}