1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use hashbrown::HashMap;
4use lazy_static::lazy_static;
5use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
6use std::borrow::Cow;
7use std::{cell::RefCell, sync::Arc};
8
9use crate::{CharString, CharStringExt, DictWordMetadata};
10
11use super::Dictionary;
12use super::FuzzyMatchResult;
13
14pub struct FstDictionary {
19 mutable_dict: Arc<MutableDictionary>,
21 word_map: FstMap<Vec<u8>>,
23 words: Vec<(CharString, DictWordMetadata)>,
25}
26
27const EXPECTED_DISTANCE: u8 = 3;
28const TRANSPOSITION_COST_ONE: bool = true;
29
30lazy_static! {
31 static ref DICT: Arc<FstDictionary> = Arc::new((*MutableDictionary::curated()).clone().into());
32}
33
34thread_local! {
35 static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
40 EXPECTED_DISTANCE,
41 LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
42 )]);
43}
44
45impl PartialEq for FstDictionary {
46 fn eq(&self, other: &Self) -> bool {
47 self.mutable_dict == other.mutable_dict
48 }
49}
50
51impl FstDictionary {
52 pub fn curated() -> Arc<Self> {
55 (*DICT).clone()
56 }
57
58 pub fn new(mut words: Vec<(CharString, DictWordMetadata)>) -> Self {
61 words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
62 words.dedup_by(|(a, _), (b, _)| a == b);
63
64 let mut builder = fst::MapBuilder::memory();
65 for (index, (word, _)) in words.iter().enumerate() {
66 let word = word.iter().collect::<String>();
67 builder
68 .insert(word, index as u64)
69 .expect("Insertion not in lexicographical order!");
70 }
71
72 let mut mutable_dict = MutableDictionary::new();
73 mutable_dict.extend_words(words.iter().cloned());
74
75 let fst_bytes = builder.into_inner().unwrap();
76 let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
77
78 FstDictionary {
79 mutable_dict: Arc::new(mutable_dict),
80 word_map,
81 words,
82 }
83 }
84}
85
86fn build_dfa(max_distance: u8, query: &str) -> DFA {
87 AUTOMATON_BUILDERS.with_borrow_mut(|v| {
89 if !v.iter().any(|t| t.0 == max_distance) {
90 v.push((
91 max_distance,
92 LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
93 ));
94 }
95 });
96
97 AUTOMATON_BUILDERS.with_borrow(|v| {
98 v.iter()
99 .find(|a| a.0 == max_distance)
100 .unwrap()
101 .1
102 .build_dfa(query)
103 })
104}
105
106fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
108 let mut word_index_pairs = Vec::new();
109 while let Some((_, v, s)) = stream.next() {
110 word_index_pairs.push((v, dfa.distance(s).to_u8()));
111 }
112
113 word_index_pairs
114}
115
116impl Dictionary for FstDictionary {
117 fn contains_word(&self, word: &[char]) -> bool {
118 self.mutable_dict.contains_word(word)
119 }
120
121 fn contains_word_str(&self, word: &str) -> bool {
122 self.mutable_dict.contains_word_str(word)
123 }
124
125 fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
126 self.mutable_dict.get_word_metadata(word)
127 }
128
129 fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
130 self.mutable_dict.get_word_metadata_str(word)
131 }
132
133 fn fuzzy_match(
134 &'_ self,
135 word: &[char],
136 max_distance: u8,
137 max_results: usize,
138 ) -> Vec<FuzzyMatchResult<'_>> {
139 let misspelled_word_charslice = word.normalized();
140 let misspelled_word_string = misspelled_word_charslice.to_string();
141
142 let dfa = build_dfa(max_distance, &misspelled_word_string);
144 let dfa_lowercase = build_dfa(max_distance, &misspelled_word_string.to_lowercase());
145 let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
146 let mut word_indexes_lowercase_stream = self
147 .word_map
148 .search_with_state(&dfa_lowercase)
149 .into_stream();
150
151 let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
152 let lower_dists = stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
153
154 let mut merged = Vec::with_capacity(upper_dists.len().max(lower_dists.len()));
158 let mut best_distances = HashMap::<u64, u8>::new();
159
160 for (idx, dist) in upper_dists.into_iter().chain(lower_dists.into_iter()) {
161 best_distances
162 .entry(idx)
163 .and_modify(|existing| *existing = (*existing).min(dist))
164 .or_insert(dist);
165 }
166
167 for (index, edit_distance) in best_distances {
168 let (word, metadata) = &self.words[index as usize];
169 merged.push(FuzzyMatchResult {
170 word,
171 edit_distance,
172 metadata: Cow::Borrowed(metadata),
173 });
174 }
175
176 merged.retain(|v| v.edit_distance > 0);
178 merged.sort_unstable_by(|a, b| {
179 a.edit_distance
180 .cmp(&b.edit_distance)
181 .then_with(|| a.word.cmp(b.word))
182 });
183 merged.truncate(max_results);
184
185 merged
186 }
187
188 fn fuzzy_match_str(
189 &'_ self,
190 word: &str,
191 max_distance: u8,
192 max_results: usize,
193 ) -> Vec<FuzzyMatchResult<'_>> {
194 self.fuzzy_match(
195 word.chars().collect::<Vec<_>>().as_slice(),
196 max_distance,
197 max_results,
198 )
199 }
200
201 fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
202 self.mutable_dict.words_iter()
203 }
204
205 fn word_count(&self) -> usize {
206 self.mutable_dict.word_count()
207 }
208
209 fn contains_exact_word(&self, word: &[char]) -> bool {
210 self.mutable_dict.contains_exact_word(word)
211 }
212
213 fn contains_exact_word_str(&self, word: &str) -> bool {
214 self.mutable_dict.contains_exact_word_str(word)
215 }
216
217 fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
218 self.mutable_dict.get_correct_capitalization_of(word)
219 }
220
221 fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
222 self.mutable_dict.get_word_from_id(id)
223 }
224
225 fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
226 self.mutable_dict.find_words_with_prefix(prefix)
227 }
228
229 fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
230 self.mutable_dict.find_words_with_common_prefix(word)
231 }
232}
233
234#[cfg(test)]
235mod tests {
236 use itertools::Itertools;
237
238 use crate::CharStringExt;
239 use crate::spell::{Dictionary, WordId};
240
241 use super::FstDictionary;
242
243 #[test]
244 fn damerau_transposition_costs_one() {
245 let lev_automata =
246 levenshtein_automata::LevenshteinAutomatonBuilder::new(1, true).build_dfa("woof");
247 assert_eq!(
248 lev_automata.eval("wofo"),
249 levenshtein_automata::Distance::Exact(1)
250 );
251 }
252
253 #[test]
254 fn damerau_transposition_costs_two() {
255 let lev_automata =
256 levenshtein_automata::LevenshteinAutomatonBuilder::new(1, false).build_dfa("woof");
257 assert_eq!(
258 lev_automata.eval("wofo"),
259 levenshtein_automata::Distance::AtLeast(2)
260 );
261 }
262
263 #[test]
264 fn fst_map_contains_all_in_mutable_dict() {
265 let dict = FstDictionary::curated();
266
267 for word in dict.words_iter() {
268 let misspelled_normalized = word.normalized();
269 let misspelled_word = misspelled_normalized.to_string();
270 let misspelled_lower = misspelled_normalized.to_lower().to_string();
271
272 dbg!(&misspelled_lower);
273
274 assert!(!misspelled_word.is_empty());
275 assert!(dict.word_map.contains_key(misspelled_word));
276 }
277 }
278
279 #[test]
280 fn fst_contains_hello() {
281 let dict = FstDictionary::curated();
282
283 let word: Vec<_> = "hello".chars().collect();
284 let misspelled_normalized = word.normalized();
285 let misspelled_word = misspelled_normalized.to_string();
286 let misspelled_lower = misspelled_normalized.to_lower().to_string();
287
288 assert!(dict.contains_word(&misspelled_normalized));
289 assert!(
290 dict.word_map.contains_key(misspelled_lower)
291 || dict.word_map.contains_key(misspelled_word)
292 );
293 }
294
295 #[test]
296 fn on_is_not_nominal() {
297 let dict = FstDictionary::curated();
298
299 assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
300 }
301
302 #[test]
303 fn fuzzy_result_sorted_by_edit_distance() {
304 let dict = FstDictionary::curated();
305
306 let results = dict.fuzzy_match_str("hello", 3, 100);
307 let is_sorted_by_dist = results
308 .iter()
309 .map(|fm| fm.edit_distance)
310 .tuple_windows()
311 .all(|(a, b)| a <= b);
312
313 assert!(is_sorted_by_dist)
314 }
315
316 #[test]
317 fn curated_contains_no_duplicates() {
318 let dict = FstDictionary::curated();
319
320 assert!(dict.words.iter().map(|(word, _)| word).all_unique());
321 }
322
323 #[test]
324 fn contractions_not_derived() {
325 let dict = FstDictionary::curated();
326
327 let contractions = ["there's", "we're", "here's"];
328
329 for contraction in contractions {
330 dbg!(contraction);
331 assert!(
332 dict.get_word_metadata_str(contraction)
333 .unwrap()
334 .derived_from
335 .is_none()
336 )
337 }
338 }
339
340 #[test]
341 fn plural_llamas_derived_from_llama() {
342 let dict = FstDictionary::curated();
343
344 assert_eq!(
345 dict.get_word_metadata_str("llamas")
346 .unwrap()
347 .derived_from
348 .unwrap(),
349 WordId::from_word_str("llama")
350 )
351 }
352
353 #[test]
354 fn plural_cats_derived_from_cat() {
355 let dict = FstDictionary::curated();
356
357 assert_eq!(
358 dict.get_word_metadata_str("cats")
359 .unwrap()
360 .derived_from
361 .unwrap(),
362 WordId::from_word_str("cat")
363 );
364 }
365
366 #[test]
367 fn unhappy_derived_from_happy() {
368 let dict = FstDictionary::curated();
369
370 assert_eq!(
371 dict.get_word_metadata_str("unhappy")
372 .unwrap()
373 .derived_from
374 .unwrap(),
375 WordId::from_word_str("happy")
376 );
377 }
378
379 #[test]
380 fn quickly_derived_from_quick() {
381 let dict = FstDictionary::curated();
382
383 assert_eq!(
384 dict.get_word_metadata_str("quickly")
385 .unwrap()
386 .derived_from
387 .unwrap(),
388 WordId::from_word_str("quick")
389 );
390 }
391}