Skip to main content

harper_core/spell/
fst_dictionary.rs

1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use hashbrown::HashMap;
4use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
5use std::borrow::Cow;
6use std::sync::LazyLock;
7use std::{cell::RefCell, sync::Arc};
8
9use crate::{CharString, CharStringExt, DictWordMetadata};
10
11use super::Dictionary;
12use super::FuzzyMatchResult;
13
14/// An immutable dictionary allowing for very fast spellchecking.
15///
16/// For dictionaries with changing contents, such as user and file dictionaries, prefer
17/// [`MutableDictionary`].
18pub struct FstDictionary {
19    /// Underlying [`super::MutableDictionary`] used for everything except fuzzy finding
20    mutable_dict: Arc<MutableDictionary>,
21    /// Used for fuzzy-finding the index of words or metadata
22    word_map: FstMap<Vec<u8>>,
23    /// Used for fuzzy-finding the index of words or metadata
24    words: Vec<(CharString, DictWordMetadata)>,
25}
26
27const EXPECTED_DISTANCE: u8 = 3;
28const TRANSPOSITION_COST_ONE: bool = true;
29
30static DICT: LazyLock<Arc<FstDictionary>> =
31    LazyLock::new(|| Arc::new((*MutableDictionary::curated()).clone().into()));
32
33thread_local! {
34    // Builders are computationally expensive and do not depend on the word, so we store a
35    // collection of builders and the associated edit distance here.
36    // Currently, the edit distance we use is three, but a value that does not exist in this
37    // collection will create a new builder of that distance and push it to the collection.
38    static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
39        EXPECTED_DISTANCE,
40        LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
41    )]);
42}
43
44impl PartialEq for FstDictionary {
45    fn eq(&self, other: &Self) -> bool {
46        self.mutable_dict == other.mutable_dict
47    }
48}
49
50impl FstDictionary {
51    /// Create a dictionary from the curated dictionary included
52    /// in the Harper binary.
53    pub fn curated() -> Arc<Self> {
54        (*DICT).clone()
55    }
56
57    /// Construct a new [`FstDictionary`] using a wordlist as a source.
58    /// This can be expensive, so only use this if fast fuzzy searches are worth it.
59    pub fn new(mut words: Vec<(CharString, DictWordMetadata)>) -> Self {
60        words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
61        words.dedup_by(|(a, _), (b, _)| a == b);
62
63        let mut builder = fst::MapBuilder::memory();
64        for (index, (word, _)) in words.iter().enumerate() {
65            let word = word.iter().collect::<String>();
66            builder
67                .insert(word, index as u64)
68                .expect("Insertion not in lexicographical order!");
69        }
70
71        let mut mutable_dict = MutableDictionary::new();
72        mutable_dict.extend_words(words.iter().cloned());
73
74        let fst_bytes = builder.into_inner().unwrap();
75        let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
76
77        FstDictionary {
78            mutable_dict: Arc::new(mutable_dict),
79            word_map,
80            words,
81        }
82    }
83}
84
85fn build_dfa(max_distance: u8, query: &str) -> DFA {
86    // Insert if it does not exist
87    AUTOMATON_BUILDERS.with_borrow_mut(|v| {
88        if !v.iter().any(|t| t.0 == max_distance) {
89            v.push((
90                max_distance,
91                LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
92            ));
93        }
94    });
95
96    AUTOMATON_BUILDERS.with_borrow(|v| {
97        v.iter()
98            .find(|a| a.0 == max_distance)
99            .unwrap()
100            .1
101            .build_dfa(query)
102    })
103}
104
105/// Consumes a DFA stream and emits the index-edit distance pairs it produces.
106fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
107    let mut word_index_pairs = Vec::new();
108    while let Some((_, v, s)) = stream.next() {
109        word_index_pairs.push((v, dfa.distance(s).to_u8()));
110    }
111
112    word_index_pairs
113}
114
115impl Dictionary for FstDictionary {
116    fn contains_word(&self, word: &[char]) -> bool {
117        self.mutable_dict.contains_word(word)
118    }
119
120    fn contains_word_str(&self, word: &str) -> bool {
121        self.mutable_dict.contains_word_str(word)
122    }
123
124    fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
125        self.mutable_dict.get_word_metadata(word)
126    }
127
128    fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
129        self.mutable_dict.get_word_metadata_str(word)
130    }
131
132    fn fuzzy_match(
133        &'_ self,
134        word: &[char],
135        max_distance: u8,
136        max_results: usize,
137    ) -> Vec<FuzzyMatchResult<'_>> {
138        let misspelled_word_charslice = word.normalized();
139        let misspelled_word_string = misspelled_word_charslice.to_string();
140
141        // Actual FST search
142        let dfa = build_dfa(max_distance, &misspelled_word_string);
143        let dfa_lowercase = build_dfa(max_distance, &misspelled_word_string.to_lowercase());
144        let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
145        let mut word_indexes_lowercase_stream = self
146            .word_map
147            .search_with_state(&dfa_lowercase)
148            .into_stream();
149
150        let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
151        let lower_dists = stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
152
153        // Merge the two results, keeping the smallest distance when both DFAs match.
154        // The uppercase and lowercase searches can return different result counts, so
155        // we can't simply zip the vectors without losing matches.
156        let mut merged = Vec::with_capacity(upper_dists.len().max(lower_dists.len()));
157        let mut best_distances = HashMap::<u64, u8>::new();
158
159        for (idx, dist) in upper_dists.into_iter().chain(lower_dists.into_iter()) {
160            best_distances
161                .entry(idx)
162                .and_modify(|existing| *existing = (*existing).min(dist))
163                .or_insert(dist);
164        }
165
166        for (index, edit_distance) in best_distances {
167            let (word, metadata) = &self.words[index as usize];
168            merged.push(FuzzyMatchResult {
169                word,
170                edit_distance,
171                metadata: Cow::Borrowed(metadata),
172            });
173        }
174
175        // Ignore exact matches
176        merged.retain(|v| v.edit_distance > 0);
177        merged.sort_unstable_by(|a, b| {
178            a.edit_distance
179                .cmp(&b.edit_distance)
180                .then_with(|| a.word.cmp(b.word))
181        });
182        merged.truncate(max_results);
183
184        merged
185    }
186
187    fn fuzzy_match_str(
188        &'_ self,
189        word: &str,
190        max_distance: u8,
191        max_results: usize,
192    ) -> Vec<FuzzyMatchResult<'_>> {
193        self.fuzzy_match(
194            word.chars().collect::<Vec<_>>().as_slice(),
195            max_distance,
196            max_results,
197        )
198    }
199
200    fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
201        self.mutable_dict.words_iter()
202    }
203
204    fn word_count(&self) -> usize {
205        self.mutable_dict.word_count()
206    }
207
208    fn contains_exact_word(&self, word: &[char]) -> bool {
209        self.mutable_dict.contains_exact_word(word)
210    }
211
212    fn contains_exact_word_str(&self, word: &str) -> bool {
213        self.mutable_dict.contains_exact_word_str(word)
214    }
215
216    fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
217        self.mutable_dict.get_correct_capitalization_of(word)
218    }
219
220    fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
221        self.mutable_dict.get_word_from_id(id)
222    }
223
224    fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
225        self.mutable_dict.find_words_with_prefix(prefix)
226    }
227
228    fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
229        self.mutable_dict.find_words_with_common_prefix(word)
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use itertools::Itertools;
236
237    use crate::CharStringExt;
238    use crate::spell::{Dictionary, WordId};
239
240    use super::FstDictionary;
241
242    #[test]
243    fn damerau_transposition_costs_one() {
244        let lev_automata =
245            levenshtein_automata::LevenshteinAutomatonBuilder::new(1, true).build_dfa("woof");
246        assert_eq!(
247            lev_automata.eval("wofo"),
248            levenshtein_automata::Distance::Exact(1)
249        );
250    }
251
252    #[test]
253    fn damerau_transposition_costs_two() {
254        let lev_automata =
255            levenshtein_automata::LevenshteinAutomatonBuilder::new(1, false).build_dfa("woof");
256        assert_eq!(
257            lev_automata.eval("wofo"),
258            levenshtein_automata::Distance::AtLeast(2)
259        );
260    }
261
262    #[test]
263    fn fst_map_contains_all_in_mutable_dict() {
264        let dict = FstDictionary::curated();
265
266        for word in dict.words_iter() {
267            let misspelled_normalized = word.normalized();
268            let misspelled_word = misspelled_normalized.to_string();
269            let misspelled_lower = misspelled_normalized.to_lower().to_string();
270
271            dbg!(&misspelled_lower);
272
273            assert!(!misspelled_word.is_empty());
274            assert!(dict.word_map.contains_key(misspelled_word));
275        }
276    }
277
278    #[test]
279    fn fst_contains_hello() {
280        let dict = FstDictionary::curated();
281
282        let word: Vec<_> = "hello".chars().collect();
283        let misspelled_normalized = word.normalized();
284        let misspelled_word = misspelled_normalized.to_string();
285        let misspelled_lower = misspelled_normalized.to_lower().to_string();
286
287        assert!(dict.contains_word(&misspelled_normalized));
288        assert!(
289            dict.word_map.contains_key(misspelled_lower)
290                || dict.word_map.contains_key(misspelled_word)
291        );
292    }
293
294    #[test]
295    fn on_is_not_nominal() {
296        let dict = FstDictionary::curated();
297
298        assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
299    }
300
301    #[test]
302    fn fuzzy_result_sorted_by_edit_distance() {
303        let dict = FstDictionary::curated();
304
305        let results = dict.fuzzy_match_str("hello", 3, 100);
306        let is_sorted_by_dist = results
307            .iter()
308            .map(|fm| fm.edit_distance)
309            .tuple_windows()
310            .all(|(a, b)| a <= b);
311
312        assert!(is_sorted_by_dist)
313    }
314
315    #[test]
316    fn curated_contains_no_duplicates() {
317        let dict = FstDictionary::curated();
318
319        assert!(dict.words.iter().map(|(word, _)| word).all_unique());
320    }
321
322    #[test]
323    fn contractions_not_derived() {
324        let dict = FstDictionary::curated();
325
326        let contractions = ["there's", "we're", "here's"];
327
328        for contraction in contractions {
329            dbg!(contraction);
330            assert!(
331                dict.get_word_metadata_str(contraction)
332                    .unwrap()
333                    .derived_from
334                    .is_none()
335            )
336        }
337    }
338
339    #[test]
340    fn plural_llamas_derived_from_llama() {
341        let dict = FstDictionary::curated();
342
343        assert_eq!(
344            dict.get_word_metadata_str("llamas")
345                .unwrap()
346                .derived_from
347                .unwrap(),
348            WordId::from_word_str("llama")
349        )
350    }
351
352    #[test]
353    fn plural_cats_derived_from_cat() {
354        let dict = FstDictionary::curated();
355
356        assert_eq!(
357            dict.get_word_metadata_str("cats")
358                .unwrap()
359                .derived_from
360                .unwrap(),
361            WordId::from_word_str("cat")
362        );
363    }
364
365    #[test]
366    fn unhappy_derived_from_happy() {
367        let dict = FstDictionary::curated();
368
369        assert_eq!(
370            dict.get_word_metadata_str("unhappy")
371                .unwrap()
372                .derived_from
373                .unwrap(),
374            WordId::from_word_str("happy")
375        );
376    }
377
378    #[test]
379    fn quickly_derived_from_quick() {
380        let dict = FstDictionary::curated();
381
382        assert_eq!(
383            dict.get_word_metadata_str("quickly")
384                .unwrap()
385                .derived_from
386                .unwrap(),
387            WordId::from_word_str("quick")
388        );
389    }
390}