use std::path::PathBuf;
use burn::module::Module;
use burn::prelude::*;
use burn::record::{
BinFileRecorder, CompactRecorder, DefaultRecorder, FullPrecisionSettings,
HalfPrecisionSettings, NamedMpkFileRecorder, Recorder,
};
use crate::config::AttnResConfig;
use crate::model::AttnResTransformer;
#[derive(Debug)]
pub enum SerializationError {
SaveFailed {
path: String,
detail: String,
},
LoadFailed {
path: String,
detail: String,
},
RecorderError(String),
}
impl std::fmt::Display for SerializationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::SaveFailed { path, detail } => {
write!(f, "Failed to save model to '{path}': {detail}")
}
Self::LoadFailed { path, detail } => {
write!(f, "Failed to load model from '{path}': {detail}")
}
Self::RecorderError(msg) => write!(f, "Serialization error: {msg}"),
}
}
}
impl std::error::Error for SerializationError {}
impl From<burn::record::RecorderError> for SerializationError {
fn from(err: burn::record::RecorderError) -> Self {
Self::RecorderError(format!("{err:?}"))
}
}
impl<B: Backend> AttnResTransformer<B> {
pub fn save(&self, path: &str, _device: &B::Device) -> Result<(), SerializationError> {
let recorder = DefaultRecorder::default();
recorder
.record(self.clone().into_record(), PathBuf::from(path))
.map_err(|e| SerializationError::SaveFailed {
path: path.to_string(),
detail: format!("{e:?}"),
})?;
Ok(())
}
pub fn load(
path: &str,
config: &AttnResConfig,
device: &B::Device,
) -> Result<Self, SerializationError> {
let recorder = DefaultRecorder::default();
let record = recorder.load(PathBuf::from(path), device).map_err(|e| {
SerializationError::LoadFailed {
path: path.to_string(),
detail: format!("{e:?}"),
}
})?;
let model = config.init_model::<B>(device).load_record(record);
Ok(model)
}
pub fn save_compact(&self, path: &str) -> Result<(), SerializationError> {
let recorder = CompactRecorder::default();
recorder
.record(self.clone().into_record(), PathBuf::from(path))
.map_err(|e| SerializationError::SaveFailed {
path: path.to_string(),
detail: format!("{e:?}"),
})?;
Ok(())
}
pub fn load_compact(
path: &str,
config: &AttnResConfig,
device: &B::Device,
) -> Result<Self, SerializationError> {
let recorder = NamedMpkFileRecorder::<HalfPrecisionSettings>::default();
let record = recorder.load(PathBuf::from(path), device).map_err(|e| {
SerializationError::LoadFailed {
path: path.to_string(),
detail: format!("{e:?}"),
}
})?;
let model = config.init_model::<B>(device).load_record(record);
Ok(model)
}
pub fn save_binary(&self, path: &str) -> Result<(), SerializationError> {
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
recorder
.record(self.clone().into_record(), PathBuf::from(path))
.map_err(|e| SerializationError::SaveFailed {
path: path.to_string(),
detail: format!("{e:?}"),
})?;
Ok(())
}
pub fn load_binary(
path: &str,
config: &AttnResConfig,
device: &B::Device,
) -> Result<Self, SerializationError> {
let recorder = BinFileRecorder::<FullPrecisionSettings>::default();
let record = recorder.load(PathBuf::from(path), device).map_err(|e| {
SerializationError::LoadFailed {
path: path.to_string(),
detail: format!("{e:?}"),
}
})?;
let model = config.init_model::<B>(device).load_record(record);
Ok(model)
}
}
#[cfg(test)]
mod tests {
use super::*;
use burn::backend::NdArray;
type TestBackend = NdArray;
#[test]
fn test_save_load_roundtrip() {
let device = Default::default();
let config = AttnResConfig::new(32, 4, 2)
.with_num_heads(4)
.with_vocab_size(50);
let model: AttnResTransformer<TestBackend> = config.init_model(&device);
let input = Tensor::<TestBackend, 2, Int>::zeros([1, 8], &device);
let out_before = model.forward(input.clone(), None);
let path = std::env::temp_dir().join("attnres_test_save_load");
let path_str = path.to_str().unwrap();
model.save(path_str, &device).expect("Failed to save");
let loaded: AttnResTransformer<TestBackend> =
AttnResTransformer::load(path_str, &config, &device).expect("Failed to load");
let out_after = loaded.forward(input, None);
let diff: f32 = (out_before - out_after).abs().max().into_scalar();
assert!(
diff < 1e-6,
"Loaded model should produce identical output, diff={diff}"
);
let _ = std::fs::remove_file(format!("{path_str}.mpk"));
}
#[test]
fn test_save_load_compact_roundtrip() {
let device = Default::default();
let config = AttnResConfig::new(32, 4, 2)
.with_num_heads(4)
.with_vocab_size(50);
let model: AttnResTransformer<TestBackend> = config.init_model(&device);
let input = Tensor::<TestBackend, 2, Int>::zeros([1, 8], &device);
let out_before = model.forward(input.clone(), None);
let path = std::env::temp_dir().join("attnres_test_save_load_compact");
let path_str = path.to_str().unwrap();
model
.save_compact(path_str)
.expect("Failed to save compact");
let loaded: AttnResTransformer<TestBackend> =
AttnResTransformer::load_compact(path_str, &config, &device)
.expect("Failed to load compact");
let out_after = loaded.forward(input, None);
let diff: f32 = (out_before - out_after).abs().max().into_scalar();
assert!(
diff < 1e-2,
"Compact-loaded model should produce similar output, diff={diff}"
);
let _ = std::fs::remove_file(format!("{path_str}.mpk"));
}
#[test]
fn test_save_load_binary_roundtrip() {
let device = Default::default();
let config = AttnResConfig::new(32, 4, 2)
.with_num_heads(4)
.with_vocab_size(50);
let model: AttnResTransformer<TestBackend> = config.init_model(&device);
let input = Tensor::<TestBackend, 2, Int>::zeros([1, 8], &device);
let out_before = model.forward(input.clone(), None);
let path = std::env::temp_dir().join("attnres_test_save_load_bin");
let path_str = path.to_str().unwrap();
model.save_binary(path_str).expect("Failed to save binary");
let loaded: AttnResTransformer<TestBackend> =
AttnResTransformer::load_binary(path_str, &config, &device)
.expect("Failed to load binary");
let out_after = loaded.forward(input, None);
let diff: f32 = (out_before - out_after).abs().max().into_scalar();
assert!(
diff < 1e-6,
"Binary-loaded model should produce identical output, diff={diff}"
);
let _ = std::fs::remove_file(format!("{path_str}.bin"));
}
#[test]
fn test_load_nonexistent_returns_error() {
let device = Default::default();
let config = AttnResConfig::new(32, 4, 2)
.with_num_heads(4)
.with_vocab_size(50);
let result = AttnResTransformer::<TestBackend>::load(
"/tmp/nonexistent_attnres_model_xyz",
&config,
&device,
);
assert!(result.is_err(), "Loading nonexistent file should fail");
let err = result.unwrap_err();
let msg = format!("{err}");
assert!(
msg.contains("nonexistent_attnres_model_xyz"),
"Error should contain the path, got: {msg}"
);
}
#[test]
fn test_serialization_error_display() {
let err = SerializationError::SaveFailed {
path: "test/path".to_string(),
detail: "disk full".to_string(),
};
assert_eq!(
format!("{err}"),
"Failed to save model to 'test/path': disk full"
);
let err = SerializationError::LoadFailed {
path: "model.mpk".to_string(),
detail: "corrupted".to_string(),
};
assert_eq!(
format!("{err}"),
"Failed to load model from 'model.mpk': corrupted"
);
}
}