Skip to main content

ares/agents/
checkpoint.rs

1//! Agent checkpoint/crash recovery system.
2//!
3//! Serializes agent state to disk before each step. On restart,
4//! restores from the latest checkpoint and resumes execution.
5//!
6//! Inspired by Octopoda-OS crash recovery patterns.
7
8use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10
11/// A checkpoint captures agent state at a point in time.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Checkpoint {
14    /// Unique checkpoint ID
15    pub id: String,
16    /// Agent name/type
17    pub agent_name: String,
18    /// Session ID
19    pub session_id: String,
20    /// Step number (0-indexed)
21    pub step: usize,
22    /// Conversation messages so far
23    pub messages: Vec<CheckpointMessage>,
24    /// Tool calls made and their results
25    pub tool_calls: Vec<ToolCallRecord>,
26    /// Partial results accumulated
27    pub partial_results: Vec<String>,
28    /// Timestamp (Unix epoch seconds)
29    pub timestamp: u64,
30    /// Status of this checkpoint
31    pub status: CheckpointStatus,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CheckpointMessage {
36    pub role: String, // "user" | "assistant" | "system"
37    pub content: String,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ToolCallRecord {
42    pub tool_name: String,
43    pub arguments: String,
44    pub result: Option<String>,
45    pub success: bool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub enum CheckpointStatus {
50    /// Agent is actively running
51    InProgress,
52    /// Agent completed successfully
53    Completed,
54    /// Agent failed/crashed
55    Failed(String),
56    /// Agent was halted (e.g., by loop detector)
57    Halted(String),
58}
59
60/// Manages checkpoints for agent crash recovery.
61pub struct CheckpointManager {
62    /// Directory to store checkpoint files
63    checkpoint_dir: PathBuf,
64}
65
66impl CheckpointManager {
67    /// Create a new checkpoint manager.
68    pub fn new(checkpoint_dir: &Path) -> std::io::Result<Self> {
69        std::fs::create_dir_all(checkpoint_dir)?;
70        Ok(Self {
71            checkpoint_dir: checkpoint_dir.to_path_buf(),
72        })
73    }
74
75    /// Create a default checkpoint manager (~/.ares/checkpoints/).
76    pub fn default_dir() -> std::io::Result<Self> {
77        let dir = dirs_or_default().join("checkpoints");
78        Self::new(&dir)
79    }
80
81    /// Save a checkpoint to disk.
82    pub fn save(&self, checkpoint: &Checkpoint) -> std::io::Result<()> {
83        let filename = format!("{}_{}.json", checkpoint.session_id, checkpoint.step);
84        let path = self.checkpoint_dir.join(&filename);
85        let json = serde_json::to_string_pretty(checkpoint)
86            .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
87        std::fs::write(&path, json)?;
88
89        // Also update the "latest" symlink/pointer
90        let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", checkpoint.session_id));
91        std::fs::write(&latest_path, &filename)?;
92
93        Ok(())
94    }
95
96    /// Load the latest checkpoint for a session.
97    pub fn load_latest(&self, session_id: &str) -> std::io::Result<Option<Checkpoint>> {
98        let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", session_id));
99        if !latest_path.exists() {
100            return Ok(None);
101        }
102
103        let filename = std::fs::read_to_string(&latest_path)?;
104        let checkpoint_path = self.checkpoint_dir.join(filename.trim());
105        if !checkpoint_path.exists() {
106            return Ok(None);
107        }
108
109        let json = std::fs::read_to_string(&checkpoint_path)?;
110        let checkpoint: Checkpoint = serde_json::from_str(&json)
111            .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
112        Ok(Some(checkpoint))
113    }
114
115    /// List all checkpoints for a session, ordered by step.
116    pub fn list_checkpoints(&self, session_id: &str) -> std::io::Result<Vec<Checkpoint>> {
117        let mut checkpoints = Vec::new();
118        let prefix = format!("{}_", session_id);
119
120        for entry in std::fs::read_dir(&self.checkpoint_dir)? {
121            let entry = entry?;
122            let name = entry.file_name().to_string_lossy().to_string();
123            if name.starts_with(&prefix) && name.ends_with(".json") && !name.contains("latest") {
124                let json = std::fs::read_to_string(entry.path())?;
125                if let Ok(cp) = serde_json::from_str::<Checkpoint>(&json) {
126                    checkpoints.push(cp);
127                }
128            }
129        }
130
131        checkpoints.sort_by_key(|c| c.step);
132        Ok(checkpoints)
133    }
134
135    /// Clean up old checkpoints for a completed session.
136    pub fn cleanup(&self, session_id: &str) -> std::io::Result<usize> {
137        let mut removed = 0;
138        let prefix = format!("{}_", session_id);
139
140        for entry in std::fs::read_dir(&self.checkpoint_dir)? {
141            let entry = entry?;
142            let name = entry.file_name().to_string_lossy().to_string();
143            if name.starts_with(&prefix) {
144                std::fs::remove_file(entry.path())?;
145                removed += 1;
146            }
147        }
148
149        Ok(removed)
150    }
151
152    /// Check if a session has a recoverable checkpoint.
153    pub fn has_checkpoint(&self, session_id: &str) -> bool {
154        let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", session_id));
155        latest_path.exists()
156    }
157}
158
159fn dirs_or_default() -> PathBuf {
160    dirs::data_dir()
161        .unwrap_or_else(|| PathBuf::from("/tmp"))
162        .join("ares")
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use std::time::{SystemTime, UNIX_EPOCH};
169
170    fn temp_dir() -> tempfile::TempDir {
171        tempfile::tempdir().unwrap()
172    }
173
174    fn sample_checkpoint(session: &str, step: usize) -> Checkpoint {
175        Checkpoint {
176            id: format!("{}-{}", session, step),
177            agent_name: "test-agent".into(),
178            session_id: session.into(),
179            step,
180            messages: vec![
181                CheckpointMessage { role: "user".into(), content: "Hello".into() },
182                CheckpointMessage { role: "assistant".into(), content: "Hi there".into() },
183            ],
184            tool_calls: vec![
185                ToolCallRecord {
186                    tool_name: "search".into(),
187                    arguments: "query".into(),
188                    result: Some("found it".into()),
189                    success: true,
190                },
191            ],
192            partial_results: vec!["partial output".into()],
193            timestamp: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
194            status: CheckpointStatus::InProgress,
195        }
196    }
197
198    #[test]
199    fn test_save_and_load() {
200        let dir = temp_dir();
201        let mgr = CheckpointManager::new(dir.path()).unwrap();
202        let cp = sample_checkpoint("sess1", 0);
203
204        mgr.save(&cp).unwrap();
205        let loaded = mgr.load_latest("sess1").unwrap();
206        assert!(loaded.is_some());
207        let loaded = loaded.unwrap();
208        assert_eq!(loaded.session_id, "sess1");
209        assert_eq!(loaded.step, 0);
210        assert_eq!(loaded.messages.len(), 2);
211        assert_eq!(loaded.tool_calls.len(), 1);
212    }
213
214    #[test]
215    fn test_load_nonexistent() {
216        let dir = temp_dir();
217        let mgr = CheckpointManager::new(dir.path()).unwrap();
218        let loaded = mgr.load_latest("nonexistent").unwrap();
219        assert!(loaded.is_none());
220    }
221
222    #[test]
223    fn test_multiple_steps() {
224        let dir = temp_dir();
225        let mgr = CheckpointManager::new(dir.path()).unwrap();
226
227        mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
228        mgr.save(&sample_checkpoint("sess1", 1)).unwrap();
229        mgr.save(&sample_checkpoint("sess1", 2)).unwrap();
230
231        let latest = mgr.load_latest("sess1").unwrap().unwrap();
232        assert_eq!(latest.step, 2, "latest should be step 2");
233
234        let all = mgr.list_checkpoints("sess1").unwrap();
235        assert_eq!(all.len(), 3);
236        assert_eq!(all[0].step, 0);
237        assert_eq!(all[2].step, 2);
238    }
239
240    #[test]
241    fn test_cleanup() {
242        let dir = temp_dir();
243        let mgr = CheckpointManager::new(dir.path()).unwrap();
244
245        mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
246        mgr.save(&sample_checkpoint("sess1", 1)).unwrap();
247        assert!(mgr.has_checkpoint("sess1"));
248
249        let removed = mgr.cleanup("sess1").unwrap();
250        assert!(removed >= 2);
251        assert!(!mgr.has_checkpoint("sess1"));
252    }
253
254    #[test]
255    fn test_separate_sessions() {
256        let dir = temp_dir();
257        let mgr = CheckpointManager::new(dir.path()).unwrap();
258
259        mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
260        mgr.save(&sample_checkpoint("sess2", 0)).unwrap();
261
262        assert!(mgr.has_checkpoint("sess1"));
263        assert!(mgr.has_checkpoint("sess2"));
264
265        mgr.cleanup("sess1").unwrap();
266        assert!(!mgr.has_checkpoint("sess1"));
267        assert!(mgr.has_checkpoint("sess2"));
268    }
269
270    #[test]
271    fn test_checkpoint_status_serialization() {
272        let dir = temp_dir();
273        let mgr = CheckpointManager::new(dir.path()).unwrap();
274
275        let mut cp = sample_checkpoint("sess1", 0);
276        cp.status = CheckpointStatus::Failed("OOM".into());
277        mgr.save(&cp).unwrap();
278
279        let loaded = mgr.load_latest("sess1").unwrap().unwrap();
280        assert_eq!(loaded.status, CheckpointStatus::Failed("OOM".into()));
281    }
282}