use super::{ClassificationMetric, Metric, MetricInput, MetricResult};
use crate::error::{Result, TrustformersError};
#[derive(Debug, Clone)]
pub struct DefaultMetric {
classification_metric: ClassificationMetric,
}
impl DefaultMetric {
pub fn new() -> Self {
Self {
classification_metric: ClassificationMetric::new(),
}
}
}
impl Metric for DefaultMetric {
fn add_batch(&mut self, predictions: &MetricInput, references: &MetricInput) -> Result<()> {
self.classification_metric.add_batch(predictions, references)
}
fn compute(&self) -> Result<MetricResult> {
let mut result = self.classification_metric.compute()?;
result.name = "default".to_string();
Ok(result)
}
fn reset(&mut self) {
self.classification_metric.reset();
}
fn name(&self) -> &str {
"default"
}
}
impl Default for DefaultMetric {
fn default() -> Self {
Self::new()
}
}
pub struct CompositeMetric {
metrics: Vec<Box<dyn Metric>>,
}
impl std::fmt::Debug for CompositeMetric {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("CompositeMetric")
.field("metrics_count", &self.metrics.len())
.finish()
}
}
impl CompositeMetric {
pub fn new(metrics: Vec<Box<dyn Metric>>) -> Self {
Self { metrics }
}
pub fn metrics(&self) -> &Vec<Box<dyn Metric>> {
&self.metrics
}
pub fn add_metric(&mut self, metric: Box<dyn Metric>) {
self.metrics.push(metric);
}
pub fn len(&self) -> usize {
self.metrics.len()
}
pub fn is_empty(&self) -> bool {
self.metrics.is_empty()
}
pub fn compute_all(&self) -> Result<Vec<MetricResult>> {
self.metrics.iter().map(|m| m.compute()).collect()
}
pub fn reset_all(&mut self) {
for metric in &mut self.metrics {
metric.reset();
}
}
pub fn get_metric_names(&self) -> Vec<&str> {
self.metrics.iter().map(|m| m.name()).collect()
}
pub fn add_batch_to_compatible(
&mut self,
predictions: &MetricInput,
references: &MetricInput,
) -> Result<()> {
let mut successful = false;
for metric in &mut self.metrics {
if metric.add_batch(predictions, references).is_ok() {
successful = true;
}
}
if successful {
Ok(())
} else {
Err(TrustformersError::invalid_input_simple(
"No metrics in the composite could handle the provided input types".to_string(),
))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::auto::metrics::{GenerationMetric, MetricInput};
#[test]
fn test_default_metric_basic() {
let mut metric = DefaultMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1, 0, 1]);
let references = MetricInput::Classifications(vec![0, 0, 1, 1]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "default");
assert_eq!(result.value, 0.5); assert!(result.details.contains_key("accuracy"));
}
#[test]
fn test_default_metric_reset() {
let mut metric = DefaultMetric::new();
metric
.add_batch(
&MetricInput::Classifications(vec![0, 1]),
&MetricInput::Classifications(vec![0, 1]),
)
.expect("operation failed in test");
metric.reset();
assert!(metric.compute().is_err());
}
#[test]
fn test_composite_metric_creation() {
let composite = CompositeMetric::new(vec![
Box::new(ClassificationMetric::new()),
Box::new(GenerationMetric::new()),
]);
assert_eq!(composite.len(), 2);
assert!(!composite.is_empty());
let names = composite.get_metric_names();
assert_eq!(names, vec!["classification", "generation"]);
}
#[test]
fn test_composite_metric_add_metric() {
let mut composite = CompositeMetric::new(vec![]);
assert!(composite.is_empty());
composite.add_metric(Box::new(ClassificationMetric::new()));
assert_eq!(composite.len(), 1);
assert!(!composite.is_empty());
}
#[test]
fn test_composite_metric_reset_all() {
let mut composite = CompositeMetric::new(vec![Box::new(ClassificationMetric::new())]);
composite
.add_batch_to_compatible(
&MetricInput::Classifications(vec![0, 1]),
&MetricInput::Classifications(vec![0, 1]),
)
.expect("operation failed in test");
composite.reset_all();
let results = composite.compute_all();
assert!(results.is_err());
}
#[test]
fn test_composite_metric_compute_all() {
let mut composite = CompositeMetric::new(vec![Box::new(ClassificationMetric::new())]);
composite
.add_batch_to_compatible(
&MetricInput::Classifications(vec![0, 1]),
&MetricInput::Classifications(vec![0, 1]),
)
.expect("operation failed in test");
let results = composite.compute_all().expect("operation failed in test");
assert_eq!(results.len(), 1);
assert_eq!(results[0].name, "classification");
}
#[test]
fn test_composite_metric_incompatible_input() {
let mut composite = CompositeMetric::new(vec![Box::new(ClassificationMetric::new())]);
let result = composite.add_batch_to_compatible(
&MetricInput::Text(vec!["hello".to_string()]),
&MetricInput::Text(vec!["world".to_string()]),
);
assert!(result.is_err());
}
#[test]
fn test_composite_metric_mixed_compatibility() {
let mut composite = CompositeMetric::new(vec![
Box::new(ClassificationMetric::new()),
Box::new(GenerationMetric::new()),
]);
let result = composite.add_batch_to_compatible(
&MetricInput::Text(vec!["hello".to_string()]),
&MetricInput::Text(vec!["world".to_string()]),
);
assert!(result.is_ok());
let results = composite.compute_all();
assert!(results.is_err()); }
#[test]
fn test_default_metric_name() {
let metric = DefaultMetric::new();
assert_eq!(metric.name(), "default");
}
#[test]
fn test_composite_metric_empty() {
let composite = CompositeMetric::new(vec![]);
assert!(composite.is_empty());
assert_eq!(composite.len(), 0);
let results = composite.compute_all().expect("operation failed in test");
assert_eq!(results.len(), 0);
}
}