burn-store 0.21.0

Storage and serialization infrastructure for Burn
Documentation
use burn_core as burn;

use crate::{ModuleSnapshot, SafetensorsStore};
use burn_core::module::Module;
use burn_nn::{
    BatchNorm, BatchNormConfig, Linear, LinearConfig, PaddingConfig2d, Relu,
    conv::{Conv2d, Conv2dConfig},
};
use burn_tensor::Tensor;
use burn_tensor::backend::Backend;

type TestBackend = burn_flex::Flex;

#[derive(Module, Debug)]
pub struct Net<B: Backend> {
    conv1: Conv2d<B>,
    norm1: BatchNorm<B>,
    fc1: Linear<B>,
    relu: Relu,
}

impl<B: Backend> Net<B> {
    pub fn new(device: &B::Device) -> Self {
        Self {
            conv1: Conv2dConfig::new([3, 4], [3, 3])
                .with_padding(PaddingConfig2d::Explicit(1, 1, 1, 1))
                .init(device),
            norm1: BatchNormConfig::new(4).init(device),
            fc1: LinearConfig::new(4 * 8 * 8, 16).init(device),
            relu: Relu::new(),
        }
    }

    /// Forward pass of the model.
    pub fn forward(&self, x: Tensor<B, 4>) -> Tensor<B, 2> {
        let x = self.conv1.forward(x);
        let x = self.norm1.forward(x);
        let x = self.relu.forward(x);
        // Flatten all dimensions except the batch dimension
        let x = x.flatten(1, 3);
        self.fc1.forward(x)
    }
}

#[test]
#[cfg(all(feature = "std", target_has_atomic = "ptr"))]
fn multi_layer_model_import() {
    let device = Default::default();

    // Reference the safetensors file from burn-import
    let safetensors_path = concat!(
        env!("CARGO_MANIFEST_DIR"),
        "/safetensors-tests/tests/multi_layer/multi_layer.safetensors"
    );

    // Load the model using SafetensorsStore
    // Note: PyTorch and Burn have different conventions for linear layer weights
    // PyTorch stores as [out_features, in_features], Burn as [in_features, out_features]
    // Also, tensor names may differ (e.g., PyTorch uses different names for BatchNorm params)
    let mut store = SafetensorsStore::from_file(safetensors_path)
        .with_from_adapter(crate::PyTorchToBurnAdapter) // Use adapter to handle PyTorch format
        .allow_partial(true); // Allow partial loading due to naming differences
    let mut model = Net::<TestBackend>::new(&device);

    let result = model.load_from(&mut store).unwrap();

    // With the adapter, weights should load correctly
    assert!(!result.applied.is_empty());
    assert!(
        result.errors.is_empty(),
        "Should have no errors with adapter: {:?}",
        result.errors
    );

    // Test forward pass with the loaded weights
    // Note: Due to shape mismatches (PyTorch vs Burn conventions for linear layers),
    // we can't directly compare outputs with PyTorch model.
    // This test mainly verifies that the loading mechanism works.
    let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);
    let _output = model.forward(input);

    // Verify that some tensors were loaded successfully
    // Conv and BatchNorm layers should load correctly
    assert!(result.applied.iter().any(|n| n.contains("conv1")));
    assert!(result.applied.iter().any(|n| n.contains("norm1")));
}

