use super::{Metric, MetricInput, MetricResult};
use crate::error::{Result, TrustformersError};
use crate::evaluation::bridge::NlpAdapter;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct GenerationMetric {
predictions: Vec<String>,
references: Vec<String>,
}
impl GenerationMetric {
pub fn new() -> Self {
Self {
predictions: Vec::new(),
references: Vec::new(),
}
}
pub fn predictions(&self) -> &Vec<String> {
&self.predictions
}
pub fn references(&self) -> &Vec<String> {
&self.references
}
}
impl Metric for GenerationMetric {
fn add_batch(&mut self, predictions: &MetricInput, references: &MetricInput) -> Result<()> {
match (predictions, references) {
(MetricInput::Text(pred), MetricInput::Text(ref_)) => {
self.predictions.extend(pred.clone());
self.references.extend(ref_.clone());
Ok(())
},
_ => Err(TrustformersError::invalid_input_simple("Invalid input types for generation metric: expected Text for both predictions and references".to_string()
)),
}
}
fn compute(&self) -> Result<MetricResult> {
if self.predictions.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"No data available for metric computation".to_string(),
));
}
let mut bleu_adapter = NlpAdapter::bleu(4, true);
let preds = MetricInput::Text(self.predictions.clone());
let refs = MetricInput::Text(self.references.clone());
bleu_adapter.add_batch(&preds, &refs)?;
let bleu_result = bleu_adapter.compute()?;
let bleu_score = bleu_result.value;
let mut details = HashMap::new();
details.insert("bleu_like".to_string(), bleu_score);
details.insert("bleu".to_string(), bleu_score);
Ok(MetricResult {
name: "generation".to_string(),
value: bleu_score,
details,
metadata: HashMap::new(),
})
}
fn reset(&mut self) {
self.predictions.clear();
self.references.clear();
}
fn name(&self) -> &str {
"generation"
}
}
impl Default for GenerationMetric {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_generation_metric_basic() {
let mut metric = GenerationMetric::new();
let predictions = MetricInput::Text(vec![
"the quick brown fox".to_string(),
"hello world".to_string(),
]);
let references = MetricInput::Text(vec![
"the quick brown fox jumps".to_string(),
"hello world test".to_string(),
]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "generation");
assert!(result.value >= 0.0 && result.value <= 1.0);
assert!(result.details.contains_key("bleu_like"));
}
#[test]
fn test_generation_metric_perfect_match() {
let mut metric = GenerationMetric::new();
let predictions = MetricInput::Text(vec![
"the quick brown fox jumps over the lazy dog".to_string()
]);
let references = MetricInput::Text(vec![
"the quick brown fox jumps over the lazy dog".to_string()
]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert!(
(result.value - 1.0).abs() < 1e-6,
"perfect match should give BLEU=1.0, got {}",
result.value
);
}
#[test]
fn test_generation_metric_no_overlap() {
let mut metric = GenerationMetric::new();
let predictions = MetricInput::Text(vec!["foo bar".to_string()]);
let references = MetricInput::Text(vec!["baz qux".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert!(
result.value < 0.5,
"no-overlap BLEU should be low, got {}",
result.value
);
}
#[test]
fn test_generation_metric_empty_text() {
let mut metric = GenerationMetric::new();
let predictions = MetricInput::Text(vec!["".to_string()]);
let references = MetricInput::Text(vec!["hello world".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, 0.0, "empty hypothesis should give BLEU=0.0");
}
#[test]
fn test_generation_metric_partial_overlap() {
let mut metric = GenerationMetric::new();
let predictions = MetricInput::Text(vec!["hello world test".to_string()]);
let references = MetricInput::Text(vec!["hello universe test".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert!(result.value > 0.0 && result.value < 1.0);
}
#[test]
fn test_generation_metric_reset() {
let mut metric = GenerationMetric::new();
let predictions = MetricInput::Text(vec!["hello".to_string()]);
let references = MetricInput::Text(vec!["world".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
metric.reset();
assert!(metric.compute().is_err());
}
#[test]
fn test_generation_metric_invalid_input() {
let mut metric = GenerationMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1]);
let references = MetricInput::Text(vec!["hello".to_string()]);
let result = metric.add_batch(&predictions, &references);
assert!(result.is_err());
}
#[test]
fn test_generation_metric_multiple_batches() {
let mut metric = GenerationMetric::new();
metric
.add_batch(
&MetricInput::Text(vec!["hello world".to_string()]),
&MetricInput::Text(vec!["hello world".to_string()]),
)
.expect("operation failed in test");
metric
.add_batch(
&MetricInput::Text(vec!["foo bar".to_string()]),
&MetricInput::Text(vec!["baz qux".to_string()]),
)
.expect("operation failed in test");
let result = metric.compute().expect("operation failed in test");
assert!(
result.value > 0.0 && result.value < 1.0,
"averaged BLEU should be in (0, 1), got {}",
result.value
);
}
}