use std::cmp::min;
#[allow(clippy::needless_range_loop)]
pub fn levenshtein_distance(s1: &str, s2: &str) -> usize {
let len1 = s1.chars().count();
let len2 = s2.chars().count();
if len1 == 0 {
return len2;
}
if len2 == 0 {
return len1;
}
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();
let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
for i in 0..=len1 {
matrix[i][0] = i;
}
for j in 0..=len2 {
matrix[0][j] = j;
}
for i in 1..=len1 {
for j in 1..=len2 {
let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
0
} else {
1
};
matrix[i][j] = min(
min(
matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
matrix[i - 1][j - 1] + cost, );
}
}
matrix[len1][len2]
}
#[allow(clippy::needless_range_loop)]
pub fn levenshtein_distance_threshold(s1: &str, s2: &str, threshold: usize) -> Option<usize> {
let len1 = s1.chars().count();
let len2 = s2.chars().count();
if len1.abs_diff(len2) > threshold {
return None;
}
if len1 == 0 {
return if len2 <= threshold { Some(len2) } else { None };
}
if len2 == 0 {
return if len1 <= threshold { Some(len1) } else { None };
}
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();
let mut prev_row = vec![0; len2 + 1];
let mut curr_row = vec![0; len2 + 1];
for j in 0..=len2 {
prev_row[j] = j;
}
for i in 1..=len1 {
curr_row[0] = i;
let mut min_in_row = i;
for j in 1..=len2 {
let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
0
} else {
1
};
curr_row[j] = min(
min(
prev_row[j] + 1, curr_row[j - 1] + 1, ),
prev_row[j - 1] + cost, );
min_in_row = min(min_in_row, curr_row[j]);
}
if min_in_row > threshold {
return None;
}
std::mem::swap(&mut prev_row, &mut curr_row);
}
let distance = prev_row[len2];
if distance <= threshold {
Some(distance)
} else {
None
}
}
#[allow(clippy::needless_range_loop)]
pub fn damerau_levenshtein_distance(s1: &str, s2: &str) -> usize {
let len1 = s1.chars().count();
let len2 = s2.chars().count();
if len1 == 0 {
return len2;
}
if len2 == 0 {
return len1;
}
let s1_chars: Vec<char> = s1.chars().collect();
let s2_chars: Vec<char> = s2.chars().collect();
let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
for i in 0..=len1 {
matrix[i][0] = i;
}
for j in 0..=len2 {
matrix[0][j] = j;
}
for i in 1..=len1 {
for j in 1..=len2 {
let cost = if s1_chars[i - 1] == s2_chars[j - 1] {
0
} else {
1
};
matrix[i][j] = min(
min(
matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
matrix[i - 1][j - 1] + cost, );
if i > 1
&& j > 1
&& s1_chars[i - 1] == s2_chars[j - 2]
&& s1_chars[i - 2] == s2_chars[j - 1]
{
matrix[i][j] = min(
matrix[i][j],
matrix[i - 2][j - 2] + cost, );
}
}
}
matrix[len1][len2]
}
pub struct LevenshteinMatcher {
query: String,
#[allow(dead_code)]
query_chars: Vec<char>,
#[allow(dead_code)]
query_len: usize,
}
impl LevenshteinMatcher {
pub fn new(query: String) -> Self {
let query_chars: Vec<char> = query.chars().collect();
let query_len = query_chars.len();
LevenshteinMatcher {
query,
query_chars,
query_len,
}
}
pub fn distance_threshold(&self, candidate: &str, threshold: usize) -> Option<usize> {
levenshtein_distance_threshold(&self.query, candidate, threshold)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_levenshtein_distance() {
assert_eq!(levenshtein_distance("", ""), 0);
assert_eq!(levenshtein_distance("", "a"), 1);
assert_eq!(levenshtein_distance("a", ""), 1);
assert_eq!(levenshtein_distance("a", "a"), 0);
assert_eq!(levenshtein_distance("ab", "ac"), 1);
assert_eq!(levenshtein_distance("abc", "def"), 3);
assert_eq!(levenshtein_distance("kitten", "sitting"), 3);
assert_eq!(levenshtein_distance("search", "serach"), 2); }
#[test]
fn test_levenshtein_distance_threshold() {
assert_eq!(
levenshtein_distance_threshold("kitten", "sitting", 3),
Some(3)
);
assert_eq!(levenshtein_distance_threshold("kitten", "sitting", 2), None);
assert_eq!(
levenshtein_distance_threshold("search", "search", 0),
Some(0)
);
assert_eq!(levenshtein_distance_threshold("a", "abc", 1), None);
assert_eq!(levenshtein_distance_threshold("a", "ab", 1), Some(1));
}
#[test]
fn test_damerau_levenshtein_distance() {
assert_eq!(damerau_levenshtein_distance("", ""), 0);
assert_eq!(damerau_levenshtein_distance("ab", "ba"), 1); assert_eq!(damerau_levenshtein_distance("search", "serach"), 1); assert_eq!(damerau_levenshtein_distance("kitten", "sitting"), 3);
}
#[test]
fn test_levenshtein_matcher() {
let matcher = LevenshteinMatcher::new("search".to_string());
assert_eq!(matcher.distance_threshold("search", 2), Some(0));
assert_eq!(matcher.distance_threshold("serach", 2), Some(2));
}
#[test]
fn test_common_typos() {
let common_typos = vec![
("the", "teh"), ("search", "serach"), ("hello", "helo"), ("world", "wrold"), ("quick", "quikc"), ];
for (correct, typo) in common_typos {
let distance = levenshtein_distance(correct, typo);
assert!(
distance <= 2,
"Distance too high for {correct} -> {typo}: {distance}"
);
let damerau_distance = damerau_levenshtein_distance(correct, typo);
assert!(
damerau_distance <= distance,
"Damerau distance should be <= Levenshtein"
);
}
}
}