use chrono::{DateTime, Utc};
use serde::{Deserialize, Serialize};
use similar::TextDiff;
use std::collections::HashMap;
use std::sync::RwLock;
use crate::error::{read_or_recover, write_or_recover};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FileSnapshot {
pub version: usize,
pub path: String,
pub content: String,
pub timestamp: DateTime<Utc>,
pub tool_name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VersionSummary {
pub version: usize,
pub path: String,
pub timestamp: DateTime<Utc>,
pub tool_name: String,
pub size: usize,
}
impl From<&FileSnapshot> for VersionSummary {
fn from(snapshot: &FileSnapshot) -> Self {
Self {
version: snapshot.version,
path: snapshot.path.clone(),
timestamp: snapshot.timestamp,
tool_name: snapshot.tool_name.clone(),
size: snapshot.content.len(),
}
}
}
pub struct FileHistory {
snapshots: RwLock<HashMap<String, Vec<FileSnapshot>>>,
max_snapshots: usize,
}
impl FileHistory {
pub fn new(max_snapshots: usize) -> Self {
Self {
snapshots: RwLock::new(HashMap::new()),
max_snapshots,
}
}
pub fn save_snapshot(&self, path: &str, content: &str, tool_name: &str) -> usize {
let mut snapshots = write_or_recover(&self.snapshots);
let file_versions = snapshots.entry(path.to_string()).or_default();
let version = file_versions.len();
file_versions.push(FileSnapshot {
version,
path: path.to_string(),
content: content.to_string(),
timestamp: Utc::now(),
tool_name: tool_name.to_string(),
});
self.evict_if_needed(&mut snapshots);
version
}
pub fn list_versions(&self, path: &str) -> Vec<VersionSummary> {
let snapshots = read_or_recover(&self.snapshots);
snapshots
.get(path)
.map(|versions| versions.iter().map(VersionSummary::from).collect())
.unwrap_or_default()
}
pub fn list_files(&self) -> Vec<(String, usize)> {
let snapshots = read_or_recover(&self.snapshots);
snapshots
.iter()
.map(|(path, versions)| (path.clone(), versions.len()))
.collect()
}
pub fn get_version(&self, path: &str, version: usize) -> Option<FileSnapshot> {
let snapshots = read_or_recover(&self.snapshots);
snapshots
.get(path)
.and_then(|versions| versions.get(version).cloned())
}
pub fn get_latest(&self, path: &str) -> Option<FileSnapshot> {
let snapshots = read_or_recover(&self.snapshots);
snapshots
.get(path)
.and_then(|versions| versions.last().cloned())
}
pub fn diff(&self, path: &str, from_version: usize, to_version: usize) -> Option<String> {
let snapshots = read_or_recover(&self.snapshots);
let versions = snapshots.get(path)?;
let from = versions.get(from_version)?;
let to = versions.get(to_version)?;
Some(generate_unified_diff(
&from.content,
&to.content,
path,
from_version,
to_version,
))
}
pub fn diff_with_current(
&self,
path: &str,
version: usize,
current_content: &str,
) -> Option<String> {
let snapshots = read_or_recover(&self.snapshots);
let versions = snapshots.get(path)?;
let from = versions.get(version)?;
Some(generate_unified_diff(
&from.content,
current_content,
path,
version,
versions.len(), ))
}
pub fn total_snapshots(&self) -> usize {
let snapshots = read_or_recover(&self.snapshots);
snapshots.values().map(|v| v.len()).sum()
}
pub fn clear_file(&self, path: &str) {
let mut snapshots = write_or_recover(&self.snapshots);
snapshots.remove(path);
}
pub fn clear_all(&self) {
let mut snapshots = write_or_recover(&self.snapshots);
snapshots.clear();
}
fn evict_if_needed(&self, snapshots: &mut HashMap<String, Vec<FileSnapshot>>) {
let total: usize = snapshots.values().map(|v| v.len()).sum();
if total <= self.max_snapshots {
return;
}
let to_remove = total - self.max_snapshots;
let mut all_entries: Vec<(String, usize, DateTime<Utc>)> = Vec::new();
for (path, versions) in snapshots.iter() {
for snapshot in versions {
all_entries.push((path.clone(), snapshot.version, snapshot.timestamp));
}
}
all_entries.sort_by_key(|e| e.2);
for (path, version, _) in all_entries.into_iter().take(to_remove) {
if let Some(versions) = snapshots.get_mut(&path) {
versions.retain(|s| s.version != version);
if versions.is_empty() {
snapshots.remove(&path);
}
}
}
}
}
fn generate_unified_diff(
old: &str,
new: &str,
path: &str,
from_version: usize,
to_version: usize,
) -> String {
let diff = TextDiff::from_lines(old, new);
let mut output = String::new();
output.push_str(&format!("--- a/{} (version {})\n", path, from_version));
output.push_str(&format!("+++ b/{} (version {})\n", path, to_version));
for hunk in diff.unified_diff().context_radius(3).iter_hunks() {
output.push_str(&format!("{}", hunk));
}
output
}
pub fn is_file_modifying_tool(tool_name: &str) -> bool {
matches!(tool_name, "write" | "edit" | "patch")
}
pub fn extract_file_path(tool_name: &str, args: &serde_json::Value) -> Option<String> {
if is_file_modifying_tool(tool_name) {
args.get("file_path")
.and_then(|v| v.as_str())
.map(|s| s.to_string())
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_history() {
let history = FileHistory::new(100);
assert_eq!(history.total_snapshots(), 0);
assert!(history.list_files().is_empty());
}
#[test]
fn test_save_snapshot() {
let history = FileHistory::new(100);
let v = history.save_snapshot("test.rs", "fn main() {}", "write");
assert_eq!(v, 0);
assert_eq!(history.total_snapshots(), 1);
}
#[test]
fn test_save_multiple_snapshots() {
let history = FileHistory::new(100);
let v0 = history.save_snapshot("test.rs", "version 0", "write");
let v1 = history.save_snapshot("test.rs", "version 1", "edit");
let v2 = history.save_snapshot("test.rs", "version 2", "patch");
assert_eq!(v0, 0);
assert_eq!(v1, 1);
assert_eq!(v2, 2);
assert_eq!(history.total_snapshots(), 3);
}
#[test]
fn test_save_multiple_files() {
let history = FileHistory::new(100);
history.save_snapshot("a.rs", "content a", "write");
history.save_snapshot("b.rs", "content b", "write");
assert_eq!(history.total_snapshots(), 2);
assert_eq!(history.list_files().len(), 2);
}
#[test]
fn test_list_versions_empty() {
let history = FileHistory::new(100);
assert!(history.list_versions("nonexistent.rs").is_empty());
}
#[test]
fn test_list_versions() {
let history = FileHistory::new(100);
history.save_snapshot("test.rs", "v0", "write");
history.save_snapshot("test.rs", "v1", "edit");
let versions = history.list_versions("test.rs");
assert_eq!(versions.len(), 2);
assert_eq!(versions[0].version, 0);
assert_eq!(versions[0].tool_name, "write");
assert_eq!(versions[0].size, 2);
assert_eq!(versions[1].version, 1);
assert_eq!(versions[1].tool_name, "edit");
}
#[test]
fn test_get_version() {
let history = FileHistory::new(100);
history.save_snapshot("test.rs", "original", "write");
history.save_snapshot("test.rs", "modified", "edit");
let v0 = history.get_version("test.rs", 0).unwrap();
assert_eq!(v0.content, "original");
assert_eq!(v0.tool_name, "write");
let v1 = history.get_version("test.rs", 1).unwrap();
assert_eq!(v1.content, "modified");
}
#[test]
fn test_get_version_nonexistent() {
let history = FileHistory::new(100);
assert!(history.get_version("test.rs", 0).is_none());
history.save_snapshot("test.rs", "content", "write");
assert!(history.get_version("test.rs", 99).is_none());
}
#[test]
fn test_get_latest() {
let history = FileHistory::new(100);
assert!(history.get_latest("test.rs").is_none());
history.save_snapshot("test.rs", "v0", "write");
history.save_snapshot("test.rs", "v1", "edit");
let latest = history.get_latest("test.rs").unwrap();
assert_eq!(latest.content, "v1");
assert_eq!(latest.version, 1);
}
#[test]
fn test_diff_between_versions() {
let history = FileHistory::new(100);
history.save_snapshot("test.rs", "line1\nline2\nline3\n", "write");
history.save_snapshot("test.rs", "line1\nmodified\nline3\n", "edit");
let diff = history.diff("test.rs", 0, 1).unwrap();
assert!(diff.contains("--- a/test.rs (version 0)"));
assert!(diff.contains("+++ b/test.rs (version 1)"));
assert!(diff.contains("-line2"));
assert!(diff.contains("+modified"));
}
#[test]
fn test_diff_nonexistent_version() {
let history = FileHistory::new(100);
history.save_snapshot("test.rs", "content", "write");
assert!(history.diff("test.rs", 0, 5).is_none());
assert!(history.diff("nonexistent.rs", 0, 1).is_none());
}
#[test]
fn test_diff_same_version() {
let history = FileHistory::new(100);
history.save_snapshot("test.rs", "same content\n", "write");
let diff = history.diff("test.rs", 0, 0).unwrap();
assert!(diff.contains("--- a/test.rs"));
assert!(!diff.contains("-same content"));
}
#[test]
fn test_diff_with_current() {
let history = FileHistory::new(100);
history.save_snapshot("test.rs", "old\n", "write");
let diff = history.diff_with_current("test.rs", 0, "new\n").unwrap();
assert!(diff.contains("-old"));
assert!(diff.contains("+new"));
}
#[test]
fn test_list_files() {
let history = FileHistory::new(100);
history.save_snapshot("a.rs", "a", "write");
history.save_snapshot("b.rs", "b1", "write");
history.save_snapshot("b.rs", "b2", "edit");
let files = history.list_files();
assert_eq!(files.len(), 2);
let a_count = files.iter().find(|(p, _)| p == "a.rs").unwrap().1;
let b_count = files.iter().find(|(p, _)| p == "b.rs").unwrap().1;
assert_eq!(a_count, 1);
assert_eq!(b_count, 2);
}
#[test]
fn test_clear_file() {
let history = FileHistory::new(100);
history.save_snapshot("a.rs", "a", "write");
history.save_snapshot("b.rs", "b", "write");
history.clear_file("a.rs");
assert_eq!(history.total_snapshots(), 1);
assert!(history.list_versions("a.rs").is_empty());
assert_eq!(history.list_versions("b.rs").len(), 1);
}
#[test]
fn test_clear_all() {
let history = FileHistory::new(100);
history.save_snapshot("a.rs", "a", "write");
history.save_snapshot("b.rs", "b", "write");
history.clear_all();
assert_eq!(history.total_snapshots(), 0);
assert!(history.list_files().is_empty());
}
#[test]
fn test_eviction_when_over_limit() {
let history = FileHistory::new(3);
history.save_snapshot("test.rs", "v0", "write");
history.save_snapshot("test.rs", "v1", "edit");
history.save_snapshot("test.rs", "v2", "edit");
assert_eq!(history.total_snapshots(), 3);
history.save_snapshot("test.rs", "v3", "edit");
assert!(history.total_snapshots() <= 3);
}
#[test]
fn test_eviction_across_files() {
let history = FileHistory::new(3);
history.save_snapshot("a.rs", "a0", "write");
history.save_snapshot("b.rs", "b0", "write");
history.save_snapshot("c.rs", "c0", "write");
history.save_snapshot("d.rs", "d0", "write");
assert!(history.total_snapshots() <= 3);
}
#[test]
fn test_version_summary_from_snapshot() {
let snapshot = FileSnapshot {
version: 5,
path: "test.rs".to_string(),
content: "hello world".to_string(),
timestamp: Utc::now(),
tool_name: "edit".to_string(),
};
let summary = VersionSummary::from(&snapshot);
assert_eq!(summary.version, 5);
assert_eq!(summary.path, "test.rs");
assert_eq!(summary.tool_name, "edit");
assert_eq!(summary.size, 11); }
#[test]
fn test_is_file_modifying_tool() {
assert!(is_file_modifying_tool("write"));
assert!(is_file_modifying_tool("edit"));
assert!(is_file_modifying_tool("patch"));
assert!(!is_file_modifying_tool("read"));
assert!(!is_file_modifying_tool("bash"));
assert!(!is_file_modifying_tool("grep"));
assert!(!is_file_modifying_tool("glob"));
assert!(!is_file_modifying_tool("ls"));
}
#[test]
fn test_extract_file_path() {
let args = serde_json::json!({"file_path": "src/main.rs", "content": "hello"});
assert_eq!(
extract_file_path("write", &args),
Some("src/main.rs".to_string())
);
assert_eq!(
extract_file_path("edit", &args),
Some("src/main.rs".to_string())
);
assert_eq!(
extract_file_path("patch", &args),
Some("src/main.rs".to_string())
);
assert_eq!(extract_file_path("read", &args), None);
assert_eq!(extract_file_path("bash", &args), None);
}
#[test]
fn test_extract_file_path_missing() {
let args = serde_json::json!({"content": "hello"});
assert_eq!(extract_file_path("write", &args), None);
}
#[test]
fn test_generate_unified_diff() {
let old = "line1\nline2\nline3\n";
let new = "line1\nchanged\nline3\n";
let diff = generate_unified_diff(old, new, "test.rs", 0, 1);
assert!(diff.contains("--- a/test.rs (version 0)"));
assert!(diff.contains("+++ b/test.rs (version 1)"));
assert!(diff.contains("-line2"));
assert!(diff.contains("+changed"));
}
#[test]
fn test_generate_unified_diff_no_changes() {
let content = "same\n";
let diff = generate_unified_diff(content, content, "test.rs", 0, 0);
assert!(diff.contains("--- a/test.rs"));
assert!(!diff.contains("@@"));
}
#[test]
fn test_generate_unified_diff_addition() {
let old = "line1\nline3\n";
let new = "line1\nline2\nline3\n";
let diff = generate_unified_diff(old, new, "test.rs", 0, 1);
assert!(diff.contains("+line2"));
}
#[test]
fn test_generate_unified_diff_deletion() {
let old = "line1\nline2\nline3\n";
let new = "line1\nline3\n";
let diff = generate_unified_diff(old, new, "test.rs", 0, 1);
assert!(diff.contains("-line2"));
}
}