1use crate::error::SeqResult;
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
26pub enum EditOp {
27 Match { src: usize, tgt: usize },
29 Substitute { src: usize, tgt: usize },
31 Delete { src: usize },
33 Insert { tgt: usize },
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
39pub struct EditCounts {
40 pub matches: usize,
42 pub substitutions: usize,
44 pub deletions: usize,
46 pub insertions: usize,
48}
49
50impl EditCounts {
51 #[must_use]
53 pub fn distance(&self) -> usize {
54 self.substitutions + self.deletions + self.insertions
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Eq)]
60pub struct EditAlignment {
61 pub ops: Vec<EditOp>,
63 pub counts: EditCounts,
65}
66
67pub fn align<T: Eq>(a: &[T], b: &[T]) -> EditAlignment {
72 let m = a.len();
73 let n = b.len();
74 let cols = n + 1;
75 let mut dp = vec![0usize; (m + 1) * cols];
77 for i in 0..=m {
78 dp[i * cols] = i;
79 }
80 for j in 0..=n {
81 dp[j] = j;
82 }
83 for i in 1..=m {
84 for j in 1..=n {
85 let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
86 let del = dp[(i - 1) * cols + j] + 1;
87 let ins = dp[i * cols + (j - 1)] + 1;
88 let sub = dp[(i - 1) * cols + (j - 1)] + cost;
89 dp[i * cols + j] = del.min(ins).min(sub);
90 }
91 }
92
93 let mut ops_rev: Vec<EditOp> = Vec::new();
95 let mut counts = EditCounts::default();
96 let mut i = m;
97 let mut j = n;
98 while i > 0 || j > 0 {
99 let here = dp[i * cols + j];
100 if i > 0 && j > 0 {
102 let cost = if a[i - 1] == b[j - 1] { 0 } else { 1 };
103 if here == dp[(i - 1) * cols + (j - 1)] + cost {
104 if cost == 0 {
105 ops_rev.push(EditOp::Match {
106 src: i - 1,
107 tgt: j - 1,
108 });
109 counts.matches += 1;
110 } else {
111 ops_rev.push(EditOp::Substitute {
112 src: i - 1,
113 tgt: j - 1,
114 });
115 counts.substitutions += 1;
116 }
117 i -= 1;
118 j -= 1;
119 continue;
120 }
121 }
122 if i > 0 && here == dp[(i - 1) * cols + j] + 1 {
123 ops_rev.push(EditOp::Delete { src: i - 1 });
124 counts.deletions += 1;
125 i -= 1;
126 continue;
127 }
128 ops_rev.push(EditOp::Insert { tgt: j - 1 });
130 counts.insertions += 1;
131 j -= 1;
132 }
133
134 ops_rev.reverse();
135 EditAlignment {
136 ops: ops_rev,
137 counts,
138 }
139}
140
141pub fn edit_distance_aligned<T: Eq>(a: &[T], b: &[T]) -> usize {
145 align(a, b).counts.distance()
146}
147
148pub fn word_error_rate<T: Eq>(reference: &[T], hypothesis: &[T]) -> SeqResult<f64> {
154 let counts = align(reference, hypothesis).counts;
155 let n = reference.len();
156 if n == 0 {
157 return Ok(counts.distance() as f64);
160 }
161 Ok(counts.distance() as f64 / n as f64)
162}
163
164pub fn character_error_rate(reference: &str, hypothesis: &str) -> SeqResult<f64> {
166 let r: Vec<char> = reference.chars().collect();
167 let h: Vec<char> = hypothesis.chars().collect();
168 word_error_rate(&r, &h)
169}
170
171#[cfg(test)]
174mod tests {
175 use super::*;
176
177 fn chars(s: &str) -> Vec<char> {
178 s.chars().collect()
179 }
180
181 #[test]
182 fn kitten_to_sitting_distance_three() {
183 let a = chars("kitten");
184 let b = chars("sitting");
185 let al = align(&a, &b);
186 assert_eq!(al.counts.distance(), 3);
187 }
188
189 #[test]
190 fn kitten_to_sitting_op_breakdown() {
191 let a = chars("kitten");
193 let b = chars("sitting");
194 let al = align(&a, &b);
195 assert_eq!(al.counts.substitutions, 2);
196 assert_eq!(al.counts.insertions, 1);
197 assert_eq!(al.counts.deletions, 0);
198 assert_eq!(al.counts.matches, 4); }
200
201 #[test]
202 fn identical_sequences_all_matches() {
203 let a = chars("hello");
204 let al = align(&a, &a);
205 assert_eq!(al.counts.distance(), 0);
206 assert_eq!(al.counts.matches, 5);
207 assert!(al.ops.iter().all(|op| matches!(op, EditOp::Match { .. })));
208 }
209
210 #[test]
211 fn empty_source_is_all_insertions() {
212 let a: Vec<char> = Vec::new();
213 let b = chars("abc");
214 let al = align(&a, &b);
215 assert_eq!(al.counts.insertions, 3);
216 assert_eq!(al.counts.distance(), 3);
217 assert_eq!(al.ops.len(), 3);
218 }
219
220 #[test]
221 fn empty_target_is_all_deletions() {
222 let a = chars("abc");
223 let b: Vec<char> = Vec::new();
224 let al = align(&a, &b);
225 assert_eq!(al.counts.deletions, 3);
226 assert_eq!(al.counts.distance(), 3);
227 }
228
229 #[test]
230 fn both_empty_is_no_ops() {
231 let a: Vec<char> = Vec::new();
232 let b: Vec<char> = Vec::new();
233 let al = align(&a, &b);
234 assert!(al.ops.is_empty());
235 assert_eq!(al.counts.distance(), 0);
236 }
237
238 #[test]
239 fn ops_reconstruct_target() {
240 let a = chars("intention");
242 let b = chars("execution");
243 let al = align(&a, &b);
244 let mut rebuilt: Vec<char> = Vec::new();
245 for op in &al.ops {
246 match *op {
247 EditOp::Match { tgt, .. } | EditOp::Substitute { tgt, .. } => rebuilt.push(b[tgt]),
248 EditOp::Insert { tgt } => rebuilt.push(b[tgt]),
249 EditOp::Delete { .. } => {}
250 }
251 }
252 assert_eq!(rebuilt, b);
253 }
254
255 #[test]
256 fn ops_consume_source_in_order() {
257 let a = chars("abcdef");
259 let b = chars("azced");
260 let al = align(&a, &b);
261 let mut consumed: Vec<usize> = Vec::new();
262 for op in &al.ops {
263 match *op {
264 EditOp::Match { src, .. }
265 | EditOp::Substitute { src, .. }
266 | EditOp::Delete { src } => consumed.push(src),
267 EditOp::Insert { .. } => {}
268 }
269 }
270 let expected: Vec<usize> = (0..a.len()).collect();
271 assert_eq!(consumed, expected);
272 }
273
274 #[test]
275 fn distance_matches_scalar_reference() {
276 let pairs = [
278 ("flaw", "lawn"),
279 ("gumbo", "gambol"),
280 ("book", "back"),
281 ("", "nonempty"),
282 ("same", "same"),
283 ];
284 for (x, y) in pairs {
285 let a = chars(x);
286 let b = chars(y);
287 let via_align = edit_distance_aligned(&a, &b);
288 let via_scalar = crate::metrics::metrics::edit_distance(&a, &b);
289 assert_eq!(via_align, via_scalar, "{x} vs {y}");
290 }
291 }
292
293 #[test]
294 fn op_count_equals_alignment_length_invariant() {
295 let a = chars("alignment");
297 let b = chars("assignment");
298 let al = align(&a, &b);
299 let c = al.counts;
300 assert_eq!(c.matches + c.substitutions + c.deletions, a.len());
302 assert_eq!(c.matches + c.substitutions + c.insertions, b.len());
304 }
305
306 #[test]
307 fn word_error_rate_basic() {
308 let r = vec!["the", "cat", "sat"];
310 let h = vec!["the", "cat", "sit"];
311 let wer = word_error_rate(&r, &h).expect("wer");
312 assert!((wer - 1.0 / 3.0).abs() < 1e-9, "wer={wer}");
313 }
314
315 #[test]
316 fn word_error_rate_perfect_is_zero() {
317 let r = vec!["a", "b", "c"];
318 let wer = word_error_rate(&r, &r).expect("wer");
319 assert!(wer.abs() < 1e-12);
320 }
321
322 #[test]
323 fn word_error_rate_empty_reference() {
324 let r: Vec<&str> = Vec::new();
325 let h = vec!["x", "y"];
326 let wer = word_error_rate(&r, &h).expect("wer");
327 assert!((wer - 2.0).abs() < 1e-12);
328 }
329
330 #[test]
331 fn character_error_rate_string_api() {
332 let cer = character_error_rate("kitten", "sitting").expect("cer");
333 assert!((cer - 3.0 / 6.0).abs() < 1e-9, "cer={cer}");
334 }
335
336 #[test]
337 fn works_on_token_ids() {
338 let a = vec![1usize, 2, 3, 4];
339 let b = vec![1usize, 3, 4];
340 let al = align(&a, &b);
341 assert_eq!(al.counts.distance(), 1);
343 assert_eq!(al.counts.deletions, 1);
344 }
345}