1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use lazy_static::lazy_static;
4use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
5use std::borrow::Cow;
6use std::{cell::RefCell, sync::Arc};
7
8use crate::{CharString, CharStringExt, DictWordMetadata};
9
10use super::Dictionary;
11use super::FuzzyMatchResult;
12
13pub struct FstDictionary {
18 mutable_dict: Arc<MutableDictionary>,
20 word_map: FstMap<Vec<u8>>,
22 words: Vec<(CharString, DictWordMetadata)>,
24}
25
26const EXPECTED_DISTANCE: u8 = 3;
27const TRANSPOSITION_COST_ONE: bool = true;
28
29lazy_static! {
30 static ref DICT: Arc<FstDictionary> = Arc::new((*MutableDictionary::curated()).clone().into());
31}
32
33thread_local! {
34 static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
39 EXPECTED_DISTANCE,
40 LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
41 )]);
42}
43
44impl PartialEq for FstDictionary {
45 fn eq(&self, other: &Self) -> bool {
46 self.mutable_dict == other.mutable_dict
47 }
48}
49
50impl FstDictionary {
51 pub fn curated() -> Arc<Self> {
54 (*DICT).clone()
55 }
56
57 pub fn new(mut words: Vec<(CharString, DictWordMetadata)>) -> Self {
60 words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
61 words.dedup_by(|(a, _), (b, _)| a == b);
62
63 let mut builder = fst::MapBuilder::memory();
64 for (index, (word, _)) in words.iter().enumerate() {
65 let word = word.iter().collect::<String>();
66 builder
67 .insert(word, index as u64)
68 .expect("Insertion not in lexicographical order!");
69 }
70
71 let mut mutable_dict = MutableDictionary::new();
72 mutable_dict.extend_words(words.iter().cloned());
73
74 let fst_bytes = builder.into_inner().unwrap();
75 let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
76
77 FstDictionary {
78 mutable_dict: Arc::new(mutable_dict),
79 word_map,
80 words,
81 }
82 }
83}
84
85fn build_dfa(max_distance: u8, query: &str) -> DFA {
86 AUTOMATON_BUILDERS.with_borrow_mut(|v| {
88 if !v.iter().any(|t| t.0 == max_distance) {
89 v.push((
90 max_distance,
91 LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
92 ));
93 }
94 });
95
96 AUTOMATON_BUILDERS.with_borrow(|v| {
97 v.iter()
98 .find(|a| a.0 == max_distance)
99 .unwrap()
100 .1
101 .build_dfa(query)
102 })
103}
104
105fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
107 let mut word_index_pairs = Vec::new();
108 while let Some((_, v, s)) = stream.next() {
109 word_index_pairs.push((v, dfa.distance(s).to_u8()));
110 }
111
112 word_index_pairs
113}
114
115impl Dictionary for FstDictionary {
116 fn contains_word(&self, word: &[char]) -> bool {
117 self.mutable_dict.contains_word(word)
118 }
119
120 fn contains_word_str(&self, word: &str) -> bool {
121 self.mutable_dict.contains_word_str(word)
122 }
123
124 fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
125 self.mutable_dict.get_word_metadata(word)
126 }
127
128 fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
129 self.mutable_dict.get_word_metadata_str(word)
130 }
131
132 fn fuzzy_match(
133 &'_ self,
134 word: &[char],
135 max_distance: u8,
136 max_results: usize,
137 ) -> Vec<FuzzyMatchResult<'_>> {
138 let misspelled_word_charslice = word.normalized();
139 let misspelled_word_string = misspelled_word_charslice.to_string();
140
141 let dfa = build_dfa(max_distance, &misspelled_word_string);
143 let dfa_lowercase = build_dfa(max_distance, &misspelled_word_string.to_lowercase());
144 let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
145 let mut word_indexes_lowercase_stream = self
146 .word_map
147 .search_with_state(&dfa_lowercase)
148 .into_stream();
149
150 let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
151 let lower_dists = stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
152
153 let mut merged = Vec::with_capacity(upper_dists.len());
154
155 for ((i_u, dist_u), (i_l, dist_l)) in upper_dists.into_iter().zip(lower_dists.into_iter()) {
157 let (chosen_index, edit_distance) = if dist_u <= dist_l {
158 (i_u, dist_u)
159 } else {
160 (i_l, dist_l)
161 };
162
163 let (word, metadata) = &self.words[chosen_index as usize];
164
165 merged.push(FuzzyMatchResult {
166 word,
167 edit_distance,
168 metadata: Cow::Borrowed(metadata),
169 })
170 }
171
172 merged.sort_unstable_by_key(|v| v.word);
173 merged.dedup_by_key(|v| v.word);
174 merged.sort_unstable_by_key(|v| v.edit_distance);
175 merged.truncate(max_results);
176
177 merged
178 }
179
180 fn fuzzy_match_str(
181 &'_ self,
182 word: &str,
183 max_distance: u8,
184 max_results: usize,
185 ) -> Vec<FuzzyMatchResult<'_>> {
186 self.fuzzy_match(
187 word.chars().collect::<Vec<_>>().as_slice(),
188 max_distance,
189 max_results,
190 )
191 }
192
193 fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
194 self.mutable_dict.words_iter()
195 }
196
197 fn word_count(&self) -> usize {
198 self.mutable_dict.word_count()
199 }
200
201 fn contains_exact_word(&self, word: &[char]) -> bool {
202 self.mutable_dict.contains_exact_word(word)
203 }
204
205 fn contains_exact_word_str(&self, word: &str) -> bool {
206 self.mutable_dict.contains_exact_word_str(word)
207 }
208
209 fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
210 self.mutable_dict.get_correct_capitalization_of(word)
211 }
212
213 fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
214 self.mutable_dict.get_word_from_id(id)
215 }
216
217 fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
218 self.mutable_dict.find_words_with_prefix(prefix)
219 }
220
221 fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
222 self.mutable_dict.find_words_with_common_prefix(word)
223 }
224}
225
226#[cfg(test)]
227mod tests {
228 use itertools::Itertools;
229
230 use crate::CharStringExt;
231 use crate::spell::{Dictionary, WordId};
232
233 use super::FstDictionary;
234
235 #[test]
236 fn damerau_transposition_costs_one() {
237 let lev_automata =
238 levenshtein_automata::LevenshteinAutomatonBuilder::new(1, true).build_dfa("woof");
239 assert_eq!(
240 lev_automata.eval("wofo"),
241 levenshtein_automata::Distance::Exact(1)
242 );
243 }
244
245 #[test]
246 fn damerau_transposition_costs_two() {
247 let lev_automata =
248 levenshtein_automata::LevenshteinAutomatonBuilder::new(1, false).build_dfa("woof");
249 assert_eq!(
250 lev_automata.eval("wofo"),
251 levenshtein_automata::Distance::AtLeast(2)
252 );
253 }
254
255 #[test]
256 fn fst_map_contains_all_in_mutable_dict() {
257 let dict = FstDictionary::curated();
258
259 for word in dict.words_iter() {
260 let misspelled_normalized = word.normalized();
261 let misspelled_word = misspelled_normalized.to_string();
262 let misspelled_lower = misspelled_normalized.to_lower().to_string();
263
264 dbg!(&misspelled_lower);
265
266 assert!(!misspelled_word.is_empty());
267 assert!(dict.word_map.contains_key(misspelled_word));
268 }
269 }
270
271 #[test]
272 fn fst_contains_hello() {
273 let dict = FstDictionary::curated();
274
275 let word: Vec<_> = "hello".chars().collect();
276 let misspelled_normalized = word.normalized();
277 let misspelled_word = misspelled_normalized.to_string();
278 let misspelled_lower = misspelled_normalized.to_lower().to_string();
279
280 assert!(dict.contains_word(&misspelled_normalized));
281 assert!(
282 dict.word_map.contains_key(misspelled_lower)
283 || dict.word_map.contains_key(misspelled_word)
284 );
285 }
286
287 #[test]
288 fn on_is_not_nominal() {
289 let dict = FstDictionary::curated();
290
291 assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
292 }
293
294 #[test]
295 fn fuzzy_result_sorted_by_edit_distance() {
296 let dict = FstDictionary::curated();
297
298 let results = dict.fuzzy_match_str("hello", 3, 100);
299 let is_sorted_by_dist = results
300 .iter()
301 .map(|fm| fm.edit_distance)
302 .tuple_windows()
303 .all(|(a, b)| a <= b);
304
305 assert!(is_sorted_by_dist)
306 }
307
308 #[test]
309 fn curated_contains_no_duplicates() {
310 let dict = FstDictionary::curated();
311
312 assert!(dict.words.iter().map(|(word, _)| word).all_unique());
313 }
314
315 #[test]
316 fn contractions_not_derived() {
317 let dict = FstDictionary::curated();
318
319 let contractions = ["there's", "we're", "here's"];
320
321 for contraction in contractions {
322 dbg!(contraction);
323 assert!(
324 dict.get_word_metadata_str(contraction)
325 .unwrap()
326 .derived_from
327 .is_none()
328 )
329 }
330 }
331
332 #[test]
333 fn plural_llamas_derived_from_llama() {
334 let dict = FstDictionary::curated();
335
336 assert_eq!(
337 dict.get_word_metadata_str("llamas")
338 .unwrap()
339 .derived_from
340 .unwrap(),
341 WordId::from_word_str("llama")
342 )
343 }
344
345 #[test]
346 fn plural_cats_derived_from_cat() {
347 let dict = FstDictionary::curated();
348
349 assert_eq!(
350 dict.get_word_metadata_str("cats")
351 .unwrap()
352 .derived_from
353 .unwrap(),
354 WordId::from_word_str("cat")
355 );
356 }
357
358 #[test]
359 fn unhappy_derived_from_happy() {
360 let dict = FstDictionary::curated();
361
362 assert_eq!(
363 dict.get_word_metadata_str("unhappy")
364 .unwrap()
365 .derived_from
366 .unwrap(),
367 WordId::from_word_str("happy")
368 );
369 }
370
371 #[test]
372 fn quickly_derived_from_quick() {
373 let dict = FstDictionary::curated();
374
375 assert_eq!(
376 dict.get_word_metadata_str("quickly")
377 .unwrap()
378 .derived_from
379 .unwrap(),
380 WordId::from_word_str("quick")
381 );
382 }
383}