Skip to main content

aft/
checkpoint.rs

1use std::collections::HashMap;
2use std::path::PathBuf;
3
4use crate::backup::BackupStore;
5use crate::error::AftError;
6
7/// Metadata about a checkpoint, returned by list/create/restore.
8#[derive(Debug, Clone)]
9pub struct CheckpointInfo {
10    pub name: String,
11    pub file_count: usize,
12    pub created_at: u64,
13}
14
15/// A stored checkpoint: a snapshot of multiple file contents.
16#[derive(Debug, Clone)]
17struct Checkpoint {
18    name: String,
19    file_contents: HashMap<PathBuf, String>,
20    created_at: u64,
21}
22
23/// Workspace-wide, per-session checkpoint store.
24///
25/// Partitioned by session (issue #14): two OpenCode sessions sharing one bridge
26/// can both create checkpoints named `snap1` without collision, and restoring
27/// from one session does not leak the other's file set. Checkpoints are kept
28/// in memory only — a bridge crash drops all of them, which is a deliberate
29/// trade-off to keep this refactor bounded. Durable checkpoints are a possible
30/// follow-up.
31#[derive(Debug)]
32pub struct CheckpointStore {
33    /// session -> name -> checkpoint
34    checkpoints: HashMap<String, HashMap<String, Checkpoint>>,
35}
36
37impl CheckpointStore {
38    pub fn new() -> Self {
39        CheckpointStore {
40            checkpoints: HashMap::new(),
41        }
42    }
43
44    /// Create a checkpoint by reading the given files, scoped to `session`.
45    ///
46    /// If `files` is empty, snapshots all tracked files for **that session**
47    /// from the BackupStore (other sessions' tracked files are not visible).
48    /// Overwrites any existing checkpoint with the same name in this session.
49    pub fn create(
50        &mut self,
51        session: &str,
52        name: &str,
53        files: Vec<PathBuf>,
54        backup_store: &BackupStore,
55    ) -> Result<CheckpointInfo, AftError> {
56        let file_list = if files.is_empty() {
57            backup_store.tracked_files(session)
58        } else {
59            files
60        };
61
62        let mut file_contents = HashMap::new();
63        for path in &file_list {
64            let content = std::fs::read_to_string(path).map_err(|_| AftError::FileNotFound {
65                path: path.display().to_string(),
66            })?;
67            file_contents.insert(path.clone(), content);
68        }
69
70        let created_at = current_timestamp();
71        let file_count = file_contents.len();
72
73        let checkpoint = Checkpoint {
74            name: name.to_string(),
75            file_contents,
76            created_at,
77        };
78
79        self.checkpoints
80            .entry(session.to_string())
81            .or_default()
82            .insert(name.to_string(), checkpoint);
83
84        log::info!("checkpoint created: {} ({} files)", name, file_count);
85
86        Ok(CheckpointInfo {
87            name: name.to_string(),
88            file_count,
89            created_at,
90        })
91    }
92
93    /// Restore a checkpoint by overwriting files with stored content.
94    pub fn restore(&self, session: &str, name: &str) -> Result<CheckpointInfo, AftError> {
95        let checkpoint = self.get(session, name)?;
96
97        for (path, content) in &checkpoint.file_contents {
98            std::fs::write(path, content).map_err(|_| AftError::FileNotFound {
99                path: path.display().to_string(),
100            })?;
101        }
102
103        log::info!("checkpoint restored: {}", name);
104
105        Ok(CheckpointInfo {
106            name: checkpoint.name.clone(),
107            file_count: checkpoint.file_contents.len(),
108            created_at: checkpoint.created_at,
109        })
110    }
111
112    /// Restore a checkpoint using a caller-validated path list.
113    pub fn restore_validated(
114        &self,
115        session: &str,
116        name: &str,
117        validated_paths: &[PathBuf],
118    ) -> Result<CheckpointInfo, AftError> {
119        let checkpoint = self.get(session, name)?;
120
121        for path in validated_paths {
122            let content =
123                checkpoint
124                    .file_contents
125                    .get(path)
126                    .ok_or_else(|| AftError::FileNotFound {
127                        path: path.display().to_string(),
128                    })?;
129            std::fs::write(path, content).map_err(|_| AftError::FileNotFound {
130                path: path.display().to_string(),
131            })?;
132        }
133
134        log::info!("checkpoint restored: {}", name);
135
136        Ok(CheckpointInfo {
137            name: checkpoint.name.clone(),
138            file_count: checkpoint.file_contents.len(),
139            created_at: checkpoint.created_at,
140        })
141    }
142
143    /// Return the file paths stored for a checkpoint.
144    pub fn file_paths(&self, session: &str, name: &str) -> Result<Vec<PathBuf>, AftError> {
145        let checkpoint = self.get(session, name)?;
146        Ok(checkpoint.file_contents.keys().cloned().collect())
147    }
148
149    /// List all checkpoints for this session with metadata.
150    pub fn list(&self, session: &str) -> Vec<CheckpointInfo> {
151        self.checkpoints
152            .get(session)
153            .map(|s| {
154                s.values()
155                    .map(|cp| CheckpointInfo {
156                        name: cp.name.clone(),
157                        file_count: cp.file_contents.len(),
158                        created_at: cp.created_at,
159                    })
160                    .collect()
161            })
162            .unwrap_or_default()
163    }
164
165    /// Total checkpoint count across all sessions (for `/aft-status`).
166    pub fn total_count(&self) -> usize {
167        self.checkpoints.values().map(|s| s.len()).sum()
168    }
169
170    /// Remove checkpoints older than `ttl_hours` across all sessions.
171    /// Empty session entries are pruned after cleanup.
172    pub fn cleanup(&mut self, ttl_hours: u32) {
173        let now = current_timestamp();
174        let ttl_secs = ttl_hours as u64 * 3600;
175        self.checkpoints.retain(|_, session_cps| {
176            session_cps.retain(|_, cp| now.saturating_sub(cp.created_at) < ttl_secs);
177            !session_cps.is_empty()
178        });
179    }
180
181    fn get(&self, session: &str, name: &str) -> Result<&Checkpoint, AftError> {
182        self.checkpoints
183            .get(session)
184            .and_then(|s| s.get(name))
185            .ok_or_else(|| AftError::CheckpointNotFound {
186                name: name.to_string(),
187            })
188    }
189}
190
191fn current_timestamp() -> u64 {
192    std::time::SystemTime::now()
193        .duration_since(std::time::UNIX_EPOCH)
194        .unwrap_or_default()
195        .as_secs()
196}
197
198#[cfg(test)]
199mod tests {
200    use super::*;
201    use crate::protocol::DEFAULT_SESSION_ID;
202    use std::fs;
203
204    fn temp_file(name: &str, content: &str) -> PathBuf {
205        let dir = std::env::temp_dir().join("aft_checkpoint_tests");
206        fs::create_dir_all(&dir).unwrap();
207        let path = dir.join(name);
208        fs::write(&path, content).unwrap();
209        path
210    }
211
212    #[test]
213    fn create_and_restore_round_trip() {
214        let path1 = temp_file("cp_rt1.txt", "hello");
215        let path2 = temp_file("cp_rt2.txt", "world");
216
217        let backup_store = BackupStore::new();
218        let mut store = CheckpointStore::new();
219
220        let info = store
221            .create(
222                DEFAULT_SESSION_ID,
223                "snap1",
224                vec![path1.clone(), path2.clone()],
225                &backup_store,
226            )
227            .unwrap();
228        assert_eq!(info.name, "snap1");
229        assert_eq!(info.file_count, 2);
230
231        // Modify files
232        fs::write(&path1, "changed1").unwrap();
233        fs::write(&path2, "changed2").unwrap();
234
235        // Restore
236        let info = store.restore(DEFAULT_SESSION_ID, "snap1").unwrap();
237        assert_eq!(info.file_count, 2);
238        assert_eq!(fs::read_to_string(&path1).unwrap(), "hello");
239        assert_eq!(fs::read_to_string(&path2).unwrap(), "world");
240    }
241
242    #[test]
243    fn overwrite_existing_name() {
244        let path = temp_file("cp_overwrite.txt", "v1");
245        let backup_store = BackupStore::new();
246        let mut store = CheckpointStore::new();
247
248        store
249            .create(DEFAULT_SESSION_ID, "dup", vec![path.clone()], &backup_store)
250            .unwrap();
251        fs::write(&path, "v2").unwrap();
252        store
253            .create(DEFAULT_SESSION_ID, "dup", vec![path.clone()], &backup_store)
254            .unwrap();
255
256        // Restore should give v2 (the overwritten checkpoint)
257        fs::write(&path, "v3").unwrap();
258        store.restore(DEFAULT_SESSION_ID, "dup").unwrap();
259        assert_eq!(fs::read_to_string(&path).unwrap(), "v2");
260    }
261
262    #[test]
263    fn list_returns_metadata_scoped_to_session() {
264        let path = temp_file("cp_list.txt", "data");
265        let backup_store = BackupStore::new();
266        let mut store = CheckpointStore::new();
267
268        store
269            .create(DEFAULT_SESSION_ID, "a", vec![path.clone()], &backup_store)
270            .unwrap();
271        store
272            .create(DEFAULT_SESSION_ID, "b", vec![path.clone()], &backup_store)
273            .unwrap();
274        store
275            .create("other_session", "c", vec![path.clone()], &backup_store)
276            .unwrap();
277
278        let default_list = store.list(DEFAULT_SESSION_ID);
279        assert_eq!(default_list.len(), 2);
280        let names: Vec<&str> = default_list.iter().map(|i| i.name.as_str()).collect();
281        assert!(names.contains(&"a"));
282        assert!(names.contains(&"b"));
283
284        let other_list = store.list("other_session");
285        assert_eq!(other_list.len(), 1);
286        assert_eq!(other_list[0].name, "c");
287    }
288
289    #[test]
290    fn sessions_isolate_checkpoint_names() {
291        // Same checkpoint name in two sessions does not collide on restore.
292        let path_a = temp_file("cp_isolated_a.txt", "a-original");
293        let path_b = temp_file("cp_isolated_b.txt", "b-original");
294        let backup_store = BackupStore::new();
295        let mut store = CheckpointStore::new();
296
297        // Both sessions create a checkpoint with the same name but different files.
298        store
299            .create("session_a", "snap", vec![path_a.clone()], &backup_store)
300            .unwrap();
301        store
302            .create("session_b", "snap", vec![path_b.clone()], &backup_store)
303            .unwrap();
304
305        fs::write(&path_a, "a-modified").unwrap();
306        fs::write(&path_b, "b-modified").unwrap();
307
308        // Restoring session A's "snap" only touches path_a.
309        store.restore("session_a", "snap").unwrap();
310        assert_eq!(fs::read_to_string(&path_a).unwrap(), "a-original");
311        assert_eq!(fs::read_to_string(&path_b).unwrap(), "b-modified");
312
313        // Restoring session B's "snap" only touches path_b.
314        fs::write(&path_a, "a-modified").unwrap();
315        store.restore("session_b", "snap").unwrap();
316        assert_eq!(fs::read_to_string(&path_a).unwrap(), "a-modified");
317        assert_eq!(fs::read_to_string(&path_b).unwrap(), "b-original");
318    }
319
320    #[test]
321    fn cleanup_removes_expired_across_sessions() {
322        let path = temp_file("cp_cleanup.txt", "data");
323        let backup_store = BackupStore::new();
324        let mut store = CheckpointStore::new();
325
326        store
327            .create(
328                DEFAULT_SESSION_ID,
329                "recent",
330                vec![path.clone()],
331                &backup_store,
332            )
333            .unwrap();
334
335        // Manually insert an expired checkpoint in another session.
336        store
337            .checkpoints
338            .entry("other".to_string())
339            .or_default()
340            .insert(
341                "old".to_string(),
342                Checkpoint {
343                    name: "old".to_string(),
344                    file_contents: HashMap::new(),
345                    created_at: 1000, // far in the past
346                },
347            );
348
349        assert_eq!(store.total_count(), 2);
350        store.cleanup(24); // 24 hours
351        assert_eq!(store.total_count(), 1);
352        assert_eq!(store.list(DEFAULT_SESSION_ID)[0].name, "recent");
353        assert!(store.list("other").is_empty());
354    }
355
356    #[test]
357    fn restore_nonexistent_returns_error() {
358        let store = CheckpointStore::new();
359        let result = store.restore(DEFAULT_SESSION_ID, "nope");
360        assert!(result.is_err());
361        match result.unwrap_err() {
362            AftError::CheckpointNotFound { name } => {
363                assert_eq!(name, "nope");
364            }
365            other => panic!("expected CheckpointNotFound, got: {:?}", other),
366        }
367    }
368
369    #[test]
370    fn restore_nonexistent_in_other_session_returns_error() {
371        // A "snap" that exists in session A must NOT be visible from session B.
372        let path = temp_file("cp_cross_session.txt", "data");
373        let backup_store = BackupStore::new();
374        let mut store = CheckpointStore::new();
375        store
376            .create("session_a", "only_a", vec![path], &backup_store)
377            .unwrap();
378        assert!(store.restore("session_b", "only_a").is_err());
379    }
380
381    #[test]
382    fn create_with_empty_files_uses_backup_tracked() {
383        let path = temp_file("cp_tracked.txt", "tracked_content");
384        let mut backup_store = BackupStore::new();
385        backup_store
386            .snapshot(DEFAULT_SESSION_ID, &path, "auto")
387            .unwrap();
388
389        let mut store = CheckpointStore::new();
390        let info = store
391            .create(DEFAULT_SESSION_ID, "from_tracked", vec![], &backup_store)
392            .unwrap();
393        assert!(info.file_count >= 1);
394
395        // Modify and restore
396        fs::write(&path, "modified").unwrap();
397        store.restore(DEFAULT_SESSION_ID, "from_tracked").unwrap();
398        assert_eq!(fs::read_to_string(&path).unwrap(), "tracked_content");
399    }
400}