use std::cmp::min;
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum EditOp {
Delete(char),
Insert(char),
Substitute(char, char),
Transpose(char, char),
}
#[derive(Debug, Clone)]
pub struct ErrorModel {
pub p_deletion: f64,
pub p_insertion: f64,
pub p_substitution: f64,
pub p_transposition: f64,
_char_confusion: HashMap<(char, char), f64>,
max_edit_distance: usize,
}
impl Default for ErrorModel {
fn default() -> Self {
Self {
p_deletion: 0.25,
p_insertion: 0.25,
p_substitution: 0.25,
p_transposition: 0.25,
_char_confusion: HashMap::new(),
max_edit_distance: 2, }
}
}
impl ErrorModel {
pub fn new(
p_deletion: f64,
p_insertion: f64,
p_substitution: f64,
p_transposition: f64,
) -> Self {
let total = p_deletion + p_insertion + p_substitution + p_transposition;
Self {
p_deletion: p_deletion / total,
p_insertion: p_insertion / total,
p_substitution: p_substitution / total,
p_transposition: p_transposition / total,
_char_confusion: HashMap::new(),
max_edit_distance: 2,
}
}
pub fn with_max_distance(mut self, maxdistance: usize) -> Self {
self.max_edit_distance = maxdistance;
self
}
pub fn error_probability(&self, typo: &str, correct: &str) -> f64 {
if typo == correct {
return 1.0;
}
let edit_distance = self.min_edit_operations(typo, correct);
match edit_distance.len() {
0 => 1.0, 1 => {
match edit_distance[0] {
EditOp::Delete(_) => self.p_deletion,
EditOp::Insert(_) => self.p_insertion,
EditOp::Substitute(_, _) => self.p_substitution,
EditOp::Transpose(_, _) => self.p_transposition,
}
}
n => {
let base_prob = 0.1f64.powi(n as i32 - 1);
let mut prob = base_prob;
for op in &edit_distance {
match op {
EditOp::Delete(_) => prob *= self.p_deletion,
EditOp::Insert(_) => prob *= self.p_insertion,
EditOp::Substitute(_, _) => prob *= self.p_substitution,
EditOp::Transpose(_, _) => prob *= self.p_transposition,
}
}
prob
}
}
}
pub fn min_edit_operations(&self, typo: &str, correct: &str) -> Vec<EditOp> {
let typo_chars: Vec<char> = typo.chars().collect();
let correct_chars: Vec<char> = correct.chars().collect();
if typo == correct {
return vec![];
}
if (typo_chars.len() as isize - correct_chars.len() as isize).abs()
> self.max_edit_distance as isize
{
return vec![EditOp::Substitute('?', '?')];
}
if correct_chars.len() == typo_chars.len() + 1 {
for i in 0..correct_chars.len() {
let mut test_chars = correct_chars.clone();
test_chars.remove(i);
if test_chars == typo_chars {
return vec![EditOp::Delete(correct_chars[i])];
}
}
} else if correct_chars.len() + 1 == typo_chars.len() {
for i in 0..typo_chars.len() {
let mut test_chars = typo_chars.clone();
test_chars.remove(i);
if test_chars == correct_chars {
return vec![EditOp::Insert(typo_chars[i])];
}
}
} else if correct_chars.len() == typo_chars.len() {
let mut diff_positions = Vec::new();
for i in 0..correct_chars.len() {
if correct_chars[i] != typo_chars[i] {
diff_positions.push(i);
}
}
if diff_positions.len() == 1 {
let i = diff_positions[0];
return vec![EditOp::Substitute(correct_chars[i], typo_chars[i])];
} else if diff_positions.len() == 2 && diff_positions[0] + 1 == diff_positions[1] {
let i = diff_positions[0];
if correct_chars[i] == typo_chars[i + 1] && correct_chars[i + 1] == typo_chars[i] {
return vec![EditOp::Transpose(correct_chars[i], correct_chars[i + 1])];
}
}
}
let mut operations = Vec::new();
let _distance = self.levenshtein_with_ops_efficient(correct, typo, &mut operations);
operations
}
fn levenshtein_with_ops_efficient(
&self,
s1: &str,
s2: &str,
operations: &mut Vec<EditOp>,
) -> usize {
let chars1: Vec<char> = s1.chars().collect();
let chars2: Vec<char> = s2.chars().collect();
let len1 = chars1.len();
let len2 = chars2.len();
if s1 == s2 {
return 0;
}
if (len1 as isize - len2 as isize).abs() > self.max_edit_distance as isize {
return self.max_edit_distance + 1; }
let mut prev_row = (0..=len2).collect::<Vec<_>>();
let mut curr_row = vec![0; len2 + 1];
let mut op_matrix = vec![vec![0; len2 + 1]; len1 + 1];
for j in 1..=len2 {
op_matrix[0][j] = 1; }
for i in 1..=len1 {
curr_row[0] = i;
op_matrix[i][0] = 2;
for j in 1..=len2 {
let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
let del_cost = prev_row[j] + 1;
let ins_cost = curr_row[j - 1] + 1;
let sub_cost = prev_row[j - 1] + cost;
curr_row[j] = min(min(del_cost, ins_cost), sub_cost);
if curr_row[j] == del_cost {
op_matrix[i][j] = 2; } else if curr_row[j] == ins_cost {
op_matrix[i][j] = 1; } else if cost > 0 {
op_matrix[i][j] = 3; } else {
op_matrix[i][j] = 0; }
if i > 1
&& j > 1
&& chars1[i - 1] == chars2[j - 2]
&& chars1[i - 2] == chars2[j - 1]
{
let trans_cost = prev_row[j - 2] + 1;
if trans_cost < curr_row[j] {
curr_row[j] = trans_cost;
op_matrix[i][j] = 4; }
}
}
if curr_row.iter().all(|&c| c > self.max_edit_distance) {
return self.max_edit_distance + 1;
}
std::mem::swap(&mut prev_row, &mut curr_row);
}
let mut i = len1;
let mut j = len2;
let mut backtrack_ops = Vec::new();
while i > 0 || j > 0 {
match if i == 0 || j == 0 {
if i == 0 {
1
} else {
2
} } else {
op_matrix[i][j]
} {
0 => {
i -= 1;
j -= 1;
}
1 => {
j -= 1;
backtrack_ops.push(EditOp::Insert(chars2[j]));
}
2 => {
i -= 1;
backtrack_ops.push(EditOp::Delete(chars1[i]));
}
3 => {
i -= 1;
j -= 1;
backtrack_ops.push(EditOp::Substitute(chars1[i], chars2[j]));
}
4 => {
i -= 2;
j -= 2;
backtrack_ops.push(EditOp::Transpose(chars1[i + 1], chars1[i + 2]));
}
_ => break, }
}
backtrack_ops.reverse();
operations.extend(backtrack_ops);
prev_row[len2]
}
pub fn levenshtein_with_ops(&self, s1: &str, s2: &str, operations: &mut Vec<EditOp>) -> usize {
let chars1: Vec<char> = s1.chars().collect();
let chars2: Vec<char> = s2.chars().collect();
let len1 = chars1.len();
let len2 = chars2.len();
let mut matrix = vec![vec![0; len2 + 1]; len1 + 1];
for (i, row) in matrix.iter_mut().enumerate().take(len1 + 1) {
row[0] = i;
}
for j in 0..=len2 {
matrix[0][j] = j;
}
for i in 1..=len1 {
for j in 1..=len2 {
let cost = if chars1[i - 1] == chars2[j - 1] { 0 } else { 1 };
matrix[i][j] = min(
min(
matrix[i - 1][j] + 1, matrix[i][j - 1] + 1, ),
matrix[i - 1][j - 1] + cost, );
if i > 1
&& j > 1
&& chars1[i - 1] == chars2[j - 2]
&& chars1[i - 2] == chars2[j - 1]
{
matrix[i][j] = min(
matrix[i][j],
matrix[i - 2][j - 2] + 1, );
}
}
}
let mut i = len1;
let mut j = len2;
let mut temp_ops = Vec::new();
while i > 0 || j > 0 {
if i > 0 && j > 0 && chars1[i - 1] == chars2[j - 1] {
i -= 1;
j -= 1;
} else if i > 1
&& j > 1
&& chars1[i - 1] == chars2[j - 2]
&& chars1[i - 2] == chars2[j - 1]
&& matrix[i][j] == matrix[i - 2][j - 2] + 1
{
temp_ops.push(EditOp::Transpose(chars1[i - 2], chars1[i - 1]));
i -= 2;
j -= 2;
} else if i > 0 && j > 0 && matrix[i][j] == matrix[i - 1][j - 1] + 1 {
temp_ops.push(EditOp::Substitute(chars1[i - 1], chars2[j - 1]));
i -= 1;
j -= 1;
} else if i > 0 && matrix[i][j] == matrix[i - 1][j] + 1 {
temp_ops.push(EditOp::Delete(chars1[i - 1]));
i -= 1;
} else if j > 0 && matrix[i][j] == matrix[i][j - 1] + 1 {
temp_ops.push(EditOp::Insert(chars2[j - 1]));
j -= 1;
} else {
break;
}
}
temp_ops.reverse();
operations.extend(temp_ops);
matrix[len1][len2]
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_model() {
let error_model = ErrorModel::default();
let p_deletion = error_model.error_probability("cat", "cart"); let p_insertion = error_model.error_probability("cart", "cat"); let p_substitution = error_model.error_probability("cat", "cut"); let p_transposition = error_model.error_probability("form", "from");
assert!(p_deletion > 0.0);
assert!(p_insertion > 0.0);
assert!(p_substitution > 0.0);
assert!(p_transposition > 0.0);
assert_eq!(error_model.error_probability("word", "word"), 1.0);
}
#[test]
fn test_edit_operations() {
let error_model = ErrorModel::default();
let ops = error_model.min_edit_operations("cat", "cart");
assert_eq!(ops.len(), 1);
assert!(matches!(ops[0], EditOp::Delete('r')));
let ops = error_model.min_edit_operations("cart", "cat");
assert_eq!(ops.len(), 1);
assert!(matches!(ops[0], EditOp::Insert('r')));
let ops = error_model.min_edit_operations("cut", "cat");
assert_eq!(ops.len(), 1);
assert!(matches!(ops[0], EditOp::Substitute('a', 'u')));
let ops = error_model.min_edit_operations("from", "form");
assert_eq!(ops.len(), 1);
assert!(matches!(ops[0], EditOp::Transpose('o', 'r')));
}
#[test]
fn test_efficient_levenshtein() {
let error_model = ErrorModel::default();
let mut ops1 = Vec::new();
let mut ops2 = Vec::new();
let dist1 = error_model.levenshtein_with_ops("hello", "hello", &mut ops1);
let dist2 = error_model.levenshtein_with_ops_efficient("hello", "hello", &mut ops2);
assert_eq!(dist1, 0);
assert_eq!(dist2, 0);
assert!(ops1.is_empty());
assert!(ops2.is_empty());
let test_cases = [
("cat", "bat"), ("cat", "cats"), ("cats", "cat"), ];
for (s1, s2) in test_cases {
let mut ops1 = Vec::new();
let mut ops2 = Vec::new();
let dist1 = error_model.levenshtein_with_ops(s1, s2, &mut ops1);
let dist2 = error_model.levenshtein_with_ops_efficient(s1, s2, &mut ops2);
assert_eq!(dist1, 1);
assert_eq!(dist2, 1);
}
let mut ops1 = Vec::new();
let mut ops2 = Vec::new();
error_model.levenshtein_with_ops("abc", "acb", &mut ops1);
error_model.levenshtein_with_ops_efficient("abc", "acb", &mut ops2);
assert!(ops1.len() <= 2); assert!(ops2.len() <= 2);
let mut ops1 = Vec::new();
let mut ops2 = Vec::new();
let dist1 = error_model.levenshtein_with_ops("programming", "programmer", &mut ops1);
let dist2 =
error_model.levenshtein_with_ops_efficient("programming", "programmer", &mut ops2);
assert!(dist1 <= 3); assert!(dist2 <= 3);
}
#[test]
fn test_early_termination() {
let error_model = ErrorModel::default().with_max_distance(1);
let ops = error_model.min_edit_operations("cat", "dog");
if !ops.is_empty() {
assert!(matches!(ops[0], EditOp::Substitute(_, _)) || ops.len() > 1);
}
let error_model = ErrorModel::default().with_max_distance(3);
let ops = error_model.min_edit_operations("kitten", "sitting");
assert!(!ops.is_empty());
let ops = error_model.min_edit_operations("algorithm", "logarithm");
if ops.len() == 1 {
assert!(matches!(ops[0], EditOp::Substitute(_, _)));
} else {
assert!(!ops.is_empty());
}
}
}