use super::{Metric, MetricInput, MetricResult};
use crate::error::{Result, TrustformersError};
use crate::evaluation::bridge::NlpAdapter;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct QuestionAnsweringMetric {
predictions: Vec<String>,
references: Vec<String>,
}
impl QuestionAnsweringMetric {
pub fn new() -> Self {
Self {
predictions: Vec::new(),
references: Vec::new(),
}
}
fn normalize_answer(&self, s: &str) -> String {
s.to_lowercase()
.chars()
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
.collect::<String>()
.split_whitespace()
.collect::<Vec<&str>>()
.join(" ")
}
}
impl Metric for QuestionAnsweringMetric {
fn add_batch(&mut self, predictions: &MetricInput, references: &MetricInput) -> Result<()> {
match (predictions, references) {
(MetricInput::Text(pred), MetricInput::Text(ref_)) => {
self.predictions.extend(pred.clone());
self.references.extend(ref_.clone());
Ok(())
},
_ => Err(TrustformersError::invalid_input_simple("Invalid input types for QA metric: expected Text for both predictions and references".to_string()
)),
}
}
fn compute(&self) -> Result<MetricResult> {
if self.predictions.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"No data available for metric computation".to_string(),
));
}
let num_pairs = self.predictions.len().min(self.references.len());
let mut exact_matches = 0;
for (pred, ref_) in self.predictions.iter().zip(self.references.iter()) {
let pred_norm = self.normalize_answer(pred);
let ref_norm = self.normalize_answer(ref_);
if pred_norm == ref_norm {
exact_matches += 1;
}
}
let exact_match_score =
if num_pairs > 0 { exact_matches as f64 / num_pairs as f64 } else { 0.0 };
let norm_preds: Vec<String> =
self.predictions.iter().map(|s| self.normalize_answer(s)).collect();
let norm_refs: Vec<String> =
self.references.iter().map(|s| self.normalize_answer(s)).collect();
let preds_input = MetricInput::Text(norm_preds);
let refs_input = MetricInput::Text(norm_refs);
let mut token_f1_adapter = NlpAdapter::token_f1();
token_f1_adapter.add_batch(&preds_input, &refs_input)?;
let f1_result = token_f1_adapter.compute()?;
let avg_f1 = f1_result.value;
let mut details = HashMap::new();
details.insert("exact_match".to_string(), exact_match_score);
details.insert("f1".to_string(), avg_f1);
Ok(MetricResult {
name: "question_answering".to_string(),
value: (exact_match_score + avg_f1) / 2.0,
details,
metadata: HashMap::new(),
})
}
fn reset(&mut self) {
self.predictions.clear();
self.references.clear();
}
fn name(&self) -> &str {
"question_answering"
}
}
impl Default for QuestionAnsweringMetric {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qa_metric_basic() {
let mut metric = QuestionAnsweringMetric::new();
let predictions = MetricInput::Text(vec![
"Barack Obama".to_string(),
"July 20, 1969".to_string(),
]);
let references = MetricInput::Text(vec![
"Barack Hussein Obama".to_string(),
"July 20, 1969".to_string(),
]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "question_answering");
assert!(result.value >= 0.0 && result.value <= 1.0);
assert!(result.details.contains_key("exact_match"));
assert!(result.details.contains_key("f1"));
}
#[test]
fn test_qa_metric_perfect_match() {
let mut metric = QuestionAnsweringMetric::new();
let predictions = MetricInput::Text(vec!["Barack Obama".to_string()]);
let references = MetricInput::Text(vec!["Barack Obama".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("exact_match"), Some(&1.0));
assert_eq!(result.details.get("f1"), Some(&1.0));
assert_eq!(result.value, 1.0); }
#[test]
fn test_qa_metric_no_match() {
let mut metric = QuestionAnsweringMetric::new();
let predictions = MetricInput::Text(vec!["completely different".to_string()]);
let references = MetricInput::Text(vec!["Barack Obama".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("exact_match"), Some(&0.0));
assert_eq!(result.details.get("f1"), Some(&0.0));
assert_eq!(result.value, 0.0); }
#[test]
fn test_qa_metric_partial_match() {
let mut metric = QuestionAnsweringMetric::new();
let predictions = MetricInput::Text(vec!["Barack Obama".to_string()]);
let references = MetricInput::Text(vec!["Barack Hussein Obama".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("exact_match"), Some(&0.0));
let f1 = result.details.get("f1").expect("expected value not found");
assert!(*f1 > 0.0 && *f1 < 1.0);
}
#[test]
fn test_qa_metric_normalization() {
let mut metric = QuestionAnsweringMetric::new();
let predictions = MetricInput::Text(vec![" Barack H. Obama! ".to_string()]);
let references = MetricInput::Text(vec!["Barack H Obama".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("exact_match"), Some(&1.0)); }
#[test]
fn test_qa_metric_empty_answer() {
let mut metric = QuestionAnsweringMetric::new();
let predictions = MetricInput::Text(vec!["".to_string()]);
let references = MetricInput::Text(vec!["Barack Obama".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("exact_match"), Some(&0.0));
assert_eq!(result.details.get("f1"), Some(&0.0)); }
#[test]
fn test_qa_metric_multiple_batches() {
let mut metric = QuestionAnsweringMetric::new();
metric
.add_batch(
&MetricInput::Text(vec!["Barack Obama".to_string()]),
&MetricInput::Text(vec!["Barack Obama".to_string()]),
)
.expect("operation failed in test");
metric
.add_batch(
&MetricInput::Text(vec!["different answer".to_string()]),
&MetricInput::Text(vec!["Barack Obama".to_string()]),
)
.expect("operation failed in test");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("exact_match"), Some(&0.5));
assert_eq!(result.details.get("f1"), Some(&0.5));
assert_eq!(result.value, 0.5);
}
#[test]
fn test_qa_metric_reset() {
let mut metric = QuestionAnsweringMetric::new();
metric
.add_batch(
&MetricInput::Text(vec!["Barack Obama".to_string()]),
&MetricInput::Text(vec!["Barack Obama".to_string()]),
)
.expect("operation failed in test");
metric.reset();
assert!(metric.compute().is_err());
}
#[test]
fn test_qa_metric_invalid_input() {
let mut metric = QuestionAnsweringMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1]);
let references = MetricInput::Text(vec!["Barack Obama".to_string()]);
let result = metric.add_batch(&predictions, &references);
assert!(result.is_err());
}
#[test]
fn test_answer_normalization() {
let metric = QuestionAnsweringMetric::new();
assert_eq!(
metric.normalize_answer(" Barack H. Obama! "),
"barack h obama"
);
assert_eq!(metric.normalize_answer("PARIS"), "paris");
assert_eq!(metric.normalize_answer("New York City"), "new york city");
assert_eq!(metric.normalize_answer("123-456-7890"), "1234567890");
assert_eq!(metric.normalize_answer(""), "");
assert_eq!(metric.normalize_answer(" "), "");
}
#[test]
fn test_qa_metric_case_sensitivity() {
let mut metric = QuestionAnsweringMetric::new();
let predictions = MetricInput::Text(vec!["BARACK OBAMA".to_string()]);
let references = MetricInput::Text(vec!["barack obama".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("exact_match"), Some(&1.0));
assert_eq!(result.details.get("f1"), Some(&1.0));
}
}