Skip to main content

oxicuda_seq/metrics/
edit_distance.rs

1//! Levenshtein edit distance with alignment backtrace.
2//!
3//! Unlike the scalar [`crate::metrics::metrics::edit_distance`], which returns
4//! only the minimum number of edits, this module reconstructs the full
5//! **optimal edit script** by backtracing through the dynamic-programming
6//! table. The script is a sequence of [`EditOp`] values that transforms the
7//! source sequence into the target with the fewest insertions, deletions, and
8//! substitutions, and is the basis for ASR Word/Character Error Rate breakdowns
9//! and alignment visualisation.
10//!
11//! The standard Wagner-Fischer recurrence is used:
12//! ```text
13//! D[i][0] = i,  D[0][j] = j
14//! D[i][j] = min( D[i-1][j]   + 1,                 (deletion)
15//!                D[i][j-1]   + 1,                 (insertion)
16//!                D[i-1][j-1] + (a_i ≠ b_j) )       (match / substitution)
17//! ```
18//! Ties are broken deterministically with the priority
19//! match/substitute → delete → insert, which yields a stable, left-aligned
20//! script.
21
22use crate::error::SeqResult;
23
24/// A single elementary edit in an alignment between a source and target.
25#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum EditOp {
27    /// Source and target symbols are equal — kept as-is. Holds `(src, tgt)` indices.
28    Match { src: usize, tgt: usize },
29    /// Source symbol replaced by a different target symbol. Holds `(src, tgt)`.
30    Substitute { src: usize, tgt: usize },
31    /// Source symbol deleted (present in source, absent in target). Holds `src`.
32    Delete { src: usize },
33    /// Target symbol inserted (absent in source, present in target). Holds `tgt`.
34    Insert { tgt: usize },
35}
36
37/// Counts of each operation kind in an edit script, plus the total distance.
38#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub struct EditCounts {
40    /// Number of matched (unchanged) symbols.
41    pub matches: usize,
42    /// Number of substitutions.
43    pub substitutions: usize,
44    /// Number of deletions.
45    pub deletions: usize,
46    /// Number of insertions.
47    pub insertions: usize,
48}
49
50impl EditCounts {
51    /// Levenshtein distance = substitutions + deletions + insertions.
52    #[must_use]
53    pub fn distance(&self) -> usize {
54        self.substitutions + self.deletions + self.insertions
55    }
56}
57
58/// Result of an edit-distance alignment: the optimal script and its counts.
59#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct EditAlignment {
61    /// The optimal edit script, ordered source-left to source-right.
62    pub ops: Vec<EditOp>,
63    /// Aggregate operation counts.
64    pub counts: EditCounts,
65}
66
67/// Compute the Levenshtein distance and an optimal alignment between `a` and `b`.
68///
69/// Works for any sequence of `Eq` elements (characters, tokens, words …). The
70/// returned [`EditAlignment`] contains both the operation list and the counts.
71pub fn align<T: Eq>(a: &[T], b: &[T]) -> EditAlignment {
72    let m = a.len();
73    let n = b.len();
74    let cols = n + 1;
75    // dp[i][j] in flat layout.
76    let mut dp = vec![0usize; (m + 1) * cols];
77    for i in 0..=m {
78        dp[i * cols] = i;
79    }
80    for j in 0..=n {
81        dp[j] = j;
82    }
83    for i in 1..=m {
84        for j in 1..=n {
85            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
86            let del = dp[(i - 1) * cols + j] + 1;
87            let ins = dp[i * cols + (j - 1)] + 1;
88            let sub = dp[(i - 1) * cols + (j - 1)] + cost;
89            dp[i * cols + j] = del.min(ins).min(sub);
90        }
91    }
92
93    // Backtrace from (m, n) to (0, 0).
94    let mut ops_rev: Vec<EditOp> = Vec::new();
95    let mut counts = EditCounts::default();
96    let mut i = m;
97    let mut j = n;
98    while i > 0 || j > 0 {
99        let here = dp[i * cols + j];
100        // Priority: diagonal (match/sub) → delete → insert.
101        if i > 0 && j > 0 {
102            let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
103            if here == dp[(i - 1) * cols + (j - 1)] + cost {
104                if cost == 0 {
105                    ops_rev.push(EditOp::Match {
106                        src: i - 1,
107                        tgt: j - 1,
108                    });
109                    counts.matches += 1;
110                } else {
111                    ops_rev.push(EditOp::Substitute {
112                        src: i - 1,
113                        tgt: j - 1,
114                    });
115                    counts.substitutions += 1;
116                }
117                i -= 1;
118                j -= 1;
119                continue;
120            }
121        }
122        if i > 0 && here == dp[(i - 1) * cols + j] + 1 {
123            ops_rev.push(EditOp::Delete { src: i - 1 });
124            counts.deletions += 1;
125            i -= 1;
126            continue;
127        }
128        // Remaining case: insertion (j > 0 must hold here).
129        ops_rev.push(EditOp::Insert { tgt: j - 1 });
130        counts.insertions += 1;
131        j -= 1;
132    }
133
134    ops_rev.reverse();
135    EditAlignment {
136        ops: ops_rev,
137        counts,
138    }
139}
140
141/// Levenshtein distance via the backtracing aligner (kept for symmetry).
142///
143/// Equivalent in value to [`crate::metrics::metrics::edit_distance`].
144pub fn edit_distance_aligned<T: Eq>(a: &[T], b: &[T]) -> usize {
145    align(a, b).counts.distance()
146}
147
148/// Word Error Rate: `(S + D + I) / max(N_ref, 1)` where `N_ref = reference.len()`.
149///
150/// `reference` and `hypothesis` are token sequences (e.g. words). Returns `0.0`
151/// for two empty inputs and `f64::from(hyp.len())` semantics scaled by the
152/// reference length otherwise.
153pub fn word_error_rate<T: Eq>(reference: &[T], hypothesis: &[T]) -> SeqResult<f64> {
154    let counts = align(reference, hypothesis).counts;
155    let n = reference.len();
156    if n == 0 {
157        // No reference tokens: rate is 0 if hypothesis also empty, else the
158        // number of spurious insertions (conventionally normalised by 1).
159        return Ok(counts.distance() as f64);
160    }
161    Ok(counts.distance() as f64 / n as f64)
162}
163
164/// Character Error Rate over `char` sequences derived from `&str` inputs.
165pub fn character_error_rate(reference: &str, hypothesis: &str) -> SeqResult<f64> {
166    let r: Vec<char> = reference.chars().collect();
167    let h: Vec<char> = hypothesis.chars().collect();
168    word_error_rate(&r, &h)
169}
170
171// ─── Tests ───────────────────────────────────────────────────────────────────
172
173#[cfg(test)]
174mod tests {
175    use super::*;
176
177    fn chars(s: &str) -> Vec<char> {
178        s.chars().collect()
179    }
180
181    #[test]
182    fn kitten_to_sitting_distance_three() {
183        let a = chars("kitten");
184        let b = chars("sitting");
185        let al = align(&a, &b);
186        assert_eq!(al.counts.distance(), 3);
187    }
188
189    #[test]
190    fn kitten_to_sitting_op_breakdown() {
191        // kitten → sitting: substitute k→s, substitute e→i, insert g.
192        let a = chars("kitten");
193        let b = chars("sitting");
194        let al = align(&a, &b);
195        assert_eq!(al.counts.substitutions, 2);
196        assert_eq!(al.counts.insertions, 1);
197        assert_eq!(al.counts.deletions, 0);
198        assert_eq!(al.counts.matches, 4); // i, t, t, n
199    }
200
201    #[test]
202    fn identical_sequences_all_matches() {
203        let a = chars("hello");
204        let al = align(&a, &a);
205        assert_eq!(al.counts.distance(), 0);
206        assert_eq!(al.counts.matches, 5);
207        assert!(al.ops.iter().all(|op| matches!(op, EditOp::Match { .. })));
208    }
209
210    #[test]
211    fn empty_source_is_all_insertions() {
212        let a: Vec<char> = Vec::new();
213        let b = chars("abc");
214        let al = align(&a, &b);
215        assert_eq!(al.counts.insertions, 3);
216        assert_eq!(al.counts.distance(), 3);
217        assert_eq!(al.ops.len(), 3);
218    }
219
220    #[test]
221    fn empty_target_is_all_deletions() {
222        let a = chars("abc");
223        let b: Vec<char> = Vec::new();
224        let al = align(&a, &b);
225        assert_eq!(al.counts.deletions, 3);
226        assert_eq!(al.counts.distance(), 3);
227    }
228
229    #[test]
230    fn both_empty_is_no_ops() {
231        let a: Vec<char> = Vec::new();
232        let b: Vec<char> = Vec::new();
233        let al = align(&a, &b);
234        assert!(al.ops.is_empty());
235        assert_eq!(al.counts.distance(), 0);
236    }
237
238    #[test]
239    fn ops_reconstruct_target() {
240        // Applying the script to the source must reproduce the target exactly.
241        let a = chars("intention");
242        let b = chars("execution");
243        let al = align(&a, &b);
244        let mut rebuilt: Vec<char> = Vec::new();
245        for op in &al.ops {
246            match *op {
247                EditOp::Match { tgt, .. } | EditOp::Substitute { tgt, .. } => rebuilt.push(b[tgt]),
248                EditOp::Insert { tgt } => rebuilt.push(b[tgt]),
249                EditOp::Delete { .. } => {}
250            }
251        }
252        assert_eq!(rebuilt, b);
253    }
254
255    #[test]
256    fn ops_consume_source_in_order() {
257        // The source indices touched by Match/Substitute/Delete must be 0..m in order.
258        let a = chars("abcdef");
259        let b = chars("azced");
260        let al = align(&a, &b);
261        let mut consumed: Vec<usize> = Vec::new();
262        for op in &al.ops {
263            match *op {
264                EditOp::Match { src, .. }
265                | EditOp::Substitute { src, .. }
266                | EditOp::Delete { src } => consumed.push(src),
267                EditOp::Insert { .. } => {}
268            }
269        }
270        let expected: Vec<usize> = (0..a.len()).collect();
271        assert_eq!(consumed, expected);
272    }
273
274    #[test]
275    fn distance_matches_scalar_reference() {
276        // Agreement with the existing scalar edit_distance on several pairs.
277        let pairs = [
278            ("flaw", "lawn"),
279            ("gumbo", "gambol"),
280            ("book", "back"),
281            ("", "nonempty"),
282            ("same", "same"),
283        ];
284        for (x, y) in pairs {
285            let a = chars(x);
286            let b = chars(y);
287            let via_align = edit_distance_aligned(&a, &b);
288            let via_scalar = crate::metrics::metrics::edit_distance(&a, &b);
289            assert_eq!(via_align, via_scalar, "{x} vs {y}");
290        }
291    }
292
293    #[test]
294    fn op_count_equals_alignment_length_invariant() {
295        // matches + subs is the number of aligned columns where both advance.
296        let a = chars("alignment");
297        let b = chars("assignment");
298        let al = align(&a, &b);
299        let c = al.counts;
300        // Source length = matches + subs + deletions.
301        assert_eq!(c.matches + c.substitutions + c.deletions, a.len());
302        // Target length = matches + subs + insertions.
303        assert_eq!(c.matches + c.substitutions + c.insertions, b.len());
304    }
305
306    #[test]
307    fn word_error_rate_basic() {
308        // ref: the cat sat ; hyp: the cat sit  → 1 substitution / 3 = 0.333…
309        let r = vec!["the", "cat", "sat"];
310        let h = vec!["the", "cat", "sit"];
311        let wer = word_error_rate(&r, &h).expect("wer");
312        assert!((wer - 1.0 / 3.0).abs() < 1e-9, "wer={wer}");
313    }
314
315    #[test]
316    fn word_error_rate_perfect_is_zero() {
317        let r = vec!["a", "b", "c"];
318        let wer = word_error_rate(&r, &r).expect("wer");
319        assert!(wer.abs() < 1e-12);
320    }
321
322    #[test]
323    fn word_error_rate_empty_reference() {
324        let r: Vec<&str> = Vec::new();
325        let h = vec!["x", "y"];
326        let wer = word_error_rate(&r, &h).expect("wer");
327        assert!((wer - 2.0).abs() < 1e-12);
328    }
329
330    #[test]
331    fn character_error_rate_string_api() {
332        let cer = character_error_rate("kitten", "sitting").expect("cer");
333        assert!((cer - 3.0 / 6.0).abs() < 1e-9, "cer={cer}");
334    }
335
336    #[test]
337    fn works_on_token_ids() {
338        let a = vec![1usize, 2, 3, 4];
339        let b = vec![1usize, 3, 4];
340        let al = align(&a, &b);
341        // Deleting the token `2` is the single optimal edit.
342        assert_eq!(al.counts.distance(), 1);
343        assert_eq!(al.counts.deletions, 1);
344    }
345}