fuzzy_matcher/
clangd.rs

1///! The fuzzy matching algorithm used in clangd.
2///! https://github.com/llvm-mirror/clang-tools-extra/blob/master/clangd/FuzzyMatch.cpp
3///!
4///! # Example:
5///! ```edition2018
6///! use fuzzy_matcher::FuzzyMatcher;
7///! use fuzzy_matcher::clangd::ClangdMatcher;
8///!
9///! let matcher = ClangdMatcher::default();
10///!
11///! assert_eq!(None, matcher.fuzzy_match("abc", "abx"));
12///! assert!(matcher.fuzzy_match("axbycz", "abc").is_some());
13///! assert!(matcher.fuzzy_match("axbycz", "xyz").is_some());
14///!
15///! let (score, indices) = matcher.fuzzy_indices("axbycz", "abc").unwrap();
16///! assert_eq!(indices, [0, 2, 4]);
17///!
18///! ```
19///!
20///! Algorithm modified from
21///! https://github.com/llvm-mirror/clang-tools-extra/blob/master/clangd/FuzzyMatch.cpp
22///! Also check: https://github.com/lewang/flx/issues/98
23use crate::util::*;
24use crate::{FuzzyMatcher, IndexType, ScoreType};
25use std::cell::RefCell;
26use std::cmp::max;
27use thread_local::ThreadLocal;
28
29#[derive(Eq, PartialEq, Debug, Copy, Clone)]
30enum CaseMatching {
31    Respect,
32    Ignore,
33    Smart,
34}
35
36#[derive(Debug)]
37pub struct ClangdMatcher {
38    case: CaseMatching,
39
40    use_cache: bool,
41
42    c_cache: ThreadLocal<RefCell<Vec<char>>>, // vector to store the characters of choice
43    p_cache: ThreadLocal<RefCell<Vec<char>>>, // vector to store the characters of pattern
44}
45
46impl Default for ClangdMatcher {
47    fn default() -> Self {
48        Self {
49            case: CaseMatching::Ignore,
50            use_cache: true,
51            c_cache: ThreadLocal::new(),
52            p_cache: ThreadLocal::new(),
53        }
54    }
55}
56
57impl ClangdMatcher {
58    pub fn ignore_case(mut self) -> Self {
59        self.case = CaseMatching::Ignore;
60        self
61    }
62
63    pub fn smart_case(mut self) -> Self {
64        self.case = CaseMatching::Smart;
65        self
66    }
67
68    pub fn respect_case(mut self) -> Self {
69        self.case = CaseMatching::Respect;
70        self
71    }
72
73    pub fn use_cache(mut self, use_cache: bool) -> Self {
74        self.use_cache = use_cache;
75        self
76    }
77
78    fn contains_upper(&self, string: &str) -> bool {
79        for ch in string.chars() {
80            if ch.is_ascii_uppercase() {
81                return true;
82            }
83        }
84
85        false
86    }
87
88    fn is_case_sensitive(&self, pattern: &str) -> bool {
89        match self.case {
90            CaseMatching::Respect => true,
91            CaseMatching::Ignore => false,
92            CaseMatching::Smart => self.contains_upper(pattern),
93        }
94    }
95}
96
97impl FuzzyMatcher for ClangdMatcher {
98    fn fuzzy_indices(&self, choice: &str, pattern: &str) -> Option<(ScoreType, Vec<IndexType>)> {
99        let case_sensitive = self.is_case_sensitive(pattern);
100
101        let mut choice_chars = self.c_cache.get_or_default().borrow_mut();
102        let mut pattern_chars = self.p_cache.get_or_default().borrow_mut();
103
104        *choice_chars = choice.chars().collect();
105
106        *pattern_chars = pattern.chars().collect();
107
108        cheap_matches(&choice_chars, &pattern_chars, case_sensitive)?;
109
110        let num_pattern_chars = pattern_chars.len();
111        let num_choice_chars = choice_chars.len();
112
113        let dp = build_graph(&choice_chars, &pattern_chars, false, case_sensitive);
114
115        // search backwards for the matched indices
116        let mut indices_reverse = Vec::with_capacity(num_pattern_chars);
117        let cell = dp[num_pattern_chars][num_choice_chars];
118
119        let (mut last_action, score) = if cell.match_score > cell.miss_score {
120            (Action::Match, cell.match_score)
121        } else {
122            (Action::Miss, cell.miss_score)
123        };
124
125        let mut row = num_pattern_chars;
126        let mut col = num_choice_chars;
127
128        while row > 0 || col > 0 {
129            if last_action == Action::Match {
130                indices_reverse.push((col - 1) as IndexType);
131            }
132
133            let cell = &dp[row][col];
134            if last_action == Action::Match {
135                last_action = cell.last_action_match;
136                row -= 1;
137                col -= 1;
138            } else {
139                last_action = cell.last_action_miss;
140                col -= 1;
141            }
142        }
143
144        if !self.use_cache {
145            // drop the allocated memory
146            self.c_cache.get().map(|cell| cell.take());
147            self.p_cache.get().map(|cell| cell.take());
148        }
149
150        indices_reverse.reverse();
151        Some((adjust_score(score, num_choice_chars), indices_reverse))
152    }
153
154    fn fuzzy_match(&self, choice: &str, pattern: &str) -> Option<ScoreType> {
155        let case_sensitive = self.is_case_sensitive(pattern);
156
157        let mut choice_chars = self.c_cache.get_or_default().borrow_mut();
158        let mut pattern_chars = self.p_cache.get_or_default().borrow_mut();
159
160        *choice_chars = choice.chars().collect();
161
162        *pattern_chars = pattern.chars().collect();
163
164        cheap_matches(&choice_chars, &pattern_chars, case_sensitive)?;
165
166        let num_pattern_chars = pattern_chars.len();
167        let num_choice_chars = choice_chars.len();
168
169        let dp = build_graph(&choice_chars, &pattern_chars, true, case_sensitive);
170
171        let cell = dp[num_pattern_chars & 1][num_choice_chars];
172        let score = max(cell.match_score, cell.miss_score);
173
174        if !self.use_cache {
175            // drop the allocated memory
176            self.c_cache.get().map(|cell| cell.take());
177            self.p_cache.get().map(|cell| cell.take());
178        }
179
180        Some(adjust_score(score, num_choice_chars))
181    }
182}
183
184/// fuzzy match `line` with `pattern`, returning the score and indices of matches
185pub fn fuzzy_indices(line: &str, pattern: &str) -> Option<(ScoreType, Vec<IndexType>)> {
186    ClangdMatcher::default()
187        .ignore_case()
188        .fuzzy_indices(line, pattern)
189}
190
191/// fuzzy match `line` with `pattern`, returning the score(the larger the better) on match
192pub fn fuzzy_match(line: &str, pattern: &str) -> Option<ScoreType> {
193    ClangdMatcher::default()
194        .ignore_case()
195        .fuzzy_match(line, pattern)
196}
197
198// checkout https://github.com/llvm-mirror/clang-tools-extra/blob/master/clangd/FuzzyMatch.cpp
199// for the description
200fn build_graph(
201    line: &[char],
202    pattern: &[char],
203    compressed: bool,
204    case_sensitive: bool,
205) -> Vec<Vec<Score>> {
206    let num_line_chars = line.len();
207    let num_pattern_chars = pattern.len();
208    let max_rows = if compressed { 2 } else { num_pattern_chars + 1 };
209
210    let mut dp: Vec<Vec<Score>> = Vec::with_capacity(max_rows);
211
212    for _ in 0..max_rows {
213        dp.push(vec![Score::default(); num_line_chars + 1]);
214    }
215
216    dp[0][0].miss_score = 0;
217
218    // first line
219    for (idx, &ch) in line.iter().enumerate() {
220        dp[0][idx + 1] = Score {
221            miss_score: dp[0][idx].miss_score - skip_penalty(idx, ch, Action::Miss),
222            last_action_miss: Action::Miss,
223            match_score: AWFUL_SCORE,
224            last_action_match: Action::Miss,
225        };
226    }
227
228    // build the matrix
229    let mut pat_prev_ch = '\0';
230    for (pat_idx, &pat_ch) in pattern.iter().enumerate() {
231        let current_row_idx = if compressed {
232            (pat_idx + 1) & 1
233        } else {
234            pat_idx + 1
235        };
236        let prev_row_idx = if compressed { pat_idx & 1 } else { pat_idx };
237
238        let mut line_prev_ch = '\0';
239        for (line_idx, &line_ch) in line.iter().enumerate() {
240            if line_idx < pat_idx {
241                line_prev_ch = line_ch;
242                continue;
243            }
244
245            // what if we skip current line character?
246            // we need to calculate the cases where the pre line character is matched/missed
247            let pre_miss = &dp[current_row_idx][line_idx];
248            let mut match_miss_score = pre_miss.match_score;
249            let mut miss_miss_score = pre_miss.miss_score;
250            if pat_idx < num_pattern_chars - 1 {
251                match_miss_score -= skip_penalty(line_idx, line_ch, Action::Match);
252                miss_miss_score -= skip_penalty(line_idx, line_ch, Action::Miss);
253            }
254
255            let (miss_score, last_action_miss) = if match_miss_score > miss_miss_score {
256                (match_miss_score, Action::Match)
257            } else {
258                (miss_miss_score, Action::Miss)
259            };
260
261            // what if we want to match current line character?
262            // so we need to calculate the cases where the pre pattern character is matched/missed
263            let pre_match = &dp[prev_row_idx][line_idx];
264            let match_match_score = if allow_match(pat_ch, line_ch, case_sensitive) {
265                pre_match.match_score
266                    + match_bonus(
267                        pat_idx,
268                        pat_ch,
269                        pat_prev_ch,
270                        line_idx,
271                        line_ch,
272                        line_prev_ch,
273                        Action::Match,
274                    )
275            } else {
276                AWFUL_SCORE
277            };
278
279            let miss_match_score = if allow_match(pat_ch, line_ch, case_sensitive) {
280                pre_match.miss_score
281                    + match_bonus(
282                        pat_idx,
283                        pat_ch,
284                        pat_prev_ch,
285                        line_idx,
286                        line_ch,
287                        line_prev_ch,
288                        Action::Match,
289                    )
290            } else {
291                AWFUL_SCORE
292            };
293
294            let (match_score, last_action_match) = if match_match_score > miss_match_score {
295                (match_match_score, Action::Match)
296            } else {
297                (miss_match_score, Action::Miss)
298            };
299
300            dp[current_row_idx][line_idx + 1] = Score {
301                miss_score,
302                last_action_miss,
303                match_score,
304                last_action_match,
305            };
306
307            line_prev_ch = line_ch;
308        }
309
310        pat_prev_ch = pat_ch;
311    }
312
313    dp
314}
315
316fn adjust_score(score: ScoreType, num_line_chars: usize) -> ScoreType {
317    // line width will affect 10 scores
318    score - (((num_line_chars + 1) as f64).ln().floor() as ScoreType)
319}
320
321const AWFUL_SCORE: ScoreType = -(1 << 30);
322
323#[derive(Debug, PartialEq, Clone, Copy)]
324enum Action {
325    Miss,
326    Match,
327}
328
329#[derive(Debug, Clone, Copy)]
330struct Score {
331    pub last_action_miss: Action,
332    pub last_action_match: Action,
333    pub miss_score: ScoreType,
334    pub match_score: ScoreType,
335}
336
337impl Default for Score {
338    fn default() -> Self {
339        Self {
340            last_action_miss: Action::Miss,
341            last_action_match: Action::Miss,
342            miss_score: AWFUL_SCORE,
343            match_score: AWFUL_SCORE,
344        }
345    }
346}
347
348fn skip_penalty(_ch_idx: usize, ch: char, last_action: Action) -> ScoreType {
349    let mut score = 1;
350    if last_action == Action::Match {
351        // Non-consecutive match.
352        score += 3;
353    }
354
355    if char_type_of(ch) == CharType::NonWord {
356        // skip separator
357        score += 6;
358    }
359
360    score
361}
362
363fn allow_match(pat_ch: char, line_ch: char, case_sensitive: bool) -> bool {
364    char_equal(pat_ch, line_ch, case_sensitive)
365}
366
367fn match_bonus(
368    pat_idx: usize,
369    pat_ch: char,
370    pat_prev_ch: char,
371    line_idx: usize,
372    line_ch: char,
373    line_prev_ch: char,
374    last_action: Action,
375) -> ScoreType {
376    let mut score = 10;
377    let pat_role = char_role(pat_prev_ch, pat_ch);
378    let line_role = char_role(line_prev_ch, line_ch);
379
380    // Bonus: pattern so far is a (case-insensitive) prefix of the word.
381    if pat_idx == line_idx {
382        score += 10;
383    }
384
385    // Bonus: case match
386    if pat_ch == line_ch {
387        score += 8;
388    }
389
390    // Bonus: match header
391    if line_role == CharRole::Head {
392        score += 9;
393    }
394
395    // Bonus: a Head in the pattern aligns with one in the word.
396    if pat_role == CharRole::Head && line_role == CharRole::Head {
397        score += 10;
398    }
399
400    // Penalty: matching inside a segment (and previous char wasn't matched).
401    if line_role == CharRole::Tail && pat_idx > 0 && last_action == Action::Miss {
402        score -= 30;
403    }
404
405    // Penalty: a Head in the pattern matches in the middle of a word segment.
406    if pat_role == CharRole::Head && line_role == CharRole::Tail {
407        score -= 10;
408    }
409
410    // Penalty: matching the first pattern character in the middle of a segment.
411    if pat_idx == 0 && line_role == CharRole::Tail {
412        score -= 40;
413    }
414
415    score
416}
417
418#[cfg(test)]
419mod tests {
420    use super::*;
421    use crate::util::{assert_order, wrap_matches};
422
423    fn wrap_fuzzy_match(line: &str, pattern: &str) -> Option<String> {
424        let (_score, indices) = fuzzy_indices(line, pattern)?;
425        Some(wrap_matches(line, &indices))
426    }
427
428    #[test]
429    fn test_match_or_not() {
430        assert_eq!(None, fuzzy_match("abcdefaghi", "中"));
431        assert_eq!(None, fuzzy_match("abc", "abx"));
432        assert!(fuzzy_match("axbycz", "abc").is_some());
433        assert!(fuzzy_match("axbycz", "xyz").is_some());
434
435        assert_eq!("[a]x[b]y[c]z", &wrap_fuzzy_match("axbycz", "abc").unwrap());
436        assert_eq!("a[x]b[y]c[z]", &wrap_fuzzy_match("axbycz", "xyz").unwrap());
437        assert_eq!(
438            "[H]ello, [世]界",
439            &wrap_fuzzy_match("Hello, 世界", "H世").unwrap()
440        );
441    }
442
443    #[test]
444    fn test_match_quality() {
445        let matcher = ClangdMatcher::default();
446        // case
447        assert_order(&matcher, "monad", &["monad", "Monad", "mONAD"]);
448
449        // initials
450        assert_order(&matcher, "ab", &["ab", "aoo_boo", "acb"]);
451        assert_order(&matcher, "CC", &["CamelCase", "camelCase", "camelcase"]);
452        assert_order(&matcher, "cC", &["camelCase", "CamelCase", "camelcase"]);
453        assert_order(
454            &matcher,
455            "cc",
456            &[
457                "camel case",
458                "camelCase",
459                "camelcase",
460                "CamelCase",
461                "camel ace",
462            ],
463        );
464        assert_order(
465            &matcher,
466            "Da.Te",
467            &["Data.Text", "Data.Text.Lazy", "Data.Aeson.Encoding.text"],
468        );
469        assert_order(&matcher, "foobar.h", &["foobar.h", "foo/bar.h"]);
470        // prefix
471        assert_order(&matcher, "is", &["isIEEE", "inSuf"]);
472        // shorter
473        assert_order(&matcher, "ma", &["map", "many", "maximum"]);
474        assert_order(&matcher, "print", &["printf", "sprintf"]);
475        // score(PRINT) = kMinScore
476        assert_order(&matcher, "ast", &["ast", "AST", "INT_FAST16_MAX"]);
477        // score(PRINT) > kMinScore
478        assert_order(&matcher, "Int", &["int", "INT", "PRINT"]);
479    }
480}
481
482#[allow(dead_code)]
483fn print_dp(line: &str, pattern: &str, dp: &[Vec<Score>]) {
484    let num_line_chars = line.chars().count();
485    let num_pattern_chars = pattern.chars().count();
486
487    print!("\t");
488    for (idx, ch) in line.chars().enumerate() {
489        print!("\t\t{}/{}", idx + 1, ch);
490    }
491
492    for (row_num, row) in dp.iter().enumerate().take(num_pattern_chars + 1) {
493        print!("\n{}\t", row_num);
494        for cell in row.iter().take(num_line_chars + 1) {
495            print!(
496                "({},{})/({},{})\t",
497                cell.miss_score,
498                if cell.last_action_miss == Action::Miss {
499                    'X'
500                } else {
501                    'O'
502                },
503                cell.match_score,
504                if cell.last_action_match == Action::Miss {
505                    'X'
506                } else {
507                    'O'
508                }
509            );
510        }
511    }
512}