1use crate::error::{SeqError, SeqResult};
36use crate::metrics::edit_distance::{EditOp, align};
37
38#[derive(Debug, Clone, Copy, PartialEq)]
40pub struct TerResult {
41 pub score: f64,
43 pub num_edits: usize,
45 pub num_shifts: usize,
47 pub ref_len: usize,
49}
50
51const MAX_SHIFT_LEN: usize = 10;
56
57fn edit_distance<T: Eq>(a: &[T], b: &[T]) -> usize {
59 align(a, b).counts.distance()
60}
61
62fn aligned_mask<T: Eq>(hyp: &[T], ref_: &[T]) -> Vec<bool> {
67 let mut mask = vec![false; hyp.len()];
68 for op in align(hyp, ref_).ops {
69 if let EditOp::Match { src, .. } = op {
70 if src < mask.len() {
71 mask[src] = true;
72 }
73 }
74 }
75 mask
76}
77
78fn apply_shift<T: Clone>(seq: &[T], from: usize, len: usize, to: usize) -> Vec<T> {
81 let mut block: Vec<T> = seq[from..from + len].to_vec();
82 let mut rest: Vec<T> = Vec::with_capacity(seq.len() - len);
83 rest.extend_from_slice(&seq[..from]);
84 rest.extend_from_slice(&seq[from + len..]);
85 let mut out = Vec::with_capacity(seq.len());
86 out.extend_from_slice(&rest[..to]);
87 out.append(&mut block);
88 out.extend_from_slice(&rest[to..]);
89 out
90}
91
92fn best_shift<T: Eq + Clone>(
97 hyp: &[T],
98 ref_: &[T],
99 current: usize,
100) -> Option<(usize, usize, usize, usize)> {
101 let h = hyp.len();
102 if h == 0 {
103 return None;
104 }
105 let mask = aligned_mask(hyp, ref_);
106
107 let mut best: Option<(usize, usize, usize, usize)> = None;
108 let max_len = MAX_SHIFT_LEN.min(h);
110 for len in 1..=max_len {
111 for from in 0..=h - len {
112 let block_aligned = (from..from + len).all(|p| mask[p]);
115 if block_aligned {
116 continue;
117 }
118 if !occurs_in(ref_, &hyp[from..from + len]) {
120 continue;
121 }
122 let rest_len = h - len;
124 for to in 0..=rest_len {
125 if to == from {
127 continue;
128 }
129 let shifted = apply_shift(hyp, from, len, to);
130 let dist = edit_distance(&shifted, ref_);
131 if dist + 1 < current {
133 let better = match best {
134 None => true,
135 Some((_, _, _, bd)) => dist < bd,
136 };
137 if better {
138 best = Some((from, len, to, dist));
139 }
140 }
141 }
142 }
143 }
144 best
145}
146
147fn occurs_in<T: Eq>(haystack: &[T], needle: &[T]) -> bool {
149 if needle.is_empty() || needle.len() > haystack.len() {
150 return false;
151 }
152 let last = haystack.len() - needle.len();
153 for start in 0..=last {
154 if haystack[start..start + needle.len()] == *needle {
155 return true;
156 }
157 }
158 false
159}
160
161fn ter_tokens<T: Eq + Clone>(hyp: &[T], ref_: &[T]) -> SeqResult<TerResult> {
168 let ref_len = ref_.len();
169 if ref_len == 0 {
170 return Err(SeqError::EmptyInput);
171 }
172
173 let mut current_hyp: Vec<T> = hyp.to_vec();
174 let mut num_shifts = 0usize;
175 let mut current_dist = edit_distance(¤t_hyp, ref_);
176
177 loop {
183 if current_dist == 0 {
184 break;
185 }
186 match best_shift(¤t_hyp, ref_, current_dist) {
187 Some((from, len, to, new_dist)) => {
188 current_hyp = apply_shift(¤t_hyp, from, len, to);
189 current_dist = new_dist;
190 num_shifts += 1;
191 }
192 None => break,
193 }
194 }
195
196 let num_edits = current_dist;
197 let score = (num_edits + num_shifts) as f64 / ref_len as f64;
198 Ok(TerResult {
199 score,
200 num_edits,
201 num_shifts,
202 ref_len,
203 })
204}
205
206pub fn ter(hyp: &[&str], ref_: &[&str]) -> SeqResult<TerResult> {
211 ter_tokens(hyp, ref_)
212}
213
214pub fn ter_ids(hyp: &[usize], ref_: &[usize]) -> SeqResult<TerResult> {
216 ter_tokens(hyp, ref_)
217}
218
219#[cfg(test)]
220mod tests {
221 use super::*;
222
223 fn words(s: &str) -> Vec<&str> {
224 s.split_whitespace().collect()
225 }
226
227 #[test]
228 fn identical_sentences_score_zero() {
229 let h = words("the cat sat on the mat");
230 let r = words("the cat sat on the mat");
231 let res = ter(&h, &r).expect("ter");
232 assert_eq!(res.num_edits, 0);
233 assert_eq!(res.num_shifts, 0);
234 assert!(res.score.abs() < 1e-12, "score={}", res.score);
235 assert_eq!(res.ref_len, 6);
236 }
237
238 #[test]
239 fn pure_substitutions_no_shifts() {
240 let r = words("a b c d e");
242 let h = words("a x c y e");
243 let res = ter(&h, &r).expect("ter");
244 assert_eq!(res.num_shifts, 0);
245 assert_eq!(res.num_edits, 2);
246 assert!((res.score - 2.0 / 5.0).abs() < 1e-12, "score={}", res.score);
247 }
248
249 #[test]
250 fn block_reordering_finds_shift_and_lowers_score() {
251 let r = words("A B C D E F");
255 let h = words("E F A B C D");
256 let no_shift = align(&h, &r).counts.distance() as f64 / r.len() as f64;
257 let res = ter(&h, &r).expect("ter");
258 assert!(
259 res.num_shifts >= 1,
260 "expected a shift, got {}",
261 res.num_shifts
262 );
263 assert!(
264 res.score < no_shift - 1e-12,
265 "shifted score {} should beat no-shift {}",
266 res.score,
267 no_shift
268 );
269 assert_eq!(res.num_edits, 0);
271 assert_eq!(res.num_shifts, 1);
272 assert!((res.score - 1.0 / 6.0).abs() < 1e-12, "score={}", res.score);
273 }
274
275 #[test]
276 fn single_swap_is_one_shift() {
277 let r = words("a b c");
279 let h = words("b a c");
280 let res = ter(&h, &r).expect("ter");
281 assert_eq!(res.num_edits, 0);
282 assert_eq!(res.num_shifts, 1);
283 assert!((res.score - 1.0 / 3.0).abs() < 1e-12);
284 }
285
286 #[test]
287 fn insertions_counted() {
288 let r = words("the quick fox");
290 let h = words("the quick brown fox");
291 let res = ter(&h, &r).expect("ter");
292 assert_eq!(res.num_shifts, 0);
293 assert_eq!(res.num_edits, 1);
294 assert!((res.score - 1.0 / 3.0).abs() < 1e-12);
295 }
296
297 #[test]
298 fn deletions_counted() {
299 let r = words("the quick brown fox");
301 let h = words("the quick fox");
302 let res = ter(&h, &r).expect("ter");
303 assert_eq!(res.num_shifts, 0);
304 assert_eq!(res.num_edits, 1);
305 assert!((res.score - 1.0 / 4.0).abs() < 1e-12);
306 }
307
308 #[test]
309 fn normalisation_by_reference_length() {
310 let r_short = words("a b");
312 let h_short = words("a x");
313 let res_short = ter(&h_short, &r_short).expect("ter");
314 assert!((res_short.score - 1.0 / 2.0).abs() < 1e-12);
315
316 let r_long = words("a b c d");
317 let h_long = words("a x c d");
318 let res_long = ter(&h_long, &r_long).expect("ter");
319 assert!((res_long.score - 1.0 / 4.0).abs() < 1e-12);
320 }
321
322 #[test]
323 fn empty_reference_is_error() {
324 let h = words("a b c");
325 let r: Vec<&str> = Vec::new();
326 assert!(ter(&h, &r).is_err());
327 }
328
329 #[test]
330 fn empty_hypothesis_against_reference() {
331 let h: Vec<&str> = Vec::new();
333 let r = words("a b c");
334 let res = ter(&h, &r).expect("ter");
335 assert_eq!(res.num_shifts, 0);
336 assert_eq!(res.num_edits, 3);
337 assert!((res.score - 1.0).abs() < 1e-12);
338 }
339
340 #[test]
341 fn token_id_variant_matches_string_variant() {
342 let h_ids = vec![4usize, 5, 0, 1, 2, 3];
344 let r_ids = vec![0usize, 1, 2, 3, 4, 5];
345 let res = ter_ids(&h_ids, &r_ids).expect("ter");
346 assert_eq!(res.num_edits, 0);
347 assert_eq!(res.num_shifts, 1);
348 assert!((res.score - 1.0 / 6.0).abs() < 1e-12);
349 }
350
351 #[test]
352 fn shift_never_increases_total_cost() {
353 let cases = [
356 ("the cat sat", "the sat cat"),
357 ("one two three four", "four three two one"),
358 ("a b c d e", "b c d e a"),
359 ("hello world foo bar", "foo bar hello world"),
360 ];
361 for (hs, rs) in cases {
362 let h = words(hs);
363 let r = words(rs);
364 let baseline = align(&h, &r).counts.distance() as f64 / r.len() as f64;
365 let res = ter(&h, &r).expect("ter");
366 assert!(
367 res.score <= baseline + 1e-12,
368 "case ({hs} | {rs}): ter {} > baseline {}",
369 res.score,
370 baseline
371 );
372 }
373 }
374
375 #[test]
376 fn far_block_move_is_single_shift() {
377 let r = words("w x a b c d y z");
379 let h = words("a b w x c d y z");
380 let res = ter(&h, &r).expect("ter");
381 assert_eq!(res.num_edits, 0);
382 assert_eq!(res.num_shifts, 1);
383 assert!((res.score - 1.0 / 8.0).abs() < 1e-12, "score={}", res.score);
384 }
385}