harper_core/spell/
fst_dictionary.rs1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use lazy_static::lazy_static;
4use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
5use std::{cell::RefCell, sync::Arc};
6
7use crate::{CharString, CharStringExt, WordMetadata};
8
9use super::Dictionary;
10use super::FuzzyMatchResult;
11
12pub struct FstDictionary {
17 full_dict: Arc<MutableDictionary>,
19 word_map: FstMap<Vec<u8>>,
21 words: Vec<(CharString, WordMetadata)>,
23}
24
25const EXPECTED_DISTANCE: u8 = 3;
26const TRANSPOSITION_COST_ONE: bool = false;
27
28lazy_static! {
29 static ref DICT: Arc<FstDictionary> = Arc::new((*MutableDictionary::curated()).clone().into());
30}
31
32thread_local! {
33 static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
38 EXPECTED_DISTANCE,
39 LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
40 )]);
41}
42
43impl PartialEq for FstDictionary {
44 fn eq(&self, other: &Self) -> bool {
45 self.full_dict == other.full_dict
46 }
47}
48
49impl FstDictionary {
50 pub fn curated() -> Arc<Self> {
53 (*DICT).clone()
54 }
55
56 pub fn new(mut words: Vec<(CharString, WordMetadata)>) -> Self {
59 words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
60 words.dedup_by(|(a, _), (b, _)| a == b);
61
62 let mut builder = fst::MapBuilder::memory();
63 for (index, (word, _)) in words.iter().enumerate() {
64 let word = word.iter().collect::<String>();
65 builder
66 .insert(word, index as u64)
67 .expect("Insertion not in lexicographical order!");
68 }
69
70 let mut full_dict = MutableDictionary::new();
71 full_dict.extend_words(words.iter().cloned());
72
73 let fst_bytes = builder.into_inner().unwrap();
74 let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
75
76 FstDictionary {
77 full_dict: Arc::new(full_dict),
78 word_map,
79 words,
80 }
81 }
82}
83
84fn build_dfa(max_distance: u8, query: &str) -> DFA {
85 AUTOMATON_BUILDERS.with_borrow_mut(|v| {
87 if !v.iter().any(|t| t.0 == max_distance) {
88 v.push((
89 max_distance,
90 LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
91 ));
92 }
93 });
94
95 AUTOMATON_BUILDERS.with_borrow(|v| {
96 v.iter()
97 .find(|a| a.0 == max_distance)
98 .unwrap()
99 .1
100 .build_dfa(query)
101 })
102}
103
104fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
106 let mut word_index_pairs = Vec::new();
107 while let Some((_, v, s)) = stream.next() {
108 word_index_pairs.push((v, dfa.distance(s).to_u8()));
109 }
110
111 word_index_pairs
112}
113
114impl Dictionary for FstDictionary {
115 fn contains_word(&self, word: &[char]) -> bool {
116 self.full_dict.contains_word(word)
117 }
118
119 fn contains_word_str(&self, word: &str) -> bool {
120 self.full_dict.contains_word_str(word)
121 }
122
123 fn get_word_metadata(&self, word: &[char]) -> Option<&WordMetadata> {
124 self.full_dict.get_word_metadata(word)
125 }
126
127 fn get_word_metadata_str(&self, word: &str) -> Option<&WordMetadata> {
128 self.full_dict.get_word_metadata_str(word)
129 }
130
131 fn fuzzy_match(
132 &self,
133 word: &[char],
134 max_distance: u8,
135 max_results: usize,
136 ) -> Vec<FuzzyMatchResult> {
137 let misspelled_word_charslice = word.normalized();
138 let misspelled_word_string = misspelled_word_charslice.to_string();
139
140 let dfa = build_dfa(max_distance, &misspelled_word_string);
142 let dfa_lowercase = build_dfa(max_distance, &misspelled_word_string.to_lowercase());
143 let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
144 let mut word_indexes_lowercase_stream = self
145 .word_map
146 .search_with_state(&dfa_lowercase)
147 .into_stream();
148
149 let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
150 let lower_dists = stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
151
152 let mut merged = Vec::with_capacity(upper_dists.len());
153
154 for ((i_u, dist_u), (i_l, dist_l)) in upper_dists.into_iter().zip(lower_dists.into_iter()) {
156 let (chosen_index, edit_distance) = if dist_u <= dist_l {
157 (i_u, dist_u)
158 } else {
159 (i_l, dist_l)
160 };
161
162 let (word, metadata) = &self.words[chosen_index as usize];
163
164 merged.push(FuzzyMatchResult {
165 word,
166 edit_distance,
167 metadata,
168 })
169 }
170
171 merged.sort_unstable_by_key(|v| v.word);
172 merged.dedup_by_key(|v| v.word);
173 merged.sort_unstable_by_key(|v| v.edit_distance);
174 merged.truncate(max_results);
175
176 merged
177 }
178
179 fn fuzzy_match_str(
180 &self,
181 word: &str,
182 max_distance: u8,
183 max_results: usize,
184 ) -> Vec<FuzzyMatchResult> {
185 self.fuzzy_match(
186 word.chars().collect::<Vec<_>>().as_slice(),
187 max_distance,
188 max_results,
189 )
190 }
191
192 fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
193 self.full_dict.words_iter()
194 }
195
196 fn word_count(&self) -> usize {
197 self.full_dict.word_count()
198 }
199
200 fn contains_exact_word(&self, word: &[char]) -> bool {
201 self.full_dict.contains_exact_word(word)
202 }
203
204 fn contains_exact_word_str(&self, word: &str) -> bool {
205 self.full_dict.contains_exact_word_str(word)
206 }
207
208 fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
209 self.full_dict.get_correct_capitalization_of(word)
210 }
211
212 fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
213 self.full_dict.get_word_from_id(id)
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use itertools::Itertools;
220
221 use crate::CharStringExt;
222 use crate::Dictionary;
223 use crate::WordId;
224
225 use super::FstDictionary;
226
227 #[test]
228 fn fst_map_contains_all_in_full_dict() {
229 let dict = FstDictionary::curated();
230
231 for word in dict.words_iter() {
232 let misspelled_normalized = word.normalized();
233 let misspelled_word = misspelled_normalized.to_string();
234 let misspelled_lower = misspelled_normalized.to_lower().to_string();
235
236 dbg!(&misspelled_lower);
237
238 assert!(!misspelled_word.is_empty());
239 assert!(
240 dict.word_map.contains_key(misspelled_word)
241 || dict.word_map.contains_key(misspelled_lower)
242 );
243 }
244 }
245
246 #[test]
247 fn fst_contains_hello() {
248 let dict = FstDictionary::curated();
249
250 let word: Vec<_> = "hello".chars().collect();
251 let misspelled_normalized = word.normalized();
252 let misspelled_word = misspelled_normalized.to_string();
253 let misspelled_lower = misspelled_normalized.to_lower().to_string();
254
255 assert!(dict.contains_word(&misspelled_normalized));
256 assert!(
257 dict.word_map.contains_key(misspelled_lower)
258 || dict.word_map.contains_key(misspelled_word)
259 );
260 }
261
262 #[test]
263 fn on_is_not_nominal() {
264 let dict = FstDictionary::curated();
265
266 assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
267 }
268
269 #[test]
270 fn fuzzy_result_sorted_by_edit_distance() {
271 let dict = FstDictionary::curated();
272
273 let results = dict.fuzzy_match_str("hello", 3, 100);
274 let is_sorted_by_dist = results
275 .iter()
276 .map(|fm| fm.edit_distance)
277 .tuple_windows()
278 .all(|(a, b)| a <= b);
279
280 assert!(is_sorted_by_dist)
281 }
282
283 #[test]
284 fn curated_contains_no_duplicates() {
285 let dict = FstDictionary::curated();
286
287 assert!(dict.words.iter().map(|(word, _)| word).all_unique());
288 }
289
290 #[test]
291 fn contractions_not_derived() {
292 let dict = FstDictionary::curated();
293
294 let contractions = ["there's", "we're", "here's"];
295
296 for contraction in contractions {
297 dbg!(contraction);
298 assert!(
299 dict.get_word_metadata_str(contraction)
300 .unwrap()
301 .derived_from
302 .is_none()
303 )
304 }
305 }
306
307 #[test]
308 fn plural_llamas_derived_from_llama() {
309 let dict = FstDictionary::curated();
310
311 assert_eq!(
312 dict.get_word_metadata_str("llamas")
313 .unwrap()
314 .derived_from
315 .unwrap(),
316 WordId::from_word_str("llama")
317 )
318 }
319
320 #[test]
321 fn plural_cats_derived_from_cat() {
322 let dict = FstDictionary::curated();
323
324 assert_eq!(
325 dict.get_word_metadata_str("cats")
326 .unwrap()
327 .derived_from
328 .unwrap(),
329 WordId::from_word_str("cat")
330 );
331 }
332
333 #[test]
334 fn unhappy_derived_from_happy() {
335 let dict = FstDictionary::curated();
336
337 assert_eq!(
338 dict.get_word_metadata_str("unhappy")
339 .unwrap()
340 .derived_from
341 .unwrap(),
342 WordId::from_word_str("happy")
343 );
344 }
345
346 #[test]
347 fn quickly_derived_from_quick() {
348 let dict = FstDictionary::curated();
349
350 assert_eq!(
351 dict.get_word_metadata_str("quickly")
352 .unwrap()
353 .derived_from
354 .unwrap(),
355 WordId::from_word_str("quick")
356 );
357 }
358}