1use serde::{Deserialize, Serialize};
9use std::path::{Path, PathBuf};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct Checkpoint {
14 pub id: String,
16 pub agent_name: String,
18 pub session_id: String,
20 pub step: usize,
22 pub messages: Vec<CheckpointMessage>,
24 pub tool_calls: Vec<ToolCallRecord>,
26 pub partial_results: Vec<String>,
28 pub timestamp: u64,
30 pub status: CheckpointStatus,
32}
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct CheckpointMessage {
36 pub role: String, pub content: String,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
41pub struct ToolCallRecord {
42 pub tool_name: String,
43 pub arguments: String,
44 pub result: Option<String>,
45 pub success: bool,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
49pub enum CheckpointStatus {
50 InProgress,
52 Completed,
54 Failed(String),
56 Halted(String),
58}
59
60pub struct CheckpointManager {
62 checkpoint_dir: PathBuf,
64}
65
66impl CheckpointManager {
67 pub fn new(checkpoint_dir: &Path) -> std::io::Result<Self> {
69 std::fs::create_dir_all(checkpoint_dir)?;
70 Ok(Self {
71 checkpoint_dir: checkpoint_dir.to_path_buf(),
72 })
73 }
74
75 pub fn default_dir() -> std::io::Result<Self> {
77 let dir = dirs_or_default().join("checkpoints");
78 Self::new(&dir)
79 }
80
81 pub fn save(&self, checkpoint: &Checkpoint) -> std::io::Result<()> {
83 let filename = format!("{}_{}.json", checkpoint.session_id, checkpoint.step);
84 let path = self.checkpoint_dir.join(&filename);
85 let json = serde_json::to_string_pretty(checkpoint)
86 .map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))?;
87 std::fs::write(&path, json)?;
88
89 let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", checkpoint.session_id));
91 std::fs::write(&latest_path, &filename)?;
92
93 Ok(())
94 }
95
96 pub fn load_latest(&self, session_id: &str) -> std::io::Result<Option<Checkpoint>> {
98 let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", session_id));
99 if !latest_path.exists() {
100 return Ok(None);
101 }
102
103 let filename = std::fs::read_to_string(&latest_path)?;
104 let checkpoint_path = self.checkpoint_dir.join(filename.trim());
105 if !checkpoint_path.exists() {
106 return Ok(None);
107 }
108
109 let json = std::fs::read_to_string(&checkpoint_path)?;
110 let checkpoint: Checkpoint = serde_json::from_str(&json)
111 .map_err(|e| std::io::Error::new(std::io::ErrorKind::InvalidData, e))?;
112 Ok(Some(checkpoint))
113 }
114
115 pub fn list_checkpoints(&self, session_id: &str) -> std::io::Result<Vec<Checkpoint>> {
117 let mut checkpoints = Vec::new();
118 let prefix = format!("{}_", session_id);
119
120 for entry in std::fs::read_dir(&self.checkpoint_dir)? {
121 let entry = entry?;
122 let name = entry.file_name().to_string_lossy().to_string();
123 if name.starts_with(&prefix) && name.ends_with(".json") && !name.contains("latest") {
124 let json = std::fs::read_to_string(entry.path())?;
125 if let Ok(cp) = serde_json::from_str::<Checkpoint>(&json) {
126 checkpoints.push(cp);
127 }
128 }
129 }
130
131 checkpoints.sort_by_key(|c| c.step);
132 Ok(checkpoints)
133 }
134
135 pub fn cleanup(&self, session_id: &str) -> std::io::Result<usize> {
137 let mut removed = 0;
138 let prefix = format!("{}_", session_id);
139
140 for entry in std::fs::read_dir(&self.checkpoint_dir)? {
141 let entry = entry?;
142 let name = entry.file_name().to_string_lossy().to_string();
143 if name.starts_with(&prefix) {
144 std::fs::remove_file(entry.path())?;
145 removed += 1;
146 }
147 }
148
149 Ok(removed)
150 }
151
152 pub fn has_checkpoint(&self, session_id: &str) -> bool {
154 let latest_path = self.checkpoint_dir.join(format!("{}_latest.json", session_id));
155 latest_path.exists()
156 }
157}
158
159fn dirs_or_default() -> PathBuf {
160 dirs::data_dir()
161 .unwrap_or_else(|| PathBuf::from("/tmp"))
162 .join("ares")
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use std::time::{SystemTime, UNIX_EPOCH};
169
170 fn temp_dir() -> tempfile::TempDir {
171 tempfile::tempdir().unwrap()
172 }
173
174 fn sample_checkpoint(session: &str, step: usize) -> Checkpoint {
175 Checkpoint {
176 id: format!("{}-{}", session, step),
177 agent_name: "test-agent".into(),
178 session_id: session.into(),
179 step,
180 messages: vec![
181 CheckpointMessage { role: "user".into(), content: "Hello".into() },
182 CheckpointMessage { role: "assistant".into(), content: "Hi there".into() },
183 ],
184 tool_calls: vec![
185 ToolCallRecord {
186 tool_name: "search".into(),
187 arguments: "query".into(),
188 result: Some("found it".into()),
189 success: true,
190 },
191 ],
192 partial_results: vec!["partial output".into()],
193 timestamp: SystemTime::now().duration_since(UNIX_EPOCH).unwrap().as_secs(),
194 status: CheckpointStatus::InProgress,
195 }
196 }
197
198 #[test]
199 fn test_save_and_load() {
200 let dir = temp_dir();
201 let mgr = CheckpointManager::new(dir.path()).unwrap();
202 let cp = sample_checkpoint("sess1", 0);
203
204 mgr.save(&cp).unwrap();
205 let loaded = mgr.load_latest("sess1").unwrap();
206 assert!(loaded.is_some());
207 let loaded = loaded.unwrap();
208 assert_eq!(loaded.session_id, "sess1");
209 assert_eq!(loaded.step, 0);
210 assert_eq!(loaded.messages.len(), 2);
211 assert_eq!(loaded.tool_calls.len(), 1);
212 }
213
214 #[test]
215 fn test_load_nonexistent() {
216 let dir = temp_dir();
217 let mgr = CheckpointManager::new(dir.path()).unwrap();
218 let loaded = mgr.load_latest("nonexistent").unwrap();
219 assert!(loaded.is_none());
220 }
221
222 #[test]
223 fn test_multiple_steps() {
224 let dir = temp_dir();
225 let mgr = CheckpointManager::new(dir.path()).unwrap();
226
227 mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
228 mgr.save(&sample_checkpoint("sess1", 1)).unwrap();
229 mgr.save(&sample_checkpoint("sess1", 2)).unwrap();
230
231 let latest = mgr.load_latest("sess1").unwrap().unwrap();
232 assert_eq!(latest.step, 2, "latest should be step 2");
233
234 let all = mgr.list_checkpoints("sess1").unwrap();
235 assert_eq!(all.len(), 3);
236 assert_eq!(all[0].step, 0);
237 assert_eq!(all[2].step, 2);
238 }
239
240 #[test]
241 fn test_cleanup() {
242 let dir = temp_dir();
243 let mgr = CheckpointManager::new(dir.path()).unwrap();
244
245 mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
246 mgr.save(&sample_checkpoint("sess1", 1)).unwrap();
247 assert!(mgr.has_checkpoint("sess1"));
248
249 let removed = mgr.cleanup("sess1").unwrap();
250 assert!(removed >= 2);
251 assert!(!mgr.has_checkpoint("sess1"));
252 }
253
254 #[test]
255 fn test_separate_sessions() {
256 let dir = temp_dir();
257 let mgr = CheckpointManager::new(dir.path()).unwrap();
258
259 mgr.save(&sample_checkpoint("sess1", 0)).unwrap();
260 mgr.save(&sample_checkpoint("sess2", 0)).unwrap();
261
262 assert!(mgr.has_checkpoint("sess1"));
263 assert!(mgr.has_checkpoint("sess2"));
264
265 mgr.cleanup("sess1").unwrap();
266 assert!(!mgr.has_checkpoint("sess1"));
267 assert!(mgr.has_checkpoint("sess2"));
268 }
269
270 #[test]
271 fn test_checkpoint_status_serialization() {
272 let dir = temp_dir();
273 let mgr = CheckpointManager::new(dir.path()).unwrap();
274
275 let mut cp = sample_checkpoint("sess1", 0);
276 cp.status = CheckpointStatus::Failed("OOM".into());
277 mgr.save(&cp).unwrap();
278
279 let loaded = mgr.load_latest("sess1").unwrap().unwrap();
280 assert_eq!(loaded.status, CheckpointStatus::Failed("OOM".into()));
281 }
282}