corrode_mcp/mcp/
patch.rs

1use std::{borrow::Cow, str::FromStr};
2
3use anyhow::{Context as _, Result};
4
5/// Represents the range of lines in a hunk header
6#[derive(Clone, Debug)]
7pub struct HeaderRange {
8    /// The line number the patch starts at
9    pub start: usize,
10    /// The line numbers visible for the patch
11    pub range: usize,
12}
13
14/// Represents the header of a hunk in a patch
15#[derive(Clone, Debug)]
16pub struct HunkHeader {
17    pub source: HeaderRange,
18    #[allow(dead_code)]
19    pub dest: HeaderRange,
20
21    // Optional values after fixing the ranges
22    pub fixed_source: Option<HeaderRange>,
23    pub fixed_dest: Option<HeaderRange>,
24}
25
26/// Represents a line in a hunk
27#[derive(Clone, Debug, strum_macros::EnumIs)]
28pub enum HunkLine {
29    Context(String),
30    Added(String),
31    Removed(String),
32}
33
34impl HunkLine {
35    pub fn content(&self) -> &str {
36        match self {
37            HunkLine::Removed(s) | HunkLine::Context(s) | HunkLine::Added(s) => s,
38        }
39    }
40
41    pub fn as_patch_line(&self) -> Cow<str> {
42        match self {
43            HunkLine::Context(s) => Cow::Owned(format!(" {s}")),
44            HunkLine::Added(s) => Cow::Owned(format!("+{s}")),
45            HunkLine::Removed(s) => Cow::Owned(format!("-{s}")),
46        }
47    }
48}
49
50/// Represents a hunk in a patch
51#[derive(Clone, Debug)]
52pub struct Hunk {
53    /// The parsed header of the hunk
54    pub header: HunkHeader,
55
56    /// Parsed lines of the hunk
57    pub lines: Vec<HunkLine>,
58
59    /// The original full hunk body
60    pub body: String,
61}
62
63impl<'a> From<&'a Hunk> for Cow<'a, Hunk> {
64    fn from(val: &'a Hunk) -> Self {
65        Cow::Borrowed(val)
66    }
67}
68
69impl From<Hunk> for Cow<'_, Hunk> {
70    fn from(val: Hunk) -> Self {
71        Cow::Owned(val)
72    }
73}
74
75impl Hunk {
76    fn matchable_lines(&self) -> impl Iterator<Item = &HunkLine> {
77        self.lines
78            .iter()
79            .filter(|l| l.is_removed() || l.is_context())
80    }
81
82    /// Inserts a line at the given index on matchable lines. Converts the index to the actual
83    /// underlying index
84    pub fn insert_line_at(&mut self, line: HunkLine, index: usize) {
85        self.lines.insert(self.real_index(index), line);
86    }
87
88    pub fn real_index(&self, index: usize) -> usize {
89        self.lines
90            .iter()
91            .enumerate()
92            .filter(|(_, l)| l.is_removed() || l.is_context())
93            .nth(index)
94            .map_or_else(|| self.lines.len(), |(i, _)| i)
95    }
96
97    pub fn matches(&self, line: &str, index: usize, log: bool) -> bool {
98        let expected = self
99            .matchable_lines()
100            .skip(index)
101            .map(HunkLine::content)
102            .next();
103
104        // let outcome = expected.map(str::trim) == Some(line.trim());
105        let outcome = expected == Some(line);
106
107        if log {
108            if outcome {
109                // Calculate mismatching leading whitespace
110                tracing::trace!(line, expected, "Matched line");
111            } else {
112                tracing::trace!(line, expected, "Did not match line");
113            }
114        }
115        outcome
116    }
117
118    pub fn render_updated(&self) -> Result<String> {
119        // Extract any context after the second @@ block to add to the new header line
120        // i.e. with `@@ -1,2 +2,1 @@ my_function()` we want my_function() to be included
121        let header_context = self
122            .body
123            .lines()
124            .next()
125            .unwrap_or_default()
126            .rsplit("@@")
127            .next()
128            .unwrap_or_default();
129
130        let source = self
131            .header
132            .fixed_source
133            .as_ref()
134            .context("Expected updated source")?;
135        let dest = self
136            .header
137            .fixed_dest
138            .as_ref()
139            .context("Expected updated dest")?;
140
141        let mut updated = format!(
142            "@@ -{},{} +{},{} @@{header_context}\n",
143            source.start + 1,
144            source.range,
145            dest.start + 1,
146            dest.range
147        );
148
149        for line in &self.lines {
150            updated.push_str(&line.as_patch_line());
151            updated.push('\n');
152        }
153
154        Ok(updated.to_string())
155    }
156}
157
158/// A hunk that is found in a file
159#[derive(Clone, Debug)]
160pub struct Candidate<'a> {
161    /// The line number in the file we started at
162    start: usize,
163
164    /// The current line we are matching against
165    current_line: usize,
166
167    hunk: Cow<'a, Hunk>,
168}
169
170impl<'a> Candidate<'a> {
171    pub fn new(line: usize, hunk: impl Into<Cow<'a, Hunk>>) -> Self {
172        Self {
173            start: line,
174            current_line: 0,
175            hunk: hunk.into(),
176        }
177    }
178
179    /// Number difference in visible lines between the source and destination for the next hunk
180    ///
181    /// If lines were added, the following hunk will start at an increased line number, if lines
182    /// were removed, the following hunk will start at a decreased line number
183    #[allow(clippy::cast_possible_wrap)]
184    pub fn offset(&self) -> isize {
185        self.hunk.lines.iter().filter(|l| l.is_added()).count() as isize
186            - self.hunk.lines.iter().filter(|l| l.is_removed()).count() as isize
187    }
188
189    pub fn next_line_matches(&self, line: &str) -> bool {
190        self.hunk.matches(line, self.current_line, true)
191    }
192
193    pub fn is_complete(&self) -> bool {
194        // We increment one over the current line, so if we are at the end of the hunk, we are done
195        self.current_line == self.hunk.matchable_lines().count()
196    }
197
198    pub fn updated_source_header(&self) -> HeaderRange {
199        let source_lines = self
200            .hunk
201            .lines
202            .iter()
203            .filter(|l| l.is_removed() || l.is_context())
204            .count();
205
206        let source_start = self.start;
207
208        HeaderRange {
209            start: source_start,
210            range: source_lines,
211        }
212    }
213
214    pub fn updated_dest_header(&self, offset: isize) -> HeaderRange {
215        let dest_lines = self
216            .hunk
217            .lines
218            .iter()
219            .filter(|l| l.is_added() || l.is_context())
220            .count();
221
222        // The offset is the sum off removed and added lines by preceding hunks
223        let dest_start = self.start.saturating_add_signed(offset);
224
225        HeaderRange {
226            start: dest_start,
227            range: dest_lines,
228        }
229    }
230}
231
232impl FromStr for Hunk {
233    type Err = anyhow::Error;
234
235    fn from_str(s: &str) -> Result<Self, Self::Err> {
236        let header: HunkHeader = s.parse()?;
237        let lines = s
238            .lines()
239            .skip(1)
240            .map(FromStr::from_str)
241            .collect::<Result<Vec<HunkLine>>>()?;
242
243        Ok(Hunk {
244            header,
245            lines,
246            body: s.into(),
247        })
248    }
249}
250
251impl FromStr for HunkLine {
252    type Err = anyhow::Error;
253
254    fn from_str(s: &str) -> Result<Self, Self::Err> {
255        if let Some(line) = s.strip_prefix('+') {
256            Ok(HunkLine::Added(line.into()))
257        } else if let Some(line) = s.strip_prefix('-') {
258            Ok(HunkLine::Removed(line.into()))
259        } else {
260            let s = s.strip_prefix(' ').unwrap_or(s);
261            Ok(HunkLine::Context(s.into()))
262        }
263    }
264}
265
266impl FromStr for HunkHeader {
267    type Err = anyhow::Error;
268
269    fn from_str(s: &str) -> Result<Self, Self::Err> {
270        if !s.starts_with("@@") {
271            anyhow::bail!("Hunk header must start with @@");
272        }
273
274        let parts: Vec<&str> = s.split_whitespace().collect();
275        if parts.len() < 4 {
276            anyhow::bail!("Invalid hunk header format");
277        }
278
279        let old_range = parts[1].split(',').collect::<Vec<&str>>();
280        let new_range = parts[2].split(',').collect::<Vec<&str>>();
281
282        if old_range.len() != 2 || new_range.len() != 2 {
283            anyhow::bail!("Invalid range format in hunk header");
284        }
285
286        let old_lines = HeaderRange {
287            start: old_range[0]
288                .replace('-', "")
289                .parse()
290                .context("Invalid old start line")?,
291            range: old_range[1].parse().context("Invalid old range")?,
292        };
293
294        let new_lines = HeaderRange {
295            start: new_range[0]
296                .replace('+', "")
297                .parse()
298                .context("Invalid new start line")?,
299            range: new_range[1].parse().context("Invalid new range")?,
300        };
301
302        Ok(HunkHeader {
303            source: old_lines,
304            dest: new_lines,
305            fixed_source: None,
306            fixed_dest: None,
307        })
308    }
309}
310
311/// Parses the hunks from a patch
312pub fn parse_hunks(patch: &str) -> Result<Vec<Hunk>> {
313    let mut hunks = Vec::new();
314    let mut current_hunk_lines = Vec::new();
315
316    for line in patch.lines() {
317        if line.starts_with("@@") {
318            if !current_hunk_lines.is_empty() {
319                let hunk = Hunk::from_str(&current_hunk_lines.join("\n"))?;
320                hunks.push(hunk);
321            }
322
323            current_hunk_lines = vec![line];
324        } else if !current_hunk_lines.is_empty() {
325            current_hunk_lines.push(line);
326        }
327    }
328
329    if !current_hunk_lines.is_empty() {
330        let hunk = Hunk::from_str(&current_hunk_lines.join("\n"))?;
331        hunks.push(hunk);
332    }
333
334    Ok(hunks)
335}
336
337/// For each hunks, finds potential candidates in the file
338///
339/// llms are dumb and cannot count
340///
341/// However, with a patch we can reasonably fix the headers
342/// by searching in the neighboring lines of the original hunk header
343pub fn find_candidates<'a>(content: &str, hunks: &'a [Hunk]) -> Vec<Candidate<'a>> {
344    let mut candidates = Vec::new();
345
346    for (line_n, line) in content.lines().enumerate() {
347        // 1. Check if a hunk matches the line, then create a candidate if it does
348        if let Some(hunk) = hunks.iter().find(|h| h.matches(line, 0, false)) {
349            tracing::trace!(line, "Found hunk match; creating new candidate");
350            candidates.push(Candidate::new(line_n, hunk));
351        }
352
353        // 2. For each active candidate, check if the next line matches. If it does, increment the
354        // the index of the candidate. Otherwise, remove the candidate
355        let mut new_candidates = Vec::new();
356        candidates.retain_mut(|c| {
357            if c.is_complete() {
358                true
359            } else if c.next_line_matches(line) {
360                tracing::trace!(line, "Candidate matched line");
361                c.current_line += 1;
362                true
363            } else if line.trim().is_empty() {
364                tracing::trace!(line, "Current line is empty; keeping candidate around");
365                // We create a new candidate with a whitespace line added at the index of this
366                // candidate. This helps with LLMs misjudging whitespace in the context
367                let mut new_hunk: Hunk = c.hunk.clone().into_owned();
368                new_hunk.insert_line_at(HunkLine::Context(line.into()), c.current_line);
369                let mut new_candidate = Candidate::new(c.start, new_hunk);
370                new_candidate.current_line = c.current_line + 1;
371
372                new_candidates.push(new_candidate);
373                false
374            } else if c
375                .hunk
376                .lines.iter()
377                .skip(c.hunk.real_index(c.current_line + 1))
378                .all(HunkLine::is_context)
379            {
380                // If the following remaining lines, including this one, are context only, accept
381                // the current AI overlords incompetence and add a finished candidate without the
382                // remaining lines.
383                tracing::trace!(line, "Mismatch; remaining is context only, adding finished candidate without the remaining lines");
384                let real_index = c.hunk.real_index(c.current_line);
385                let mut new_hunk = c.hunk.clone().into_owned();
386                new_hunk.lines = new_hunk
387                    .lines
388                    .iter()
389                    .take(real_index)
390                    .cloned()
391                    .collect();
392
393                let mut new_candidate = Candidate::new(c.start, new_hunk);
394                new_candidate.current_line = c.current_line;
395                new_candidates.push(new_candidate);
396                false
397            } else {
398                tracing::trace!(line, "Removing candidate");
399                false
400            }
401        });
402        candidates.append(&mut new_candidates);
403    }
404
405    candidates
406}
407
408/// Takes a list of candidates and rebuits the hunk headers
409///
410/// Filters out duplicates. The resulting hunks should result in a valid patch.
411pub fn rebuild_hunks(candidates: &[Candidate<'_>]) -> Vec<Hunk> {
412    // Assume that the candidates are sorted by the start line
413    // Then we can just iterate over the candidates and update the ranges
414
415    let mut current_offset: isize = 0;
416    let mut hunks: Vec<Hunk> = Vec::new();
417
418    for candidate in candidates {
419        let source_header = candidate.updated_source_header();
420
421        let dest_header = candidate.updated_dest_header(current_offset);
422        current_offset += candidate.offset();
423
424        // Could probably continue the cow, but at this point the number of hunks should be small
425        let mut hunk = candidate.hunk.clone().into_owned();
426        hunk.header.fixed_source = Some(source_header);
427        hunk.header.fixed_dest = Some(dest_header);
428
429        // Filter duplicates. A hunk is a duplicate if the hunk body is the same. If a duplicate
430        // is detected, prefer the one with the fixed_source closest to the original source line
431        // If so, we swap it with the existing hunk.
432
433        if let Some(existing) = hunks.iter_mut().find(|h| *h.body == hunk.body) {
434            let (Some(existing_source), Some(new_source)) =
435                (&existing.header.fixed_source, &hunk.header.fixed_source)
436            else {
437                tracing::warn!("Potential bad duplicate when rebuilding patch; could be a bug, please check the edit");
438                continue;
439            };
440
441            #[allow(clippy::cast_possible_wrap)]
442            if ((existing_source.start as isize)
443                .saturating_sub_unsigned(existing.header.source.start))
444            .abs()
445                < ((new_source.start as isize).saturating_sub_unsigned(hunk.header.source.start))
446                    .abs()
447            {
448                continue;
449            }
450            *existing = hunk;
451        } else {
452            hunks.push(hunk);
453        }
454    }
455
456    hunks
457}
458
459/// Takes the file lines from the original patch if possible, then rebuilds the patch
460pub fn rebuild_patch(original: &str, hunks: &[Hunk]) -> Result<String> {
461    let mut new_patch = original.lines().take(2).collect::<Vec<_>>().join("\n");
462    new_patch.push('\n');
463
464    debug_assert!(
465        !new_patch.is_empty(),
466        "Original file lines in patch tools are empty"
467    );
468
469    for hunk in hunks {
470        new_patch.push_str(&hunk.render_updated()?);
471    }
472
473    Ok(new_patch)
474}