use std::collections::HashMap;
use std::path::{Path, PathBuf};
#[derive(Debug, Clone)]
pub struct UndoEntry {
files: HashMap<PathBuf, Option<Vec<u8>>>,
}
#[derive(Debug, Default)]
pub struct UndoStack {
entries: Vec<UndoEntry>,
pending: HashMap<PathBuf, Option<Vec<u8>>>,
}
impl UndoStack {
pub fn new() -> Self {
Self::default()
}
pub fn snapshot(&mut self, path: &Path) {
let abs = match std::fs::canonicalize(path) {
Ok(p) => p,
Err(_) => path.to_path_buf(), };
if self.pending.contains_key(&abs) {
return;
}
let content = std::fs::read(&abs).ok();
self.pending.insert(abs, content);
}
pub fn commit_turn(&mut self) {
if self.pending.is_empty() {
return;
}
self.entries.push(UndoEntry {
files: std::mem::take(&mut self.pending),
});
}
pub fn undo(&mut self) -> Option<String> {
let entry = self.entries.pop()?;
let mut restored = Vec::new();
for (path, original) in &entry.files {
match original {
Some(content) => {
if let Err(e) = std::fs::write(path, content) {
restored.push(format!(" ❌ {} (write failed: {e})", path.display()));
} else {
restored.push(format!(" ↩ {} (restored)", path.display()));
}
}
None => {
if let Err(e) = std::fs::remove_file(path) {
restored.push(format!(" ❌ {} (delete failed: {e})", path.display()));
} else {
restored.push(format!(
" ↩ {} (removed — was newly created)",
path.display()
));
}
}
}
}
restored.sort();
Some(format!(
"Undid {} file(s) from last turn:\n{}",
entry.files.len(),
restored.join("\n")
))
}
pub fn depth(&self) -> usize {
self.entries.len()
}
}
pub fn is_mutating_tool(name: &str) -> bool {
matches!(name, "Write" | "Edit" | "Delete" | "Overwrite")
}
pub fn extract_file_path(name: &str, args: &serde_json::Value) -> Option<String> {
match name {
"Write" | "Edit" | "Delete" | "Overwrite" => args
.get("file_path")
.or_else(|| args.get("path"))
.and_then(|v| v.as_str())
.map(|s| s.to_string()),
_ => None,
}
}
#[cfg(test)]
mod tests {
use super::*;
use tempfile::TempDir;
fn setup() -> (UndoStack, TempDir) {
(UndoStack::new(), TempDir::new().unwrap())
}
#[test]
fn test_undo_restores_overwritten_file() {
let (mut stack, tmp) = setup();
let path = tmp.path().join("test.txt");
std::fs::write(&path, "original").unwrap();
stack.snapshot(&path);
std::fs::write(&path, "modified").unwrap();
stack.commit_turn();
let result = stack.undo();
assert!(result.is_some());
assert_eq!(std::fs::read_to_string(&path).unwrap(), "original");
}
#[test]
fn test_undo_removes_newly_created_file() {
let (mut stack, tmp) = setup();
let path = tmp.path().join("new.txt");
stack.snapshot(&path);
std::fs::write(&path, "created").unwrap();
stack.commit_turn();
stack.undo();
assert!(!path.exists());
}
#[test]
fn test_undo_empty_stack() {
let mut stack = UndoStack::new();
assert!(stack.undo().is_none());
}
#[test]
fn test_multiple_files_per_turn() {
let (mut stack, tmp) = setup();
let a = tmp.path().join("a.txt");
let b = tmp.path().join("b.txt");
std::fs::write(&a, "aaa").unwrap();
std::fs::write(&b, "bbb").unwrap();
stack.snapshot(&a);
stack.snapshot(&b);
std::fs::write(&a, "AAA").unwrap();
std::fs::write(&b, "BBB").unwrap();
stack.commit_turn();
stack.undo();
assert_eq!(std::fs::read_to_string(&a).unwrap(), "aaa");
assert_eq!(std::fs::read_to_string(&b).unwrap(), "bbb");
}
#[test]
fn test_only_first_snapshot_per_file_per_turn() {
let (mut stack, tmp) = setup();
let path = tmp.path().join("test.txt");
std::fs::write(&path, "v1").unwrap();
stack.snapshot(&path); std::fs::write(&path, "v2").unwrap();
stack.snapshot(&path); std::fs::write(&path, "v3").unwrap();
stack.commit_turn();
stack.undo();
assert_eq!(std::fs::read_to_string(&path).unwrap(), "v1");
}
#[test]
fn test_multi_turn_undo() {
let (mut stack, tmp) = setup();
let path = tmp.path().join("test.txt");
std::fs::write(&path, "v1").unwrap();
stack.snapshot(&path);
std::fs::write(&path, "v2").unwrap();
stack.commit_turn();
stack.snapshot(&path);
std::fs::write(&path, "v3").unwrap();
stack.commit_turn();
assert_eq!(stack.depth(), 2);
stack.undo();
assert_eq!(std::fs::read_to_string(&path).unwrap(), "v2");
stack.undo();
assert_eq!(std::fs::read_to_string(&path).unwrap(), "v1");
}
#[test]
fn test_is_mutating_tool() {
assert!(is_mutating_tool("Write"));
assert!(is_mutating_tool("Edit"));
assert!(is_mutating_tool("Delete"));
assert!(!is_mutating_tool("Read"));
assert!(!is_mutating_tool("Grep"));
assert!(!is_mutating_tool("Bash"));
}
#[test]
fn test_extract_file_path() {
let args = serde_json::json!({"file_path": "src/main.rs"});
assert_eq!(
extract_file_path("Write", &args),
Some("src/main.rs".into())
);
assert_eq!(extract_file_path("Read", &args), None);
}
#[test]
fn test_no_commit_if_no_snapshots() {
let mut stack = UndoStack::new();
stack.commit_turn(); assert_eq!(stack.depth(), 0);
}
}