Skip to main content

hematite/agent/
diff_tracker.rs

1use similar::TextDiff;
2use std::collections::HashMap;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6/// Authoritative Turn Diff Tracker.
7/// Enables Hematite to proactively capture workspace mutations and
8/// generate high-precision unified diffs for human-in-the-loop verification.
9
10pub struct TurnDiffTracker {
11    /// Baseline snapshots: Path -> Original Content
12    baselines: HashMap<PathBuf, Vec<u8>>,
13}
14
15impl TurnDiffTracker {
16    pub fn new() -> Self {
17        Self {
18            baselines: HashMap::new(),
19        }
20    }
21
22    /// Capture a baseline snapshot of a file if it hasn't been seen yet this turn.
23    pub fn on_file_access(&mut self, path: &Path) {
24        if !self.baselines.contains_key(path) {
25            if path.exists() {
26                if let Ok(content) = fs::read(path) {
27                    self.baselines.insert(path.to_path_buf(), content);
28                }
29            } else {
30                // For new files, the baseline is empty
31                self.baselines.insert(path.to_path_buf(), Vec::new());
32            }
33        }
34    }
35
36    pub fn reset(&mut self) {
37        self.baselines.clear();
38    }
39
40    /// Generate an aggregated unified diff of all modifications tracked this turn.
41    pub fn generate_diff(&self) -> Result<String, String> {
42        if self.baselines.is_empty() {
43            return Ok(String::new());
44        }
45
46        let mut aggregated = String::new();
47        let mut sorted_paths: Vec<_> = self.baselines.keys().collect();
48        sorted_paths.sort();
49
50        for path in sorted_paths {
51            let original_bytes = self.baselines.get(path).unwrap();
52            let current_bytes = fs::read(path).unwrap_or_default();
53
54            if original_bytes == &current_bytes {
55                continue;
56            }
57
58            let original_text = String::from_utf8_lossy(original_bytes);
59            let current_text = String::from_utf8_lossy(&current_bytes);
60
61            let diff = TextDiff::from_lines(&original_text, &current_text);
62            let rel_path = path.to_string_lossy();
63
64            let unified = diff
65                .unified_diff()
66                .header(&format!("a/{}", rel_path), &format!("b/{}", rel_path))
67                .to_string();
68
69            aggregated.push_str(&unified);
70            aggregated.push('\n');
71        }
72
73        Ok(aggregated)
74    }
75}
76
77#[cfg(test)]
78mod tests {
79    use super::*;
80    use tempfile::tempdir;
81
82    #[test]
83    fn test_diff_generation() {
84        let dir = tempdir().unwrap();
85        let file_path = dir.path().join("test.txt");
86        fs::write(&file_path, "original line\n").unwrap();
87
88        let mut tracker = TurnDiffTracker::new();
89        tracker.on_file_access(&file_path);
90
91        fs::write(&file_path, "modified line\n").unwrap();
92
93        let diff = tracker.generate_diff().expect("Should have a diff");
94        assert!(diff.contains("-original line"));
95        assert!(diff.contains("+modified line"));
96    }
97}