use burn_core as burn;
use crate::{ModuleSnapshot, SafetensorsStore};
use burn_core::module::Module;
use burn_nn::{
BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
conv::{Conv2d, Conv2dConfig},
};
use burn_tensor::Tensor;
use burn_tensor::backend::Backend;
type TestBackend = burn_flex::Flex;
#[derive(Module, Debug)]
pub struct Net<B: Backend> {
conv1: Conv2d<B>,
norm1: BatchNorm<B>,
fc1: Linear<B>,
relu: Relu,
}
impl<B: Backend> Net<B> {
pub fn new(device: &B::Device) -> Self {
Self {
conv1: Conv2dConfig::new([3, 4], [3, 3])
.with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
.init(device),
norm1: BatchNormConfig::new(4).init(device),
fc1: LinearConfig::new(4 * 8 * 8, 16).init(device),
relu: Relu::new(),
}
}
pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
let x = self.conv1.forward(x);
let x = self.norm1.forward(x);
let x = self.relu.forward(x);
let x = x.flatten(1, 3);
self.fc1.forward(x)
}
}
#[test]
#[cfg(all(feature = "std", target_has_atomic = "ptr"))]
fn multi_layer_model_import() {
let device = Default::default();
let safetensors_path = concat!(
env!("CARGO_MANIFEST_DIR"),
"/safetensors-tests/tests/multi_layer/multi_layer.safetensors"
);
let mut store = SafetensorsStore::from_file(safetensors_path)
.with_from_adapter(crate::PyTorchToBurnAdapter) .allow_partial(true); let mut model = Net::<TestBackend>::new(&device);
let result = model.load_from(&mut store).unwrap();
assert!(!result.applied.is_empty());
assert!(
result.errors.is_empty(),
"Should have no errors with adapter: {:?}",
result.errors
);
let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);
let _output = model.forward(input);
assert!(result.applied.iter().any(|n| n.contains("conv1")));
assert!(result.applied.iter().any(|n| n.contains("norm1")));
}
#[test]
#[cfg(all(feature = "std", target_has_atomic = "ptr"))]
fn safetensors_round_trip_with_pytorch_model() {
let device = Default::default();
let safetensors_path = concat!(
env!("CARGO_MANIFEST_DIR"),
"/safetensors-tests/tests/multi_layer/multi_layer.safetensors"
);
let mut load_store = SafetensorsStore::from_file(safetensors_path)
.with_from_adapter(crate::PyTorchToBurnAdapter) .allow_partial(true); let mut model = Net::<TestBackend>::new(&device);
let load_result = model.load_from(&mut load_store).unwrap();
assert!(!load_result.applied.is_empty());
assert!(
load_result.errors.is_empty(),
"Should have no errors with adapter: {:?}",
load_result.errors
);
let mut save_store = SafetensorsStore::from_bytes(None).metadata("source", "pytorch");
model.save_into(&mut save_store).unwrap();
let mut model2 = Net::<TestBackend>::new(&device);
let mut load_store2 = SafetensorsStore::from_bytes(None);
if let SafetensorsStore::Memory(ref mut p) = load_store2
&& let SafetensorsStore::Memory(ref p_save) = save_store
{
p.set_data(p_save.data().unwrap().as_ref().clone());
}
let result = model2.load_from(&mut load_store2).unwrap();
assert!(!result.applied.is_empty());
let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);
let output1 = model.forward(input.clone());
let output2 = model2.forward(input);
let output1_data = output1.to_data().to_vec::<f32>().unwrap();
let output2_data = output2.to_data().to_vec::<f32>().unwrap();
for (a, b) in output1_data.iter().zip(output2_data.iter()) {
assert!((a - b).abs() < 1e-7, "Outputs differ after round trip");
}
}
#[test]
#[cfg(all(feature = "std", target_has_atomic = "ptr"))]
fn partial_load_from_pytorch_model() {
let device = Default::default();
let safetensors_path = concat!(
env!("CARGO_MANIFEST_DIR"),
"/safetensors-tests/tests/multi_layer/multi_layer.safetensors"
);
let mut store = SafetensorsStore::from_file(safetensors_path)
.validate(false) .allow_partial(true);
let mut model = Net::<TestBackend>::new(&device);
let _initial_fc1_weight = model.fc1.weight.val().to_data();
let result = model.load_from(&mut store).unwrap();
assert!(!result.applied.is_empty());
}
#[test]
#[cfg(all(feature = "std", target_has_atomic = "ptr"))]
fn verify_tensor_names_from_pytorch() {
let device = Default::default();
let safetensors_path = concat!(
env!("CARGO_MANIFEST_DIR"),
"/safetensors-tests/tests/multi_layer/multi_layer.safetensors"
);
let mut model = Net::<TestBackend>::new(&device);
let mut store = SafetensorsStore::from_file(safetensors_path)
.validate(false) .allow_partial(true); let result = model.load_from(&mut store).unwrap();
assert!(!result.applied.is_empty());
let views = model.collect(None, None, false);
let tensor_names: Vec<String> = views.iter().map(|v| v.full_path()).collect();
assert!(tensor_names.iter().any(|n| n.contains("conv1")));
assert!(tensor_names.iter().any(|n| n.contains("norm1")));
assert!(tensor_names.iter().any(|n| n.contains("fc1")));
}