Skip to main content

aft/
fuzzy_match.rs

1//! Fuzzy string matching for edit_match, inspired by opencode's 4-pass approach.
2//!
3//! When exact matching fails, progressively relaxes comparison:
4//!   Pass 1: Exact match (str::find / match_indices)
5//!   Pass 2: Trim trailing whitespace per line
6//!   Pass 3: Trim both ends per line
7//!   Pass 4: Normalize Unicode punctuation + trim
8
9/// A match result: byte offset in source and the matched byte length.
10#[derive(Debug, Clone)]
11pub struct FuzzyMatch {
12    pub byte_start: usize,
13    pub byte_len: usize,
14    /// Which pass found the match (1=exact, 2=rstrip, 3=trim, 4=unicode)
15    pub pass: u8,
16}
17
18/// Find all occurrences of `needle` in `haystack` using progressive fuzzy matching.
19/// Returns matches in order of their byte position in the source.
20pub fn find_all_fuzzy(haystack: &str, needle: &str) -> Vec<FuzzyMatch> {
21    // Pass 1: exact match (fast path)
22    let exact: Vec<FuzzyMatch> = haystack
23        .match_indices(needle)
24        .map(|(idx, _)| FuzzyMatch {
25            byte_start: idx,
26            byte_len: needle.len(),
27            pass: 1,
28        })
29        .collect();
30
31    if !exact.is_empty() {
32        return exact;
33    }
34
35    // For fuzzy passes, work line-by-line
36    let needle_lines: Vec<&str> = needle.lines().collect();
37    if needle_lines.is_empty() {
38        return vec![];
39    }
40
41    let haystack_lines: Vec<&str> = haystack.lines().collect();
42    let line_byte_offsets = compute_line_offsets(haystack);
43
44    // Pass 2: rstrip (trim trailing whitespace)
45    let rstrip_matches = find_line_matches(
46        &haystack_lines,
47        &needle_lines,
48        &line_byte_offsets,
49        haystack,
50        |a, b| a.trim_end() == b.trim_end(),
51        2,
52    );
53    if !rstrip_matches.is_empty() {
54        return rstrip_matches;
55    }
56
57    // Pass 3: trim (both ends)
58    let trim_matches = find_line_matches(
59        &haystack_lines,
60        &needle_lines,
61        &line_byte_offsets,
62        haystack,
63        |a, b| a.trim() == b.trim(),
64        3,
65    );
66    if !trim_matches.is_empty() {
67        return trim_matches;
68    }
69
70    // Pass 4: normalized Unicode + trim
71    let normalized_matches = find_line_matches(
72        &haystack_lines,
73        &needle_lines,
74        &line_byte_offsets,
75        haystack,
76        |a, b| normalize_unicode(a.trim()) == normalize_unicode(b.trim()),
77        4,
78    );
79    normalized_matches
80}
81
82/// Compute byte offset of each line start in the source string.
83fn compute_line_offsets(source: &str) -> Vec<usize> {
84    let mut offsets = vec![0];
85    for (i, c) in source.char_indices() {
86        if c == '\n' && i + 1 <= source.len() {
87            offsets.push(i + 1);
88        }
89    }
90    offsets
91}
92
93/// Find all positions where `needle_lines` matches a contiguous sequence in `haystack_lines`.
94fn find_line_matches<F>(
95    haystack_lines: &[&str],
96    needle_lines: &[&str],
97    line_offsets: &[usize],
98    haystack: &str,
99    compare: F,
100    pass: u8,
101) -> Vec<FuzzyMatch>
102where
103    F: Fn(&str, &str) -> bool,
104{
105    let mut matches = Vec::new();
106    if needle_lines.len() > haystack_lines.len() {
107        return matches;
108    }
109
110    'outer: for i in 0..=(haystack_lines.len() - needle_lines.len()) {
111        for j in 0..needle_lines.len() {
112            if !compare(haystack_lines[i + j], needle_lines[j]) {
113                continue 'outer;
114            }
115        }
116        // Found a match at line `i` spanning `needle_lines.len()` lines
117        let byte_start = line_offsets[i];
118        let end_line = i + needle_lines.len();
119        let byte_end = if end_line < line_offsets.len() {
120            // Include the newline after the last matched line
121            line_offsets[end_line]
122        } else {
123            haystack.len()
124        };
125        matches.push(FuzzyMatch {
126            byte_start,
127            byte_len: byte_end - byte_start,
128            pass,
129        });
130    }
131
132    matches
133}
134
135/// Normalize Unicode punctuation to ASCII equivalents.
136fn normalize_unicode(s: &str) -> String {
137    s.chars()
138        .map(|c| match c {
139            '\u{2018}' | '\u{2019}' | '\u{201A}' | '\u{201B}' => '\'',
140            '\u{201C}' | '\u{201D}' | '\u{201E}' | '\u{201F}' => '"',
141            '\u{2010}' | '\u{2011}' | '\u{2012}' | '\u{2013}' | '\u{2014}' | '\u{2015}' => '-',
142            '\u{00A0}' => ' ',
143            _ => c,
144        })
145        .collect::<String>()
146        .replace('\u{2026}', "...")
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152
153    #[test]
154    fn test_exact_match() {
155        let matches = find_all_fuzzy("hello world", "world");
156        assert_eq!(matches.len(), 1);
157        assert_eq!(matches[0].byte_start, 6);
158        assert_eq!(matches[0].pass, 1);
159    }
160
161    #[test]
162    fn test_exact_match_multiple() {
163        let matches = find_all_fuzzy("foo bar foo baz foo", "foo");
164        assert_eq!(matches.len(), 3);
165        assert_eq!(matches[0].byte_start, 0);
166        assert_eq!(matches[1].byte_start, 8);
167        assert_eq!(matches[2].byte_start, 16);
168    }
169
170    #[test]
171    fn test_rstrip_match() {
172        let source = "  hello  \n  world  \n";
173        let needle = "  hello\n  world";
174        let matches = find_all_fuzzy(source, needle);
175        assert_eq!(matches.len(), 1);
176        assert_eq!(matches[0].pass, 2); // rstrip pass
177    }
178
179    #[test]
180    fn test_trim_match() {
181        let source = "    function foo() {\n      return 1;\n    }\n";
182        let needle = "function foo() {\n  return 1;\n}";
183        let matches = find_all_fuzzy(source, needle);
184        assert_eq!(matches.len(), 1);
185        assert_eq!(matches[0].pass, 3); // trim pass
186    }
187
188    #[test]
189    fn test_unicode_normalize() {
190        let source = "let msg = \u{201C}hello\u{201D}\n";
191        let needle = "let msg = \"hello\"";
192        let matches = find_all_fuzzy(source, needle);
193        assert_eq!(matches.len(), 1);
194        assert_eq!(matches[0].pass, 4); // unicode pass
195    }
196
197    #[test]
198    fn test_no_match() {
199        let matches = find_all_fuzzy("hello world", "xyz");
200        assert!(matches.is_empty());
201    }
202
203    #[test]
204    fn test_multiline_exact() {
205        let source = "line1\nline2\nline3\nline4\n";
206        let needle = "line2\nline3";
207        let matches = find_all_fuzzy(source, needle);
208        assert_eq!(matches.len(), 1);
209        assert_eq!(matches[0].byte_start, 6);
210        assert_eq!(matches[0].pass, 1);
211    }
212}