Skip to main content

scud/attractor/
checkpoint.rs

1//! Checkpoint save/load for pipeline resumption.
2
3use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7
8use super::context::ContextSnapshot;
9use super::outcome::StageStatus;
10
11/// Serializable checkpoint for a pipeline run.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Checkpoint {
14    /// ISO 8601 timestamp of when this checkpoint was created.
15    pub timestamp: String,
16    /// Current node ID being executed (or last completed).
17    pub current_node: String,
18    /// Set of completed node IDs.
19    pub completed_nodes: Vec<String>,
20    /// Retry counts per node.
21    pub node_retries: HashMap<String, u32>,
22    /// Status of each visited node.
23    pub node_statuses: HashMap<String, StageStatus>,
24    /// Context snapshot at checkpoint time.
25    pub context: ContextSnapshot,
26    /// Execution log entries.
27    pub log: Vec<LogEntry>,
28}
29
30/// A log entry in the checkpoint.
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct LogEntry {
33    pub timestamp: String,
34    pub node_id: String,
35    pub message: String,
36}
37
38impl Checkpoint {
39    /// Create a new checkpoint at the given node.
40    pub fn new(current_node: &str, context: ContextSnapshot) -> Self {
41        Self {
42            timestamp: chrono::Utc::now().to_rfc3339(),
43            current_node: current_node.to_string(),
44            completed_nodes: vec![],
45            node_retries: HashMap::new(),
46            node_statuses: HashMap::new(),
47            context,
48            log: vec![],
49        }
50    }
51
52    /// Mark a node as completed.
53    pub fn mark_completed(&mut self, node_id: &str, status: StageStatus) {
54        if !self.completed_nodes.contains(&node_id.to_string()) {
55            self.completed_nodes.push(node_id.to_string());
56        }
57        self.node_statuses.insert(node_id.to_string(), status);
58    }
59
60    /// Increment retry count for a node.
61    pub fn increment_retry(&mut self, node_id: &str) -> u32 {
62        let count = self.node_retries.entry(node_id.to_string()).or_insert(0);
63        *count += 1;
64        *count
65    }
66
67    /// Get retry count for a node.
68    pub fn retry_count(&self, node_id: &str) -> u32 {
69        self.node_retries.get(node_id).copied().unwrap_or(0)
70    }
71
72    /// Add a log entry.
73    pub fn log(&mut self, node_id: &str, message: impl Into<String>) {
74        self.log.push(LogEntry {
75            timestamp: chrono::Utc::now().to_rfc3339(),
76            node_id: node_id.to_string(),
77            message: message.into(),
78        });
79    }
80
81    /// Save checkpoint to a file.
82    pub fn save(&self, path: &Path) -> Result<()> {
83        let json = serde_json::to_string_pretty(self).context("Failed to serialize checkpoint")?;
84        std::fs::write(path, json).context("Failed to write checkpoint file")?;
85        Ok(())
86    }
87
88    /// Load checkpoint from a file.
89    pub fn load(path: &Path) -> Result<Self> {
90        let json = std::fs::read_to_string(path).context("Failed to read checkpoint file")?;
91        let checkpoint: Self =
92            serde_json::from_str(&json).context("Failed to deserialize checkpoint")?;
93        Ok(checkpoint)
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100
101    #[test]
102    fn test_checkpoint_roundtrip() {
103        let dir = tempfile::tempdir().unwrap();
104        let path = dir.path().join("checkpoint.json");
105
106        let mut ctx_values = HashMap::new();
107        ctx_values.insert("key".into(), serde_json::json!("value"));
108        let snapshot = ContextSnapshot::from(ctx_values);
109
110        let mut cp = Checkpoint::new("node_a", snapshot);
111        cp.mark_completed("node_a", StageStatus::Success);
112        cp.increment_retry("node_b");
113        cp.log("node_a", "Did something");
114
115        cp.save(&path).unwrap();
116        let loaded = Checkpoint::load(&path).unwrap();
117
118        assert_eq!(loaded.current_node, "node_a");
119        assert_eq!(loaded.completed_nodes, vec!["node_a"]);
120        assert_eq!(loaded.retry_count("node_b"), 1);
121        assert_eq!(loaded.log.len(), 1);
122    }
123}