1use crate::error::TextError;
30use std::collections::HashMap;
31use std::fmt;
32
33#[derive(Debug, Clone, PartialEq, Eq, Hash)]
37pub struct NerSpan {
38 pub start: usize,
40 pub end: usize,
42 pub label: String,
44}
45
46impl NerSpan {
47 pub fn new(start: usize, end: usize, label: impl Into<String>) -> Self {
49 NerSpan {
50 start,
51 end,
52 label: label.into(),
53 }
54 }
55}
56
57impl fmt::Display for NerSpan {
58 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
59 write!(f, "{}[{}..{})", self.label, self.start, self.end)
60 }
61}
62
63#[derive(Debug, Clone)]
65pub struct ClassMetrics {
66 pub precision: f64,
68 pub recall: f64,
70 pub f1: f64,
72 pub support: usize,
74 pub tp: usize,
76 pub fp: usize,
78 pub fn_: usize,
80}
81
82impl ClassMetrics {
83 fn from_counts(tp: usize, fp: usize, fn_: usize) -> Self {
84 let precision = if tp + fp == 0 {
85 0.0
86 } else {
87 tp as f64 / (tp + fp) as f64
88 };
89 let recall = if tp + fn_ == 0 {
90 0.0
91 } else {
92 tp as f64 / (tp + fn_) as f64
93 };
94 let f1 = if precision + recall == 0.0 {
95 0.0
96 } else {
97 2.0 * precision * recall / (precision + recall)
98 };
99 ClassMetrics {
100 precision,
101 recall,
102 f1,
103 support: tp + fn_,
104 tp,
105 fp,
106 fn_,
107 }
108 }
109}
110
111#[derive(Debug, Clone)]
113pub struct NerEvaluationResult {
114 pub precision: f64,
116 pub recall: f64,
118 pub f1: f64,
120 pub per_class: HashMap<String, ClassMetrics>,
122 pub exact_match: f64,
125}
126
127pub fn extract_spans_from_bio(tokens: &[&str], labels: &[&str]) -> Vec<NerSpan> {
142 if tokens.len() != labels.len() {
143 return Vec::new();
144 }
145
146 let mut spans: Vec<NerSpan> = Vec::new();
147 let mut current: Option<(usize, String)> = None; for (i, label) in labels.iter().enumerate() {
150 let tag = *label;
151
152 if tag == "O" || tag.is_empty() {
153 if let Some((start, lbl)) = current.take() {
155 spans.push(NerSpan::new(start, i, lbl));
156 }
157 } else if let Some(rest) = tag.strip_prefix("B-") {
158 if let Some((start, lbl)) = current.take() {
160 spans.push(NerSpan::new(start, i, lbl));
161 }
162 current = Some((i, rest.to_string()));
163 } else if let Some(rest) = tag.strip_prefix("I-") {
164 match ¤t {
166 Some((_, lbl)) if lbl == rest => {
167 }
169 _ => {
170 if let Some((start, lbl)) = current.take() {
172 spans.push(NerSpan::new(start, i, lbl));
173 }
174 current = Some((i, rest.to_string()));
175 }
176 }
177 } else {
178 if let Some((start, lbl)) = current.take() {
180 spans.push(NerSpan::new(start, i, lbl));
181 }
182 current = Some((i, tag.to_string()));
183 }
184 }
185
186 if let Some((start, lbl)) = current.take() {
188 spans.push(NerSpan::new(start, labels.len(), lbl));
189 }
190
191 spans
192}
193
194pub fn evaluate_ner(predictions: &[Vec<NerSpan>], gold: &[Vec<NerSpan>]) -> NerEvaluationResult {
204 if predictions.len() != gold.len() {
205 return NerEvaluationResult {
207 precision: 0.0,
208 recall: 0.0,
209 f1: 0.0,
210 per_class: HashMap::new(),
211 exact_match: 0.0,
212 };
213 }
214
215 let mut class_tp: HashMap<String, usize> = HashMap::new();
217 let mut class_fp: HashMap<String, usize> = HashMap::new();
218 let mut class_fn: HashMap<String, usize> = HashMap::new();
219
220 let mut exact_match_count = 0usize;
221
222 for (pred_spans, gold_spans) in predictions.iter().zip(gold.iter()) {
223 use std::collections::HashSet;
225 let gold_set: HashSet<&NerSpan> = gold_spans.iter().collect();
226 let pred_set: HashSet<&NerSpan> = pred_spans.iter().collect();
227
228 for span in pred_spans {
229 let label = &span.label;
230 if gold_set.contains(span) {
231 *class_tp.entry(label.clone()).or_insert(0) += 1;
232 } else {
233 *class_fp.entry(label.clone()).or_insert(0) += 1;
234 }
235 }
236
237 for span in gold_spans {
238 let label = &span.label;
239 if !pred_set.contains(span) {
240 *class_fn.entry(label.clone()).or_insert(0) += 1;
241 }
242 }
243
244 if pred_set == gold_set {
246 exact_match_count += 1;
247 }
248 }
249
250 let mut all_labels: std::collections::BTreeSet<String> = std::collections::BTreeSet::new();
252 for k in class_tp
253 .keys()
254 .chain(class_fp.keys())
255 .chain(class_fn.keys())
256 {
257 all_labels.insert(k.clone());
258 }
259
260 let mut per_class: HashMap<String, ClassMetrics> = HashMap::new();
261 let mut total_tp = 0usize;
262 let mut total_fp = 0usize;
263 let mut total_fn = 0usize;
264
265 for label in &all_labels {
266 let tp = *class_tp.get(label).unwrap_or(&0);
267 let fp = *class_fp.get(label).unwrap_or(&0);
268 let fn_ = *class_fn.get(label).unwrap_or(&0);
269
270 per_class.insert(label.clone(), ClassMetrics::from_counts(tp, fp, fn_));
271
272 total_tp += tp;
273 total_fp += fp;
274 total_fn += fn_;
275 }
276
277 let overall = ClassMetrics::from_counts(total_tp, total_fp, total_fn);
278 let n_sequences = predictions.len();
279 let exact_match = if n_sequences == 0 {
280 0.0
281 } else {
282 exact_match_count as f64 / n_sequences as f64
283 };
284
285 NerEvaluationResult {
286 precision: overall.precision,
287 recall: overall.recall,
288 f1: overall.f1,
289 per_class,
290 exact_match,
291 }
292}
293
294pub fn conll_format_report(result: &NerEvaluationResult) -> String {
307 let mut lines: Vec<String> = Vec::new();
308
309 let total_gold: usize = result.per_class.values().map(|m| m.support).sum();
311 let total_pred: usize = result.per_class.values().map(|m| m.tp + m.fp).sum();
312 let total_correct: usize = result.per_class.values().map(|m| m.tp).sum();
313
314 lines.push(format!(
315 "processed N tokens with {} phrases; found: {} phrases; correct: {}",
316 total_gold, total_pred, total_correct,
317 ));
318
319 lines.push(format!(
320 "accuracy: N/A; precision: {:6.2}%; recall: {:6.2}%; FB1: {:6.2}",
321 result.precision * 100.0,
322 result.recall * 100.0,
323 result.f1 * 100.0,
324 ));
325
326 let mut labels: Vec<&String> = result.per_class.keys().collect();
328 labels.sort();
329
330 for label in labels {
331 let m = &result.per_class[label];
332 lines.push(format!(
333 " {:>8}: precision: {:6.2}%; recall: {:6.2}%; FB1: {:6.2} {}",
334 label,
335 m.precision * 100.0,
336 m.recall * 100.0,
337 m.f1 * 100.0,
338 m.support,
339 ));
340 }
341
342 lines.join("\n")
343}
344
345pub fn evaluate_ner_checked(
348 predictions: &[Vec<NerSpan>],
349 gold: &[Vec<NerSpan>],
350) -> Result<NerEvaluationResult, TextError> {
351 if predictions.len() != gold.len() {
352 return Err(TextError::InvalidInput(format!(
353 "Prediction batch size ({}) != gold batch size ({})",
354 predictions.len(),
355 gold.len()
356 )));
357 }
358 Ok(evaluate_ner(predictions, gold))
359}
360
361#[cfg(test)]
364mod tests {
365 use super::*;
366
367 #[test]
370 fn test_bio_extraction_basic() {
371 let tokens = vec!["John", "Smith", "works", "at", "Google", "."];
372 let labels = vec!["B-PER", "I-PER", "O", "O", "B-ORG", "O"];
373 let spans = extract_spans_from_bio(&tokens, &labels);
374
375 assert_eq!(spans.len(), 2);
376 assert_eq!(spans[0], NerSpan::new(0, 2, "PER"));
377 assert_eq!(spans[1], NerSpan::new(4, 5, "ORG"));
378 }
379
380 #[test]
381 fn test_bio_extraction_single_token_entity() {
382 let tokens = vec!["Paris", "is", "beautiful"];
383 let labels = vec!["B-LOC", "O", "O"];
384 let spans = extract_spans_from_bio(&tokens, &labels);
385
386 assert_eq!(spans.len(), 1);
387 assert_eq!(spans[0], NerSpan::new(0, 1, "LOC"));
388 }
389
390 #[test]
391 fn test_bio_extraction_all_outside() {
392 let tokens = vec!["the", "cat", "sat"];
393 let labels = vec!["O", "O", "O"];
394 let spans = extract_spans_from_bio(&tokens, &labels);
395 assert!(spans.is_empty());
396 }
397
398 #[test]
399 fn test_bio_extraction_trailing_entity() {
400 let tokens = vec!["Visit", "New", "York"];
401 let labels = vec!["O", "B-LOC", "I-LOC"];
402 let spans = extract_spans_from_bio(&tokens, &labels);
403
404 assert_eq!(spans.len(), 1);
405 assert_eq!(spans[0], NerSpan::new(1, 3, "LOC"));
406 }
407
408 #[test]
409 fn test_bio_extraction_nested_b_tags() {
410 let tokens = vec!["Apple", "Inc", "Google", "LLC"];
412 let labels = vec!["B-ORG", "I-ORG", "B-ORG", "I-ORG"];
413 let spans = extract_spans_from_bio(&tokens, &labels);
414
415 assert_eq!(spans.len(), 2);
416 assert_eq!(spans[0], NerSpan::new(0, 2, "ORG"));
417 assert_eq!(spans[1], NerSpan::new(2, 4, "ORG"));
418 }
419
420 #[test]
421 fn test_bio_extraction_type_change() {
422 let tokens = vec!["John", "Google"];
424 let labels = vec!["B-PER", "I-ORG"];
425 let spans = extract_spans_from_bio(&tokens, &labels);
426
427 assert_eq!(spans.len(), 2);
428 assert_eq!(spans[0].label, "PER");
429 assert_eq!(spans[1].label, "ORG");
430 }
431
432 #[test]
433 fn test_bio_extraction_mismatched_lengths() {
434 let tokens = vec!["John"];
435 let labels = vec!["B-PER", "O"];
436 let spans = extract_spans_from_bio(&tokens, &labels);
437 assert!(spans.is_empty()); }
439
440 #[test]
443 fn test_ner_eval_perfect() {
444 let gold = vec![vec![NerSpan::new(0, 2, "PER"), NerSpan::new(4, 5, "ORG")]];
445 let pred = gold.clone();
446
447 let result = evaluate_ner(&pred, &gold);
448 assert!((result.precision - 1.0).abs() < 1e-9);
449 assert!((result.recall - 1.0).abs() < 1e-9);
450 assert!((result.f1 - 1.0).abs() < 1e-9);
451 assert!((result.exact_match - 1.0).abs() < 1e-9);
452 }
453
454 #[test]
455 fn test_ner_eval_no_predictions() {
456 let gold = vec![vec![NerSpan::new(0, 2, "PER")]];
457 let pred = vec![vec![]];
458
459 let result = evaluate_ner(&pred, &gold);
460 assert!((result.precision - 0.0).abs() < 1e-9);
461 assert!((result.recall - 0.0).abs() < 1e-9);
462 assert!((result.f1 - 0.0).abs() < 1e-9);
463 }
464
465 #[test]
466 fn test_ner_eval_partial_match() {
467 let gold = vec![vec![NerSpan::new(0, 3, "PER")]];
469 let pred = vec![vec![NerSpan::new(0, 2, "PER")]];
470
471 let result = evaluate_ner(&pred, &gold);
472 assert!((result.precision - 0.0).abs() < 1e-9);
474 assert!((result.recall - 0.0).abs() < 1e-9);
475 }
476
477 #[test]
478 fn test_ner_eval_wrong_label() {
479 let gold = vec![vec![NerSpan::new(0, 2, "PER")]];
481 let pred = vec![vec![NerSpan::new(0, 2, "ORG")]];
482
483 let result = evaluate_ner(&pred, &gold);
484 assert!((result.f1 - 0.0).abs() < 1e-9);
485 }
486
487 #[test]
488 fn test_ner_eval_per_class_metrics() {
489 let gold = vec![vec![
490 NerSpan::new(0, 2, "PER"),
491 NerSpan::new(3, 5, "ORG"),
492 NerSpan::new(6, 8, "LOC"),
493 ]];
494 let pred = vec![vec![
495 NerSpan::new(0, 2, "PER"), NerSpan::new(6, 8, "LOC"), NerSpan::new(9, 11, "ORG"), ]];
500
501 let result = evaluate_ner(&pred, &gold);
502
503 let per = result.per_class.get("PER").expect("PER class");
504 assert_eq!(per.tp, 1);
505 assert_eq!(per.fp, 0);
506 assert_eq!(per.fn_, 0);
507
508 let org = result.per_class.get("ORG").expect("ORG class");
509 assert_eq!(org.tp, 0);
511 assert_eq!(org.fp, 1);
512 assert_eq!(org.fn_, 1);
513
514 let loc = result.per_class.get("LOC").expect("LOC class");
515 assert_eq!(loc.tp, 1);
516 assert_eq!(loc.fp, 0);
517 assert_eq!(loc.fn_, 0);
518 }
519
520 #[test]
521 fn test_ner_eval_multiple_sequences() {
522 let gold = vec![
523 vec![NerSpan::new(0, 1, "PER")],
524 vec![NerSpan::new(0, 1, "ORG")],
525 ];
526 let pred = vec![
527 vec![NerSpan::new(0, 1, "PER")], vec![NerSpan::new(0, 1, "PER")], ];
530
531 let result = evaluate_ner(&pred, &gold);
532 assert!((result.precision - 0.5).abs() < 1e-9);
535 assert!((result.recall - 0.5).abs() < 1e-9);
536 }
537
538 #[test]
539 fn test_ner_eval_exact_match_rate() {
540 let gold = vec![
541 vec![NerSpan::new(0, 1, "PER")],
542 vec![NerSpan::new(0, 1, "ORG")],
543 ];
544 let pred = vec![
545 vec![NerSpan::new(0, 1, "PER")], vec![NerSpan::new(0, 1, "PER")], ];
548
549 let result = evaluate_ner(&pred, &gold);
550 assert!((result.exact_match - 0.5).abs() < 1e-9);
551 }
552
553 #[test]
554 fn test_conll_report_format() {
555 let gold = vec![vec![NerSpan::new(0, 2, "PER")]];
556 let pred = gold.clone();
557 let result = evaluate_ner(&pred, &gold);
558 let report = conll_format_report(&result);
559
560 assert!(report.contains("precision"));
561 assert!(report.contains("recall"));
562 assert!(report.contains("FB1"));
563 assert!(report.contains("PER"));
564 assert!(report.contains("100.00"));
565 }
566
567 #[test]
568 fn test_ner_checked_length_mismatch() {
569 let gold = vec![vec![NerSpan::new(0, 1, "PER")]];
570 let pred: Vec<Vec<NerSpan>> = vec![];
571 assert!(evaluate_ner_checked(&pred, &gold).is_err());
572 }
573
574 #[test]
575 fn test_bio_extraction_multiple_types() {
576 let tokens = vec!["The", "EU", "said", "Angela", "Merkel", "from", "Germany"];
577 let labels = vec!["O", "B-ORG", "O", "B-PER", "I-PER", "O", "B-LOC"];
578
579 let spans = extract_spans_from_bio(&tokens, &labels);
580 assert_eq!(spans.len(), 3);
581 assert_eq!(spans[0], NerSpan::new(1, 2, "ORG"));
582 assert_eq!(spans[1], NerSpan::new(3, 5, "PER"));
583 assert_eq!(spans[2], NerSpan::new(6, 7, "LOC"));
584 }
585}