1use super::{MutableDictionary, WordId};
2use fst::{IntoStreamer, Map as FstMap, Streamer, map::StreamWithState};
3use hashbrown::HashMap;
4use levenshtein_automata::{DFA, LevenshteinAutomatonBuilder};
5use std::borrow::Cow;
6use std::sync::LazyLock;
7use std::{cell::RefCell, sync::Arc};
8
9use crate::{CharString, CharStringExt, DictWordMetadata};
10
11use super::Dictionary;
12use super::FuzzyMatchResult;
13
14pub struct FstDictionary {
19 mutable_dict: Arc<MutableDictionary>,
21 word_map: FstMap<Vec<u8>>,
23 words: Vec<(CharString, DictWordMetadata)>,
25}
26
27const EXPECTED_DISTANCE: u8 = 3;
28const TRANSPOSITION_COST_ONE: bool = true;
29
30static DICT: LazyLock<Arc<FstDictionary>> =
31 LazyLock::new(|| Arc::new((*MutableDictionary::curated()).clone().into()));
32
33thread_local! {
34 static AUTOMATON_BUILDERS: RefCell<Vec<(u8, LevenshteinAutomatonBuilder)>> = RefCell::new(vec![(
39 EXPECTED_DISTANCE,
40 LevenshteinAutomatonBuilder::new(EXPECTED_DISTANCE, TRANSPOSITION_COST_ONE),
41 )]);
42}
43
44impl PartialEq for FstDictionary {
45 fn eq(&self, other: &Self) -> bool {
46 self.mutable_dict == other.mutable_dict
47 }
48}
49
50impl FstDictionary {
51 pub fn curated() -> Arc<Self> {
54 (*DICT).clone()
55 }
56
57 pub fn new(mut words: Vec<(CharString, DictWordMetadata)>) -> Self {
60 words.sort_unstable_by(|(a, _), (b, _)| a.cmp(b));
61 words.dedup_by(|(a, _), (b, _)| a == b);
62
63 let mut builder = fst::MapBuilder::memory();
64 for (index, (word, _)) in words.iter().enumerate() {
65 let word = word.iter().collect::<String>();
66 builder
67 .insert(word, index as u64)
68 .expect("Insertion not in lexicographical order!");
69 }
70
71 let mut mutable_dict = MutableDictionary::new();
72 mutable_dict.extend_words(words.iter().cloned());
73
74 let fst_bytes = builder.into_inner().unwrap();
75 let word_map = FstMap::new(fst_bytes).expect("Unable to build FST map.");
76
77 FstDictionary {
78 mutable_dict: Arc::new(mutable_dict),
79 word_map,
80 words,
81 }
82 }
83}
84
85fn build_dfa(max_distance: u8, query: &str) -> DFA {
86 AUTOMATON_BUILDERS.with_borrow_mut(|v| {
88 if !v.iter().any(|t| t.0 == max_distance) {
89 v.push((
90 max_distance,
91 LevenshteinAutomatonBuilder::new(max_distance, TRANSPOSITION_COST_ONE),
92 ));
93 }
94 });
95
96 AUTOMATON_BUILDERS.with_borrow(|v| {
97 v.iter()
98 .find(|a| a.0 == max_distance)
99 .unwrap()
100 .1
101 .build_dfa(query)
102 })
103}
104
105fn stream_distances_vec(stream: &mut StreamWithState<&DFA>, dfa: &DFA) -> Vec<(u64, u8)> {
107 let mut word_index_pairs = Vec::new();
108 while let Some((_, v, s)) = stream.next() {
109 word_index_pairs.push((v, dfa.distance(s).to_u8()));
110 }
111
112 word_index_pairs
113}
114
115impl Dictionary for FstDictionary {
116 fn contains_word(&self, word: &[char]) -> bool {
117 self.mutable_dict.contains_word(word)
118 }
119
120 fn contains_word_str(&self, word: &str) -> bool {
121 self.mutable_dict.contains_word_str(word)
122 }
123
124 fn get_word_metadata(&self, word: &[char]) -> Option<Cow<'_, DictWordMetadata>> {
125 self.mutable_dict.get_word_metadata(word)
126 }
127
128 fn get_word_metadata_str(&self, word: &str) -> Option<Cow<'_, DictWordMetadata>> {
129 self.mutable_dict.get_word_metadata_str(word)
130 }
131
132 fn fuzzy_match(
133 &'_ self,
134 word: &[char],
135 max_distance: u8,
136 max_results: usize,
137 ) -> Vec<FuzzyMatchResult<'_>> {
138 let misspelled_word_charslice = word.normalized();
139 let misspelled_word_string = misspelled_word_charslice.to_string();
140
141 let dfa = build_dfa(max_distance, &misspelled_word_string);
143 let dfa_lowercase = build_dfa(max_distance, &misspelled_word_string.to_lowercase());
144 let mut word_indexes_stream = self.word_map.search_with_state(&dfa).into_stream();
145 let mut word_indexes_lowercase_stream = self
146 .word_map
147 .search_with_state(&dfa_lowercase)
148 .into_stream();
149
150 let upper_dists = stream_distances_vec(&mut word_indexes_stream, &dfa);
151 let lower_dists = stream_distances_vec(&mut word_indexes_lowercase_stream, &dfa_lowercase);
152
153 let mut merged = Vec::with_capacity(upper_dists.len().max(lower_dists.len()));
157 let mut best_distances = HashMap::<u64, u8>::new();
158
159 for (idx, dist) in upper_dists.into_iter().chain(lower_dists.into_iter()) {
160 best_distances
161 .entry(idx)
162 .and_modify(|existing| *existing = (*existing).min(dist))
163 .or_insert(dist);
164 }
165
166 for (index, edit_distance) in best_distances {
167 let (word, metadata) = &self.words[index as usize];
168 merged.push(FuzzyMatchResult {
169 word,
170 edit_distance,
171 metadata: Cow::Borrowed(metadata),
172 });
173 }
174
175 merged.retain(|v| v.edit_distance > 0);
177 merged.sort_unstable_by(|a, b| {
178 a.edit_distance
179 .cmp(&b.edit_distance)
180 .then_with(|| a.word.cmp(b.word))
181 });
182 merged.truncate(max_results);
183
184 merged
185 }
186
187 fn fuzzy_match_str(
188 &'_ self,
189 word: &str,
190 max_distance: u8,
191 max_results: usize,
192 ) -> Vec<FuzzyMatchResult<'_>> {
193 self.fuzzy_match(
194 word.chars().collect::<Vec<_>>().as_slice(),
195 max_distance,
196 max_results,
197 )
198 }
199
200 fn words_iter(&self) -> Box<dyn Iterator<Item = &'_ [char]> + Send + '_> {
201 self.mutable_dict.words_iter()
202 }
203
204 fn word_count(&self) -> usize {
205 self.mutable_dict.word_count()
206 }
207
208 fn contains_exact_word(&self, word: &[char]) -> bool {
209 self.mutable_dict.contains_exact_word(word)
210 }
211
212 fn contains_exact_word_str(&self, word: &str) -> bool {
213 self.mutable_dict.contains_exact_word_str(word)
214 }
215
216 fn get_correct_capitalization_of(&self, word: &[char]) -> Option<&'_ [char]> {
217 self.mutable_dict.get_correct_capitalization_of(word)
218 }
219
220 fn get_word_from_id(&self, id: &WordId) -> Option<&[char]> {
221 self.mutable_dict.get_word_from_id(id)
222 }
223
224 fn find_words_with_prefix(&self, prefix: &[char]) -> Vec<Cow<'_, [char]>> {
225 self.mutable_dict.find_words_with_prefix(prefix)
226 }
227
228 fn find_words_with_common_prefix(&self, word: &[char]) -> Vec<Cow<'_, [char]>> {
229 self.mutable_dict.find_words_with_common_prefix(word)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use itertools::Itertools;
236
237 use crate::CharStringExt;
238 use crate::spell::{Dictionary, WordId};
239
240 use super::FstDictionary;
241
242 #[test]
243 fn damerau_transposition_costs_one() {
244 let lev_automata =
245 levenshtein_automata::LevenshteinAutomatonBuilder::new(1, true).build_dfa("woof");
246 assert_eq!(
247 lev_automata.eval("wofo"),
248 levenshtein_automata::Distance::Exact(1)
249 );
250 }
251
252 #[test]
253 fn damerau_transposition_costs_two() {
254 let lev_automata =
255 levenshtein_automata::LevenshteinAutomatonBuilder::new(1, false).build_dfa("woof");
256 assert_eq!(
257 lev_automata.eval("wofo"),
258 levenshtein_automata::Distance::AtLeast(2)
259 );
260 }
261
262 #[test]
263 fn fst_map_contains_all_in_mutable_dict() {
264 let dict = FstDictionary::curated();
265
266 for word in dict.words_iter() {
267 let misspelled_normalized = word.normalized();
268 let misspelled_word = misspelled_normalized.to_string();
269 let misspelled_lower = misspelled_normalized.to_lower().to_string();
270
271 dbg!(&misspelled_lower);
272
273 assert!(!misspelled_word.is_empty());
274 assert!(dict.word_map.contains_key(misspelled_word));
275 }
276 }
277
278 #[test]
279 fn fst_contains_hello() {
280 let dict = FstDictionary::curated();
281
282 let word: Vec<_> = "hello".chars().collect();
283 let misspelled_normalized = word.normalized();
284 let misspelled_word = misspelled_normalized.to_string();
285 let misspelled_lower = misspelled_normalized.to_lower().to_string();
286
287 assert!(dict.contains_word(&misspelled_normalized));
288 assert!(
289 dict.word_map.contains_key(misspelled_lower)
290 || dict.word_map.contains_key(misspelled_word)
291 );
292 }
293
294 #[test]
295 fn on_is_not_nominal() {
296 let dict = FstDictionary::curated();
297
298 assert!(!dict.get_word_metadata_str("on").unwrap().is_nominal());
299 }
300
301 #[test]
302 fn fuzzy_result_sorted_by_edit_distance() {
303 let dict = FstDictionary::curated();
304
305 let results = dict.fuzzy_match_str("hello", 3, 100);
306 let is_sorted_by_dist = results
307 .iter()
308 .map(|fm| fm.edit_distance)
309 .tuple_windows()
310 .all(|(a, b)| a <= b);
311
312 assert!(is_sorted_by_dist)
313 }
314
315 #[test]
316 fn curated_contains_no_duplicates() {
317 let dict = FstDictionary::curated();
318
319 assert!(dict.words.iter().map(|(word, _)| word).all_unique());
320 }
321
322 #[test]
323 fn contractions_not_derived() {
324 let dict = FstDictionary::curated();
325
326 let contractions = ["there's", "we're", "here's"];
327
328 for contraction in contractions {
329 dbg!(contraction);
330 assert!(
331 dict.get_word_metadata_str(contraction)
332 .unwrap()
333 .derived_from
334 .is_none()
335 )
336 }
337 }
338
339 #[test]
340 fn plural_llamas_derived_from_llama() {
341 let dict = FstDictionary::curated();
342
343 assert_eq!(
344 dict.get_word_metadata_str("llamas")
345 .unwrap()
346 .derived_from
347 .unwrap(),
348 WordId::from_word_str("llama")
349 )
350 }
351
352 #[test]
353 fn plural_cats_derived_from_cat() {
354 let dict = FstDictionary::curated();
355
356 assert_eq!(
357 dict.get_word_metadata_str("cats")
358 .unwrap()
359 .derived_from
360 .unwrap(),
361 WordId::from_word_str("cat")
362 );
363 }
364
365 #[test]
366 fn unhappy_derived_from_happy() {
367 let dict = FstDictionary::curated();
368
369 assert_eq!(
370 dict.get_word_metadata_str("unhappy")
371 .unwrap()
372 .derived_from
373 .unwrap(),
374 WordId::from_word_str("happy")
375 );
376 }
377
378 #[test]
379 fn quickly_derived_from_quick() {
380 let dict = FstDictionary::curated();
381
382 assert_eq!(
383 dict.get_word_metadata_str("quickly")
384 .unwrap()
385 .derived_from
386 .unwrap(),
387 WordId::from_word_str("quick")
388 );
389 }
390}