Skip to main content

hematite/agent/
diff_tracker.rs

1use similar::TextDiff;
2use std::collections::HashMap;
3use std::fs;
4use std::path::{Path, PathBuf};
5
6/// Authoritative Turn Diff Tracker.
7/// Enables Hematite to proactively capture workspace mutations and
8/// generate high-precision unified diffs for human-in-the-loop verification.
9pub struct TurnDiffTracker {
10    /// Baseline snapshots: Path -> Original Content
11    baselines: HashMap<PathBuf, Vec<u8>>,
12}
13
14impl Default for TurnDiffTracker {
15    fn default() -> Self {
16        Self::new()
17    }
18}
19
20impl TurnDiffTracker {
21    pub fn new() -> Self {
22        Self {
23            baselines: HashMap::new(),
24        }
25    }
26
27    /// Capture a baseline snapshot of a file if it hasn't been seen yet this turn.
28    pub fn on_file_access(&mut self, path: &Path) {
29        if !self.baselines.contains_key(path) {
30            if path.exists() {
31                if let Ok(content) = fs::read(path) {
32                    self.baselines.insert(path.to_path_buf(), content);
33                }
34            } else {
35                // For new files, the baseline is empty
36                self.baselines.insert(path.to_path_buf(), Vec::new());
37            }
38        }
39    }
40
41    pub fn reset(&mut self) {
42        self.baselines.clear();
43    }
44
45    /// Generate an aggregated unified diff of all modifications tracked this turn.
46    pub fn generate_diff(&self) -> Result<String, String> {
47        if self.baselines.is_empty() {
48            return Ok(String::new());
49        }
50
51        let mut aggregated = String::with_capacity(self.baselines.len() * 512);
52        let mut sorted_paths: Vec<_> = self.baselines.keys().collect();
53        sorted_paths.sort_unstable();
54
55        for path in sorted_paths {
56            let original_bytes = self.baselines.get(path).unwrap();
57            let current_bytes = fs::read(path).unwrap_or_default();
58
59            if original_bytes == &current_bytes {
60                continue;
61            }
62
63            let original_text = String::from_utf8_lossy(original_bytes);
64            let current_text = String::from_utf8_lossy(&current_bytes);
65
66            let diff = TextDiff::from_lines(&original_text, &current_text);
67            let rel_path = path.to_string_lossy();
68
69            let unified = diff
70                .unified_diff()
71                .header(&format!("a/{}", rel_path), &format!("b/{}", rel_path))
72                .to_string();
73
74            aggregated.push_str(&unified);
75            aggregated.push('\n');
76        }
77
78        Ok(aggregated)
79    }
80}
81
82#[cfg(test)]
83mod tests {
84    use super::*;
85    use tempfile::tempdir;
86
87    #[test]
88    fn test_diff_generation() {
89        let dir = tempdir().unwrap();
90        let file_path = dir.path().join("test.txt");
91        fs::write(&file_path, "original line\n").unwrap();
92
93        let mut tracker = TurnDiffTracker::new();
94        tracker.on_file_access(&file_path);
95
96        fs::write(&file_path, "modified line\n").unwrap();
97
98        let diff = tracker.generate_diff().expect("Should have a diff");
99        assert!(diff.contains("-original line"));
100        assert!(diff.contains("+modified line"));
101    }
102}