#[cfg(feature = "std")]
use crate::KeyRemapper;
use crate::burnpack::store::BurnpackStore;
use crate::{ModuleAdapter, ModuleSnapshot, ModuleStore, PathFilter};
use burn_core as burn;
use burn_core::module::{Module, Param};
use burn_tensor::shape;
use burn_tensor::{Tensor, backend::Backend};
type TestBackend = burn_flex::Flex;
#[derive(Module, Debug)]
struct TestModule<B: Backend> {
weight: Param<Tensor<B, 2>>,
bias: Param<Tensor<B, 1>>,
nested: NestedModule<B>,
}
#[derive(Module, Debug)]
struct NestedModule<B: Backend> {
gamma: Param<Tensor<B, 1>>,
beta: Param<Tensor<B, 1>>,
}
impl<B: Backend> TestModule<B> {
fn new(device: &B::Device) -> Self {
Self {
weight: Param::from_data([[1.0, 2.0], [3.0, 4.0]], device),
bias: Param::from_data([0.1, 0.2], device),
nested: NestedModule {
gamma: Param::from_data([1.0, 1.0], device),
beta: Param::from_data([0.0, 0.0], device),
},
}
}
fn new_zeros(device: &B::Device) -> Self {
Self {
weight: Param::from_tensor(Tensor::zeros([2, 2], device)),
bias: Param::from_tensor(Tensor::zeros([2], device)),
nested: NestedModule {
gamma: Param::from_tensor(Tensor::zeros([2], device)),
beta: Param::from_tensor(Tensor::zeros([2], device)),
},
}
}
fn new_uninitialized(device: &B::Device) -> Self {
use burn_core::module::ParamId;
let device_clone = device.clone();
let device_clone2 = device.clone();
let device_clone3 = device.clone();
let device_clone4 = device.clone();
Self {
weight: Param::uninitialized(
ParamId::new(),
move |d, _| Tensor::zeros([2, 2], d),
device_clone,
true,
[2, 2].into(),
),
bias: Param::uninitialized(
ParamId::new(),
move |d, _| Tensor::zeros([2], d),
device_clone2,
true,
[2].into(),
),
nested: NestedModule {
gamma: Param::uninitialized(
ParamId::new(),
move |d, _| Tensor::zeros([2], d),
device_clone3,
true,
[2].into(),
),
beta: Param::uninitialized(
ParamId::new(),
move |d, _| Tensor::zeros([2], d),
device_clone4,
true,
[2].into(),
),
},
}
}
}
#[test]
fn test_store_from_bytes_round_trip() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 4); assert!(result.errors.is_empty());
let weight1 = module.weight.val().to_data().to_vec::<f32>().unwrap();
let weight2 = module2.weight.val().to_data().to_vec::<f32>().unwrap();
assert_eq!(weight1, weight2);
}
#[test]
fn test_store_with_metadata() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None)
.metadata("version", "1.0.0")
.metadata("model_name", "test_model")
.metadata("author", "burn_team");
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 4);
}
#[test]
#[cfg(feature = "std")]
fn test_store_with_path_filter() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_regex("^(weight|bias)$");
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 2); assert_eq!(result.skipped.len(), 2);
let weight2 = module2.weight.val().to_data().to_vec::<f32>().unwrap();
assert_eq!(weight2, vec![1.0, 2.0, 3.0, 4.0]);
let gamma2 = module2
.nested
.gamma
.val()
.to_data()
.to_vec::<f32>()
.unwrap();
assert_eq!(gamma2, vec![0.0, 0.0]);
}
#[test]
#[cfg(feature = "std")]
fn test_store_with_key_remapping() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let remapper = KeyRemapper::new()
.add_pattern(r"nested\.gamma", "nested.new_gamma")
.unwrap()
.add_pattern(r"nested\.beta", "nested.new_beta")
.unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes))
.remap(remapper)
.allow_partial(true);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert_eq!(result.applied.len(), 2); assert_eq!(result.unused.len(), 2); assert_eq!(result.missing.len(), 2); }
#[test]
fn test_store_allow_partial() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let filter = PathFilter::new()
.with_full_path("weight")
.with_full_path("bias");
let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 2);
assert_eq!(result.missing.len(), 2);
let weight2 = module2.weight.val().to_data().to_vec::<f32>().unwrap();
assert_eq!(weight2, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_store_match_all() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None).match_all();
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 4);
assert!(result.errors.is_empty());
assert!(result.missing.is_empty());
assert!(result.unused.is_empty());
}
#[test]
fn test_store_with_full_path() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes))
.with_full_path("weight")
.with_full_path("nested.gamma");
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 2); assert_eq!(result.skipped.len(), 2); }
#[test]
#[cfg(feature = "std")]
fn test_store_chain_multiple_patterns() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None)
.metadata("version", "1.0")
.metadata("format", "burnpack")
.with_regex(r"^(weight|nested\.)")
.match_all();
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 4); }
#[test]
#[cfg(feature = "std")]
fn test_store_with_remap_pattern() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes))
.with_remap_pattern(r"^nested\.", "sub_module.")
.allow_partial(true);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert_eq!(result.applied.len(), 2); assert_eq!(result.unused.len(), 2); }
#[test]
fn test_store_default_metadata() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
fn test_store_default_metadata_with_custom() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None)
.metadata("custom_field", "custom_value")
.metadata("author", "test_author");
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
fn test_store_clear_metadata() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None).clear_metadata();
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
fn test_store_validate_enabled() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert!(result.errors.is_empty());
}
#[test]
fn test_store_validate_disabled() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes)).validate(false);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
fn test_store_allow_partial_missing_tensors() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let filter = PathFilter::new().with_full_path("weight");
let mut save_store = BurnpackStore::from_bytes(None).with_filter(filter);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes.clone()));
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2);
assert!(result.is_err());
let mut load_store = BurnpackStore::from_bytes(Some(bytes)).allow_partial(true);
let mut module3 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module3).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 1); assert!(!result.missing.is_empty()); }
#[test]
#[cfg(feature = "std")]
fn test_store_file_round_trip() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("test_file_round_trip.bpk");
let mut save_store = BurnpackStore::from_file(&path).metadata("test", "value");
save_store.collect_from(&module).unwrap();
assert!(path.exists());
let mut load_store = BurnpackStore::from_file(&path);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 4);
let weight1 = module.weight.val().to_data().to_vec::<f32>().unwrap();
let weight2 = module2.weight.val().to_data().to_vec::<f32>().unwrap();
assert_eq!(weight1, weight2);
}
#[test]
#[cfg(feature = "std")]
fn test_store_overwrite_protection() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("test_model.bpk");
let mut save_store = BurnpackStore::from_file(&path);
save_store.collect_from(&module).unwrap();
assert!(path.exists());
let mut save_store2 = BurnpackStore::from_file(&path);
let result = save_store2.collect_from(&module);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("File already exists")
);
let mut save_store3 = BurnpackStore::from_file(&path).overwrite(true);
save_store3.collect_from(&module).unwrap();
let mut load_store = BurnpackStore::from_file(&path);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
#[cfg(feature = "std")]
fn test_store_overwrite_with_metadata() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("test_model_metadata.bpk");
let mut save_store = BurnpackStore::from_file(&path)
.metadata("version", "1.0")
.overwrite(true);
save_store.collect_from(&module).unwrap();
let mut save_store2 = BurnpackStore::from_file(&path)
.metadata("version", "2.0")
.overwrite(true);
save_store2.collect_from(&module).unwrap();
let mut load_store = BurnpackStore::from_file(&path);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
#[cfg(feature = "std")]
fn test_store_auto_extension_default() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("model");
let mut save_store = BurnpackStore::from_file(&path);
save_store.collect_from(&module).unwrap();
let expected_path = temp_dir.path().join("model.bpk");
assert!(expected_path.exists());
assert!(!path.exists());
let mut load_store = BurnpackStore::from_file(&path);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
#[cfg(feature = "std")]
fn test_store_auto_extension_with_existing_extension() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("model.bpk");
let mut save_store = BurnpackStore::from_file(&path);
save_store.collect_from(&module).unwrap();
assert!(path.exists());
let double_ext_path = temp_dir.path().join("model.bpk.bpk");
assert!(!double_ext_path.exists());
let mut load_store = BurnpackStore::from_file(&path);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
#[cfg(feature = "std")]
fn test_store_auto_extension_with_custom_extension() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("model.mpk");
let mut save_store = BurnpackStore::from_file(&path);
save_store.collect_from(&module).unwrap();
assert!(path.exists());
let burnpack_path = temp_dir.path().join("model.mpk.bpk");
assert!(!burnpack_path.exists());
let mut load_store = BurnpackStore::from_file(&path);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
#[cfg(feature = "std")]
fn test_store_auto_extension_disabled() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("model");
let mut save_store = BurnpackStore::from_file(&path).auto_extension(false);
save_store.collect_from(&module).unwrap();
assert!(path.exists());
let burnpack_path = temp_dir.path().join("model.bpk");
assert!(!burnpack_path.exists());
let mut load_store = BurnpackStore::from_file(&path).auto_extension(false);
let mut module2 = TestModule::<TestBackend>::new_zeros(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
#[cfg(feature = "std")]
fn test_partial_loading_preserves_lazy_initialization() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("model.bpk");
let mut save_store = BurnpackStore::from_file(&path);
save_store.collect_from(&module).unwrap();
let mut load_module = TestModule::<TestBackend>::new_uninitialized(&device);
assert!(
!load_module.weight.is_initialized(),
"weight should be uninitialized before loading"
);
assert!(
!load_module.bias.is_initialized(),
"bias should be uninitialized before loading"
);
assert!(
!load_module.nested.gamma.is_initialized(),
"nested.gamma should be uninitialized before loading"
);
assert!(
!load_module.nested.beta.is_initialized(),
"nested.beta should be uninitialized before loading"
);
let filter = PathFilter::new().with_regex("^(weight|bias)$");
let mut load_store = BurnpackStore::from_file(&path).filter(filter);
let result = load_module.load_from(&mut load_store).unwrap();
assert_eq!(result.applied.len(), 2);
assert!(result.applied.contains(&"weight".to_string()));
assert!(result.applied.contains(&"bias".to_string()));
assert_eq!(result.skipped.len(), 2);
assert!(result.skipped.contains(&"nested.gamma".to_string()));
assert!(result.skipped.contains(&"nested.beta".to_string()));
assert!(
load_module.weight.is_initialized(),
"weight should be initialized after loading"
);
assert!(
load_module.bias.is_initialized(),
"bias should be initialized after loading"
);
assert!(
!load_module.nested.gamma.is_initialized(),
"nested.gamma should remain uninitialized (was skipped)"
);
assert!(
!load_module.nested.beta.is_initialized(),
"nested.beta should remain uninitialized (was skipped)"
);
let weight_data = load_module.weight.val().to_data().to_vec::<f32>().unwrap();
assert_eq!(weight_data, vec![1.0, 2.0, 3.0, 4.0]);
let bias_data = load_module.bias.val().to_data().to_vec::<f32>().unwrap();
assert_eq!(bias_data, vec![0.1, 0.2]);
let gamma_data = load_module
.nested
.gamma
.val()
.to_data()
.to_vec::<f32>()
.unwrap();
assert_eq!(gamma_data, vec![0.0, 0.0]);
assert!(
load_module.nested.gamma.is_initialized(),
"nested.gamma should be initialized after first access"
);
}
#[derive(Module, Debug)]
struct ForwardTestModel<B: Backend> {
linear1: burn_nn::Linear<B>,
linear2: burn_nn::Linear<B>,
}
impl<B: Backend> ForwardTestModel<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.linear1.forward(input);
let x = burn::tensor::activation::gelu(x);
self.linear2.forward(x)
}
}
#[derive(burn::config::Config, Debug)]
struct ForwardTestModelConfig {
input_size: usize,
hidden_size: usize,
output_size: usize,
}
impl ForwardTestModelConfig {
fn init<B: Backend>(&self, device: &B::Device) -> ForwardTestModel<B> {
ForwardTestModel {
linear1: burn_nn::LinearConfig::new(self.input_size, self.hidden_size)
.with_bias(true)
.init(device),
linear2: burn_nn::LinearConfig::new(self.hidden_size, self.output_size)
.with_bias(true)
.init(device),
}
}
}
#[test]
#[cfg(feature = "std")]
fn test_forward_pass_preservation_after_save_load() {
use tempfile::tempdir;
let device = Default::default();
let config = ForwardTestModelConfig {
input_size: 4,
hidden_size: 8,
output_size: 2,
};
let model1 = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[1, 4],
burn_tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let output1 = model1.forward(input.clone());
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("forward_test_model.bpk");
let mut save_store = BurnpackStore::from_file(&path);
save_store.collect_from(&model1).unwrap();
let mut model2 = config.init::<TestBackend>(&device);
let output2 = model2.forward(input.clone());
assert!(
!output1
.clone()
.all_close(output2.clone(), Some(1e-6), Some(1e-6)),
"output2 should differ from output1 (different random initializations)"
);
let mut load_store = BurnpackStore::from_file(&path);
let result = load_store.apply_to(&mut model2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 4);
let output3 = model2.forward(input.clone());
assert!(
output1.all_close(output3, Some(1e-6), Some(1e-6)),
"output3 should equal output1 after loading weights"
);
}
#[test]
fn test_store_get_all_snapshots() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let snapshots = load_store.get_all_snapshots().unwrap();
assert_eq!(snapshots.len(), 4);
assert!(snapshots.contains_key("weight"));
assert!(snapshots.contains_key("bias"));
assert!(snapshots.contains_key("nested.gamma"));
assert!(snapshots.contains_key("nested.beta"));
}
#[test]
fn test_store_get_snapshot_existing() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let snapshot = load_store.get_snapshot("weight").unwrap();
assert!(snapshot.is_some());
let snapshot = snapshot.unwrap();
assert_eq!(snapshot.full_path(), "weight");
assert_eq!(snapshot.shape, shape![2, 2]);
let data = snapshot.to_data().unwrap();
assert_eq!(data.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_store_get_snapshot_nested() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let snapshot = load_store.get_snapshot("nested.gamma").unwrap();
assert!(snapshot.is_some());
let snapshot = snapshot.unwrap();
assert_eq!(snapshot.full_path(), "nested.gamma");
assert_eq!(snapshot.shape, shape![2]);
}
#[test]
fn test_store_get_snapshot_not_found() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let snapshot = load_store.get_snapshot("nonexistent").unwrap();
assert!(snapshot.is_none());
}
#[test]
fn test_store_keys() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let keys = load_store.keys().unwrap();
assert_eq!(keys.len(), 4);
assert!(keys.contains(&"weight".to_string()));
assert!(keys.contains(&"bias".to_string()));
assert!(keys.contains(&"nested.gamma".to_string()));
assert!(keys.contains(&"nested.beta".to_string()));
}
#[test]
#[cfg(feature = "std")]
fn test_store_get_all_snapshots_from_file() {
use tempfile::tempdir;
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("test_get_all_snapshots.bpk");
let mut save_store = BurnpackStore::from_file(&path);
save_store.collect_from(&module).unwrap();
let mut load_store = BurnpackStore::from_file(&path);
let snapshots = load_store.get_all_snapshots().unwrap();
assert_eq!(snapshots.len(), 4);
let weight_snapshot = snapshots.get("weight").unwrap();
let data = weight_snapshot.to_data().unwrap();
assert_eq!(data.to_vec::<f32>().unwrap(), vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_store_caching_behavior() {
let device = Default::default();
let module = TestModule::<TestBackend>::new(&device);
let mut save_store = BurnpackStore::from_bytes(None);
save_store.collect_from(&module).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let snapshots1 = load_store.get_all_snapshots().unwrap();
assert_eq!(snapshots1.len(), 4);
let snapshots2 = load_store.get_all_snapshots().unwrap();
assert_eq!(snapshots2.len(), 4);
let weight = load_store.get_snapshot("weight").unwrap();
assert!(weight.is_some());
}
#[test]
fn test_store_cache_invalidation_on_save() {
let device = Default::default();
let module1 = TestModule::<TestBackend>::new(&device);
let mut store = BurnpackStore::from_bytes(None);
store.collect_from(&module1).unwrap();
let snapshots1 = store.get_all_snapshots().unwrap();
assert_eq!(snapshots1.len(), 4);
let weight1_data = snapshots1.get("weight").unwrap().to_data().unwrap();
let weight1_values: Vec<f32> = weight1_data.to_vec().unwrap();
let module2 = TestModule::<TestBackend> {
weight: Param::from_tensor(Tensor::from_data([[10.0, 20.0], [30.0, 40.0]], &device)),
bias: Param::from_tensor(Tensor::from_data([100.0, 200.0], &device)),
nested: NestedModule {
gamma: Param::from_tensor(Tensor::from_data([1000.0, 2000.0], &device)),
beta: Param::from_tensor(Tensor::from_data([3000.0, 4000.0], &device)),
},
};
store.collect_from(&module2).unwrap();
let snapshots2 = store.get_all_snapshots().unwrap();
assert_eq!(snapshots2.len(), 4);
let weight2_data = snapshots2.get("weight").unwrap().to_data().unwrap();
let weight2_values: Vec<f32> = weight2_data.to_vec().unwrap();
assert_ne!(weight1_values, weight2_values);
assert_eq!(weight2_values, vec![10.0, 20.0, 30.0, 40.0]);
}
#[test]
fn test_store_quantized_module_round_trip() {
use burn_core::module::Quantizer;
use burn_nn::LinearConfig;
use burn_tensor::quantization::{
Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue,
};
let device = Default::default();
let linear = LinearConfig::new(512, 512)
.with_bias(false)
.init::<TestBackend>(&device);
let scheme = <<TestBackend as burn_tensor::backend::BackendTypes>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::Tensor)
.with_param(QuantParam::F32);
let calibration = Calibration::MinMax;
let mut quantizer = Quantizer {
calibration,
scheme,
};
let quantized_linear = linear.quantize_weights(&mut quantizer);
let mut save_store = BurnpackStore::from_bytes(None);
let result = save_store.collect_from(&quantized_linear);
assert!(
result.is_ok(),
"Failed to save quantized module: {:?}",
result.err()
);
let bytes = save_store.get_bytes().expect("Failed to get bytes");
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let snapshots = load_store
.get_all_snapshots()
.expect("Failed to get snapshots");
assert_eq!(snapshots.len(), 1, "Expected 1 tensor (weight)");
assert!(snapshots.contains_key("weight"), "Expected 'weight' tensor");
let weight_snapshot = snapshots.get("weight").unwrap();
assert_eq!(weight_snapshot.shape, shape![512, 512]);
let weight_data = weight_snapshot
.to_data()
.expect("Failed to load tensor data");
assert_eq!(weight_data.shape, shape![512, 512]);
}
#[test]
fn test_store_half_precision_round_trip() {
use crate::HalfPrecisionAdapter;
use burn_nn::{Linear, LinearConfig};
use burn_tensor::DType;
#[derive(Module, Debug)]
struct HalfModel<B: Backend> {
linear: Linear<B>,
}
let device = Default::default();
let model = HalfModel::<TestBackend> {
linear: LinearConfig::new(4, 2).with_bias(true).init(&device),
};
let adapter = HalfPrecisionAdapter::new();
let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter.clone());
save_store.collect_from(&model).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut inspect_store = BurnpackStore::from_bytes(Some(bytes.clone()));
let snapshots = inspect_store.get_all_snapshots().unwrap();
for (_, snapshot) in snapshots.iter() {
assert_eq!(snapshot.dtype, DType::F16, "Expected F16 in stored data");
}
let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_from_adapter(adapter);
let mut model2 = HalfModel::<TestBackend> {
linear: LinearConfig::new(4, 2).with_bias(true).init(&device),
};
let result = load_store.apply_to(&mut model2).unwrap();
assert!(result.is_success());
let w1 = model.linear.weight.val().to_data().to_vec::<f32>().unwrap();
let w2 = model2
.linear
.weight
.val()
.to_data()
.to_vec::<f32>()
.unwrap();
for (a, b) in w1.iter().zip(w2.iter()) {
assert!(
(a - b).abs() < 0.01,
"Weight values differ too much after F16 round-trip: {} vs {}",
a,
b
);
}
}
#[test]
fn test_store_half_precision_batch_norm_excluded() {
use crate::HalfPrecisionAdapter;
use burn_nn::{BatchNorm, BatchNormConfig, Linear, LinearConfig};
use burn_tensor::DType;
#[derive(Module, Debug)]
struct BnModel<B: Backend> {
linear: Linear<B>,
bn: BatchNorm<B>,
}
let device = Default::default();
let model = BnModel::<TestBackend> {
linear: LinearConfig::new(4, 2).with_bias(true).init(&device),
bn: BatchNormConfig::new(2).init(&device),
};
let adapter = HalfPrecisionAdapter::new();
let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter);
save_store.collect_from(&model).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut inspect_store = BurnpackStore::from_bytes(Some(bytes));
let snapshots = inspect_store.get_all_snapshots().unwrap();
for (name, snapshot) in snapshots.iter() {
if name.starts_with("linear") {
assert_eq!(
snapshot.dtype,
DType::F16,
"Linear tensor '{}' should be F16",
name
);
} else if name.starts_with("bn") {
assert_eq!(
snapshot.dtype,
DType::F32,
"BatchNorm tensor '{}' should stay F32",
name
);
}
}
}
#[test]
fn test_store_half_precision_without_module() {
use crate::HalfPrecisionAdapter;
use burn_nn::{LayerNorm, LayerNormConfig, Linear, LinearConfig};
use burn_tensor::DType;
#[derive(Module, Debug)]
struct MixedModel<B: Backend> {
linear: Linear<B>,
norm: LayerNorm<B>,
}
let device = Default::default();
let model = MixedModel::<TestBackend> {
linear: LinearConfig::new(4, 2).with_bias(true).init(&device),
norm: LayerNormConfig::new(2).init(&device),
};
let adapter = HalfPrecisionAdapter::new().without_module("LayerNorm");
let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter);
save_store.collect_from(&model).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut inspect_store = BurnpackStore::from_bytes(Some(bytes));
let snapshots = inspect_store.get_all_snapshots().unwrap();
for (name, snapshot) in snapshots.iter() {
if name.starts_with("linear") {
assert_eq!(
snapshot.dtype,
DType::F16,
"Linear tensor '{}' should be F16",
name
);
} else if name.starts_with("norm") {
assert_eq!(
snapshot.dtype,
DType::F32,
"LayerNorm tensor '{}' should stay F32",
name
);
}
}
}
#[test]
fn test_store_half_precision_chained_with_pytorch() {
use crate::{HalfPrecisionAdapter, PyTorchToBurnAdapter};
use burn_nn::{Linear, LinearConfig};
use burn_tensor::DType;
#[derive(Module, Debug)]
struct ChainModel<B: Backend> {
linear: Linear<B>,
}
let device = Default::default();
let model = ChainModel::<TestBackend> {
linear: LinearConfig::new(4, 2).with_bias(true).init(&device),
};
let adapter = crate::BurnToPyTorchAdapter.chain(HalfPrecisionAdapter::new());
let mut save_store = BurnpackStore::from_bytes(None).with_to_adapter(adapter);
save_store.collect_from(&model).unwrap();
let bytes = save_store.get_bytes().unwrap();
let mut inspect_store = BurnpackStore::from_bytes(Some(bytes.clone()));
let snapshots = inspect_store.get_all_snapshots().unwrap();
let weight = snapshots.get("linear.weight").unwrap();
assert_eq!(weight.dtype, DType::F16);
assert_eq!(weight.shape, shape![2, 4]);
let adapter = HalfPrecisionAdapter::new().chain(PyTorchToBurnAdapter);
let mut load_store = BurnpackStore::from_bytes(Some(bytes)).with_from_adapter(adapter);
let mut model2 = ChainModel::<TestBackend> {
linear: LinearConfig::new(4, 2).with_bias(true).init(&device),
};
let result = load_store.apply_to(&mut model2).unwrap();
assert!(result.is_success());
}
#[test]
fn test_store_quantized_module_block_level() {
use burn_core::module::Quantizer;
use burn_nn::LinearConfig;
use burn_tensor::quantization::{
Calibration, QTensorPrimitive, QuantLevel, QuantParam, QuantValue,
};
let device = Default::default();
let linear = LinearConfig::new(128, 128)
.with_bias(false)
.init::<TestBackend>(&device);
let scheme = <<TestBackend as burn_tensor::backend::BackendTypes>::QuantizedTensorPrimitive as QTensorPrimitive>::default_scheme()
.with_value(QuantValue::Q8S)
.with_level(QuantLevel::block([32])) .with_param(QuantParam::F32);
let calibration = Calibration::MinMax;
let mut quantizer = Quantizer {
calibration,
scheme,
};
let quantized_linear = linear.quantize_weights(&mut quantizer);
let mut save_store = BurnpackStore::from_bytes(None);
let result = save_store.collect_from(&quantized_linear);
assert!(
result.is_ok(),
"Failed to save quantized module with block-level quantization: {:?}",
result.err()
);
let bytes = save_store.get_bytes().expect("Failed to get bytes");
let mut load_store = BurnpackStore::from_bytes(Some(bytes));
let snapshots = load_store
.get_all_snapshots()
.expect("Failed to get snapshots");
assert_eq!(snapshots.len(), 1);
let weight_snapshot = snapshots.get("weight").unwrap();
assert_eq!(weight_snapshot.shape, shape![128, 128]);
}