Skip to main content

scirs2_text/evaluation/
ner.rs

1//! CoNLL-2003 NER evaluation protocol — span-level F1.
2//!
3//! This module implements the *exact match* span evaluation protocol used by
4//! the CoNLL-2003 shared task.  Both span boundaries **and** the entity label
5//! must agree for a prediction to count as a true positive.
6//!
7//! # BIO encoding
8//!
9//! Token labels follow the standard BIO(E) scheme:
10//! - `B-<TYPE>` — beginning of a named entity of type `<TYPE>`
11//! - `I-<TYPE>` — inside/continuation of a named entity of the same type
12//! - `O` — outside any entity
13//!
14//! # Example
15//!
16//! ```rust
17//! use scirs2_text::evaluation::ner::{extract_spans_from_bio, evaluate_ner, NerSpan};
18//!
19//! let tokens = ["John", "Smith", "works", "at", "Google", "."];
20//! let labels = ["B-PER", "I-PER", "O", "O", "B-ORG", "O"];
21//!
22//! let gold = extract_spans_from_bio(
23//!     &tokens.iter().map(|s| *s).collect::<Vec<_>>(),
24//!     &labels.iter().map(|s| *s).collect::<Vec<_>>(),
25//! );
26//! assert_eq!(gold.len(), 2);
27//! ```
28
29use crate::error::TextError;
30use std::collections::HashMap;
31use std::fmt;
32
33// ─── Core data types ──────────────────────────────────────────────────────────
34
35/// A contiguous named-entity span in a token sequence.
36#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct NerSpan {
38    /// Inclusive start token index.
39    pub start: usize,
40    /// Exclusive end token index.
41    pub end: usize,
42    /// Entity type label (e.g. `"PER"`, `"ORG"`, `"LOC"`, `"MISC"`).
43    pub label: String,
44}
45
46impl NerSpan {
47    /// Construct a new span.
48    pub fn new(start: usize, end: usize, label: impl Into<String>) -> Self {
49        NerSpan {
50            start,
51            end,
52            label: label.into(),
53        }
54    }
55}
56
57impl fmt::Display for NerSpan {
58    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59        write!(f, "{}[{}..{})", self.label, self.start, self.end)
60    }
61}
62
63/// Per-class precision / recall / F1 metrics.
64#[derive(Debug, Clone)]
65pub struct ClassMetrics {
66    /// Precision = TP / (TP + FP).
67    pub precision: f64,
68    /// Recall = TP / (TP + FN).
69    pub recall: f64,
70    /// F1 harmonic mean.
71    pub f1: f64,
72    /// Number of gold spans with this label.
73    pub support: usize,
74    /// True positives.
75    pub tp: usize,
76    /// False positives.
77    pub fp: usize,
78    /// False negatives.
79    pub fn_: usize,
80}
81
82impl ClassMetrics {
83    fn from_counts(tp: usize, fp: usize, fn_: usize) -> Self {
84        let precision = if tp + fp == 0 {
85            0.0
86        } else {
87            tp as f64 / (tp + fp) as f64
88        };
89        let recall = if tp + fn_ == 0 {
90            0.0
91        } else {
92            tp as f64 / (tp + fn_) as f64
93        };
94        let f1 = if precision + recall == 0.0 {
95            0.0
96        } else {
97            2.0 * precision * recall / (precision + recall)
98        };
99        ClassMetrics {
100            precision,
101            recall,
102            f1,
103            support: tp + fn_,
104            tp,
105            fp,
106            fn_,
107        }
108    }
109}
110
111/// Overall NER evaluation result (CoNLL-style).
112#[derive(Debug, Clone)]
113pub struct NerEvaluationResult {
114    /// Micro-averaged precision over all entity spans.
115    pub precision: f64,
116    /// Micro-averaged recall over all entity spans.
117    pub recall: f64,
118    /// Micro-averaged F1 (the primary CoNLL metric).
119    pub f1: f64,
120    /// Per-entity-type metrics.
121    pub per_class: HashMap<String, ClassMetrics>,
122    /// Exact match accuracy at the sentence level (fraction of sequences
123    /// where all predicted spans are correct and no gold spans are missing).
124    pub exact_match: f64,
125}
126
127// ─── BIO span extraction ──────────────────────────────────────────────────────
128
129/// Extract named-entity [`NerSpan`]s from a BIO-tagged token sequence.
130///
131/// Handles:
132/// - Standard `B-TYPE` / `I-TYPE` / `O` labels.
133/// - Implicit entity boundary changes (a new `B-` tag after an `I-` tag of a
134///   different type, or a `B-` of the same type, terminates the current span).
135///
136/// # Errors
137///
138/// Returns an empty `Vec` rather than an error on malformed label sequences
139/// (e.g. a bare `I-TYPE` with no preceding `B-TYPE`) — the continuation token
140/// is treated as a fresh span start.
141pub fn extract_spans_from_bio(tokens: &[&str], labels: &[&str]) -> Vec<NerSpan> {
142    if tokens.len() != labels.len() {
143        return Vec::new();
144    }
145
146    let mut spans: Vec<NerSpan> = Vec::new();
147    let mut current: Option<(usize, String)> = None; // (start, label)
148
149    for (i, label) in labels.iter().enumerate() {
150        let tag = *label;
151
152        if tag == "O" || tag.is_empty() {
153            // Close any open span.
154            if let Some((start, lbl)) = current.take() {
155                spans.push(NerSpan::new(start, i, lbl));
156            }
157        } else if let Some(rest) = tag.strip_prefix("B-") {
158            // New entity begins — close previous span.
159            if let Some((start, lbl)) = current.take() {
160                spans.push(NerSpan::new(start, i, lbl));
161            }
162            current = Some((i, rest.to_string()));
163        } else if let Some(rest) = tag.strip_prefix("I-") {
164            // Continuation — only valid if entity type matches.
165            match &current {
166                Some((_, lbl)) if lbl == rest => {
167                    // Same type: continue
168                }
169                _ => {
170                    // Type mismatch or no open span: close old, start new.
171                    if let Some((start, lbl)) = current.take() {
172                        spans.push(NerSpan::new(start, i, lbl));
173                    }
174                    current = Some((i, rest.to_string()));
175                }
176            }
177        } else {
178            // Unknown scheme — treat the whole tag as a B-tag.
179            if let Some((start, lbl)) = current.take() {
180                spans.push(NerSpan::new(start, i, lbl));
181            }
182            current = Some((i, tag.to_string()));
183        }
184    }
185
186    // Close trailing span.
187    if let Some((start, lbl)) = current.take() {
188        spans.push(NerSpan::new(start, labels.len(), lbl));
189    }
190
191    spans
192}
193
194// ─── Evaluation ──────────────────────────────────────────────────────────────
195
196/// Compute CoNLL-2003 NER evaluation metrics.
197///
198/// Evaluation is performed at the **span level**: a prediction is a true
199/// positive only when both token boundaries **and** the entity type label
200/// exactly match a gold span.
201///
202/// Returns micro-averaged precision/recall/F1 plus per-class breakdowns.
203pub fn evaluate_ner(predictions: &[Vec<NerSpan>], gold: &[Vec<NerSpan>]) -> NerEvaluationResult {
204    if predictions.len() != gold.len() {
205        // Return zeroed result for mismatched batch sizes.
206        return NerEvaluationResult {
207            precision: 0.0,
208            recall: 0.0,
209            f1: 0.0,
210            per_class: HashMap::new(),
211            exact_match: 0.0,
212        };
213    }
214
215    // Per-class counters: label → (tp, fp, fn_)
216    let mut class_tp: HashMap<String, usize> = HashMap::new();
217    let mut class_fp: HashMap<String, usize> = HashMap::new();
218    let mut class_fn: HashMap<String, usize> = HashMap::new();
219
220    let mut exact_match_count = 0usize;
221
222    for (pred_spans, gold_spans) in predictions.iter().zip(gold.iter()) {
223        // Convert spans to HashSet for O(1) lookup.
224        use std::collections::HashSet;
225        let gold_set: HashSet<&NerSpan> = gold_spans.iter().collect();
226        let pred_set: HashSet<&NerSpan> = pred_spans.iter().collect();
227
228        for span in pred_spans {
229            let label = &span.label;
230            if gold_set.contains(span) {
231                *class_tp.entry(label.clone()).or_insert(0) += 1;
232            } else {
233                *class_fp.entry(label.clone()).or_insert(0) += 1;
234            }
235        }
236
237        for span in gold_spans {
238            let label = &span.label;
239            if !pred_set.contains(span) {
240                *class_fn.entry(label.clone()).or_insert(0) += 1;
241            }
242        }
243
244        // Exact match: predicted set equals gold set.
245        if pred_set == gold_set {
246            exact_match_count += 1;
247        }
248    }
249
250    // Collect all known labels.
251    let mut all_labels: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
252    for k in class_tp
253        .keys()
254        .chain(class_fp.keys())
255        .chain(class_fn.keys())
256    {
257        all_labels.insert(k.clone());
258    }
259
260    let mut per_class: HashMap<String, ClassMetrics> = HashMap::new();
261    let mut total_tp = 0usize;
262    let mut total_fp = 0usize;
263    let mut total_fn = 0usize;
264
265    for label in &all_labels {
266        let tp = *class_tp.get(label).unwrap_or(&0);
267        let fp = *class_fp.get(label).unwrap_or(&0);
268        let fn_ = *class_fn.get(label).unwrap_or(&0);
269
270        per_class.insert(label.clone(), ClassMetrics::from_counts(tp, fp, fn_));
271
272        total_tp += tp;
273        total_fp += fp;
274        total_fn += fn_;
275    }
276
277    let overall = ClassMetrics::from_counts(total_tp, total_fp, total_fn);
278    let n_sequences = predictions.len();
279    let exact_match = if n_sequences == 0 {
280        0.0
281    } else {
282        exact_match_count as f64 / n_sequences as f64
283    };
284
285    NerEvaluationResult {
286        precision: overall.precision,
287        recall: overall.recall,
288        f1: overall.f1,
289        per_class,
290        exact_match,
291    }
292}
293
294// ─── CoNLL-style report formatting ───────────────────────────────────────────
295
296/// Format an [`NerEvaluationResult`] as a `conlleval`-style text report.
297///
298/// The output closely follows the style produced by the official `conlleval.pl`
299/// Perl script distributed with the CoNLL-2003 shared task:
300///
301/// ```text
302/// processed N tokens; found: K phrases; correct: J
303/// accuracy: XX.XX%; precision: XX.XX%; recall: XX.XX%; FB1: XX.XX
304///        LOC: precision XX.XX%; recall XX.XX%; FB1: XX.XX  N
305/// ```
306pub fn conll_format_report(result: &NerEvaluationResult) -> String {
307    let mut lines: Vec<String> = Vec::new();
308
309    // Compute aggregate counts from per-class metrics.
310    let total_gold: usize = result.per_class.values().map(|m| m.support).sum();
311    let total_pred: usize = result.per_class.values().map(|m| m.tp + m.fp).sum();
312    let total_correct: usize = result.per_class.values().map(|m| m.tp).sum();
313
314    lines.push(format!(
315        "processed N tokens with {} phrases; found: {} phrases; correct: {}",
316        total_gold, total_pred, total_correct,
317    ));
318
319    lines.push(format!(
320        "accuracy:  N/A; precision:  {:6.2}%; recall:  {:6.2}%; FB1:  {:6.2}",
321        result.precision * 100.0,
322        result.recall * 100.0,
323        result.f1 * 100.0,
324    ));
325
326    // Per-class rows sorted alphabetically.
327    let mut labels: Vec<&String> = result.per_class.keys().collect();
328    labels.sort();
329
330    for label in labels {
331        let m = &result.per_class[label];
332        lines.push(format!(
333            "  {:>8}: precision:  {:6.2}%; recall:  {:6.2}%; FB1:  {:6.2}  {}",
334            label,
335            m.precision * 100.0,
336            m.recall * 100.0,
337            m.f1 * 100.0,
338            m.support,
339        ));
340    }
341
342    lines.join("\n")
343}
344
345/// Convenience wrapper that returns an error instead of silently degrading when
346/// input lengths are mismatched.
347pub fn evaluate_ner_checked(
348    predictions: &[Vec<NerSpan>],
349    gold: &[Vec<NerSpan>],
350) -> Result<NerEvaluationResult, TextError> {
351    if predictions.len() != gold.len() {
352        return Err(TextError::InvalidInput(format!(
353            "Prediction batch size ({}) != gold batch size ({})",
354            predictions.len(),
355            gold.len()
356        )));
357    }
358    Ok(evaluate_ner(predictions, gold))
359}
360
361// ─── Tests ────────────────────────────────────────────────────────────────────
362
363#[cfg(test)]
364mod tests {
365    use super::*;
366
367    // ── BIO extraction ────────────────────────────────────────────────────
368
369    #[test]
370    fn test_bio_extraction_basic() {
371        let tokens = vec!["John", "Smith", "works", "at", "Google", "."];
372        let labels = vec!["B-PER", "I-PER", "O", "O", "B-ORG", "O"];
373        let spans = extract_spans_from_bio(&tokens, &labels);
374
375        assert_eq!(spans.len(), 2);
376        assert_eq!(spans[0], NerSpan::new(0, 2, "PER"));
377        assert_eq!(spans[1], NerSpan::new(4, 5, "ORG"));
378    }
379
380    #[test]
381    fn test_bio_extraction_single_token_entity() {
382        let tokens = vec!["Paris", "is", "beautiful"];
383        let labels = vec!["B-LOC", "O", "O"];
384        let spans = extract_spans_from_bio(&tokens, &labels);
385
386        assert_eq!(spans.len(), 1);
387        assert_eq!(spans[0], NerSpan::new(0, 1, "LOC"));
388    }
389
390    #[test]
391    fn test_bio_extraction_all_outside() {
392        let tokens = vec!["the", "cat", "sat"];
393        let labels = vec!["O", "O", "O"];
394        let spans = extract_spans_from_bio(&tokens, &labels);
395        assert!(spans.is_empty());
396    }
397
398    #[test]
399    fn test_bio_extraction_trailing_entity() {
400        let tokens = vec!["Visit", "New", "York"];
401        let labels = vec!["O", "B-LOC", "I-LOC"];
402        let spans = extract_spans_from_bio(&tokens, &labels);
403
404        assert_eq!(spans.len(), 1);
405        assert_eq!(spans[0], NerSpan::new(1, 3, "LOC"));
406    }
407
408    #[test]
409    fn test_bio_extraction_nested_b_tags() {
410        // Two consecutive B- tags without an I- between them.
411        let tokens = vec!["Apple", "Inc", "Google", "LLC"];
412        let labels = vec!["B-ORG", "I-ORG", "B-ORG", "I-ORG"];
413        let spans = extract_spans_from_bio(&tokens, &labels);
414
415        assert_eq!(spans.len(), 2);
416        assert_eq!(spans[0], NerSpan::new(0, 2, "ORG"));
417        assert_eq!(spans[1], NerSpan::new(2, 4, "ORG"));
418    }
419
420    #[test]
421    fn test_bio_extraction_type_change() {
422        // I-ORG after B-PER: different type — creates a new implicit span.
423        let tokens = vec!["John", "Google"];
424        let labels = vec!["B-PER", "I-ORG"];
425        let spans = extract_spans_from_bio(&tokens, &labels);
426
427        assert_eq!(spans.len(), 2);
428        assert_eq!(spans[0].label, "PER");
429        assert_eq!(spans[1].label, "ORG");
430    }
431
432    #[test]
433    fn test_bio_extraction_mismatched_lengths() {
434        let tokens = vec!["John"];
435        let labels = vec!["B-PER", "O"];
436        let spans = extract_spans_from_bio(&tokens, &labels);
437        assert!(spans.is_empty()); // graceful degradation
438    }
439
440    // ── NER evaluation ────────────────────────────────────────────────────
441
442    #[test]
443    fn test_ner_eval_perfect() {
444        let gold = vec![vec![NerSpan::new(0, 2, "PER"), NerSpan::new(4, 5, "ORG")]];
445        let pred = gold.clone();
446
447        let result = evaluate_ner(&pred, &gold);
448        assert!((result.precision - 1.0).abs() < 1e-9);
449        assert!((result.recall - 1.0).abs() < 1e-9);
450        assert!((result.f1 - 1.0).abs() < 1e-9);
451        assert!((result.exact_match - 1.0).abs() < 1e-9);
452    }
453
454    #[test]
455    fn test_ner_eval_no_predictions() {
456        let gold = vec![vec![NerSpan::new(0, 2, "PER")]];
457        let pred = vec![vec![]];
458
459        let result = evaluate_ner(&pred, &gold);
460        assert!((result.precision - 0.0).abs() < 1e-9);
461        assert!((result.recall - 0.0).abs() < 1e-9);
462        assert!((result.f1 - 0.0).abs() < 1e-9);
463    }
464
465    #[test]
466    fn test_ner_eval_partial_match() {
467        // Prediction has wrong end boundary → FP + FN.
468        let gold = vec![vec![NerSpan::new(0, 3, "PER")]];
469        let pred = vec![vec![NerSpan::new(0, 2, "PER")]];
470
471        let result = evaluate_ner(&pred, &gold);
472        // 0 TP, 1 FP, 1 FN → precision=0, recall=0, f1=0
473        assert!((result.precision - 0.0).abs() < 1e-9);
474        assert!((result.recall - 0.0).abs() < 1e-9);
475    }
476
477    #[test]
478    fn test_ner_eval_wrong_label() {
479        // Same span boundaries but wrong label.
480        let gold = vec![vec![NerSpan::new(0, 2, "PER")]];
481        let pred = vec![vec![NerSpan::new(0, 2, "ORG")]];
482
483        let result = evaluate_ner(&pred, &gold);
484        assert!((result.f1 - 0.0).abs() < 1e-9);
485    }
486
487    #[test]
488    fn test_ner_eval_per_class_metrics() {
489        let gold = vec![vec![
490            NerSpan::new(0, 2, "PER"),
491            NerSpan::new(3, 5, "ORG"),
492            NerSpan::new(6, 8, "LOC"),
493        ]];
494        let pred = vec![vec![
495            NerSpan::new(0, 2, "PER"), // TP for PER
496            // ORG missed → FN
497            NerSpan::new(6, 8, "LOC"),  // TP for LOC
498            NerSpan::new(9, 11, "ORG"), // FP for ORG
499        ]];
500
501        let result = evaluate_ner(&pred, &gold);
502
503        let per = result.per_class.get("PER").expect("PER class");
504        assert_eq!(per.tp, 1);
505        assert_eq!(per.fp, 0);
506        assert_eq!(per.fn_, 0);
507
508        let org = result.per_class.get("ORG").expect("ORG class");
509        // 0 TP, 1 FP (9-11), 1 FN (3-5)
510        assert_eq!(org.tp, 0);
511        assert_eq!(org.fp, 1);
512        assert_eq!(org.fn_, 1);
513
514        let loc = result.per_class.get("LOC").expect("LOC class");
515        assert_eq!(loc.tp, 1);
516        assert_eq!(loc.fp, 0);
517        assert_eq!(loc.fn_, 0);
518    }
519
520    #[test]
521    fn test_ner_eval_multiple_sequences() {
522        let gold = vec![
523            vec![NerSpan::new(0, 1, "PER")],
524            vec![NerSpan::new(0, 1, "ORG")],
525        ];
526        let pred = vec![
527            vec![NerSpan::new(0, 1, "PER")], // correct
528            vec![NerSpan::new(0, 1, "PER")], // wrong label
529        ];
530
531        let result = evaluate_ner(&pred, &gold);
532        // PER: 2 TP total? No — the second seq has pred PER but gold ORG.
533        // Seq 1: PER TP=1. Seq 2: ORG FN=1, PER FP=1.
534        assert!((result.precision - 0.5).abs() < 1e-9);
535        assert!((result.recall - 0.5).abs() < 1e-9);
536    }
537
538    #[test]
539    fn test_ner_eval_exact_match_rate() {
540        let gold = vec![
541            vec![NerSpan::new(0, 1, "PER")],
542            vec![NerSpan::new(0, 1, "ORG")],
543        ];
544        let pred = vec![
545            vec![NerSpan::new(0, 1, "PER")], // exact
546            vec![NerSpan::new(0, 1, "PER")], // not exact
547        ];
548
549        let result = evaluate_ner(&pred, &gold);
550        assert!((result.exact_match - 0.5).abs() < 1e-9);
551    }
552
553    #[test]
554    fn test_conll_report_format() {
555        let gold = vec![vec![NerSpan::new(0, 2, "PER")]];
556        let pred = gold.clone();
557        let result = evaluate_ner(&pred, &gold);
558        let report = conll_format_report(&result);
559
560        assert!(report.contains("precision"));
561        assert!(report.contains("recall"));
562        assert!(report.contains("FB1"));
563        assert!(report.contains("PER"));
564        assert!(report.contains("100.00"));
565    }
566
567    #[test]
568    fn test_ner_checked_length_mismatch() {
569        let gold = vec![vec![NerSpan::new(0, 1, "PER")]];
570        let pred: Vec<Vec<NerSpan>> = vec![];
571        assert!(evaluate_ner_checked(&pred, &gold).is_err());
572    }
573
574    #[test]
575    fn test_bio_extraction_multiple_types() {
576        let tokens = vec!["The", "EU", "said", "Angela", "Merkel", "from", "Germany"];
577        let labels = vec!["O", "B-ORG", "O", "B-PER", "I-PER", "O", "B-LOC"];
578
579        let spans = extract_spans_from_bio(&tokens, &labels);
580        assert_eq!(spans.len(), 3);
581        assert_eq!(spans[0], NerSpan::new(1, 2, "ORG"));
582        assert_eq!(spans[1], NerSpan::new(3, 5, "PER"));
583        assert_eq!(spans[2], NerSpan::new(6, 7, "LOC"));
584    }
585}