Skip to main content

koda_core/
undo.rs

1//! Undo stack for file mutations.
2//!
3//! Snapshots file contents before Write/Edit/Delete tool execution.
4//! Each turn's mutations are grouped into a single undo entry.
5//!
6//! ## How it works
7//!
8//! 1. Before any file mutation, the current file contents are snapshotted
9//! 2. All mutations in a single turn are grouped into one undo entry
10//! 3. `/undo` restores all files from the most recent entry
11//! 4. Stack depth is unlimited — undo as many turns as needed
12//!
13//! ## What gets tracked
14//!
15//! - **Write**: snapshots the file if it existed (for overwrite), or marks as "created"
16//! - **Edit**: snapshots the file before the edit
17//! - **Delete**: snapshots the file contents before deletion
18//!
19//! Git checkpointing provides a separate safety net via `git stash`-style
20//! snapshots before each turn. `/undo` is faster (in-memory) but git
21//! checkpoints survive process crashes.
22
23use std::collections::HashMap;
24use std::path::{Path, PathBuf};
25
26/// A snapshot of file states before a turn's mutations.
27#[derive(Debug, Clone)]
28pub struct UndoEntry {
29    /// Map of absolute path → previous content (None = file didn't exist).
30    files: HashMap<PathBuf, Option<Vec<u8>>>,
31}
32
33/// Stack of undo entries, one per turn.
34#[derive(Debug, Default)]
35pub struct UndoStack {
36    entries: Vec<UndoEntry>,
37    /// Accumulates snapshots for the current (in-progress) turn.
38    pending: HashMap<PathBuf, Option<Vec<u8>>>,
39}
40
41impl UndoStack {
42    /// Create an empty undo stack.
43    pub fn new() -> Self {
44        Self::default()
45    }
46
47    /// Snapshot a file before mutation. Call before Write/Edit/Delete.
48    ///
49    /// Only snapshots the first time per file per turn (preserves original state).
50    pub fn snapshot(&mut self, path: &Path) {
51        let abs = match std::fs::canonicalize(path) {
52            Ok(p) => p,
53            Err(_) => path.to_path_buf(), // File doesn't exist yet
54        };
55
56        // Only snapshot the first mutation per file per turn
57        if self.pending.contains_key(&abs) {
58            return;
59        }
60
61        let content = std::fs::read(&abs).ok();
62        self.pending.insert(abs, content);
63    }
64
65    /// Finalize the current turn's snapshots into an undo entry.
66    ///
67    /// Call at the end of each inference turn (after all tool calls complete).
68    /// Does nothing if no mutations were snapshotted.
69    pub fn commit_turn(&mut self) {
70        if self.pending.is_empty() {
71            return;
72        }
73        self.entries.push(UndoEntry {
74            files: std::mem::take(&mut self.pending),
75        });
76    }
77
78    /// Undo the last turn's file mutations.
79    ///
80    /// Returns a summary of what was restored, or None if nothing to undo.
81    pub fn undo(&mut self) -> Option<String> {
82        let entry = self.entries.pop()?;
83        let mut restored = Vec::new();
84
85        for (path, original) in &entry.files {
86            match original {
87                Some(content) => {
88                    if let Err(e) = std::fs::write(path, content) {
89                        restored.push(format!("  ❌ {} (write failed: {e})", path.display()));
90                    } else {
91                        restored.push(format!("  ↩ {} (restored)", path.display()));
92                    }
93                }
94                None => {
95                    // File didn't exist before — delete it
96                    if let Err(e) = std::fs::remove_file(path) {
97                        restored.push(format!("  ❌ {} (delete failed: {e})", path.display()));
98                    } else {
99                        restored.push(format!(
100                            "  ↩ {} (removed — was newly created)",
101                            path.display()
102                        ));
103                    }
104                }
105            }
106        }
107
108        restored.sort();
109        Some(format!(
110            "Undid {} file(s) from last turn:\n{}",
111            entry.files.len(),
112            restored.join("\n")
113        ))
114    }
115
116    /// How many turns can be undone.
117    pub fn depth(&self) -> usize {
118        self.entries.len()
119    }
120}
121
122/// Check if a tool name is a file-mutating tool that should be snapshotted.
123///
124/// # Examples
125///
126/// ```
127/// use koda_core::undo::is_mutating_tool;
128///
129/// assert!(is_mutating_tool("Write"));
130/// assert!(is_mutating_tool("Edit"));
131/// assert!(is_mutating_tool("Delete"));
132/// assert!(!is_mutating_tool("Read"));
133/// assert!(!is_mutating_tool("Grep"));
134/// ```
135pub fn is_mutating_tool(name: &str) -> bool {
136    matches!(name, "Write" | "Edit" | "Delete" | "Overwrite")
137}
138
139/// Extract the target file path from tool arguments.
140///
141/// # Examples
142///
143/// ```
144/// use koda_core::undo::extract_file_path;
145///
146/// let args = serde_json::json!({"file_path": "src/main.rs"});
147/// assert_eq!(extract_file_path("Write", &args), Some("src/main.rs".into()));
148/// assert_eq!(extract_file_path("Read", &args), None);
149///
150/// // Also accepts "path" as an alias:
151/// let args = serde_json::json!({"path": "lib.rs"});
152/// assert_eq!(extract_file_path("Edit", &args), Some("lib.rs".into()));
153/// ```
154pub fn extract_file_path(name: &str, args: &serde_json::Value) -> Option<String> {
155    match name {
156        "Write" | "Edit" | "Delete" | "Overwrite" => args
157            .get("file_path")
158            .or_else(|| args.get("path"))
159            .and_then(|v| v.as_str())
160            .map(|s| s.to_string()),
161        _ => None,
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use tempfile::TempDir;
169
170    fn setup() -> (UndoStack, TempDir) {
171        (UndoStack::new(), TempDir::new().unwrap())
172    }
173
174    #[test]
175    fn test_undo_restores_overwritten_file() {
176        let (mut stack, tmp) = setup();
177        let path = tmp.path().join("test.txt");
178        std::fs::write(&path, "original").unwrap();
179
180        // Snapshot before mutation
181        stack.snapshot(&path);
182        std::fs::write(&path, "modified").unwrap();
183        stack.commit_turn();
184
185        // Undo
186        let result = stack.undo();
187        assert!(result.is_some());
188        assert_eq!(std::fs::read_to_string(&path).unwrap(), "original");
189    }
190
191    #[test]
192    fn test_undo_removes_newly_created_file() {
193        let (mut stack, tmp) = setup();
194        let path = tmp.path().join("new.txt");
195
196        // Snapshot before creation (file doesn't exist)
197        stack.snapshot(&path);
198        std::fs::write(&path, "created").unwrap();
199        stack.commit_turn();
200
201        // Undo
202        stack.undo();
203        assert!(!path.exists());
204    }
205
206    #[test]
207    fn test_undo_empty_stack() {
208        let mut stack = UndoStack::new();
209        assert!(stack.undo().is_none());
210    }
211
212    #[test]
213    fn test_multiple_files_per_turn() {
214        let (mut stack, tmp) = setup();
215        let a = tmp.path().join("a.txt");
216        let b = tmp.path().join("b.txt");
217        std::fs::write(&a, "aaa").unwrap();
218        std::fs::write(&b, "bbb").unwrap();
219
220        stack.snapshot(&a);
221        stack.snapshot(&b);
222        std::fs::write(&a, "AAA").unwrap();
223        std::fs::write(&b, "BBB").unwrap();
224        stack.commit_turn();
225
226        stack.undo();
227        assert_eq!(std::fs::read_to_string(&a).unwrap(), "aaa");
228        assert_eq!(std::fs::read_to_string(&b).unwrap(), "bbb");
229    }
230
231    #[test]
232    fn test_only_first_snapshot_per_file_per_turn() {
233        let (mut stack, tmp) = setup();
234        let path = tmp.path().join("test.txt");
235        std::fs::write(&path, "v1").unwrap();
236
237        stack.snapshot(&path); // Captures "v1"
238        std::fs::write(&path, "v2").unwrap();
239        stack.snapshot(&path); // Should NOT overwrite snapshot
240        std::fs::write(&path, "v3").unwrap();
241        stack.commit_turn();
242
243        stack.undo();
244        assert_eq!(std::fs::read_to_string(&path).unwrap(), "v1");
245    }
246
247    #[test]
248    fn test_multi_turn_undo() {
249        let (mut stack, tmp) = setup();
250        let path = tmp.path().join("test.txt");
251        std::fs::write(&path, "v1").unwrap();
252
253        // Turn 1
254        stack.snapshot(&path);
255        std::fs::write(&path, "v2").unwrap();
256        stack.commit_turn();
257
258        // Turn 2
259        stack.snapshot(&path);
260        std::fs::write(&path, "v3").unwrap();
261        stack.commit_turn();
262
263        assert_eq!(stack.depth(), 2);
264
265        // Undo turn 2
266        stack.undo();
267        assert_eq!(std::fs::read_to_string(&path).unwrap(), "v2");
268
269        // Undo turn 1
270        stack.undo();
271        assert_eq!(std::fs::read_to_string(&path).unwrap(), "v1");
272    }
273
274    #[test]
275    fn test_is_mutating_tool() {
276        assert!(is_mutating_tool("Write"));
277        assert!(is_mutating_tool("Edit"));
278        assert!(is_mutating_tool("Delete"));
279        assert!(!is_mutating_tool("Read"));
280        assert!(!is_mutating_tool("Grep"));
281        assert!(!is_mutating_tool("Bash"));
282    }
283
284    #[test]
285    fn test_extract_file_path() {
286        let args = serde_json::json!({"file_path": "src/main.rs"});
287        assert_eq!(
288            extract_file_path("Write", &args),
289            Some("src/main.rs".into())
290        );
291        assert_eq!(extract_file_path("Read", &args), None);
292    }
293
294    #[test]
295    fn test_no_commit_if_no_snapshots() {
296        let mut stack = UndoStack::new();
297        stack.commit_turn(); // Nothing pending
298        assert_eq!(stack.depth(), 0);
299    }
300}