use super::{GenerationMetric, Metric, MetricInput, MetricResult};
use crate::error::Result;
use crate::evaluation::bridge::NlpAdapter;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct Seq2SeqMetric {
generation_metric: GenerationMetric,
}
impl Seq2SeqMetric {
pub fn new() -> Self {
Self {
generation_metric: GenerationMetric::new(),
}
}
}
impl Metric for Seq2SeqMetric {
fn add_batch(&mut self, predictions: &MetricInput, references: &MetricInput) -> Result<()> {
self.generation_metric.add_batch(predictions, references)
}
fn compute(&self) -> Result<MetricResult> {
let preds = MetricInput::Text(self.generation_metric.predictions().clone());
let refs = MetricInput::Text(self.generation_metric.references().clone());
if self.generation_metric.predictions().is_empty() {
return self.generation_metric.compute().map(|mut r| {
r.name = "seq2seq".to_string();
r
});
}
let mut rouge2_adapter = NlpAdapter::rouge_n(2);
rouge2_adapter.add_batch(&preds, &refs)?;
let rouge2_result = rouge2_adapter.compute()?;
let mut rouge_l_adapter = NlpAdapter::rouge_l();
rouge_l_adapter.add_batch(&preds, &refs)?;
let rouge_l_result = rouge_l_adapter.compute()?;
let primary = (rouge2_result.value + rouge_l_result.value) / 2.0;
let mut details = HashMap::new();
details.insert("rouge_2".to_string(), rouge2_result.value);
details.insert("rouge_l".to_string(), rouge_l_result.value);
details.insert("bleu_like".to_string(), primary);
Ok(MetricResult {
name: "seq2seq".to_string(),
value: primary,
details,
metadata: HashMap::new(),
})
}
fn reset(&mut self) {
self.generation_metric.reset();
}
fn name(&self) -> &str {
"seq2seq"
}
}
impl Default for Seq2SeqMetric {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_seq2seq_metric_basic() {
let mut metric = Seq2SeqMetric::new();
let predictions = MetricInput::Text(vec![
"the quick brown fox".to_string(),
"hello world".to_string(),
]);
let references = MetricInput::Text(vec![
"the quick brown fox jumps".to_string(),
"hello world test".to_string(),
]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "seq2seq");
assert!(result.value >= 0.0 && result.value <= 1.0);
assert!(result.details.contains_key("bleu_like"));
}
#[test]
fn test_seq2seq_metric_perfect_match() {
let mut metric = Seq2SeqMetric::new();
let predictions = MetricInput::Text(vec!["hello world".to_string()]);
let references = MetricInput::Text(vec!["hello world".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "seq2seq");
assert!(
(result.value - 1.0).abs() < 1e-6,
"perfect match should give 1.0, got {}",
result.value
);
}
#[test]
fn test_seq2seq_metric_no_overlap() {
let mut metric = Seq2SeqMetric::new();
let predictions = MetricInput::Text(vec!["foo bar".to_string()]);
let references = MetricInput::Text(vec!["baz qux".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "seq2seq");
assert_eq!(result.value, 0.0, "no-overlap seq2seq should be 0.0");
}
#[test]
fn test_seq2seq_metric_translation_example() {
let mut metric = Seq2SeqMetric::new();
let predictions = MetricInput::Text(vec![
"The cat sits on the mat".to_string(),
"I love artificial intelligence".to_string(),
]);
let references = MetricInput::Text(vec![
"The cat is sitting on the mat".to_string(),
"I love AI and machine learning".to_string(),
]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "seq2seq");
assert!(result.value > 0.0); }
#[test]
fn test_seq2seq_metric_summarization_example() {
let mut metric = Seq2SeqMetric::new();
let predictions = MetricInput::Text(vec!["AI research advances rapidly".to_string()]);
let references = MetricInput::Text(vec![
"Artificial intelligence research is advancing rapidly".to_string(),
]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "seq2seq");
assert!(result.value > 0.0); }
#[test]
fn test_seq2seq_metric_reset() {
let mut metric = Seq2SeqMetric::new();
let predictions = MetricInput::Text(vec!["hello".to_string()]);
let references = MetricInput::Text(vec!["world".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
metric.reset();
assert!(metric.compute().is_err());
}
#[test]
fn test_seq2seq_metric_invalid_input() {
let mut metric = Seq2SeqMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1]);
let references = MetricInput::Text(vec!["hello".to_string()]);
let result = metric.add_batch(&predictions, &references);
assert!(result.is_err());
}
#[test]
fn test_seq2seq_metric_multiple_batches() {
let mut metric = Seq2SeqMetric::new();
metric
.add_batch(
&MetricInput::Text(vec!["hello world".to_string()]),
&MetricInput::Text(vec!["hello world".to_string()]),
)
.expect("operation failed in test");
metric
.add_batch(
&MetricInput::Text(vec!["foo bar".to_string()]),
&MetricInput::Text(vec!["baz qux".to_string()]),
)
.expect("operation failed in test");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "seq2seq");
assert!(
result.value >= 0.0 && result.value <= 1.0,
"mixed seq2seq should be in [0, 1], got {}",
result.value
);
}
#[test]
fn test_seq2seq_metric_name() {
let metric = Seq2SeqMetric::new();
assert_eq!(metric.name(), "seq2seq");
}
#[test]
fn test_seq2seq_metric_empty_sequences() {
let mut metric = Seq2SeqMetric::new();
let predictions = MetricInput::Text(vec!["".to_string()]);
let references = MetricInput::Text(vec!["hello world".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "seq2seq");
assert_eq!(result.value, 0.0, "empty prediction should give 0.0");
}
}