1use std::collections::HashMap;
24use std::path::{Path, PathBuf};
25
26#[derive(Debug, Clone)]
28pub struct UndoEntry {
29 files: HashMap<PathBuf, Option<Vec<u8>>>,
31}
32
33#[derive(Debug, Default)]
35pub struct UndoStack {
36 entries: Vec<UndoEntry>,
37 pending: HashMap<PathBuf, Option<Vec<u8>>>,
39}
40
41impl UndoStack {
42 pub fn new() -> Self {
44 Self::default()
45 }
46
47 pub fn snapshot(&mut self, path: &Path) {
51 let abs = match std::fs::canonicalize(path) {
52 Ok(p) => p,
53 Err(_) => path.to_path_buf(), };
55
56 if self.pending.contains_key(&abs) {
58 return;
59 }
60
61 let content = std::fs::read(&abs).ok();
62 self.pending.insert(abs, content);
63 }
64
65 pub fn commit_turn(&mut self) {
70 if self.pending.is_empty() {
71 return;
72 }
73 self.entries.push(UndoEntry {
74 files: std::mem::take(&mut self.pending),
75 });
76 }
77
78 pub fn undo(&mut self) -> Option<String> {
82 let entry = self.entries.pop()?;
83 let mut restored = Vec::new();
84
85 for (path, original) in &entry.files {
86 match original {
87 Some(content) => {
88 if let Err(e) = std::fs::write(path, content) {
89 restored.push(format!(" ❌ {} (write failed: {e})", path.display()));
90 } else {
91 restored.push(format!(" ↩ {} (restored)", path.display()));
92 }
93 }
94 None => {
95 if let Err(e) = std::fs::remove_file(path) {
97 restored.push(format!(" ❌ {} (delete failed: {e})", path.display()));
98 } else {
99 restored.push(format!(
100 " ↩ {} (removed — was newly created)",
101 path.display()
102 ));
103 }
104 }
105 }
106 }
107
108 restored.sort();
109 Some(format!(
110 "Undid {} file(s) from last turn:\n{}",
111 entry.files.len(),
112 restored.join("\n")
113 ))
114 }
115
116 pub fn depth(&self) -> usize {
118 self.entries.len()
119 }
120}
121
122pub fn is_mutating_tool(name: &str) -> bool {
136 matches!(name, "Write" | "Edit" | "Delete" | "Overwrite")
137}
138
139pub fn extract_file_path(name: &str, args: &serde_json::Value) -> Option<String> {
155 match name {
156 "Write" | "Edit" | "Delete" | "Overwrite" => args
157 .get("file_path")
158 .or_else(|| args.get("path"))
159 .and_then(|v| v.as_str())
160 .map(|s| s.to_string()),
161 _ => None,
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use tempfile::TempDir;
169
170 fn setup() -> (UndoStack, TempDir) {
171 (UndoStack::new(), TempDir::new().unwrap())
172 }
173
174 #[test]
175 fn test_undo_restores_overwritten_file() {
176 let (mut stack, tmp) = setup();
177 let path = tmp.path().join("test.txt");
178 std::fs::write(&path, "original").unwrap();
179
180 stack.snapshot(&path);
182 std::fs::write(&path, "modified").unwrap();
183 stack.commit_turn();
184
185 let result = stack.undo();
187 assert!(result.is_some());
188 assert_eq!(std::fs::read_to_string(&path).unwrap(), "original");
189 }
190
191 #[test]
192 fn test_undo_removes_newly_created_file() {
193 let (mut stack, tmp) = setup();
194 let path = tmp.path().join("new.txt");
195
196 stack.snapshot(&path);
198 std::fs::write(&path, "created").unwrap();
199 stack.commit_turn();
200
201 stack.undo();
203 assert!(!path.exists());
204 }
205
206 #[test]
207 fn test_undo_empty_stack() {
208 let mut stack = UndoStack::new();
209 assert!(stack.undo().is_none());
210 }
211
212 #[test]
213 fn test_multiple_files_per_turn() {
214 let (mut stack, tmp) = setup();
215 let a = tmp.path().join("a.txt");
216 let b = tmp.path().join("b.txt");
217 std::fs::write(&a, "aaa").unwrap();
218 std::fs::write(&b, "bbb").unwrap();
219
220 stack.snapshot(&a);
221 stack.snapshot(&b);
222 std::fs::write(&a, "AAA").unwrap();
223 std::fs::write(&b, "BBB").unwrap();
224 stack.commit_turn();
225
226 stack.undo();
227 assert_eq!(std::fs::read_to_string(&a).unwrap(), "aaa");
228 assert_eq!(std::fs::read_to_string(&b).unwrap(), "bbb");
229 }
230
231 #[test]
232 fn test_only_first_snapshot_per_file_per_turn() {
233 let (mut stack, tmp) = setup();
234 let path = tmp.path().join("test.txt");
235 std::fs::write(&path, "v1").unwrap();
236
237 stack.snapshot(&path); std::fs::write(&path, "v2").unwrap();
239 stack.snapshot(&path); std::fs::write(&path, "v3").unwrap();
241 stack.commit_turn();
242
243 stack.undo();
244 assert_eq!(std::fs::read_to_string(&path).unwrap(), "v1");
245 }
246
247 #[test]
248 fn test_multi_turn_undo() {
249 let (mut stack, tmp) = setup();
250 let path = tmp.path().join("test.txt");
251 std::fs::write(&path, "v1").unwrap();
252
253 stack.snapshot(&path);
255 std::fs::write(&path, "v2").unwrap();
256 stack.commit_turn();
257
258 stack.snapshot(&path);
260 std::fs::write(&path, "v3").unwrap();
261 stack.commit_turn();
262
263 assert_eq!(stack.depth(), 2);
264
265 stack.undo();
267 assert_eq!(std::fs::read_to_string(&path).unwrap(), "v2");
268
269 stack.undo();
271 assert_eq!(std::fs::read_to_string(&path).unwrap(), "v1");
272 }
273
274 #[test]
275 fn test_is_mutating_tool() {
276 assert!(is_mutating_tool("Write"));
277 assert!(is_mutating_tool("Edit"));
278 assert!(is_mutating_tool("Delete"));
279 assert!(!is_mutating_tool("Read"));
280 assert!(!is_mutating_tool("Grep"));
281 assert!(!is_mutating_tool("Bash"));
282 }
283
284 #[test]
285 fn test_extract_file_path() {
286 let args = serde_json::json!({"file_path": "src/main.rs"});
287 assert_eq!(
288 extract_file_path("Write", &args),
289 Some("src/main.rs".into())
290 );
291 assert_eq!(extract_file_path("Read", &args), None);
292 }
293
294 #[test]
295 fn test_no_commit_if_no_snapshots() {
296 let mut stack = UndoStack::new();
297 stack.commit_turn(); assert_eq!(stack.depth(), 0);
299 }
300}