hematite/agent/
diff_tracker.rs1use similar::TextDiff;
2use std::collections::HashMap;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6pub struct TurnDiffTracker {
10 baselines: HashMap<PathBuf, Vec<u8>>,
12}
13
14impl Default for TurnDiffTracker {
15 fn default() -> Self {
16 Self::new()
17 }
18}
19
20impl TurnDiffTracker {
21 pub fn new() -> Self {
22 Self {
23 baselines: HashMap::new(),
24 }
25 }
26
27 pub fn on_file_access(&mut self, path: &Path) {
29 if !self.baselines.contains_key(path) {
30 if path.exists() {
31 if let Ok(content) = fs::read(path) {
32 self.baselines.insert(path.to_path_buf(), content);
33 }
34 } else {
35 self.baselines.insert(path.to_path_buf(), Vec::new());
37 }
38 }
39 }
40
41 pub fn reset(&mut self) {
42 self.baselines.clear();
43 }
44
45 pub fn generate_diff(&self) -> Result<String, String> {
47 if self.baselines.is_empty() {
48 return Ok(String::new());
49 }
50
51 let mut aggregated = String::with_capacity(self.baselines.len() * 512);
52 let mut sorted_paths: Vec<_> = self.baselines.keys().collect();
53 sorted_paths.sort_unstable();
54
55 for path in sorted_paths {
56 let original_bytes = self.baselines.get(path).unwrap();
57 let current_bytes = fs::read(path).unwrap_or_default();
58
59 if original_bytes == ¤t_bytes {
60 continue;
61 }
62
63 let original_text = String::from_utf8_lossy(original_bytes);
64 let current_text = String::from_utf8_lossy(¤t_bytes);
65
66 let diff = TextDiff::from_lines(&original_text, ¤t_text);
67 let rel_path = path.to_string_lossy();
68
69 let unified = diff
70 .unified_diff()
71 .header(&format!("a/{}", rel_path), &format!("b/{}", rel_path))
72 .to_string();
73
74 aggregated.push_str(&unified);
75 aggregated.push('\n');
76 }
77
78 Ok(aggregated)
79 }
80}
81
82#[cfg(test)]
83mod tests {
84 use super::*;
85 use tempfile::tempdir;
86
87 #[test]
88 fn test_diff_generation() {
89 let dir = tempdir().unwrap();
90 let file_path = dir.path().join("test.txt");
91 fs::write(&file_path, "original line\n").unwrap();
92
93 let mut tracker = TurnDiffTracker::new();
94 tracker.on_file_access(&file_path);
95
96 fs::write(&file_path, "modified line\n").unwrap();
97
98 let diff = tracker.generate_diff().expect("Should have a diff");
99 assert!(diff.contains("-original line"));
100 assert!(diff.contains("+modified line"));
101 }
102}