harper_core/spell/
fst_dictionary.rs

1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use lazy_static::lazy_static;
4use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
5use std::{cell::RefCell, sync::Arc};
6
7use crate::{CharString, CharStringExt, WordMetadata};
8
9use super::Dictionary;
10use super::FuzzyMatchResult;
11
12/// An immutable dictionary allowing for very fast spellchecking.
13///
14/// For dictionaries with changing contents, such as user and file dictionaries, prefer
15/// [`MutableDictionary`].
16pub struct FstDictionary {
17    /// Underlying [`super::MutableDictionary`] used for everything except fuzzy finding
18    full_dict: Arc<MutableDictionary>,
19    /// Used for fuzzy-finding the index of words or metadata
20    word_map: FstMap<Vec<u8>>,
21    /// Used for fuzzy-finding the index of words or metadata
22    words: Vec<(CharString, WordMetadata)>,
23}
24
25const EXPECTED_DISTANCE: u8 = 3;
26const TRANSPOSITION_COST_ONE: bool = false;
27
28lazy_static! {
29    static ref DICT: Arc<FstDictionary> = Arc::new((*MutableDictionary::curated()).clone().into());
30}
31
32thread_local! {
33    // Builders are computationally expensive and do not depend on the word, so we store a
34    // collection of builders and the associated edit distance here.
35    // Currently, the edit distance we use is three, but a value that does not exist in this
36    // collection will create a new builder of that distance and push it to the collection.
37    static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
38        EXPECTED_DISTANCE,
39        LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
40    )]);
41}
42
43impl PartialEq for FstDictionary {
44    fn eq(&self, other: &Self) -> bool {
45        self.full_dict == other.full_dict
46    }
47}
48
49impl FstDictionary {
50    /// Create a dictionary from the curated dictionary included
51    /// in the Harper binary.
52    pub fn curated() -> Arc<Self> {
53        (*DICT).clone()
54    }
55
56    /// Construct a new [`FstDictionary`] using a word list as a source.
57    /// This can be expensive, so only use this if fast fuzzy searches are worth it.
58    pub fn new(mut words: Vec<(CharString, WordMetadata)>) -> Self {
59        words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
60        words.dedup_by(|(a, _), (b, _)| a == b);
61
62        let mut builder = fst::MapBuilder::memory();
63        for (index, (word, _)) in words.iter().enumerate() {
64            let word = word.iter().collect::<String>();
65            builder
66                .insert(word, index as u64)
67                .expect("Insertion not in lexicographical order!");
68        }
69
70        let mut full_dict = MutableDictionary::new();
71        full_dict.extend_words(words.iter().cloned());
72
73        let fst_bytes = builder.into_inner().unwrap();
74        let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
75
76        FstDictionary {
77            full_dict: Arc::new(full_dict),
78            word_map,
79            words,
80        }
81    }
82}
83
84fn build_dfa(max_distance: u8, query: &str) -> DFA {
85    // Insert if it does not exist
86    AUTOMATON_BUILDERS.with_borrow_mut(|v| {
87        if !v.iter().any(|t| t.0 == max_distance) {
88            v.push((
89                max_distance,
90                LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
91            ));
92        }
93    });
94
95    AUTOMATON_BUILDERS.with_borrow(|v| {
96        v.iter()
97            .find(|a| a.0 == max_distance)
98            .unwrap()
99            .1
100            .build_dfa(query)
101    })
102}
103
104/// Consumes a DFA stream and emits the index-edit distance pairs it produces.
105fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
106    let mut word_index_pairs = Vec::new();
107    while let Some((_, v, s)) = stream.next() {
108        word_index_pairs.push((v, dfa.distance(s).to_u8()));
109    }
110
111    word_index_pairs
112}
113
114impl Dictionary for FstDictionary {
115    fn contains_word(&self, word: &[char]) -> bool {
116        self.full_dict.contains_word(word)
117    }
118
119    fn contains_word_str(&self, word: &str) -> bool {
120        self.full_dict.contains_word_str(word)
121    }
122
123    fn get_word_metadata(&self, word: &[char]) -> Option<&WordMetadata> {
124        self.full_dict.get_word_metadata(word)
125    }
126
127    fn get_word_metadata_str(&self, word: &str) -> Option<&WordMetadata> {
128        self.full_dict.get_word_metadata_str(word)
129    }
130
131    fn fuzzy_match(
132        &self,
133        word: &[char],
134        max_distance: u8,
135        max_results: usize,
136    ) -> Vec<FuzzyMatchResult> {
137        let misspelled_word_charslice = word.normalized();
138        let misspelled_word_string = misspelled_word_charslice.to_string();
139
140        // Actual FST search
141        let dfa = build_dfa(max_distance, &misspelled_word_string);
142        let dfa_lowercase = build_dfa(max_distance, &misspelled_word_string.to_lowercase());
143        let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
144        let mut word_indexes_lowercase_stream = self
145            .word_map
146            .search_with_state(&dfa_lowercase)
147            .into_stream();
148
149        let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
150        let lower_dists = stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
151
152        let mut merged = Vec::with_capacity(upper_dists.len());
153
154        // Merge the two results
155        for ((i_u, dist_u), (i_l, dist_l)) in upper_dists.into_iter().zip(lower_dists.into_iter()) {
156            let (chosen_index, edit_distance) = if dist_u <= dist_l {
157                (i_u, dist_u)
158            } else {
159                (i_l, dist_l)
160            };
161
162            let (word, metadata) = &self.words[chosen_index as usize];
163
164            merged.push(FuzzyMatchResult {
165                word,
166                edit_distance,
167                metadata,
168            })
169        }
170
171        merged.sort_unstable_by_key(|v| v.word);
172        merged.dedup_by_key(|v| v.word);
173        merged.sort_unstable_by_key(|v| v.edit_distance);
174        merged.truncate(max_results);
175
176        merged
177    }
178
179    fn fuzzy_match_str(
180        &self,
181        word: &str,
182        max_distance: u8,
183        max_results: usize,
184    ) -> Vec<FuzzyMatchResult> {
185        self.fuzzy_match(
186            word.chars().collect::<Vec<_>>().as_slice(),
187            max_distance,
188            max_results,
189        )
190    }
191
192    fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
193        self.full_dict.words_iter()
194    }
195
196    fn word_count(&self) -> usize {
197        self.full_dict.word_count()
198    }
199
200    fn contains_exact_word(&self, word: &[char]) -> bool {
201        self.full_dict.contains_exact_word(word)
202    }
203
204    fn contains_exact_word_str(&self, word: &str) -> bool {
205        self.full_dict.contains_exact_word_str(word)
206    }
207
208    fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
209        self.full_dict.get_correct_capitalization_of(word)
210    }
211
212    fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
213        self.full_dict.get_word_from_id(id)
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use itertools::Itertools;
220
221    use crate::CharStringExt;
222    use crate::Dictionary;
223    use crate::WordId;
224
225    use super::FstDictionary;
226
227    #[test]
228    fn fst_map_contains_all_in_full_dict() {
229        let dict = FstDictionary::curated();
230
231        for word in dict.words_iter() {
232            let misspelled_normalized = word.normalized();
233            let misspelled_word = misspelled_normalized.to_string();
234            let misspelled_lower = misspelled_normalized.to_lower().to_string();
235
236            dbg!(&misspelled_lower);
237
238            assert!(!misspelled_word.is_empty());
239            assert!(
240                dict.word_map.contains_key(misspelled_word)
241                    || dict.word_map.contains_key(misspelled_lower)
242            );
243        }
244    }
245
246    #[test]
247    fn fst_contains_hello() {
248        let dict = FstDictionary::curated();
249
250        let word: Vec<_> = "hello".chars().collect();
251        let misspelled_normalized = word.normalized();
252        let misspelled_word = misspelled_normalized.to_string();
253        let misspelled_lower = misspelled_normalized.to_lower().to_string();
254
255        assert!(dict.contains_word(&misspelled_normalized));
256        assert!(
257            dict.word_map.contains_key(misspelled_lower)
258                || dict.word_map.contains_key(misspelled_word)
259        );
260    }
261
262    #[test]
263    fn on_is_not_nominal() {
264        let dict = FstDictionary::curated();
265
266        assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
267    }
268
269    #[test]
270    fn fuzzy_result_sorted_by_edit_distance() {
271        let dict = FstDictionary::curated();
272
273        let results = dict.fuzzy_match_str("hello", 3, 100);
274        let is_sorted_by_dist = results
275            .iter()
276            .map(|fm| fm.edit_distance)
277            .tuple_windows()
278            .all(|(a, b)| a <= b);
279
280        assert!(is_sorted_by_dist)
281    }
282
283    #[test]
284    fn curated_contains_no_duplicates() {
285        let dict = FstDictionary::curated();
286
287        assert!(dict.words.iter().map(|(word, _)| word).all_unique());
288    }
289
290    #[test]
291    fn contractions_not_derived() {
292        let dict = FstDictionary::curated();
293
294        let contractions = ["there's", "we're", "here's"];
295
296        for contraction in contractions {
297            dbg!(contraction);
298            assert!(
299                dict.get_word_metadata_str(contraction)
300                    .unwrap()
301                    .derived_from
302                    .is_none()
303            )
304        }
305    }
306
307    #[test]
308    fn plural_llamas_derived_from_llama() {
309        let dict = FstDictionary::curated();
310
311        assert_eq!(
312            dict.get_word_metadata_str("llamas")
313                .unwrap()
314                .derived_from
315                .unwrap(),
316            WordId::from_word_str("llama")
317        )
318    }
319
320    #[test]
321    fn plural_cats_derived_from_cat() {
322        let dict = FstDictionary::curated();
323
324        assert_eq!(
325            dict.get_word_metadata_str("cats")
326                .unwrap()
327                .derived_from
328                .unwrap(),
329            WordId::from_word_str("cat")
330        );
331    }
332
333    #[test]
334    fn unhappy_derived_from_happy() {
335        let dict = FstDictionary::curated();
336
337        assert_eq!(
338            dict.get_word_metadata_str("unhappy")
339                .unwrap()
340                .derived_from
341                .unwrap(),
342            WordId::from_word_str("happy")
343        );
344    }
345
346    #[test]
347    fn quickly_derived_from_quick() {
348        let dict = FstDictionary::curated();
349
350        assert_eq!(
351            dict.get_word_metadata_str("quickly")
352                .unwrap()
353                .derived_from
354                .unwrap(),
355            WordId::from_word_str("quick")
356        );
357    }
358}