use crate::error::{SeqError, SeqResult};
use crate::metrics::edit_distance::{EditOp, align};
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct TerResult {
pub score: f64,
pub num_edits: usize,
pub num_shifts: usize,
pub ref_len: usize,
}
const MAX_SHIFT_LEN: usize = 10;
fn edit_distance<T: Eq>(a: &[T], b: &[T]) -> usize {
align(a, b).counts.distance()
}
fn aligned_mask<T: Eq>(hyp: &[T], ref_: &[T]) -> Vec<bool> {
let mut mask = vec![false; hyp.len()];
for op in align(hyp, ref_).ops {
if let EditOp::Match { src, .. } = op {
if src < mask.len() {
mask[src] = true;
}
}
}
mask
}
fn apply_shift<T: Clone>(seq: &[T], from: usize, len: usize, to: usize) -> Vec<T> {
let mut block: Vec<T> = seq[from..from + len].to_vec();
let mut rest: Vec<T> = Vec::with_capacity(seq.len() - len);
rest.extend_from_slice(&seq[..from]);
rest.extend_from_slice(&seq[from + len..]);
let mut out = Vec::with_capacity(seq.len());
out.extend_from_slice(&rest[..to]);
out.append(&mut block);
out.extend_from_slice(&rest[to..]);
out
}
fn best_shift<T: Eq + Clone>(
hyp: &[T],
ref_: &[T],
current: usize,
) -> Option<(usize, usize, usize, usize)> {
let h = hyp.len();
if h == 0 {
return None;
}
let mask = aligned_mask(hyp, ref_);
let mut best: Option<(usize, usize, usize, usize)> = None;
let max_len = MAX_SHIFT_LEN.min(h);
for len in 1..=max_len {
for from in 0..=h - len {
let block_aligned = (from..from + len).all(|p| mask[p]);
if block_aligned {
continue;
}
if !occurs_in(ref_, &hyp[from..from + len]) {
continue;
}
let rest_len = h - len;
for to in 0..=rest_len {
if to == from {
continue;
}
let shifted = apply_shift(hyp, from, len, to);
let dist = edit_distance(&shifted, ref_);
if dist + 1 < current {
let better = match best {
None => true,
Some((_, _, _, bd)) => dist < bd,
};
if better {
best = Some((from, len, to, dist));
}
}
}
}
}
best
}
fn occurs_in<T: Eq>(haystack: &[T], needle: &[T]) -> bool {
if needle.is_empty() || needle.len() > haystack.len() {
return false;
}
let last = haystack.len() - needle.len();
for start in 0..=last {
if haystack[start..start + needle.len()] == *needle {
return true;
}
}
false
}
fn ter_tokens<T: Eq + Clone>(hyp: &[T], ref_: &[T]) -> SeqResult<TerResult> {
let ref_len = ref_.len();
if ref_len == 0 {
return Err(SeqError::EmptyInput);
}
let mut current_hyp: Vec<T> = hyp.to_vec();
let mut num_shifts = 0usize;
let mut current_dist = edit_distance(¤t_hyp, ref_);
loop {
if current_dist == 0 {
break;
}
match best_shift(¤t_hyp, ref_, current_dist) {
Some((from, len, to, new_dist)) => {
current_hyp = apply_shift(¤t_hyp, from, len, to);
current_dist = new_dist;
num_shifts += 1;
}
None => break,
}
}
let num_edits = current_dist;
let score = (num_edits + num_shifts) as f64 / ref_len as f64;
Ok(TerResult {
score,
num_edits,
num_shifts,
ref_len,
})
}
pub fn ter(hyp: &[&str], ref_: &[&str]) -> SeqResult<TerResult> {
ter_tokens(hyp, ref_)
}
pub fn ter_ids(hyp: &[usize], ref_: &[usize]) -> SeqResult<TerResult> {
ter_tokens(hyp, ref_)
}
#[cfg(test)]
mod tests {
use super::*;
fn words(s: &str) -> Vec<&str> {
s.split_whitespace().collect()
}
#[test]
fn identical_sentences_score_zero() {
let h = words("the cat sat on the mat");
let r = words("the cat sat on the mat");
let res = ter(&h, &r).expect("ter");
assert_eq!(res.num_edits, 0);
assert_eq!(res.num_shifts, 0);
assert!(res.score.abs() < 1e-12, "score={}", res.score);
assert_eq!(res.ref_len, 6);
}
#[test]
fn pure_substitutions_no_shifts() {
let r = words("a b c d e");
let h = words("a x c y e");
let res = ter(&h, &r).expect("ter");
assert_eq!(res.num_shifts, 0);
assert_eq!(res.num_edits, 2);
assert!((res.score - 2.0 / 5.0).abs() < 1e-12, "score={}", res.score);
}
#[test]
fn block_reordering_finds_shift_and_lowers_score() {
let r = words("A B C D E F");
let h = words("E F A B C D");
let no_shift = align(&h, &r).counts.distance() as f64 / r.len() as f64;
let res = ter(&h, &r).expect("ter");
assert!(
res.num_shifts >= 1,
"expected a shift, got {}",
res.num_shifts
);
assert!(
res.score < no_shift - 1e-12,
"shifted score {} should beat no-shift {}",
res.score,
no_shift
);
assert_eq!(res.num_edits, 0);
assert_eq!(res.num_shifts, 1);
assert!((res.score - 1.0 / 6.0).abs() < 1e-12, "score={}", res.score);
}
#[test]
fn single_swap_is_one_shift() {
let r = words("a b c");
let h = words("b a c");
let res = ter(&h, &r).expect("ter");
assert_eq!(res.num_edits, 0);
assert_eq!(res.num_shifts, 1);
assert!((res.score - 1.0 / 3.0).abs() < 1e-12);
}
#[test]
fn insertions_counted() {
let r = words("the quick fox");
let h = words("the quick brown fox");
let res = ter(&h, &r).expect("ter");
assert_eq!(res.num_shifts, 0);
assert_eq!(res.num_edits, 1);
assert!((res.score - 1.0 / 3.0).abs() < 1e-12);
}
#[test]
fn deletions_counted() {
let r = words("the quick brown fox");
let h = words("the quick fox");
let res = ter(&h, &r).expect("ter");
assert_eq!(res.num_shifts, 0);
assert_eq!(res.num_edits, 1);
assert!((res.score - 1.0 / 4.0).abs() < 1e-12);
}
#[test]
fn normalisation_by_reference_length() {
let r_short = words("a b");
let h_short = words("a x");
let res_short = ter(&h_short, &r_short).expect("ter");
assert!((res_short.score - 1.0 / 2.0).abs() < 1e-12);
let r_long = words("a b c d");
let h_long = words("a x c d");
let res_long = ter(&h_long, &r_long).expect("ter");
assert!((res_long.score - 1.0 / 4.0).abs() < 1e-12);
}
#[test]
fn empty_reference_is_error() {
let h = words("a b c");
let r: Vec<&str> = Vec::new();
assert!(ter(&h, &r).is_err());
}
#[test]
fn empty_hypothesis_against_reference() {
let h: Vec<&str> = Vec::new();
let r = words("a b c");
let res = ter(&h, &r).expect("ter");
assert_eq!(res.num_shifts, 0);
assert_eq!(res.num_edits, 3);
assert!((res.score - 1.0).abs() < 1e-12);
}
#[test]
fn token_id_variant_matches_string_variant() {
let h_ids = vec![4usize, 5, 0, 1, 2, 3];
let r_ids = vec![0usize, 1, 2, 3, 4, 5];
let res = ter_ids(&h_ids, &r_ids).expect("ter");
assert_eq!(res.num_edits, 0);
assert_eq!(res.num_shifts, 1);
assert!((res.score - 1.0 / 6.0).abs() < 1e-12);
}
#[test]
fn shift_never_increases_total_cost() {
let cases = [
("the cat sat", "the sat cat"),
("one two three four", "four three two one"),
("a b c d e", "b c d e a"),
("hello world foo bar", "foo bar hello world"),
];
for (hs, rs) in cases {
let h = words(hs);
let r = words(rs);
let baseline = align(&h, &r).counts.distance() as f64 / r.len() as f64;
let res = ter(&h, &r).expect("ter");
assert!(
res.score <= baseline + 1e-12,
"case ({hs} | {rs}): ter {} > baseline {}",
res.score,
baseline
);
}
}
#[test]
fn far_block_move_is_single_shift() {
let r = words("w x a b c d y z");
let h = words("a b w x c d y z");
let res = ter(&h, &r).expect("ter");
assert_eq!(res.num_edits, 0);
assert_eq!(res.num_shifts, 1);
assert!((res.score - 1.0 / 8.0).abs() < 1e-12, "score={}", res.score);
}
}