fuzzy_matcher/
clangd.rs

1///! The fuzzy matching algorithm used in clangd.
2///! https://github.com/llvm-mirror/clang-tools-extra/blob/master/clangd/FuzzyMatch.cpp
3///!
4///! # Example:
5///! ```edition2018
6///! use fuzzy_matcher::FuzzyMatcher;
7///! use fuzzy_matcher::clangd::ClangdMatcher;
8///!
9///! let matcher = ClangdMatcher::default();
10///!
11///! assert_eq!(None, matcher.fuzzy_match("abc", "abx"));
12///! assert!(matcher.fuzzy_match("axbycz", "abc").is_some());
13///! assert!(matcher.fuzzy_match("axbycz", "xyz").is_some());
14///!
15///! let (score, indices) = matcher.fuzzy_indices("axbycz", "abc").unwrap();
16///! assert_eq!(indices, [0, 2, 4]);
17///!
18///! ```
19///!
20///! Algorithm modified from
21///! https://github.com/llvm-mirror/clang-tools-extra/blob/master/clangd/FuzzyMatch.cpp
22///! Also check: https://github.com/lewang/flx/issues/98
23use crate::util::*;
24use crate::{FuzzyMatcher, IndexType, ScoreType};
25use std::cell::RefCell;
26use std::cmp::max;
27use thread_local::CachedThreadLocal;
28
29#[derive(Eq, PartialEq, Debug, Copy, Clone)]
30enum CaseMatching {
31    Respect,
32    Ignore,
33    Smart,
34}
35
36pub struct ClangdMatcher {
37    case: CaseMatching,
38
39    use_cache: bool,
40
41    c_cache: CachedThreadLocal<RefCell<Vec<char>>>, // vector to store the characters of choice
42    p_cache: CachedThreadLocal<RefCell<Vec<char>>>, // vector to store the characters of pattern
43}
44
45impl Default for ClangdMatcher {
46    fn default() -> Self {
47        Self {
48            case: CaseMatching::Ignore,
49            use_cache: true,
50            c_cache: CachedThreadLocal::new(),
51            p_cache: CachedThreadLocal::new(),
52        }
53    }
54}
55
56impl ClangdMatcher {
57    pub fn ignore_case(mut self) -> Self {
58        self.case = CaseMatching::Ignore;
59        self
60    }
61
62    pub fn smart_case(mut self) -> Self {
63        self.case = CaseMatching::Smart;
64        self
65    }
66
67    pub fn respect_case(mut self) -> Self {
68        self.case = CaseMatching::Respect;
69        self
70    }
71
72    pub fn use_cache(mut self, use_cache: bool) -> Self {
73        self.use_cache = use_cache;
74        self
75    }
76
77    fn contains_upper(&self, string: &str) -> bool {
78        for ch in string.chars() {
79            if ch.is_ascii_uppercase() {
80                return true;
81            }
82        }
83
84        false
85    }
86
87    fn is_case_sensitive(&self, pattern: &str) -> bool {
88        match self.case {
89            CaseMatching::Respect => true,
90            CaseMatching::Ignore => false,
91            CaseMatching::Smart => self.contains_upper(pattern),
92        }
93    }
94}
95
96impl FuzzyMatcher for ClangdMatcher {
97    fn fuzzy_indices(&self, choice: &str, pattern: &str) -> Option<(ScoreType, Vec<IndexType>)> {
98        let case_sensitive = self.is_case_sensitive(pattern);
99
100        let mut choice_chars = self
101            .c_cache
102            .get_or(|| RefCell::new(Vec::new()))
103            .borrow_mut();
104        let mut pattern_chars = self
105            .p_cache
106            .get_or(|| RefCell::new(Vec::new()))
107            .borrow_mut();
108
109        choice_chars.clear();
110        for char in choice.chars() {
111            choice_chars.push(char);
112        }
113
114        pattern_chars.clear();
115        for char in pattern.chars() {
116            pattern_chars.push(char);
117        }
118
119        if cheap_matches(&choice_chars, &pattern_chars, case_sensitive).is_none() {
120            return None;
121        }
122
123        let num_pattern_chars = pattern_chars.len();
124        let num_choice_chars = choice_chars.len();
125
126        let dp = build_graph(&choice_chars, &pattern_chars, false, case_sensitive);
127
128        // search backwards for the matched indices
129        let mut indices_reverse = Vec::with_capacity(num_pattern_chars);
130        let cell = dp[num_pattern_chars][num_choice_chars];
131
132        let (mut last_action, score) = if cell.match_score > cell.miss_score {
133            (Action::Match, cell.match_score)
134        } else {
135            (Action::Miss, cell.miss_score)
136        };
137
138        let mut row = num_pattern_chars;
139        let mut col = num_choice_chars;
140
141        while row > 0 || col > 0 {
142            if last_action == Action::Match {
143                indices_reverse.push((col - 1) as IndexType);
144            }
145
146            let cell = &dp[row][col];
147            if last_action == Action::Match {
148                last_action = cell.last_action_match;
149                row -= 1;
150                col -= 1;
151            } else {
152                last_action = cell.last_action_miss;
153                col -= 1;
154            }
155        }
156
157        if !self.use_cache {
158            // drop the allocated memory
159            self.c_cache.get().map(|cell| cell.replace(vec![]));
160            self.p_cache.get().map(|cell| cell.replace(vec![]));
161        }
162
163        indices_reverse.reverse();
164        Some((adjust_score(score, num_choice_chars), indices_reverse))
165    }
166
167    fn fuzzy_match(&self, choice: &str, pattern: &str) -> Option<ScoreType> {
168        let case_sensitive = self.is_case_sensitive(pattern);
169
170        let mut choice_chars = self
171            .c_cache
172            .get_or(|| RefCell::new(Vec::new()))
173            .borrow_mut();
174        let mut pattern_chars = self
175            .p_cache
176            .get_or(|| RefCell::new(Vec::new()))
177            .borrow_mut();
178
179        choice_chars.clear();
180        for char in choice.chars() {
181            choice_chars.push(char);
182        }
183
184        pattern_chars.clear();
185        for char in pattern.chars() {
186            pattern_chars.push(char);
187        }
188
189        if cheap_matches(&choice_chars, &pattern_chars, case_sensitive).is_none() {
190            return None;
191        }
192
193        let num_pattern_chars = pattern_chars.len();
194        let num_choice_chars = choice_chars.len();
195
196        let dp = build_graph(&choice_chars, &pattern_chars, true, case_sensitive);
197
198        let cell = dp[num_pattern_chars & 1][num_choice_chars];
199        let score = max(cell.match_score, cell.miss_score);
200
201        if !self.use_cache {
202            // drop the allocated memory
203            self.c_cache.get().map(|cell| cell.replace(vec![]));
204            self.p_cache.get().map(|cell| cell.replace(vec![]));
205        }
206
207        Some(adjust_score(score, num_choice_chars))
208    }
209}
210
211/// fuzzy match `line` with `pattern`, returning the score and indices of matches
212pub fn fuzzy_indices(line: &str, pattern: &str) -> Option<(ScoreType, Vec<IndexType>)> {
213    ClangdMatcher::default()
214        .ignore_case()
215        .fuzzy_indices(line, pattern)
216}
217
218/// fuzzy match `line` with `pattern`, returning the score(the larger the better) on match
219pub fn fuzzy_match(line: &str, pattern: &str) -> Option<ScoreType> {
220    ClangdMatcher::default()
221        .ignore_case()
222        .fuzzy_match(line, pattern)
223}
224
225// checkout https://github.com/llvm-mirror/clang-tools-extra/blob/master/clangd/FuzzyMatch.cpp
226// for the description
227fn build_graph(
228    line: &[char],
229    pattern: &[char],
230    compressed: bool,
231    case_sensitive: bool,
232) -> Vec<Vec<Score>> {
233    let num_line_chars = line.len();
234    let num_pattern_chars = pattern.len();
235    let max_rows = if compressed { 2 } else { num_pattern_chars + 1 };
236
237    let mut dp: Vec<Vec<Score>> = Vec::with_capacity(max_rows);
238
239    for _ in 0..max_rows {
240        dp.push(vec![Score::default(); num_line_chars + 1]);
241    }
242
243    dp[0][0].miss_score = 0;
244
245    // first line
246    for (idx, &ch) in line.iter().enumerate() {
247        dp[0][idx + 1] = Score {
248            miss_score: dp[0][idx].miss_score - skip_penalty(idx, ch, Action::Miss),
249            last_action_miss: Action::Miss,
250            match_score: AWFUL_SCORE,
251            last_action_match: Action::Miss,
252        };
253    }
254
255    // build the matrix
256    let mut pat_prev_ch = '\0';
257    for (pat_idx, &pat_ch) in pattern.iter().enumerate() {
258        let current_row_idx = if compressed {
259            (pat_idx + 1) & 1
260        } else {
261            pat_idx + 1
262        };
263        let prev_row_idx = if compressed { pat_idx & 1 } else { pat_idx };
264
265        let mut line_prev_ch = '\0';
266        for (line_idx, &line_ch) in line.iter().enumerate() {
267            if line_idx < pat_idx {
268                line_prev_ch = line_ch;
269                continue;
270            }
271
272            // what if we skip current line character?
273            // we need to calculate the cases where the pre line character is matched/missed
274            let pre_miss = &dp[current_row_idx][line_idx];
275            let mut match_miss_score = pre_miss.match_score;
276            let mut miss_miss_score = pre_miss.miss_score;
277            if pat_idx < num_pattern_chars - 1 {
278                match_miss_score -= skip_penalty(line_idx, line_ch, Action::Match);
279                miss_miss_score -= skip_penalty(line_idx, line_ch, Action::Miss);
280            }
281
282            let (miss_score, last_action_miss) = if match_miss_score > miss_miss_score {
283                (match_miss_score, Action::Match)
284            } else {
285                (miss_miss_score, Action::Miss)
286            };
287
288            // what if we want to match current line character?
289            // so we need to calculate the cases where the pre pattern character is matched/missed
290            let pre_match = &dp[prev_row_idx][line_idx];
291            let match_match_score = if allow_match(pat_ch, line_ch, case_sensitive) {
292                pre_match.match_score
293                    + match_bonus(
294                        pat_idx,
295                        pat_ch,
296                        pat_prev_ch,
297                        line_idx,
298                        line_ch,
299                        line_prev_ch,
300                        Action::Match,
301                    )
302            } else {
303                AWFUL_SCORE
304            };
305
306            let miss_match_score = if allow_match(pat_ch, line_ch, case_sensitive) {
307                pre_match.miss_score
308                    + match_bonus(
309                        pat_idx,
310                        pat_ch,
311                        pat_prev_ch,
312                        line_idx,
313                        line_ch,
314                        line_prev_ch,
315                        Action::Match,
316                    )
317            } else {
318                AWFUL_SCORE
319            };
320
321            let (match_score, last_action_match) = if match_match_score > miss_match_score {
322                (match_match_score, Action::Match)
323            } else {
324                (miss_match_score, Action::Miss)
325            };
326
327            dp[current_row_idx][line_idx + 1] = Score {
328                miss_score,
329                last_action_miss,
330                match_score,
331                last_action_match,
332            };
333
334            line_prev_ch = line_ch;
335        }
336
337        pat_prev_ch = pat_ch;
338    }
339
340    dp
341}
342
343fn adjust_score(score: ScoreType, num_line_chars: usize) -> ScoreType {
344    // line width will affect 10 scores
345    score - (((num_line_chars + 1) as f64).ln().floor() as ScoreType)
346}
347
348const AWFUL_SCORE: ScoreType = -(1 << 30);
349
350#[derive(Debug, PartialEq, Clone, Copy)]
351enum Action {
352    Miss,
353    Match,
354}
355
356#[derive(Debug, Clone, Copy)]
357struct Score {
358    pub last_action_miss: Action,
359    pub last_action_match: Action,
360    pub miss_score: ScoreType,
361    pub match_score: ScoreType,
362}
363
364impl Default for Score {
365    fn default() -> Self {
366        Self {
367            last_action_miss: Action::Miss,
368            last_action_match: Action::Miss,
369            miss_score: AWFUL_SCORE,
370            match_score: AWFUL_SCORE,
371        }
372    }
373}
374
375fn skip_penalty(_ch_idx: usize, ch: char, last_action: Action) -> ScoreType {
376    let mut score = 1;
377    if last_action == Action::Match {
378        // Non-consecutive match.
379        score += 3;
380    }
381
382    if char_type_of(ch) == CharType::NonWord {
383        // skip separator
384        score += 6;
385    }
386
387    score
388}
389
390fn allow_match(pat_ch: char, line_ch: char, case_sensitive: bool) -> bool {
391    char_equal(pat_ch, line_ch, case_sensitive)
392}
393
394fn match_bonus(
395    pat_idx: usize,
396    pat_ch: char,
397    pat_prev_ch: char,
398    line_idx: usize,
399    line_ch: char,
400    line_prev_ch: char,
401    last_action: Action,
402) -> ScoreType {
403    let mut score = 10;
404    let pat_role = char_role(pat_prev_ch, pat_ch);
405    let line_role = char_role(line_prev_ch, line_ch);
406
407    // Bonus: pattern so far is a (case-insensitive) prefix of the word.
408    if pat_idx == line_idx {
409        score += 10;
410    }
411
412    // Bonus: case match
413    if pat_ch == line_ch {
414        score += 8;
415    }
416
417    // Bonus: match header
418    if line_role == CharRole::Head {
419        score += 9;
420    }
421
422    // Bonus: a Head in the pattern aligns with one in the word.
423    if pat_role == CharRole::Head && line_role == CharRole::Head {
424        score += 10;
425    }
426
427    // Penalty: matching inside a segment (and previous char wasn't matched).
428    if line_role == CharRole::Tail && pat_idx > 0 && last_action == Action::Miss {
429        score -= 30;
430    }
431
432    // Penalty: a Head in the pattern matches in the middle of a word segment.
433    if pat_role == CharRole::Head && line_role == CharRole::Tail {
434        score -= 10;
435    }
436
437    // Penalty: matching the first pattern character in the middle of a segment.
438    if pat_idx == 0 && line_role == CharRole::Tail {
439        score -= 40;
440    }
441
442    score
443}
444
445#[cfg(test)]
446mod tests {
447    use super::*;
448    use crate::util::{assert_order, wrap_matches};
449
450    fn wrap_fuzzy_match(line: &str, pattern: &str) -> Option<String> {
451        let (_score, indices) = fuzzy_indices(line, pattern)?;
452        Some(wrap_matches(line, &indices))
453    }
454
455    #[test]
456    fn test_match_or_not() {
457        assert_eq!(None, fuzzy_match("abcdefaghi", "中"));
458        assert_eq!(None, fuzzy_match("abc", "abx"));
459        assert!(fuzzy_match("axbycz", "abc").is_some());
460        assert!(fuzzy_match("axbycz", "xyz").is_some());
461
462        assert_eq!("[a]x[b]y[c]z", &wrap_fuzzy_match("axbycz", "abc").unwrap());
463        assert_eq!("a[x]b[y]c[z]", &wrap_fuzzy_match("axbycz", "xyz").unwrap());
464        assert_eq!(
465            "[H]ello, [世]界",
466            &wrap_fuzzy_match("Hello, 世界", "H世").unwrap()
467        );
468    }
469
470    #[test]
471    fn test_match_quality() {
472        let matcher = ClangdMatcher::default();
473        // case
474        assert_order(&matcher, "monad", &["monad", "Monad", "mONAD"]);
475
476        // initials
477        assert_order(&matcher, "ab", &["ab", "aoo_boo", "acb"]);
478        assert_order(&matcher, "CC", &["CamelCase", "camelCase", "camelcase"]);
479        assert_order(&matcher, "cC", &["camelCase", "CamelCase", "camelcase"]);
480        assert_order(
481            &matcher,
482            "cc",
483            &[
484                "camel case",
485                "camelCase",
486                "camelcase",
487                "CamelCase",
488                "camel ace",
489            ],
490        );
491        assert_order(
492            &matcher,
493            "Da.Te",
494            &["Data.Text", "Data.Text.Lazy", "Data.Aeson.Encoding.text"],
495        );
496        assert_order(&matcher, "foobar.h", &["foobar.h", "foo/bar.h"]);
497        // prefix
498        assert_order(&matcher, "is", &["isIEEE", "inSuf"]);
499        // shorter
500        assert_order(&matcher, "ma", &["map", "many", "maximum"]);
501        assert_order(&matcher, "print", &["printf", "sprintf"]);
502        // score(PRINT) = kMinScore
503        assert_order(&matcher, "ast", &["ast", "AST", "INT_FAST16_MAX"]);
504        // score(PRINT) > kMinScore
505        assert_order(&matcher, "Int", &["int", "INT", "PRINT"]);
506    }
507}
508
509#[allow(dead_code)]
510fn print_dp(line: &str, pattern: &str, dp: &[Vec<Score>]) {
511    let num_line_chars = line.chars().count();
512    let num_pattern_chars = pattern.chars().count();
513
514    print!("\t");
515    for (idx, ch) in line.chars().enumerate() {
516        print!("\t\t{}/{}", idx + 1, ch);
517    }
518
519    for (row_num, row) in dp.iter().enumerate().take(num_pattern_chars + 1) {
520        print!("\n{}\t", row_num);
521        for cell in row.iter().take(num_line_chars + 1) {
522            print!(
523                "({},{})/({},{})\t",
524                cell.miss_score,
525                if cell.last_action_miss == Action::Miss {
526                    'X'
527                } else {
528                    'O'
529                },
530                cell.match_score,
531                if cell.last_action_match == Action::Miss {
532                    'X'
533                } else {
534                    'O'
535                }
536            );
537        }
538    }
539}