Skip to main content

cersei_tools/tool_primitives/
diff.rs

1//! Text diffing primitives using the `similar` crate.
2//!
3//! Pure functions — no I/O, no async. Produces unified diffs,
4//! structured line diffs, and can apply patches.
5
6use similar::{ChangeTag as SimilarTag, TextDiff};
7use std::fmt;
8
9/// Type of change for a diff line.
10#[derive(Debug, Clone, Copy, PartialEq, Eq)]
11pub enum ChangeTag {
12    Added,
13    Removed,
14    Unchanged,
15}
16
17/// A single line in a structured diff.
18#[derive(Debug, Clone)]
19pub struct DiffLine {
20    pub tag: ChangeTag,
21    pub line_number_old: Option<usize>,
22    pub line_number_new: Option<usize>,
23    pub content: String,
24}
25
26/// Error when applying a patch fails.
27#[derive(Debug)]
28pub struct PatchError {
29    pub message: String,
30}
31
32impl fmt::Display for PatchError {
33    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34        write!(f, "patch error: {}", self.message)
35    }
36}
37
38impl std::error::Error for PatchError {}
39
40/// Produce a unified diff string (standard format with @@ hunk headers).
41///
42/// ```rust,ignore
43/// let diff = unified_diff("hello\nworld\n", "hello\nearth\n", 3);
44/// assert!(diff.contains("-world"));
45/// assert!(diff.contains("+earth"));
46/// ```
47pub fn unified_diff(old: &str, new: &str, context_lines: usize) -> String {
48    let diff = TextDiff::from_lines(old, new);
49    diff.unified_diff()
50        .context_radius(context_lines)
51        .header("old", "new")
52        .to_string()
53}
54
55/// Return a structured per-line diff.
56///
57/// Each line includes the change tag, old/new line numbers, and content.
58pub fn line_diff(old: &str, new: &str) -> Vec<DiffLine> {
59    let diff = TextDiff::from_lines(old, new);
60    let mut result = Vec::new();
61    let mut old_line: usize = 1;
62    let mut new_line: usize = 1;
63
64    for change in diff.iter_all_changes() {
65        let tag = match change.tag() {
66            SimilarTag::Equal => ChangeTag::Unchanged,
67            SimilarTag::Insert => ChangeTag::Added,
68            SimilarTag::Delete => ChangeTag::Removed,
69        };
70
71        let (ln_old, ln_new) = match tag {
72            ChangeTag::Unchanged => {
73                let r = (Some(old_line), Some(new_line));
74                old_line += 1;
75                new_line += 1;
76                r
77            }
78            ChangeTag::Removed => {
79                let r = (Some(old_line), None);
80                old_line += 1;
81                r
82            }
83            ChangeTag::Added => {
84                let r = (None, Some(new_line));
85                new_line += 1;
86                r
87            }
88        };
89
90        result.push(DiffLine {
91            tag,
92            line_number_old: ln_old,
93            line_number_new: ln_new,
94            content: change.to_string_lossy().to_string(),
95        });
96    }
97
98    result
99}
100
101/// Apply a unified diff patch to the original text.
102///
103/// Returns the patched text, or an error if the patch doesn't apply cleanly.
104/// This is a simple line-based patch applicator — it handles standard unified
105/// diff format with `@@` hunk headers and `+`/`-`/` ` line prefixes.
106pub fn apply_patch(original: &str, patch: &str) -> Result<String, PatchError> {
107    let original_lines: Vec<&str> = original.lines().collect();
108    let mut result_lines: Vec<String> = Vec::new();
109    let mut orig_idx: usize = 0;
110
111    let patch_lines: Vec<&str> = patch.lines().collect();
112    let mut patch_idx: usize = 0;
113
114    // Skip header lines (---, +++, etc.)
115    while patch_idx < patch_lines.len() {
116        let line = patch_lines[patch_idx];
117        if line.starts_with("@@") {
118            break;
119        }
120        patch_idx += 1;
121    }
122
123    while patch_idx < patch_lines.len() {
124        let line = patch_lines[patch_idx];
125
126        if line.starts_with("@@") {
127            // Parse hunk header: @@ -old_start,old_count +new_start,new_count @@
128            let parts: Vec<&str> = line.split_whitespace().collect();
129            if parts.len() < 3 {
130                return Err(PatchError {
131                    message: format!("malformed hunk header: {}", line),
132                });
133            }
134
135            let old_part = parts[1].trim_start_matches('-');
136            let old_start: usize = old_part
137                .split(',')
138                .next()
139                .and_then(|s| s.parse().ok())
140                .unwrap_or(1);
141
142            // Copy unchanged lines before this hunk
143            while orig_idx + 1 < old_start && orig_idx < original_lines.len() {
144                result_lines.push(original_lines[orig_idx].to_string());
145                orig_idx += 1;
146            }
147
148            patch_idx += 1;
149            continue;
150        }
151
152        if line.starts_with('-') {
153            // Remove line — skip it in original
154            orig_idx += 1;
155        } else if line.starts_with('+') {
156            // Add line
157            result_lines.push(line[1..].to_string());
158        } else if line.starts_with(' ') || line.is_empty() {
159            // Context line — copy from original
160            if orig_idx < original_lines.len() {
161                result_lines.push(original_lines[orig_idx].to_string());
162                orig_idx += 1;
163            }
164        }
165
166        patch_idx += 1;
167    }
168
169    // Copy remaining original lines after last hunk
170    while orig_idx < original_lines.len() {
171        result_lines.push(original_lines[orig_idx].to_string());
172        orig_idx += 1;
173    }
174
175    Ok(result_lines.join("\n"))
176}
177
178// ─── Tests ─────────────────────────────────────────────────────────────────
179
180#[cfg(test)]
181mod tests {
182    use super::*;
183
184    #[test]
185    fn test_unified_diff_basic() {
186        let old = "hello\nworld\n";
187        let new = "hello\nearth\n";
188        let diff = unified_diff(old, new, 3);
189        assert!(diff.contains("-world"));
190        assert!(diff.contains("+earth"));
191        assert!(diff.contains("@@"));
192    }
193
194    #[test]
195    fn test_unified_diff_identical() {
196        let text = "same\ncontent\n";
197        let diff = unified_diff(text, text, 3);
198        assert!(diff.is_empty() || !diff.contains("@@"));
199    }
200
201    #[test]
202    fn test_line_diff_basic() {
203        let old = "a\nb\nc\n";
204        let new = "a\nB\nc\n";
205        let lines = line_diff(old, new);
206
207        let removed: Vec<_> = lines
208            .iter()
209            .filter(|l| l.tag == ChangeTag::Removed)
210            .collect();
211        let added: Vec<_> = lines.iter().filter(|l| l.tag == ChangeTag::Added).collect();
212
213        assert_eq!(removed.len(), 1);
214        assert_eq!(added.len(), 1);
215        assert!(removed[0].content.contains('b'));
216        assert!(added[0].content.contains('B'));
217    }
218
219    #[test]
220    fn test_line_diff_empty() {
221        let lines = line_diff("", "");
222        assert!(lines.is_empty());
223    }
224
225    #[test]
226    fn test_apply_patch_basic() {
227        let old = "hello\nworld\nfoo\n";
228        let new = "hello\nearth\nfoo\n";
229        let patch = unified_diff(old, new, 3);
230        let result = apply_patch(old, &patch).unwrap();
231        assert!(result.contains("earth"));
232        assert!(!result.contains("world"));
233    }
234
235    #[test]
236    fn test_line_numbers() {
237        let old = "a\nb\nc\n";
238        let new = "a\nc\n";
239        let lines = line_diff(old, new);
240
241        let removed = lines.iter().find(|l| l.tag == ChangeTag::Removed).unwrap();
242        assert_eq!(removed.line_number_old, Some(2));
243        assert_eq!(removed.line_number_new, None);
244    }
245}