1use crate::EntityType;
77use serde::{Deserialize, Serialize};
78use std::collections::HashMap;
79use std::fmt;
80
81#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
96pub struct TypeDistribution {
97 probs: Vec<(EntityType, f64)>,
99}
100
101impl TypeDistribution {
102 #[must_use]
107 pub fn new(probs: Vec<(EntityType, f64)>) -> Self {
108 let probs = probs
109 .into_iter()
110 .map(|(t, p)| (t, p.clamp(0.0, 1.0)))
111 .filter(|(_, p)| *p > 0.0)
112 .collect();
113 Self { probs }
114 }
115
116 #[must_use]
118 pub fn uniform(types: &[EntityType]) -> Self {
119 if types.is_empty() {
120 return Self { probs: vec![] };
121 }
122 let p = 1.0 / types.len() as f64;
123 Self::new(types.iter().map(|t| (t.clone(), p)).collect())
124 }
125
126 #[must_use]
128 pub fn point_mass(entity_type: EntityType, confidence: f64) -> Self {
129 Self::new(vec![(entity_type, confidence)])
130 }
131
132 #[must_use]
134 pub fn argmax(&self) -> Option<(&EntityType, f64)> {
135 self.probs
136 .iter()
137 .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
138 .map(|(t, p)| (t, *p))
139 }
140
141 #[must_use]
143 pub fn prob(&self, entity_type: &EntityType) -> f64 {
144 self.probs
145 .iter()
146 .find(|(t, _)| t == entity_type)
147 .map(|(_, p)| *p)
148 .unwrap_or(0.0)
149 }
150
151 #[must_use]
155 pub fn entropy(&self) -> f64 {
156 let total: f64 = self.probs.iter().map(|(_, p)| p).sum();
157 if total <= 0.0 {
158 return 0.0;
159 }
160
161 let mut h = 0.0;
162 for (_, p) in &self.probs {
163 if *p > 0.0 {
164 let normalized = p / total;
165 h -= normalized * normalized.ln();
166 }
167 }
168 h
169 }
170
171 #[must_use]
175 pub fn margin(&self) -> f64 {
176 if self.probs.len() < 2 {
177 return 1.0;
178 }
179
180 let mut sorted: Vec<f64> = self.probs.iter().map(|(_, p)| *p).collect();
181 sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
182
183 sorted[0] - sorted.get(1).unwrap_or(&0.0)
184 }
185
186 #[must_use]
188 pub fn is_confident(&self, threshold: f64) -> bool {
189 self.argmax().is_some_and(|(_, p)| p >= threshold)
190 }
191
192 #[must_use]
194 pub fn to_map(&self) -> HashMap<EntityType, f64> {
195 self.probs.iter().cloned().collect()
196 }
197
198 #[must_use]
200 pub fn num_types(&self) -> usize {
201 self.probs.len()
202 }
203
204 #[must_use]
206 pub fn is_empty(&self) -> bool {
207 self.probs.is_empty()
208 }
209
210 pub fn iter(&self) -> impl Iterator<Item = (&EntityType, f64)> {
212 self.probs.iter().map(|(t, p)| (t, *p))
213 }
214}
215
216impl fmt::Display for TypeDistribution {
217 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218 if self.probs.is_empty() {
219 return write!(f, "∅");
220 }
221 let parts: Vec<String> = self
222 .probs
223 .iter()
224 .map(|(t, p)| format!("{}:{:.1}%", t.as_label(), p * 100.0))
225 .collect();
226 write!(f, "{{{}}}", parts.join(", "))
227 }
228}
229
230#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
242pub enum Abstention {
243 LowConfidence {
245 max_score: f64,
247 },
248
249 Ambiguous {
251 top_types: Vec<EntityType>,
253 margin: f64,
255 },
256
257 OutOfDomain {
259 detected_domain: Option<String>,
261 },
262
263 Conflict {
265 predictions: Vec<(String, EntityType)>, },
268
269 InvalidSpan {
271 reason: String,
273 },
274
275 Declined {
277 reason: String,
279 },
280}
281
282impl Abstention {
283 #[must_use]
285 pub fn description(&self) -> String {
286 match self {
287 Self::LowConfidence { max_score } => {
288 format!("Low confidence: max score {:.1}%", max_score * 100.0)
289 }
290 Self::Ambiguous { top_types, margin } => {
291 let types: Vec<_> = top_types.iter().map(|t| t.as_label()).collect();
292 format!(
293 "Ambiguous between {} (margin: {:.1}%)",
294 types.join(" vs "),
295 margin * 100.0
296 )
297 }
298 Self::OutOfDomain { detected_domain } => match detected_domain {
299 Some(d) => format!("Out of domain: detected '{}'", d),
300 None => "Out of domain".to_string(),
301 },
302 Self::Conflict { predictions } => {
303 let conflicts: Vec<_> = predictions
304 .iter()
305 .map(|(src, t)| format!("{}→{}", src, t.as_label()))
306 .collect();
307 format!("Conflict: {}", conflicts.join(", "))
308 }
309 Self::InvalidSpan { reason } => format!("Invalid span: {}", reason),
310 Self::Declined { reason } => format!("Declined: {}", reason),
311 }
312 }
313
314 #[must_use]
316 pub fn is_resolvable(&self) -> bool {
317 matches!(
318 self,
319 Self::LowConfidence { .. } | Self::Ambiguous { .. } | Self::Conflict { .. }
320 )
321 }
322}
323
324impl fmt::Display for Abstention {
325 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326 write!(f, "{}", self.description())
327 }
328}
329
330#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
345pub enum UncertainPrediction {
346 Single {
348 entity_type: EntityType,
350 confidence: f64,
352 },
353
354 Distributed(TypeDistribution),
356
357 Abstained(Abstention),
359}
360
361impl UncertainPrediction {
362 #[must_use]
364 pub fn from_type(entity_type: EntityType, confidence: f64) -> Self {
365 Self::Single {
366 entity_type,
367 confidence: confidence.clamp(0.0, 1.0),
368 }
369 }
370
371 #[must_use]
373 pub fn distributed(dist: TypeDistribution) -> Self {
374 Self::Distributed(dist)
375 }
376
377 #[must_use]
379 pub fn abstain(reason: Abstention) -> Self {
380 Self::Abstained(reason)
381 }
382
383 #[must_use]
385 pub fn abstain_low_confidence(max_score: f64) -> Self {
386 Self::Abstained(Abstention::LowConfidence { max_score })
387 }
388
389 #[must_use]
391 pub fn abstain_ambiguous(top_types: Vec<EntityType>, margin: f64) -> Self {
392 Self::Abstained(Abstention::Ambiguous { top_types, margin })
393 }
394
395 #[must_use]
397 pub fn is_abstention(&self) -> bool {
398 matches!(self, Self::Abstained(_))
399 }
400
401 #[must_use]
403 pub fn is_confident(&self, threshold: f64) -> bool {
404 match self {
405 Self::Single { confidence, .. } => *confidence >= threshold,
406 Self::Distributed(dist) => dist.is_confident(threshold),
407 Self::Abstained(_) => false,
408 }
409 }
410
411 #[must_use]
415 pub fn best(&self) -> Option<(&EntityType, f64)> {
416 match self {
417 Self::Single {
418 entity_type,
419 confidence,
420 } => Some((entity_type, *confidence)),
421 Self::Distributed(dist) => dist.argmax(),
422 Self::Abstained(_) => None,
423 }
424 }
425
426 #[must_use]
430 pub fn get_type(&self) -> Option<&EntityType> {
431 match self {
432 Self::Single { entity_type, .. } => Some(entity_type),
433 Self::Distributed(dist) => dist.argmax().map(|(t, _)| t),
434 Self::Abstained(_) => None,
435 }
436 }
437
438 #[must_use]
442 pub fn confidence(&self) -> f64 {
443 match self {
444 Self::Single { confidence, .. } => *confidence,
445 Self::Distributed(dist) => dist.argmax().map(|(_, p)| p).unwrap_or(0.0),
446 Self::Abstained(_) => 0.0,
447 }
448 }
449
450 #[must_use]
452 pub fn distribution(&self) -> Option<&TypeDistribution> {
453 match self {
454 Self::Distributed(dist) => Some(dist),
455 _ => None,
456 }
457 }
458
459 #[must_use]
461 pub fn abstention_reason(&self) -> Option<&Abstention> {
462 match self {
463 Self::Abstained(reason) => Some(reason),
464 _ => None,
465 }
466 }
467
468 #[must_use]
472 pub fn with_threshold(self, threshold: f64) -> Self {
473 match &self {
474 Self::Single { confidence, .. } if *confidence < threshold => {
475 Self::abstain_low_confidence(*confidence)
476 }
477 Self::Distributed(dist) => {
478 if let Some((_, p)) = dist.argmax() {
479 if p < threshold {
480 if dist.num_types() >= 2 {
481 let top: Vec<_> = dist.iter().take(2).map(|(t, _)| t.clone()).collect();
482 return Self::abstain_ambiguous(top, dist.margin());
483 }
484 return Self::abstain_low_confidence(p);
485 }
486 }
487 self
488 }
489 _ => self,
490 }
491 }
492}
493
494impl fmt::Display for UncertainPrediction {
495 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
496 match self {
497 Self::Single {
498 entity_type,
499 confidence,
500 } => {
501 write!(f, "{} ({:.1}%)", entity_type.as_label(), confidence * 100.0)
502 }
503 Self::Distributed(dist) => write!(f, "{}", dist),
504 Self::Abstained(reason) => write!(f, "ABSTAIN: {}", reason),
505 }
506 }
507}
508
509#[derive(Debug, Clone, Serialize, Deserialize)]
517pub struct UncertainEntity {
518 pub text: String,
520 pub start: usize,
522 pub end: usize,
524 pub prediction: UncertainPrediction,
526 pub source: Option<String>,
528}
529
530impl UncertainEntity {
531 #[must_use]
533 pub fn new(text: String, start: usize, end: usize, prediction: UncertainPrediction) -> Self {
534 Self {
535 text,
536 start,
537 end,
538 prediction,
539 source: None,
540 }
541 }
542
543 #[must_use]
545 pub fn with_source(mut self, source: impl Into<String>) -> Self {
546 self.source = Some(source.into());
547 self
548 }
549
550 #[must_use]
552 pub fn should_include(&self, threshold: f64) -> bool {
553 self.prediction.is_confident(threshold)
554 }
555
556 #[must_use]
560 pub fn to_entity(&self, threshold: f64) -> Option<crate::Entity> {
561 if !self.prediction.is_confident(threshold) {
562 return None;
563 }
564
565 let (entity_type, confidence) = self.prediction.best()?;
566 Some(crate::Entity::new(
567 &self.text,
568 entity_type.clone(),
569 self.start,
570 self.end,
571 confidence,
572 ))
573 }
574}
575
576#[derive(Debug, Clone, Default, Serialize, Deserialize)]
584pub struct SelectiveMetrics {
585 pub predictions: usize,
587 pub abstentions: usize,
589 pub correct: usize,
591 pub coverage: f64,
593 pub accuracy: f64,
595 pub risk: f64,
597}
598
599impl SelectiveMetrics {
600 #[must_use]
608 pub fn compute(predictions: &[(Option<EntityType>, EntityType)]) -> Self {
609 let mut metrics = Self::default();
610 let total = predictions.len();
611 if total == 0 {
612 return metrics;
613 }
614
615 for (pred, gold) in predictions {
616 match pred {
617 Some(pred_type) => {
618 metrics.predictions += 1;
619 if pred_type == gold {
620 metrics.correct += 1;
621 }
622 }
623 None => {
624 metrics.abstentions += 1;
625 }
626 }
627 }
628
629 metrics.coverage = metrics.predictions as f64 / total as f64;
630 metrics.accuracy = if metrics.predictions > 0 {
631 metrics.correct as f64 / metrics.predictions as f64
632 } else {
633 0.0
634 };
635
636 let incorrect = metrics.predictions - metrics.correct;
639 metrics.risk = incorrect as f64 / total as f64;
640
641 metrics
642 }
643
644 #[must_use]
649 pub fn coverage_accuracy_auc(
650 uncertain_predictions: &[(UncertainPrediction, EntityType)],
651 ) -> f64 {
652 if uncertain_predictions.is_empty() {
653 return 0.0;
654 }
655
656 let mut sorted: Vec<_> = uncertain_predictions
658 .iter()
659 .map(|(pred, gold)| (pred.confidence(), pred.get_type(), gold))
660 .collect();
661 sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
662
663 let total = sorted.len() as f64;
665 let mut correct = 0.0;
666 let mut auc = 0.0;
667
668 for (i, (_, pred_type, gold)) in sorted.iter().enumerate() {
669 if pred_type.is_some_and(|pt| pt == *gold) {
670 correct += 1.0;
671 }
672 let coverage = (i + 1) as f64 / total;
673 let accuracy = correct / (i + 1) as f64;
674
675 if i > 0 {
677 let prev_coverage = i as f64 / total;
678 auc += (coverage - prev_coverage) * accuracy;
679 }
680 }
681
682 auc
683 }
684}
685
686impl fmt::Display for SelectiveMetrics {
687 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
688 write!(
689 f,
690 "Coverage: {:.1}%, Accuracy: {:.1}%, Risk: {:.1}% ({}/{} predicted, {} abstained)",
691 self.coverage * 100.0,
692 self.accuracy * 100.0,
693 self.risk * 100.0,
694 self.predictions,
695 self.predictions + self.abstentions,
696 self.abstentions
697 )
698 }
699}
700
701#[cfg(test)]
706mod tests {
707 use super::*;
708
709 #[test]
710 fn test_type_distribution_argmax() {
711 let dist = TypeDistribution::new(vec![
712 (EntityType::Person, 0.7),
713 (EntityType::Organization, 0.2),
714 (EntityType::Location, 0.1),
715 ]);
716
717 let (best, prob) = dist.argmax().unwrap();
718 assert_eq!(*best, EntityType::Person);
719 assert!((prob - 0.7).abs() < 1e-10);
720 }
721
722 #[test]
723 fn test_type_distribution_entropy() {
724 let point = TypeDistribution::point_mass(EntityType::Person, 1.0);
726 assert!((point.entropy() - 0.0).abs() < 1e-10);
727
728 let uniform = TypeDistribution::uniform(&[
730 EntityType::Person,
731 EntityType::Organization,
732 EntityType::Location,
733 ]);
734 assert!(uniform.entropy() > 0.0);
735 }
736
737 #[test]
738 fn test_type_distribution_margin() {
739 let clear = TypeDistribution::new(vec![
741 (EntityType::Person, 0.9),
742 (EntityType::Organization, 0.1),
743 ]);
744 assert!((clear.margin() - 0.8).abs() < 1e-10);
745
746 let tied = TypeDistribution::new(vec![
748 (EntityType::Person, 0.5),
749 (EntityType::Organization, 0.5),
750 ]);
751 assert!((tied.margin() - 0.0).abs() < 1e-10);
752 }
753
754 #[test]
755 fn test_uncertain_prediction_single() {
756 let pred = UncertainPrediction::from_type(EntityType::Person, 0.85);
757 assert!(pred.is_confident(0.8));
758 assert!(!pred.is_confident(0.9));
759 assert!(!pred.is_abstention());
760
761 let (t, c) = pred.best().unwrap();
762 assert_eq!(*t, EntityType::Person);
763 assert!((c - 0.85).abs() < 1e-10);
764 }
765
766 #[test]
767 fn test_uncertain_prediction_abstain() {
768 let pred = UncertainPrediction::abstain_low_confidence(0.35);
769 assert!(pred.is_abstention());
770 assert!(!pred.is_confident(0.1));
771 assert!(pred.best().is_none());
772
773 let reason = pred.abstention_reason().unwrap();
774 assert!(matches!(reason, Abstention::LowConfidence { .. }));
775 }
776
777 #[test]
778 fn test_with_threshold() {
779 let pred = UncertainPrediction::from_type(EntityType::Person, 0.6);
780
781 let result = pred.clone().with_threshold(0.7);
783 assert!(result.is_abstention());
784
785 let result2 = pred.with_threshold(0.5);
787 assert!(!result2.is_abstention());
788 }
789
790 #[test]
791 fn test_selective_metrics() {
792 let predictions = vec![
793 (Some(EntityType::Person), EntityType::Person), (Some(EntityType::Organization), EntityType::Person), (None, EntityType::Location), (Some(EntityType::Location), EntityType::Location), ];
798
799 let metrics = SelectiveMetrics::compute(&predictions);
800
801 assert_eq!(metrics.predictions, 3);
802 assert_eq!(metrics.abstentions, 1);
803 assert_eq!(metrics.correct, 2);
804 assert!((metrics.coverage - 0.75).abs() < 1e-10);
805 assert!((metrics.accuracy - 2.0 / 3.0).abs() < 1e-10);
806 }
807
808 #[test]
809 fn test_uncertain_entity_to_entity() {
810 let ue = UncertainEntity::new(
811 "John".to_string(),
812 0,
813 4,
814 UncertainPrediction::from_type(EntityType::Person, 0.9),
815 );
816
817 let entity = ue.to_entity(0.8);
819 assert!(entity.is_some());
820 let e = entity.unwrap();
821 assert_eq!(e.text, "John");
822 assert_eq!(e.entity_type, EntityType::Person);
823
824 let entity_low = ue.to_entity(0.95);
826 assert!(entity_low.is_none());
827 }
828
829 #[test]
830 fn test_abstention_resolvable() {
831 assert!(Abstention::LowConfidence { max_score: 0.3 }.is_resolvable());
832 assert!(Abstention::Ambiguous {
833 top_types: vec![],
834 margin: 0.1
835 }
836 .is_resolvable());
837 assert!(!Abstention::OutOfDomain {
838 detected_domain: None
839 }
840 .is_resolvable());
841 }
842}