Skip to main content

aft/
fuzzy_match.rs

1//! Fuzzy string matching for edit_match, inspired by opencode's 4-pass approach.
2//!
3//! When exact matching fails, progressively relaxes comparison:
4//!   Pass 1: Exact match (str::find / match_indices)
5//!   Pass 2: Trim trailing whitespace per line
6//!   Pass 3: Trim both ends per line
7//!   Pass 4: Normalize Unicode punctuation + trim
8
9/// A match result: byte offset in source and the matched byte length.
10#[derive(Debug, Clone)]
11pub struct FuzzyMatch {
12    pub byte_start: usize,
13    pub byte_len: usize,
14    /// Which pass found the match (1=exact, 2=rstrip, 3=trim, 4=unicode)
15    pub pass: u8,
16}
17
18/// Find all occurrences of `needle` in `haystack` using progressive fuzzy matching.
19/// Returns matches in order of their byte position in the source.
20pub fn find_all_fuzzy(haystack: &str, needle: &str) -> Vec<FuzzyMatch> {
21    // Pass 1: exact match (fast path)
22    let exact: Vec<FuzzyMatch> = haystack
23        .match_indices(needle)
24        .map(|(idx, _)| FuzzyMatch {
25            byte_start: idx,
26            byte_len: needle.len(),
27            pass: 1,
28        })
29        .collect();
30
31    if !exact.is_empty() {
32        return exact;
33    }
34
35    // For fuzzy passes, work line-by-line
36    let needle_lines: Vec<&str> = needle.lines().collect();
37    if needle_lines.is_empty() {
38        return vec![];
39    }
40
41    let haystack_lines: Vec<&str> = haystack.lines().collect();
42    let line_byte_offsets = compute_line_offsets(haystack);
43
44    // Pass 2: rstrip (trim trailing whitespace)
45    let rstrip_matches = find_line_matches(
46        &haystack_lines,
47        &needle_lines,
48        &line_byte_offsets,
49        haystack,
50        |a, b| a.trim_end() == b.trim_end(),
51        2,
52    );
53    if !rstrip_matches.is_empty() {
54        return rstrip_matches;
55    }
56
57    // Pass 3: trim (both ends)
58    let trim_matches = find_line_matches(
59        &haystack_lines,
60        &needle_lines,
61        &line_byte_offsets,
62        haystack,
63        |a, b| a.trim() == b.trim(),
64        3,
65    );
66    if !trim_matches.is_empty() {
67        return trim_matches;
68    }
69
70    // Pass 4: normalized Unicode + trim. Normalize each line once instead of
71    // allocating inside the O(haystack_lines × needle_lines) comparison loop.
72    let normalized_haystack_lines: Vec<String> = haystack_lines
73        .iter()
74        .map(|line| normalize_unicode(line.trim()))
75        .collect();
76    let normalized_needle_lines: Vec<String> = needle_lines
77        .iter()
78        .map(|line| normalize_unicode(line.trim()))
79        .collect();
80    let normalized_haystack_refs: Vec<&str> = normalized_haystack_lines
81        .iter()
82        .map(String::as_str)
83        .collect();
84    let normalized_needle_refs: Vec<&str> =
85        normalized_needle_lines.iter().map(String::as_str).collect();
86    let normalized_matches = find_line_matches(
87        &normalized_haystack_refs,
88        &normalized_needle_refs,
89        &line_byte_offsets,
90        haystack,
91        |a, b| a == b,
92        4,
93    );
94    normalized_matches
95}
96
97/// Compute byte offset of each line start in the source string.
98fn compute_line_offsets(source: &str) -> Vec<usize> {
99    let mut offsets = vec![0];
100    for (i, c) in source.char_indices() {
101        if c == '\n' && i + 1 <= source.len() {
102            offsets.push(i + 1);
103        }
104    }
105    offsets
106}
107
108/// Find all positions where `needle_lines` matches a contiguous sequence in `haystack_lines`.
109fn find_line_matches<F>(
110    haystack_lines: &[&str],
111    needle_lines: &[&str],
112    line_offsets: &[usize],
113    haystack: &str,
114    compare: F,
115    pass: u8,
116) -> Vec<FuzzyMatch>
117where
118    F: Fn(&str, &str) -> bool,
119{
120    let mut matches = Vec::new();
121    if needle_lines.len() > haystack_lines.len() {
122        return matches;
123    }
124
125    'outer: for i in 0..=(haystack_lines.len() - needle_lines.len()) {
126        for j in 0..needle_lines.len() {
127            if !compare(haystack_lines[i + j], needle_lines[j]) {
128                continue 'outer;
129            }
130        }
131        // Found a match at line `i` spanning `needle_lines.len()` lines
132        let byte_start = line_offsets[i];
133        let end_line = i + needle_lines.len();
134        let byte_end = if end_line < line_offsets.len() {
135            // Include the newline after the last matched line
136            line_offsets[end_line]
137        } else {
138            haystack.len()
139        };
140        matches.push(FuzzyMatch {
141            byte_start,
142            byte_len: byte_end - byte_start,
143            pass,
144        });
145    }
146
147    matches
148}
149
150/// Normalize Unicode punctuation to ASCII equivalents.
151fn normalize_unicode(s: &str) -> String {
152    s.chars()
153        .map(|c| match c {
154            '\u{2018}' | '\u{2019}' | '\u{201A}' | '\u{201B}' => '\'',
155            '\u{201C}' | '\u{201D}' | '\u{201E}' | '\u{201F}' => '"',
156            '\u{2010}' | '\u{2011}' | '\u{2012}' | '\u{2013}' | '\u{2014}' | '\u{2015}' => '-',
157            '\u{00A0}' => ' ',
158            _ => c,
159        })
160        .collect::<String>()
161        .replace('\u{2026}', "...")
162}
163
164#[cfg(test)]
165mod tests {
166    use super::*;
167
168    #[test]
169    fn test_exact_match() {
170        let matches = find_all_fuzzy("hello world", "world");
171        assert_eq!(matches.len(), 1);
172        assert_eq!(matches[0].byte_start, 6);
173        assert_eq!(matches[0].pass, 1);
174    }
175
176    #[test]
177    fn test_exact_match_multiple() {
178        let matches = find_all_fuzzy("foo bar foo baz foo", "foo");
179        assert_eq!(matches.len(), 3);
180        assert_eq!(matches[0].byte_start, 0);
181        assert_eq!(matches[1].byte_start, 8);
182        assert_eq!(matches[2].byte_start, 16);
183    }
184
185    #[test]
186    fn test_rstrip_match() {
187        let source = "  hello  \n  world  \n";
188        let needle = "  hello\n  world";
189        let matches = find_all_fuzzy(source, needle);
190        assert_eq!(matches.len(), 1);
191        assert_eq!(matches[0].pass, 2); // rstrip pass
192    }
193
194    #[test]
195    fn test_trim_match() {
196        let source = "    function foo() {\n      return 1;\n    }\n";
197        let needle = "function foo() {\n  return 1;\n}";
198        let matches = find_all_fuzzy(source, needle);
199        assert_eq!(matches.len(), 1);
200        assert_eq!(matches[0].pass, 3); // trim pass
201    }
202
203    #[test]
204    fn test_unicode_normalize() {
205        let source = "let msg = \u{201C}hello\u{201D}\n";
206        let needle = "let msg = \"hello\"";
207        let matches = find_all_fuzzy(source, needle);
208        assert_eq!(matches.len(), 1);
209        assert_eq!(matches[0].pass, 4); // unicode pass
210    }
211
212    #[test]
213    fn test_unicode_normalize_multiline_variants() {
214        let source = "alpha\n  let title = \u{201C}hello\u{201D}\u{2026}\n  let slug = foo\u{2014}bar\u{00A0}baz\nomega\n";
215        let needle = "let title = \"hello\"...\nlet slug = foo-bar baz";
216        let matches = find_all_fuzzy(source, needle);
217
218        assert_eq!(matches.len(), 1);
219        assert_eq!(matches[0].pass, 4);
220        assert_eq!(matches[0].byte_start, source.find("  let title").unwrap());
221    }
222
223    #[test]
224    fn test_no_match() {
225        let matches = find_all_fuzzy("hello world", "xyz");
226        assert!(matches.is_empty());
227    }
228
229    #[test]
230    fn test_multiline_exact() {
231        let source = "line1\nline2\nline3\nline4\n";
232        let needle = "line2\nline3";
233        let matches = find_all_fuzzy(source, needle);
234        assert_eq!(matches.len(), 1);
235        assert_eq!(matches[0].byte_start, 6);
236        assert_eq!(matches[0].pass, 1);
237    }
238}