harper_core/spell/
trie_dictionary.rs

1use lazy_static::lazy_static;
2use std::borrow::Cow;
3use std::sync::Arc;
4
5use trie_rs::Trie;
6use trie_rs::iter::{Keys, PrefixIter, SearchIter};
7
8use crate::DictWordMetadata;
9
10use super::{Dictionary, FstDictionary, FuzzyMatchResult, WordId};
11
12/// A [`Dictionary`] optimized for pre- and postfix search.
13/// Wraps another dictionary to implement other operations.
14pub struct TrieDictionary<D: Dictionary> {
15    trie: Trie<char>,
16    inner: D,
17}
18
19lazy_static! {
20    static ref DICT: Arc<TrieDictionary<Arc<FstDictionary>>> =
21        Arc::new(TrieDictionary::new(FstDictionary::curated()));
22}
23
24impl TrieDictionary<Arc<FstDictionary>> {
25    /// Create a dictionary from the curated dictionary included
26    /// in the Harper binary.
27    pub fn curated() -> Arc<Self> {
28        (*DICT).clone()
29    }
30}
31
32impl<D: Dictionary> TrieDictionary<D> {
33    pub fn new(inner: D) -> Self {
34        let trie = Trie::from_iter(inner.words_iter());
35
36        Self { inner, trie }
37    }
38}
39
40impl<D: Dictionary> Dictionary for TrieDictionary<D> {
41    fn contains_word(&self, word: &[char]) -> bool {
42        self.inner.contains_word(word)
43    }
44
45    fn contains_word_str(&self, word: &str) -> bool {
46        self.inner.contains_word_str(word)
47    }
48
49    fn contains_exact_word(&self, word: &[char]) -> bool {
50        self.inner.contains_exact_word(word)
51    }
52
53    fn contains_exact_word_str(&self, word: &str) -> bool {
54        self.inner.contains_exact_word_str(word)
55    }
56
57    fn fuzzy_match(
58        &'_ self,
59        word: &[char],
60        max_distance: u8,
61        max_results: usize,
62    ) -> Vec<FuzzyMatchResult<'_>> {
63        self.inner.fuzzy_match(word, max_distance, max_results)
64    }
65
66    fn fuzzy_match_str(
67        &'_ self,
68        word: &str,
69        max_distance: u8,
70        max_results: usize,
71    ) -> Vec<FuzzyMatchResult<'_>> {
72        self.inner.fuzzy_match_str(word, max_distance, max_results)
73    }
74
75    fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
76        self.inner.get_correct_capitalization_of(word)
77    }
78
79    fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
80        self.inner.get_word_metadata(word)
81    }
82
83    fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
84        self.inner.get_word_metadata_str(word)
85    }
86
87    fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
88        self.inner.words_iter()
89    }
90
91    fn word_count(&self) -> usize {
92        self.inner.word_count()
93    }
94
95    fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
96        self.inner.get_word_from_id(id)
97    }
98
99    fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
100        let results: Keys<SearchIter<'_, char, (), Vec<char>, _>> =
101            self.trie.predictive_search(prefix);
102        results.map(Cow::Owned).collect()
103    }
104
105    fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
106        let results: Keys<PrefixIter<'_, char, (), Vec<char>, _>> =
107            self.trie.common_prefix_search(word);
108        results.map(Cow::Owned).collect()
109    }
110}
111
112#[cfg(test)]
113mod tests {
114    use std::borrow::Cow;
115
116    use crate::DictWordMetadata;
117    use crate::char_string::char_string;
118    use crate::spell::MutableDictionary;
119    use crate::spell::dictionary::Dictionary;
120    use crate::spell::trie_dictionary::TrieDictionary;
121
122    #[test]
123    fn gets_prefixes_as_expected() {
124        let mut inner = MutableDictionary::new();
125        inner.append_word_str("predict", DictWordMetadata::default());
126        inner.append_word_str("prelude", DictWordMetadata::default());
127        inner.append_word_str("preview", DictWordMetadata::default());
128        inner.append_word_str("dwight", DictWordMetadata::default());
129
130        let dict = TrieDictionary::new(inner);
131
132        let with_prefix = dict.find_words_with_prefix(char_string!("pre").as_slice());
133
134        assert_eq!(with_prefix.len(), 3);
135        assert!(with_prefix.contains(&Cow::Owned(char_string!("predict").into_vec())));
136        assert!(with_prefix.contains(&Cow::Owned(char_string!("prelude").into_vec())));
137        assert!(with_prefix.contains(&Cow::Owned(char_string!("preview").into_vec())));
138    }
139
140    #[test]
141    fn gets_common_prefixes_as_expected() {
142        let mut inner = MutableDictionary::new();
143        inner.append_word_str("pre", DictWordMetadata::default());
144        inner.append_word_str("prep", DictWordMetadata::default());
145        inner.append_word_str("dwight", DictWordMetadata::default());
146
147        let dict = TrieDictionary::new(inner);
148
149        let with_prefix =
150            dict.find_words_with_common_prefix(char_string!("preposition").as_slice());
151
152        assert_eq!(with_prefix.len(), 2);
153        assert!(with_prefix.contains(&Cow::Owned(char_string!("pre").into_vec())));
154        assert!(with_prefix.contains(&Cow::Owned(char_string!("prep").into_vec())));
155    }
156}