use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::node::NodeId;
pub type CheckpointId = String;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeState {
pub node_id: usize,
pub completed: bool,
pub success: bool,
pub output_data: Option<Vec<u8>>,
#[serde(default)]
pub output_summary: Option<String>,
}
impl NodeState {
pub fn completed(node_id: usize, success: bool) -> Self {
Self {
node_id,
completed: true,
success,
output_data: None,
output_summary: None,
}
}
pub fn pending(node_id: usize) -> Self {
Self {
node_id,
completed: false,
success: false,
output_data: None,
output_summary: None,
}
}
pub fn with_output_data(mut self, data: Vec<u8>) -> Self {
self.output_data = Some(data);
self
}
pub fn with_summary(mut self, summary: impl Into<String>) -> Self {
self.output_summary = Some(summary.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: CheckpointId,
pub timestamp: u64,
pub pc: usize,
pub loop_count: usize,
pub active_nodes: HashSet<usize>,
pub node_states: HashMap<usize, NodeState>,
pub env_data: Option<Vec<u8>>,
pub metadata: HashMap<String, String>,
}
impl Checkpoint {
pub fn new(pc: usize, loop_count: usize) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
let id = format!("ckpt_{}_{}", timestamp, pc);
Self {
id,
timestamp,
pc,
loop_count,
active_nodes: HashSet::new(),
node_states: HashMap::new(),
env_data: None,
metadata: HashMap::new(),
}
}
pub fn with_id(id: impl Into<String>, pc: usize, loop_count: usize) -> Self {
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs();
Self {
id: id.into(),
timestamp,
pc,
loop_count,
active_nodes: HashSet::new(),
node_states: HashMap::new(),
env_data: None,
metadata: HashMap::new(),
}
}
pub fn set_active_nodes(&mut self, nodes: &HashSet<NodeId>) {
self.active_nodes = nodes.iter().map(|id| id.0).collect();
}
pub fn get_active_nodes(&self) -> HashSet<NodeId> {
self.active_nodes.iter().map(|id| NodeId(*id)).collect()
}
pub fn add_node_state(&mut self, state: NodeState) {
self.node_states.insert(state.node_id, state);
}
pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
}
#[derive(Debug, Clone)]
pub enum CheckpointError {
NotFound(CheckpointId),
SerializationError(String),
DeserializationError(String),
StorageError(String),
InvalidCheckpoint(String),
StoreNotConfigured,
}
impl std::fmt::Display for CheckpointError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
CheckpointError::NotFound(id) => write!(f, "Checkpoint not found: {}", id),
CheckpointError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
CheckpointError::DeserializationError(msg) => {
write!(f, "Deserialization error: {}", msg)
}
CheckpointError::StorageError(msg) => write!(f, "Storage error: {}", msg),
CheckpointError::InvalidCheckpoint(msg) => write!(f, "Invalid checkpoint: {}", msg),
CheckpointError::StoreNotConfigured => write!(f, "Checkpoint store not configured"),
}
}
}
impl std::error::Error for CheckpointError {}
#[async_trait]
pub trait CheckpointStore: Send + Sync {
async fn save(&self, checkpoint: &Checkpoint) -> Result<(), CheckpointError>;
async fn load(&self, id: &CheckpointId) -> Result<Checkpoint, CheckpointError>;
async fn delete(&self, id: &CheckpointId) -> Result<(), CheckpointError>;
async fn list(&self) -> Result<Vec<CheckpointId>, CheckpointError>;
async fn latest(&self) -> Result<Option<Checkpoint>, CheckpointError>;
async fn clear(&self) -> Result<(), CheckpointError>;
}
#[derive(Default)]
pub struct MemoryCheckpointStore {
checkpoints: std::sync::RwLock<HashMap<CheckpointId, Checkpoint>>,
}
impl MemoryCheckpointStore {
pub fn new() -> Self {
Self {
checkpoints: std::sync::RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl CheckpointStore for MemoryCheckpointStore {
async fn save(&self, checkpoint: &Checkpoint) -> Result<(), CheckpointError> {
let mut store = self.checkpoints.write().map_err(|e| {
CheckpointError::StorageError(format!("Failed to acquire write lock: {}", e))
})?;
store.insert(checkpoint.id.clone(), checkpoint.clone());
Ok(())
}
async fn load(&self, id: &CheckpointId) -> Result<Checkpoint, CheckpointError> {
let store = self.checkpoints.read().map_err(|e| {
CheckpointError::StorageError(format!("Failed to acquire read lock: {}", e))
})?;
store
.get(id)
.cloned()
.ok_or(CheckpointError::NotFound(id.clone()))
}
async fn delete(&self, id: &CheckpointId) -> Result<(), CheckpointError> {
let mut store = self.checkpoints.write().map_err(|e| {
CheckpointError::StorageError(format!("Failed to acquire write lock: {}", e))
})?;
store.remove(id);
Ok(())
}
async fn list(&self) -> Result<Vec<CheckpointId>, CheckpointError> {
let store = self.checkpoints.read().map_err(|e| {
CheckpointError::StorageError(format!("Failed to acquire read lock: {}", e))
})?;
Ok(store.keys().cloned().collect())
}
async fn latest(&self) -> Result<Option<Checkpoint>, CheckpointError> {
let store = self.checkpoints.read().map_err(|e| {
CheckpointError::StorageError(format!("Failed to acquire read lock: {}", e))
})?;
Ok(store.values().max_by_key(|c| c.timestamp).cloned())
}
async fn clear(&self) -> Result<(), CheckpointError> {
let mut store = self.checkpoints.write().map_err(|e| {
CheckpointError::StorageError(format!("Failed to acquire write lock: {}", e))
})?;
store.clear();
Ok(())
}
}
pub struct FileCheckpointStore {
base_path: PathBuf,
}
impl FileCheckpointStore {
pub fn new(base_path: impl AsRef<Path>) -> Self {
Self {
base_path: base_path.as_ref().to_path_buf(),
}
}
fn checkpoint_path(&self, id: &CheckpointId) -> Result<PathBuf, CheckpointError> {
if id.contains('/') || id.contains('\\') || id.contains("..") {
return Err(CheckpointError::InvalidCheckpoint(
"Checkpoint ID contains invalid characters".to_string(),
));
}
Ok(self.base_path.join(format!("{}.json", id)))
}
async fn ensure_dir(&self) -> Result<(), CheckpointError> {
tokio::fs::create_dir_all(&self.base_path)
.await
.map_err(|e| {
CheckpointError::StorageError(format!(
"Failed to create checkpoint directory: {}",
e
))
})
}
}
#[async_trait]
impl CheckpointStore for FileCheckpointStore {
async fn save(&self, checkpoint: &Checkpoint) -> Result<(), CheckpointError> {
self.ensure_dir().await?;
let json = serde_json::to_string_pretty(checkpoint)
.map_err(|e| CheckpointError::SerializationError(e.to_string()))?;
let path = self.checkpoint_path(&checkpoint.id)?;
tokio::fs::write(&path, json).await.map_err(|e| {
CheckpointError::StorageError(format!("Failed to write checkpoint file: {}", e))
})?;
Ok(())
}
async fn load(&self, id: &CheckpointId) -> Result<Checkpoint, CheckpointError> {
let path = self.checkpoint_path(id)?;
match tokio::fs::metadata(&path).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Err(CheckpointError::NotFound(id.clone()));
}
Err(e) => {
return Err(CheckpointError::StorageError(format!(
"Failed to check checkpoint file: {}",
e
)));
}
}
let json = tokio::fs::read_to_string(&path).await.map_err(|e| {
CheckpointError::StorageError(format!("Failed to read checkpoint file: {}", e))
})?;
serde_json::from_str(&json)
.map_err(|e| CheckpointError::DeserializationError(e.to_string()))
}
async fn delete(&self, id: &CheckpointId) -> Result<(), CheckpointError> {
let path = self.checkpoint_path(id)?;
match tokio::fs::metadata(&path).await {
Ok(_) => {
tokio::fs::remove_file(&path).await.map_err(|e| {
CheckpointError::StorageError(format!(
"Failed to delete checkpoint file: {}",
e
))
})?;
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
}
Err(e) => {
return Err(CheckpointError::StorageError(format!(
"Failed to check checkpoint file: {}",
e
)));
}
}
Ok(())
}
async fn list(&self) -> Result<Vec<CheckpointId>, CheckpointError> {
self.ensure_dir().await?;
let mut entries = tokio::fs::read_dir(&self.base_path).await.map_err(|e| {
CheckpointError::StorageError(format!("Failed to read checkpoint directory: {}", e))
})?;
let mut ids = Vec::new();
while let Some(entry) = entries.next_entry().await.map_err(|e| {
CheckpointError::StorageError(format!("Failed to read directory entry: {}", e))
})? {
if let Some(name) = entry.file_name().to_str()
&& name.ends_with(".json")
{
ids.push(name.trim_end_matches(".json").to_string());
}
}
Ok(ids)
}
async fn latest(&self) -> Result<Option<Checkpoint>, CheckpointError> {
let ids = self.list().await?;
let mut latest: Option<Checkpoint> = None;
for id in ids {
if let Ok(checkpoint) = self.load(&id).await
&& latest
.as_ref()
.is_none_or(|l| checkpoint.timestamp > l.timestamp)
{
latest = Some(checkpoint);
}
}
Ok(latest)
}
async fn clear(&self) -> Result<(), CheckpointError> {
let ids = self.list().await?;
for id in ids {
self.delete(&id).await?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub enabled: bool,
pub interval_nodes: Option<usize>,
pub interval_seconds: Option<u64>,
pub on_loop_iteration: bool,
pub before_conditional: bool,
pub max_checkpoints: usize,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
enabled: false,
interval_nodes: Some(10),
interval_seconds: None,
on_loop_iteration: true,
before_conditional: true,
max_checkpoints: 5,
}
}
}
impl CheckpointConfig {
pub fn enabled() -> Self {
Self {
enabled: true,
..Default::default()
}
}
pub fn with_node_interval(mut self, interval: usize) -> Self {
self.interval_nodes = Some(interval);
self
}
pub fn with_time_interval(mut self, seconds: u64) -> Self {
self.interval_seconds = Some(seconds);
self
}
pub fn with_loop_checkpoint(mut self, enabled: bool) -> Self {
self.on_loop_iteration = enabled;
self
}
pub fn with_max_checkpoints(mut self, max: usize) -> Self {
self.max_checkpoints = max;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_checkpoint_store() {
let store = MemoryCheckpointStore::new();
let mut checkpoint = Checkpoint::new(5, 2);
checkpoint.add_metadata("test_key", "test_value");
checkpoint.active_nodes.insert(1);
checkpoint.active_nodes.insert(2);
store.save(&checkpoint).await.unwrap();
let loaded = store.load(&checkpoint.id).await.unwrap();
assert_eq!(loaded.pc, 5);
assert_eq!(loaded.loop_count, 2);
assert!(loaded.active_nodes.contains(&1));
assert!(loaded.active_nodes.contains(&2));
assert_eq!(
loaded.metadata.get("test_key"),
Some(&"test_value".to_string())
);
let ids = store.list().await.unwrap();
assert_eq!(ids.len(), 1);
let latest = store.latest().await.unwrap();
assert!(latest.is_some());
store.delete(&checkpoint.id).await.unwrap();
let ids = store.list().await.unwrap();
assert_eq!(ids.len(), 0);
}
#[test]
fn test_checkpoint_creation() {
let checkpoint = Checkpoint::new(10, 3);
assert_eq!(checkpoint.pc, 10);
assert_eq!(checkpoint.loop_count, 3);
assert!(checkpoint.id.starts_with("ckpt_"));
}
#[test]
fn test_checkpoint_config() {
let config = CheckpointConfig::enabled()
.with_node_interval(5)
.with_time_interval(60)
.with_max_checkpoints(10);
assert!(config.enabled);
assert_eq!(config.interval_nodes, Some(5));
assert_eq!(config.interval_seconds, Some(60));
assert_eq!(config.max_checkpoints, 10);
}
}