use std::path::{Path, PathBuf};
use burn::module::Module;
use burn::prelude::*;
use burn::record::{FullPrecisionSettings, NamedMpkFileRecorder, Recorder};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use crate::error::{Result, TrainError};
use crate::learner::{LearnerConfig, TrainingHistory, TrainingState};
const MODEL_WEIGHTS_FILE: &str = "model.mpk";
const STATE_FILE: &str = "state.json";
const CONFIG_FILE: &str = "config.json";
const METADATA_FILE: &str = "export_meta.json";
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExportMetadata {
pub version: String,
pub arch: String,
pub n_classes: Option<usize>,
pub seq_len: Option<usize>,
pub n_vars: Option<usize>,
pub timestamp: String,
pub best_val_loss: Option<f32>,
pub best_val_acc: Option<f32>,
pub epochs_trained: usize,
pub extra: std::collections::HashMap<String, String>,
}
impl ExportMetadata {
pub fn new(arch: impl Into<String>) -> Self {
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs().to_string())
.unwrap_or_else(|_| "0".to_string());
Self {
version: "1.0".to_string(),
arch: arch.into(),
n_classes: None,
seq_len: None,
n_vars: None,
timestamp,
best_val_loss: None,
best_val_acc: None,
epochs_trained: 0,
extra: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn with_n_classes(mut self, n: usize) -> Self {
self.n_classes = Some(n);
self
}
#[must_use]
pub fn with_seq_len(mut self, len: usize) -> Self {
self.seq_len = Some(len);
self
}
#[must_use]
pub fn with_n_vars(mut self, n: usize) -> Self {
self.n_vars = Some(n);
self
}
#[must_use]
pub fn with_training_stats(mut self, state: &TrainingState) -> Self {
self.epochs_trained = state.epoch;
self.best_val_loss = Some(state.best_valid_loss);
if let Some(last_metrics) = state.history.metrics.last() {
if let Some(&acc) = last_metrics.get("accuracy") {
self.best_val_acc = Some(acc);
}
}
self
}
#[must_use]
pub fn with_extra(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
self.extra.insert(key.into(), value.into());
self
}
}
#[derive(Debug)]
pub struct LearnerExport {
pub path: PathBuf,
pub metadata: ExportMetadata,
pub state: TrainingState,
pub config: LearnerConfig,
}
impl LearnerExport {
pub fn load(path: impl AsRef<Path>) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let metadata_path = path.join(METADATA_FILE);
let metadata: ExportMetadata = load_json(&metadata_path)?;
let state_path = path.join(STATE_FILE);
let state: TrainingState = load_json(&state_path)?;
let config_path = path.join(CONFIG_FILE);
let config: LearnerConfig = load_json(&config_path)?;
Ok(Self {
path,
metadata,
state,
config,
})
}
pub fn load_model<B, M>(&self, device: &B::Device) -> Result<M::Record>
where
B: Backend,
M: Module<B>,
M::Record: DeserializeOwned,
{
let model_path = self.path.join(MODEL_WEIGHTS_FILE);
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
recorder
.load(model_path, device)
.map_err(|e| TrainError::CheckpointError(format!("Failed to load model: {}", e)))
}
pub fn model_path(&self) -> PathBuf {
self.path.join(MODEL_WEIGHTS_FILE)
}
pub fn history(&self) -> &TrainingHistory {
&self.state.history
}
pub fn best_val_loss(&self) -> f32 {
self.state.best_valid_loss
}
pub fn epochs_trained(&self) -> usize {
self.state.epoch
}
}
pub fn save_model_bundle<B, M>(
model: &M,
path: impl AsRef<Path>,
config: &LearnerConfig,
state: &TrainingState,
metadata: &ExportMetadata,
) -> Result<()>
where
B: Backend,
M: Module<B>,
M::Record: Serialize,
{
let path = path.as_ref();
std::fs::create_dir_all(path)
.map_err(|e| TrainError::CheckpointError(format!("Failed to create directory: {}", e)))?;
let model_path = path.join(MODEL_WEIGHTS_FILE);
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
let record = model.clone().into_record();
recorder
.record(record, model_path)
.map_err(|e| TrainError::CheckpointError(format!("Failed to save model: {}", e)))?;
let state_path = path.join(STATE_FILE);
save_json(state, &state_path)?;
let config_path = path.join(CONFIG_FILE);
save_json(config, &config_path)?;
let metadata_path = path.join(METADATA_FILE);
save_json(metadata, &metadata_path)?;
Ok(())
}
pub fn quick_save<B, M>(model: &M, path: impl AsRef<Path>) -> Result<()>
where
B: Backend,
M: Module<B>,
M::Record: Serialize,
{
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
let record = model.clone().into_record();
recorder
.record(record, path.as_ref().to_path_buf())
.map_err(|e| TrainError::CheckpointError(format!("Failed to save model: {}", e)))
}
pub fn quick_load<B, M>(path: impl AsRef<Path>, device: &B::Device) -> Result<M::Record>
where
B: Backend,
M: Module<B>,
M::Record: DeserializeOwned,
{
let recorder = NamedMpkFileRecorder::<FullPrecisionSettings>::new();
recorder
.load(path.as_ref().to_path_buf(), device)
.map_err(|e| TrainError::CheckpointError(format!("Failed to load model: {}", e)))
}
fn save_json<T: Serialize>(data: &T, path: &Path) -> Result<()> {
let json = serde_json::to_string_pretty(data)
.map_err(|e| TrainError::SerializationError(format!("Failed to serialize: {}", e)))?;
std::fs::write(path, json)
.map_err(|e| TrainError::CheckpointError(format!("Failed to write file: {}", e)))
}
fn load_json<T: DeserializeOwned>(path: &Path) -> Result<T> {
let json = std::fs::read_to_string(path)
.map_err(|e| TrainError::CheckpointError(format!("Failed to read file: {}", e)))?;
serde_json::from_str(&json)
.map_err(|e| TrainError::SerializationError(format!("Failed to deserialize: {}", e)))
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_export_metadata() {
let meta = ExportMetadata::new("InceptionTimePlus")
.with_n_classes(5)
.with_seq_len(100)
.with_n_vars(3)
.with_extra("dataset", "NATOPS");
assert_eq!(meta.arch, "InceptionTimePlus");
assert_eq!(meta.n_classes, Some(5));
assert_eq!(meta.seq_len, Some(100));
assert_eq!(meta.n_vars, Some(3));
assert_eq!(meta.extra.get("dataset"), Some(&"NATOPS".to_string()));
}
#[test]
fn test_export_metadata_with_training_stats() {
let mut state = TrainingState::default();
state.epoch = 25;
state.best_valid_loss = 0.15;
let meta = ExportMetadata::new("ResNetPlus").with_training_stats(&state);
assert_eq!(meta.epochs_trained, 25);
assert_eq!(meta.best_val_loss, Some(0.15));
}
}