use crate::error::{Error, Result};
use crate::format::safetensors::{SafeTensors, save_safetensors};
use numr::dtype::DType;
use numr::runtime::Runtime;
use numr::runtime::cpu::CpuRuntime;
use numr::tensor::Tensor;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::Path;
pub const CHECKPOINT_VERSION: u32 = 2;
fn default_version() -> u32 {
1
}
pub type CheckpointData<R> = (
HashMap<String, Tensor<R>>,
Option<HashMap<String, Tensor<R>>>,
TrainingState,
);
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TrainingState {
#[serde(default = "default_version")]
pub version: u32,
pub step: u64,
#[serde(default)]
pub epoch: u64,
#[serde(default)]
pub learning_rate: f64,
#[serde(default)]
pub metadata: HashMap<String, String>,
}
pub fn save_checkpoint<P: AsRef<Path>>(
dir: P,
model_state: &HashMap<String, Tensor<CpuRuntime>>,
optimizer_state: Option<&HashMap<String, Tensor<CpuRuntime>>>,
training_state: &TrainingState,
) -> Result<()> {
let dir = dir.as_ref();
std::fs::create_dir_all(dir).map_err(|e| Error::TrainingError {
reason: format!("failed to create checkpoint dir: {e}"),
})?;
save_safetensors(dir.join("model.safetensors"), model_state, None)?;
if let Some(opt_state) = optimizer_state {
if !opt_state.is_empty() {
save_safetensors(dir.join("optimizer.safetensors"), opt_state, None)?;
}
}
let mut state = training_state.clone();
state.version = CHECKPOINT_VERSION;
let json = serde_json::to_string_pretty(&state).map_err(|e| Error::TrainingError {
reason: format!("failed to serialize training state: {e}"),
})?;
std::fs::write(dir.join("training_state.json"), json).map_err(|e| Error::TrainingError {
reason: format!("failed to write training state: {e}"),
})?;
Ok(())
}
pub fn load_checkpoint<R: Runtime<DType = DType>, P: AsRef<Path>>(
dir: P,
device: &R::Device,
) -> Result<CheckpointData<R>> {
let dir = dir.as_ref();
let model_path = dir.join("model.safetensors");
let mut st = SafeTensors::open(&model_path)?;
let model_state = st.load_all::<R>(device)?;
let opt_path = dir.join("optimizer.safetensors");
let optimizer_state = if opt_path.exists() {
let mut opt_st = SafeTensors::open(&opt_path)?;
Some(opt_st.load_all::<R>(device)?)
} else {
None
};
let state_path = dir.join("training_state.json");
let json = std::fs::read_to_string(&state_path).map_err(|e| Error::TrainingError {
reason: format!("failed to read training state: {e}"),
})?;
let training_state: TrainingState =
serde_json::from_str(&json).map_err(|e| Error::TrainingError {
reason: format!("failed to parse training state: {e}"),
})?;
Ok((model_state, optimizer_state, training_state))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::trainer::test_helpers::*;
use tempfile::TempDir;
#[test]
fn test_save_and_load_checkpoint() {
let dir = TempDir::new().unwrap();
let device = make_device();
let model_state = make_model_state(&device);
let opt_state = make_opt_state(&device);
let training_state = TrainingState {
version: CHECKPOINT_VERSION,
step: 1000,
epoch: 2,
learning_rate: 3e-4,
metadata: HashMap::new(),
};
save_checkpoint(dir.path(), &model_state, Some(&opt_state), &training_state).unwrap();
assert!(dir.path().join("model.safetensors").exists());
assert!(dir.path().join("optimizer.safetensors").exists());
assert!(dir.path().join("training_state.json").exists());
let (loaded_model, loaded_opt, loaded_state) =
load_checkpoint::<CpuRuntime, _>(dir.path(), &device).unwrap();
assert_eq!(loaded_model.len(), 2);
let w = &loaded_model["layers.0.weight"];
assert_eq!(w.shape(), &[2, 2]);
let data: Vec<f32> = w.to_vec();
assert!((data[0] - 1.0).abs() < 1e-6);
let opt = loaded_opt.unwrap();
assert_eq!(opt.len(), 2);
let m = &opt["layers.0.weight.m"];
let m_data: Vec<f32> = m.to_vec();
assert!((m_data[0] - 0.01).abs() < 1e-6);
assert_eq!(loaded_state.step, 1000);
assert_eq!(loaded_state.epoch, 2);
assert!((loaded_state.learning_rate - 3e-4).abs() < 1e-10);
}
#[test]
fn test_checkpoint_without_optimizer() {
let dir = TempDir::new().unwrap();
let device = make_device();
let mut model_state = HashMap::new();
model_state.insert(
"w".to_string(),
Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device),
);
let training_state = make_training_state(42);
save_checkpoint(dir.path(), &model_state, None, &training_state).unwrap();
assert!(!dir.path().join("optimizer.safetensors").exists());
let (_, loaded_opt, loaded_state) =
load_checkpoint::<CpuRuntime, _>(dir.path(), &device).unwrap();
assert!(loaded_opt.is_none());
assert_eq!(loaded_state.step, 42);
}
#[test]
fn test_version_round_trip() {
let dir = TempDir::new().unwrap();
let device = make_device();
let mut model_state = HashMap::new();
model_state.insert(
"w".to_string(),
Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device),
);
let state = make_training_state(10);
save_checkpoint(dir.path(), &model_state, None, &state).unwrap();
let (_, _, loaded) = load_checkpoint::<CpuRuntime, _>(dir.path(), &device).unwrap();
assert_eq!(loaded.version, CHECKPOINT_VERSION);
}
#[test]
fn test_version_backward_compat() {
let dir = TempDir::new().unwrap();
let device = make_device();
let mut model_state = HashMap::new();
model_state.insert(
"w".to_string(),
Tensor::<CpuRuntime>::from_slice(&[1.0f32], &[1], &device),
);
save_safetensors(dir.path().join("model.safetensors"), &model_state, None).unwrap();
let json = r#"{"step": 100, "epoch": 5, "learning_rate": 0.001, "metadata": {}}"#;
std::fs::write(dir.path().join("training_state.json"), json).unwrap();
let (_, _, loaded) = load_checkpoint::<CpuRuntime, _>(dir.path(), &device).unwrap();
assert_eq!(loaded.version, 1); assert_eq!(loaded.step, 100);
assert_eq!(loaded.epoch, 5);
}
}