harper_core/spell/
fst_dictionary.rs

1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use hashbrown::HashMap;
4use lazy_static::lazy_static;
5use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
6use std::borrow::Cow;
7use std::{cell::RefCell, sync::Arc};
8
9use crate::{CharString, CharStringExt, DictWordMetadata};
10
11use super::Dictionary;
12use super::FuzzyMatchResult;
13
14/// An immutable dictionary allowing for very fast spellchecking.
15///
16/// For dictionaries with changing contents, such as user and file dictionaries, prefer
17/// [`MutableDictionary`].
18pub struct FstDictionary {
19    /// Underlying [`super::MutableDictionary`] used for everything except fuzzy finding
20    mutable_dict: Arc<MutableDictionary>,
21    /// Used for fuzzy-finding the index of words or metadata
22    word_map: FstMap<Vec<u8>>,
23    /// Used for fuzzy-finding the index of words or metadata
24    words: Vec<(CharString, DictWordMetadata)>,
25}
26
27const EXPECTED_DISTANCE: u8 = 3;
28const TRANSPOSITION_COST_ONE: bool = true;
29
30lazy_static! {
31    static ref DICT: Arc<FstDictionary> = Arc::new((*MutableDictionary::curated()).clone().into());
32}
33
34thread_local! {
35    // Builders are computationally expensive and do not depend on the word, so we store a
36    // collection of builders and the associated edit distance here.
37    // Currently, the edit distance we use is three, but a value that does not exist in this
38    // collection will create a new builder of that distance and push it to the collection.
39    static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
40        EXPECTED_DISTANCE,
41        LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
42    )]);
43}
44
45impl PartialEq for FstDictionary {
46    fn eq(&self, other: &Self) -> bool {
47        self.mutable_dict == other.mutable_dict
48    }
49}
50
51impl FstDictionary {
52    /// Create a dictionary from the curated dictionary included
53    /// in the Harper binary.
54    pub fn curated() -> Arc<Self> {
55        (*DICT).clone()
56    }
57
58    /// Construct a new [`FstDictionary`] using a wordlist as a source.
59    /// This can be expensive, so only use this if fast fuzzy searches are worth it.
60    pub fn new(mut words: Vec<(CharString, DictWordMetadata)>) -> Self {
61        words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
62        words.dedup_by(|(a, _), (b, _)| a == b);
63
64        let mut builder = fst::MapBuilder::memory();
65        for (index, (word, _)) in words.iter().enumerate() {
66            let word = word.iter().collect::<String>();
67            builder
68                .insert(word, index as u64)
69                .expect("Insertion not in lexicographical order!");
70        }
71
72        let mut mutable_dict = MutableDictionary::new();
73        mutable_dict.extend_words(words.iter().cloned());
74
75        let fst_bytes = builder.into_inner().unwrap();
76        let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
77
78        FstDictionary {
79            mutable_dict: Arc::new(mutable_dict),
80            word_map,
81            words,
82        }
83    }
84}
85
86fn build_dfa(max_distance: u8, query: &str) -> DFA {
87    // Insert if it does not exist
88    AUTOMATON_BUILDERS.with_borrow_mut(|v| {
89        if !v.iter().any(|t| t.0 == max_distance) {
90            v.push((
91                max_distance,
92                LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
93            ));
94        }
95    });
96
97    AUTOMATON_BUILDERS.with_borrow(|v| {
98        v.iter()
99            .find(|a| a.0 == max_distance)
100            .unwrap()
101            .1
102            .build_dfa(query)
103    })
104}
105
106/// Consumes a DFA stream and emits the index-edit distance pairs it produces.
107fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
108    let mut word_index_pairs = Vec::new();
109    while let Some((_, v, s)) = stream.next() {
110        word_index_pairs.push((v, dfa.distance(s).to_u8()));
111    }
112
113    word_index_pairs
114}
115
116impl Dictionary for FstDictionary {
117    fn contains_word(&self, word: &[char]) -> bool {
118        self.mutable_dict.contains_word(word)
119    }
120
121    fn contains_word_str(&self, word: &str) -> bool {
122        self.mutable_dict.contains_word_str(word)
123    }
124
125    fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
126        self.mutable_dict.get_word_metadata(word)
127    }
128
129    fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
130        self.mutable_dict.get_word_metadata_str(word)
131    }
132
133    fn fuzzy_match(
134        &'_ self,
135        word: &[char],
136        max_distance: u8,
137        max_results: usize,
138    ) -> Vec<FuzzyMatchResult<'_>> {
139        let misspelled_word_charslice = word.normalized();
140        let misspelled_word_string = misspelled_word_charslice.to_string();
141
142        // Actual FST search
143        let dfa = build_dfa(max_distance, &misspelled_word_string);
144        let dfa_lowercase = build_dfa(max_distance, &misspelled_word_string.to_lowercase());
145        let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
146        let mut word_indexes_lowercase_stream = self
147            .word_map
148            .search_with_state(&dfa_lowercase)
149            .into_stream();
150
151        let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
152        let lower_dists = stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
153
154        // Merge the two results, keeping the smallest distance when both DFAs match.
155        // The uppercase and lowercase searches can return different result counts, so
156        // we can't simply zip the vectors without losing matches.
157        let mut merged = Vec::with_capacity(upper_dists.len().max(lower_dists.len()));
158        let mut best_distances = HashMap::<u64, u8>::new();
159
160        for (idx, dist) in upper_dists.into_iter().chain(lower_dists.into_iter()) {
161            best_distances
162                .entry(idx)
163                .and_modify(|existing| *existing = (*existing).min(dist))
164                .or_insert(dist);
165        }
166
167        for (index, edit_distance) in best_distances {
168            let (word, metadata) = &self.words[index as usize];
169            merged.push(FuzzyMatchResult {
170                word,
171                edit_distance,
172                metadata: Cow::Borrowed(metadata),
173            });
174        }
175
176        // Ignore exact matches
177        merged.retain(|v| v.edit_distance > 0);
178        merged.sort_unstable_by(|a, b| {
179            a.edit_distance
180                .cmp(&b.edit_distance)
181                .then_with(|| a.word.cmp(b.word))
182        });
183        merged.truncate(max_results);
184
185        merged
186    }
187
188    fn fuzzy_match_str(
189        &'_ self,
190        word: &str,
191        max_distance: u8,
192        max_results: usize,
193    ) -> Vec<FuzzyMatchResult<'_>> {
194        self.fuzzy_match(
195            word.chars().collect::<Vec<_>>().as_slice(),
196            max_distance,
197            max_results,
198        )
199    }
200
201    fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
202        self.mutable_dict.words_iter()
203    }
204
205    fn word_count(&self) -> usize {
206        self.mutable_dict.word_count()
207    }
208
209    fn contains_exact_word(&self, word: &[char]) -> bool {
210        self.mutable_dict.contains_exact_word(word)
211    }
212
213    fn contains_exact_word_str(&self, word: &str) -> bool {
214        self.mutable_dict.contains_exact_word_str(word)
215    }
216
217    fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
218        self.mutable_dict.get_correct_capitalization_of(word)
219    }
220
221    fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
222        self.mutable_dict.get_word_from_id(id)
223    }
224
225    fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
226        self.mutable_dict.find_words_with_prefix(prefix)
227    }
228
229    fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
230        self.mutable_dict.find_words_with_common_prefix(word)
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use itertools::Itertools;
237
238    use crate::CharStringExt;
239    use crate::spell::{Dictionary, WordId};
240
241    use super::FstDictionary;
242
243    #[test]
244    fn damerau_transposition_costs_one() {
245        let lev_automata =
246            levenshtein_automata::LevenshteinAutomatonBuilder::new(1, true).build_dfa("woof");
247        assert_eq!(
248            lev_automata.eval("wofo"),
249            levenshtein_automata::Distance::Exact(1)
250        );
251    }
252
253    #[test]
254    fn damerau_transposition_costs_two() {
255        let lev_automata =
256            levenshtein_automata::LevenshteinAutomatonBuilder::new(1, false).build_dfa("woof");
257        assert_eq!(
258            lev_automata.eval("wofo"),
259            levenshtein_automata::Distance::AtLeast(2)
260        );
261    }
262
263    #[test]
264    fn fst_map_contains_all_in_mutable_dict() {
265        let dict = FstDictionary::curated();
266
267        for word in dict.words_iter() {
268            let misspelled_normalized = word.normalized();
269            let misspelled_word = misspelled_normalized.to_string();
270            let misspelled_lower = misspelled_normalized.to_lower().to_string();
271
272            dbg!(&misspelled_lower);
273
274            assert!(!misspelled_word.is_empty());
275            assert!(dict.word_map.contains_key(misspelled_word));
276        }
277    }
278
279    #[test]
280    fn fst_contains_hello() {
281        let dict = FstDictionary::curated();
282
283        let word: Vec<_> = "hello".chars().collect();
284        let misspelled_normalized = word.normalized();
285        let misspelled_word = misspelled_normalized.to_string();
286        let misspelled_lower = misspelled_normalized.to_lower().to_string();
287
288        assert!(dict.contains_word(&misspelled_normalized));
289        assert!(
290            dict.word_map.contains_key(misspelled_lower)
291                || dict.word_map.contains_key(misspelled_word)
292        );
293    }
294
295    #[test]
296    fn on_is_not_nominal() {
297        let dict = FstDictionary::curated();
298
299        assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
300    }
301
302    #[test]
303    fn fuzzy_result_sorted_by_edit_distance() {
304        let dict = FstDictionary::curated();
305
306        let results = dict.fuzzy_match_str("hello", 3, 100);
307        let is_sorted_by_dist = results
308            .iter()
309            .map(|fm| fm.edit_distance)
310            .tuple_windows()
311            .all(|(a, b)| a <= b);
312
313        assert!(is_sorted_by_dist)
314    }
315
316    #[test]
317    fn curated_contains_no_duplicates() {
318        let dict = FstDictionary::curated();
319
320        assert!(dict.words.iter().map(|(word, _)| word).all_unique());
321    }
322
323    #[test]
324    fn contractions_not_derived() {
325        let dict = FstDictionary::curated();
326
327        let contractions = ["there's", "we're", "here's"];
328
329        for contraction in contractions {
330            dbg!(contraction);
331            assert!(
332                dict.get_word_metadata_str(contraction)
333                    .unwrap()
334                    .derived_from
335                    .is_none()
336            )
337        }
338    }
339
340    #[test]
341    fn plural_llamas_derived_from_llama() {
342        let dict = FstDictionary::curated();
343
344        assert_eq!(
345            dict.get_word_metadata_str("llamas")
346                .unwrap()
347                .derived_from
348                .unwrap(),
349            WordId::from_word_str("llama")
350        )
351    }
352
353    #[test]
354    fn plural_cats_derived_from_cat() {
355        let dict = FstDictionary::curated();
356
357        assert_eq!(
358            dict.get_word_metadata_str("cats")
359                .unwrap()
360                .derived_from
361                .unwrap(),
362            WordId::from_word_str("cat")
363        );
364    }
365
366    #[test]
367    fn unhappy_derived_from_happy() {
368        let dict = FstDictionary::curated();
369
370        assert_eq!(
371            dict.get_word_metadata_str("unhappy")
372                .unwrap()
373                .derived_from
374                .unwrap(),
375            WordId::from_word_str("happy")
376        );
377    }
378
379    #[test]
380    fn quickly_derived_from_quick() {
381        let dict = FstDictionary::curated();
382
383        assert_eq!(
384            dict.get_word_metadata_str("quickly")
385                .unwrap()
386                .derived_from
387                .unwrap(),
388            WordId::from_word_str("quick")
389        );
390    }
391}