1use std::cmp;
2use std::collections::hash_map::Entry;
3use std::collections::{HashMap, HashSet};
4use std::fmt;
5
6use utf8_ranges::{Utf8Range, Utf8Sequences};
7
8use fst::automaton::Automaton;
9
10pub use self::error::Error;
11
12mod error;
13
14const STATE_LIMIT: usize = 10_000; pub struct Levenshtein {
74 prog: DynamicLevenshtein,
75 dfa: Dfa,
76}
77
78impl Levenshtein {
79 #[inline]
90 pub fn new(query: &str, distance: u32) -> Result<Levenshtein, Error> {
91 let lev = DynamicLevenshtein {
92 query: query.to_owned(),
93 dist: distance as usize,
94 };
95 let dfa = DfaBuilder::new(lev.clone()).build()?;
96 Ok(Levenshtein { prog: lev, dfa })
97 }
98}
99
100impl fmt::Debug for Levenshtein {
101 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
102 write!(
103 f,
104 "Levenshtein(query: {:?}, distance: {:?})",
105 self.prog.query, self.prog.dist
106 )
107 }
108}
109
110#[derive(Clone)]
111struct DynamicLevenshtein {
112 query: String,
113 dist: usize,
114}
115
116impl DynamicLevenshtein {
117 fn start(&self) -> Vec<usize> {
118 (0..self.query.chars().count() + 1).collect()
119 }
120
121 fn is_match(&self, state: &[usize]) -> bool {
122 state.last().map(|&n| n <= self.dist).unwrap_or(false)
123 }
124
125 fn can_match(&self, state: &[usize]) -> bool {
126 state.iter().min().map(|&n| n <= self.dist).unwrap_or(false)
127 }
128
129 fn accept(&self, state: &[usize], chr: Option<char>) -> Vec<usize> {
130 let mut next = vec![state[0] + 1];
131 for (i, c) in self.query.chars().enumerate() {
132 let cost = if Some(c) == chr { 0 } else { 1 };
133 let v = cmp::min(
134 cmp::min(next[i] + 1, state[i + 1] + 1),
135 state[i] + cost,
136 );
137 next.push(cmp::min(v, self.dist + 1));
138 }
139 next
140 }
141}
142
143impl Automaton for Levenshtein {
144 type State = Option<usize>;
145
146 #[inline]
147 fn start(&self) -> Option<usize> {
148 Some(0)
149 }
150
151 #[inline]
152 fn is_match(&self, state: &Option<usize>) -> bool {
153 state.map(|state| self.dfa.states[state].is_match).unwrap_or(false)
154 }
155
156 #[inline]
157 fn can_match(&self, state: &Option<usize>) -> bool {
158 state.is_some()
159 }
160
161 #[inline]
162 fn accept(&self, state: &Option<usize>, byte: u8) -> Option<usize> {
163 state.and_then(|state| self.dfa.states[state].next[byte as usize])
164 }
165}
166
167#[derive(Debug)]
168pub struct Dfa {
169 states: Vec<State>,
170}
171
172struct State {
173 next: [Option<usize>; 256],
174 is_match: bool,
175}
176
177impl fmt::Debug for State {
178 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
179 writeln!(f, "State {{")?;
180 writeln!(f, " is_match: {:?}", self.is_match)?;
181 for i in 0..256 {
182 if let Some(si) = self.next[i] {
183 writeln!(f, " {:?}: {:?}", i, si)?;
184 }
185 }
186 write!(f, "}}")
187 }
188}
189
190struct DfaBuilder {
191 dfa: Dfa,
192 lev: DynamicLevenshtein,
193 cache: HashMap<Vec<usize>, usize>,
194}
195
196impl DfaBuilder {
197 fn new(lev: DynamicLevenshtein) -> DfaBuilder {
198 DfaBuilder {
199 dfa: Dfa { states: Vec::with_capacity(16) },
200 lev,
201 cache: HashMap::with_capacity(1024),
202 }
203 }
204
205 fn build(mut self) -> Result<Dfa, Error> {
206 let mut stack = vec![self.lev.start()];
207 let mut seen = HashSet::new();
208 let query = self.lev.query.clone(); while let Some(lev_state) = stack.pop() {
210 let dfa_si = self.cached_state(&lev_state).unwrap();
211 let mismatch = self.add_mismatch_utf8_states(dfa_si, &lev_state);
212 if let Some((next_si, lev_next)) = mismatch {
213 if !seen.contains(&next_si) {
214 seen.insert(next_si);
215 stack.push(lev_next);
216 }
217 }
218 for (i, c) in query.chars().enumerate() {
219 if lev_state[i] > self.lev.dist {
220 continue;
221 }
222 let lev_next = self.lev.accept(&lev_state, Some(c));
223 let next_si = self.cached_state(&lev_next);
224 if let Some(next_si) = next_si {
225 self.add_utf8_sequences(true, dfa_si, next_si, c, c);
226 if !seen.contains(&next_si) {
227 seen.insert(next_si);
228 stack.push(lev_next);
229 }
230 }
231 }
232 if self.dfa.states.len() > STATE_LIMIT {
233 return Err(Error::TooManyStates(STATE_LIMIT));
234 }
235 }
236 Ok(self.dfa)
237 }
238
239 fn cached_state(&mut self, lev_state: &[usize]) -> Option<usize> {
240 self.cached(lev_state).map(|(si, _)| si)
241 }
242
243 fn cached(&mut self, lev_state: &[usize]) -> Option<(usize, bool)> {
244 if !self.lev.can_match(lev_state) {
245 return None;
246 }
247 Some(match self.cache.entry(lev_state.to_vec()) {
248 Entry::Occupied(v) => (*v.get(), true),
249 Entry::Vacant(v) => {
250 let is_match = self.lev.is_match(lev_state);
251 self.dfa.states.push(State { next: [None; 256], is_match });
252 (*v.insert(self.dfa.states.len() - 1), false)
253 }
254 })
255 }
256
257 fn add_mismatch_utf8_states(
258 &mut self,
259 from_si: usize,
260 lev_state: &[usize],
261 ) -> Option<(usize, Vec<usize>)> {
262 let mismatch_state = self.lev.accept(lev_state, None);
263 let to_si = match self.cached(&mismatch_state) {
264 None => return None,
265 Some((si, _)) => si,
266 };
269 self.add_utf8_sequences(false, from_si, to_si, '\u{0}', '\u{10FFFF}');
270 return Some((to_si, mismatch_state));
271 }
272
273 fn add_utf8_sequences(
274 &mut self,
275 overwrite: bool,
276 from_si: usize,
277 to_si: usize,
278 from_chr: char,
279 to_chr: char,
280 ) {
281 for seq in Utf8Sequences::new(from_chr, to_chr) {
282 let mut fsi = from_si;
283 for range in &seq.as_slice()[0..seq.len() - 1] {
284 let tsi = self.new_state(false);
285 self.add_utf8_range(overwrite, fsi, tsi, range);
286 fsi = tsi;
287 }
288 self.add_utf8_range(
289 overwrite,
290 fsi,
291 to_si,
292 &seq.as_slice()[seq.len() - 1],
293 );
294 }
295 }
296
297 fn add_utf8_range(
298 &mut self,
299 overwrite: bool,
300 from: usize,
301 to: usize,
302 range: &Utf8Range,
303 ) {
304 for b in range.start as usize..range.end as usize + 1 {
305 if overwrite || self.dfa.states[from].next[b].is_none() {
306 self.dfa.states[from].next[b] = Some(to);
307 }
308 }
309 }
310
311 fn new_state(&mut self, is_match: bool) -> usize {
312 self.dfa.states.push(State { next: [None; 256], is_match });
313 self.dfa.states.len() - 1
314 }
315}