use oxify_model::{ExecutionContext, NodeExecutionResult, NodeId, WorkflowId};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
use uuid::Uuid;
pub type CheckpointId = Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExecutionCheckpoint {
pub id: CheckpointId,
pub workflow_id: WorkflowId,
pub execution_id: Uuid,
pub context: ExecutionContext,
pub completed_nodes: Vec<NodeId>,
pub node_results: HashMap<NodeId, NodeExecutionResult>,
pub current_level: usize,
pub paused: bool,
pub created_at: std::time::SystemTime,
pub reason: String,
}
impl ExecutionCheckpoint {
pub fn new(
workflow_id: WorkflowId,
execution_id: Uuid,
context: ExecutionContext,
completed_nodes: Vec<NodeId>,
current_level: usize,
reason: String,
) -> Self {
Self {
id: Uuid::new_v4(),
workflow_id,
execution_id,
context,
completed_nodes,
node_results: HashMap::new(),
current_level,
paused: false,
created_at: std::time::SystemTime::now(),
reason,
}
}
pub fn add_node_result(&mut self, node_id: NodeId, result: NodeExecutionResult) {
self.node_results.insert(node_id, result);
}
pub fn is_node_completed(&self, node_id: NodeId) -> bool {
self.completed_nodes.contains(&node_id)
}
}
pub trait CheckpointStore: Send + Sync {
fn save(&self, checkpoint: &ExecutionCheckpoint) -> Result<CheckpointId, String>;
fn load(&self, id: CheckpointId) -> Result<ExecutionCheckpoint, String>;
fn load_latest(&self, execution_id: Uuid) -> Result<ExecutionCheckpoint, String>;
fn list_by_workflow(&self, workflow_id: WorkflowId) -> Vec<ExecutionCheckpoint>;
fn list_by_execution(&self, execution_id: Uuid) -> Vec<ExecutionCheckpoint>;
fn delete(&self, id: CheckpointId) -> Result<(), String>;
fn delete_by_execution(&self, execution_id: Uuid) -> Result<(), String>;
}
pub struct FileCheckpointStore {
base_path: PathBuf,
checkpoints: Arc<RwLock<HashMap<CheckpointId, ExecutionCheckpoint>>>,
}
impl FileCheckpointStore {
pub fn new<P: AsRef<Path>>(base_path: P) -> Result<Self, String> {
let base_path = base_path.as_ref().to_path_buf();
fs::create_dir_all(&base_path)
.map_err(|e| format!("Failed to create checkpoint directory: {}", e))?;
let mut store = Self {
base_path,
checkpoints: Arc::new(RwLock::new(HashMap::new())),
};
store.load_all()?;
Ok(store)
}
fn load_all(&mut self) -> Result<(), String> {
let entries = fs::read_dir(&self.base_path)
.map_err(|e| format!("Failed to read checkpoint directory: {}", e))?;
for entry in entries.flatten() {
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("json") {
if let Ok(content) = fs::read_to_string(&path) {
if let Ok(checkpoint) = serde_json::from_str::<ExecutionCheckpoint>(&content) {
self.checkpoints
.write()
.unwrap()
.insert(checkpoint.id, checkpoint);
}
}
}
}
Ok(())
}
fn checkpoint_path(&self, id: CheckpointId) -> PathBuf {
self.base_path.join(format!("{}.json", id))
}
}
impl CheckpointStore for FileCheckpointStore {
fn save(&self, checkpoint: &ExecutionCheckpoint) -> Result<CheckpointId, String> {
let path = self.checkpoint_path(checkpoint.id);
let json = serde_json::to_string_pretty(checkpoint)
.map_err(|e| format!("Failed to serialize checkpoint: {}", e))?;
fs::write(&path, json).map_err(|e| format!("Failed to write checkpoint file: {}", e))?;
self.checkpoints
.write()
.unwrap()
.insert(checkpoint.id, checkpoint.clone());
tracing::info!(
"Saved checkpoint {} for execution {}",
checkpoint.id,
checkpoint.execution_id
);
Ok(checkpoint.id)
}
fn load(&self, id: CheckpointId) -> Result<ExecutionCheckpoint, String> {
self.checkpoints
.read()
.unwrap()
.get(&id)
.cloned()
.ok_or_else(|| format!("Checkpoint {} not found", id))
}
fn load_latest(&self, execution_id: Uuid) -> Result<ExecutionCheckpoint, String> {
let checkpoints = self.list_by_execution(execution_id);
checkpoints
.into_iter()
.max_by_key(|c| c.created_at)
.ok_or_else(|| format!("No checkpoints found for execution {}", execution_id))
}
fn list_by_workflow(&self, workflow_id: WorkflowId) -> Vec<ExecutionCheckpoint> {
self.checkpoints
.read()
.unwrap()
.values()
.filter(|c| c.workflow_id == workflow_id)
.cloned()
.collect()
}
fn list_by_execution(&self, execution_id: Uuid) -> Vec<ExecutionCheckpoint> {
self.checkpoints
.read()
.unwrap()
.values()
.filter(|c| c.execution_id == execution_id)
.cloned()
.collect()
}
fn delete(&self, id: CheckpointId) -> Result<(), String> {
let path = self.checkpoint_path(id);
if path.exists() {
fs::remove_file(&path)
.map_err(|e| format!("Failed to delete checkpoint file: {}", e))?;
}
self.checkpoints.write().unwrap().remove(&id);
tracing::info!("Deleted checkpoint {}", id);
Ok(())
}
fn delete_by_execution(&self, execution_id: Uuid) -> Result<(), String> {
let checkpoints = self.list_by_execution(execution_id);
for checkpoint in checkpoints {
self.delete(checkpoint.id)?;
}
Ok(())
}
}
pub struct InMemoryCheckpointStore {
checkpoints: Arc<RwLock<HashMap<CheckpointId, ExecutionCheckpoint>>>,
}
impl InMemoryCheckpointStore {
pub fn new() -> Self {
Self {
checkpoints: Arc::new(RwLock::new(HashMap::new())),
}
}
}
impl Default for InMemoryCheckpointStore {
fn default() -> Self {
Self::new()
}
}
impl CheckpointStore for InMemoryCheckpointStore {
fn save(&self, checkpoint: &ExecutionCheckpoint) -> Result<CheckpointId, String> {
self.checkpoints
.write()
.unwrap()
.insert(checkpoint.id, checkpoint.clone());
Ok(checkpoint.id)
}
fn load(&self, id: CheckpointId) -> Result<ExecutionCheckpoint, String> {
self.checkpoints
.read()
.unwrap()
.get(&id)
.cloned()
.ok_or_else(|| format!("Checkpoint {} not found", id))
}
fn load_latest(&self, execution_id: Uuid) -> Result<ExecutionCheckpoint, String> {
let checkpoints = self.list_by_execution(execution_id);
checkpoints
.into_iter()
.max_by_key(|c| c.created_at)
.ok_or_else(|| format!("No checkpoints found for execution {}", execution_id))
}
fn list_by_workflow(&self, workflow_id: WorkflowId) -> Vec<ExecutionCheckpoint> {
self.checkpoints
.read()
.unwrap()
.values()
.filter(|c| c.workflow_id == workflow_id)
.cloned()
.collect()
}
fn list_by_execution(&self, execution_id: Uuid) -> Vec<ExecutionCheckpoint> {
self.checkpoints
.read()
.unwrap()
.values()
.filter(|c| c.execution_id == execution_id)
.cloned()
.collect()
}
fn delete(&self, id: CheckpointId) -> Result<(), String> {
self.checkpoints.write().unwrap().remove(&id);
Ok(())
}
fn delete_by_execution(&self, execution_id: Uuid) -> Result<(), String> {
let checkpoints = self.list_by_execution(execution_id);
for checkpoint in checkpoints {
self.delete(checkpoint.id)?;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_checkpoint_creation() {
let ctx = ExecutionContext::new(Uuid::new_v4());
let checkpoint = ExecutionCheckpoint::new(
Uuid::new_v4(),
Uuid::new_v4(),
ctx,
vec![],
0,
"test".to_string(),
);
assert_eq!(checkpoint.current_level, 0);
assert!(!checkpoint.paused);
assert_eq!(checkpoint.reason, "test");
}
#[test]
fn test_in_memory_store() {
let store = InMemoryCheckpointStore::new();
let ctx = ExecutionContext::new(Uuid::new_v4());
let execution_id = Uuid::new_v4();
let checkpoint = ExecutionCheckpoint::new(
Uuid::new_v4(),
execution_id,
ctx,
vec![],
0,
"test".to_string(),
);
let id = store.save(&checkpoint).unwrap();
let loaded = store.load(id).unwrap();
assert_eq!(loaded.id, checkpoint.id);
assert_eq!(loaded.execution_id, execution_id);
let checkpoints = store.list_by_execution(execution_id);
assert_eq!(checkpoints.len(), 1);
store.delete(id).unwrap();
assert!(store.load(id).is_err());
}
#[test]
fn test_load_latest() {
let store = InMemoryCheckpointStore::new();
let ctx = ExecutionContext::new(Uuid::new_v4());
let execution_id = Uuid::new_v4();
for i in 0..3 {
let mut checkpoint = ExecutionCheckpoint::new(
Uuid::new_v4(),
execution_id,
ctx.clone(),
vec![],
i,
format!("checkpoint_{}", i),
);
std::thread::sleep(std::time::Duration::from_millis(10));
checkpoint.created_at = std::time::SystemTime::now();
store.save(&checkpoint).unwrap();
}
let latest = store.load_latest(execution_id).unwrap();
assert_eq!(latest.current_level, 2);
}
}