Skip to main content

oxi_agent/tools/
edit_diff.rs

1/// Edit diff computation engine
2/// Computes unified diffs for edit previews. Supports:
3/// - Multiple non-overlapping edits in one call
4/// - Line ending normalization (CRLF → LF)
5/// - BOM detection and preservation
6/// - Fuzzy matching fallback
7use std::fmt;
8
9/// A single edit operation
10#[derive(Debug, Clone)]
11pub struct Edit {
12    /// pub.
13    pub old_text: String,
14    /// pub.
15    pub new_text: String,
16}
17
18/// Result of computing diffs for a set of edits
19#[derive(Debug, Clone)]
20pub struct EditDiffResult {
21    /// The unified diff string
22    pub diff: String,
23    /// Line number of the first change in the new file (for editor navigation)
24    pub first_changed_line: Option<usize>,
25}
26
27/// Error during diff computation
28#[derive(Debug, Clone)]
29pub struct EditDiffError {
30    /// pub.
31    pub message: String,
32}
33
34impl fmt::Display for EditDiffError {
35    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
36        write!(f, "{}", self.message)
37    }
38}
39
40/// Detect line ending type of content
41pub fn detect_line_ending(content: &str) -> LineEnding {
42    if content.contains("\r\n") {
43        LineEnding::Crlf
44    } else if content.contains('\r') {
45        LineEnding::Cr
46    } else {
47        LineEnding::Lf
48    }
49}
50
51/// Line ending type
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum LineEnding {
54    /// lf variant.
55    Lf,
56    /// crlf variant.
57    Crlf,
58    /// cr variant.
59    Cr,
60}
61
62/// Normalize content to LF line endings for diff computation
63pub fn normalize_to_lf(content: &str) -> String {
64    content.replace("\r\n", "\n").replace('\r', "\n")
65}
66
67/// Restore original line endings
68pub fn restore_line_endings(content: &str, ending: LineEnding) -> String {
69    match ending {
70        LineEnding::Crlf => content.replace('\n', "\r\n"),
71        LineEnding::Cr => content.replace('\n', "\r"),
72        LineEnding::Lf => content.to_string(),
73    }
74}
75
76/// Strip UTF-8 BOM from content
77pub fn strip_bom(content: &str) -> &str {
78    content.strip_prefix('\u{feff}').unwrap_or(content)
79}
80
81/// Check if content starts with BOM
82pub fn has_bom(content: &str) -> bool {
83    content.starts_with('\u{feff}')
84}
85
86/// Apply multiple edits to normalized (LF) content.
87/// Validates that edits don't overlap, are unique, and all find matches.
88pub fn apply_edits_to_normalized_content(
89    content: &str,
90    edits: &[Edit],
91) -> Result<String, EditDiffError> {
92    if edits.is_empty() {
93        return Ok(content.to_string());
94    }
95
96    // Find all match positions first
97    let mut matches: Vec<(usize, usize, &Edit)> = Vec::new(); // (start, end, edit)
98
99    for edit in edits {
100        // Reject empty old_text
101        if edit.old_text.is_empty() {
102            return Err(EditDiffError {
103                message: "old_text cannot be empty. Match must be unique in the file.".to_string(),
104            });
105        }
106
107        // Check for multiple occurrences of this edit's old_text
108        let first_pos = content.find(&edit.old_text).ok_or_else(|| EditDiffError {
109            message: "Text to replace not found in file. Make sure to match the exact text including whitespace and newlines.".to_string(),
110        })?;
111
112        // Count total occurrences
113        let mut search_start = 0;
114        let mut occurrence_count = 0;
115        let mut first_found = false;
116        while let Some(pos) = content[search_start..].find(&edit.old_text) {
117            let actual_pos = search_start + pos;
118            if !first_found {
119                // Already found first match at first_pos
120                first_found = true;
121            }
122            occurrence_count += 1;
123            search_start = actual_pos + 1;
124        }
125
126        // If more than one occurrence, reject as ambiguous
127        if occurrence_count > 1 {
128            return Err(EditDiffError {
129                message: format!(
130                    "Edit rejected: '{}' appears {} times in the file. Matches must be unique. Provide more context to disambiguate.",
131                    edit.old_text.chars().take(50).collect::<String>(),
132                    occurrence_count
133                ),
134            });
135        }
136
137        let end = first_pos + edit.old_text.len();
138
139        // Check for overlaps with existing matches
140        for &(existing_start, existing_end, _) in &matches {
141            if first_pos < existing_end && end > existing_start {
142                return Err(EditDiffError {
143                    message: "Edits overlap — merge nearby edits into one.".to_string(),
144                });
145            }
146        }
147
148        matches.push((first_pos, end, edit));
149    }
150
151    // Sort by position (reverse order for safe replacement)
152    matches.sort_by_key(|b| std::cmp::Reverse(b.0));
153
154    let mut result = content.to_string();
155    for (start, end, edit) in matches {
156        result.replace_range(start..end, &edit.new_text);
157    }
158
159    Ok(result)
160}
161
162/// Compute a unified diff between original and modified content.
163/// Returns the diff string and the first changed line number.
164pub fn compute_edits_diff(original: &str, modified: &str, context_lines: usize) -> EditDiffResult {
165    let orig_lines: Vec<&str> = original.lines().collect();
166    let mod_lines: Vec<&str> = modified.lines().collect();
167
168    let mut diff = String::new();
169    let mut first_changed_line: Option<usize> = None;
170
171    // Simple line-by-line diff using longest common subsequence approach
172    let lcs = compute_lcs_table(&orig_lines, &mod_lines);
173    let mut diff_ops = Vec::new();
174    build_diff_ops(
175        &lcs,
176        &orig_lines,
177        &mod_lines,
178        orig_lines.len(),
179        mod_lines.len(),
180        &mut diff_ops,
181    );
182
183    // Group into hunks with context
184    let hunks = group_into_hunks(&diff_ops, &orig_lines, &mod_lines, context_lines);
185
186    for (i, hunk) in hunks.iter().enumerate() {
187        if i > 0 {
188            diff.push('\n');
189        }
190
191        if first_changed_line.is_none() {
192            first_changed_line = Some(hunk.new_start);
193        }
194
195        diff.push_str(&format!(
196            "@@ -{},{} +{},{} @@\n",
197            hunk.old_start + 1,
198            hunk.old_count,
199            hunk.new_start + 1,
200            hunk.new_count,
201        ));
202
203        for line in &hunk.lines {
204            match line {
205                DiffLine::Context(s) => diff.push_str(&format!(" {}\n", s)),
206                DiffLine::Remove(s) => diff.push_str(&format!("-{}\n", s)),
207                DiffLine::Add(s) => diff.push_str(&format!("+{}\n", s)),
208            }
209        }
210    }
211
212    EditDiffResult {
213        diff,
214        first_changed_line,
215    }
216}
217
218/// Generate a diff string from a set of edits applied to content
219pub fn generate_diff_string(
220    content: &str,
221    edits: &[Edit],
222    context_lines: usize,
223) -> Result<EditDiffResult, EditDiffError> {
224    let normalized = normalize_to_lf(strip_bom(content));
225    let modified = apply_edits_to_normalized_content(&normalized, edits)?;
226    Ok(compute_edits_diff(&normalized, &modified, context_lines))
227}
228
229// LCS table computation for diff
230fn compute_lcs_table(a: &[&str], b: &[&str]) -> Vec<Vec<usize>> {
231    let m = a.len();
232    let n = b.len();
233    let mut dp = vec![vec![0; n + 1]; m + 1];
234
235    for i in 1..=m {
236        for j in 1..=n {
237            if a[i - 1] == b[j - 1] {
238                dp[i][j] = dp[i - 1][j - 1] + 1;
239            } else {
240                dp[i][j] = dp[i - 1][j].max(dp[i][j - 1]);
241            }
242        }
243    }
244
245    dp
246}
247
248#[derive(Debug, Clone)]
249enum DiffOp {
250    Equal(usize, usize), // (old_idx, new_idx)
251    Remove(usize),       // old_idx
252    Add(usize),          // new_idx
253}
254
255fn build_diff_ops(
256    dp: &[Vec<usize>],
257    a: &[&str],
258    b: &[&str],
259    i: usize,
260    j: usize,
261    ops: &mut Vec<DiffOp>,
262) {
263    if i > 0 && j > 0 && a[i - 1] == b[j - 1] {
264        build_diff_ops(dp, a, b, i - 1, j - 1, ops);
265        ops.push(DiffOp::Equal(i - 1, j - 1));
266    } else if j > 0 && (i == 0 || dp[i][j - 1] >= dp[i - 1][j]) {
267        build_diff_ops(dp, a, b, i, j - 1, ops);
268        ops.push(DiffOp::Add(j - 1));
269    } else if i > 0 {
270        build_diff_ops(dp, a, b, i - 1, j, ops);
271        ops.push(DiffOp::Remove(i - 1));
272    }
273}
274
275#[derive(Debug)]
276enum DiffLine<'a> {
277    Context(&'a str),
278    Remove(&'a str),
279    Add(&'a str),
280}
281
282struct Hunk<'a> {
283    old_start: usize,
284    old_count: usize,
285    new_start: usize,
286    new_count: usize,
287    lines: Vec<DiffLine<'a>>,
288}
289
290fn group_into_hunks<'a>(
291    ops: &[DiffOp],
292    old_lines: &[&'a str],
293    new_lines: &[&'a str],
294    context: usize,
295) -> Vec<Hunk<'a>> {
296    // Find change boundaries
297    let mut changes: Vec<usize> = Vec::new();
298    for (i, op) in ops.iter().enumerate() {
299        match op {
300            DiffOp::Remove(_) | DiffOp::Add(_) => changes.push(i),
301            DiffOp::Equal(_, _) => {}
302        }
303    }
304
305    if changes.is_empty() {
306        return Vec::new();
307    }
308
309    // Group changes into hunks
310    let mut hunks: Vec<Hunk<'a>> = Vec::new();
311    let mut hunk_start = changes[0];
312    let mut hunk_end = changes[0];
313
314    for &change_idx in &changes[1..] {
315        if change_idx <= hunk_end + 2 * context {
316            // Within context range, extend current hunk
317            hunk_end = change_idx;
318        } else {
319            // Start a new hunk
320            hunks.push(build_hunk(
321                ops, old_lines, new_lines, hunk_start, hunk_end, context,
322            ));
323            hunk_start = change_idx;
324            hunk_end = change_idx;
325        }
326    }
327    hunks.push(build_hunk(
328        ops, old_lines, new_lines, hunk_start, hunk_end, context,
329    ));
330
331    hunks
332}
333
334fn build_hunk<'a>(
335    ops: &[DiffOp],
336    old_lines: &[&'a str],
337    new_lines: &[&'a str],
338    change_start: usize,
339    change_end: usize,
340    context: usize,
341) -> Hunk<'a> {
342    let start = change_start.saturating_sub(context);
343    let end = (change_end + context + 1).min(ops.len());
344
345    let mut lines = Vec::new();
346    let mut _old_pos = usize::MAX;
347    let mut _new_pos = usize::MAX;
348    let mut old_count = 0;
349    let mut new_count = 0;
350    let mut first_old = None;
351    let mut first_new = None;
352
353    for op in ops.iter().take(end).skip(start) {
354        match op {
355            DiffOp::Equal(oi, ni) => {
356                if first_old.is_none() {
357                    first_old = Some(*oi);
358                }
359                if first_new.is_none() {
360                    first_new = Some(*ni);
361                }
362                _old_pos = *oi;
363                _new_pos = *ni;
364                lines.push(DiffLine::Context(old_lines[*oi]));
365                old_count += 1;
366                new_count += 1;
367            }
368            DiffOp::Remove(oi) => {
369                if first_old.is_none() {
370                    first_old = Some(*oi);
371                }
372                _old_pos = *oi;
373                lines.push(DiffLine::Remove(old_lines[*oi]));
374                old_count += 1;
375            }
376            DiffOp::Add(ni) => {
377                if first_new.is_none() {
378                    first_new = Some(*ni);
379                }
380                _new_pos = *ni;
381                lines.push(DiffLine::Add(new_lines[*ni]));
382                new_count += 1;
383            }
384        }
385    }
386
387    Hunk {
388        old_start: first_old.unwrap_or(0),
389        old_count,
390        new_start: first_new.unwrap_or(0),
391        new_count,
392        lines,
393    }
394}
395
396#[cfg(test)]
397mod tests {
398    use super::*;
399
400    #[test]
401    fn test_normalize_to_lf() {
402        assert_eq!(normalize_to_lf("a\r\nb\r\n"), "a\nb\n");
403        assert_eq!(normalize_to_lf("a\nb\n"), "a\nb\n");
404    }
405
406    #[test]
407    fn test_detect_line_ending() {
408        assert_eq!(detect_line_ending("a\r\nb"), LineEnding::Crlf);
409        assert_eq!(detect_line_ending("a\nb"), LineEnding::Lf);
410        assert_eq!(detect_line_ending("a\rb"), LineEnding::Cr);
411    }
412
413    #[test]
414    fn test_strip_bom() {
415        assert_eq!(strip_bom("\u{feff}hello"), "hello");
416        assert_eq!(strip_bom("hello"), "hello");
417    }
418
419    #[test]
420    fn test_apply_edits_simple() {
421        let content = "hello world\nfoo bar\n";
422        let edits = vec![Edit {
423            old_text: "hello world".to_string(),
424            new_text: "hello Rust".to_string(),
425        }];
426        let result = apply_edits_to_normalized_content(content, &edits).unwrap();
427        assert_eq!(result, "hello Rust\nfoo bar\n");
428    }
429
430    #[test]
431    fn test_apply_multiple_edits() {
432        let content = "aaa\nbbb\nccc\n";
433        let edits = vec![
434            Edit {
435                old_text: "aaa".to_string(),
436                new_text: "AAA".to_string(),
437            },
438            Edit {
439                old_text: "ccc".to_string(),
440                new_text: "CCC".to_string(),
441            },
442        ];
443        let result = apply_edits_to_normalized_content(content, &edits).unwrap();
444        assert_eq!(result, "AAA\nbbb\nCCC\n");
445    }
446
447    #[test]
448    fn test_apply_overlapping_edits_fails() {
449        let content = "aaa\nbbb\nccc\n";
450        let edits = vec![
451            Edit {
452                old_text: "aaa\nbbb".to_string(),
453                new_text: "AAA".to_string(),
454            },
455            Edit {
456                old_text: "bbb\nccc".to_string(),
457                new_text: "CCC".to_string(),
458            },
459        ];
460        let result = apply_edits_to_normalized_content(content, &edits);
461        assert!(result.is_err());
462    }
463
464    #[test]
465    fn test_apply_not_found_fails() {
466        let content = "hello world";
467        let edits = vec![Edit {
468            old_text: "not found".to_string(),
469            new_text: "replacement".to_string(),
470        }];
471        let result = apply_edits_to_normalized_content(content, &edits);
472        assert!(result.is_err());
473    }
474
475    #[test]
476    fn test_compute_diff() {
477        let original = "line1\nline2\nline3\n";
478        let modified = "line1\nmodified\nline3\n";
479        let result = compute_edits_diff(original, modified, 1);
480        assert!(result.diff.contains("-line2"));
481        assert!(result.diff.contains("+modified"));
482        assert_eq!(result.first_changed_line, Some(0)); // 0-indexed
483    }
484
485    #[test]
486    fn test_generate_diff_string() {
487        let content = "hello world\nfoo bar\n";
488        let edits = vec![Edit {
489            old_text: "hello world".to_string(),
490            new_text: "hello Rust".to_string(),
491        }];
492        let result = generate_diff_string(content, &edits, 2).unwrap();
493        assert!(result.diff.contains("-hello world"));
494        assert!(result.diff.contains("+hello Rust"));
495    }
496}