use crate::{ModuleSnapshot, SafetensorsStore};
use burn_nn::LinearConfig;
type TestBackend = burn_flex::Flex;
#[test]
fn shape_mismatch_errors() {
let device = Default::default();
let module = LinearConfig::new(2, 2)
.with_bias(true)
.init::<TestBackend>(&device);
let mut save_store = SafetensorsStore::from_bytes(None);
module.save_into(&mut save_store).unwrap();
let mut incompatible_module = LinearConfig::new(3, 3)
.with_bias(true)
.init::<TestBackend>(&device);
let mut load_store = SafetensorsStore::from_bytes(None).validate(false); if let SafetensorsStore::Memory(ref mut p) = load_store
&& let SafetensorsStore::Memory(ref p_save) = save_store
{
let data_arc = p_save.data().unwrap();
p.set_data(data_arc.as_ref().clone());
}
let result = incompatible_module.load_from(&mut load_store).unwrap();
assert!(!result.errors.is_empty());
let mut load_store_with_validation = SafetensorsStore::from_bytes(None).validate(true);
if let SafetensorsStore::Memory(ref mut p) = load_store_with_validation
&& let SafetensorsStore::Memory(ref p_save) = save_store
{
let data_arc = p_save.data().unwrap();
p.set_data(data_arc.as_ref().clone());
}
let validation_result = incompatible_module.load_from(&mut load_store_with_validation);
assert!(validation_result.is_err());
}