1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use hashbrown::HashMap;
4use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
5use std::borrow::Cow;
6use std::sync::LazyLock;
7use std::{cell::RefCell, sync::Arc};
8
9use crate::{CharString, CharStringExt, DictWordMetadata};
10
11use super::Dictionary;
12use super::FuzzyMatchResult;
13
14pub struct FstDictionary {
19 mutable_dict: Arc<MutableDictionary>,
21 word_map: FstMap<Vec<u8>>,
23 words: Vec<(CharString, DictWordMetadata)>,
25}
26
27const EXPECTED_DISTANCE: u8 = 3;
28const TRANSPOSITION_COST_ONE: bool = true;
29
30static DICT: LazyLock<Arc<FstDictionary>> =
31 LazyLock::new(|| Arc::new((*MutableDictionary::curated()).clone().into()));
32
33thread_local! {
34 static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
39 EXPECTED_DISTANCE,
40 LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
41 )]);
42}
43
44impl PartialEq for FstDictionary {
45 fn eq(&self, other: &Self) -> bool {
46 self.mutable_dict == other.mutable_dict
47 }
48}
49
50impl FstDictionary {
51 pub fn curated() -> Arc<Self> {
54 (*DICT).clone()
55 }
56
57 pub fn new(mut words: Vec<(CharString, DictWordMetadata)>) -> Self {
60 words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
61 words.dedup_by(|(a, _), (b, _)| a == b);
62
63 let mut builder = fst::MapBuilder::memory();
64 for (index, (word, _)) in words.iter().enumerate() {
65 let word = word.iter().collect::<String>();
66 builder
67 .insert(word, index as u64)
68 .expect("Insertion not in lexicographical order!");
69 }
70
71 let mut mutable_dict = MutableDictionary::new();
72 mutable_dict.extend_words(words.iter().cloned());
73
74 let fst_bytes = builder.into_inner().unwrap();
75 let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
76
77 FstDictionary {
78 mutable_dict: Arc::new(mutable_dict),
79 word_map,
80 words,
81 }
82 }
83}
84
85fn build_dfa(max_distance: u8, query: &str) -> DFA {
86 AUTOMATON_BUILDERS.with_borrow_mut(|v| {
88 if !v.iter().any(|t| t.0 == max_distance) {
89 v.push((
90 max_distance,
91 LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
92 ));
93 }
94 });
95
96 AUTOMATON_BUILDERS.with_borrow(|v| {
97 v.iter()
98 .find(|a| a.0 == max_distance)
99 .unwrap()
100 .1
101 .build_dfa(query)
102 })
103}
104
105fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
107 let mut word_index_pairs = Vec::new();
108 while let Some((_, v, s)) = stream.next() {
109 word_index_pairs.push((v, dfa.distance(s).to_u8()));
110 }
111
112 word_index_pairs
113}
114
115fn merge_best_distances(
117 best_distances: &mut HashMap<u64, u8>,
118 distances: impl IntoIterator<Item = (u64, u8)>,
119) {
120 for (idx, dist) in distances {
121 best_distances
122 .entry(idx)
123 .and_modify(|existing| *existing = (*existing).min(dist))
124 .or_insert(dist);
125 }
126}
127
128impl Dictionary for FstDictionary {
129 fn contains_word(&self, word: &[char]) -> bool {
130 self.mutable_dict.contains_word(word)
131 }
132
133 fn contains_word_str(&self, word: &str) -> bool {
134 self.mutable_dict.contains_word_str(word)
135 }
136
137 fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
138 self.mutable_dict.get_word_metadata(word)
139 }
140
141 fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
142 self.mutable_dict.get_word_metadata_str(word)
143 }
144
145 fn fuzzy_match(
146 &'_ self,
147 word: &[char],
148 max_distance: u8,
149 max_results: usize,
150 ) -> Vec<FuzzyMatchResult<'_>> {
151 let misspelled_word_charslice = word.normalized();
152 let misspelled_word_string = misspelled_word_charslice.to_string();
153 let misspelled_lower = misspelled_word_string.to_lowercase();
154 let is_already_lower = misspelled_lower == misspelled_word_string;
155
156 let dfa = build_dfa(max_distance, &misspelled_word_string);
158 let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
159 let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
160
161 let mut best_distances = HashMap::<u64, u8>::new();
165
166 merge_best_distances(&mut best_distances, upper_dists);
167
168 if !is_already_lower {
170 let dfa_lowercase = build_dfa(max_distance, &misspelled_lower);
171 let mut word_indexes_lowercase_stream = self
172 .word_map
173 .search_with_state(&dfa_lowercase)
174 .into_stream();
175 let lower_dists =
176 stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
177
178 merge_best_distances(&mut best_distances, lower_dists);
179 }
180
181 let mut merged = Vec::with_capacity(best_distances.len());
182 for (index, edit_distance) in best_distances {
183 let (word, metadata) = &self.words[index as usize];
184 merged.push(FuzzyMatchResult {
185 word,
186 edit_distance,
187 metadata: Cow::Borrowed(metadata),
188 });
189 }
190
191 merged.retain(|v| v.edit_distance > 0);
193 merged.sort_unstable_by(|a, b| {
194 a.edit_distance
195 .cmp(&b.edit_distance)
196 .then_with(|| a.word.cmp(b.word))
197 });
198 merged.truncate(max_results);
199
200 merged
201 }
202
203 fn fuzzy_match_str(
204 &'_ self,
205 word: &str,
206 max_distance: u8,
207 max_results: usize,
208 ) -> Vec<FuzzyMatchResult<'_>> {
209 self.fuzzy_match(
210 word.chars().collect::<Vec<_>>().as_slice(),
211 max_distance,
212 max_results,
213 )
214 }
215
216 fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
217 self.mutable_dict.words_iter()
218 }
219
220 fn word_count(&self) -> usize {
221 self.mutable_dict.word_count()
222 }
223
224 fn contains_exact_word(&self, word: &[char]) -> bool {
225 self.mutable_dict.contains_exact_word(word)
226 }
227
228 fn contains_exact_word_str(&self, word: &str) -> bool {
229 self.mutable_dict.contains_exact_word_str(word)
230 }
231
232 fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
233 self.mutable_dict.get_correct_capitalization_of(word)
234 }
235
236 fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
237 self.mutable_dict.get_word_from_id(id)
238 }
239
240 fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
241 self.mutable_dict.find_words_with_prefix(prefix)
242 }
243
244 fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
245 self.mutable_dict.find_words_with_common_prefix(word)
246 }
247}
248
249#[cfg(test)]
250mod tests {
251 use itertools::Itertools;
252
253 use crate::CharStringExt;
254 use crate::DictWordMetadata;
255 use crate::spell::{Dictionary, MutableDictionary, WordId};
256
257 use super::FstDictionary;
258
259 fn test_dictionaries(words: &[&str]) -> (MutableDictionary, FstDictionary) {
260 let mut mutable = MutableDictionary::new();
261
262 for word in words {
263 mutable.append_word_str(word, DictWordMetadata::default());
264 }
265
266 let fst = FstDictionary::from(mutable.clone());
267
268 (mutable, fst)
269 }
270
271 fn fuzzy_matches<D: Dictionary + ?Sized>(
272 dict: &D,
273 word: &str,
274 max_distance: u8,
275 max_results: usize,
276 ) -> Vec<(String, u8)> {
277 let mut matches = dict
278 .fuzzy_match_str(word, max_distance, max_results)
279 .into_iter()
280 .map(|result| (result.word.iter().collect::<String>(), result.edit_distance))
281 .collect_vec();
282
283 matches.sort_unstable_by(|a, b| a.1.cmp(&b.1).then_with(|| a.0.cmp(&b.0)));
284 matches
285 }
286
287 #[test]
288 fn damerau_transposition_costs_one() {
289 let lev_automata =
290 levenshtein_automata::LevenshteinAutomatonBuilder::new(1, true).build_dfa("woof");
291 assert_eq!(
292 lev_automata.eval("wofo"),
293 levenshtein_automata::Distance::Exact(1)
294 );
295 }
296
297 #[test]
298 fn damerau_transposition_costs_two() {
299 let lev_automata =
300 levenshtein_automata::LevenshteinAutomatonBuilder::new(1, false).build_dfa("woof");
301 assert_eq!(
302 lev_automata.eval("wofo"),
303 levenshtein_automata::Distance::AtLeast(2)
304 );
305 }
306
307 #[test]
308 fn fst_map_contains_all_in_mutable_dict() {
309 let dict = FstDictionary::curated();
310
311 for word in dict.words_iter() {
312 let misspelled_normalized = word.normalized();
313 let misspelled_word = misspelled_normalized.to_string();
314 let misspelled_lower = misspelled_normalized.to_lower().to_string();
315
316 dbg!(&misspelled_lower);
317
318 assert!(!misspelled_word.is_empty());
319 assert!(dict.word_map.contains_key(misspelled_word));
320 }
321 }
322
323 #[test]
324 fn fst_contains_hello() {
325 let dict = FstDictionary::curated();
326
327 let word: Vec<_> = "hello".chars().collect();
328 let misspelled_normalized = word.normalized();
329 let misspelled_word = misspelled_normalized.to_string();
330 let misspelled_lower = misspelled_normalized.to_lower().to_string();
331
332 assert!(dict.contains_word(&misspelled_normalized));
333 assert!(
334 dict.word_map.contains_key(misspelled_lower)
335 || dict.word_map.contains_key(misspelled_word)
336 );
337 }
338
339 #[test]
340 fn on_is_not_nominal() {
341 let dict = FstDictionary::curated();
342
343 assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
344 }
345
346 #[test]
347 fn fuzzy_result_sorted_by_edit_distance() {
348 let dict = FstDictionary::curated();
349
350 let results = dict.fuzzy_match_str("hello", 3, 100);
351 let is_sorted_by_dist = results
352 .iter()
353 .map(|fm| fm.edit_distance)
354 .tuple_windows()
355 .all(|(a, b)| a <= b);
356
357 assert!(is_sorted_by_dist)
358 }
359
360 #[test]
361 fn curated_contains_no_duplicates() {
362 let dict = FstDictionary::curated();
363
364 assert!(dict.words.iter().map(|(word, _)| word).all_unique());
365 }
366
367 #[test]
368 fn contractions_not_derived() {
369 let dict = FstDictionary::curated();
370
371 let contractions = ["there's", "we're", "here's"];
372
373 for contraction in contractions {
374 dbg!(contraction);
375 assert!(
376 dict.get_word_metadata_str(contraction)
377 .unwrap()
378 .derived_from
379 .is_none()
380 )
381 }
382 }
383
384 #[test]
385 fn plural_llamas_derived_from_llama() {
386 let dict = FstDictionary::curated();
387
388 assert_eq!(
389 dict.get_word_metadata_str("llamas")
390 .unwrap()
391 .derived_from
392 .unwrap(),
393 WordId::from_word_str("llama")
394 )
395 }
396
397 #[test]
398 fn plural_cats_derived_from_cat() {
399 let dict = FstDictionary::curated();
400
401 assert_eq!(
402 dict.get_word_metadata_str("cats")
403 .unwrap()
404 .derived_from
405 .unwrap(),
406 WordId::from_word_str("cat")
407 );
408 }
409
410 #[test]
411 fn unhappy_derived_from_happy() {
412 let dict = FstDictionary::curated();
413
414 assert_eq!(
415 dict.get_word_metadata_str("unhappy")
416 .unwrap()
417 .derived_from
418 .unwrap(),
419 WordId::from_word_str("happy")
420 );
421 }
422
423 #[test]
424 fn quickly_derived_from_quick() {
425 let dict = FstDictionary::curated();
426
427 assert_eq!(
428 dict.get_word_metadata_str("quickly")
429 .unwrap()
430 .derived_from
431 .unwrap(),
432 WordId::from_word_str("quick")
433 );
434 }
435
436 #[test]
437 fn lowercase_fuzzy_match_matches_mutable_dictionary() {
438 let (mutable, fst) =
439 test_dictionaries(&["spelling", "spilling", "selling", "smelling", "shelling"]);
440
441 let mutable_results = fuzzy_matches(&mutable, "speling", 3, 10);
442 let fst_results = fuzzy_matches(&fst, "speling", 3, 10);
443
444 assert_eq!(fst_results, mutable_results);
445 assert_eq!(fst_results.first(), Some(&(String::from("spelling"), 1)));
446 }
447
448 #[test]
449 fn capitalized_fuzzy_match_matches_mutable_dictionary() {
450 let (mutable, fst) =
451 test_dictionaries(&["spelling", "spilling", "selling", "smelling", "shelling"]);
452
453 let mutable_results = fuzzy_matches(&mutable, "Speling", 3, 10);
454 let fst_results = fuzzy_matches(&fst, "Speling", 3, 10);
455
456 assert_eq!(fst_results, mutable_results);
457 assert_eq!(fst_results.first(), Some(&(String::from("spelling"), 1)));
458 }
459
460 #[test]
461 fn uppercase_fuzzy_match_matches_mutable_dictionary() {
462 let (mutable, fst) =
463 test_dictionaries(&["spelling", "spilling", "selling", "smelling", "shelling"]);
464
465 let mutable_results = fuzzy_matches(&mutable, "SPELING", 3, 10);
466 let fst_results = fuzzy_matches(&fst, "SPELING", 3, 10);
467
468 assert_eq!(fst_results, mutable_results);
469 assert_eq!(fst_results.first(), Some(&(String::from("spelling"), 1)));
470 }
471
472 #[test]
473 fn query_casing_produces_the_same_fuzzy_matches() {
474 let (_, fst) =
475 test_dictionaries(&["spelling", "spilling", "selling", "smelling", "shelling"]);
476
477 let lowercase = fuzzy_matches(&fst, "speling", 3, 10);
478 let capitalized = fuzzy_matches(&fst, "Speling", 3, 10);
479 let uppercase = fuzzy_matches(&fst, "SPELING", 3, 10);
480
481 assert_eq!(lowercase, capitalized);
482 assert_eq!(lowercase, uppercase);
483 }
484}