use anno::{Entity, EntityType};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct EvalSpan {
pub entity_type: EntityType,
pub start: usize,
pub end: usize,
pub text: String,
}
impl EvalSpan {
#[must_use]
pub fn new(entity_type: EntityType, start: usize, end: usize, text: impl Into<String>) -> Self {
Self {
entity_type,
start,
end,
text: text.into(),
}
}
#[must_use]
pub fn exact_boundary_match(&self, other: &Self) -> bool {
self.start == other.start && self.end == other.end
}
#[must_use]
pub fn has_overlap(&self, other: &Self) -> bool {
self.start < other.end && other.start < self.end
}
#[must_use]
pub fn type_match(&self, other: &Self) -> bool {
self.entity_type == other.entity_type
}
}
impl From<&Entity> for EvalSpan {
fn from(entity: &Entity) -> Self {
Self {
entity_type: entity.entity_type.clone(),
start: entity.start(),
end: entity.end(),
text: entity.text.clone(),
}
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct MucCounts {
pub correct: usize,
pub incorrect: usize,
pub partial: usize,
pub missed: usize,
pub spurious: usize,
}
impl MucCounts {
#[must_use]
pub fn possible(&self) -> usize {
self.correct + self.incorrect + self.partial + self.missed
}
#[must_use]
pub fn actual(&self) -> usize {
self.correct + self.incorrect + self.partial + self.spurious
}
#[must_use]
pub fn precision_exact(&self) -> f64 {
let actual = self.actual();
if actual == 0 {
return 0.0;
}
self.correct as f64 / actual as f64
}
#[must_use]
pub fn recall_exact(&self) -> f64 {
let possible = self.possible();
if possible == 0 {
return 0.0;
}
self.correct as f64 / possible as f64
}
#[must_use]
pub fn precision_partial(&self) -> f64 {
let actual = self.actual();
if actual == 0 {
return 0.0;
}
(self.correct as f64 + 0.5 * self.partial as f64) / actual as f64
}
#[must_use]
pub fn recall_partial(&self) -> f64 {
let possible = self.possible();
if possible == 0 {
return 0.0;
}
(self.correct as f64 + 0.5 * self.partial as f64) / possible as f64
}
#[must_use]
pub fn f1_exact(&self) -> f64 {
let p = self.precision_exact();
let r = self.recall_exact();
if p + r == 0.0 {
return 0.0;
}
2.0 * p * r / (p + r)
}
#[must_use]
pub fn f1_partial(&self) -> f64 {
let p = self.precision_partial();
let r = self.recall_partial();
if p + r == 0.0 {
return 0.0;
}
2.0 * p * r / (p + r)
}
pub fn merge(&mut self, other: &MucCounts) {
self.correct += other.correct;
self.incorrect += other.incorrect;
self.partial += other.partial;
self.missed += other.missed;
self.spurious += other.spurious;
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct NerEvalResults {
pub strict: MucCounts,
pub exact: MucCounts,
pub partial: MucCounts,
pub ent_type: MucCounts,
pub by_type: HashMap<String, MucCounts>,
}
impl NerEvalResults {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn summary(&self) -> NerEvalSummary {
NerEvalSummary {
strict_precision: self.strict.precision_exact(),
strict_recall: self.strict.recall_exact(),
strict_f1: self.strict.f1_exact(),
exact_precision: self.exact.precision_exact(),
exact_recall: self.exact.recall_exact(),
exact_f1: self.exact.f1_exact(),
partial_precision: self.partial.precision_partial(),
partial_recall: self.partial.recall_partial(),
partial_f1: self.partial.f1_partial(),
type_precision: self.ent_type.precision_exact(),
type_recall: self.ent_type.recall_exact(),
type_f1: self.ent_type.f1_exact(),
}
}
pub fn merge(&mut self, other: &NerEvalResults) {
self.strict.merge(&other.strict);
self.exact.merge(&other.exact);
self.partial.merge(&other.partial);
self.ent_type.merge(&other.ent_type);
for (entity_type, counts) in &other.by_type {
self.by_type
.entry(entity_type.clone())
.or_default()
.merge(counts);
}
}
#[must_use]
pub fn to_markdown(&self) -> String {
let summary = self.summary();
format!(
"| Schema | Precision | Recall | F1 |\n\
|--------|-----------|--------|----|\n\
| Strict | {:.1}% | {:.1}% | {:.1}% |\n\
| Exact | {:.1}% | {:.1}% | {:.1}% |\n\
| Partial| {:.1}% | {:.1}% | {:.1}% |\n\
| Type | {:.1}% | {:.1}% | {:.1}% |",
summary.strict_precision * 100.0,
summary.strict_recall * 100.0,
summary.strict_f1 * 100.0,
summary.exact_precision * 100.0,
summary.exact_recall * 100.0,
summary.exact_f1 * 100.0,
summary.partial_precision * 100.0,
summary.partial_recall * 100.0,
summary.partial_f1 * 100.0,
summary.type_precision * 100.0,
summary.type_recall * 100.0,
summary.type_f1 * 100.0,
)
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct NerEvalSummary {
pub strict_precision: f64,
pub strict_recall: f64,
pub strict_f1: f64,
pub exact_precision: f64,
pub exact_recall: f64,
pub exact_f1: f64,
pub partial_precision: f64,
pub partial_recall: f64,
pub partial_f1: f64,
pub type_precision: f64,
pub type_recall: f64,
pub type_f1: f64,
}
#[must_use]
pub fn evaluate_ner(gold: &[EvalSpan], predicted: &[EvalSpan]) -> NerEvalResults {
let mut results = NerEvalResults::new();
let mut matched_preds: Vec<bool> = vec![false; predicted.len()];
for gold_span in gold {
let entity_type_str = format!("{:?}", gold_span.entity_type);
let mut best_match: Option<(usize, MatchType)> = None;
for (pred_idx, pred_span) in predicted.iter().enumerate() {
if matched_preds[pred_idx] {
continue;
}
let match_type = classify_match(gold_span, pred_span);
if match_type != MatchType::None {
if best_match.is_none()
|| match_type.priority() > best_match.as_ref().map_or(0, |(_, m)| m.priority())
{
best_match = Some((pred_idx, match_type));
}
}
}
match best_match {
Some((pred_idx, match_type)) => {
matched_preds[pred_idx] = true;
let pred_span = &predicted[pred_idx];
update_counts(&mut results, gold_span, pred_span, match_type);
let type_counts = results.by_type.entry(entity_type_str).or_default();
if match_type == MatchType::ExactBoth {
type_counts.correct += 1;
} else {
type_counts.incorrect += 1;
}
}
None => {
results.strict.missed += 1;
results.exact.missed += 1;
results.partial.missed += 1;
results.ent_type.missed += 1;
results.by_type.entry(entity_type_str).or_default().missed += 1;
}
}
}
for (pred_idx, matched) in matched_preds.iter().enumerate() {
if !matched {
results.strict.spurious += 1;
results.exact.spurious += 1;
results.partial.spurious += 1;
results.ent_type.spurious += 1;
let entity_type_str = format!("{:?}", predicted[pred_idx].entity_type);
results.by_type.entry(entity_type_str).or_default().spurious += 1;
}
}
results
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
enum MatchType {
None,
PartialBoundaryWrongType,
PartialBoundaryCorrectType,
ExactBoundaryWrongType,
ExactBoth,
}
impl MatchType {
fn priority(self) -> u8 {
match self {
MatchType::None => 0,
MatchType::PartialBoundaryWrongType => 1,
MatchType::PartialBoundaryCorrectType => 2,
MatchType::ExactBoundaryWrongType => 3,
MatchType::ExactBoth => 4,
}
}
}
fn classify_match(gold: &EvalSpan, pred: &EvalSpan) -> MatchType {
let exact_boundary = gold.exact_boundary_match(pred);
let has_overlap = gold.has_overlap(pred);
let type_match = gold.type_match(pred);
if exact_boundary && type_match {
MatchType::ExactBoth
} else if exact_boundary && !type_match {
MatchType::ExactBoundaryWrongType
} else if has_overlap && type_match {
MatchType::PartialBoundaryCorrectType
} else if has_overlap && !type_match {
MatchType::PartialBoundaryWrongType
} else {
MatchType::None
}
}
fn update_counts(
results: &mut NerEvalResults,
gold: &EvalSpan,
pred: &EvalSpan,
match_type: MatchType,
) {
let exact_boundary = gold.exact_boundary_match(pred);
let type_match = gold.type_match(pred);
if exact_boundary && type_match {
results.strict.correct += 1;
} else {
results.strict.incorrect += 1;
}
if exact_boundary {
results.exact.correct += 1;
} else {
results.exact.incorrect += 1;
}
match match_type {
MatchType::ExactBoth | MatchType::ExactBoundaryWrongType => {
results.partial.correct += 1;
}
MatchType::PartialBoundaryCorrectType | MatchType::PartialBoundaryWrongType => {
results.partial.partial += 1;
}
MatchType::None => {
results.partial.incorrect += 1;
}
}
if type_match && (exact_boundary || gold.has_overlap(pred)) {
results.ent_type.correct += 1;
} else {
results.ent_type.incorrect += 1;
}
}
#[must_use]
pub fn evaluate_entities(gold: &[Entity], predicted: &[Entity]) -> NerEvalResults {
let gold_spans: Vec<EvalSpan> = gold.iter().map(EvalSpan::from).collect();
let pred_spans: Vec<EvalSpan> = predicted.iter().map(EvalSpan::from).collect();
evaluate_ner(&gold_spans, &pred_spans)
}
#[cfg(test)]
mod tests {
use super::*;
fn span(t: EntityType, start: usize, end: usize) -> EvalSpan {
EvalSpan::new(t, start, end, "test")
}
#[test]
fn test_exact_match() {
let gold = vec![span(EntityType::Person, 0, 5)];
let pred = vec![span(EntityType::Person, 0, 5)];
let results = evaluate_ner(&gold, &pred);
assert_eq!(results.strict.correct, 1);
assert_eq!(results.exact.correct, 1);
assert_eq!(results.partial.correct, 1);
assert_eq!(results.ent_type.correct, 1);
}
#[test]
fn test_wrong_type() {
let gold = vec![span(EntityType::Person, 0, 5)];
let pred = vec![span(EntityType::Organization, 0, 5)];
let results = evaluate_ner(&gold, &pred);
assert_eq!(results.strict.incorrect, 1);
assert_eq!(results.exact.correct, 1);
assert_eq!(results.partial.correct, 1);
assert_eq!(results.ent_type.incorrect, 1);
}
#[test]
fn test_partial_boundary() {
let gold = vec![span(EntityType::Person, 0, 10)];
let pred = vec![span(EntityType::Person, 0, 8)];
let results = evaluate_ner(&gold, &pred);
assert_eq!(results.strict.incorrect, 1);
assert_eq!(results.exact.incorrect, 1);
assert_eq!(results.partial.partial, 1);
assert_eq!(results.ent_type.correct, 1);
}
#[test]
fn test_missing_entity() {
let gold = vec![span(EntityType::Person, 0, 5)];
let pred: Vec<EvalSpan> = vec![];
let results = evaluate_ner(&gold, &pred);
assert_eq!(results.strict.missed, 1);
assert_eq!(results.exact.missed, 1);
assert_eq!(results.partial.missed, 1);
assert_eq!(results.ent_type.missed, 1);
}
#[test]
fn test_spurious_entity() {
let gold: Vec<EvalSpan> = vec![];
let pred = vec![span(EntityType::Person, 0, 5)];
let results = evaluate_ner(&gold, &pred);
assert_eq!(results.strict.spurious, 1);
assert_eq!(results.exact.spurious, 1);
assert_eq!(results.partial.spurious, 1);
assert_eq!(results.ent_type.spurious, 1);
}
#[test]
fn test_precision_recall_f1() {
let gold = vec![
span(EntityType::Person, 0, 5),
span(EntityType::Location, 10, 15),
];
let pred = vec![
span(EntityType::Person, 0, 5), span(EntityType::Organization, 20, 25), ];
let results = evaluate_ner(&gold, &pred);
assert_eq!(results.strict.correct, 1);
assert_eq!(results.strict.spurious, 1);
assert_eq!(results.strict.missed, 1);
assert!((results.strict.precision_exact() - 0.5).abs() < 0.01);
assert!((results.strict.recall_exact() - 0.5).abs() < 0.01);
assert!((results.strict.f1_exact() - 0.5).abs() < 0.01);
}
#[test]
fn test_markdown_output() {
let gold = vec![span(EntityType::Person, 0, 5)];
let pred = vec![span(EntityType::Person, 0, 5)];
let results = evaluate_ner(&gold, &pred);
let md = results.to_markdown();
assert!(md.contains("Strict"));
assert!(md.contains("100.0%"));
}
#[test]
fn test_per_type_breakdown() {
let gold = vec![
span(EntityType::Person, 0, 5),
span(EntityType::Person, 10, 15),
span(EntityType::Location, 20, 25),
];
let pred = vec![
span(EntityType::Person, 0, 5), span(EntityType::Organization, 10, 15), span(EntityType::Location, 20, 25), ];
let results = evaluate_ner(&gold, &pred);
let person_counts = results.by_type.get("Person").unwrap();
assert_eq!(person_counts.correct, 1);
assert_eq!(person_counts.incorrect, 1);
let loc_counts = results.by_type.get("Location").unwrap();
assert_eq!(loc_counts.correct, 1);
}
}