scud/attractor/
checkpoint.rs1use anyhow::{Context, Result};
4use serde::{Deserialize, Serialize};
5use std::collections::HashMap;
6use std::path::Path;
7
8use super::context::ContextSnapshot;
9use super::outcome::StageStatus;
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Checkpoint {
14 pub timestamp: String,
16 pub current_node: String,
18 pub completed_nodes: Vec<String>,
20 pub node_retries: HashMap<String, u32>,
22 pub node_statuses: HashMap<String, StageStatus>,
24 pub context: ContextSnapshot,
26 pub log: Vec<LogEntry>,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct LogEntry {
33 pub timestamp: String,
34 pub node_id: String,
35 pub message: String,
36}
37
38impl Checkpoint {
39 pub fn new(current_node: &str, context: ContextSnapshot) -> Self {
41 Self {
42 timestamp: chrono::Utc::now().to_rfc3339(),
43 current_node: current_node.to_string(),
44 completed_nodes: vec![],
45 node_retries: HashMap::new(),
46 node_statuses: HashMap::new(),
47 context,
48 log: vec![],
49 }
50 }
51
52 pub fn mark_completed(&mut self, node_id: &str, status: StageStatus) {
54 if !self.completed_nodes.contains(&node_id.to_string()) {
55 self.completed_nodes.push(node_id.to_string());
56 }
57 self.node_statuses.insert(node_id.to_string(), status);
58 }
59
60 pub fn increment_retry(&mut self, node_id: &str) -> u32 {
62 let count = self.node_retries.entry(node_id.to_string()).or_insert(0);
63 *count += 1;
64 *count
65 }
66
67 pub fn retry_count(&self, node_id: &str) -> u32 {
69 self.node_retries.get(node_id).copied().unwrap_or(0)
70 }
71
72 pub fn log(&mut self, node_id: &str, message: impl Into<String>) {
74 self.log.push(LogEntry {
75 timestamp: chrono::Utc::now().to_rfc3339(),
76 node_id: node_id.to_string(),
77 message: message.into(),
78 });
79 }
80
81 pub fn save(&self, path: &Path) -> Result<()> {
83 let json = serde_json::to_string_pretty(self).context("Failed to serialize checkpoint")?;
84 std::fs::write(path, json).context("Failed to write checkpoint file")?;
85 Ok(())
86 }
87
88 pub fn load(path: &Path) -> Result<Self> {
90 let json = std::fs::read_to_string(path).context("Failed to read checkpoint file")?;
91 let checkpoint: Self =
92 serde_json::from_str(&json).context("Failed to deserialize checkpoint")?;
93 Ok(checkpoint)
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100
101 #[test]
102 fn test_checkpoint_roundtrip() {
103 let dir = tempfile::tempdir().unwrap();
104 let path = dir.path().join("checkpoint.json");
105
106 let mut ctx_values = HashMap::new();
107 ctx_values.insert("key".into(), serde_json::json!("value"));
108 let snapshot = ContextSnapshot::from(ctx_values);
109
110 let mut cp = Checkpoint::new("node_a", snapshot);
111 cp.mark_completed("node_a", StageStatus::Success);
112 cp.increment_retry("node_b");
113 cp.log("node_a", "Did something");
114
115 cp.save(&path).unwrap();
116 let loaded = Checkpoint::load(&path).unwrap();
117
118 assert_eq!(loaded.current_node, "node_a");
119 assert_eq!(loaded.completed_nodes, vec!["node_a"]);
120 assert_eq!(loaded.retry_count("node_b"), 1);
121 assert_eq!(loaded.log.len(), 1);
122 }
123}