use std::collections::HashMap;
use crate::error::{SeqError, SeqResult};
pub fn char_ngram_counts(text: &str, n: usize) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
if n == 0 {
return counts;
}
let chars: Vec<char> = text.chars().collect();
if chars.len() < n {
return counts;
}
for i in 0..=chars.len() - n {
let ngram: String = chars[i..i + n].iter().collect();
*counts.entry(ngram).or_insert(0) += 1;
}
counts
}
pub fn word_ngram_counts(text: &str, n: usize) -> HashMap<String, usize> {
let mut counts: HashMap<String, usize> = HashMap::new();
if n == 0 {
return counts;
}
let words: Vec<String> = text.split_whitespace().map(|w| w.to_lowercase()).collect();
if words.len() < n {
return counts;
}
for i in 0..=words.len() - n {
let ngram = words[i..i + n].join(" ");
*counts.entry(ngram).or_insert(0) += 1;
}
counts
}
pub fn ngram_intersection(a: &HashMap<String, usize>, b: &HashMap<String, usize>) -> usize {
a.iter()
.filter_map(|(ngram, &cnt_a)| b.get(ngram).map(|&cnt_b| cnt_a.min(cnt_b)))
.sum()
}
pub fn f_beta(precision: f64, recall: f64, beta: f64) -> f64 {
const EPS: f64 = 1e-15;
if precision + recall < EPS {
return 0.0;
}
let beta2 = beta * beta;
(1.0 + beta2) * precision * recall / (beta2 * precision + recall + EPS)
}
pub fn chrf_score(hypothesis: &str, reference: &str, max_char_n: usize, beta: f64) -> f64 {
if max_char_n == 0 {
return 0.0;
}
let mut total_f = 0.0_f64;
let mut n_orders = 0usize;
for n in 1..=max_char_n {
let hyp_counts = char_ngram_counts(hypothesis, n);
let ref_counts = char_ngram_counts(reference, n);
let hyp_total: usize = hyp_counts.values().sum();
let ref_total: usize = ref_counts.values().sum();
if hyp_total == 0 && ref_total == 0 {
continue;
}
let matches = ngram_intersection(&hyp_counts, &ref_counts);
let prec = matches as f64 / (hyp_total as f64 + 1e-15);
let rec = matches as f64 / (ref_total as f64 + 1e-15);
total_f += f_beta(prec, rec, beta);
n_orders += 1;
}
if n_orders == 0 {
return 0.0;
}
total_f / n_orders as f64
}
pub fn chrf_plus_plus(hypothesis: &str, reference: &str, max_char_n: usize, beta: f64) -> f64 {
let mut f_scores: Vec<f64> = Vec::new();
for n in 1..=max_char_n {
let hyp_counts = char_ngram_counts(hypothesis, n);
let ref_counts = char_ngram_counts(reference, n);
let hyp_total: usize = hyp_counts.values().sum();
let ref_total: usize = ref_counts.values().sum();
if hyp_total == 0 && ref_total == 0 {
continue;
}
let matches = ngram_intersection(&hyp_counts, &ref_counts);
let prec = matches as f64 / (hyp_total as f64 + 1e-15);
let rec = matches as f64 / (ref_total as f64 + 1e-15);
f_scores.push(f_beta(prec, rec, beta));
}
for n in [1usize, 2] {
let hyp_counts = word_ngram_counts(hypothesis, n);
let ref_counts = word_ngram_counts(reference, n);
let hyp_total: usize = hyp_counts.values().sum();
let ref_total: usize = ref_counts.values().sum();
let matches = ngram_intersection(&hyp_counts, &ref_counts);
let prec = matches as f64 / (hyp_total as f64 + 1e-15);
let rec = matches as f64 / (ref_total as f64 + 1e-15);
f_scores.push(f_beta(prec, rec, beta));
}
if f_scores.is_empty() {
return 0.0;
}
f_scores.iter().sum::<f64>() / f_scores.len() as f64
}
pub fn corpus_chrf(
hypotheses: &[&str],
references: &[&str],
max_char_n: usize,
beta: f64,
) -> SeqResult<f64> {
if hypotheses.is_empty() || references.is_empty() {
return Err(SeqError::EmptyInput);
}
if hypotheses.len() != references.len() {
return Err(SeqError::LengthMismatch {
a: hypotheses.len(),
b: references.len(),
});
}
let total: f64 = hypotheses
.iter()
.zip(references.iter())
.map(|(&h, &r)| chrf_score(h, r, max_char_n, beta))
.sum();
Ok(total / hypotheses.len() as f64)
}
pub fn corpus_chrf_plus_plus(
hypotheses: &[&str],
references: &[&str],
max_char_n: usize,
beta: f64,
) -> SeqResult<f64> {
if hypotheses.is_empty() || references.is_empty() {
return Err(SeqError::EmptyInput);
}
if hypotheses.len() != references.len() {
return Err(SeqError::LengthMismatch {
a: hypotheses.len(),
b: references.len(),
});
}
let total: f64 = hypotheses
.iter()
.zip(references.iter())
.map(|(&h, &r)| chrf_plus_plus(h, r, max_char_n, beta))
.sum();
Ok(total / hypotheses.len() as f64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn char_ngrams_bigrams_abc() {
let counts = char_ngram_counts("abc", 2);
assert_eq!(counts.get("ab").copied().unwrap_or(0), 1);
assert_eq!(counts.get("bc").copied().unwrap_or(0), 1);
assert_eq!(counts.len(), 2);
}
#[test]
fn char_ngrams_n_greater_than_len_returns_empty() {
let counts = char_ngram_counts("ab", 5);
assert!(counts.is_empty());
}
#[test]
fn char_ngrams_handles_zero_n() {
let counts = char_ngram_counts("hello", 0);
assert!(counts.is_empty());
}
#[test]
fn char_ngrams_repeated_chars() {
let counts = char_ngram_counts("aaa", 2);
assert_eq!(counts.get("aa").copied().unwrap_or(0), 2);
assert_eq!(counts.len(), 1);
}
#[test]
fn char_ngrams_unicode_cjk() {
let counts = char_ngram_counts("你好", 1);
assert_eq!(counts.len(), 2);
assert_eq!(counts.get("你").copied().unwrap_or(0), 1);
assert_eq!(counts.get("好").copied().unwrap_or(0), 1);
}
#[test]
fn word_ngrams_unigrams_hello_world() {
let counts = word_ngram_counts("hello world", 1);
assert_eq!(counts.get("hello").copied().unwrap_or(0), 1);
assert_eq!(counts.get("world").copied().unwrap_or(0), 1);
assert_eq!(counts.len(), 2);
}
#[test]
fn word_ngrams_bigrams_with_repetition() {
let counts = word_ngram_counts("A B A", 2);
assert_eq!(counts.get("a b").copied().unwrap_or(0), 1);
assert_eq!(counts.get("b a").copied().unwrap_or(0), 1);
assert_eq!(counts.len(), 2);
}
#[test]
fn word_ngrams_n_greater_than_word_count_returns_empty() {
let counts = word_ngram_counts("hello", 3);
assert!(counts.is_empty());
}
#[test]
fn word_ngrams_empty_string_returns_empty() {
let counts = word_ngram_counts("", 1);
assert!(counts.is_empty());
}
#[test]
fn word_ngrams_lowercases_tokens() {
let counts = word_ngram_counts("Hello WORLD", 1);
assert!(counts.contains_key("hello"));
assert!(counts.contains_key("world"));
assert!(!counts.contains_key("Hello"));
}
#[test]
fn ngram_intersection_identical_maps() {
let mut a = HashMap::new();
a.insert("ab".to_string(), 2usize);
a.insert("bc".to_string(), 3usize);
let b = a.clone();
assert_eq!(ngram_intersection(&a, &b), 5);
}
#[test]
fn ngram_intersection_disjoint_maps() {
let mut a = HashMap::new();
a.insert("ab".to_string(), 2usize);
let mut b = HashMap::new();
b.insert("cd".to_string(), 3usize);
assert_eq!(ngram_intersection(&a, &b), 0);
}
#[test]
fn ngram_intersection_partial_overlap() {
let mut a = HashMap::new();
a.insert("ab".to_string(), 3usize);
a.insert("cd".to_string(), 2usize);
let mut b = HashMap::new();
b.insert("ab".to_string(), 2usize); b.insert("ef".to_string(), 5usize);
assert_eq!(ngram_intersection(&a, &b), 2);
}
#[test]
fn f_beta_both_zero_returns_zero() {
assert_eq!(f_beta(0.0, 0.0, 1.0), 0.0);
}
#[test]
fn f_beta_perfect_precision_and_recall() {
let f = f_beta(1.0, 1.0, 1.0);
assert!((f - 1.0).abs() < 1e-10, "got {f}");
}
#[test]
fn f_beta_harmonic_mean_beta_one() {
let f = f_beta(0.5, 1.0, 1.0);
let expected = 2.0 * 0.5 / (0.5 + 1.0);
assert!((f - expected).abs() < 1e-9, "got {f}");
}
#[test]
fn f_beta_recall_heavy_beta_two() {
let f = f_beta(0.5, 1.0, 2.0);
let expected = 5.0 * 0.5 * 1.0 / (4.0 * 0.5 + 1.0);
assert!((f - expected).abs() < 1e-9, "got {f}");
}
#[test]
fn chrf_score_identical_strings_is_one() {
let s = "the cat sat on the mat";
let score = chrf_score(s, s, 6, 2.0);
assert!((score - 1.0).abs() < 1e-9, "got {score}");
}
#[test]
fn chrf_score_empty_hypothesis_is_zero() {
let score = chrf_score("", "some reference text", 6, 2.0);
assert!(score < 1e-10, "got {score}");
}
#[test]
fn chrf_score_completely_different_strings_is_low() {
let score = chrf_score("aaaa", "bbbb", 6, 2.0);
assert!(score < 1e-10, "got {score}");
}
#[test]
fn chrf_score_partial_overlap_between_zero_and_one() {
let hyp = "the cat";
let reference = "the dog";
let score = chrf_score(hyp, reference, 6, 2.0);
assert!(score > 0.0 && score < 1.0, "got {score}");
}
#[test]
fn chrf_score_max_char_n_zero_returns_zero() {
let score = chrf_score("hello", "hello", 0, 2.0);
assert_eq!(score, 0.0);
}
#[test]
fn chrf_score_symmetric_for_equal_length_identical() {
let s = "same";
let score_b1 = chrf_score(s, s, 4, 1.0);
let score_b2 = chrf_score(s, s, 4, 2.0);
assert!((score_b1 - 1.0).abs() < 1e-9);
assert!((score_b2 - 1.0).abs() < 1e-9);
}
#[test]
fn chrf_plus_plus_identical_strings_is_one() {
let s = "the quick brown fox";
let score = chrf_plus_plus(s, s, 6, 2.0);
assert!((score - 1.0).abs() < 1e-9, "got {score}");
}
#[test]
fn chrf_plus_plus_different_strings_differs_from_chrf() {
let hyp = "i like cats";
let reference = "she loves dogs";
let score = chrf_plus_plus(hyp, reference, 6, 2.0);
assert!(
(0.0..=1.0 + 1e-10).contains(&score),
"score out of range: {score}"
);
}
#[test]
fn chrf_plus_plus_identical_has_same_score_as_chrf_is_one() {
let s = "hello world";
let chrf = chrf_score(s, s, 6, 2.0);
let chrfpp = chrf_plus_plus(s, s, 6, 2.0);
assert!((chrf - 1.0).abs() < 1e-9);
assert!((chrfpp - 1.0).abs() < 1e-9);
}
#[test]
fn chrf_plus_plus_empty_hypothesis_is_low() {
let score = chrf_plus_plus("", "some reference text here", 6, 2.0);
assert!(score < 0.5, "got {score}");
}
#[test]
fn corpus_chrf_all_perfect_is_one() {
let hyps = vec!["hello world", "foo bar baz"];
let refs = vec!["hello world", "foo bar baz"];
let score = corpus_chrf(&hyps, &refs, 6, 2.0).expect("ok");
assert!((score - 1.0).abs() < 1e-9, "got {score}");
}
#[test]
fn corpus_chrf_empty_error() {
let err = corpus_chrf(&[], &[], 6, 2.0).unwrap_err();
assert!(matches!(err, SeqError::EmptyInput));
}
#[test]
fn corpus_chrf_length_mismatch_error() {
let hyps = vec!["hello"];
let refs = vec!["hello", "world"];
let err = corpus_chrf(&hyps, &refs, 6, 2.0).unwrap_err();
assert!(matches!(err, SeqError::LengthMismatch { .. }));
}
#[test]
fn corpus_chrf_single_sentence_matches_sentence_level() {
let hyp = "the cat sat";
let reference = "the dog sat";
let sentence = chrf_score(hyp, reference, 6, 2.0);
let corpus = corpus_chrf(&[hyp], &[reference], 6, 2.0).expect("ok");
assert!((sentence - corpus).abs() < 1e-12);
}
#[test]
fn corpus_chrf_plus_plus_all_perfect_is_one() {
let hyps = vec!["the quick brown fox", "jumps over the lazy dog"];
let refs = vec!["the quick brown fox", "jumps over the lazy dog"];
let score = corpus_chrf_plus_plus(&hyps, &refs, 6, 2.0).expect("ok");
assert!((score - 1.0).abs() < 1e-9, "got {score}");
}
#[test]
fn corpus_chrf_plus_plus_empty_error() {
let err = corpus_chrf_plus_plus(&[], &[], 6, 2.0).unwrap_err();
assert!(matches!(err, SeqError::EmptyInput));
}
#[test]
fn corpus_chrf_plus_plus_length_mismatch_error() {
let hyps = vec!["a", "b", "c"];
let refs = vec!["a", "b"];
let err = corpus_chrf_plus_plus(&hyps, &refs, 6, 2.0).unwrap_err();
assert!(matches!(err, SeqError::LengthMismatch { .. }));
}
#[test]
fn corpus_chrf_plus_plus_single_sentence_matches_sentence_level() {
let hyp = "translation output here";
let reference = "the reference sentence here";
let sentence = chrf_plus_plus(hyp, reference, 6, 2.0);
let corpus = corpus_chrf_plus_plus(&[hyp], &[reference], 6, 2.0).expect("ok");
assert!((sentence - corpus).abs() < 1e-12);
}
#[test]
fn chrf_score_in_range_for_partial_match() {
let hyp = "the cat sat on the mat";
let reference = "the cat sat on the floor";
let score = chrf_score(hyp, reference, 6, 2.0);
assert!(
score > 0.5 && score < 1.0,
"expected partial match score, got {score}"
);
}
}