use super::{Metric, MetricInput, MetricResult};
use crate::error::{Result, TrustformersError};
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct TokenClassificationMetric {
predictions: Vec<Vec<String>>,
references: Vec<Vec<String>>,
}
impl TokenClassificationMetric {
pub fn new() -> Self {
Self {
predictions: Vec::new(),
references: Vec::new(),
}
}
fn convert_spans_to_strings(&self, spans: &[(usize, usize, String)]) -> Vec<String> {
spans
.iter()
.map(|(start, end, label)| format!("{}:{}:{}", start, end, label))
.collect()
}
fn parse_tags_to_spans(&self, tags: &str) -> Vec<(usize, usize, String)> {
let tag_list: Vec<&str> = tags.split_whitespace().collect();
let mut spans = Vec::new();
let mut current_start = None;
let mut current_label: Option<String> = None;
for (i, tag) in tag_list.iter().enumerate() {
if tag.starts_with("B-") {
if let (Some(start), Some(label)) = (current_start, ¤t_label) {
spans.push((start, i, label.clone()));
}
current_start = Some(i);
current_label = Some(tag[2..].to_string());
} else if tag.starts_with("I-") {
} else {
if let (Some(start), Some(label)) = (current_start, ¤t_label) {
spans.push((start, i, label.clone()));
}
current_start = None;
current_label = None;
}
}
if let (Some(start), Some(label)) = (current_start, current_label) {
spans.push((start, tag_list.len(), label));
}
spans
}
}
impl Metric for TokenClassificationMetric {
fn add_batch(&mut self, predictions: &MetricInput, references: &MetricInput) -> Result<()> {
match (predictions, references) {
(MetricInput::Spans(pred_spans), MetricInput::Spans(ref_spans)) => {
for pred_seq in pred_spans {
self.predictions.push(self.convert_spans_to_strings(pred_seq));
}
for ref_seq in ref_spans {
self.references.push(self.convert_spans_to_strings(ref_seq));
}
Ok(())
},
(MetricInput::Text(pred_text), MetricInput::Text(ref_text)) => {
for pred_tags in pred_text {
let spans = self.parse_tags_to_spans(pred_tags);
self.predictions.push(self.convert_spans_to_strings(&spans));
}
for ref_tags in ref_text {
let spans = self.parse_tags_to_spans(ref_tags);
self.references.push(self.convert_spans_to_strings(&spans));
}
Ok(())
},
_ => Err(TrustformersError::invalid_input_simple("Invalid input types for token classification metric: expected Spans or Text for both predictions and references".to_string()
)),
}
}
fn compute(&self) -> Result<MetricResult> {
let mut total_predicted = 0;
let mut total_reference = 0;
let mut total_correct = 0;
let num_sequences = self.predictions.len().min(self.references.len());
for i in 0..num_sequences {
let pred_entities = &self.predictions[i];
let ref_entities = &self.references[i];
total_predicted += pred_entities.len();
total_reference += ref_entities.len();
for pred_entity in pred_entities {
if ref_entities.contains(pred_entity) {
total_correct += 1;
}
}
}
let precision = if total_predicted > 0 {
total_correct as f64 / total_predicted as f64
} else {
0.0
};
let recall = if total_reference > 0 {
total_correct as f64 / total_reference as f64
} else {
0.0
};
let f1 = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
let mut details = HashMap::new();
details.insert("precision".to_string(), precision);
details.insert("recall".to_string(), recall);
details.insert("f1".to_string(), f1);
Ok(MetricResult {
name: "token_classification".to_string(),
value: f1, details,
metadata: HashMap::new(),
})
}
fn reset(&mut self) {
self.predictions.clear();
self.references.clear();
}
fn name(&self) -> &str {
"token_classification"
}
}
impl Default for TokenClassificationMetric {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_token_classification_metric_spans() {
let mut metric = TokenClassificationMetric::new();
let predictions = MetricInput::Spans(vec![vec![
(0, 2, "PERSON".to_string()),
(3, 5, "LOCATION".to_string()),
]]);
let references = MetricInput::Spans(vec![vec![
(0, 2, "PERSON".to_string()),
(3, 5, "LOCATION".to_string()),
]]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.name, "token_classification");
assert_eq!(result.value, 1.0); assert_eq!(result.details.get("precision"), Some(&1.0));
assert_eq!(result.details.get("recall"), Some(&1.0));
assert_eq!(result.details.get("f1"), Some(&1.0));
}
#[test]
fn test_token_classification_metric_partial_match() {
let mut metric = TokenClassificationMetric::new();
let predictions = MetricInput::Spans(vec![vec![
(0, 2, "PERSON".to_string()),
(3, 5, "LOCATION".to_string()),
]]);
let references = MetricInput::Spans(vec![
vec![(0, 2, "PERSON".to_string()), (3, 6, "LOCATION".to_string())], ]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("precision"), Some(&0.5)); assert_eq!(result.details.get("recall"), Some(&0.5)); assert_eq!(result.details.get("f1"), Some(&0.5));
}
#[test]
fn test_token_classification_metric_no_match() {
let mut metric = TokenClassificationMetric::new();
let predictions = MetricInput::Spans(vec![vec![(0, 2, "PERSON".to_string())]]);
let references = MetricInput::Spans(vec![vec![(3, 5, "LOCATION".to_string())]]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("precision"), Some(&0.0));
assert_eq!(result.details.get("recall"), Some(&0.0));
assert_eq!(result.details.get("f1"), Some(&0.0));
}
#[test]
fn test_token_classification_metric_text_tags() {
let mut metric = TokenClassificationMetric::new();
let predictions = MetricInput::Text(vec!["B-PER I-PER O B-LOC O".to_string()]);
let references = MetricInput::Text(vec!["B-PER I-PER O B-LOC O".to_string()]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, 1.0); }
#[test]
fn test_token_classification_metric_empty() {
let mut metric = TokenClassificationMetric::new();
let predictions = MetricInput::Spans(vec![vec![]]);
let references = MetricInput::Spans(vec![vec![]]);
metric.add_batch(&predictions, &references).expect("add operation failed");
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.details.get("precision"), Some(&0.0));
assert_eq!(result.details.get("recall"), Some(&0.0));
assert_eq!(result.details.get("f1"), Some(&0.0));
}
#[test]
fn test_token_classification_metric_reset() {
let mut metric = TokenClassificationMetric::new();
metric
.add_batch(
&MetricInput::Spans(vec![vec![(0, 2, "PERSON".to_string())]]),
&MetricInput::Spans(vec![vec![(0, 2, "PERSON".to_string())]]),
)
.expect("operation failed in test");
metric.reset();
let result = metric.compute().expect("operation failed in test");
assert_eq!(result.value, 0.0);
}
#[test]
fn test_token_classification_metric_invalid_input() {
let mut metric = TokenClassificationMetric::new();
let predictions = MetricInput::Classifications(vec![0, 1]);
let references = MetricInput::Spans(vec![vec![(0, 2, "PERSON".to_string())]]);
let result = metric.add_batch(&predictions, &references);
assert!(result.is_err());
}
#[test]
fn test_convert_spans_to_strings() {
let metric = TokenClassificationMetric::new();
let spans = vec![(0, 3, "PERSON".to_string()), (5, 8, "LOC".to_string())];
let strings = metric.convert_spans_to_strings(&spans);
assert_eq!(strings, vec!["0:3:PERSON", "5:8:LOC"]);
}
#[test]
fn test_parse_tags_to_spans() {
let metric = TokenClassificationMetric::new();
let tags = "B-PER I-PER O B-LOC O B-ORG";
let spans = metric.parse_tags_to_spans(tags);
assert_eq!(
spans,
vec![
(0, 2, "PER".to_string()),
(3, 4, "LOC".to_string()),
(5, 6, "ORG".to_string()),
]
);
}
#[test]
fn test_parse_tags_empty() {
let metric = TokenClassificationMetric::new();
let tags = "O O O";
let spans = metric.parse_tags_to_spans(tags);
assert_eq!(spans, vec![]);
}
#[test]
fn test_mixed_input_types() {
let mut metric = TokenClassificationMetric::new();
let predictions = MetricInput::Spans(vec![vec![(0, 2, "PERSON".to_string())]]);
let references = MetricInput::Text(vec!["B-PER I-PER".to_string()]);
let result = metric.add_batch(&predictions, &references);
assert!(result.is_err()); }
}