1use core::cmp;
2use core::fmt;
3#[cfg(feature = "std")]
4use std::collections::hash_map::Entry;
5#[cfg(feature = "std")]
6use std::collections::{HashMap, HashSet};
7
8use utf8_ranges::{Utf8Range, Utf8Sequences};
9
10use crate::automaton::Automaton;
11
12const DEFAULT_STATE_LIMIT: usize = 10_000; #[derive(Debug)]
18pub enum LevenshteinError {
19 TooManyStates(usize),
24}
25
26impl fmt::Display for LevenshteinError {
27 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
28 match *self {
29 LevenshteinError::TooManyStates(size_limit) => write!(
30 f,
31 "Levenshtein automaton exceeds size limit of \
32 {size_limit} states"
33 ),
34 }
35 }
36}
37
38#[cfg(not(feature = "std"))]
39impl core::error::Error for LevenshteinError {}
40
41#[cfg(feature = "std")]
42impl std::error::Error for LevenshteinError {}
43
44#[cfg(feature = "alloc")]
100pub struct Levenshtein {
101 prog: DynamicLevenshtein,
102 dfa: Dfa,
103}
104
105#[cfg(feature = "alloc")]
106impl Levenshtein {
107 #[cfg(feature = "alloc")]
119 pub fn new(
120 query: &str,
121 distance: u32,
122 ) -> Result<Levenshtein, LevenshteinError> {
123 let lev = DynamicLevenshtein {
124 query: query.to_owned(),
125 dist: distance as usize,
126 };
127 let dfa = DfaBuilder::new(lev.clone()).build()?;
128 Ok(Levenshtein { prog: lev, dfa })
129 }
130
131 #[cfg(feature = "alloc")]
143 pub fn new_with_limit(
144 query: &str,
145 distance: u32,
146 state_limit: usize,
147 ) -> Result<Levenshtein, LevenshteinError> {
148 let lev = DynamicLevenshtein {
149 query: query.to_owned(),
150 dist: distance as usize,
151 };
152 let dfa =
153 DfaBuilder::new(lev.clone()).build_with_limit(state_limit)?;
154 Ok(Levenshtein { prog: lev, dfa })
155 }
156}
157
158#[cfg(feature = "alloc")]
159impl fmt::Debug for Levenshtein {
160 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
161 write!(
162 f,
163 "Levenshtein(query: {:?}, distance: {:?})",
164 self.prog.query, self.prog.dist
165 )
166 }
167}
168
169#[derive(Clone)]
170#[cfg(feature = "alloc")]
171struct DynamicLevenshtein {
172 query: String,
173 dist: usize,
174}
175
176#[cfg(feature = "alloc")]
177impl DynamicLevenshtein {
178 fn start(&self) -> Vec<usize> {
179 (0..self.query.chars().count() + 1).collect()
180 }
181
182 fn is_match(&self, state: &[usize]) -> bool {
183 state.last().is_some_and(|&n| n <= self.dist)
184 }
185
186 fn can_match(&self, state: &[usize]) -> bool {
187 state.iter().min().is_some_and(|&n| n <= self.dist)
188 }
189
190 fn accept(&self, state: &[usize], chr: Option<char>) -> Vec<usize> {
191 let mut next = vec![state[0] + 1];
192 for (i, c) in self.query.chars().enumerate() {
193 let cost = usize::from(Some(c) != chr);
194 let v = cmp::min(
195 cmp::min(next[i] + 1, state[i + 1] + 1),
196 state[i] + cost,
197 );
198 next.push(cmp::min(v, self.dist + 1));
199 }
200 next
201 }
202}
203
204#[cfg(feature = "alloc")]
205impl Automaton for Levenshtein {
206 type State = Option<usize>;
207
208 #[inline]
209 fn start(&self) -> Option<usize> {
210 Some(0)
211 }
212
213 #[inline]
214 fn is_match(&self, state: &Option<usize>) -> bool {
215 state.map(|state| self.dfa.states[state].is_match).unwrap_or(false)
216 }
217
218 #[inline]
219 fn can_match(&self, state: &Option<usize>) -> bool {
220 state.is_some()
221 }
222
223 #[inline]
224 fn accept(&self, state: &Option<usize>, byte: u8) -> Option<usize> {
225 state.and_then(|state| self.dfa.states[state].next[byte as usize])
226 }
227}
228
229#[derive(Debug)]
230#[cfg(feature = "alloc")]
231struct Dfa {
232 states: Vec<State>,
233}
234
235struct State {
236 next: [Option<usize>; 256],
237 is_match: bool,
238}
239
240impl fmt::Debug for State {
241 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
242 writeln!(f, "State {{")?;
243 writeln!(f, " is_match: {:?}", self.is_match)?;
244 for i in 0..256 {
245 if let Some(si) = self.next[i] {
246 writeln!(f, " {i:?}: {si:?}")?;
247 }
248 }
249 write!(f, "}}")
250 }
251}
252
253#[cfg(feature = "alloc")]
254struct DfaBuilder {
255 dfa: Dfa,
256 lev: DynamicLevenshtein,
257 cache: HashMap<Vec<usize>, usize>,
258}
259
260#[cfg(feature = "alloc")]
261impl DfaBuilder {
262 fn new(lev: DynamicLevenshtein) -> DfaBuilder {
263 DfaBuilder {
264 dfa: Dfa { states: Vec::with_capacity(16) },
265 lev,
266 cache: HashMap::with_capacity(1024),
267 }
268 }
269
270 fn build_with_limit(
271 mut self,
272 state_limit: usize,
273 ) -> Result<Dfa, LevenshteinError> {
274 let mut stack = vec![self.lev.start()];
275 let mut seen = HashSet::new();
276 let query = self.lev.query.clone(); while let Some(lev_state) = stack.pop() {
278 let dfa_si = self.cached_state(&lev_state).unwrap();
279 let mismatch = self.add_mismatch_utf8_states(dfa_si, &lev_state);
280 if let Some((next_si, lev_next)) = mismatch {
281 if !seen.contains(&next_si) {
282 seen.insert(next_si);
283 stack.push(lev_next);
284 }
285 }
286 for (i, c) in query.chars().enumerate() {
287 if lev_state[i] > self.lev.dist {
288 continue;
289 }
290 let lev_next = self.lev.accept(&lev_state, Some(c));
291 let next_si = self.cached_state(&lev_next);
292 if let Some(next_si) = next_si {
293 self.add_utf8_sequences(true, dfa_si, next_si, c, c);
294 if !seen.contains(&next_si) {
295 seen.insert(next_si);
296 stack.push(lev_next);
297 }
298 }
299 }
300 if self.dfa.states.len() > state_limit {
301 return Err(LevenshteinError::TooManyStates(state_limit));
302 }
303 }
304 Ok(self.dfa)
305 }
306
307 fn build(self) -> Result<Dfa, LevenshteinError> {
308 self.build_with_limit(DEFAULT_STATE_LIMIT)
309 }
310
311 fn cached_state(&mut self, lev_state: &[usize]) -> Option<usize> {
312 self.cached(lev_state).map(|(si, _)| si)
313 }
314
315 fn cached(&mut self, lev_state: &[usize]) -> Option<(usize, bool)> {
316 if !self.lev.can_match(lev_state) {
317 return None;
318 }
319 Some(match self.cache.entry(lev_state.to_vec()) {
320 Entry::Occupied(v) => (*v.get(), true),
321 Entry::Vacant(v) => {
322 let is_match = self.lev.is_match(lev_state);
323 self.dfa.states.push(State { next: [None; 256], is_match });
324 (*v.insert(self.dfa.states.len() - 1), false)
325 }
326 })
327 }
328
329 fn add_mismatch_utf8_states(
330 &mut self,
331 from_si: usize,
332 lev_state: &[usize],
333 ) -> Option<(usize, Vec<usize>)> {
334 let mismatch_state = self.lev.accept(lev_state, None);
335 let to_si = match self.cached(&mismatch_state) {
336 None => return None,
337 Some((si, _)) => si,
338 };
341 self.add_utf8_sequences(false, from_si, to_si, '\u{0}', '\u{10FFFF}');
342 Some((to_si, mismatch_state))
343 }
344
345 fn add_utf8_sequences(
346 &mut self,
347 overwrite: bool,
348 from_si: usize,
349 to_si: usize,
350 from_chr: char,
351 to_chr: char,
352 ) {
353 for seq in Utf8Sequences::new(from_chr, to_chr) {
354 let mut fsi = from_si;
355 for range in &seq.as_slice()[0..seq.len() - 1] {
356 let tsi = self.new_state(false);
357 self.add_utf8_range(overwrite, fsi, tsi, range);
358 fsi = tsi;
359 }
360 self.add_utf8_range(
361 overwrite,
362 fsi,
363 to_si,
364 &seq.as_slice()[seq.len() - 1],
365 );
366 }
367 }
368
369 fn add_utf8_range(
370 &mut self,
371 overwrite: bool,
372 from: usize,
373 to: usize,
374 range: &Utf8Range,
375 ) {
376 for b in (range.start as usize)..=(range.end as usize) {
377 if overwrite || self.dfa.states[from].next[b].is_none() {
378 self.dfa.states[from].next[b] = Some(to);
379 }
380 }
381 }
382
383 fn new_state(&mut self, is_match: bool) -> usize {
384 self.dfa.states.push(State { next: [None; 256], is_match });
385 self.dfa.states.len() - 1
386 }
387}