use burn_core as burn;
use crate::{ModuleSnapshot, ModuleStore, SafetensorsStore};
use burn_core::module::{Module, Param};
use burn_nn::{Initializer, LinearConfig};
use burn_tensor::Tensor;
use burn_tensor::backend::Backend;
use tempfile::tempdir;
type TestBackend = burn_flex::Flex;
#[derive(Module, Debug)]
struct ForwardTestModel<B: burn_tensor::backend::Backend> {
linear1: burn_nn::Linear<B>,
linear2: burn_nn::Linear<B>,
}
impl<B: burn_tensor::backend::Backend> ForwardTestModel<B> {
fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = self.linear1.forward(input);
let x = burn::tensor::activation::gelu(x);
self.linear2.forward(x)
}
}
#[derive(burn::config::Config, Debug)]
struct ForwardTestModelConfig {
input_size: usize,
hidden_size: usize,
output_size: usize,
}
impl ForwardTestModelConfig {
fn init<B: burn_tensor::backend::Backend>(&self, device: &B::Device) -> ForwardTestModel<B> {
ForwardTestModel {
linear1: LinearConfig::new(self.input_size, self.hidden_size)
.with_bias(true)
.init(device),
linear2: LinearConfig::new(self.hidden_size, self.output_size)
.with_bias(true)
.init(device),
}
}
}
#[derive(Module, Debug)]
pub struct ModuleBasic<B: Backend> {
weight_basic: Param<Tensor<B, 2>>,
}
impl<B: Backend> ModuleBasic<B> {
fn new(device: &B::Device) -> Self {
Self {
weight_basic: Initializer::Normal {
std: 1.0,
mean: 0.0,
}
.init([20, 20], device),
}
}
}
#[derive(Module, Debug)]
pub struct ModuleComposed<B: Backend> {
weight: Param<Tensor<B, 2>>,
basic: ModuleBasic<B>,
tuple: (ModuleBasic<B>, ModuleBasic<B>),
}
impl<B: Backend> ModuleComposed<B> {
fn new(device: &B::Device) -> Self {
let weight = Initializer::Normal {
std: 1.0,
mean: 0.0,
}
.init([20, 20], device);
Self {
weight,
basic: ModuleBasic::new(device),
tuple: (ModuleBasic::new(device), ModuleBasic::new(device)),
}
}
}
#[test]
fn file_based_loading() {
use std::fs;
let device = Default::default();
let module = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let temp_dir = std::env::temp_dir();
let file_path = temp_dir.join("test_safetensors.st");
let mut save_store = SafetensorsStore::from_file(&file_path).metadata("test", "file_loading");
module.save_into(&mut save_store).unwrap();
assert!(file_path.exists());
let mut load_store = SafetensorsStore::from_file(&file_path);
let mut loaded_module = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let result = loaded_module.load_from(&mut load_store).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 2);
fs::remove_file(file_path).ok();
}
#[test]
fn test_store_overwrite_protection() {
use tempfile::tempdir;
let device = Default::default();
let module = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("test_model.safetensors");
let mut save_store = SafetensorsStore::from_file(&path);
save_store.collect_from(&module).unwrap();
assert!(path.exists());
let mut save_store2 = SafetensorsStore::from_file(&path);
let result = save_store2.collect_from(&module);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("File already exists")
);
let mut save_store3 = SafetensorsStore::from_file(&path).overwrite(true);
save_store3.collect_from(&module).unwrap();
let mut load_store = SafetensorsStore::from_file(&path);
let mut module2 = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let result = load_store.apply_to(&mut module2).unwrap();
assert!(result.is_success());
}
#[test]
fn test_store_overwrite_with_metadata() {
use tempfile::tempdir;
let device = Default::default();
let module = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("test_model_metadata.safetensors");
let mut save_store = SafetensorsStore::from_file(&path)
.metadata("model_version", "v1")
.overwrite(true);
save_store.collect_from(&module).unwrap();
let mut save_store2 = SafetensorsStore::from_file(&path)
.metadata("model_version", "v2")
.overwrite(true);
save_store2.collect_from(&module).unwrap();
let mut load_store = SafetensorsStore::from_file(&path);
let mut module2 = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let result = module2.load_from(&mut load_store).unwrap();
assert!(result.is_success());
}
#[test]
fn test_forward_pass_preservation_after_save_load() {
let device = Default::default();
let config = ForwardTestModelConfig {
input_size: 4,
hidden_size: 8,
output_size: 2,
};
let model1 = config.init::<TestBackend>(&device);
let input = Tensor::<TestBackend, 2>::random(
[1, 4],
burn_tensor::Distribution::Uniform(-1.0, 1.0),
&device,
);
let output1 = model1.forward(input.clone());
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("forward_test_model.safetensors");
let mut save_store = SafetensorsStore::from_file(&path);
save_store.collect_from(&model1).unwrap();
let mut model2 = config.init::<TestBackend>(&device);
let output2 = model2.forward(input.clone());
assert!(
!output1
.clone()
.all_close(output2.clone(), Some(1e-6), Some(1e-6)),
"output2 should differ from output1 (different random initializations)"
);
let mut load_store = SafetensorsStore::from_file(&path);
let result = load_store.apply_to(&mut model2).unwrap();
assert!(result.is_success());
assert_eq!(result.applied.len(), 4);
let output3 = model2.forward(input.clone());
assert!(
output1.all_close(output3, Some(1e-6), Some(1e-6)),
"output3 should equal output1 after loading weights"
);
}
#[test]
fn should_save_load_compose() {
let device = Default::default();
let module_1 = ModuleComposed::<TestBackend>::new(&device);
let mut module_2 = ModuleComposed::<TestBackend>::new(&device);
assert_ne!(module_1.weight.to_data(), module_2.weight.to_data());
assert_ne!(
module_1.basic.weight_basic.to_data(),
module_2.basic.weight_basic.to_data()
);
let temp_dir = tempdir().unwrap();
let path = temp_dir.path().join("save_load_compose.safetensors");
let mut store = SafetensorsStore::from_file(&path);
module_1.save_into(&mut store).unwrap();
let mut load_store = SafetensorsStore::from_file(&path);
let result = module_2.load_from(&mut load_store).unwrap();
assert!(result.is_success());
assert_eq!(module_1.weight.to_data(), module_2.weight.to_data());
assert_eq!(
module_1.basic.weight_basic.to_data(),
module_2.basic.weight_basic.to_data()
);
}