cersei_tools/
file_snapshot.rs1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6#[derive(Debug, Clone)]
8pub struct FileSnapshot {
9 pub path: PathBuf,
10 pub before: String,
11 pub after: String,
12 pub tool_name: String,
13 pub tool_call_id: String,
14 pub timestamp: u64,
15}
16
17#[derive(Debug, Clone, Default)]
19pub struct SnapshotManager {
20 snapshots: Vec<FileSnapshot>,
22}
23
24impl SnapshotManager {
25 pub fn new() -> Self {
26 Self::default()
27 }
28
29 pub fn record(
31 &mut self,
32 path: &Path,
33 before: &str,
34 after: &str,
35 tool_name: &str,
36 tool_call_id: &str,
37 ) {
38 self.snapshots.push(FileSnapshot {
39 path: path.to_path_buf(),
40 before: before.to_string(),
41 after: after.to_string(),
42 tool_name: tool_name.to_string(),
43 tool_call_id: tool_call_id.to_string(),
44 timestamp: now_secs(),
45 });
46 }
47
48 pub fn undo_last(&mut self, path: &Path) -> Option<String> {
50 if let Some(idx) = self.snapshots.iter().rposition(|s| s.path == path) {
52 let snapshot = self.snapshots.remove(idx);
53 if std::fs::write(&snapshot.path, &snapshot.before).is_ok() {
55 return Some(snapshot.before);
56 }
57 }
58 None
59 }
60
61 pub fn undo_tool_call(&mut self, tool_call_id: &str) -> Vec<PathBuf> {
63 let mut reverted = Vec::new();
64 let matching: Vec<usize> = self
65 .snapshots
66 .iter()
67 .enumerate()
68 .filter(|(_, s)| s.tool_call_id == tool_call_id)
69 .map(|(i, _)| i)
70 .collect();
71
72 for idx in matching.into_iter().rev() {
74 let snapshot = &self.snapshots[idx];
75 if std::fs::write(&snapshot.path, &snapshot.before).is_ok() {
76 reverted.push(snapshot.path.clone());
77 }
78 }
79
80 self.snapshots.retain(|s| s.tool_call_id != tool_call_id);
82 reverted
83 }
84
85 pub fn undo_all(&mut self) -> Vec<PathBuf> {
87 let mut reverted = Vec::new();
88
89 let mut first_states: HashMap<PathBuf, String> = HashMap::new();
91 for snapshot in &self.snapshots {
92 first_states
93 .entry(snapshot.path.clone())
94 .or_insert_with(|| snapshot.before.clone());
95 }
96
97 for (path, original) in &first_states {
98 if std::fs::write(path, original).is_ok() {
99 reverted.push(path.clone());
100 }
101 }
102
103 self.snapshots.clear();
104 reverted
105 }
106
107 pub fn recent(&self, limit: usize) -> Vec<&FileSnapshot> {
109 self.snapshots.iter().rev().take(limit).collect()
110 }
111
112 pub fn for_file(&self, path: &Path) -> Vec<&FileSnapshot> {
114 self.snapshots.iter().filter(|s| s.path == path).collect()
115 }
116
117 pub fn count(&self) -> usize {
119 self.snapshots.len()
120 }
121
122 pub fn modified_files(&self) -> Vec<PathBuf> {
124 let mut files: Vec<PathBuf> = self
125 .snapshots
126 .iter()
127 .map(|s| s.path.clone())
128 .collect::<std::collections::HashSet<_>>()
129 .into_iter()
130 .collect();
131 files.sort();
132 files
133 }
134}
135
136static SNAPSHOT_REGISTRY: once_cell::sync::Lazy<
138 dashmap::DashMap<String, std::sync::Arc<parking_lot::Mutex<SnapshotManager>>>,
139> = once_cell::sync::Lazy::new(dashmap::DashMap::new);
140
141pub fn session_snapshots(session_id: &str) -> std::sync::Arc<parking_lot::Mutex<SnapshotManager>> {
143 SNAPSHOT_REGISTRY
144 .entry(session_id.to_string())
145 .or_insert_with(|| std::sync::Arc::new(parking_lot::Mutex::new(SnapshotManager::new())))
146 .clone()
147}
148
149fn now_secs() -> u64 {
150 std::time::SystemTime::now()
151 .duration_since(std::time::UNIX_EPOCH)
152 .map(|d| d.as_secs())
153 .unwrap_or(0)
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use std::io::Write;
160
161 #[test]
162 fn test_record_and_undo() {
163 let dir = tempfile::tempdir().unwrap();
164 let file = dir.path().join("test.txt");
165 std::fs::write(&file, "original").unwrap();
166
167 let mut mgr = SnapshotManager::new();
168 mgr.record(&file, "original", "modified", "Edit", "call-1");
169 std::fs::write(&file, "modified").unwrap();
170
171 assert_eq!(mgr.count(), 1);
172 let restored = mgr.undo_last(&file);
173 assert_eq!(restored, Some("original".to_string()));
174 assert_eq!(std::fs::read_to_string(&file).unwrap(), "original");
175 }
176
177 #[test]
178 fn test_undo_tool_call() {
179 let dir = tempfile::tempdir().unwrap();
180 let f1 = dir.path().join("a.txt");
181 let f2 = dir.path().join("b.txt");
182 std::fs::write(&f1, "a-orig").unwrap();
183 std::fs::write(&f2, "b-orig").unwrap();
184
185 let mut mgr = SnapshotManager::new();
186 mgr.record(&f1, "a-orig", "a-new", "Edit", "call-1");
187 mgr.record(&f2, "b-orig", "b-new", "Edit", "call-1");
188 std::fs::write(&f1, "a-new").unwrap();
189 std::fs::write(&f2, "b-new").unwrap();
190
191 let reverted = mgr.undo_tool_call("call-1");
192 assert_eq!(reverted.len(), 2);
193 assert_eq!(std::fs::read_to_string(&f1).unwrap(), "a-orig");
194 assert_eq!(std::fs::read_to_string(&f2).unwrap(), "b-orig");
195 }
196
197 #[test]
198 fn test_modified_files() {
199 let mut mgr = SnapshotManager::new();
200 mgr.record(Path::new("/a.rs"), "x", "y", "Edit", "c1");
201 mgr.record(Path::new("/b.rs"), "x", "y", "Write", "c2");
202 mgr.record(Path::new("/a.rs"), "y", "z", "Edit", "c3");
203 assert_eq!(mgr.modified_files().len(), 2);
204 }
205}