use crate::tree::RadixNode;
#[derive(Default)]
struct BufferCache {
prev_row: Vec<usize>,
curr_row: Vec<usize>,
}
fn is_fuzzy_prefix_match(
search_term: &[u8],
target: &[u8],
max_distance: u8,
cache: &mut BufferCache,
) -> Option<usize> {
if search_term.is_empty() {
return Some(0); }
let max_dist = max_distance as usize;
let search_len = search_term.len();
let min_prefix_len = search_len.saturating_sub(max_dist).max(1);
let max_prefix_len = std::cmp::min(target.len(), search_len + max_dist);
let mut best_distance = usize::MAX;
for prefix_len in min_prefix_len..=max_prefix_len {
if prefix_len > target.len() {
continue;
}
let dist = levenshtein_distance(search_term, &target[0..prefix_len], cache);
best_distance = std::cmp::min(best_distance, dist);
}
if best_distance <= max_dist {
Some(best_distance)
} else {
None
}
}
fn levenshtein_distance(a: &[u8], b: &[u8], cache: &mut BufferCache) -> usize {
if a.is_empty() {
return b.len();
}
if b.is_empty() {
return a.len();
}
let m = a.len();
let n = b.len();
let prev_row = &mut cache.prev_row;
let curr_row = &mut cache.curr_row;
if prev_row.len() < n + 1 {
prev_row.resize(n + 1, 0);
}
if curr_row.len() < n + 1 {
curr_row.resize(n + 1, 0);
}
#[allow(clippy::needless_range_loop)]
for j in 0..=n {
prev_row[j] = j;
}
for i in 1..=m {
curr_row[0] = i;
for j in 1..=n {
let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
curr_row[j] = std::cmp::min(
std::cmp::min(
prev_row[j] + 1, curr_row[j - 1] + 1, ),
prev_row[j - 1] + cost, );
}
std::mem::swap(prev_row, curr_row);
}
prev_row[n]
}
pub struct TypoTolerantSearchIterator<'a, T> {
stack: Vec<(&'a RadixNode<T>, usize, Vec<u8>, u8)>,
search_key: Vec<u8>,
max_distance: u8,
cache: BufferCache,
}
impl<'a, T> TypoTolerantSearchIterator<'a, T> {
pub(crate) fn new(root: &'a RadixNode<T>, search_key: &[u8], max_distance: u8) -> Self {
let mut iterator = Self {
stack: Vec::new(),
search_key: search_key.to_vec(),
max_distance,
cache: Default::default(),
};
iterator.stack.push((root, 0, Vec::new(), 0));
iterator
}
}
impl<'a, T> Iterator for TypoTolerantSearchIterator<'a, T> {
type Item = (Vec<u8>, &'a T, u8);
fn next(&mut self) -> Option<Self::Item> {
loop {
let (node, child_index, current_key, _) = self.stack.pop()?;
if child_index == 0 {
if let Some(ref value) = node.value {
if let Some(distance) = is_fuzzy_prefix_match(
&self.search_key,
¤t_key,
self.max_distance,
&mut self.cache,
) {
if !node.children.is_empty() {
self.stack.push((node, 1, current_key.clone(), 0));
}
return Some((current_key, value, distance as u8));
}
}
if !node.children.is_empty() {
self.stack.push((node, 1, current_key, 0));
}
continue;
}
let current_child_idx = child_index - 1;
if current_child_idx < node.children.len() {
if current_child_idx + 1 < node.children.len() {
self.stack
.push((node, child_index + 1, current_key.clone(), 0));
}
let (_, child) = &node.children[current_child_idx];
let mut child_key = current_key;
child_key.extend_from_slice(&child.key);
self.stack.push((child, 0, child_key, 0));
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_levenshtein_distance_basic() {
let mut cache = BufferCache::default();
assert_eq!(levenshtein_distance(b"", b"", &mut cache), 0);
assert_eq!(levenshtein_distance(b"abc", b"", &mut cache), 3);
assert_eq!(levenshtein_distance(b"", b"xyz", &mut cache), 3);
assert_eq!(levenshtein_distance(b"abc", b"abc", &mut cache), 0);
assert_eq!(levenshtein_distance(b"hello", b"hello", &mut cache), 0);
assert_eq!(levenshtein_distance(b"abc", b"abd", &mut cache), 1); assert_eq!(levenshtein_distance(b"abc", b"abcd", &mut cache), 1); assert_eq!(levenshtein_distance(b"abcd", b"abc", &mut cache), 1); assert_eq!(levenshtein_distance(b"kitten", b"sitting", &mut cache), 3);
assert_eq!(levenshtein_distance(b"saturday", b"sunday", &mut cache), 3);
assert_eq!(levenshtein_distance(b"abdc", b"abc", &mut cache), 1);
assert_eq!(levenshtein_distance(b"abc", b"abdc", &mut cache), 1);
assert_eq!(levenshtein_distance(b"abc", b"ac", &mut cache), 1);
assert_eq!(levenshtein_distance(b"ac", b"abc", &mut cache), 1);
}
#[test]
fn test_levenshtein_distance_utf8() {
let mut cache = BufferCache::default();
assert_eq!(
levenshtein_distance("café".as_bytes(), "cafe".as_bytes(), &mut cache),
2
);
assert_eq!(
levenshtein_distance("🚀".as_bytes(), "x".as_bytes(), &mut cache),
4
);
assert_eq!(
levenshtein_distance("café".as_bytes(), "café".as_bytes(), &mut cache),
0
);
}
#[test]
fn test_fuzzy_prefix_match_basic() {
let mut cache = BufferCache::default();
assert_eq!(
is_fuzzy_prefix_match(b"hel", b"hello", 0, &mut cache),
Some(0)
);
assert_eq!(
is_fuzzy_prefix_match(b"hello", b"hello", 0, &mut cache),
Some(0)
);
assert_eq!(
is_fuzzy_prefix_match(b"helo", b"hello", 1, &mut cache),
Some(1)
); assert_eq!(
is_fuzzy_prefix_match(b"hel", b"hello", 1, &mut cache),
Some(0)
);
assert_eq!(is_fuzzy_prefix_match(b"xyz", b"abc", 2, &mut cache), None);
assert_eq!(
is_fuzzy_prefix_match(b"hello", b"world", 2, &mut cache),
None
);
assert_eq!(
is_fuzzy_prefix_match(b"", b"anything", 1, &mut cache),
Some(0)
);
}
#[test]
fn test_fuzzy_prefix_match_edge_cases() {
let mut cache = BufferCache::default();
assert_eq!(is_fuzzy_prefix_match(b"hello", b"hel", 1, &mut cache), None);
assert_eq!(
is_fuzzy_prefix_match(b"helo", b"hello", 2, &mut cache),
Some(1)
);
assert_eq!(
is_fuzzy_prefix_match(b"hllo", b"hello", 2, &mut cache),
Some(1)
); assert_eq!(
is_fuzzy_prefix_match(b"hlo", b"hello", 2, &mut cache),
Some(2)
); }
}