use super::datasets::GoldEntity;
use anno::{Entity, EntityType};
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EvalConfig {
pub min_overlap: f64,
}
impl Default for EvalConfig {
fn default() -> Self {
Self { min_overlap: 0.0 }
}
}
impl EvalConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn strict() -> Self {
Self::default()
}
#[must_use]
pub fn with_min_overlap(mut self, threshold: f64) -> Self {
self.min_overlap = threshold.clamp(0.0, 1.0);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default, Serialize, Deserialize)]
pub enum EvalMode {
#[default]
Strict,
Exact,
Partial,
Type,
}
impl EvalMode {
pub fn all() -> &'static [EvalMode] {
&[
EvalMode::Strict,
EvalMode::Exact,
EvalMode::Partial,
EvalMode::Type,
]
}
#[must_use]
pub fn name(&self) -> &'static str {
match self {
EvalMode::Strict => "Strict",
EvalMode::Exact => "Exact",
EvalMode::Partial => "Partial",
EvalMode::Type => "Type",
}
}
#[must_use]
pub fn description(&self) -> &'static str {
match self {
EvalMode::Strict => "Exact boundary + exact type (CoNLL standard)",
EvalMode::Exact => "Exact boundary only (type can differ)",
EvalMode::Partial => "Partial boundary overlap + exact type",
EvalMode::Type => "Any overlap + exact type",
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ModeResults {
pub mode: EvalMode,
pub precision: f64,
pub recall: f64,
pub f1: f64,
pub true_positives: usize,
pub false_positives: usize,
pub false_negatives: usize,
}
impl ModeResults {
#[must_use]
pub fn compute(predicted: &[Entity], gold: &[GoldEntity], mode: EvalMode) -> Self {
let (tp, fp, fn_count) = count_matches(predicted, gold, mode);
let precision = if tp + fp > 0 {
tp as f64 / (tp + fp) as f64
} else {
0.0
};
let recall = if tp + fn_count > 0 {
tp as f64 / (tp + fn_count) as f64
} else {
0.0
};
let f1 = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
Self {
mode,
precision,
recall,
f1,
true_positives: tp,
false_positives: fp,
false_negatives: fn_count,
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MultiModeResults {
pub strict: ModeResults,
pub exact: ModeResults,
pub partial: ModeResults,
pub type_mode: ModeResults,
}
impl MultiModeResults {
#[must_use]
pub fn compute(predicted: &[Entity], gold: &[GoldEntity]) -> Self {
Self {
strict: ModeResults::compute(predicted, gold, EvalMode::Strict),
exact: ModeResults::compute(predicted, gold, EvalMode::Exact),
partial: ModeResults::compute(predicted, gold, EvalMode::Partial),
type_mode: ModeResults::compute(predicted, gold, EvalMode::Type),
}
}
#[must_use]
pub fn get(&self, mode: EvalMode) -> &ModeResults {
match mode {
EvalMode::Strict => &self.strict,
EvalMode::Exact => &self.exact,
EvalMode::Partial => &self.partial,
EvalMode::Type => &self.type_mode,
}
}
pub fn print_summary(&self) {
println!("Evaluation Mode Results:");
println!(
"{:<10} {:>10} {:>10} {:>10}",
"Mode", "Precision", "Recall", "F1"
);
println!("{:-<43}", "");
for mode in EvalMode::all() {
let r = self.get(*mode);
println!(
"{:<10} {:>9.1}% {:>9.1}% {:>9.1}%",
mode.name(),
r.precision * 100.0,
r.recall * 100.0,
r.f1 * 100.0
);
}
}
}
fn entities_match(pred: &Entity, gold: &GoldEntity, mode: EvalMode) -> bool {
match mode {
EvalMode::Strict => {
pred.start() == gold.start
&& pred.end() == gold.end
&& types_match(&pred.entity_type, &gold.entity_type)
}
EvalMode::Exact => {
pred.start() == gold.start && pred.end() == gold.end
}
EvalMode::Partial => {
has_overlap(pred.start(), pred.end(), gold.start, gold.end)
&& types_match(&pred.entity_type, &gold.entity_type)
}
EvalMode::Type => {
has_overlap(pred.start(), pred.end(), gold.start, gold.end)
&& types_match(&pred.entity_type, &gold.entity_type)
}
}
}
fn types_match(a: &EntityType, b: &EntityType) -> bool {
super::entity_type_matches(a, b)
}
fn has_overlap(start1: usize, end1: usize, start2: usize, end2: usize) -> bool {
start1 < end2 && start2 < end1
}
fn has_sufficient_overlap(
start1: usize,
end1: usize,
start2: usize,
end2: usize,
min_threshold: f64,
) -> bool {
if !has_overlap(start1, end1, start2, end2) {
return false;
}
if min_threshold <= 0.0 {
return true;
}
overlap_ratio(start1, end1, start2, end2) >= min_threshold
}
#[must_use]
pub fn overlap_ratio(start1: usize, end1: usize, start2: usize, end2: usize) -> f64 {
let intersection_start = start1.max(start2);
let intersection_end = end1.min(end2);
if intersection_start >= intersection_end {
return 0.0;
}
let intersection = (intersection_end - intersection_start) as f64;
let union =
((end1 - start1) + (end2 - start2) - (intersection_end - intersection_start)) as f64;
if union == 0.0 {
1.0
} else {
intersection / union
}
}
fn count_matches(
predicted: &[Entity],
gold: &[GoldEntity],
mode: EvalMode,
) -> (usize, usize, usize) {
let mut gold_matched = vec![false; gold.len()];
let mut tp = 0;
let mut fp = 0;
for pred in predicted {
let mut found_match = false;
for (i, g) in gold.iter().enumerate() {
if gold_matched[i] {
continue;
}
if entities_match(pred, g, mode) {
gold_matched[i] = true;
found_match = true;
tp += 1;
break;
}
}
if !found_match {
fp += 1;
}
}
let fn_count = gold_matched.iter().filter(|&&m| !m).count();
(tp, fp, fn_count)
}
#[must_use]
pub fn evaluate_with_mode(
predicted: &[Entity],
gold: &[GoldEntity],
mode: EvalMode,
) -> ModeResults {
ModeResults::compute(predicted, gold, mode)
}
#[must_use]
pub fn evaluate_with_config(
predicted: &[Entity],
gold: &[GoldEntity],
mode: EvalMode,
config: &EvalConfig,
) -> ModeResults {
let (tp, fp, fn_count) = count_matches_with_config(predicted, gold, mode, config);
let precision = if tp + fp > 0 {
tp as f64 / (tp + fp) as f64
} else {
0.0
};
let recall = if tp + fn_count > 0 {
tp as f64 / (tp + fn_count) as f64
} else {
0.0
};
let f1 = if precision + recall > 0.0 {
2.0 * precision * recall / (precision + recall)
} else {
0.0
};
ModeResults {
mode,
precision,
recall,
f1,
true_positives: tp,
false_positives: fp,
false_negatives: fn_count,
}
}
fn count_matches_with_config(
predicted: &[Entity],
gold: &[GoldEntity],
mode: EvalMode,
config: &EvalConfig,
) -> (usize, usize, usize) {
let mut gold_matched = vec![false; gold.len()];
let mut tp = 0;
let mut fp = 0;
for pred in predicted {
let mut found_match = false;
for (i, g) in gold.iter().enumerate() {
if gold_matched[i] {
continue;
}
if entities_match_with_config(pred, g, mode, config) {
gold_matched[i] = true;
found_match = true;
tp += 1;
break;
}
}
if !found_match {
fp += 1;
}
}
let fn_count = gold_matched.iter().filter(|&&m| !m).count();
(tp, fp, fn_count)
}
fn entities_match_with_config(
pred: &Entity,
gold: &GoldEntity,
mode: EvalMode,
config: &EvalConfig,
) -> bool {
match mode {
EvalMode::Strict => {
pred.start() == gold.start
&& pred.end() == gold.end
&& types_match(&pred.entity_type, &gold.entity_type)
}
EvalMode::Exact => pred.start() == gold.start && pred.end() == gold.end,
EvalMode::Partial | EvalMode::Type => {
has_sufficient_overlap(
pred.start(),
pred.end(),
gold.start,
gold.end,
config.min_overlap,
) && types_match(&pred.entity_type, &gold.entity_type)
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn pred(text: &str, ty: EntityType, start: usize, end: usize) -> Entity {
Entity::new(text, ty, start, end, 0.9)
}
fn gold(text: &str, ty: EntityType, start: usize) -> GoldEntity {
GoldEntity::new(text, ty, start)
}
#[test]
fn test_strict_exact_match() {
let predicted = vec![pred("John", EntityType::Person, 0, 4)];
let gold_entities = vec![gold("John", EntityType::Person, 0)];
let results = ModeResults::compute(&predicted, &gold_entities, EvalMode::Strict);
assert!((results.f1 - 1.0).abs() < 0.001);
}
#[test]
fn test_strict_wrong_boundary() {
let predicted = vec![pred("John Smith", EntityType::Person, 0, 10)];
let gold_entities = vec![gold("John", EntityType::Person, 0)];
let results = ModeResults::compute(&predicted, &gold_entities, EvalMode::Strict);
assert_eq!(results.f1, 0.0);
let partial = ModeResults::compute(&predicted, &gold_entities, EvalMode::Partial);
assert!((partial.f1 - 1.0).abs() < 0.001);
}
#[test]
fn test_strict_wrong_type() {
let predicted = vec![pred("Apple", EntityType::Organization, 0, 5)];
let gold_entities = vec![gold("Apple", EntityType::Location, 0)];
let results = ModeResults::compute(&predicted, &gold_entities, EvalMode::Strict);
assert_eq!(results.f1, 0.0);
let exact = ModeResults::compute(&predicted, &gold_entities, EvalMode::Exact);
assert!((exact.f1 - 1.0).abs() < 0.001);
}
#[test]
fn test_partial_overlap() {
let predicted = vec![pred("New York City", EntityType::Location, 0, 13)];
let gold_entities = vec![gold("New York", EntityType::Location, 0)];
let strict = ModeResults::compute(&predicted, &gold_entities, EvalMode::Strict);
assert_eq!(strict.f1, 0.0);
let partial = ModeResults::compute(&predicted, &gold_entities, EvalMode::Partial);
assert!((partial.f1 - 1.0).abs() < 0.001);
}
#[test]
fn test_no_overlap() {
let predicted = vec![pred("John", EntityType::Person, 0, 4)];
let gold_entities = vec![gold("Mary", EntityType::Person, 10)];
for mode in EvalMode::all() {
let results = ModeResults::compute(&predicted, &gold_entities, *mode);
assert_eq!(
results.f1, 0.0,
"Mode {:?} should fail with no overlap",
mode
);
}
}
#[test]
fn test_multi_mode_results() {
let predicted = vec![
pred("John", EntityType::Person, 0, 4),
pred("New York City", EntityType::Location, 10, 23),
];
let gold_entities = vec![
gold("John", EntityType::Person, 0),
gold("New York", EntityType::Location, 10),
];
let all = MultiModeResults::compute(&predicted, &gold_entities);
assert!((all.strict.precision - 0.5).abs() < 0.001);
assert!((all.partial.f1 - 1.0).abs() < 0.001);
}
#[test]
fn test_overlap_ratio() {
assert!((overlap_ratio(0, 10, 0, 10) - 1.0).abs() < 0.001);
assert!((overlap_ratio(0, 5, 10, 15) - 0.0).abs() < 0.001);
assert!(
(overlap_ratio(0, 10, 5, 15) - (5.0 / 15.0)).abs() < 0.001,
"Expected IoU of 5/15 = {}, got {}",
5.0 / 15.0,
overlap_ratio(0, 10, 5, 15)
);
}
#[test]
fn test_empty_inputs() {
let empty_pred: Vec<Entity> = vec![];
let empty_gold: Vec<GoldEntity> = vec![];
let results = ModeResults::compute(&empty_pred, &empty_gold, EvalMode::Strict);
assert_eq!(results.f1, 0.0);
assert_eq!(results.true_positives, 0);
assert_eq!(results.false_positives, 0);
assert_eq!(results.false_negatives, 0);
}
#[test]
fn test_config_default() {
let config = EvalConfig::default();
assert_eq!(config.min_overlap, 0.0);
}
#[test]
fn test_config_with_overlap() {
let config = EvalConfig::new().with_min_overlap(0.5);
assert_eq!(config.min_overlap, 0.5);
}
#[test]
fn test_config_clamp() {
let config = EvalConfig::new().with_min_overlap(1.5);
assert_eq!(config.min_overlap, 1.0);
let config = EvalConfig::new().with_min_overlap(-0.5);
assert_eq!(config.min_overlap, 0.0);
}
#[test]
fn test_partial_with_zero_threshold() {
let predicted = vec![pred("New York City", EntityType::Location, 0, 13)];
let gold_entities = vec![gold("New York", EntityType::Location, 0)];
let config = EvalConfig::default();
let results = evaluate_with_config(&predicted, &gold_entities, EvalMode::Partial, &config);
assert!((results.f1 - 1.0).abs() < 0.001);
}
#[test]
fn test_partial_with_high_threshold() {
let predicted = vec![pred("New York City", EntityType::Location, 0, 13)];
let gold_entities = vec![gold("New York", EntityType::Location, 0)];
let config = EvalConfig::new().with_min_overlap(0.5);
let results = evaluate_with_config(&predicted, &gold_entities, EvalMode::Partial, &config);
assert!(
(results.f1 - 1.0).abs() < 0.001,
"0.5 threshold should pass"
);
let config = EvalConfig::new().with_min_overlap(0.7);
let results = evaluate_with_config(&predicted, &gold_entities, EvalMode::Partial, &config);
assert_eq!(results.f1, 0.0, "0.7 threshold should fail");
}
#[test]
fn test_partial_barely_touching() {
let predicted = vec![pred("Apple", EntityType::Organization, 0, 5)];
let gold_entities = vec![gold("Banana", EntityType::Organization, 4)];
let config = EvalConfig::default();
let results = evaluate_with_config(&predicted, &gold_entities, EvalMode::Partial, &config);
assert!((results.f1 - 1.0).abs() < 0.001);
let config = EvalConfig::new().with_min_overlap(0.2);
let results = evaluate_with_config(&predicted, &gold_entities, EvalMode::Partial, &config);
assert_eq!(results.f1, 0.0);
}
#[test]
fn test_strict_mode_ignores_threshold() {
let predicted = vec![pred("John", EntityType::Person, 0, 4)];
let gold_entities = vec![gold("John", EntityType::Person, 0)];
let config = EvalConfig::new().with_min_overlap(0.99);
let results = evaluate_with_config(&predicted, &gold_entities, EvalMode::Strict, &config);
assert!((results.f1 - 1.0).abs() < 0.001);
}
#[test]
fn test_has_sufficient_overlap() {
assert!(has_sufficient_overlap(0, 10, 5, 15, 0.0));
assert!(!has_sufficient_overlap(0, 10, 5, 15, 0.5));
assert!(has_sufficient_overlap(0, 10, 2, 12, 0.5));
assert!(!has_sufficient_overlap(0, 5, 10, 15, 0.0));
}
}