1use crate::error::{SeqError, SeqResult};
22use crate::tagging::bioes::{Span, Tag, extract_spans};
23use std::collections::BTreeMap;
24
25#[derive(Debug, Clone, PartialEq)]
29pub struct PrfScore {
30 pub tp: usize,
32 pub fp: usize,
34 pub fn_: usize,
36 pub precision: f64,
38 pub recall: f64,
40 pub f1: f64,
42}
43
44impl PrfScore {
45 #[must_use]
47 pub fn from_counts(tp: usize, fp: usize, fn_: usize) -> Self {
48 let precision = if tp + fp == 0 {
49 0.0
50 } else {
51 tp as f64 / (tp + fp) as f64
52 };
53 let recall = if tp + fn_ == 0 {
54 0.0
55 } else {
56 tp as f64 / (tp + fn_) as f64
57 };
58 let f1 = if precision + recall == 0.0 {
59 0.0
60 } else {
61 2.0 * precision * recall / (precision + recall)
62 };
63 Self {
64 tp,
65 fp,
66 fn_,
67 precision,
68 recall,
69 f1,
70 }
71 }
72}
73
74#[derive(Debug, Clone)]
76pub struct SpanF1Report {
77 pub overall: PrfScore,
79 pub per_type: BTreeMap<String, PrfScore>,
81}
82
83fn match_spans(gold: &[Span], pred: &[Span]) -> (usize, usize, usize) {
90 let mut remaining: BTreeMap<(String, usize, usize), usize> = BTreeMap::new();
92 for g in gold {
93 *remaining
94 .entry((g.entity_type.clone(), g.start, g.end))
95 .or_insert(0) += 1;
96 }
97 let mut tp = 0usize;
98 for p in pred {
99 let key = (p.entity_type.clone(), p.start, p.end);
100 if let Some(cnt) = remaining.get_mut(&key) {
101 if *cnt > 0 {
102 *cnt -= 1;
103 tp += 1;
104 }
105 }
106 }
107 (tp, pred.len(), gold.len())
108}
109
110#[must_use]
112pub fn span_f1_from_spans(gold: &[Span], pred: &[Span]) -> SpanF1Report {
113 let (tp, n_pred, n_gold) = match_spans(gold, pred);
114 let overall = PrfScore::from_counts(tp, n_pred - tp, n_gold - tp);
115
116 let mut types: BTreeMap<String, (Vec<Span>, Vec<Span>)> = BTreeMap::new();
118 for g in gold {
119 types
120 .entry(g.entity_type.clone())
121 .or_default()
122 .0
123 .push(g.clone());
124 }
125 for p in pred {
126 types
127 .entry(p.entity_type.clone())
128 .or_default()
129 .1
130 .push(p.clone());
131 }
132 let mut per_type = BTreeMap::new();
133 for (ty, (g, p)) in types {
134 let (t, np, ng) = match_spans(&g, &p);
135 per_type.insert(ty, PrfScore::from_counts(t, np - t, ng - t));
136 }
137
138 SpanF1Report { overall, per_type }
139}
140
141pub fn span_f1(gold: &[Tag], pred: &[Tag]) -> SeqResult<SpanF1Report> {
150 if gold.len() != pred.len() {
151 return Err(SeqError::LengthMismatch {
152 a: gold.len(),
153 b: pred.len(),
154 });
155 }
156 if gold.is_empty() {
157 return Err(SeqError::EmptyInput);
158 }
159 let gspans = extract_spans(gold);
160 let pspans = extract_spans(pred);
161 Ok(span_f1_from_spans(&gspans, &pspans))
162}
163
164#[cfg(test)]
167mod tests {
168 use super::*;
169 use crate::tagging::bioes::parse_tags;
170
171 fn tags(strs: &[&str]) -> Vec<Tag> {
172 parse_tags(strs).expect("parse")
173 }
174
175 #[test]
176 fn prf_from_counts_basic() {
177 let s = PrfScore::from_counts(3, 1, 1);
178 assert_eq!(s.tp, 3);
179 assert!((s.precision - 0.75).abs() < 1e-12);
180 assert!((s.recall - 0.75).abs() < 1e-12);
181 assert!((s.f1 - 0.75).abs() < 1e-12);
182 }
183
184 #[test]
185 fn prf_zero_predictions() {
186 let s = PrfScore::from_counts(0, 0, 4);
187 assert_eq!(s.precision, 0.0);
188 assert_eq!(s.recall, 0.0);
189 assert_eq!(s.f1, 0.0);
190 }
191
192 #[test]
193 fn perfect_match_is_one() {
194 let g = tags(&["B-PER", "I-PER", "O", "S-LOC"]);
195 let report = span_f1(&g, &g).expect("ok");
196 assert!((report.overall.f1 - 1.0).abs() < 1e-12);
197 assert!((report.overall.precision - 1.0).abs() < 1e-12);
198 assert!((report.overall.recall - 1.0).abs() < 1e-12);
199 assert_eq!(report.overall.tp, 2);
200 }
201
202 #[test]
203 fn boundary_error_counts_as_miss() {
204 let g = tags(&["B-PER", "I-PER", "O"]);
206 let p = tags(&["B-PER", "O", "O"]);
207 let r = span_f1(&g, &p).expect("ok");
208 assert_eq!(r.overall.tp, 0); assert_eq!(r.overall.fp, 1); assert_eq!(r.overall.fn_, 1); assert_eq!(r.overall.f1, 0.0);
212 }
213
214 #[test]
215 fn type_error_counts_as_miss() {
216 let g = tags(&["S-PER"]);
218 let p = tags(&["S-LOC"]);
219 let r = span_f1(&g, &p).expect("ok");
220 assert_eq!(r.overall.tp, 0);
221 assert_eq!(r.overall.fp, 1);
222 assert_eq!(r.overall.fn_, 1);
223 }
224
225 #[test]
226 fn partial_credit_micro_average() {
227 let g = tags(&["B-PER", "I-PER", "O", "S-LOC", "O"]);
229 let p = tags(&["B-PER", "I-PER", "O", "O", "S-LOC"]);
230 let r = span_f1(&g, &p).expect("ok");
231 assert_eq!(r.overall.tp, 1);
233 assert_eq!(r.overall.fp, 1);
234 assert_eq!(r.overall.fn_, 1);
235 assert!((r.overall.precision - 0.5).abs() < 1e-12);
236 assert!((r.overall.recall - 0.5).abs() < 1e-12);
237 assert!((r.overall.f1 - 0.5).abs() < 1e-12);
238 }
239
240 #[test]
241 fn per_type_breakdown() {
242 let g = tags(&["B-PER", "I-PER", "S-LOC", "S-ORG"]);
243 let p = tags(&["B-PER", "I-PER", "S-LOC", "O"]);
244 let r = span_f1(&g, &p).expect("ok");
245 assert!((r.per_type["PER"].f1 - 1.0).abs() < 1e-12);
247 assert!((r.per_type["LOC"].f1 - 1.0).abs() < 1e-12);
248 assert_eq!(r.per_type["ORG"].tp, 0);
249 assert_eq!(r.per_type["ORG"].fn_, 1);
250 }
251
252 #[test]
253 fn length_mismatch_errors() {
254 let g = tags(&["O", "O"]);
255 let p = tags(&["O"]);
256 assert!(matches!(
257 span_f1(&g, &p),
258 Err(SeqError::LengthMismatch { .. })
259 ));
260 }
261
262 #[test]
263 fn empty_errors() {
264 assert!(matches!(span_f1(&[], &[]), Err(SeqError::EmptyInput)));
265 }
266
267 #[test]
268 fn all_outside_gives_zero_spans() {
269 let g = tags(&["O", "O", "O"]);
270 let p = tags(&["O", "O", "O"]);
271 let r = span_f1(&g, &p).expect("ok");
272 assert_eq!(r.overall.tp, 0);
274 assert_eq!(r.overall.fp, 0);
275 assert_eq!(r.overall.fn_, 0);
276 assert_eq!(r.overall.f1, 0.0);
277 }
278
279 #[test]
280 fn precision_recall_can_differ() {
281 let g = tags(&["S-PER", "O", "O"]);
283 let p = tags(&["S-PER", "S-LOC", "O"]);
284 let r = span_f1(&g, &p).expect("ok");
285 assert_eq!(r.overall.tp, 1);
286 assert_eq!(r.overall.fp, 1);
287 assert_eq!(r.overall.fn_, 0);
288 assert!((r.overall.precision - 0.5).abs() < 1e-12);
289 assert!((r.overall.recall - 1.0).abs() < 1e-12);
290 }
291
292 #[test]
293 fn duplicate_predictions_only_one_tp() {
294 let g = [Span {
296 entity_type: "PER".into(),
297 start: 0,
298 end: 0,
299 }];
300 let p = [
301 Span {
302 entity_type: "PER".into(),
303 start: 0,
304 end: 0,
305 },
306 Span {
307 entity_type: "PER".into(),
308 start: 0,
309 end: 0,
310 },
311 ];
312 let r = span_f1_from_spans(&g, &p);
313 assert_eq!(r.overall.tp, 1);
314 assert_eq!(r.overall.fp, 1);
315 }
316
317 #[test]
318 fn from_spans_matches_tag_path() {
319 let g = tags(&["B-PER", "E-PER", "O"]);
320 let p = tags(&["B-PER", "E-PER", "S-LOC"]);
321 let via_tags = span_f1(&g, &p).expect("ok");
322 let gs = extract_spans(&g);
323 let ps = extract_spans(&p);
324 let via_spans = span_f1_from_spans(&gs, &ps);
325 assert_eq!(via_tags.overall, via_spans.overall);
326 }
327}