use crate::{ModuleSnapshot, SafetensorsStore};
use burn_nn::LinearConfig;
type TestBackend = burn_flex::Flex;
#[test]
fn default_metadata_included() {
let default_metadata = SafetensorsStore::default_metadata();
assert_eq!(default_metadata.get("format").unwrap(), "safetensors");
assert_eq!(default_metadata.get("producer").unwrap(), "burn");
assert!(default_metadata.contains_key("version"));
let version = default_metadata.get("version").unwrap();
assert!(!version.is_empty());
}
#[test]
fn metadata_preservation() {
let device = Default::default();
let module = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let mut save_store = SafetensorsStore::from_bytes(None)
.metadata("model_type", "linear")
.metadata("custom_field", "test_value");
module.save_into(&mut save_store).unwrap();
let mut load_store = SafetensorsStore::from_bytes(None);
if let SafetensorsStore::Memory(ref mut p) = load_store
&& let SafetensorsStore::Memory(ref p_save) = save_store
{
let data_arc = p_save.data().unwrap();
p.set_data(data_arc.as_ref().clone());
}
let mut module2 = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let result = module2.load_from(&mut load_store).unwrap();
assert!(result.is_success());
}
#[test]
fn clear_metadata_removes_all() {
let device = Default::default();
let module = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let mut save_store = SafetensorsStore::from_bytes(None)
.metadata("model_type", "linear")
.metadata("custom_field", "test_value")
.clear_metadata();
module.save_into(&mut save_store).unwrap();
let mut load_store = SafetensorsStore::from_bytes(None);
if let SafetensorsStore::Memory(ref mut p) = load_store
&& let SafetensorsStore::Memory(ref p_save) = save_store
{
let data_arc = p_save.data().unwrap();
p.set_data(data_arc.as_ref().clone());
}
let mut module2 = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let result = module2.load_from(&mut load_store).unwrap();
assert!(result.is_success());
}
#[test]
fn clear_then_add_custom_metadata() {
let device = Default::default();
let module = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let mut save_store = SafetensorsStore::from_bytes(None)
.clear_metadata()
.metadata("only_custom", "value");
module.save_into(&mut save_store).unwrap();
let mut load_store = SafetensorsStore::from_bytes(None);
if let SafetensorsStore::Memory(ref mut p) = load_store
&& let SafetensorsStore::Memory(ref p_save) = save_store
{
let data_arc = p_save.data().unwrap();
p.set_data(data_arc.as_ref().clone());
}
let mut module2 = LinearConfig::new(4, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let result = module2.load_from(&mut load_store).unwrap();
assert!(result.is_success());
}