Skip to main content

harper_core/spell/
fst_dictionary.rs

1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use hashbrown::HashMap;
4use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
5use std::borrow::Cow;
6use std::sync::LazyLock;
7use std::{cell::RefCell, sync::Arc};
8
9use crate::{CharString, CharStringExt, DictWordMetadata};
10
11use super::Dictionary;
12use super::FuzzyMatchResult;
13
14/// An immutable dictionary allowing for very fast spellchecking.
15///
16/// For dictionaries with changing contents, such as user and file dictionaries, prefer
17/// [`MutableDictionary`].
18pub struct FstDictionary {
19    /// Underlying [`super::MutableDictionary`] used for everything except fuzzy finding
20    mutable_dict: Arc<MutableDictionary>,
21    /// Used for fuzzy-finding the index of words or metadata
22    word_map: FstMap<Vec<u8>>,
23    /// Used for fuzzy-finding the index of words or metadata
24    words: Vec<(CharString, DictWordMetadata)>,
25}
26
27const EXPECTED_DISTANCE: u8 = 3;
28const TRANSPOSITION_COST_ONE: bool = true;
29
30static DICT: LazyLock<Arc<FstDictionary>> =
31    LazyLock::new(|| Arc::new((*MutableDictionary::curated()).clone().into()));
32
33thread_local! {
34    // Builders are computationally expensive and do not depend on the word, so we store a
35    // collection of builders and the associated edit distance here.
36    // Currently, the edit distance we use is three, but a value that does not exist in this
37    // collection will create a new builder of that distance and push it to the collection.
38    static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
39        EXPECTED_DISTANCE,
40        LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
41    )]);
42}
43
44impl PartialEq for FstDictionary {
45    fn eq(&self, other: &Self) -> bool {
46        self.mutable_dict == other.mutable_dict
47    }
48}
49
50impl FstDictionary {
51    /// Create a dictionary from the curated dictionary included
52    /// in the Harper binary.
53    pub fn curated() -> Arc<Self> {
54        (*DICT).clone()
55    }
56
57    /// Construct a new [`FstDictionary`] using a wordlist as a source.
58    /// This can be expensive, so only use this if fast fuzzy searches are worth it.
59    pub fn new(mut words: Vec<(CharString, DictWordMetadata)>) -> Self {
60        words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
61        words.dedup_by(|(a, _), (b, _)| a == b);
62
63        let mut builder = fst::MapBuilder::memory();
64        for (index, (word, _)) in words.iter().enumerate() {
65            let word = word.iter().collect::<String>();
66            builder
67                .insert(word, index as u64)
68                .expect("Insertion not in lexicographical order!");
69        }
70
71        let mut mutable_dict = MutableDictionary::new();
72        mutable_dict.extend_words(words.iter().cloned());
73
74        let fst_bytes = builder.into_inner().unwrap();
75        let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
76
77        FstDictionary {
78            mutable_dict: Arc::new(mutable_dict),
79            word_map,
80            words,
81        }
82    }
83}
84
85fn build_dfa(max_distance: u8, query: &str) -> DFA {
86    // Insert if it does not exist
87    AUTOMATON_BUILDERS.with_borrow_mut(|v| {
88        if !v.iter().any(|t| t.0 == max_distance) {
89            v.push((
90                max_distance,
91                LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
92            ));
93        }
94    });
95
96    AUTOMATON_BUILDERS.with_borrow(|v| {
97        v.iter()
98            .find(|a| a.0 == max_distance)
99            .unwrap()
100            .1
101            .build_dfa(query)
102    })
103}
104
105/// Consumes a DFA stream and emits the index-edit distance pairs it produces.
106fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
107    let mut word_index_pairs = Vec::new();
108    while let Some((_, v, s)) = stream.next() {
109        word_index_pairs.push((v, dfa.distance(s).to_u8()));
110    }
111
112    word_index_pairs
113}
114
115/// Merges index-distance pairs, keeping the smallest distance for each word.
116fn merge_best_distances(
117    best_distances: &mut HashMap<u64, u8>,
118    distances: impl IntoIterator<Item = (u64, u8)>,
119) {
120    for (idx, dist) in distances {
121        best_distances
122            .entry(idx)
123            .and_modify(|existing| *existing = (*existing).min(dist))
124            .or_insert(dist);
125    }
126}
127
128impl Dictionary for FstDictionary {
129    fn contains_word(&self, word: &[char]) -> bool {
130        self.mutable_dict.contains_word(word)
131    }
132
133    fn contains_word_str(&self, word: &str) -> bool {
134        self.mutable_dict.contains_word_str(word)
135    }
136
137    fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
138        self.mutable_dict.get_word_metadata(word)
139    }
140
141    fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
142        self.mutable_dict.get_word_metadata_str(word)
143    }
144
145    fn fuzzy_match(
146        &'_ self,
147        word: &[char],
148        max_distance: u8,
149        max_results: usize,
150    ) -> Vec<FuzzyMatchResult<'_>> {
151        let misspelled_word_charslice = word.normalized();
152        let misspelled_word_string = misspelled_word_charslice.to_string();
153        let misspelled_lower = misspelled_word_string.to_lowercase();
154        let is_already_lower = misspelled_lower == misspelled_word_string;
155
156        // Actual FST search
157        let dfa = build_dfa(max_distance, &misspelled_word_string);
158        let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
159        let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
160
161        // Merge the two results, keeping the smallest distance when both DFAs match.
162        // The uppercase and lowercase searches can return different result counts, so
163        // we can't simply zip the vectors without losing matches.
164        let mut best_distances = HashMap::<u64, u8>::new();
165
166        merge_best_distances(&mut best_distances, upper_dists);
167
168        // Only build the lowercase DFA when the query is not already lowercase.
169        if !is_already_lower {
170            let dfa_lowercase = build_dfa(max_distance, &misspelled_lower);
171            let mut word_indexes_lowercase_stream = self
172                .word_map
173                .search_with_state(&dfa_lowercase)
174                .into_stream();
175            let lower_dists =
176                stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
177
178            merge_best_distances(&mut best_distances, lower_dists);
179        }
180
181        let mut merged = Vec::with_capacity(best_distances.len());
182        for (index, edit_distance) in best_distances {
183            let (word, metadata) = &self.words[index as usize];
184            merged.push(FuzzyMatchResult {
185                word,
186                edit_distance,
187                metadata: Cow::Borrowed(metadata),
188            });
189        }
190
191        // Ignore exact matches
192        merged.retain(|v| v.edit_distance > 0);
193        merged.sort_unstable_by(|a, b| {
194            a.edit_distance
195                .cmp(&b.edit_distance)
196                .then_with(|| a.word.cmp(b.word))
197        });
198        merged.truncate(max_results);
199
200        merged
201    }
202
203    fn fuzzy_match_str(
204        &'_ self,
205        word: &str,
206        max_distance: u8,
207        max_results: usize,
208    ) -> Vec<FuzzyMatchResult<'_>> {
209        self.fuzzy_match(
210            word.chars().collect::<Vec<_>>().as_slice(),
211            max_distance,
212            max_results,
213        )
214    }
215
216    fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
217        self.mutable_dict.words_iter()
218    }
219
220    fn word_count(&self) -> usize {
221        self.mutable_dict.word_count()
222    }
223
224    fn contains_exact_word(&self, word: &[char]) -> bool {
225        self.mutable_dict.contains_exact_word(word)
226    }
227
228    fn contains_exact_word_str(&self, word: &str) -> bool {
229        self.mutable_dict.contains_exact_word_str(word)
230    }
231
232    fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
233        self.mutable_dict.get_correct_capitalization_of(word)
234    }
235
236    fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
237        self.mutable_dict.get_word_from_id(id)
238    }
239
240    fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
241        self.mutable_dict.find_words_with_prefix(prefix)
242    }
243
244    fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
245        self.mutable_dict.find_words_with_common_prefix(word)
246    }
247}
248
249#[cfg(test)]
250mod tests {
251    use itertools::Itertools;
252
253    use crate::CharStringExt;
254    use crate::DictWordMetadata;
255    use crate::spell::{Dictionary, MutableDictionary, WordId};
256
257    use super::FstDictionary;
258
259    fn test_dictionaries(words: &[&str]) -> (MutableDictionary, FstDictionary) {
260        let mut mutable = MutableDictionary::new();
261
262        for word in words {
263            mutable.append_word_str(word, DictWordMetadata::default());
264        }
265
266        let fst = FstDictionary::from(mutable.clone());
267
268        (mutable, fst)
269    }
270
271    fn fuzzy_matches<D: Dictionary + ?Sized>(
272        dict: &D,
273        word: &str,
274        max_distance: u8,
275        max_results: usize,
276    ) -> Vec<(String, u8)> {
277        let mut matches = dict
278            .fuzzy_match_str(word, max_distance, max_results)
279            .into_iter()
280            .map(|result| (result.word.iter().collect::<String>(), result.edit_distance))
281            .collect_vec();
282
283        matches.sort_unstable_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
284        matches
285    }
286
287    #[test]
288    fn damerau_transposition_costs_one() {
289        let lev_automata =
290            levenshtein_automata::LevenshteinAutomatonBuilder::new(1, true).build_dfa("woof");
291        assert_eq!(
292            lev_automata.eval("wofo"),
293            levenshtein_automata::Distance::Exact(1)
294        );
295    }
296
297    #[test]
298    fn damerau_transposition_costs_two() {
299        let lev_automata =
300            levenshtein_automata::LevenshteinAutomatonBuilder::new(1, false).build_dfa("woof");
301        assert_eq!(
302            lev_automata.eval("wofo"),
303            levenshtein_automata::Distance::AtLeast(2)
304        );
305    }
306
307    #[test]
308    fn fst_map_contains_all_in_mutable_dict() {
309        let dict = FstDictionary::curated();
310
311        for word in dict.words_iter() {
312            let misspelled_normalized = word.normalized();
313            let misspelled_word = misspelled_normalized.to_string();
314            let misspelled_lower = misspelled_normalized.to_lower().to_string();
315
316            dbg!(&misspelled_lower);
317
318            assert!(!misspelled_word.is_empty());
319            assert!(dict.word_map.contains_key(misspelled_word));
320        }
321    }
322
323    #[test]
324    fn fst_contains_hello() {
325        let dict = FstDictionary::curated();
326
327        let word: Vec<_> = "hello".chars().collect();
328        let misspelled_normalized = word.normalized();
329        let misspelled_word = misspelled_normalized.to_string();
330        let misspelled_lower = misspelled_normalized.to_lower().to_string();
331
332        assert!(dict.contains_word(&misspelled_normalized));
333        assert!(
334            dict.word_map.contains_key(misspelled_lower)
335                || dict.word_map.contains_key(misspelled_word)
336        );
337    }
338
339    #[test]
340    fn on_is_not_nominal() {
341        let dict = FstDictionary::curated();
342
343        assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
344    }
345
346    #[test]
347    fn fuzzy_result_sorted_by_edit_distance() {
348        let dict = FstDictionary::curated();
349
350        let results = dict.fuzzy_match_str("hello", 3, 100);
351        let is_sorted_by_dist = results
352            .iter()
353            .map(|fm| fm.edit_distance)
354            .tuple_windows()
355            .all(|(a, b)| a <= b);
356
357        assert!(is_sorted_by_dist)
358    }
359
360    #[test]
361    fn curated_contains_no_duplicates() {
362        let dict = FstDictionary::curated();
363
364        assert!(dict.words.iter().map(|(word, _)| word).all_unique());
365    }
366
367    #[test]
368    fn contractions_not_derived() {
369        let dict = FstDictionary::curated();
370
371        let contractions = ["there's", "we're", "here's"];
372
373        for contraction in contractions {
374            dbg!(contraction);
375            assert!(
376                dict.get_word_metadata_str(contraction)
377                    .unwrap()
378                    .derived_from
379                    .is_none()
380            )
381        }
382    }
383
384    #[test]
385    fn plural_llamas_derived_from_llama() {
386        let dict = FstDictionary::curated();
387
388        assert_eq!(
389            dict.get_word_metadata_str("llamas")
390                .unwrap()
391                .derived_from
392                .unwrap(),
393            WordId::from_word_str("llama")
394        )
395    }
396
397    #[test]
398    fn plural_cats_derived_from_cat() {
399        let dict = FstDictionary::curated();
400
401        assert_eq!(
402            dict.get_word_metadata_str("cats")
403                .unwrap()
404                .derived_from
405                .unwrap(),
406            WordId::from_word_str("cat")
407        );
408    }
409
410    #[test]
411    fn unhappy_derived_from_happy() {
412        let dict = FstDictionary::curated();
413
414        assert_eq!(
415            dict.get_word_metadata_str("unhappy")
416                .unwrap()
417                .derived_from
418                .unwrap(),
419            WordId::from_word_str("happy")
420        );
421    }
422
423    #[test]
424    fn quickly_derived_from_quick() {
425        let dict = FstDictionary::curated();
426
427        assert_eq!(
428            dict.get_word_metadata_str("quickly")
429                .unwrap()
430                .derived_from
431                .unwrap(),
432            WordId::from_word_str("quick")
433        );
434    }
435
436    #[test]
437    fn lowercase_fuzzy_match_matches_mutable_dictionary() {
438        let (mutable, fst) =
439            test_dictionaries(&["spelling", "spilling", "selling", "smelling", "shelling"]);
440
441        let mutable_results = fuzzy_matches(&mutable, "speling", 3, 10);
442        let fst_results = fuzzy_matches(&fst, "speling", 3, 10);
443
444        assert_eq!(fst_results, mutable_results);
445        assert_eq!(fst_results.first(), Some(&(String::from("spelling"), 1)));
446    }
447
448    #[test]
449    fn capitalized_fuzzy_match_matches_mutable_dictionary() {
450        let (mutable, fst) =
451            test_dictionaries(&["spelling", "spilling", "selling", "smelling", "shelling"]);
452
453        let mutable_results = fuzzy_matches(&mutable, "Speling", 3, 10);
454        let fst_results = fuzzy_matches(&fst, "Speling", 3, 10);
455
456        assert_eq!(fst_results, mutable_results);
457        assert_eq!(fst_results.first(), Some(&(String::from("spelling"), 1)));
458    }
459
460    #[test]
461    fn uppercase_fuzzy_match_matches_mutable_dictionary() {
462        let (mutable, fst) =
463            test_dictionaries(&["spelling", "spilling", "selling", "smelling", "shelling"]);
464
465        let mutable_results = fuzzy_matches(&mutable, "SPELING", 3, 10);
466        let fst_results = fuzzy_matches(&fst, "SPELING", 3, 10);
467
468        assert_eq!(fst_results, mutable_results);
469        assert_eq!(fst_results.first(), Some(&(String::from("spelling"), 1)));
470    }
471
472    #[test]
473    fn query_casing_produces_the_same_fuzzy_matches() {
474        let (_, fst) =
475            test_dictionaries(&["spelling", "spilling", "selling", "smelling", "shelling"]);
476
477        let lowercase = fuzzy_matches(&fst, "speling", 3, 10);
478        let capitalized = fuzzy_matches(&fst, "Speling", 3, 10);
479        let uppercase = fuzzy_matches(&fst, "SPELING", 3, 10);
480
481        assert_eq!(lowercase, capitalized);
482        assert_eq!(lowercase, uppercase);
483    }
484}