fst_no_std/automaton/
levenshtein.rs

1use core::cmp;
2use core::fmt;
3#[cfg(feature = "std")]
4use std::collections::hash_map::Entry;
5#[cfg(feature = "std")]
6use std::collections::{HashMap, HashSet};
7
8use utf8_ranges::{Utf8Range, Utf8Sequences};
9
10use crate::automaton::Automaton;
11
12const DEFAULT_STATE_LIMIT: usize = 10_000; // currently at least 20MB >_<
13
14/// An error that occurred while building a Levenshtein automaton.
15///
16/// This error is only defined when the `levenshtein` crate feature is enabled.
17#[derive(Debug)]
18pub enum LevenshteinError {
19    /// If construction of the automaton reaches some hard-coded limit
20    /// on the number of states, then this error is returned.
21    ///
22    /// The number given is the limit that was exceeded.
23    TooManyStates(usize),
24}
25
26impl fmt::Display for LevenshteinError {
27    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28        match *self {
29            LevenshteinError::TooManyStates(size_limit) => write!(
30                f,
31                "Levenshtein automaton exceeds size limit of \
32                           {size_limit} states"
33            ),
34        }
35    }
36}
37
38#[cfg(not(feature = "std"))]
39impl core::error::Error for LevenshteinError {}
40
41#[cfg(feature = "std")]
42impl std::error::Error for LevenshteinError {}
43
44/// A Unicode aware Levenshtein automaton for running efficient fuzzy queries.
45///
46/// This is only defined when the `levenshtein` crate feature is enabled.
47///
48/// A Levenshtein automata is one way to search any finite state transducer
49/// for keys that *approximately* match a given query. A Levenshtein automaton
50/// approximates this by returning all keys within a certain edit distance of
51/// the query. The edit distance is defined by the number of insertions,
52/// deletions and substitutions required to turn the query into the key.
53/// Insertions, deletions and substitutions are based on
54/// **Unicode characters** (where each character is a single Unicode scalar
55/// value).
56///
57/// # Example
58///
59/// This example shows how to find all keys within an edit distance of `1`
60/// from `foo`.
61///
62/// ```rust
63/// use fst_no_std::automaton::Levenshtein;
64/// use fst_no_std::{IntoStreamer, Streamer, Set};
65///
66/// let keys = vec!["fa", "fo", "fob", "focus", "foo", "food", "foul"];
67/// let set = Set::from_iter(keys).unwrap();
68///
69/// let lev = Levenshtein::new("foo", 1).unwrap();
70/// let mut stream = set.search(&lev).into_stream();
71///
72/// let mut keys = vec![];
73/// while let Some(key) = stream.next() {
74///     keys.push(key.to_vec());
75/// }
76///
77/// assert_eq!(keys, vec![
78///    "fo".as_bytes(),   // 1 deletion
79///     "fob".as_bytes(),  // 1 substitution
80///     "foo".as_bytes(),  // 0 insertions/deletions/substitutions
81///     "food".as_bytes(), // 1 insertion
82/// ]);
83/// ```
84///
85/// This example only uses ASCII characters, but it will work equally well
86/// on Unicode characters.
87///
88/// # Warning: experimental
89///
90/// While executing this Levenshtein automaton against a finite state
91/// transducer will be very fast, *constructing* an automaton may not be.
92/// Namely, this implementation is a proof of concept. While I believe the
93/// algorithmic complexity is not exponential, the implementation is not speedy
94/// and it can use enormous amounts of memory (tens of MB before a hard-coded
95/// limit will cause an error to be returned).
96///
97/// This is important functionality, so one should count on this implementation
98/// being vastly improved in the future.
99#[cfg(feature = "alloc")]
100pub struct Levenshtein {
101    prog: DynamicLevenshtein,
102    dfa: Dfa,
103}
104
105#[cfg(feature = "alloc")]
106impl Levenshtein {
107    /// Create a new Levenshtein query.
108    ///
109    /// The query finds all matching terms that are at most `distance`
110    /// edit operations from `query`. (An edit operation may be an insertion,
111    /// a deletion or a substitution.)
112    ///
113    /// If the underlying automaton becomes too big, then an error is returned.
114    /// Use `new_with_limit` to raise the limit dynamically.
115    ///
116    /// A `Levenshtein` value satisfies the `Automaton` trait, which means it
117    /// can be used with the `search` method of any finite state transducer.
118    #[cfg(feature = "alloc")]
119    pub fn new(
120        query: &str,
121        distance: u32,
122    ) -> Result<Levenshtein, LevenshteinError> {
123        let lev = DynamicLevenshtein {
124            query: query.to_owned(),
125            dist: distance as usize,
126        };
127        let dfa = DfaBuilder::new(lev.clone()).build()?;
128        Ok(Levenshtein { prog: lev, dfa })
129    }
130
131    /// Create a new Levenshtein query, but pass the state limit yourself.
132    ///
133    /// The query finds all matching terms that are at most `distance`
134    /// edit operations from `query`. (An edit operation may be an insertion,
135    /// a deletion or a substitution.)
136    ///
137    /// If the underlying automaton becomes too big, then an error is returned.
138    /// This limit can be configured with `state_limit`.
139    ///
140    /// A `Levenshtein` value satisfies the `Automaton` trait, which means it
141    /// can be used with the `search` method of any finite state transducer.
142    #[cfg(feature = "alloc")]
143    pub fn new_with_limit(
144        query: &str,
145        distance: u32,
146        state_limit: usize,
147    ) -> Result<Levenshtein, LevenshteinError> {
148        let lev = DynamicLevenshtein {
149            query: query.to_owned(),
150            dist: distance as usize,
151        };
152        let dfa =
153            DfaBuilder::new(lev.clone()).build_with_limit(state_limit)?;
154        Ok(Levenshtein { prog: lev, dfa })
155    }
156}
157
158#[cfg(feature = "alloc")]
159impl fmt::Debug for Levenshtein {
160    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161        write!(
162            f,
163            "Levenshtein(query: {:?}, distance: {:?})",
164            self.prog.query, self.prog.dist
165        )
166    }
167}
168
169#[derive(Clone)]
170#[cfg(feature = "alloc")]
171struct DynamicLevenshtein {
172    query: String,
173    dist: usize,
174}
175
176#[cfg(feature = "alloc")]
177impl DynamicLevenshtein {
178    fn start(&self) -> Vec<usize> {
179        (0..self.query.chars().count() + 1).collect()
180    }
181
182    fn is_match(&self, state: &[usize]) -> bool {
183        state.last().is_some_and(|&n| n <= self.dist)
184    }
185
186    fn can_match(&self, state: &[usize]) -> bool {
187        state.iter().min().is_some_and(|&n| n <= self.dist)
188    }
189
190    fn accept(&self, state: &[usize], chr: Option<char>) -> Vec<usize> {
191        let mut next = vec![state[0] + 1];
192        for (i, c) in self.query.chars().enumerate() {
193            let cost = usize::from(Some(c) != chr);
194            let v = cmp::min(
195                cmp::min(next[i] + 1, state[i + 1] + 1),
196                state[i] + cost,
197            );
198            next.push(cmp::min(v, self.dist + 1));
199        }
200        next
201    }
202}
203
204#[cfg(feature = "alloc")]
205impl Automaton for Levenshtein {
206    type State = Option<usize>;
207
208    #[inline]
209    fn start(&self) -> Option<usize> {
210        Some(0)
211    }
212
213    #[inline]
214    fn is_match(&self, state: &Option<usize>) -> bool {
215        state.map(|state| self.dfa.states[state].is_match).unwrap_or(false)
216    }
217
218    #[inline]
219    fn can_match(&self, state: &Option<usize>) -> bool {
220        state.is_some()
221    }
222
223    #[inline]
224    fn accept(&self, state: &Option<usize>, byte: u8) -> Option<usize> {
225        state.and_then(|state| self.dfa.states[state].next[byte as usize])
226    }
227}
228
229#[derive(Debug)]
230#[cfg(feature = "alloc")]
231struct Dfa {
232    states: Vec<State>,
233}
234
235struct State {
236    next: [Option<usize>; 256],
237    is_match: bool,
238}
239
240impl fmt::Debug for State {
241    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242        writeln!(f, "State {{")?;
243        writeln!(f, "  is_match: {:?}", self.is_match)?;
244        for i in 0..256 {
245            if let Some(si) = self.next[i] {
246                writeln!(f, "  {i:?}: {si:?}")?;
247            }
248        }
249        write!(f, "}}")
250    }
251}
252
253#[cfg(feature = "alloc")]
254struct DfaBuilder {
255    dfa: Dfa,
256    lev: DynamicLevenshtein,
257    cache: HashMap<Vec<usize>, usize>,
258}
259
260#[cfg(feature = "alloc")]
261impl DfaBuilder {
262    fn new(lev: DynamicLevenshtein) -> DfaBuilder {
263        DfaBuilder {
264            dfa: Dfa { states: Vec::with_capacity(16) },
265            lev,
266            cache: HashMap::with_capacity(1024),
267        }
268    }
269
270    fn build_with_limit(
271        mut self,
272        state_limit: usize,
273    ) -> Result<Dfa, LevenshteinError> {
274        let mut stack = vec![self.lev.start()];
275        let mut seen = HashSet::new();
276        let query = self.lev.query.clone(); // temp work around of borrowck
277        while let Some(lev_state) = stack.pop() {
278            let dfa_si = self.cached_state(&lev_state).unwrap();
279            let mismatch = self.add_mismatch_utf8_states(dfa_si, &lev_state);
280            if let Some((next_si, lev_next)) = mismatch {
281                if !seen.contains(&next_si) {
282                    seen.insert(next_si);
283                    stack.push(lev_next);
284                }
285            }
286            for (i, c) in query.chars().enumerate() {
287                if lev_state[i] > self.lev.dist {
288                    continue;
289                }
290                let lev_next = self.lev.accept(&lev_state, Some(c));
291                let next_si = self.cached_state(&lev_next);
292                if let Some(next_si) = next_si {
293                    self.add_utf8_sequences(true, dfa_si, next_si, c, c);
294                    if !seen.contains(&next_si) {
295                        seen.insert(next_si);
296                        stack.push(lev_next);
297                    }
298                }
299            }
300            if self.dfa.states.len() > state_limit {
301                return Err(LevenshteinError::TooManyStates(state_limit));
302            }
303        }
304        Ok(self.dfa)
305    }
306
307    fn build(self) -> Result<Dfa, LevenshteinError> {
308        self.build_with_limit(DEFAULT_STATE_LIMIT)
309    }
310
311    fn cached_state(&mut self, lev_state: &[usize]) -> Option<usize> {
312        self.cached(lev_state).map(|(si, _)| si)
313    }
314
315    fn cached(&mut self, lev_state: &[usize]) -> Option<(usize, bool)> {
316        if !self.lev.can_match(lev_state) {
317            return None;
318        }
319        Some(match self.cache.entry(lev_state.to_vec()) {
320            Entry::Occupied(v) => (*v.get(), true),
321            Entry::Vacant(v) => {
322                let is_match = self.lev.is_match(lev_state);
323                self.dfa.states.push(State { next: [None; 256], is_match });
324                (*v.insert(self.dfa.states.len() - 1), false)
325            }
326        })
327    }
328
329    fn add_mismatch_utf8_states(
330        &mut self,
331        from_si: usize,
332        lev_state: &[usize],
333    ) -> Option<(usize, Vec<usize>)> {
334        let mismatch_state = self.lev.accept(lev_state, None);
335        let to_si = match self.cached(&mismatch_state) {
336            None => return None,
337            Some((si, _)) => si,
338            // Some((si, true)) => return Some((si, mismatch_state)),
339            // Some((si, false)) => si,
340        };
341        self.add_utf8_sequences(false, from_si, to_si, '\u{0}', '\u{10FFFF}');
342        Some((to_si, mismatch_state))
343    }
344
345    fn add_utf8_sequences(
346        &mut self,
347        overwrite: bool,
348        from_si: usize,
349        to_si: usize,
350        from_chr: char,
351        to_chr: char,
352    ) {
353        for seq in Utf8Sequences::new(from_chr, to_chr) {
354            let mut fsi = from_si;
355            for range in &seq.as_slice()[0..seq.len() - 1] {
356                let tsi = self.new_state(false);
357                self.add_utf8_range(overwrite, fsi, tsi, range);
358                fsi = tsi;
359            }
360            self.add_utf8_range(
361                overwrite,
362                fsi,
363                to_si,
364                &seq.as_slice()[seq.len() - 1],
365            );
366        }
367    }
368
369    fn add_utf8_range(
370        &mut self,
371        overwrite: bool,
372        from: usize,
373        to: usize,
374        range: &Utf8Range,
375    ) {
376        for b in (range.start as usize)..=(range.end as usize) {
377            if overwrite || self.dfa.states[from].next[b].is_none() {
378                self.dfa.states[from].next[b] = Some(to);
379            }
380        }
381    }
382
383    fn new_state(&mut self, is_match: bool) -> usize {
384        self.dfa.states.push(State { next: [None; 256], is_match });
385        self.dfa.states.len() - 1
386    }
387}