use crate::error::SeqResult;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum EditOp {
Match { src: usize, tgt: usize },
Substitute { src: usize, tgt: usize },
Delete { src: usize },
Insert { tgt: usize },
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub struct EditCounts {
pub matches: usize,
pub substitutions: usize,
pub deletions: usize,
pub insertions: usize,
}
impl EditCounts {
#[must_use]
pub fn distance(&self) -> usize {
self.substitutions + self.deletions + self.insertions
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct EditAlignment {
pub ops: Vec<EditOp>,
pub counts: EditCounts,
}
pub fn align<T: Eq>(a: &[T], b: &[T]) -> EditAlignment {
let m = a.len();
let n = b.len();
let cols = n + 1;
let mut dp = vec![0usize; (m + 1) * cols];
for i in 0..=m {
dp[i * cols] = i;
}
for j in 0..=n {
dp[j] = j;
}
for i in 1..=m {
for j in 1..=n {
let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
let del = dp[(i - 1) * cols + j] + 1;
let ins = dp[i * cols + (j - 1)] + 1;
let sub = dp[(i - 1) * cols + (j - 1)] + cost;
dp[i * cols + j] = del.min(ins).min(sub);
}
}
let mut ops_rev: Vec<EditOp> = Vec::new();
let mut counts = EditCounts::default();
let mut i = m;
let mut j = n;
while i > 0 || j > 0 {
let here = dp[i * cols + j];
if i > 0 && j > 0 {
let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
if here == dp[(i - 1) * cols + (j - 1)] + cost {
if cost == 0 {
ops_rev.push(EditOp::Match {
src: i - 1,
tgt: j - 1,
});
counts.matches += 1;
} else {
ops_rev.push(EditOp::Substitute {
src: i - 1,
tgt: j - 1,
});
counts.substitutions += 1;
}
i -= 1;
j -= 1;
continue;
}
}
if i > 0 && here == dp[(i - 1) * cols + j] + 1 {
ops_rev.push(EditOp::Delete { src: i - 1 });
counts.deletions += 1;
i -= 1;
continue;
}
ops_rev.push(EditOp::Insert { tgt: j - 1 });
counts.insertions += 1;
j -= 1;
}
ops_rev.reverse();
EditAlignment {
ops: ops_rev,
counts,
}
}
pub fn edit_distance_aligned<T: Eq>(a: &[T], b: &[T]) -> usize {
align(a, b).counts.distance()
}
pub fn word_error_rate<T: Eq>(reference: &[T], hypothesis: &[T]) -> SeqResult<f64> {
let counts = align(reference, hypothesis).counts;
let n = reference.len();
if n == 0 {
return Ok(counts.distance() as f64);
}
Ok(counts.distance() as f64 / n as f64)
}
pub fn character_error_rate(reference: &str, hypothesis: &str) -> SeqResult<f64> {
let r: Vec<char> = reference.chars().collect();
let h: Vec<char> = hypothesis.chars().collect();
word_error_rate(&r, &h)
}
#[cfg(test)]
mod tests {
use super::*;
fn chars(s: &str) -> Vec<char> {
s.chars().collect()
}
#[test]
fn kitten_to_sitting_distance_three() {
let a = chars("kitten");
let b = chars("sitting");
let al = align(&a, &b);
assert_eq!(al.counts.distance(), 3);
}
#[test]
fn kitten_to_sitting_op_breakdown() {
let a = chars("kitten");
let b = chars("sitting");
let al = align(&a, &b);
assert_eq!(al.counts.substitutions, 2);
assert_eq!(al.counts.insertions, 1);
assert_eq!(al.counts.deletions, 0);
assert_eq!(al.counts.matches, 4); }
#[test]
fn identical_sequences_all_matches() {
let a = chars("hello");
let al = align(&a, &a);
assert_eq!(al.counts.distance(), 0);
assert_eq!(al.counts.matches, 5);
assert!(al.ops.iter().all(|op| matches!(op, EditOp::Match { .. })));
}
#[test]
fn empty_source_is_all_insertions() {
let a: Vec<char> = Vec::new();
let b = chars("abc");
let al = align(&a, &b);
assert_eq!(al.counts.insertions, 3);
assert_eq!(al.counts.distance(), 3);
assert_eq!(al.ops.len(), 3);
}
#[test]
fn empty_target_is_all_deletions() {
let a = chars("abc");
let b: Vec<char> = Vec::new();
let al = align(&a, &b);
assert_eq!(al.counts.deletions, 3);
assert_eq!(al.counts.distance(), 3);
}
#[test]
fn both_empty_is_no_ops() {
let a: Vec<char> = Vec::new();
let b: Vec<char> = Vec::new();
let al = align(&a, &b);
assert!(al.ops.is_empty());
assert_eq!(al.counts.distance(), 0);
}
#[test]
fn ops_reconstruct_target() {
let a = chars("intention");
let b = chars("execution");
let al = align(&a, &b);
let mut rebuilt: Vec<char> = Vec::new();
for op in &al.ops {
match *op {
EditOp::Match { tgt, .. } | EditOp::Substitute { tgt, .. } => rebuilt.push(b[tgt]),
EditOp::Insert { tgt } => rebuilt.push(b[tgt]),
EditOp::Delete { .. } => {}
}
}
assert_eq!(rebuilt, b);
}
#[test]
fn ops_consume_source_in_order() {
let a = chars("abcdef");
let b = chars("azced");
let al = align(&a, &b);
let mut consumed: Vec<usize> = Vec::new();
for op in &al.ops {
match *op {
EditOp::Match { src, .. }
| EditOp::Substitute { src, .. }
| EditOp::Delete { src } => consumed.push(src),
EditOp::Insert { .. } => {}
}
}
let expected: Vec<usize> = (0..a.len()).collect();
assert_eq!(consumed, expected);
}
#[test]
fn distance_matches_scalar_reference() {
let pairs = [
("flaw", "lawn"),
("gumbo", "gambol"),
("book", "back"),
("", "nonempty"),
("same", "same"),
];
for (x, y) in pairs {
let a = chars(x);
let b = chars(y);
let via_align = edit_distance_aligned(&a, &b);
let via_scalar = crate::metrics::metrics::edit_distance(&a, &b);
assert_eq!(via_align, via_scalar, "{x} vs {y}");
}
}
#[test]
fn op_count_equals_alignment_length_invariant() {
let a = chars("alignment");
let b = chars("assignment");
let al = align(&a, &b);
let c = al.counts;
assert_eq!(c.matches + c.substitutions + c.deletions, a.len());
assert_eq!(c.matches + c.substitutions + c.insertions, b.len());
}
#[test]
fn word_error_rate_basic() {
let r = vec!["the", "cat", "sat"];
let h = vec!["the", "cat", "sit"];
let wer = word_error_rate(&r, &h).expect("wer");
assert!((wer - 1.0 / 3.0).abs() < 1e-9, "wer={wer}");
}
#[test]
fn word_error_rate_perfect_is_zero() {
let r = vec!["a", "b", "c"];
let wer = word_error_rate(&r, &r).expect("wer");
assert!(wer.abs() < 1e-12);
}
#[test]
fn word_error_rate_empty_reference() {
let r: Vec<&str> = Vec::new();
let h = vec!["x", "y"];
let wer = word_error_rate(&r, &h).expect("wer");
assert!((wer - 2.0).abs() < 1e-12);
}
#[test]
fn character_error_rate_string_api() {
let cer = character_error_rate("kitten", "sitting").expect("cer");
assert!((cer - 3.0 / 6.0).abs() < 1e-9, "cer={cer}");
}
#[test]
fn works_on_token_ids() {
let a = vec![1usize, 2, 3, 4];
let b = vec![1usize, 3, 4];
let al = align(&a, &b);
assert_eq!(al.counts.distance(), 1);
assert_eq!(al.counts.deletions, 1);
}
}