Skip to main content

oxicuda_seq/tagging/
span_f1.rs

1//! Span-based (entity-level) precision / recall / F1 — the `seqeval` metric.
2//!
3//! Token-level accuracy is a poor measure for sequence labelling: a model that
4//! gets every `O` right but mangles entity boundaries can still score highly.
5//! The standard NER / chunking metric (CoNLL-2000/2003, implemented by the
6//! `seqeval` library) instead compares **entity spans**: a predicted span is a
7//! true positive only when an identical span — same entity type **and** same
8//! start/end boundaries — appears in the gold sequence.
9//!
10//! Given gold tags `y` and predicted tags `ŷ`, spans are extracted with
11//! [`crate::tagging::bioes::extract_spans`] and matched exactly:
12//!
13//! ```text
14//! precision = |gold ∩ pred| / |pred|
15//! recall    = |gold ∩ pred| / |gold|
16//! f1        = 2 P R / (P + R)
17//! ```
18//!
19//! Per-type breakdowns and a micro-averaged total are provided.
20
21use crate::error::{SeqError, SeqResult};
22use crate::tagging::bioes::{Span, Tag, extract_spans};
23use std::collections::BTreeMap;
24
25// ─── Counts ──────────────────────────────────────────────────────────────────
26
27/// Precision / recall / F1 with the raw TP/FP/FN counts behind them.
28#[derive(Debug, Clone, PartialEq)]
29pub struct PrfScore {
30    /// True positives (exactly-matched spans).
31    pub tp: usize,
32    /// False positives (predicted spans with no gold match).
33    pub fp: usize,
34    /// False negatives (gold spans with no predicted match).
35    pub fn_: usize,
36    /// Precision `TP / (TP + FP)` (0 when no predictions).
37    pub precision: f64,
38    /// Recall `TP / (TP + FN)` (0 when no gold spans).
39    pub recall: f64,
40    /// Harmonic-mean F1 (0 when `P + R == 0`).
41    pub f1: f64,
42}
43
44impl PrfScore {
45    /// Build a [`PrfScore`] from raw counts, computing P/R/F1.
46    #[must_use]
47    pub fn from_counts(tp: usize, fp: usize, fn_: usize) -> Self {
48        let precision = if tp + fp == 0 {
49            0.0
50        } else {
51            tp as f64 / (tp + fp) as f64
52        };
53        let recall = if tp + fn_ == 0 {
54            0.0
55        } else {
56            tp as f64 / (tp + fn_) as f64
57        };
58        let f1 = if precision + recall == 0.0 {
59            0.0
60        } else {
61            2.0 * precision * recall / (precision + recall)
62        };
63        Self {
64            tp,
65            fp,
66            fn_,
67            precision,
68            recall,
69            f1,
70        }
71    }
72}
73
74/// Full span-F1 report: a micro-averaged overall score plus per-type scores.
75#[derive(Debug, Clone)]
76pub struct SpanF1Report {
77    /// Micro-averaged score over all entity types.
78    pub overall: PrfScore,
79    /// Per-entity-type scores, keyed by type name (sorted).
80    pub per_type: BTreeMap<String, PrfScore>,
81}
82
83// ─── Span multiset matching ──────────────────────────────────────────────────
84
85/// Count exactly-matched spans (true positives) between two span lists.
86///
87/// Spans are matched as a multiset: each gold span can satisfy at most one
88/// predicted span.  Returns `(tp, n_pred, n_gold)`.
89fn match_spans(gold: &[Span], pred: &[Span]) -> (usize, usize, usize) {
90    // Multiset of gold spans by value.
91    let mut remaining: BTreeMap<(String, usize, usize), usize> = BTreeMap::new();
92    for g in gold {
93        *remaining
94            .entry((g.entity_type.clone(), g.start, g.end))
95            .or_insert(0) += 1;
96    }
97    let mut tp = 0usize;
98    for p in pred {
99        let key = (p.entity_type.clone(), p.start, p.end);
100        if let Some(cnt) = remaining.get_mut(&key) {
101            if *cnt > 0 {
102                *cnt -= 1;
103                tp += 1;
104            }
105        }
106    }
107    (tp, pred.len(), gold.len())
108}
109
110/// Compute the micro-averaged span-F1 from pre-extracted gold/predicted spans.
111#[must_use]
112pub fn span_f1_from_spans(gold: &[Span], pred: &[Span]) -> SpanF1Report {
113    let (tp, n_pred, n_gold) = match_spans(gold, pred);
114    let overall = PrfScore::from_counts(tp, n_pred - tp, n_gold - tp);
115
116    // Per-type.
117    let mut types: BTreeMap<String, (Vec<Span>, Vec<Span>)> = BTreeMap::new();
118    for g in gold {
119        types
120            .entry(g.entity_type.clone())
121            .or_default()
122            .0
123            .push(g.clone());
124    }
125    for p in pred {
126        types
127            .entry(p.entity_type.clone())
128            .or_default()
129            .1
130            .push(p.clone());
131    }
132    let mut per_type = BTreeMap::new();
133    for (ty, (g, p)) in types {
134        let (t, np, ng) = match_spans(&g, &p);
135        per_type.insert(ty, PrfScore::from_counts(t, np - t, ng - t));
136    }
137
138    SpanF1Report { overall, per_type }
139}
140
141/// Compute span-F1 directly from gold and predicted **tag** sequences.
142///
143/// Both sequences must have equal length (one tag per token).
144///
145/// # Errors
146///
147/// * [`SeqError::LengthMismatch`] — if `gold.len() ≠ pred.len()`.
148/// * [`SeqError::EmptyInput`]     — if the sequences are empty.
149pub fn span_f1(gold: &[Tag], pred: &[Tag]) -> SeqResult<SpanF1Report> {
150    if gold.len() != pred.len() {
151        return Err(SeqError::LengthMismatch {
152            a: gold.len(),
153            b: pred.len(),
154        });
155    }
156    if gold.is_empty() {
157        return Err(SeqError::EmptyInput);
158    }
159    let gspans = extract_spans(gold);
160    let pspans = extract_spans(pred);
161    Ok(span_f1_from_spans(&gspans, &pspans))
162}
163
164// ─── Tests ───────────────────────────────────────────────────────────────────
165
166#[cfg(test)]
167mod tests {
168    use super::*;
169    use crate::tagging::bioes::parse_tags;
170
171    fn tags(strs: &[&str]) -> Vec<Tag> {
172        parse_tags(strs).expect("parse")
173    }
174
175    #[test]
176    fn prf_from_counts_basic() {
177        let s = PrfScore::from_counts(3, 1, 1);
178        assert_eq!(s.tp, 3);
179        assert!((s.precision - 0.75).abs() < 1e-12);
180        assert!((s.recall - 0.75).abs() < 1e-12);
181        assert!((s.f1 - 0.75).abs() < 1e-12);
182    }
183
184    #[test]
185    fn prf_zero_predictions() {
186        let s = PrfScore::from_counts(0, 0, 4);
187        assert_eq!(s.precision, 0.0);
188        assert_eq!(s.recall, 0.0);
189        assert_eq!(s.f1, 0.0);
190    }
191
192    #[test]
193    fn perfect_match_is_one() {
194        let g = tags(&["B-PER", "I-PER", "O", "S-LOC"]);
195        let report = span_f1(&g, &g).expect("ok");
196        assert!((report.overall.f1 - 1.0).abs() < 1e-12);
197        assert!((report.overall.precision - 1.0).abs() < 1e-12);
198        assert!((report.overall.recall - 1.0).abs() < 1e-12);
199        assert_eq!(report.overall.tp, 2);
200    }
201
202    #[test]
203    fn boundary_error_counts_as_miss() {
204        // gold PER spans [0..1]; pred PER span [0..0] — different boundary.
205        let g = tags(&["B-PER", "I-PER", "O"]);
206        let p = tags(&["B-PER", "O", "O"]);
207        let r = span_f1(&g, &p).expect("ok");
208        assert_eq!(r.overall.tp, 0); // no exact match
209        assert_eq!(r.overall.fp, 1); // predicted PER[0..0]
210        assert_eq!(r.overall.fn_, 1); // gold PER[0..1] missed
211        assert_eq!(r.overall.f1, 0.0);
212    }
213
214    #[test]
215    fn type_error_counts_as_miss() {
216        // Same boundary, wrong type.
217        let g = tags(&["S-PER"]);
218        let p = tags(&["S-LOC"]);
219        let r = span_f1(&g, &p).expect("ok");
220        assert_eq!(r.overall.tp, 0);
221        assert_eq!(r.overall.fp, 1);
222        assert_eq!(r.overall.fn_, 1);
223    }
224
225    #[test]
226    fn partial_credit_micro_average() {
227        // gold: PER[0..1], LOC[3]; pred: PER[0..1], LOC[4] (LOC boundary wrong)
228        let g = tags(&["B-PER", "I-PER", "O", "S-LOC", "O"]);
229        let p = tags(&["B-PER", "I-PER", "O", "O", "S-LOC"]);
230        let r = span_f1(&g, &p).expect("ok");
231        // 1 TP (PER), 1 FP (LOC@4), 1 FN (LOC@3).
232        assert_eq!(r.overall.tp, 1);
233        assert_eq!(r.overall.fp, 1);
234        assert_eq!(r.overall.fn_, 1);
235        assert!((r.overall.precision - 0.5).abs() < 1e-12);
236        assert!((r.overall.recall - 0.5).abs() < 1e-12);
237        assert!((r.overall.f1 - 0.5).abs() < 1e-12);
238    }
239
240    #[test]
241    fn per_type_breakdown() {
242        let g = tags(&["B-PER", "I-PER", "S-LOC", "S-ORG"]);
243        let p = tags(&["B-PER", "I-PER", "S-LOC", "O"]);
244        let r = span_f1(&g, &p).expect("ok");
245        // PER perfect, LOC perfect, ORG missed.
246        assert!((r.per_type["PER"].f1 - 1.0).abs() < 1e-12);
247        assert!((r.per_type["LOC"].f1 - 1.0).abs() < 1e-12);
248        assert_eq!(r.per_type["ORG"].tp, 0);
249        assert_eq!(r.per_type["ORG"].fn_, 1);
250    }
251
252    #[test]
253    fn length_mismatch_errors() {
254        let g = tags(&["O", "O"]);
255        let p = tags(&["O"]);
256        assert!(matches!(
257            span_f1(&g, &p),
258            Err(SeqError::LengthMismatch { .. })
259        ));
260    }
261
262    #[test]
263    fn empty_errors() {
264        assert!(matches!(span_f1(&[], &[]), Err(SeqError::EmptyInput)));
265    }
266
267    #[test]
268    fn all_outside_gives_zero_spans() {
269        let g = tags(&["O", "O", "O"]);
270        let p = tags(&["O", "O", "O"]);
271        let r = span_f1(&g, &p).expect("ok");
272        // No spans at all: P/R/F1 conventionally 0 (seqeval returns 0).
273        assert_eq!(r.overall.tp, 0);
274        assert_eq!(r.overall.fp, 0);
275        assert_eq!(r.overall.fn_, 0);
276        assert_eq!(r.overall.f1, 0.0);
277    }
278
279    #[test]
280    fn precision_recall_can_differ() {
281        // gold has 1 span, pred has 2 (one matching, one spurious).
282        let g = tags(&["S-PER", "O", "O"]);
283        let p = tags(&["S-PER", "S-LOC", "O"]);
284        let r = span_f1(&g, &p).expect("ok");
285        assert_eq!(r.overall.tp, 1);
286        assert_eq!(r.overall.fp, 1);
287        assert_eq!(r.overall.fn_, 0);
288        assert!((r.overall.precision - 0.5).abs() < 1e-12);
289        assert!((r.overall.recall - 1.0).abs() < 1e-12);
290    }
291
292    #[test]
293    fn duplicate_predictions_only_one_tp() {
294        // Two identical predicted spans, one gold: exactly one TP, one FP.
295        let g = [Span {
296            entity_type: "PER".into(),
297            start: 0,
298            end: 0,
299        }];
300        let p = [
301            Span {
302                entity_type: "PER".into(),
303                start: 0,
304                end: 0,
305            },
306            Span {
307                entity_type: "PER".into(),
308                start: 0,
309                end: 0,
310            },
311        ];
312        let r = span_f1_from_spans(&g, &p);
313        assert_eq!(r.overall.tp, 1);
314        assert_eq!(r.overall.fp, 1);
315    }
316
317    #[test]
318    fn from_spans_matches_tag_path() {
319        let g = tags(&["B-PER", "E-PER", "O"]);
320        let p = tags(&["B-PER", "E-PER", "S-LOC"]);
321        let via_tags = span_f1(&g, &p).expect("ok");
322        let gs = extract_spans(&g);
323        let ps = extract_spans(&p);
324        let via_spans = span_f1_from_spans(&gs, &ps);
325        assert_eq!(via_tags.overall, via_spans.overall);
326    }
327}