quickmatch/
lib.rs

1use std::{marker::PhantomData, ptr};
2
3use rustc_hash::{FxHashMap, FxHashSet};
4
5mod config;
6
7pub use config::*;
8
9pub struct QuickMatch<'a> {
10    config: QuickMatchConfig,
11    max_word_count: usize,
12    max_word_len: usize,
13    max_query_len: usize,
14    word_index: FxHashMap<String, FxHashSet<*const str>>,
15    trigram_index: FxHashMap<[char; 3], FxHashSet<*const str>>,
16    _phantom: PhantomData<&'a str>,
17}
18
19unsafe impl<'a> Send for QuickMatch<'a> {}
20unsafe impl<'a> Sync for QuickMatch<'a> {}
21
22impl<'a> QuickMatch<'a> {
23    /// Expect the items to be pre-formatted (lowercase)
24    pub fn new(items: &[&'a str]) -> Self {
25        Self::new_with(items, QuickMatchConfig::default())
26    }
27
28    /// Expect the items to be pre-formatted (lowercase)
29    pub fn new_with(items: &[&'a str], config: QuickMatchConfig) -> Self {
30        let mut word_index: FxHashMap<String, FxHashSet<*const str>> = FxHashMap::default();
31        let mut trigram_index: FxHashMap<[char; 3], FxHashSet<*const str>> = FxHashMap::default();
32        let mut max_word_len = 0;
33        let mut max_query_len = 0;
34        let mut max_words = 0;
35        let separators = config.separators();
36
37        for &item in items {
38            max_query_len = max_query_len.max(item.len());
39            let mut word_count = 0;
40            for word in item.split(separators) {
41                word_count += 1;
42                if word.is_empty() {
43                    continue;
44                }
45
46                max_word_len = max_word_len.max(item.len());
47
48                word_index.entry(word.to_string()).or_default().insert(item);
49
50                if word.len() >= 3 {
51                    let chars = word.chars().collect::<Vec<_>>();
52                    for window in chars.windows(3) {
53                        trigram_index
54                            .entry(unsafe { ptr::read(window.as_ptr() as *const [char; 3]) })
55                            .or_default()
56                            .insert(item);
57                    }
58                }
59            }
60            max_words = max_words.max(word_count);
61        }
62
63        Self {
64            max_query_len: max_query_len + 6,
65            max_word_len: max_word_len + 4,
66            max_word_count: max_word_len + 2,
67            word_index,
68            trigram_index,
69            config,
70            _phantom: PhantomData,
71        }
72    }
73
74    ///
75    /// `limit`: max number of returned matches
76    ///
77    /// `max_trigrams`: max number of processed trigrams in unknown words (0-10 recommended)
78    ///
79    pub fn matches(&self, query: &str) -> Vec<&'a str> {
80        self.matches_with(query, &self.config)
81    }
82
83    ///
84    /// `limit`: max number of returned matches
85    ///
86    /// `max_trigrams`: max number of processed trigrams in unknown words (0-10 recommended)
87    ///
88    pub fn matches_with(&self, query: &str, config: &QuickMatchConfig) -> Vec<&'a str> {
89        let limit = config.limit();
90        let trigram_budget = config.trigram_budget();
91        let query_len = query.len();
92
93        if query.is_empty() || query_len > self.max_query_len {
94            return vec![];
95        }
96
97        let query = query
98            .trim()
99            .chars()
100            .filter(|c| c.is_ascii())
101            .collect::<String>()
102            .to_ascii_lowercase();
103
104        let words = query
105            .split(config.separators())
106            .filter(|w| !w.is_empty() && w.len() <= self.max_word_len)
107            .collect::<FxHashSet<_>>();
108
109        if words.is_empty() || words.len() > self.max_word_count {
110            return vec![];
111        }
112
113        let min_len = query_len.saturating_sub(3);
114
115        let mut pool: Option<FxHashSet<*const str>> = None;
116        let mut unknown_words = Vec::new();
117
118        let mut words_to_intersect = vec![];
119        for word in words {
120            if let Some(items) = self.word_index.get(word) {
121                words_to_intersect.push(items)
122            } else if word.len() >= 3 && unknown_words.len() < trigram_budget {
123                unknown_words.push(word.chars().collect::<Vec<_>>())
124            }
125        }
126
127        if !words_to_intersect.is_empty() {
128            words_to_intersect.sort_unstable_by_key(|set| -(set.len() as i64));
129
130            let mut intersect = words_to_intersect.pop().cloned().unwrap();
131
132            for other_set in words_to_intersect.iter().rev() {
133                intersect.retain(|ptr| other_set.contains(ptr));
134                if intersect.is_empty() {
135                    break;
136                }
137            }
138
139            pool = Some(intersect);
140        }
141        let some_pool = pool.is_some();
142
143        if unknown_words.is_empty() || trigram_budget == 0 {
144            let mut results: Vec<_> = pool
145                .unwrap_or_default()
146                .into_iter()
147                .map(|item| unsafe { &*item as &str })
148                .collect();
149
150            if results.len() > limit {
151                results.select_nth_unstable_by_key(limit, |item| item.len());
152                results.truncate(limit);
153            }
154
155            results.sort_unstable_by_key(|item| item.len());
156
157            return results;
158        }
159
160        let mut scores: FxHashMap<*const str, usize> = FxHashMap::default();
161        scores.reserve(256);
162        if let Some(pool) = &pool {
163            for &item in pool {
164                scores.insert(item, 1);
165            }
166        }
167
168        let mut trigram_count = 0;
169        let mut visited: FxHashSet<[char; 3]> = FxHashSet::default();
170
171        'outer: for round in 0..trigram_budget {
172            let mut processed_trigrams = false;
173
174            for chars in &unknown_words {
175                if trigram_count >= trigram_budget {
176                    break 'outer;
177                }
178
179                let len = chars.len();
180                let max_pos = len - 3;
181
182                let pos = if round == 0 {
183                    0
184                } else if round == 1 && max_pos > 0 {
185                    max_pos
186                } else if round == 2 && max_pos > 1 {
187                    max_pos / 2
188                } else if max_pos > 2 {
189                    // Alternate left and right of middle
190                    let mid = max_pos / 2;
191                    let offset = (round - 2) >> 1; // Faster than / 2
192                    let p = if (round & 1) == 1 {
193                        // Faster than (r - 3) % 2 == 0
194                        mid.saturating_sub(offset)
195                    } else {
196                        mid + offset
197                    };
198
199                    if p == 0 || p >= max_pos || p == mid {
200                        continue;
201                    }
202                    p
203                } else {
204                    continue;
205                };
206
207                let trigram = [chars[pos], chars[pos + 1], chars[pos + 2]];
208
209                if !visited.insert(trigram) {
210                    continue;
211                }
212
213                let Some(items) = self.trigram_index.get(&trigram) else {
214                    continue;
215                };
216
217                processed_trigrams = true;
218                trigram_count += 1;
219
220                if some_pool {
221                    for &item in items {
222                        if let Some(score) = scores.get_mut(&item) {
223                            *score += 1;
224                        }
225                    }
226                } else {
227                    for &item in items {
228                        let len = unsafe { &*item }.len();
229                        if len >= min_len {
230                            *scores.entry(item).or_default() += 1;
231                        }
232                    }
233                }
234            }
235
236            if !processed_trigrams {
237                break 'outer;
238            }
239        }
240
241        let min_score = trigram_count.div_ceil(2).max(1);
242        let mut results: Vec<_> = scores
243            .into_iter()
244            .filter(|(_, s)| *s >= min_score)
245            .map(|(item, score)| (unsafe { &*item as &str }, score))
246            .collect();
247
248        if results.len() > limit {
249            results.select_nth_unstable_by(limit, |a, b| {
250                b.1.cmp(&a.1).then_with(|| a.0.len().cmp(&b.0.len()))
251            });
252            results.truncate(limit);
253        }
254
255        results.sort_unstable_by(|a, b| b.1.cmp(&a.1).then_with(|| a.0.len().cmp(&b.0.len())));
256
257        results
258            .into_iter()
259            .take(limit)
260            .map(|(item, _)| item)
261            .collect()
262    }
263}