#[test]
#[cfg(all(feature = "std", target_has_atomic = "ptr"))]
fn safetensors_round_trip_with_pytorch_model() {
    let device = Default::default();

    // Reference the safetensors file from burn-import
    let safetensors_path = concat!(
        env!("CARGO_MANIFEST_DIR"),
        "/safetensors-tests/tests/multi_layer/multi_layer.safetensors"
    );

    // Load the model from PyTorch safetensors
    let mut load_store = SafetensorsStore::from_file(safetensors_path)
        .with_from_adapter(crate::PyTorchToBurnAdapter) // Use adapter to handle PyTorch format
        .allow_partial(true); // Allow partial loading due to naming differences
    let mut model = Net::<TestBackend>::new(&device);
    let load_result = model.load_from(&mut load_store).unwrap();
    // With the adapter, weights should load correctly
    assert!(!load_result.applied.is_empty());
    assert!(
        load_result.errors.is_empty(),
        "Should have no errors with adapter: {:?}",
        load_result.errors
    );

    // Save the model to memory
    // Note: format, producer and version are automatically added
    let mut save_store = SafetensorsStore::from_bytes(None).metadata("source", "pytorch");
    model.save_into(&mut save_store).unwrap();

    // Load into a new model
    let mut model2 = Net::<TestBackend>::new(&device);
    let mut load_store2 = SafetensorsStore::from_bytes(None);
    if let SafetensorsStore::Memory(ref mut p) = load_store2
        && let SafetensorsStore::Memory(ref p_save) = save_store
    {
        p.set_data(p_save.data().unwrap().as_ref().clone());
    }

    let result = model2.load_from(&mut load_store2).unwrap();
    assert!(!result.applied.is_empty());

    // Verify both models produce the same output
    let input = Tensor::<TestBackend, 4>::ones([1, 3, 8, 8], &device);
    let output1 = model.forward(input.clone());
    let output2 = model2.forward(input);

    // Check outputs are identical
    let output1_data = output1.to_data().to_vec::<f32>().unwrap();
    let output2_data = output2.to_data().to_vec::<f32>().unwrap();

    for (a, b) in output1_data.iter().zip(output2_data.iter()) {
        assert!((a - b).abs() < 1e-7, "Outputs differ after round trip");
    }
}

#[test]
#[cfg(all(feature = "std", target_has_atomic = "ptr"))]
fn partial_load_from_pytorch_model() {
    let device = Default::default();

    // Reference the safetensors file from burn-import
    let safetensors_path = concat!(
        env!("CARGO_MANIFEST_DIR"),
        "/safetensors-tests/tests/multi_layer/multi_layer.safetensors"
    );

    // Load only conv1 and norm1 parameters (not fc1)
    let mut store = SafetensorsStore::from_file(safetensors_path)
        .validate(false) // Disable validation due to shape differences
        .allow_partial(true);

    let mut model = Net::<TestBackend>::new(&device);

    // Save initial fc1 weights for comparison
    let _initial_fc1_weight = model.fc1.weight.val().to_data();

    let result = model.load_from(&mut store).unwrap();

    // Should load available tensors (with some errors due to shape mismatch)
    assert!(!result.applied.is_empty());

    // fc1 weight should remain unchanged if not in the file
    // or should be updated if it is in the file
    // This test verifies that partial loading works correctly
}

#[test]
#[cfg(all(feature = "std", target_has_atomic = "ptr"))]
fn verify_tensor_names_from_pytorch() {
    let device = Default::default();

    // Reference the safetensors file from burn-import
    let safetensors_path = concat!(
        env!("CARGO_MANIFEST_DIR"),
        "/safetensors-tests/tests/multi_layer/multi_layer.safetensors"
    );

    // Create a model and load from PyTorch
    let mut model = Net::<TestBackend>::new(&device);
    let mut store = SafetensorsStore::from_file(safetensors_path)
        .validate(false) // Disable validation due to shape differences
        .allow_partial(true); // Allow partial loading due to naming differences
    let result = model.load_from(&mut store).unwrap();

    // Check that we loaded some tensors (with errors due to shape mismatch)
    assert!(!result.applied.is_empty());

    // Collect tensor names from the model
    let views = model.collect(None, None, false);
    let tensor_names: Vec<String> = views.iter().map(|v| v.full_path()).collect();

    // Verify expected tensor names are present
    assert!(tensor_names.iter().any(|n| n.contains("conv1")));
    assert!(tensor_names.iter().any(|n| n.contains("norm1")));
    assert!(tensor_names.iter().any(|n| n.contains("fc1")));
}