Skip to main content

harper_core/spell/
trie_dictionary.rs

1use std::borrow::Cow;
2use std::sync::{Arc, LazyLock};
3
4use trie_rs::Trie;
5use trie_rs::iter::{Keys, PrefixIter, SearchIter};
6
7use crate::DictWordMetadata;
8
9use super::{Dictionary, FstDictionary, FuzzyMatchResult, WordId};
10
11/// A [`Dictionary`] optimized for pre- and postfix search.
12/// Wraps another dictionary to implement other operations.
13pub struct TrieDictionary<D: Dictionary> {
14    trie: Trie<char>,
15    inner: D,
16}
17
18pub static DICT: LazyLock<Arc<TrieDictionary<Arc<FstDictionary>>>> =
19    LazyLock::new(|| Arc::new(TrieDictionary::new(FstDictionary::curated())));
20
21impl TrieDictionary<Arc<FstDictionary>> {
22    /// Create a dictionary from the curated dictionary included
23    /// in the Harper binary.
24    pub fn curated() -> Arc<Self> {
25        (*DICT).clone()
26    }
27}
28
29impl<D: Dictionary> TrieDictionary<D> {
30    pub fn new(inner: D) -> Self {
31        let trie = Trie::from_iter(inner.words_iter());
32
33        Self { inner, trie }
34    }
35}
36
37impl<D: Dictionary> Dictionary for TrieDictionary<D> {
38    fn contains_word(&self, word: &[char]) -> bool {
39        self.inner.contains_word(word)
40    }
41
42    fn contains_word_str(&self, word: &str) -> bool {
43        self.inner.contains_word_str(word)
44    }
45
46    fn contains_exact_word(&self, word: &[char]) -> bool {
47        self.inner.contains_exact_word(word)
48    }
49
50    fn contains_exact_word_str(&self, word: &str) -> bool {
51        self.inner.contains_exact_word_str(word)
52    }
53
54    fn fuzzy_match(
55        &'_ self,
56        word: &[char],
57        max_distance: u8,
58        max_results: usize,
59    ) -> Vec<FuzzyMatchResult<'_>> {
60        self.inner.fuzzy_match(word, max_distance, max_results)
61    }
62
63    fn fuzzy_match_str(
64        &'_ self,
65        word: &str,
66        max_distance: u8,
67        max_results: usize,
68    ) -> Vec<FuzzyMatchResult<'_>> {
69        self.inner.fuzzy_match_str(word, max_distance, max_results)
70    }
71
72    fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
73        self.inner.get_correct_capitalization_of(word)
74    }
75
76    fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
77        self.inner.get_word_metadata(word)
78    }
79
80    fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
81        self.inner.get_word_metadata_str(word)
82    }
83
84    fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
85        self.inner.words_iter()
86    }
87
88    fn word_count(&self) -> usize {
89        self.inner.word_count()
90    }
91
92    fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
93        self.inner.get_word_from_id(id)
94    }
95
96    fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
97        let results: Keys<SearchIter<'_, char, (), Vec<char>, _>> =
98            self.trie.predictive_search(prefix);
99        results.map(Cow::Owned).collect()
100    }
101
102    fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
103        let results: Keys<PrefixIter<'_, char, (), Vec<char>, _>> =
104            self.trie.common_prefix_search(word);
105        results.map(Cow::Owned).collect()
106    }
107}
108
109#[cfg(test)]
110mod tests {
111    use std::borrow::Cow;
112
113    use crate::DictWordMetadata;
114    use crate::char_string::char_string;
115    use crate::spell::MutableDictionary;
116    use crate::spell::dictionary::Dictionary;
117    use crate::spell::trie_dictionary::TrieDictionary;
118
119    #[test]
120    fn gets_prefixes_as_expected() {
121        let mut inner = MutableDictionary::new();
122        inner.append_word_str("predict", DictWordMetadata::default());
123        inner.append_word_str("prelude", DictWordMetadata::default());
124        inner.append_word_str("preview", DictWordMetadata::default());
125        inner.append_word_str("dwight", DictWordMetadata::default());
126
127        let dict = TrieDictionary::new(inner);
128
129        let with_prefix = dict.find_words_with_prefix(char_string!("pre").as_slice());
130
131        assert_eq!(with_prefix.len(), 3);
132        assert!(with_prefix.contains(&Cow::Owned(char_string!("predict").into_vec())));
133        assert!(with_prefix.contains(&Cow::Owned(char_string!("prelude").into_vec())));
134        assert!(with_prefix.contains(&Cow::Owned(char_string!("preview").into_vec())));
135    }
136
137    #[test]
138    fn gets_common_prefixes_as_expected() {
139        let mut inner = MutableDictionary::new();
140        inner.append_word_str("pre", DictWordMetadata::default());
141        inner.append_word_str("prep", DictWordMetadata::default());
142        inner.append_word_str("dwight", DictWordMetadata::default());
143
144        let dict = TrieDictionary::new(inner);
145
146        let with_prefix =
147            dict.find_words_with_common_prefix(char_string!("preposition").as_slice());
148
149        assert_eq!(with_prefix.len(), 2);
150        assert!(with_prefix.contains(&Cow::Owned(char_string!("pre").into_vec())));
151        assert!(with_prefix.contains(&Cow::Owned(char_string!("prep").into_vec())));
152    }
153}