use unicode_normalization::UnicodeNormalization;
use unicode_segmentation::UnicodeSegmentation;
use crate::alignment::{align, count_operations, rapidfuzz_char_distance, rapidfuzz_word_distance};
use crate::output::{AlignmentOutput, SplitKind, build_output};
#[allow(clippy::cast_precision_loss)]
#[inline]
fn to_f64(n: usize) -> f64 {
n as f64
}
fn split_words(text: &str) -> Vec<&str> {
text.split_whitespace().collect()
}
fn split_graphemes(text: &str) -> Vec<&str> {
text.graphemes(true).collect()
}
fn is_all_single_char_graphemes(text: &str) -> bool {
text.graphemes(true).all(|g| g.chars().count() == 1)
}
#[must_use]
pub fn wer(reference: &str, hypothesis: &str) -> f64 {
let ref_words = split_words(reference);
let hyp_words = split_words(hypothesis);
compute_wer_fast(&ref_words, &hyp_words)
}
#[must_use]
pub fn wer_sentences(ref_sentences: &[&str], hyp_sentences: &[&str]) -> f64 {
let all_ref: Vec<&str> = ref_sentences
.iter()
.flat_map(|s| s.split_whitespace())
.collect();
let all_hyp: Vec<&str> = hyp_sentences
.iter()
.flat_map(|s| s.split_whitespace())
.collect();
compute_wer_fast(&all_ref, &all_hyp)
}
fn compute_wer_fast<S: AsRef<str> + PartialEq>(reference: &[S], hypothesis: &[S]) -> f64 {
let n = reference.len();
if n == 0 {
return 0.0;
}
let dist = rapidfuzz_word_distance(reference, hypothesis);
to_f64(dist) / to_f64(n)
}
#[must_use]
pub fn cer(reference: &str, hypothesis: &str) -> f64 {
let ref_nfc: String = reference.nfc().collect();
let hyp_nfc: String = hypothesis.nfc().collect();
if is_all_single_char_graphemes(&ref_nfc) && is_all_single_char_graphemes(&hyp_nfc) {
let n = ref_nfc.chars().count();
if n == 0 {
return 0.0;
}
let dist = rapidfuzz_char_distance(ref_nfc.chars(), hyp_nfc.chars());
to_f64(dist) / to_f64(n)
} else {
let ref_chars = split_graphemes(&ref_nfc);
let hyp_chars = split_graphemes(&hyp_nfc);
compute_wer_fast(&ref_chars, &hyp_chars)
}
}
#[must_use]
pub fn mer(reference: &str, hypothesis: &str) -> f64 {
let ref_words = split_words(reference);
let hyp_words = split_words(hypothesis);
compute_mer(&ref_words, &hyp_words)
}
fn compute_mer<S: AsRef<str> + PartialEq>(reference: &[S], hypothesis: &[S]) -> f64 {
let ops = align(reference, hypothesis);
let counts = count_operations(&ops);
let total = counts.hits + counts.substitutions + counts.deletions + counts.insertions;
if total == 0 {
return 0.0;
}
let errors = counts.substitutions + counts.deletions + counts.insertions;
to_f64(errors) / to_f64(total)
}
#[must_use]
pub fn wip(reference: &str, hypothesis: &str) -> f64 {
let ref_words = split_words(reference);
let hyp_words = split_words(hypothesis);
compute_wip(&ref_words, &hyp_words)
}
fn compute_wip<S: AsRef<str> + PartialEq>(reference: &[S], hypothesis: &[S]) -> f64 {
let n = reference.len();
let h = hypothesis.len();
if n == 0 && h == 0 {
return 1.0;
}
if n == 0 || h == 0 {
return 0.0;
}
let ops = align(reference, hypothesis);
let counts = count_operations(&ops);
let hits = counts.hits;
if hits == 0 {
return 0.0;
}
let recall = to_f64(hits) / to_f64(n);
let precision =
to_f64(hits) / to_f64(hits + counts.substitutions + counts.deletions + counts.insertions);
recall * precision
}
#[must_use]
pub fn wil(reference: &str, hypothesis: &str) -> f64 {
1.0 - wip(reference, hypothesis)
}
#[must_use]
pub fn process_words(reference: &str, hypothesis: &str) -> AlignmentOutput {
let ref_words = split_words(reference);
let hyp_words = split_words(hypothesis);
let ops = align(&ref_words, &hyp_words);
let counts = count_operations(&ops);
build_output(&ref_words, &hyp_words, &ops, &counts, SplitKind::Words)
}
#[must_use]
pub fn process_chars(reference: &str, hypothesis: &str) -> AlignmentOutput {
let ref_nfc: String = reference.nfc().collect();
let hyp_nfc: String = hypothesis.nfc().collect();
let ref_chars = split_graphemes(&ref_nfc);
let hyp_chars = split_graphemes(&hyp_nfc);
let ops = align(&ref_chars, &hyp_chars);
let counts = count_operations(&ops);
build_output(&ref_chars, &hyp_chars, &ops, &counts, SplitKind::Graphemes)
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq(a: f64, b: f64) -> bool {
(a - b).abs() < 1e-10
}
macro_rules! assert_approx_eq {
($left:expr, $right:expr) => {
let left = $left;
let right = $right;
assert!(
approx_eq(left, right),
"assertion failed: {left:?} != {right:?}"
);
};
}
#[test]
fn wer_perfect_match() {
assert_approx_eq!(wer("hello world", "hello world"), 0.0);
}
#[test]
fn wer_all_substituted() {
let result = wer("hello world", "foo bar");
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn wer_with_deletion() {
let result = wer("the cat sat", "the sat");
assert!((result - 1.0 / 3.0).abs() < 1e-10);
}
#[test]
fn wer_with_insertion() {
let result = wer("the sat", "the cat sat");
assert!((result - 0.5).abs() < 1e-10);
}
#[test]
fn wer_empty_both() {
assert_approx_eq!(wer("", ""), 0.0);
}
#[test]
fn wer_empty_reference() {
assert_approx_eq!(wer("", "hello world"), 0.0);
}
#[test]
fn wer_empty_hypothesis() {
let result = wer("hello world", "");
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn wer_multiple_sentences() {
let ref_sents = ["the cat sat", "the dog ran"];
let hyp_sents = ["the cat sat", "the dog walked"];
let result = wer_sentences(&ref_sents, &hyp_sents);
assert!((result - 1.0 / 6.0).abs() < 1e-10);
}
#[test]
fn wer_whitespace_agnostic() {
assert_approx_eq!(wer(" hello world ", "hello world"), 0.0);
}
#[test]
fn wer_single_word_match() {
assert_approx_eq!(wer("hello", "hello"), 0.0);
}
#[test]
fn wer_single_word_mismatch() {
assert!((wer("hello", "world") - 1.0).abs() < 1e-10);
}
#[test]
fn cer_perfect_match() {
assert_approx_eq!(cer("hello", "hello"), 0.0);
}
#[test]
fn cer_with_substitution() {
let result = cer("abcde", "axcde");
assert!((result - 0.2).abs() < 1e-10);
}
#[test]
fn cer_empty_both() {
assert_approx_eq!(cer("", ""), 0.0);
}
#[test]
fn cer_empty_reference() {
assert_approx_eq!(cer("", "hello"), 0.0);
}
#[test]
fn cer_empty_hypothesis() {
let result = cer("abc", "");
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn cer_with_unicode() {
assert_approx_eq!(cer("hello 👋", "hello 👋"), 0.0);
}
#[test]
fn cer_grapheme_cluster_nfc_normalized() {
let result = cer("\u{00E9}", "e\u{0301}");
assert_approx_eq!(result, 0.0);
}
#[test]
fn cer_insertion() {
let result = cer("ac", "abc");
assert!((result - 1.0 / 2.0).abs() < 1e-10);
}
#[test]
fn cer_deletion() {
let result = cer("abc", "ac");
assert!((result - 1.0 / 3.0).abs() < 1e-10);
}
#[test]
fn mer_perfect_match() {
assert_approx_eq!(mer("hello world", "hello world"), 0.0);
}
#[test]
fn mer_with_insertion() {
let result = mer("a", "a b");
assert!((result - 0.5).abs() < 1e-10);
}
#[test]
fn mer_with_deletion() {
let result = mer("a b", "a");
assert!((result - 0.5).abs() < 1e-10);
}
#[test]
fn mer_empty_both() {
assert_approx_eq!(mer("", ""), 0.0);
}
#[test]
fn wip_perfect_match() {
assert_approx_eq!(wip("hello world", "hello world"), 1.0);
}
#[test]
fn wip_empty_both() {
assert_approx_eq!(wip("", ""), 1.0);
}
#[test]
fn wip_empty_reference() {
assert_approx_eq!(wip("", "hello"), 0.0);
}
#[test]
fn wip_empty_hypothesis() {
assert_approx_eq!(wip("hello", ""), 0.0);
}
#[test]
fn wip_no_match() {
assert_approx_eq!(wip("hello", "world"), 0.0);
}
#[test]
fn wil_perfect_match() {
assert_approx_eq!(wil("hello world", "hello world"), 0.0);
}
#[test]
fn wil_no_match() {
assert_approx_eq!(wil("hello", "world"), 1.0);
}
#[test]
fn wil_empty_both() {
assert_approx_eq!(wil("", ""), 0.0);
}
#[test]
fn process_words_returns_output() {
let output = process_words("the cat sat", "the cat sat on");
assert!((output.wer - 1.0 / 3.0).abs() < 1e-10);
assert!((output.mer - 0.25).abs() < 1e-10);
assert!((output.wip - 0.75).abs() < 1e-10);
assert!((output.wil - 0.25).abs() < 1e-10);
}
#[test]
fn process_words_empty() {
let output = process_words("", "");
assert_approx_eq!(output.wer, 0.0);
assert_eq!(output.hits, 0);
}
#[test]
fn process_words_cer_zero_for_word_mode() {
let output = process_words("hello", "world");
assert_approx_eq!(output.cer, 0.0);
}
#[test]
fn process_words_perfect() {
let output = process_words("a b c", "a b c");
assert_approx_eq!(output.wer, 0.0);
assert_eq!(output.hits, 3);
}
#[test]
fn process_chars_returns_output() {
let output = process_chars("abcde", "axcde");
assert!((output.cer - 0.2).abs() < 1e-10);
assert!((output.wer - 0.2).abs() < 1e-10);
}
#[test]
fn process_chars_empty() {
let output = process_chars("", "");
assert_approx_eq!(output.cer, 0.0);
}
#[test]
fn process_chars_perfect() {
let output = process_chars("hello", "hello");
assert_approx_eq!(output.cer, 0.0);
}
#[test]
fn wip_zero_hits_non_empty() {
assert_approx_eq!(wip("a", "b"), 0.0);
}
#[test]
fn mer_with_deletions_only() {
let result = mer("a b", "a");
assert!((result - 0.5).abs() < 1e-10);
}
#[test]
fn mer_all_errors() {
let result = mer("a b", "c d");
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn process_words_with_substitution_and_deletion() {
let output = process_words("a b c", "a c");
assert_eq!(output.ref_len, 3);
assert_eq!(output.hyp_len, 2);
assert_eq!(output.hits, 2);
assert_eq!(output.deletions, 1);
assert!((output.wer - 1.0 / 3.0).abs() < 1e-10);
}
#[test]
fn process_words_with_insertion_only() {
let output = process_words("a", "a b");
assert_eq!(output.ref_len, 1);
assert_eq!(output.hyp_len, 2);
assert_eq!(output.hits, 1);
assert_eq!(output.insertions, 1);
assert!((output.wer - 1.0).abs() < 1e-10);
}
#[test]
fn process_chars_with_all_operations() {
let output = process_chars("abcd", "axd");
assert_eq!(output.ref_len, 4);
assert_eq!(output.hits, 2);
assert_eq!(output.substitutions, 1);
assert_eq!(output.deletions, 1);
assert!((output.cer - 2.0 / 4.0).abs() < 1e-10);
}
#[test]
fn process_chars_display_with_cer() {
let output = process_chars("abc", "axc");
let display = format!("{output}");
assert!(display.contains("CER:"));
}
#[test]
fn wer_cer_consistency_with_process() {
let ref_text = "the quick brown fox jumps";
let hyp_text = "the slow brown fox jumped";
assert_eq!(
wer(ref_text, hyp_text),
process_words(ref_text, hyp_text).wer
);
assert!((cer(ref_text, hyp_text) - process_chars(ref_text, hyp_text).cer).abs() < 1e-10);
}
#[test]
fn cer_nfc_combining_chars_normalized() {
let result = cer("\u{00E9}", "\u{00E9}");
assert_approx_eq!(result, 0.0);
}
#[test]
fn cer_nfc_cjk_identical() {
let result = cer("你好世界", "你好世界");
assert_approx_eq!(result, 0.0);
}
#[test]
fn cer_nfc_cjk_with_error() {
let result = cer("你好世界", "你们世界");
assert!((result - 0.25).abs() < 1e-10);
}
#[test]
fn cer_nfc_long_cjk() {
let ref_text = "今天天气真好我们可以出去玩".repeat(100);
let hyp_text = "今天天气真好人我们可以出去玩".repeat(100);
let n_chars = ref_text.chars().count();
let expected = 100.0 / n_chars as f64;
assert!((cer(&ref_text, &hyp_text) - expected).abs() < 1e-10);
}
#[test]
fn cer_emoji_fallback() {
let result = cer("👨👩👧", "👨👩👧");
assert_approx_eq!(result, 0.0);
}
#[test]
fn cer_emoji_with_substitution() {
let result = cer("👨👩👧", "👨👩👦");
assert!((result - 1.0).abs() < 1e-10);
}
#[test]
fn cer_mixed_emoji_and_cjk() {
let result = cer("你好👋", "你好👋");
assert_approx_eq!(result, 0.0);
}
#[test]
fn cer_mixed_emoji_with_error() {
let result = cer("你好👋世界", "你好🌍世界");
assert!((result - 0.2).abs() < 1e-10);
}
#[test]
fn cer_various_substitutions() {
let result = cer("你好世界", "你们世纪");
assert!((result - 0.5).abs() < 1e-10);
}
#[test]
fn process_chars_nfc_normalization() {
let output = process_chars("\u{00E9}bc", "e\u{0301}bc");
assert_approx_eq!(output.cer, 0.0);
assert_eq!(output.hits, 3);
}
#[test]
fn process_chars_emoji_fallback() {
let output = process_chars("👨👩👧", "👨👩👦");
assert!((output.cer - 1.0).abs() < 1e-10);
assert_eq!(output.substitutions, 1);
}
#[test]
fn process_chars_cjk() {
let output = process_chars("你好世界", "你们世纪");
assert!((output.cer - 0.5).abs() < 1e-10);
assert_eq!(output.substitutions, 2);
assert_eq!(output.hits, 2);
}
}