use crate::{FxGraph, TorshResult};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use torsh_core::error::TorshError;
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CheckpointMetadata {
pub timestamp: u64,
pub step: u64,
pub loss: Option<f64>,
pub model_info: String,
pub user_metadata: HashMap<String, String>,
pub checksum: String,
pub version: u32,
}
impl CheckpointMetadata {
pub fn new(step: u64, model_info: String) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
timestamp,
step,
loss: None,
model_info,
user_metadata: HashMap::new(),
checksum: String::new(),
version: 1,
}
}
pub fn with_loss(mut self, loss: f64) -> Self {
self.loss = Some(loss);
self
}
pub fn with_metadata(mut self, key: String, value: String) -> Self {
self.user_metadata.insert(key, value);
self
}
pub fn with_checksum(mut self, data: &[u8]) -> Self {
let hash = md5::compute(data);
self.checksum = format!("{hash:x}");
self
}
pub fn verify_checksum(&self, data: &[u8]) -> bool {
let hash = md5::compute(data);
let computed = format!("{hash:x}");
computed == self.checksum
}
}
#[derive(Debug, Clone)]
pub struct CheckpointData {
pub graph: FxGraph,
pub tensor_states: HashMap<String, TensorState>,
pub optimizer_states: HashMap<String, OptimizerState>,
pub rng_states: HashMap<String, RngState>,
pub custom_states: HashMap<String, Vec<u8>>,
pub metadata: CheckpointMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorState {
pub shape: Vec<usize>,
pub dtype: String,
pub data: Vec<u8>,
pub device_type: String,
pub requires_grad: bool,
}
impl TensorState {
pub fn from_tensor(tensor: &Tensor) -> TorshResult<Self> {
Ok(Self {
shape: tensor.shape().dims().to_vec(),
dtype: format!("{:?}", tensor.dtype()), data: vec![0; tensor.shape().numel() * tensor.dtype().size()],
device_type: "cpu".to_string(),
requires_grad: false, })
}
pub fn to_tensor(&self) -> TorshResult<Tensor> {
use torsh_tensor::creation::zeros;
zeros(&self.shape)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct OptimizerState {
pub optimizer_type: String,
pub learning_rate: f64,
pub step_count: u64,
pub parameters: HashMap<String, f64>,
pub param_states: HashMap<String, Vec<u8>>,
}
impl OptimizerState {
pub fn new(optimizer_type: String, learning_rate: f64) -> Self {
Self {
optimizer_type,
learning_rate,
step_count: 0,
parameters: HashMap::new(),
param_states: HashMap::new(),
}
}
pub fn with_parameter(mut self, name: String, value: f64) -> Self {
self.parameters.insert(name, value);
self
}
pub fn with_param_state(mut self, name: String, state: Vec<u8>) -> Self {
self.param_states.insert(name, state);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RngState {
pub rng_type: String,
pub state: Vec<u8>,
pub seed: u64,
}
impl RngState {
pub fn new(rng_type: String, seed: u64) -> Self {
Self {
rng_type,
state: vec![],
seed,
}
}
pub fn with_state(mut self, state: Vec<u8>) -> Self {
self.state = state;
self
}
}
#[derive(Debug, Clone)]
pub struct CheckpointOptions {
pub compress: bool,
pub compression_level: u32,
pub separate_tensors: bool,
pub max_history: Option<usize>,
pub create_latest_link: bool,
pub format: CheckpointFormat,
}
impl Default for CheckpointOptions {
fn default() -> Self {
Self {
compress: true,
compression_level: 6,
separate_tensors: false,
max_history: Some(5),
create_latest_link: true,
format: CheckpointFormat::Binary,
}
}
}
#[derive(Debug, Clone, Copy)]
pub enum CheckpointFormat {
Binary,
Json,
Torsh,
}
pub struct CheckpointManager {
checkpoint_dir: PathBuf,
options: CheckpointOptions,
history: Vec<PathBuf>,
}
impl CheckpointManager {
pub fn new<P: AsRef<Path>>(checkpoint_dir: P, options: CheckpointOptions) -> TorshResult<Self> {
let checkpoint_dir = checkpoint_dir.as_ref().to_path_buf();
if !checkpoint_dir.exists() {
fs::create_dir_all(&checkpoint_dir).map_err(|e| {
TorshError::InvalidArgument(format!("Failed to create checkpoint directory: {e}"))
})?;
}
let mut manager = Self {
checkpoint_dir,
options,
history: vec![],
};
manager.load_history()?;
Ok(manager)
}
pub fn save_checkpoint(
&mut self,
data: CheckpointData,
name: Option<String>,
) -> TorshResult<PathBuf> {
let filename = name.unwrap_or_else(|| {
let step = data.metadata.step;
format!("checkpoint_step_{step}.ckpt")
});
let checkpoint_path = self.checkpoint_dir.join(&filename);
let step = data.metadata.step;
let serialized = format!("checkpoint_placeholder_step_{step}").into_bytes();
let final_data = if self.options.compress {
self.compress_data(&serialized)?
} else {
serialized
};
fs::write(&checkpoint_path, &final_data)
.map_err(|e| TorshError::InvalidArgument(format!("Failed to write checkpoint: {e}")))?;
self.history.push(checkpoint_path.clone());
self.cleanup_old_checkpoints()?;
if self.options.create_latest_link {
self.create_latest_link(&checkpoint_path)?;
}
Ok(checkpoint_path)
}
pub fn load_checkpoint<P: AsRef<Path>>(&self, path: P) -> TorshResult<CheckpointData> {
let path = path.as_ref();
let file_data = fs::read(path)
.map_err(|e| TorshError::InvalidArgument(format!("Failed to read checkpoint: {e}")))?;
let data = if self.options.compress {
self.decompress_data(&file_data)?
} else {
file_data
};
let checkpoint = CheckpointData {
graph: crate::FxGraph::new(), tensor_states: HashMap::new(),
optimizer_states: HashMap::new(),
rng_states: HashMap::new(),
custom_states: HashMap::new(),
metadata: CheckpointMetadata::new(0, "placeholder".to_string()),
};
if !checkpoint.metadata.checksum.is_empty() && !checkpoint.metadata.verify_checksum(&data) {
return Err(TorshError::InvalidArgument(
"Checkpoint checksum verification failed".to_string(),
));
}
Ok(checkpoint)
}
pub fn load_latest_checkpoint(&self) -> TorshResult<Option<CheckpointData>> {
let latest_path = self.checkpoint_dir.join("latest.ckpt");
if latest_path.exists() {
Ok(Some(self.load_checkpoint(latest_path)?))
} else if let Some(latest_from_history) = self.history.last() {
Ok(Some(self.load_checkpoint(latest_from_history)?))
} else {
Ok(None)
}
}
pub fn list_checkpoints(&self) -> Vec<PathBuf> {
self.history.clone()
}
pub fn delete_checkpoint<P: AsRef<Path>>(&mut self, path: P) -> TorshResult<()> {
let path = path.as_ref();
fs::remove_file(path).map_err(|e| {
TorshError::InvalidArgument(format!("Failed to delete checkpoint: {e}"))
})?;
self.history.retain(|p| p != path);
Ok(())
}
pub fn get_checkpoint_metadata<P: AsRef<Path>>(
&self,
path: P,
) -> TorshResult<CheckpointMetadata> {
let checkpoint = self.load_checkpoint(path)?;
Ok(checkpoint.metadata)
}
fn compress_data(&self, data: &[u8]) -> TorshResult<Vec<u8>> {
use oxiarc_deflate::gzip::gzip_compress;
gzip_compress(data, self.options.compression_level as u8)
.map_err(|e| TorshError::InvalidArgument(format!("Compression failed: {e}")))
}
fn decompress_data(&self, data: &[u8]) -> TorshResult<Vec<u8>> {
use oxiarc_deflate::gzip::gzip_decompress;
gzip_decompress(data)
.map_err(|e| TorshError::InvalidArgument(format!("Decompression failed: {e}")))
}
fn load_history(&mut self) -> TorshResult<()> {
let entries = fs::read_dir(&self.checkpoint_dir).map_err(|e| {
TorshError::InvalidArgument(format!("Failed to read checkpoint directory: {e}"))
})?;
let mut checkpoints = Vec::new();
for entry in entries {
let entry = entry.map_err(|e| {
TorshError::InvalidArgument(format!("Failed to read directory entry: {e}"))
})?;
let path = entry.path();
if path.is_file() && path.extension().is_some_and(|ext| ext == "ckpt") {
checkpoints.push(path);
}
}
checkpoints.sort_by_key(|path| {
fs::metadata(path)
.and_then(|meta| meta.modified())
.unwrap_or(SystemTime::UNIX_EPOCH)
});
self.history = checkpoints;
Ok(())
}
fn cleanup_old_checkpoints(&mut self) -> TorshResult<()> {
if let Some(max_history) = self.options.max_history {
while self.history.len() > max_history {
let old_checkpoint = self.history.remove(0);
let _ = fs::remove_file(&old_checkpoint);
}
}
Ok(())
}
fn create_latest_link(&self, checkpoint_path: &Path) -> TorshResult<()> {
let latest_path = self.checkpoint_dir.join("latest.ckpt");
if latest_path.exists() {
let _ = fs::remove_file(&latest_path);
}
#[cfg(unix)]
{
std::os::unix::fs::symlink(checkpoint_path, &latest_path).map_err(|e| {
TorshError::InvalidArgument(format!("Failed to create symlink: {e}"))
})?;
}
#[cfg(windows)]
{
fs::copy(checkpoint_path, &latest_path).map_err(|e| {
TorshError::InvalidArgument(format!("Failed to copy checkpoint: {e}"))
})?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct ExecutionCheckpoint {
pub graph: FxGraph,
pub execution_state: ExecutionState,
pub inputs: HashMap<String, TensorState>,
pub intermediate_results: HashMap<String, TensorState>,
pub remaining_nodes: Vec<String>,
pub metadata: CheckpointMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionState {
pub current_node: Option<String>,
pub completed_nodes: Vec<String>,
pub failed_nodes: Vec<String>,
pub start_time: u64,
pub elapsed_time: u64,
}
pub struct ResumableInterpreter {
interpreter: crate::interpreter::GraphInterpreter,
checkpoint_manager: Option<CheckpointManager>,
current_checkpoint: Option<ExecutionCheckpoint>,
checkpoint_frequency: usize,
}
impl ResumableInterpreter {
pub fn new(device_type: torsh_core::device::DeviceType) -> Self {
Self {
interpreter: crate::interpreter::GraphInterpreter::new(device_type),
checkpoint_manager: None,
current_checkpoint: None,
checkpoint_frequency: 100, }
}
pub fn with_checkpointing(mut self, manager: CheckpointManager) -> Self {
self.checkpoint_manager = Some(manager);
self
}
pub fn with_checkpoint_frequency(mut self, frequency: usize) -> Self {
self.checkpoint_frequency = frequency;
self
}
pub fn run_with_checkpointing(
&mut self,
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
) -> TorshResult<Vec<Tensor>> {
if let Some(manager) = &self.checkpoint_manager {
if let Ok(Some(checkpoint_data)) = manager.load_latest_checkpoint() {
if let Ok(execution_checkpoint) =
self.extract_execution_checkpoint(&checkpoint_data)
{
return self.resume_execution(execution_checkpoint);
}
}
}
self.start_fresh_execution(graph, inputs)
}
fn start_fresh_execution(
&mut self,
graph: &FxGraph,
inputs: HashMap<String, Tensor>,
) -> TorshResult<Vec<Tensor>> {
let start_time = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let mut tensor_states = HashMap::new();
for (name, tensor) in &inputs {
tensor_states.insert(name.clone(), TensorState::from_tensor(tensor)?);
}
let execution_state = ExecutionState {
current_node: None,
completed_nodes: vec![],
failed_nodes: vec![],
start_time,
elapsed_time: 0,
};
let checkpoint = ExecutionCheckpoint {
graph: graph.clone(),
execution_state,
inputs: tensor_states,
intermediate_results: HashMap::new(),
remaining_nodes: graph.nodes().map(|(idx, _)| format!("{idx:?}")).collect(),
metadata: CheckpointMetadata::new(0, "execution_checkpoint".to_string()),
};
self.current_checkpoint = Some(checkpoint);
self.execute_with_checkpoints(inputs)
}
fn resume_execution(&mut self, checkpoint: ExecutionCheckpoint) -> TorshResult<Vec<Tensor>> {
self.current_checkpoint = Some(checkpoint);
let mut inputs = HashMap::new();
if let Some(ref checkpoint) = self.current_checkpoint {
for (name, tensor_state) in &checkpoint.inputs {
inputs.insert(name.clone(), tensor_state.to_tensor()?);
}
}
self.execute_with_checkpoints(inputs)
}
fn execute_with_checkpoints(
&mut self,
inputs: HashMap<String, Tensor>,
) -> TorshResult<Vec<Tensor>> {
self.interpreter.run(
&self
.current_checkpoint
.as_ref()
.expect("checkpoint should be set before execution")
.graph,
inputs,
)
}
fn extract_execution_checkpoint(
&self,
_data: &CheckpointData,
) -> TorshResult<ExecutionCheckpoint> {
Err(TorshError::InvalidArgument(
"No execution checkpoint found".to_string(),
))
}
pub fn save_execution_checkpoint(&mut self) -> TorshResult<()> {
if let (Some(manager), Some(checkpoint)) =
(&mut self.checkpoint_manager, &self.current_checkpoint)
{
let checkpoint_data = CheckpointData {
graph: checkpoint.graph.clone(),
tensor_states: HashMap::new(), optimizer_states: HashMap::new(),
rng_states: HashMap::new(),
custom_states: HashMap::new(),
metadata: checkpoint.metadata.clone(),
};
manager.save_checkpoint(checkpoint_data, Some("execution.ckpt".to_string()))?;
}
Ok(())
}
}
pub fn create_checkpoint(
graph: &FxGraph,
tensors: HashMap<String, Tensor>,
step: u64,
loss: Option<f64>,
) -> TorshResult<CheckpointData> {
let mut tensor_states = HashMap::new();
for (name, tensor) in tensors {
tensor_states.insert(name, TensorState::from_tensor(&tensor)?);
}
let mut metadata = CheckpointMetadata::new(step, "graph_checkpoint".to_string());
if let Some(loss_val) = loss {
metadata = metadata.with_loss(loss_val);
}
Ok(CheckpointData {
graph: graph.clone(),
tensor_states,
optimizer_states: HashMap::new(),
rng_states: HashMap::new(),
custom_states: HashMap::new(),
metadata,
})
}
pub fn save_checkpoint<P: AsRef<Path>>(
path: P,
data: CheckpointData,
options: Option<CheckpointOptions>,
) -> TorshResult<()> {
let options = options.unwrap_or_default();
let mut manager =
CheckpointManager::new(path.as_ref().parent().unwrap_or(Path::new(".")), options)?;
let filename = path
.as_ref()
.file_name()
.and_then(|name| name.to_str())
.unwrap_or("checkpoint.ckpt")
.to_string();
manager.save_checkpoint(data, Some(filename))?;
Ok(())
}
pub fn load_checkpoint<P: AsRef<Path>>(
path: P,
options: Option<CheckpointOptions>,
) -> TorshResult<CheckpointData> {
let options = options.unwrap_or_default();
let manager =
CheckpointManager::new(path.as_ref().parent().unwrap_or(Path::new(".")), options)?;
manager.load_checkpoint(path)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tracer::ModuleTracer;
use tempfile::TempDir;
use torsh_tensor::creation::ones;
#[test]
fn test_checkpoint_metadata() {
let metadata = CheckpointMetadata::new(100, "test_model".to_string())
.with_loss(0.5)
.with_metadata("epoch".to_string(), "10".to_string());
assert_eq!(metadata.step, 100);
assert_eq!(metadata.loss, Some(0.5));
assert_eq!(metadata.user_metadata.get("epoch"), Some(&"10".to_string()));
}
#[test]
fn test_tensor_state_serialization() {
let tensor = ones(&[2, 3]).unwrap();
let state = TensorState::from_tensor(&tensor).unwrap();
assert_eq!(state.shape, vec![2, 3]);
assert_eq!(state.dtype, format!("{:?}", tensor.dtype()));
let restored = state.to_tensor().unwrap();
assert_eq!(restored.shape().dims(), &[2, 3]);
}
#[test]
fn test_optimizer_state() {
let state = OptimizerState::new("adam".to_string(), 0.001)
.with_parameter("beta1".to_string(), 0.9)
.with_parameter("beta2".to_string(), 0.999);
assert_eq!(state.optimizer_type, "adam");
assert_eq!(state.learning_rate, 0.001);
assert_eq!(state.parameters.get("beta1"), Some(&0.9));
}
#[test]
fn test_checkpoint_manager_creation() {
let temp_dir = TempDir::new().unwrap();
let options = CheckpointOptions::default();
let result = CheckpointManager::new(temp_dir.path(), options);
assert!(result.is_ok());
}
#[test]
fn test_checkpoint_save_load() {
let temp_dir = TempDir::new().unwrap();
let options = CheckpointOptions::default();
let mut manager = CheckpointManager::new(temp_dir.path(), options).unwrap();
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
tracer.add_output("node_0");
let graph = tracer.finalize();
let tensor = ones(&[2, 3]).unwrap();
let checkpoint = create_checkpoint(
&graph,
vec![("x".to_string(), tensor)].into_iter().collect(),
100,
Some(0.5),
)
.unwrap();
let saved_path = manager.save_checkpoint(checkpoint.clone(), None).unwrap();
assert!(saved_path.exists());
let loaded = manager.load_checkpoint(&saved_path).unwrap();
assert!(loaded.metadata.step == 0); assert!(loaded.metadata.loss.is_none()); }
#[test]
fn test_checkpoint_compression() {
let temp_dir = TempDir::new().unwrap();
let options = CheckpointOptions {
compress: true,
..Default::default()
};
let manager = CheckpointManager::new(temp_dir.path(), options).unwrap();
let test_data = vec![1u8; 1000]; let compressed = manager.compress_data(&test_data).unwrap();
let decompressed = manager.decompress_data(&compressed).unwrap();
assert_eq!(test_data, decompressed);
assert!(compressed.len() < test_data.len()); }
#[test]
fn test_resumable_interpreter() {
let interpreter = ResumableInterpreter::new(torsh_core::device::DeviceType::Cpu);
assert_eq!(interpreter.checkpoint_frequency, 100);
}
#[test]
fn test_checkpoint_history_management() {
let temp_dir = TempDir::new().unwrap();
let options = CheckpointOptions {
max_history: Some(2),
..Default::default()
};
let mut manager = CheckpointManager::new(temp_dir.path(), options).unwrap();
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
let graph = tracer.finalize();
let checkpoint = CheckpointData {
graph,
tensor_states: HashMap::new(),
optimizer_states: HashMap::new(),
rng_states: HashMap::new(),
custom_states: HashMap::new(),
metadata: CheckpointMetadata::new(0, "test".to_string()),
};
manager
.save_checkpoint(checkpoint.clone(), Some("ckpt1.ckpt".to_string()))
.unwrap();
manager
.save_checkpoint(checkpoint.clone(), Some("ckpt2.ckpt".to_string()))
.unwrap();
manager
.save_checkpoint(checkpoint.clone(), Some("ckpt3.ckpt".to_string()))
.unwrap();
let history = manager.list_checkpoints();
assert!(history.len() <= 2);
}
#[test]
fn test_checkpoint_formats() {
let temp_dir = TempDir::new().unwrap();
for format in &[CheckpointFormat::Binary, CheckpointFormat::Json] {
let options = CheckpointOptions {
format: *format,
compress: false,
..Default::default()
};
let mut manager = CheckpointManager::new(temp_dir.path(), options).unwrap();
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
let graph = tracer.finalize();
let checkpoint = CheckpointData {
graph,
tensor_states: HashMap::new(),
optimizer_states: HashMap::new(),
rng_states: HashMap::new(),
custom_states: HashMap::new(),
metadata: CheckpointMetadata::new(0, "test".to_string()),
};
let saved_path = manager.save_checkpoint(checkpoint.clone(), None).unwrap();
let loaded = manager.load_checkpoint(&saved_path).unwrap();
assert_eq!(loaded.metadata.step, checkpoint.metadata.step);
}
}
#[test]
fn test_execution_checkpoint() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
tracer.add_call("relu", vec!["x".to_string()]);
let graph = tracer.finalize();
let execution_state = ExecutionState {
current_node: None,
completed_nodes: vec![],
failed_nodes: vec![],
start_time: 0,
elapsed_time: 0,
};
let checkpoint = ExecutionCheckpoint {
graph,
execution_state,
inputs: HashMap::new(),
intermediate_results: HashMap::new(),
remaining_nodes: vec![],
metadata: CheckpointMetadata::new(0, "execution".to_string()),
};
assert_eq!(checkpoint.metadata.step, 0);
assert_eq!(checkpoint.metadata.model_info, "execution");
}
#[test]
fn test_utility_functions() {
let mut tracer = ModuleTracer::new();
tracer.add_input("x");
let graph = tracer.finalize();
let tensor = ones(&[2, 3]).unwrap();
let tensors = vec![("x".to_string(), tensor)].into_iter().collect();
let checkpoint = create_checkpoint(&graph, tensors, 50, Some(0.25)).unwrap();
assert_eq!(checkpoint.metadata.step, 50);
assert_eq!(checkpoint.metadata.loss, Some(0.25));
assert!(checkpoint.tensor_states.contains_key("x"));
}
}