use crate::{Scirs2Exec, TlBackendError, TlBackendResult};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, BufWriter, Write};
use std::path::Path;
use std::time::{SystemTime, UNIX_EPOCH};
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub enable_compression: bool,
pub include_tape: bool,
pub verify_checksum: bool,
pub incremental: bool,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
enable_compression: false,
include_tape: false,
verify_checksum: true,
incremental: false,
}
}
}
impl CheckpointConfig {
pub fn for_training() -> Self {
Self {
enable_compression: false,
include_tape: true,
verify_checksum: true,
incremental: false,
}
}
pub fn for_inference() -> Self {
Self {
enable_compression: true,
include_tape: false,
verify_checksum: true,
incremental: false,
}
}
pub fn incremental() -> Self {
Self {
enable_compression: false,
include_tape: true,
verify_checksum: true,
incremental: true,
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct CheckpointMetadata {
pub iteration: usize,
pub timestamp: u64,
pub version: String,
pub tensor_count: usize,
pub total_bytes: usize,
pub custom: HashMap<String, String>,
pub checksum: Option<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct SerializedTensor {
name: String,
shape: Vec<usize>,
data: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct Checkpoint {
pub metadata: CheckpointMetadata,
tensors: Vec<SerializedTensor>,
#[allow(dead_code)]
config: CheckpointConfig,
}
impl Checkpoint {
pub fn from_executor(executor: &Scirs2Exec, iteration: usize) -> TlBackendResult<Self> {
Self::from_executor_with_config(executor, iteration, &CheckpointConfig::default())
}
pub fn from_executor_with_config(
executor: &Scirs2Exec,
iteration: usize,
config: &CheckpointConfig,
) -> TlBackendResult<Self> {
let mut tensors = Vec::new();
let mut total_bytes = 0;
for (name, tensor) in &executor.tensors {
let shape = tensor.shape().to_vec();
let data: Vec<f64> = tensor.iter().copied().collect();
total_bytes += data.len() * std::mem::size_of::<f64>();
tensors.push(SerializedTensor {
name: name.clone(),
shape,
data,
});
}
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.map_err(|e| TlBackendError::execution(format!("Failed to get timestamp: {}", e)))?
.as_secs();
let checksum = if config.verify_checksum {
Some(Self::compute_checksum(&tensors))
} else {
None
};
let metadata = CheckpointMetadata {
iteration,
timestamp,
version: "0.1.0".to_string(),
tensor_count: tensors.len(),
total_bytes,
custom: HashMap::new(),
checksum,
};
Ok(Checkpoint {
metadata,
tensors,
config: config.clone(),
})
}
pub fn save<P: AsRef<Path>>(&self, path: P) -> TlBackendResult<()> {
let file = File::create(path.as_ref()).map_err(|e| {
TlBackendError::execution(format!("Failed to create checkpoint file: {}", e))
})?;
let mut writer = BufWriter::new(file);
let checkpoint_data = CheckpointData {
metadata: self.metadata.clone(),
tensors: self.tensors.clone(),
};
serde_json::to_writer(&mut writer, &checkpoint_data).map_err(|e| {
TlBackendError::execution(format!("Failed to serialize checkpoint: {}", e))
})?;
writer
.flush()
.map_err(|e| TlBackendError::execution(format!("Failed to flush checkpoint: {}", e)))?;
Ok(())
}
pub fn load<P: AsRef<Path>>(path: P) -> TlBackendResult<Self> {
Self::load_with_config(path, &CheckpointConfig::default())
}
pub fn load_with_config<P: AsRef<Path>>(
path: P,
config: &CheckpointConfig,
) -> TlBackendResult<Self> {
let file = File::open(path.as_ref()).map_err(|e| {
TlBackendError::execution(format!("Failed to open checkpoint file: {}", e))
})?;
let reader = BufReader::new(file);
let checkpoint_data: CheckpointData = serde_json::from_reader(reader).map_err(|e| {
TlBackendError::execution(format!("Failed to deserialize checkpoint: {}", e))
})?;
if config.verify_checksum {
if let Some(ref expected_checksum) = checkpoint_data.metadata.checksum {
let actual_checksum = Self::compute_checksum(&checkpoint_data.tensors);
if &actual_checksum != expected_checksum {
return Err(TlBackendError::execution(
"Checkpoint checksum verification failed",
));
}
}
}
Ok(Checkpoint {
metadata: checkpoint_data.metadata,
tensors: checkpoint_data.tensors,
config: config.clone(),
})
}
pub fn restore(&self) -> TlBackendResult<Scirs2Exec> {
let mut executor = Scirs2Exec::new();
for serialized in &self.tensors {
let tensor = scirs2_core::ndarray::ArrayD::from_shape_vec(
serialized.shape.clone(),
serialized.data.clone(),
)
.map_err(|e| {
TlBackendError::execution(format!(
"Failed to restore tensor {}: {}",
serialized.name, e
))
})?;
executor.add_tensor(&serialized.name, tensor);
}
Ok(executor)
}
pub fn restore_into(&self, executor: &mut Scirs2Exec) -> TlBackendResult<()> {
for serialized in &self.tensors {
let tensor = scirs2_core::ndarray::ArrayD::from_shape_vec(
serialized.shape.clone(),
serialized.data.clone(),
)
.map_err(|e| {
TlBackendError::execution(format!(
"Failed to restore tensor {}: {}",
serialized.name, e
))
})?;
executor.add_tensor(&serialized.name, tensor);
}
Ok(())
}
pub fn add_metadata(&mut self, key: String, value: String) {
self.metadata.custom.insert(key, value);
}
pub fn get_metadata(&self, key: &str) -> Option<&String> {
self.metadata.custom.get(key)
}
fn compute_checksum(tensors: &[SerializedTensor]) -> String {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
for tensor in tensors {
tensor.name.hash(&mut hasher);
tensor.shape.hash(&mut hasher);
for &value in &tensor.data {
value.to_bits().hash(&mut hasher);
}
}
format!("{:x}", hasher.finish())
}
pub fn size_bytes(&self) -> usize {
self.metadata.total_bytes
}
pub fn size_human_readable(&self) -> String {
let bytes = self.metadata.total_bytes;
if bytes < 1024 {
format!("{} bytes", bytes)
} else if bytes < 1024 * 1024 {
format!("{:.2} KB", bytes as f64 / 1024.0)
} else if bytes < 1024 * 1024 * 1024 {
format!("{:.2} MB", bytes as f64 / (1024.0 * 1024.0))
} else {
format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
}
}
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
struct CheckpointData {
metadata: CheckpointMetadata,
tensors: Vec<SerializedTensor>,
}
pub struct CheckpointManager {
checkpoint_dir: std::path::PathBuf,
max_checkpoints: Option<usize>,
filename_pattern: String,
}
impl CheckpointManager {
pub fn new<P: AsRef<Path>>(checkpoint_dir: P) -> TlBackendResult<Self> {
let checkpoint_dir = checkpoint_dir.as_ref().to_path_buf();
if !checkpoint_dir.exists() {
std::fs::create_dir_all(&checkpoint_dir).map_err(|e| {
TlBackendError::execution(format!("Failed to create checkpoint directory: {}", e))
})?;
}
Ok(Self {
checkpoint_dir,
max_checkpoints: Some(5), filename_pattern: "checkpoint_iter_{}.json".to_string(),
})
}
pub fn set_max_checkpoints(&mut self, max: Option<usize>) {
self.max_checkpoints = max;
}
pub fn set_filename_pattern(&mut self, pattern: String) {
self.filename_pattern = pattern;
}
pub fn save_checkpoint(
&self,
executor: &Scirs2Exec,
iteration: usize,
) -> TlBackendResult<std::path::PathBuf> {
let checkpoint = Checkpoint::from_executor(executor, iteration)?;
let filename = self.filename_pattern.replace("{}", &iteration.to_string());
let path = self.checkpoint_dir.join(filename);
checkpoint.save(&path)?;
if let Some(max) = self.max_checkpoints {
self.cleanup_old_checkpoints(max)?;
}
Ok(path)
}
pub fn load_latest(&self) -> TlBackendResult<Checkpoint> {
let latest_path = self.find_latest_checkpoint()?;
Checkpoint::load(latest_path)
}
fn find_latest_checkpoint(&self) -> TlBackendResult<std::path::PathBuf> {
let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
})?;
let mut checkpoints: Vec<_> = entries
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.and_then(|s| s.to_str())
.map(|s| s == "json")
.unwrap_or(false)
})
.collect();
checkpoints.sort_by_key(|e| {
e.metadata()
.ok()
.and_then(|m| m.modified().ok())
.unwrap_or(SystemTime::UNIX_EPOCH)
});
checkpoints
.last()
.map(|e| e.path())
.ok_or_else(|| TlBackendError::execution("No checkpoints found"))
}
fn cleanup_old_checkpoints(&self, max: usize) -> TlBackendResult<()> {
let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
})?;
let mut checkpoints: Vec<_> = entries
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.and_then(|s| s.to_str())
.map(|s| s == "json")
.unwrap_or(false)
})
.collect();
checkpoints.sort_by_key(|e| {
e.metadata()
.ok()
.and_then(|m| m.modified().ok())
.unwrap_or(SystemTime::UNIX_EPOCH)
});
let to_remove = checkpoints.len().saturating_sub(max);
for entry in checkpoints.iter().take(to_remove) {
std::fs::remove_file(entry.path()).ok();
}
Ok(())
}
pub fn list_checkpoints(&self) -> TlBackendResult<Vec<std::path::PathBuf>> {
let entries = std::fs::read_dir(&self.checkpoint_dir).map_err(|e| {
TlBackendError::execution(format!("Failed to read checkpoint directory: {}", e))
})?;
let mut checkpoints: Vec<_> = entries
.filter_map(|e| e.ok())
.filter(|e| {
e.path()
.extension()
.and_then(|s| s.to_str())
.map(|s| s == "json")
.unwrap_or(false)
})
.map(|e| e.path())
.collect();
checkpoints.sort();
Ok(checkpoints)
}
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::ArrayD;
#[test]
fn test_checkpoint_config_default() {
let config = CheckpointConfig::default();
assert!(!config.enable_compression);
assert!(!config.include_tape);
assert!(config.verify_checksum);
assert!(!config.incremental);
}
#[test]
fn test_checkpoint_config_training() {
let config = CheckpointConfig::for_training();
assert!(!config.enable_compression);
assert!(config.include_tape);
assert!(config.verify_checksum);
}
#[test]
fn test_checkpoint_config_inference() {
let config = CheckpointConfig::for_inference();
assert!(config.enable_compression);
assert!(!config.include_tape);
assert!(config.verify_checksum);
}
#[test]
fn test_checkpoint_from_executor() {
let mut executor = Scirs2Exec::new();
let tensor =
ArrayD::from_shape_vec(vec![2, 3], vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0]).expect("unwrap");
executor.add_tensor("test_tensor", tensor);
let checkpoint = Checkpoint::from_executor(&executor, 1).expect("unwrap");
assert_eq!(checkpoint.metadata.iteration, 1);
assert_eq!(checkpoint.metadata.tensor_count, 1);
assert!(checkpoint.metadata.total_bytes > 0);
}
#[test]
fn test_checkpoint_save_and_load() {
let mut executor = Scirs2Exec::new();
let tensor = ArrayD::from_shape_vec(vec![3], vec![1.0, 2.0, 3.0]).expect("unwrap");
executor.add_tensor("weights", tensor);
let checkpoint = Checkpoint::from_executor(&executor, 5).expect("unwrap");
let temp_path = std::env::temp_dir().join("test_checkpoint.json");
checkpoint.save(&temp_path).expect("unwrap");
let loaded = Checkpoint::load(&temp_path).expect("unwrap");
assert_eq!(loaded.metadata.iteration, 5);
assert_eq!(loaded.metadata.tensor_count, 1);
std::fs::remove_file(temp_path).ok();
}
#[test]
fn test_checkpoint_restore() {
let mut executor = Scirs2Exec::new();
let tensor = ArrayD::from_shape_vec(vec![2], vec![10.0, 20.0]).expect("unwrap");
executor.add_tensor("params", tensor.clone());
let checkpoint = Checkpoint::from_executor(&executor, 1).expect("unwrap");
let restored_executor = checkpoint.restore().expect("unwrap");
let restored_tensor = restored_executor.get_tensor("params").expect("unwrap");
assert_eq!(restored_tensor.shape(), tensor.shape());
assert_eq!(restored_tensor[[0]], 10.0);
assert_eq!(restored_tensor[[1]], 20.0);
}
#[test]
fn test_checkpoint_metadata() {
let mut executor = Scirs2Exec::new();
let tensor = ArrayD::from_shape_vec(vec![1], vec![1.0]).expect("unwrap");
executor.add_tensor("x", tensor);
let mut checkpoint = Checkpoint::from_executor(&executor, 10).expect("unwrap");
checkpoint.add_metadata("learning_rate".to_string(), "0.001".to_string());
checkpoint.add_metadata("optimizer".to_string(), "adam".to_string());
assert_eq!(
checkpoint.get_metadata("learning_rate"),
Some(&"0.001".to_string())
);
assert_eq!(
checkpoint.get_metadata("optimizer"),
Some(&"adam".to_string())
);
assert_eq!(checkpoint.get_metadata("missing"), None);
}
#[test]
fn test_checkpoint_size_human_readable() {
let mut executor = Scirs2Exec::new();
let tensor = ArrayD::from_shape_vec(vec![1000], vec![1.0; 1000]).expect("unwrap");
executor.add_tensor("big_tensor", tensor);
let checkpoint = Checkpoint::from_executor(&executor, 1).expect("unwrap");
let size_str = checkpoint.size_human_readable();
assert!(size_str.contains("KB") || size_str.contains("bytes"));
}
#[test]
fn test_checkpoint_manager() {
let temp_dir = std::env::temp_dir().join("test_checkpoints");
let manager = CheckpointManager::new(&temp_dir).expect("unwrap");
let mut executor = Scirs2Exec::new();
let tensor = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).expect("unwrap");
executor.add_tensor("data", tensor);
let path = manager.save_checkpoint(&executor, 1).expect("unwrap");
assert!(path.exists());
let checkpoints = manager.list_checkpoints().expect("unwrap");
assert_eq!(checkpoints.len(), 1);
std::fs::remove_dir_all(temp_dir).ok();
}
#[test]
fn test_checkpoint_manager_cleanup() {
let temp_dir = std::env::temp_dir().join("test_checkpoints_cleanup");
let mut manager = CheckpointManager::new(&temp_dir).expect("unwrap");
manager.set_max_checkpoints(Some(3));
let mut executor = Scirs2Exec::new();
let tensor = ArrayD::from_shape_vec(vec![1], vec![1.0]).expect("unwrap");
executor.add_tensor("x", tensor);
for i in 1..=5 {
manager.save_checkpoint(&executor, i).expect("unwrap");
}
let checkpoints = manager.list_checkpoints().expect("unwrap");
assert!(checkpoints.len() <= 3);
std::fs::remove_dir_all(temp_dir).ok();
}
#[test]
fn test_checkpoint_checksum_verification() {
let mut executor = Scirs2Exec::new();
let tensor = ArrayD::from_shape_vec(vec![2], vec![1.0, 2.0]).expect("unwrap");
executor.add_tensor("data", tensor);
let config = CheckpointConfig {
verify_checksum: true,
..Default::default()
};
let checkpoint =
Checkpoint::from_executor_with_config(&executor, 1, &config).expect("unwrap");
assert!(checkpoint.metadata.checksum.is_some());
let temp_path = std::env::temp_dir().join("test_checksum.json");
checkpoint.save(&temp_path).expect("unwrap");
let loaded = Checkpoint::load_with_config(&temp_path, &config).expect("unwrap");
assert_eq!(loaded.metadata.checksum, checkpoint.metadata.checksum);
std::fs::remove_file(temp_path).ok();
}
}