Skip to main content

cersei_tools/
file_snapshot.rs

1//! File snapshot system: stores before/after content per tool call for undo.
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6/// A snapshot of a file's content before a tool modified it.
7#[derive(Debug, Clone)]
8pub struct FileSnapshot {
9    pub path: PathBuf,
10    pub before: String,
11    pub after: String,
12    pub tool_name: String,
13    pub tool_call_id: String,
14    pub timestamp: u64,
15}
16
17/// Session-level snapshot manager for undo support.
18#[derive(Debug, Clone, Default)]
19pub struct SnapshotManager {
20    /// All snapshots, ordered by time.
21    snapshots: Vec<FileSnapshot>,
22}
23
24impl SnapshotManager {
25    pub fn new() -> Self {
26        Self::default()
27    }
28
29    /// Record a file modification. Call BEFORE writing the new content.
30    pub fn record(
31        &mut self,
32        path: &Path,
33        before: &str,
34        after: &str,
35        tool_name: &str,
36        tool_call_id: &str,
37    ) {
38        self.snapshots.push(FileSnapshot {
39            path: path.to_path_buf(),
40            before: before.to_string(),
41            after: after.to_string(),
42            tool_name: tool_name.to_string(),
43            tool_call_id: tool_call_id.to_string(),
44            timestamp: now_secs(),
45        });
46    }
47
48    /// Undo the last modification to a file. Returns the restored content.
49    pub fn undo_last(&mut self, path: &Path) -> Option<String> {
50        // Find the most recent snapshot for this path
51        if let Some(idx) = self.snapshots.iter().rposition(|s| s.path == path) {
52            let snapshot = self.snapshots.remove(idx);
53            // Write the before content back to disk
54            if std::fs::write(&snapshot.path, &snapshot.before).is_ok() {
55                return Some(snapshot.before);
56            }
57        }
58        None
59    }
60
61    /// Undo all changes from a specific tool call ID.
62    pub fn undo_tool_call(&mut self, tool_call_id: &str) -> Vec<PathBuf> {
63        let mut reverted = Vec::new();
64        let matching: Vec<usize> = self
65            .snapshots
66            .iter()
67            .enumerate()
68            .filter(|(_, s)| s.tool_call_id == tool_call_id)
69            .map(|(i, _)| i)
70            .collect();
71
72        // Process in reverse order to handle multiple edits to same file
73        for idx in matching.into_iter().rev() {
74            let snapshot = &self.snapshots[idx];
75            if std::fs::write(&snapshot.path, &snapshot.before).is_ok() {
76                reverted.push(snapshot.path.clone());
77            }
78        }
79
80        // Remove the matching snapshots
81        self.snapshots.retain(|s| s.tool_call_id != tool_call_id);
82        reverted
83    }
84
85    /// Undo ALL changes in this session (nuclear option).
86    pub fn undo_all(&mut self) -> Vec<PathBuf> {
87        let mut reverted = Vec::new();
88
89        // Group by file, restore each to its ORIGINAL state (first snapshot's before)
90        let mut first_states: HashMap<PathBuf, String> = HashMap::new();
91        for snapshot in &self.snapshots {
92            first_states
93                .entry(snapshot.path.clone())
94                .or_insert_with(|| snapshot.before.clone());
95        }
96
97        for (path, original) in &first_states {
98            if std::fs::write(path, original).is_ok() {
99                reverted.push(path.clone());
100            }
101        }
102
103        self.snapshots.clear();
104        reverted
105    }
106
107    /// Get recent snapshots (newest first).
108    pub fn recent(&self, limit: usize) -> Vec<&FileSnapshot> {
109        self.snapshots.iter().rev().take(limit).collect()
110    }
111
112    /// Get all snapshots for a file.
113    pub fn for_file(&self, path: &Path) -> Vec<&FileSnapshot> {
114        self.snapshots.iter().filter(|s| s.path == path).collect()
115    }
116
117    /// Total number of snapshots.
118    pub fn count(&self) -> usize {
119        self.snapshots.len()
120    }
121
122    /// List of unique files that have been modified.
123    pub fn modified_files(&self) -> Vec<PathBuf> {
124        let mut files: Vec<PathBuf> = self
125            .snapshots
126            .iter()
127            .map(|s| s.path.clone())
128            .collect::<std::collections::HashSet<_>>()
129            .into_iter()
130            .collect();
131        files.sort();
132        files
133    }
134}
135
136/// Global snapshot registry keyed by session_id.
137static SNAPSHOT_REGISTRY: once_cell::sync::Lazy<
138    dashmap::DashMap<String, std::sync::Arc<parking_lot::Mutex<SnapshotManager>>>,
139> = once_cell::sync::Lazy::new(dashmap::DashMap::new);
140
141/// Get or create the snapshot manager for a session.
142pub fn session_snapshots(session_id: &str) -> std::sync::Arc<parking_lot::Mutex<SnapshotManager>> {
143    SNAPSHOT_REGISTRY
144        .entry(session_id.to_string())
145        .or_insert_with(|| std::sync::Arc::new(parking_lot::Mutex::new(SnapshotManager::new())))
146        .clone()
147}
148
149fn now_secs() -> u64 {
150    std::time::SystemTime::now()
151        .duration_since(std::time::UNIX_EPOCH)
152        .map(|d| d.as_secs())
153        .unwrap_or(0)
154}
155
156#[cfg(test)]
157mod tests {
158    use super::*;
159    use std::io::Write;
160
161    #[test]
162    fn test_record_and_undo() {
163        let dir = tempfile::tempdir().unwrap();
164        let file = dir.path().join("test.txt");
165        std::fs::write(&file, "original").unwrap();
166
167        let mut mgr = SnapshotManager::new();
168        mgr.record(&file, "original", "modified", "Edit", "call-1");
169        std::fs::write(&file, "modified").unwrap();
170
171        assert_eq!(mgr.count(), 1);
172        let restored = mgr.undo_last(&file);
173        assert_eq!(restored, Some("original".to_string()));
174        assert_eq!(std::fs::read_to_string(&file).unwrap(), "original");
175    }
176
177    #[test]
178    fn test_undo_tool_call() {
179        let dir = tempfile::tempdir().unwrap();
180        let f1 = dir.path().join("a.txt");
181        let f2 = dir.path().join("b.txt");
182        std::fs::write(&f1, "a-orig").unwrap();
183        std::fs::write(&f2, "b-orig").unwrap();
184
185        let mut mgr = SnapshotManager::new();
186        mgr.record(&f1, "a-orig", "a-new", "Edit", "call-1");
187        mgr.record(&f2, "b-orig", "b-new", "Edit", "call-1");
188        std::fs::write(&f1, "a-new").unwrap();
189        std::fs::write(&f2, "b-new").unwrap();
190
191        let reverted = mgr.undo_tool_call("call-1");
192        assert_eq!(reverted.len(), 2);
193        assert_eq!(std::fs::read_to_string(&f1).unwrap(), "a-orig");
194        assert_eq!(std::fs::read_to_string(&f2).unwrap(), "b-orig");
195    }
196
197    #[test]
198    fn test_modified_files() {
199        let mut mgr = SnapshotManager::new();
200        mgr.record(Path::new("/a.rs"), "x", "y", "Edit", "c1");
201        mgr.record(Path::new("/b.rs"), "x", "y", "Write", "c2");
202        mgr.record(Path::new("/a.rs"), "y", "z", "Edit", "c3");
203        assert_eq!(mgr.modified_files().len(), 2);
204    }
205}