use std::collections::HashMap;
pub trait Metric: Send + Sync {
fn evaluate(&self, prediction: &str, expected: &str) -> f64;
fn name(&self) -> &'static str;
}
impl Metric for Box<dyn Metric> {
#[inline]
fn evaluate(&self, prediction: &str, expected: &str) -> f64 {
(**self).evaluate(prediction, expected)
}
#[inline]
fn name(&self) -> &'static str {
(**self).name()
}
}
impl<M: Metric> Metric for &M {
#[inline]
fn evaluate(&self, prediction: &str, expected: &str) -> f64 {
(**self).evaluate(prediction, expected)
}
#[inline]
fn name(&self) -> &'static str {
(**self).name()
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct ExactMatch;
impl Metric for ExactMatch {
#[inline]
fn evaluate(&self, prediction: &str, expected: &str) -> f64 {
if prediction.trim() == expected.trim() {
1.0
} else {
0.0
}
}
#[inline]
fn name(&self) -> &'static str {
"exact_match"
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct Contains;
impl Metric for Contains {
#[inline]
fn evaluate(&self, prediction: &str, expected: &str) -> f64 {
if prediction.trim().contains(expected.trim()) {
1.0
} else {
0.0
}
}
#[inline]
fn name(&self) -> &'static str {
"contains"
}
}
#[derive(Debug, Clone, Copy, Default)]
pub struct F1Token;
impl F1Token {
fn tokenize(s: &str) -> Vec<&str> {
s.split_whitespace().collect()
}
fn word_counts<'a>(tokens: &[&'a str]) -> HashMap<&'a str, u32> {
let mut counts = HashMap::new();
for &token in tokens {
*counts.entry(token).or_insert(0) += 1;
}
counts
}
}
impl Metric for F1Token {
fn evaluate(&self, prediction: &str, expected: &str) -> f64 {
let pred_tokens = Self::tokenize(prediction.trim());
let exp_tokens = Self::tokenize(expected.trim());
if pred_tokens.is_empty() && exp_tokens.is_empty() {
return 1.0;
}
if pred_tokens.is_empty() || exp_tokens.is_empty() {
return 0.0;
}
let pred_counts = Self::word_counts(&pred_tokens);
let exp_counts = Self::word_counts(&exp_tokens);
let mut shared = 0u32;
for (token, &pred_count) in &pred_counts {
if let Some(&exp_count) = exp_counts.get(token) {
shared += pred_count.min(exp_count);
}
}
if shared == 0 {
return 0.0;
}
let precision = shared as f64 / pred_tokens.len() as f64;
let recall = shared as f64 / exp_tokens.len() as f64;
2.0 * precision * recall / (precision + recall)
}
#[inline]
fn name(&self) -> &'static str {
"f1_token"
}
}
pub struct FnMetric<F>
where
F: Fn(&str, &str) -> f64 + Send + Sync,
{
name: &'static str,
func: F,
}
impl<F> FnMetric<F>
where
F: Fn(&str, &str) -> f64 + Send + Sync,
{
pub fn new(name: &'static str, func: F) -> Self {
Self { name, func }
}
}
impl<F> Metric for FnMetric<F>
where
F: Fn(&str, &str) -> f64 + Send + Sync,
{
#[inline]
fn evaluate(&self, prediction: &str, expected: &str) -> f64 {
(self.func)(prediction, expected)
}
#[inline]
fn name(&self) -> &'static str {
self.name
}
}
impl<F> std::fmt::Debug for FnMetric<F>
where
F: Fn(&str, &str) -> f64 + Send + Sync,
{
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FnMetric")
.field("name", &self.name)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_exact_match_identical() {
let m = ExactMatch;
assert_eq!(m.evaluate("hello", "hello"), 1.0);
}
#[test]
fn test_exact_match_with_whitespace() {
let m = ExactMatch;
assert_eq!(m.evaluate(" hello ", "hello"), 1.0);
assert_eq!(m.evaluate("hello", " hello "), 1.0);
assert_eq!(m.evaluate("\thello\n", "hello"), 1.0);
}
#[test]
fn test_exact_match_different() {
let m = ExactMatch;
assert_eq!(m.evaluate("hello", "world"), 0.0);
assert_eq!(m.evaluate("Hello", "hello"), 0.0); }
#[test]
fn test_exact_match_empty() {
let m = ExactMatch;
assert_eq!(m.evaluate("", ""), 1.0);
assert_eq!(m.evaluate(" ", ""), 1.0);
assert_eq!(m.evaluate("", "hello"), 0.0);
}
#[test]
fn test_exact_match_name() {
assert_eq!(ExactMatch.name(), "exact_match");
}
#[test]
fn test_contains_substring() {
let m = Contains;
assert_eq!(m.evaluate("the answer is 42", "42"), 1.0);
assert_eq!(m.evaluate("42", "42"), 1.0);
}
#[test]
fn test_contains_no_match() {
let m = Contains;
assert_eq!(m.evaluate("the answer is 41", "42"), 0.0);
}
#[test]
fn test_contains_with_whitespace() {
let m = Contains;
assert_eq!(m.evaluate(" the answer is 42 ", "42"), 1.0);
assert_eq!(m.evaluate("the answer is 42", " 42 "), 1.0);
}
#[test]
fn test_contains_empty_expected() {
let m = Contains;
assert_eq!(m.evaluate("anything", ""), 1.0);
}
#[test]
fn test_contains_name() {
assert_eq!(Contains.name(), "contains");
}
#[test]
fn test_f1_identical() {
let m = F1Token;
assert_eq!(
m.evaluate("the quick brown fox", "the quick brown fox"),
1.0
);
}
#[test]
fn test_f1_partial_overlap() {
let m = F1Token;
let score = m.evaluate("the quick brown fox", "the slow brown fox");
assert!((score - 0.75).abs() < 1e-9, "score = {}", score);
}
#[test]
fn test_f1_no_overlap() {
let m = F1Token;
assert_eq!(m.evaluate("hello world", "foo bar"), 0.0);
}
#[test]
fn test_f1_empty_both() {
let m = F1Token;
assert_eq!(m.evaluate("", ""), 1.0);
}
#[test]
fn test_f1_one_empty() {
let m = F1Token;
assert_eq!(m.evaluate("hello", ""), 0.0);
assert_eq!(m.evaluate("", "hello"), 0.0);
}
#[test]
fn test_f1_different_lengths() {
let m = F1Token;
let score = m.evaluate("a b c d", "a b");
assert!((score - 2.0 / 3.0).abs() < 1e-9, "score = {}", score);
}
#[test]
fn test_f1_repeated_tokens() {
let m = F1Token;
let score = m.evaluate("a a b", "a b b");
assert!((score - 2.0 / 3.0).abs() < 1e-9, "score = {}", score);
}
#[test]
fn test_f1_name() {
assert_eq!(F1Token.name(), "f1_token");
}
#[test]
fn test_fn_metric_basic() {
let m = FnMetric::new("custom", |pred, exp| if pred == exp { 1.0 } else { 0.0 });
assert_eq!(m.evaluate("hello", "hello"), 1.0);
assert_eq!(m.evaluate("hello", "world"), 0.0);
}
#[test]
fn test_fn_metric_continuous() {
let m = FnMetric::new("length_ratio", |pred, expected| {
let ratio = pred.len() as f64 / expected.len().max(1) as f64;
ratio.min(1.0)
});
assert_eq!(m.evaluate("abcd", "abcd"), 1.0);
assert!((m.evaluate("ab", "abcd") - 0.5).abs() < 1e-9);
}
#[test]
fn test_fn_metric_name() {
let m = FnMetric::new("my_metric", |_, _| 0.5);
assert_eq!(m.name(), "my_metric");
}
#[test]
fn test_fn_metric_debug() {
let m = FnMetric::new("test", |_, _| 0.0);
let debug = format!("{:?}", m);
assert!(debug.contains("FnMetric"));
assert!(debug.contains("test"));
}
#[test]
fn test_box_dyn_metric() {
let m: Box<dyn Metric> = Box::new(ExactMatch);
assert_eq!(m.evaluate("hello", "hello"), 1.0);
assert_eq!(m.evaluate("hello", "world"), 0.0);
assert_eq!(m.name(), "exact_match");
}
#[test]
fn test_ref_metric() {
let m = ExactMatch;
let r: &ExactMatch = &m;
assert_eq!(r.evaluate("hello", "hello"), 1.0);
assert_eq!(r.name(), "exact_match");
}
#[test]
fn test_metric_via_box_dyn_dispatch() {
let metrics: Vec<Box<dyn Metric>> = vec![
Box::new(ExactMatch),
Box::new(Contains),
Box::new(F1Token),
Box::new(FnMetric::new("always_half", |_, _| 0.5)),
];
for m in &metrics {
let _ = m.evaluate("hello", "hello");
let _ = m.name();
}
assert_eq!(metrics[0].evaluate("hello", "hello"), 1.0);
assert_eq!(metrics[1].evaluate("hello world", "world"), 1.0);
assert_eq!(metrics[2].evaluate("hello", "hello"), 1.0);
assert_eq!(metrics[3].evaluate("anything", "anything"), 0.5);
}
#[test]
fn test_unicode_handling() {
let m = ExactMatch;
assert_eq!(m.evaluate("cafe\u{0301}", "cafe\u{0301}"), 1.0);
let m = Contains;
assert_eq!(m.evaluate("I ate at the cafe\u{0301}", "cafe\u{0301}"), 1.0);
let m = F1Token;
assert_eq!(m.evaluate("cafe\u{0301} latte", "cafe\u{0301} latte"), 1.0);
}
#[test]
fn test_multiline_strings() {
let m = ExactMatch;
assert_eq!(m.evaluate("line1\nline2", "line1\nline2"), 1.0);
let m = Contains;
assert_eq!(m.evaluate("line1\nline2\nline3", "line2"), 1.0);
}
#[test]
fn test_f1_single_token() {
let m = F1Token;
assert_eq!(m.evaluate("hello", "hello"), 1.0);
assert_eq!(m.evaluate("hello", "world"), 0.0);
}
#[test]
fn test_f1_superset() {
let m = F1Token;
let score = m.evaluate("a b c d", "a b");
assert!((score - 2.0 / 3.0).abs() < 1e-9);
let score = m.evaluate("a b", "a b c d");
assert!((score - 2.0 / 3.0).abs() < 1e-9);
}
}