use super::{Metric, MetricInput, MetricResult};
use crate::error::{Result, TrustformersError};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct ClassificationMetric {
predictions: Vec<usize>,
references: Vec<usize>,
}
impl ClassificationMetric {
pub fn new() -> Self {
Self {
predictions: Vec::new(),
references: Vec::new(),
}
}
}
impl Metric for ClassificationMetric {
fn add_batch(&mut self, predictions: &MetricInput, references: &MetricInput) -> Result<()> {
match (predictions, references) {
(MetricInput::Classifications(pred), MetricInput::Classifications(ref_)) => {
self.predictions.extend(pred);
self.references.extend(ref_);
Ok(())
},
(MetricInput::Probabilities(probs), MetricInput::Classifications(ref_)) => {
let pred: Vec<usize> = probs
.iter()
.map(|p| {
if !p.is_empty() && !p[0].is_empty() {
p[0].iter()
.enumerate()
.max_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.unwrap_or(0)
} else {
0
}
})
.collect();
self.predictions.extend(pred);
self.references.extend(ref_);
Ok(())
},
_ => Err(TrustformersError::invalid_input_simple(
"Invalid input types for classification metric: expected Classifications or Probabilities for predictions, Classifications for references".to_string()
)),
}
}
fn compute(&self) -> Result<MetricResult> {
if self.predictions.len() != self.references.len() {
return Err(TrustformersError::invalid_input_simple(
"Predictions and references must have the same length".to_string(),
));
}
if self.predictions.is_empty() {
return Err(TrustformersError::invalid_input_simple(
"No data available for metric computation".to_string(),
));
}
let correct = self
.predictions
.iter()
.zip(self.references.iter())
.filter(|(p, r)| p == r)
.count();
let accuracy = correct as f64 / self.predictions.len() as f64;
let max_pred = self.predictions.iter().max().copied().unwrap_or(0);
let max_ref = self.references.iter().max().copied().unwrap_or(0);
let num_classes = (max_pred.max(max_ref) + 1).max(2);
let mut precision_sum = 0.0;
let mut recall_sum = 0.0;
let mut f1_sum = 0.0;
let mut valid_classes = 0;
for class in 0..num_classes {
let tp = self
.predictions
.iter()
.zip(self.references.iter())
.filter(|(p, r)| **p == class && **r == class)
.count() as f64;
let fp = self
.predictions
.iter()
.zip(self.references.iter())
.filter(|(p, r)| **p == class && **r != class)
.count() as f64;
let fn_ = self
.predictions
.iter()
.zip(self.references.iter())
.filter(|(p, r)| **p != class && **r == class)
.count() as f64;
if tp + fp > 0.0 || tp + fn_ > 0.0 {
let precision = if tp + fp > 0.0 { tp / (tp + fp) } else { 0.0 };
let recall = if tp + fn_ > 0.0 { tp / (tp + fn_) } else { 0.0 };
let f1 = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
precision_sum += precision;
recall_sum += recall;
f1_sum += f1;
valid_classes += 1;
}
}
let macro_precision =
if valid_classes > 0 { precision_sum / valid_classes as f64 } else { 0.0 };
let macro_recall = if valid_classes > 0 { recall_sum / valid_classes as f64 } else { 0.0 };
let macro_f1 = if valid_classes > 0 { f1_sum / valid_classes as f64 } else { 0.0 };
let mut details = HashMap::new();
details.insert("accuracy".to_string(), accuracy);
details.insert("macro_precision".to_string(), macro_precision);
details.insert("macro_recall".to_string(), macro_recall);
details.insert("macro_f1".to_string(), macro_f1);
Ok(MetricResult {
name: "classification".to_string(),
value: accuracy, details,
metadata: HashMap::new(),
})
}
fn reset(&mut self) {
self.predictions.clear();
self.references.clear();
}
fn name(&self) -> &str {
"classification"
}
}
impl Default for ClassificationMetric {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_classification_metric_basic() {
let mut metric = ClassificationMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1, 0, 1]);
let references = MetricInput::Classifications(vec![0, 0, 1, 1]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "classification");
assert_eq!(result.value, 0.5); assert!(result.details.contains_key("macro_f1"));
}
#[test]
fn test_classification_metric_perfect_score() {
let mut metric = ClassificationMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1, 2, 0, 1, 2]);
let references = MetricInput::Classifications(vec![0, 1, 2, 0, 1, 2]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, 1.0); assert_eq!(result.details.get("accuracy"), Some(&1.0));
assert_eq!(result.details.get("macro_f1"), Some(&1.0));
}
#[test]
fn test_classification_metric_probabilities() {
let mut metric = ClassificationMetric::new();
let probabilities = MetricInput::Probabilities(vec![
vec![vec![0.9, 0.1]], vec![vec![0.2, 0.8]], ]);
let references = MetricInput::Classifications(vec![0, 1]);
metric.add_batch(&probabilities, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, 1.0); }
#[test]
fn test_classification_metric_reset() {
let mut metric = ClassificationMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1]);
let references = MetricInput::Classifications(vec![0, 1]);
metric.add_batch(&predictions, &references).expect("add operation failed");
metric.reset();
assert!(metric.compute().is_err());
}
#[test]
fn test_classification_metric_invalid_input() {
let mut metric = ClassificationMetric::new();
let predictions = MetricInput::Text(vec!["hello".to_string()]);
let references = MetricInput::Classifications(vec![0]);
let result = metric.add_batch(&predictions, &references);
assert!(result.is_err());
}
#[test]
fn test_classification_metric_mismatched_lengths() {
let mut metric = ClassificationMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1]);
let references = MetricInput::Classifications(vec![0]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute();
assert!(result.is_err());
}
}