fst_levenshtein/
lib.rs

1use std::cmp;
2use std::collections::hash_map::Entry;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5
6use utf8_ranges::{Utf8Range, Utf8Sequences};
7
8use fst::automaton::Automaton;
9
10pub use self::error::Error;
11
12mod error;
13
14const STATE_LIMIT: usize = 10_000; // currently at least 20MB >_<
15
16/// A Unicode aware Levenshtein automaton for running efficient fuzzy queries.
17///
18/// A Levenshtein automata is one way to search any finite state transducer
19/// for keys that *approximately* match a given query. A Levenshtein automaton
20/// approximates this by returning all keys within a certain edit distance of
21/// the query. The edit distance is defined by the number of insertions,
22/// deletions and substitutions required to turn the query into the key.
23/// Insertions, deletions and substitutions are based on
24/// **Unicode characters** (where each character is a single Unicode scalar
25/// value).
26///
27/// # Example
28///
29/// This example shows how to find all keys within an edit distance of `1`
30/// from `foo`.
31///
32/// ```rust
33/// extern crate fst;
34/// extern crate fst_levenshtein;
35///
36/// use fst::{IntoStreamer, Streamer, Set};
37/// use fst_levenshtein::Levenshtein;
38///
39/// fn main() {
40///     let keys = vec!["fa", "fo", "fob", "focus", "foo", "food", "foul"];
41///     let set = Set::from_iter(keys).unwrap();
42///
43///     let lev = Levenshtein::new("foo", 1).unwrap();
44///     let mut stream = set.search(&lev).into_stream();
45///
46///     let mut keys = vec![];
47///     while let Some(key) = stream.next() {
48///         keys.push(key.to_vec());
49///     }
50///     assert_eq!(keys, vec![
51///         "fo".as_bytes(),   // 1 deletion
52///         "fob".as_bytes(),  // 1 substitution
53///         "foo".as_bytes(),  // 0 insertions/deletions/substitutions
54///         "food".as_bytes(), // 1 insertion
55///     ]);
56/// }
57/// ```
58///
59/// This example only uses ASCII characters, but it will work equally well
60/// on Unicode characters.
61///
62/// # Warning: experimental
63///
64/// While executing this Levenshtein automaton against a finite state
65/// transducer will be very fast, *constructing* an automaton may not be.
66/// Namely, this implementation is a proof of concept. While I believe the
67/// algorithmic complexity is not exponential, the implementation is not speedy
68/// and it can use enormous amounts of memory (tens of MB before a hard-coded
69/// limit will cause an error to be returned).
70///
71/// This is important functionality, so one should count on this implementation
72/// being vastly improved in the future.
73pub struct Levenshtein {
74    prog: DynamicLevenshtein,
75    dfa: Dfa,
76}
77
78impl Levenshtein {
79    /// Create a new Levenshtein query.
80    ///
81    /// The query finds all matching terms that are at most `distance`
82    /// edit operations from `query`. (An edit operation may be an insertion,
83    /// a deletion or a substitution.)
84    ///
85    /// If the underlying automaton becomes too big, then an error is returned.
86    ///
87    /// A `Levenshtein` value satisfies the `Automaton` trait, which means it
88    /// can be used with the `search` method of any finite state transducer.
89    #[inline]
90    pub fn new(query: &str, distance: u32) -> Result<Levenshtein, Error> {
91        let lev = DynamicLevenshtein {
92            query: query.to_owned(),
93            dist: distance as usize,
94        };
95        let dfa = DfaBuilder::new(lev.clone()).build()?;
96        Ok(Levenshtein { prog: lev, dfa })
97    }
98}
99
100impl fmt::Debug for Levenshtein {
101    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102        write!(
103            f,
104            "Levenshtein(query: {:?}, distance: {:?})",
105            self.prog.query, self.prog.dist
106        )
107    }
108}
109
110#[derive(Clone)]
111struct DynamicLevenshtein {
112    query: String,
113    dist: usize,
114}
115
116impl DynamicLevenshtein {
117    fn start(&self) -> Vec<usize> {
118        (0..self.query.chars().count() + 1).collect()
119    }
120
121    fn is_match(&self, state: &[usize]) -> bool {
122        state.last().map(|&n| n <= self.dist).unwrap_or(false)
123    }
124
125    fn can_match(&self, state: &[usize]) -> bool {
126        state.iter().min().map(|&n| n <= self.dist).unwrap_or(false)
127    }
128
129    fn accept(&self, state: &[usize], chr: Option<char>) -> Vec<usize> {
130        let mut next = vec![state[0] + 1];
131        for (i, c) in self.query.chars().enumerate() {
132            let cost = if Some(c) == chr { 0 } else { 1 };
133            let v = cmp::min(
134                cmp::min(next[i] + 1, state[i + 1] + 1),
135                state[i] + cost,
136            );
137            next.push(cmp::min(v, self.dist + 1));
138        }
139        next
140    }
141}
142
143impl Automaton for Levenshtein {
144    type State = Option<usize>;
145
146    #[inline]
147    fn start(&self) -> Option<usize> {
148        Some(0)
149    }
150
151    #[inline]
152    fn is_match(&self, state: &Option<usize>) -> bool {
153        state.map(|state| self.dfa.states[state].is_match).unwrap_or(false)
154    }
155
156    #[inline]
157    fn can_match(&self, state: &Option<usize>) -> bool {
158        state.is_some()
159    }
160
161    #[inline]
162    fn accept(&self, state: &Option<usize>, byte: u8) -> Option<usize> {
163        state.and_then(|state| self.dfa.states[state].next[byte as usize])
164    }
165}
166
167#[derive(Debug)]
168pub struct Dfa {
169    states: Vec<State>,
170}
171
172struct State {
173    next: [Option<usize>; 256],
174    is_match: bool,
175}
176
177impl fmt::Debug for State {
178    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179        writeln!(f, "State {{")?;
180        writeln!(f, "  is_match: {:?}", self.is_match)?;
181        for i in 0..256 {
182            if let Some(si) = self.next[i] {
183                writeln!(f, "  {:?}: {:?}", i, si)?;
184            }
185        }
186        write!(f, "}}")
187    }
188}
189
190struct DfaBuilder {
191    dfa: Dfa,
192    lev: DynamicLevenshtein,
193    cache: HashMap<Vec<usize>, usize>,
194}
195
196impl DfaBuilder {
197    fn new(lev: DynamicLevenshtein) -> DfaBuilder {
198        DfaBuilder {
199            dfa: Dfa { states: Vec::with_capacity(16) },
200            lev,
201            cache: HashMap::with_capacity(1024),
202        }
203    }
204
205    fn build(mut self) -> Result<Dfa, Error> {
206        let mut stack = vec![self.lev.start()];
207        let mut seen = HashSet::new();
208        let query = self.lev.query.clone(); // temp work around of borrowck
209        while let Some(lev_state) = stack.pop() {
210            let dfa_si = self.cached_state(&lev_state).unwrap();
211            let mismatch = self.add_mismatch_utf8_states(dfa_si, &lev_state);
212            if let Some((next_si, lev_next)) = mismatch {
213                if !seen.contains(&next_si) {
214                    seen.insert(next_si);
215                    stack.push(lev_next);
216                }
217            }
218            for (i, c) in query.chars().enumerate() {
219                if lev_state[i] > self.lev.dist {
220                    continue;
221                }
222                let lev_next = self.lev.accept(&lev_state, Some(c));
223                let next_si = self.cached_state(&lev_next);
224                if let Some(next_si) = next_si {
225                    self.add_utf8_sequences(true, dfa_si, next_si, c, c);
226                    if !seen.contains(&next_si) {
227                        seen.insert(next_si);
228                        stack.push(lev_next);
229                    }
230                }
231            }
232            if self.dfa.states.len() > STATE_LIMIT {
233                return Err(Error::TooManyStates(STATE_LIMIT));
234            }
235        }
236        Ok(self.dfa)
237    }
238
239    fn cached_state(&mut self, lev_state: &[usize]) -> Option<usize> {
240        self.cached(lev_state).map(|(si, _)| si)
241    }
242
243    fn cached(&mut self, lev_state: &[usize]) -> Option<(usize, bool)> {
244        if !self.lev.can_match(lev_state) {
245            return None;
246        }
247        Some(match self.cache.entry(lev_state.to_vec()) {
248            Entry::Occupied(v) => (*v.get(), true),
249            Entry::Vacant(v) => {
250                let is_match = self.lev.is_match(lev_state);
251                self.dfa.states.push(State { next: [None; 256], is_match });
252                (*v.insert(self.dfa.states.len() - 1), false)
253            }
254        })
255    }
256
257    fn add_mismatch_utf8_states(
258        &mut self,
259        from_si: usize,
260        lev_state: &[usize],
261    ) -> Option<(usize, Vec<usize>)> {
262        let mismatch_state = self.lev.accept(lev_state, None);
263        let to_si = match self.cached(&mismatch_state) {
264            None => return None,
265            Some((si, _)) => si,
266            // Some((si, true)) => return Some((si, mismatch_state)),
267            // Some((si, false)) => si,
268        };
269        self.add_utf8_sequences(false, from_si, to_si, '\u{0}', '\u{10FFFF}');
270        return Some((to_si, mismatch_state));
271    }
272
273    fn add_utf8_sequences(
274        &mut self,
275        overwrite: bool,
276        from_si: usize,
277        to_si: usize,
278        from_chr: char,
279        to_chr: char,
280    ) {
281        for seq in Utf8Sequences::new(from_chr, to_chr) {
282            let mut fsi = from_si;
283            for range in &seq.as_slice()[0..seq.len() - 1] {
284                let tsi = self.new_state(false);
285                self.add_utf8_range(overwrite, fsi, tsi, range);
286                fsi = tsi;
287            }
288            self.add_utf8_range(
289                overwrite,
290                fsi,
291                to_si,
292                &seq.as_slice()[seq.len() - 1],
293            );
294        }
295    }
296
297    fn add_utf8_range(
298        &mut self,
299        overwrite: bool,
300        from: usize,
301        to: usize,
302        range: &Utf8Range,
303    ) {
304        for b in range.start as usize..range.end as usize + 1 {
305            if overwrite || self.dfa.states[from].next[b].is_none() {
306                self.dfa.states[from].next[b] = Some(to);
307            }
308        }
309    }
310
311    fn new_state(&mut self, is_match: bool) -> usize {
312        self.dfa.states.push(State { next: [None; 256], is_match });
313        self.dfa.states.len() - 1
314    }
315}