harper_core/spell/
fst_dictionary.rs

1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use lazy_static::lazy_static;
4use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
5use std::borrow::Cow;
6use std::{cell::RefCell, sync::Arc};
7
8use crate::{CharString, CharStringExt, DictWordMetadata};
9
10use super::Dictionary;
11use super::FuzzyMatchResult;
12
13/// An immutable dictionary allowing for very fast spellchecking.
14///
15/// For dictionaries with changing contents, such as user and file dictionaries, prefer
16/// [`MutableDictionary`].
17pub struct FstDictionary {
18    /// Underlying [`super::MutableDictionary`] used for everything except fuzzy finding
19    mutable_dict: Arc<MutableDictionary>,
20    /// Used for fuzzy-finding the index of words or metadata
21    word_map: FstMap<Vec<u8>>,
22    /// Used for fuzzy-finding the index of words or metadata
23    words: Vec<(CharString, DictWordMetadata)>,
24}
25
26const EXPECTED_DISTANCE: u8 = 3;
27const TRANSPOSITION_COST_ONE: bool = true;
28
29lazy_static! {
30    static ref DICT: Arc<FstDictionary> = Arc::new((*MutableDictionary::curated()).clone().into());
31}
32
33thread_local! {
34    // Builders are computationally expensive and do not depend on the word, so we store a
35    // collection of builders and the associated edit distance here.
36    // Currently, the edit distance we use is three, but a value that does not exist in this
37    // collection will create a new builder of that distance and push it to the collection.
38    static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
39        EXPECTED_DISTANCE,
40        LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
41    )]);
42}
43
44impl PartialEq for FstDictionary {
45    fn eq(&self, other: &Self) -> bool {
46        self.mutable_dict == other.mutable_dict
47    }
48}
49
50impl FstDictionary {
51    /// Create a dictionary from the curated dictionary included
52    /// in the Harper binary.
53    pub fn curated() -> Arc<Self> {
54        (*DICT).clone()
55    }
56
57    /// Construct a new [`FstDictionary`] using a wordlist as a source.
58    /// This can be expensive, so only use this if fast fuzzy searches are worth it.
59    pub fn new(mut words: Vec<(CharString, DictWordMetadata)>) -> Self {
60        words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
61        words.dedup_by(|(a, _), (b, _)| a == b);
62
63        let mut builder = fst::MapBuilder::memory();
64        for (index, (word, _)) in words.iter().enumerate() {
65            let word = word.iter().collect::<String>();
66            builder
67                .insert(word, index as u64)
68                .expect("Insertion not in lexicographical order!");
69        }
70
71        let mut mutable_dict = MutableDictionary::new();
72        mutable_dict.extend_words(words.iter().cloned());
73
74        let fst_bytes = builder.into_inner().unwrap();
75        let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
76
77        FstDictionary {
78            mutable_dict: Arc::new(mutable_dict),
79            word_map,
80            words,
81        }
82    }
83}
84
85fn build_dfa(max_distance: u8, query: &str) -> DFA {
86    // Insert if it does not exist
87    AUTOMATON_BUILDERS.with_borrow_mut(|v| {
88        if !v.iter().any(|t| t.0 == max_distance) {
89            v.push((
90                max_distance,
91                LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
92            ));
93        }
94    });
95
96    AUTOMATON_BUILDERS.with_borrow(|v| {
97        v.iter()
98            .find(|a| a.0 == max_distance)
99            .unwrap()
100            .1
101            .build_dfa(query)
102    })
103}
104
105/// Consumes a DFA stream and emits the index-edit distance pairs it produces.
106fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
107    let mut word_index_pairs = Vec::new();
108    while let Some((_, v, s)) = stream.next() {
109        word_index_pairs.push((v, dfa.distance(s).to_u8()));
110    }
111
112    word_index_pairs
113}
114
115impl Dictionary for FstDictionary {
116    fn contains_word(&self, word: &[char]) -> bool {
117        self.mutable_dict.contains_word(word)
118    }
119
120    fn contains_word_str(&self, word: &str) -> bool {
121        self.mutable_dict.contains_word_str(word)
122    }
123
124    fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
125        self.mutable_dict.get_word_metadata(word)
126    }
127
128    fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
129        self.mutable_dict.get_word_metadata_str(word)
130    }
131
132    fn fuzzy_match(
133        &'_ self,
134        word: &[char],
135        max_distance: u8,
136        max_results: usize,
137    ) -> Vec<FuzzyMatchResult<'_>> {
138        let misspelled_word_charslice = word.normalized();
139        let misspelled_word_string = misspelled_word_charslice.to_string();
140
141        // Actual FST search
142        let dfa = build_dfa(max_distance, &misspelled_word_string);
143        let dfa_lowercase = build_dfa(max_distance, &misspelled_word_string.to_lowercase());
144        let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
145        let mut word_indexes_lowercase_stream = self
146            .word_map
147            .search_with_state(&dfa_lowercase)
148            .into_stream();
149
150        let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
151        let lower_dists = stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
152
153        let mut merged = Vec::with_capacity(upper_dists.len());
154
155        // Merge the two results
156        for ((i_u, dist_u), (i_l, dist_l)) in upper_dists.into_iter().zip(lower_dists.into_iter()) {
157            let (chosen_index, edit_distance) = if dist_u <= dist_l {
158                (i_u, dist_u)
159            } else {
160                (i_l, dist_l)
161            };
162
163            let (word, metadata) = &self.words[chosen_index as usize];
164
165            merged.push(FuzzyMatchResult {
166                word,
167                edit_distance,
168                metadata: Cow::Borrowed(metadata),
169            })
170        }
171
172        merged.sort_unstable_by_key(|v| v.word);
173        merged.dedup_by_key(|v| v.word);
174        merged.sort_unstable_by_key(|v| v.edit_distance);
175        merged.truncate(max_results);
176
177        merged
178    }
179
180    fn fuzzy_match_str(
181        &'_ self,
182        word: &str,
183        max_distance: u8,
184        max_results: usize,
185    ) -> Vec<FuzzyMatchResult<'_>> {
186        self.fuzzy_match(
187            word.chars().collect::<Vec<_>>().as_slice(),
188            max_distance,
189            max_results,
190        )
191    }
192
193    fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
194        self.mutable_dict.words_iter()
195    }
196
197    fn word_count(&self) -> usize {
198        self.mutable_dict.word_count()
199    }
200
201    fn contains_exact_word(&self, word: &[char]) -> bool {
202        self.mutable_dict.contains_exact_word(word)
203    }
204
205    fn contains_exact_word_str(&self, word: &str) -> bool {
206        self.mutable_dict.contains_exact_word_str(word)
207    }
208
209    fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
210        self.mutable_dict.get_correct_capitalization_of(word)
211    }
212
213    fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
214        self.mutable_dict.get_word_from_id(id)
215    }
216
217    fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
218        self.mutable_dict.find_words_with_prefix(prefix)
219    }
220
221    fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
222        self.mutable_dict.find_words_with_common_prefix(word)
223    }
224}
225
226#[cfg(test)]
227mod tests {
228    use itertools::Itertools;
229
230    use crate::CharStringExt;
231    use crate::spell::{Dictionary, WordId};
232
233    use super::FstDictionary;
234
235    #[test]
236    fn damerau_transposition_costs_one() {
237        let lev_automata =
238            levenshtein_automata::LevenshteinAutomatonBuilder::new(1, true).build_dfa("woof");
239        assert_eq!(
240            lev_automata.eval("wofo"),
241            levenshtein_automata::Distance::Exact(1)
242        );
243    }
244
245    #[test]
246    fn damerau_transposition_costs_two() {
247        let lev_automata =
248            levenshtein_automata::LevenshteinAutomatonBuilder::new(1, false).build_dfa("woof");
249        assert_eq!(
250            lev_automata.eval("wofo"),
251            levenshtein_automata::Distance::AtLeast(2)
252        );
253    }
254
255    #[test]
256    fn fst_map_contains_all_in_mutable_dict() {
257        let dict = FstDictionary::curated();
258
259        for word in dict.words_iter() {
260            let misspelled_normalized = word.normalized();
261            let misspelled_word = misspelled_normalized.to_string();
262            let misspelled_lower = misspelled_normalized.to_lower().to_string();
263
264            dbg!(&misspelled_lower);
265
266            assert!(!misspelled_word.is_empty());
267            assert!(dict.word_map.contains_key(misspelled_word));
268        }
269    }
270
271    #[test]
272    fn fst_contains_hello() {
273        let dict = FstDictionary::curated();
274
275        let word: Vec<_> = "hello".chars().collect();
276        let misspelled_normalized = word.normalized();
277        let misspelled_word = misspelled_normalized.to_string();
278        let misspelled_lower = misspelled_normalized.to_lower().to_string();
279
280        assert!(dict.contains_word(&misspelled_normalized));
281        assert!(
282            dict.word_map.contains_key(misspelled_lower)
283                || dict.word_map.contains_key(misspelled_word)
284        );
285    }
286
287    #[test]
288    fn on_is_not_nominal() {
289        let dict = FstDictionary::curated();
290
291        assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
292    }
293
294    #[test]
295    fn fuzzy_result_sorted_by_edit_distance() {
296        let dict = FstDictionary::curated();
297
298        let results = dict.fuzzy_match_str("hello", 3, 100);
299        let is_sorted_by_dist = results
300            .iter()
301            .map(|fm| fm.edit_distance)
302            .tuple_windows()
303            .all(|(a, b)| a <= b);
304
305        assert!(is_sorted_by_dist)
306    }
307
308    #[test]
309    fn curated_contains_no_duplicates() {
310        let dict = FstDictionary::curated();
311
312        assert!(dict.words.iter().map(|(word, _)| word).all_unique());
313    }
314
315    #[test]
316    fn contractions_not_derived() {
317        let dict = FstDictionary::curated();
318
319        let contractions = ["there's", "we're", "here's"];
320
321        for contraction in contractions {
322            dbg!(contraction);
323            assert!(
324                dict.get_word_metadata_str(contraction)
325                    .unwrap()
326                    .derived_from
327                    .is_none()
328            )
329        }
330    }
331
332    #[test]
333    fn plural_llamas_derived_from_llama() {
334        let dict = FstDictionary::curated();
335
336        assert_eq!(
337            dict.get_word_metadata_str("llamas")
338                .unwrap()
339                .derived_from
340                .unwrap(),
341            WordId::from_word_str("llama")
342        )
343    }
344
345    #[test]
346    fn plural_cats_derived_from_cat() {
347        let dict = FstDictionary::curated();
348
349        assert_eq!(
350            dict.get_word_metadata_str("cats")
351                .unwrap()
352                .derived_from
353                .unwrap(),
354            WordId::from_word_str("cat")
355        );
356    }
357
358    #[test]
359    fn unhappy_derived_from_happy() {
360        let dict = FstDictionary::curated();
361
362        assert_eq!(
363            dict.get_word_metadata_str("unhappy")
364                .unwrap()
365                .derived_from
366                .unwrap(),
367            WordId::from_word_str("happy")
368        );
369    }
370
371    #[test]
372    fn quickly_derived_from_quick() {
373        let dict = FstDictionary::curated();
374
375        assert_eq!(
376            dict.get_word_metadata_str("quickly")
377                .unwrap()
378                .derived_from
379                .unwrap(),
380            WordId::from_word_str("quick")
381        );
382    }
383}