use super::datasets::GoldEntity;
use anno::Entity;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct PartialMatchMetrics {
pub overlap_threshold: f64,
pub precision: f64,
pub recall: f64,
pub f1: f64,
pub partial_matches: usize,
}
#[must_use]
pub fn calculate_overlap(
pred_start: usize,
pred_end: usize,
gt_start: usize,
gt_end: usize,
) -> f64 {
let intersection_start = pred_start.max(gt_start);
let intersection_end = pred_end.min(gt_end);
if intersection_start >= intersection_end {
return 0.0;
}
let intersection = (intersection_end - intersection_start) as f64;
let union = ((pred_end - pred_start) + (gt_end - gt_start)
- (intersection_end - intersection_start)) as f64;
if union == 0.0 {
return 1.0; }
intersection / union
}
pub fn calculate_partial_match_metrics(
predicted: &[Entity],
ground_truth: &[GoldEntity],
overlap_threshold: f64,
) -> PartialMatchMetrics {
let mut true_positives = 0;
let mut _false_positives = 0;
let mut gt_matched = vec![false; ground_truth.len()];
for pred in predicted {
let mut best_match: Option<(usize, f64)> = None;
for (gt_idx, gt) in ground_truth.iter().enumerate() {
if gt_matched[gt_idx] {
continue; }
if !crate::eval::entity_type_matches(&pred.entity_type, >.entity_type) {
continue;
}
let overlap = calculate_overlap(pred.start(), pred.end(), gt.start, gt.end);
let should_update = match best_match {
None => true,
Some((_, best_overlap)) => best_overlap < overlap,
};
if overlap >= overlap_threshold && should_update {
best_match = Some((gt_idx, overlap));
}
}
if let Some((gt_idx, _)) = best_match {
true_positives += 1;
gt_matched[gt_idx] = true;
} else {
_false_positives += 1;
}
}
let _false_negatives = gt_matched.iter().filter(|&&matched| !matched).count();
let precision = if !predicted.is_empty() {
true_positives as f64 / predicted.len() as f64
} else {
0.0
};
let recall = if !ground_truth.is_empty() {
true_positives as f64 / ground_truth.len() as f64
} else {
0.0
};
let f1 = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
PartialMatchMetrics {
overlap_threshold,
precision,
recall,
f1,
partial_matches: true_positives,
}
}
#[derive(Debug, Clone, Default)]
pub struct ExtractionQualityMetrics {
pub total: usize,
pub duplicates: usize,
pub duplication_rate: f64,
pub noisy: usize,
pub noise_rate: f64,
}
fn normalize_for_duplication(text: &str) -> String {
text.chars()
.filter(|c| c.is_alphanumeric())
.flat_map(|c| c.to_lowercase())
.collect()
}
fn is_noisy_span(text: &str) -> bool {
let t = text.trim();
if t.is_empty() {
return true;
}
if t.chars()
.all(|c| c.is_whitespace() || c.is_ascii_punctuation())
{
return true;
}
if t.chars().all(|c| c.is_ascii_digit() || c.is_whitespace()) {
return true;
}
normalize_for_duplication(t).is_empty()
}
#[must_use]
pub fn compute_extraction_quality_metrics(entities: &[Entity]) -> ExtractionQualityMetrics {
let total = entities.len();
if total == 0 {
return ExtractionQualityMetrics::default();
}
let mut seen: HashSet<(String, String)> = HashSet::new();
let mut duplicates = 0usize;
let mut noisy = 0usize;
for e in entities {
if is_noisy_span(&e.text) {
noisy += 1;
}
let key = (
format!("{:?}", e.entity_type),
normalize_for_duplication(&e.text),
);
if seen.contains(&key) {
duplicates += 1;
} else {
seen.insert(key);
}
}
ExtractionQualityMetrics {
total,
duplicates,
duplication_rate: duplicates as f64 / total as f64,
noisy,
noise_rate: noisy as f64 / total as f64,
}
}
#[derive(Debug, Clone)]
pub struct ConfidenceThresholdAnalysis {
pub thresholds: Vec<f64>,
pub metrics_at_threshold: Vec<(f64, PartialMatchMetrics)>,
pub optimal_threshold: Option<f64>,
}
pub fn analyze_confidence_thresholds(
predicted: &[Entity],
ground_truth: &[GoldEntity],
overlap_threshold: f64,
) -> ConfidenceThresholdAnalysis {
let thresholds: Vec<f64> = (0..=10).map(|i| i as f64 / 10.0).collect();
let mut metrics_at_threshold = Vec::new();
let mut best_f1 = 0.0;
let mut optimal_threshold = None;
for threshold in &thresholds {
let filtered: Vec<&Entity> = predicted
.iter()
.filter(|e| e.confidence >= *threshold)
.collect();
let filtered_owned: Vec<Entity> = filtered.iter().map(|e| (*e).clone()).collect();
let metrics =
calculate_partial_match_metrics(&filtered_owned, ground_truth, overlap_threshold);
if metrics.f1 > best_f1 {
best_f1 = metrics.f1;
optimal_threshold = Some(*threshold);
}
metrics_at_threshold.push((*threshold, metrics));
}
ConfidenceThresholdAnalysis {
thresholds,
metrics_at_threshold,
optimal_threshold,
}
}
#[derive(Debug, Clone, Default)]
pub struct ClassificationMetrics {
pub total: usize,
pub correct: usize,
pub class_tp: std::collections::HashMap<String, usize>,
pub class_fp: std::collections::HashMap<String, usize>,
pub class_fn: std::collections::HashMap<String, usize>,
pub class_support: std::collections::HashMap<String, usize>,
}
impl ClassificationMetrics {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add(&mut self, predicted: &str, actual: &str) {
self.total += 1;
*self.class_support.entry(actual.to_string()).or_insert(0) += 1;
if predicted == actual {
self.correct += 1;
*self.class_tp.entry(actual.to_string()).or_insert(0) += 1;
} else {
*self.class_fp.entry(predicted.to_string()).or_insert(0) += 1;
*self.class_fn.entry(actual.to_string()).or_insert(0) += 1;
}
}
#[must_use]
pub fn accuracy(&self) -> f64 {
if self.total == 0 {
return 0.0;
}
self.correct as f64 / self.total as f64
}
#[must_use]
pub fn macro_precision(&self) -> f64 {
let classes: std::collections::HashSet<_> = self
.class_support
.keys()
.chain(self.class_fp.keys())
.collect();
if classes.is_empty() {
return 0.0;
}
let sum: f64 = classes
.iter()
.map(|class| self.class_precision(class))
.sum();
sum / classes.len() as f64
}
#[must_use]
pub fn macro_recall(&self) -> f64 {
if self.class_support.is_empty() {
return 0.0;
}
let sum: f64 = self
.class_support
.keys()
.map(|class| self.class_recall(class))
.sum();
sum / self.class_support.len() as f64
}
#[must_use]
pub fn macro_f1(&self) -> f64 {
let p = self.macro_precision();
let r = self.macro_recall();
if p + r == 0.0 {
return 0.0;
}
2.0 * p * r / (p + r)
}
#[must_use]
pub fn micro_precision(&self) -> f64 {
let tp: usize = self.class_tp.values().sum();
let fp: usize = self.class_fp.values().sum();
if tp + fp == 0 {
return 0.0;
}
tp as f64 / (tp + fp) as f64
}
#[must_use]
pub fn micro_recall(&self) -> f64 {
let tp: usize = self.class_tp.values().sum();
let fn_sum: usize = self.class_fn.values().sum();
if tp + fn_sum == 0 {
return 0.0;
}
tp as f64 / (tp + fn_sum) as f64
}
#[must_use]
pub fn micro_f1(&self) -> f64 {
let p = self.micro_precision();
let r = self.micro_recall();
if p + r == 0.0 {
return 0.0;
}
2.0 * p * r / (p + r)
}
#[must_use]
pub fn weighted_f1(&self) -> f64 {
if self.total == 0 {
return 0.0;
}
let sum: f64 = self
.class_support
.iter()
.map(|(class, &support)| {
let f1 = self.class_f1(class);
f1 * support as f64
})
.sum();
sum / self.total as f64
}
#[must_use]
pub fn class_precision(&self, class: &str) -> f64 {
let tp = *self.class_tp.get(class).unwrap_or(&0);
let fp = *self.class_fp.get(class).unwrap_or(&0);
if tp + fp == 0 {
return 0.0;
}
tp as f64 / (tp + fp) as f64
}
#[must_use]
pub fn class_recall(&self, class: &str) -> f64 {
let tp = *self.class_tp.get(class).unwrap_or(&0);
let fn_count = *self.class_fn.get(class).unwrap_or(&0);
if tp + fn_count == 0 {
return 0.0;
}
tp as f64 / (tp + fn_count) as f64
}
#[must_use]
pub fn class_f1(&self, class: &str) -> f64 {
let p = self.class_precision(class);
let r = self.class_recall(class);
if p + r == 0.0 {
return 0.0;
}
2.0 * p * r / (p + r)
}
#[must_use]
pub fn classes(&self) -> Vec<&String> {
let mut classes: Vec<_> = self.class_support.keys().collect();
classes.sort();
classes
}
}
#[cfg(test)]
mod tests {
use super::*;
use anno::EntityType;
#[test]
fn test_classification_metrics() {
let mut metrics = ClassificationMetrics::new();
metrics.add("sports", "sports");
metrics.add("sports", "sports");
metrics.add("business", "business");
metrics.add("sports", "business");
assert_eq!(metrics.total, 4);
assert_eq!(metrics.correct, 3);
assert!((metrics.accuracy() - 0.75).abs() < 0.001);
}
#[test]
fn test_classification_macro_f1() {
let mut metrics = ClassificationMetrics::new();
metrics.add("a", "a");
metrics.add("b", "b");
assert!((metrics.macro_f1() - 1.0).abs() < 0.001);
}
#[test]
fn test_calculate_overlap() {
assert!((calculate_overlap(0, 10, 0, 10) - 1.0).abs() < 0.001);
let overlap = calculate_overlap(0, 10, 5, 15);
assert!(overlap > 0.0 && overlap < 1.0);
assert!((calculate_overlap(0, 10, 20, 30) - 0.0).abs() < 0.001);
}
#[test]
fn test_calculate_partial_match_metrics() {
let predicted = vec![Entity::new("John Smith", EntityType::Person, 0, 10, 0.9)];
let ground_truth = vec![GoldEntity {
text: "John Smith".to_string(),
entity_type: EntityType::Person,
original_label: "PER".to_string(),
start: 0,
end: 10,
}];
let metrics = calculate_partial_match_metrics(&predicted, &ground_truth, 0.5);
assert!((metrics.precision - 1.0).abs() < 0.001);
assert!((metrics.recall - 1.0).abs() < 0.001);
}
}