Skip to main content

aft/
fuzzy_match.rs

1//! Fuzzy string matching for edit_match, inspired by opencode's 4-pass approach.
2//!
3//! When exact matching fails, progressively relaxes comparison:
4//!   Pass 1: Exact match (str::find / match_indices)
5//!   Pass 2: Trim trailing whitespace per line
6//!   Pass 3: Trim both ends per line
7//!   Pass 4: Normalize Unicode punctuation + trim
8//!   Pass 5: Reflowed line wraps/joins with whitespace-normalized content
9
10/// A match result: byte offset in source and the matched byte length.
11#[derive(Debug, Clone)]
12pub struct FuzzyMatch {
13    pub byte_start: usize,
14    pub byte_len: usize,
15    /// Which pass found the match (1=exact, 2=rstrip, 3=trim, 4=unicode, 5=reflow)
16    pub pass: u8,
17}
18
19/// Find all occurrences of `needle` in `haystack` using progressive fuzzy matching.
20/// Returns matches in order of their byte position in the source.
21pub fn find_all_fuzzy(haystack: &str, needle: &str) -> Vec<FuzzyMatch> {
22    // Pass 1: exact match (fast path)
23    let exact: Vec<FuzzyMatch> = haystack
24        .match_indices(needle)
25        .map(|(idx, _)| FuzzyMatch {
26            byte_start: idx,
27            byte_len: needle.len(),
28            pass: 1,
29        })
30        .collect();
31
32    if !exact.is_empty() {
33        return exact;
34    }
35
36    // For fuzzy passes, work line-by-line
37    let needle_lines: Vec<&str> = needle.lines().collect();
38    if needle_lines.is_empty() {
39        return vec![];
40    }
41
42    let haystack_lines: Vec<&str> = haystack.lines().collect();
43    let line_byte_offsets = compute_line_offsets(haystack);
44
45    // Pass 2: rstrip (trim trailing whitespace)
46    let rstrip_matches = find_line_matches(
47        &haystack_lines,
48        &needle_lines,
49        &line_byte_offsets,
50        haystack,
51        |a, b| a.trim_end() == b.trim_end(),
52        2,
53    );
54    if !rstrip_matches.is_empty() {
55        return rstrip_matches;
56    }
57
58    // Pass 3: trim (both ends)
59    let trim_matches = find_line_matches(
60        &haystack_lines,
61        &needle_lines,
62        &line_byte_offsets,
63        haystack,
64        |a, b| a.trim() == b.trim(),
65        3,
66    );
67    if !trim_matches.is_empty() {
68        return trim_matches;
69    }
70
71    // Pass 4: normalized Unicode + trim. Normalize each line once instead of
72    // allocating inside the O(haystack_lines × needle_lines) comparison loop.
73    let normalized_haystack_lines: Vec<String> = haystack_lines
74        .iter()
75        .map(|line| normalize_unicode(line.trim()))
76        .collect();
77    let normalized_needle_lines: Vec<String> = needle_lines
78        .iter()
79        .map(|line| normalize_unicode(line.trim()))
80        .collect();
81    let normalized_haystack_refs: Vec<&str> = normalized_haystack_lines
82        .iter()
83        .map(String::as_str)
84        .collect();
85    let normalized_needle_refs: Vec<&str> =
86        normalized_needle_lines.iter().map(String::as_str).collect();
87    let normalized_matches = find_line_matches(
88        &normalized_haystack_refs,
89        &normalized_needle_refs,
90        &line_byte_offsets,
91        haystack,
92        |a, b| a == b,
93        4,
94    );
95    if !normalized_matches.is_empty() {
96        return normalized_matches;
97    }
98
99    // Pass 5: final fallback for formatter reflows. This pass deliberately
100    // runs only after every line-contiguous pass fails, and each candidate
101    // window must have the same non-whitespace content as the needle.
102    find_reflow_matches(&haystack_lines, &needle_lines, &line_byte_offsets, haystack)
103}
104
105/// Compute byte offset of each line start in the source string.
106fn compute_line_offsets(source: &str) -> Vec<usize> {
107    let mut offsets = vec![0];
108    for (i, c) in source.char_indices() {
109        if c == '\n' && i + 1 <= source.len() {
110            offsets.push(i + 1);
111        }
112    }
113    offsets
114}
115
116/// Find all positions where `needle_lines` matches a contiguous sequence in `haystack_lines`.
117fn find_line_matches<F>(
118    haystack_lines: &[&str],
119    needle_lines: &[&str],
120    line_offsets: &[usize],
121    haystack: &str,
122    compare: F,
123    pass: u8,
124) -> Vec<FuzzyMatch>
125where
126    F: Fn(&str, &str) -> bool,
127{
128    let mut matches = Vec::new();
129    if needle_lines.len() > haystack_lines.len() {
130        return matches;
131    }
132
133    'outer: for i in 0..=(haystack_lines.len() - needle_lines.len()) {
134        for j in 0..needle_lines.len() {
135            if !compare(haystack_lines[i + j], needle_lines[j]) {
136                continue 'outer;
137            }
138        }
139        // Found a match at line `i` spanning `needle_lines.len()` lines
140        let byte_start = line_offsets[i];
141        let end_line = i + needle_lines.len();
142        let byte_end = if end_line < line_offsets.len() {
143            // Include the newline after the last matched line
144            line_offsets[end_line]
145        } else {
146            haystack.len()
147        };
148        matches.push(FuzzyMatch {
149            byte_start,
150            byte_len: byte_end - byte_start,
151            pass,
152        });
153    }
154
155    matches
156}
157
158const REFLOW_NON_WS_TOLERANCE: usize = 8;
159
160fn find_reflow_matches(
161    haystack_lines: &[&str],
162    needle_lines: &[&str],
163    line_offsets: &[usize],
164    haystack: &str,
165) -> Vec<FuzzyMatch> {
166    let needle_text = needle_lines.join("\n");
167    let normalized_needle = normalize_reflow_whitespace(&needle_text);
168    let needle_non_whitespace = strip_reflow_whitespace(&needle_text);
169    if normalized_needle.is_empty() || needle_non_whitespace.is_empty() {
170        return Vec::new();
171    }
172
173    let min_non_whitespace = needle_non_whitespace
174        .len()
175        .saturating_sub(REFLOW_NON_WS_TOLERANCE);
176    let max_non_whitespace = needle_non_whitespace.len() + REFLOW_NON_WS_TOLERANCE;
177    let line_non_whitespace_lens: Vec<usize> = haystack_lines
178        .iter()
179        .map(|line| strip_reflow_whitespace(line).len())
180        .collect();
181    let mut matches = Vec::new();
182
183    for start in 0..haystack_lines.len() {
184        if !has_reflow_content(haystack_lines[start]) {
185            continue;
186        }
187
188        let mut window_non_whitespace_len = 0usize;
189        for end in (start + 1)..=haystack_lines.len() {
190            let line = haystack_lines[end - 1];
191            window_non_whitespace_len += line_non_whitespace_lens[end - 1];
192
193            if window_non_whitespace_len > max_non_whitespace {
194                break;
195            }
196            if window_non_whitespace_len < min_non_whitespace {
197                continue;
198            }
199            if !has_reflow_content(line) {
200                continue;
201            }
202
203            let window_text = haystack_lines[start..end].join("\n");
204            let window_non_whitespace = strip_reflow_whitespace(&window_text);
205            if window_non_whitespace != needle_non_whitespace {
206                continue;
207            }
208            if normalize_reflow_whitespace(&window_text) != normalized_needle {
209                continue;
210            }
211
212            let byte_start = line_offsets[start];
213            let byte_end = if end < line_offsets.len() {
214                line_offsets[end]
215            } else {
216                haystack.len()
217            };
218            matches.push(FuzzyMatch {
219                byte_start,
220                byte_len: byte_end - byte_start,
221                pass: 5,
222            });
223        }
224    }
225
226    matches
227}
228
229fn normalize_reflow_whitespace(s: &str) -> String {
230    let mut normalized = String::new();
231    let mut in_whitespace = false;
232
233    for c in s.trim().chars() {
234        if c.is_whitespace() {
235            in_whitespace = true;
236        } else {
237            if in_whitespace && !normalized.is_empty() {
238                normalized.push(' ');
239            }
240            normalized.push(c);
241            in_whitespace = false;
242        }
243    }
244
245    normalized
246}
247
248fn strip_reflow_whitespace(s: &str) -> String {
249    s.chars().filter(|c| !c.is_whitespace()).collect()
250}
251
252fn has_reflow_content(s: &str) -> bool {
253    s.chars().any(|c| !c.is_whitespace())
254}
255
256/// Normalize Unicode punctuation to ASCII equivalents.
257fn normalize_unicode(s: &str) -> String {
258    s.chars()
259        .map(|c| match c {
260            '\u{2018}' | '\u{2019}' | '\u{201A}' | '\u{201B}' => '\'',
261            '\u{201C}' | '\u{201D}' | '\u{201E}' | '\u{201F}' => '"',
262            '\u{2010}' | '\u{2011}' | '\u{2012}' | '\u{2013}' | '\u{2014}' | '\u{2015}' => '-',
263            '\u{00A0}' => ' ',
264            _ => c,
265        })
266        .collect::<String>()
267        .replace('\u{2026}', "...")
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn test_exact_match() {
276        let matches = find_all_fuzzy("hello world", "world");
277        assert_eq!(matches.len(), 1);
278        assert_eq!(matches[0].byte_start, 6);
279        assert_eq!(matches[0].pass, 1);
280    }
281
282    #[test]
283    fn test_exact_match_multiple() {
284        let matches = find_all_fuzzy("foo bar foo baz foo", "foo");
285        assert_eq!(matches.len(), 3);
286        assert_eq!(matches[0].byte_start, 0);
287        assert_eq!(matches[1].byte_start, 8);
288        assert_eq!(matches[2].byte_start, 16);
289    }
290
291    #[test]
292    fn test_rstrip_match() {
293        let source = "  hello  \n  world  \n";
294        let needle = "  hello\n  world";
295        let matches = find_all_fuzzy(source, needle);
296        assert_eq!(matches.len(), 1);
297        assert_eq!(matches[0].pass, 2); // rstrip pass
298    }
299
300    #[test]
301    fn test_trim_match() {
302        let source = "    function foo() {\n      return 1;\n    }\n";
303        let needle = "function foo() {\n  return 1;\n}";
304        let matches = find_all_fuzzy(source, needle);
305        assert_eq!(matches.len(), 1);
306        assert_eq!(matches[0].pass, 3); // trim pass
307    }
308
309    #[test]
310    fn test_unicode_normalize() {
311        let source = "let msg = \u{201C}hello\u{201D}\n";
312        let needle = "let msg = \"hello\"";
313        let matches = find_all_fuzzy(source, needle);
314        assert_eq!(matches.len(), 1);
315        assert_eq!(matches[0].pass, 4); // unicode pass
316    }
317
318    #[test]
319    fn test_unicode_normalize_multiline_variants() {
320        let source = "alpha\n  let title = \u{201C}hello\u{201D}\u{2026}\n  let slug = foo\u{2014}bar\u{00A0}baz\nomega\n";
321        let needle = "let title = \"hello\"...\nlet slug = foo-bar baz";
322        let matches = find_all_fuzzy(source, needle);
323
324        assert_eq!(matches.len(), 1);
325        assert_eq!(matches[0].pass, 4);
326        assert_eq!(matches[0].byte_start, source.find("  let title").unwrap());
327    }
328
329    #[test]
330    fn test_no_match() {
331        let matches = find_all_fuzzy("hello world", "xyz");
332        assert!(matches.is_empty());
333    }
334
335    #[test]
336    fn test_multiline_exact() {
337        let source = "line1\nline2\nline3\nline4\n";
338        let needle = "line2\nline3";
339        let matches = find_all_fuzzy(source, needle);
340        assert_eq!(matches.len(), 1);
341        assert_eq!(matches[0].byte_start, 6);
342        assert_eq!(matches[0].pass, 1);
343    }
344
345    #[test]
346    fn test_reflow_one_line_needle_matches_three_line_split() {
347        let source = "before\nlet total = alpha +\n    beta +\n    gamma;\nafter\n";
348        let needle = "let total = alpha + beta + gamma;";
349        let matches = find_all_fuzzy(source, needle);
350
351        assert_eq!(matches.len(), 1);
352        assert_eq!(matches[0].pass, 5);
353        assert_eq!(matches[0].byte_start, source.find("let total").unwrap());
354        assert_eq!(
355            &source[matches[0].byte_start..matches[0].byte_start + matches[0].byte_len],
356            "let total = alpha +\n    beta +\n    gamma;\n"
357        );
358    }
359
360    #[test]
361    fn test_reflow_three_line_needle_matches_one_line_join() {
362        let source = "before\nlet total = alpha + beta + gamma;\nafter\n";
363        let needle = "let total = alpha +\n    beta +\n    gamma;";
364        let matches = find_all_fuzzy(source, needle);
365
366        assert_eq!(matches.len(), 1);
367        assert_eq!(matches[0].pass, 5);
368        assert_eq!(matches[0].byte_start, source.find("let total").unwrap());
369        assert_eq!(
370            &source[matches[0].byte_start..matches[0].byte_start + matches[0].byte_len],
371            "let total = alpha + beta + gamma;\n"
372        );
373    }
374
375    #[test]
376    fn test_reflow_reports_all_ambiguous_windows() {
377        let source =
378            "let total = alpha +\n  beta +\n  gamma;\n\nlet total = alpha +\n  beta +\n  gamma;\n";
379        let needle = "let total = alpha + beta + gamma;";
380        let matches = find_all_fuzzy(source, needle);
381
382        assert_eq!(matches.len(), 2);
383        assert!(matches.iter().all(|m| m.pass == 5));
384    }
385
386    #[test]
387    fn test_reflow_near_miss_does_not_match() {
388        let source = "let total = alpha +\n  beta +\n  gamma;\n";
389        let needle = "let total = alpha + beta + delta;";
390        let matches = find_all_fuzzy(source, needle);
391
392        assert!(matches.is_empty());
393    }
394
395    #[test]
396    fn test_reflow_does_not_preempt_exact_match() {
397        let source = "let total = alpha +\n  beta +\n  gamma;\nlet total = alpha + beta + gamma;\n";
398        let needle = "let total = alpha + beta + gamma;";
399        let matches = find_all_fuzzy(source, needle);
400
401        assert_eq!(matches.len(), 1);
402        assert_eq!(matches[0].pass, 1);
403        assert_eq!(matches[0].byte_start, source.rfind("let total").unwrap());
404    }
405}