use std::collections::{HashMap, HashSet};
use std::path::{Path, PathBuf};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{SystemTime, UNIX_EPOCH};
use async_trait::async_trait;
use serde::{Deserialize, Serialize};
use crate::{DagrsError, DagrsResult, ErrorCode, node::NodeId};
pub type CheckpointId = String;
static CHECKPOINT_SEQUENCE: AtomicU64 = AtomicU64::new(0);
fn checkpoint_sequence_hint(id: &str) -> u64 {
id.rsplit('_')
.next()
.and_then(|segment| segment.parse().ok())
.unwrap_or(0)
}
pub(crate) fn checkpoint_cmp(left: &Checkpoint, right: &Checkpoint) -> std::cmp::Ordering {
left.timestamp
.cmp(&right.timestamp)
.then_with(|| checkpoint_sequence_hint(&left.id).cmp(&checkpoint_sequence_hint(&right.id)))
.then_with(|| left.id.cmp(&right.id))
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum NodeExecStatus {
Pending,
Running,
Succeeded,
Failed,
Skipped,
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq)]
pub enum StoredOutputKind {
String,
I32,
I64,
U32,
U64,
F64,
Bool,
VecString,
VecI32,
VecI64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NodeState {
pub node_id: usize,
pub status: NodeExecStatus,
pub output_data: Option<Vec<u8>>,
#[serde(default)]
pub output_kind: Option<StoredOutputKind>,
#[serde(default)]
pub output_summary: Option<String>,
}
impl NodeState {
pub fn succeeded(node_id: usize) -> Self {
Self {
node_id,
status: NodeExecStatus::Succeeded,
output_data: None,
output_kind: None,
output_summary: None,
}
}
pub fn failed(node_id: usize) -> Self {
Self {
node_id,
status: NodeExecStatus::Failed,
output_data: None,
output_kind: None,
output_summary: None,
}
}
pub fn pending(node_id: usize) -> Self {
Self {
node_id,
status: NodeExecStatus::Pending,
output_data: None,
output_kind: None,
output_summary: None,
}
}
pub fn running(node_id: usize) -> Self {
Self {
node_id,
status: NodeExecStatus::Running,
output_data: None,
output_kind: None,
output_summary: None,
}
}
pub fn skipped(node_id: usize) -> Self {
Self {
node_id,
status: NodeExecStatus::Skipped,
output_data: None,
output_kind: None,
output_summary: None,
}
}
pub fn completed(node_id: usize, success: bool) -> Self {
if success {
Self::succeeded(node_id)
} else {
Self::failed(node_id)
}
}
pub fn with_output_data(mut self, data: Vec<u8>) -> Self {
self.output_data = Some(data);
self
}
pub fn with_output_kind(mut self, kind: StoredOutputKind) -> Self {
self.output_kind = Some(kind);
self
}
pub fn with_summary(mut self, summary: impl Into<String>) -> Self {
self.output_summary = Some(summary.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Checkpoint {
pub id: CheckpointId,
pub timestamp: u64,
pub pc: usize,
pub loop_count: usize,
pub active_nodes: HashSet<usize>,
pub node_states: HashMap<usize, NodeState>,
pub env_data: Option<Vec<u8>>,
pub metadata: HashMap<String, String>,
}
impl Checkpoint {
fn current_timestamp_nanos() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
}
pub fn new(pc: usize, loop_count: usize) -> Self {
let timestamp = Self::current_timestamp_nanos();
let sequence = CHECKPOINT_SEQUENCE.fetch_add(1, Ordering::Relaxed);
let id = format!("ckpt_{}_{}_{}", timestamp, pc, sequence);
Self {
id,
timestamp,
pc,
loop_count,
active_nodes: HashSet::new(),
node_states: HashMap::new(),
env_data: None,
metadata: HashMap::new(),
}
}
pub fn with_id(id: impl Into<String>, pc: usize, loop_count: usize) -> Self {
let timestamp = Self::current_timestamp_nanos();
Self {
id: id.into(),
timestamp,
pc,
loop_count,
active_nodes: HashSet::new(),
node_states: HashMap::new(),
env_data: None,
metadata: HashMap::new(),
}
}
pub fn set_active_nodes(&mut self, nodes: &HashSet<NodeId>) {
self.active_nodes = nodes.iter().map(|id| id.0).collect();
}
pub fn get_active_nodes(&self) -> HashSet<NodeId> {
self.active_nodes.iter().map(|id| NodeId(*id)).collect()
}
pub fn add_node_state(&mut self, state: NodeState) {
self.node_states.insert(state.node_id, state);
}
pub fn add_metadata(&mut self, key: impl Into<String>, value: impl Into<String>) {
self.metadata.insert(key.into(), value.into());
}
}
#[async_trait]
pub trait CheckpointStore: Send + Sync {
async fn save(&self, checkpoint: &Checkpoint) -> DagrsResult<()>;
async fn load(&self, id: &CheckpointId) -> DagrsResult<Checkpoint>;
async fn delete(&self, id: &CheckpointId) -> DagrsResult<()>;
async fn list(&self) -> DagrsResult<Vec<CheckpointId>>;
async fn latest(&self) -> DagrsResult<Option<Checkpoint>>;
async fn clear(&self) -> DagrsResult<()>;
}
#[derive(Default)]
pub struct MemoryCheckpointStore {
checkpoints: std::sync::RwLock<HashMap<CheckpointId, Checkpoint>>,
}
impl MemoryCheckpointStore {
pub fn new() -> Self {
Self {
checkpoints: std::sync::RwLock::new(HashMap::new()),
}
}
}
#[async_trait]
impl CheckpointStore for MemoryCheckpointStore {
async fn save(&self, checkpoint: &Checkpoint) -> DagrsResult<()> {
let mut store = self
.checkpoints
.write()
.map_err(|e| checkpoint_io_error(e.to_string()))?;
store.insert(checkpoint.id.clone(), checkpoint.clone());
Ok(())
}
async fn load(&self, id: &CheckpointId) -> DagrsResult<Checkpoint> {
let store = self
.checkpoints
.read()
.map_err(|e| checkpoint_io_error(e.to_string()))?;
store
.get(id)
.cloned()
.ok_or_else(|| checkpoint_not_found(id))
}
async fn delete(&self, id: &CheckpointId) -> DagrsResult<()> {
let mut store = self
.checkpoints
.write()
.map_err(|e| checkpoint_io_error(e.to_string()))?;
store.remove(id);
Ok(())
}
async fn list(&self) -> DagrsResult<Vec<CheckpointId>> {
let store = self
.checkpoints
.read()
.map_err(|e| checkpoint_io_error(e.to_string()))?;
Ok(store.keys().cloned().collect())
}
async fn latest(&self) -> DagrsResult<Option<Checkpoint>> {
let store = self
.checkpoints
.read()
.map_err(|e| checkpoint_io_error(e.to_string()))?;
Ok(store
.values()
.max_by(|left, right| checkpoint_cmp(left, right))
.cloned())
}
async fn clear(&self) -> DagrsResult<()> {
let mut store = self
.checkpoints
.write()
.map_err(|e| checkpoint_io_error(e.to_string()))?;
store.clear();
Ok(())
}
}
pub struct FileCheckpointStore {
base_path: PathBuf,
}
impl FileCheckpointStore {
pub fn new(base_path: impl AsRef<Path>) -> Self {
Self {
base_path: base_path.as_ref().to_path_buf(),
}
}
fn checkpoint_path(&self, id: &CheckpointId) -> DagrsResult<PathBuf> {
if id.contains('/') || id.contains('\\') || id.contains("..") {
return Err(DagrsError::new(
ErrorCode::DgChk0003InvalidCheckpoint,
"checkpoint id contains invalid characters",
)
.with_checkpoint(id.clone()));
}
Ok(self.base_path.join(format!("{}.json", id)))
}
async fn ensure_dir(&self) -> DagrsResult<()> {
tokio::fs::create_dir_all(&self.base_path)
.await
.map_err(|e| checkpoint_io_error(format!("Failed to create checkpoint directory: {e}")))
}
}
#[async_trait]
impl CheckpointStore for FileCheckpointStore {
async fn save(&self, checkpoint: &Checkpoint) -> DagrsResult<()> {
self.ensure_dir().await?;
let json = serde_json::to_string_pretty(checkpoint)
.map_err(|e| checkpoint_io_error(format!("Failed to serialize checkpoint: {e}")))?;
let path = self.checkpoint_path(&checkpoint.id)?;
tokio::fs::write(&path, json)
.await
.map_err(|e| checkpoint_io_error(format!("Failed to write checkpoint file: {e}")))?;
Ok(())
}
async fn load(&self, id: &CheckpointId) -> DagrsResult<Checkpoint> {
let path = self.checkpoint_path(id)?;
match tokio::fs::metadata(&path).await {
Ok(_) => {}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
return Err(checkpoint_not_found(id));
}
Err(e) => {
return Err(checkpoint_io_error(format!(
"Failed to check checkpoint file: {e}"
)));
}
}
let json = tokio::fs::read_to_string(&path)
.await
.map_err(|e| checkpoint_io_error(format!("Failed to read checkpoint file: {e}")))?;
serde_json::from_str(&json).map_err(|e| {
DagrsError::new(
ErrorCode::DgChk0003InvalidCheckpoint,
format!("Failed to deserialize checkpoint: {e}"),
)
.with_checkpoint(id.clone())
})
}
async fn delete(&self, id: &CheckpointId) -> DagrsResult<()> {
let path = self.checkpoint_path(id)?;
match tokio::fs::metadata(&path).await {
Ok(_) => {
tokio::fs::remove_file(&path).await.map_err(|e| {
checkpoint_io_error(format!("Failed to delete checkpoint file: {e}"))
})?;
}
Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
}
Err(e) => {
return Err(checkpoint_io_error(format!(
"Failed to check checkpoint file: {e}"
)));
}
}
Ok(())
}
async fn list(&self) -> DagrsResult<Vec<CheckpointId>> {
self.ensure_dir().await?;
let mut entries = tokio::fs::read_dir(&self.base_path).await.map_err(|e| {
checkpoint_io_error(format!("Failed to read checkpoint directory: {e}"))
})?;
let mut ids = Vec::new();
while let Some(entry) = entries
.next_entry()
.await
.map_err(|e| checkpoint_io_error(format!("Failed to read directory entry: {e}")))?
{
if let Some(name) = entry.file_name().to_str()
&& name.ends_with(".json")
{
ids.push(name.trim_end_matches(".json").to_string());
}
}
Ok(ids)
}
async fn latest(&self) -> DagrsResult<Option<Checkpoint>> {
let ids = self.list().await?;
let mut latest: Option<Checkpoint> = None;
for id in ids {
if let Ok(checkpoint) = self.load(&id).await
&& latest
.as_ref()
.is_none_or(|current| checkpoint_cmp(&checkpoint, current).is_gt())
{
latest = Some(checkpoint);
}
}
Ok(latest)
}
async fn clear(&self) -> DagrsResult<()> {
let ids = self.list().await?;
for id in ids {
self.delete(&id).await?;
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct CheckpointConfig {
pub enabled: bool,
pub interval_nodes: Option<usize>,
pub interval_seconds: Option<u64>,
pub on_loop_iteration: bool,
pub max_checkpoints: usize,
}
impl Default for CheckpointConfig {
fn default() -> Self {
Self {
enabled: false,
interval_nodes: Some(10),
interval_seconds: None,
on_loop_iteration: true,
max_checkpoints: 5,
}
}
}
fn checkpoint_not_found(id: &CheckpointId) -> DagrsError {
DagrsError::new(
ErrorCode::DgChk0002CheckpointNotFound,
"checkpoint not found",
)
.with_checkpoint(id.clone())
}
fn checkpoint_io_error(message: impl Into<String>) -> DagrsError {
DagrsError::new(ErrorCode::DgChk0004CheckpointIo, message.into())
}
impl CheckpointConfig {
pub fn enabled() -> Self {
Self {
enabled: true,
..Default::default()
}
}
pub fn with_node_interval(mut self, interval: usize) -> Self {
self.interval_nodes = Some(interval);
self
}
pub fn with_time_interval(mut self, seconds: u64) -> Self {
self.interval_seconds = Some(seconds);
self
}
pub fn with_loop_checkpoint(mut self, enabled: bool) -> Self {
self.on_loop_iteration = enabled;
self
}
pub fn with_max_checkpoints(mut self, max: usize) -> Self {
self.max_checkpoints = max;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_checkpoint_store() {
let store = MemoryCheckpointStore::new();
let mut checkpoint = Checkpoint::new(5, 2);
checkpoint.add_metadata("test_key", "test_value");
checkpoint.active_nodes.insert(1);
checkpoint.active_nodes.insert(2);
store.save(&checkpoint).await.unwrap();
let loaded = store.load(&checkpoint.id).await.unwrap();
assert_eq!(loaded.pc, 5);
assert_eq!(loaded.loop_count, 2);
assert!(loaded.active_nodes.contains(&1));
assert!(loaded.active_nodes.contains(&2));
assert_eq!(
loaded.metadata.get("test_key"),
Some(&"test_value".to_string())
);
let ids = store.list().await.unwrap();
assert_eq!(ids.len(), 1);
let latest = store.latest().await.unwrap();
assert!(latest.is_some());
store.delete(&checkpoint.id).await.unwrap();
let ids = store.list().await.unwrap();
assert_eq!(ids.len(), 0);
}
#[test]
fn test_checkpoint_creation() {
let checkpoint = Checkpoint::new(10, 3);
let another = Checkpoint::new(10, 3);
assert_eq!(checkpoint.pc, 10);
assert_eq!(checkpoint.loop_count, 3);
assert!(checkpoint.id.starts_with("ckpt_"));
assert_ne!(checkpoint.id, another.id);
assert!(another.timestamp >= checkpoint.timestamp);
}
#[test]
fn test_checkpoint_config() {
let config = CheckpointConfig::enabled()
.with_node_interval(5)
.with_time_interval(60)
.with_max_checkpoints(10);
assert!(config.enabled);
assert_eq!(config.interval_nodes, Some(5));
assert_eq!(config.interval_seconds, Some(60));
assert_eq!(config.max_checkpoints, 10);
assert_eq!(NodeState::running(1).status, NodeExecStatus::Running);
assert_eq!(NodeState::skipped(2).status, NodeExecStatus::Skipped);
}
}