Skip to main content

anno/types/
uncertain.rs

1//! Uncertain predictions and abstention for selective NER.
2//!
3//! # The Problem with Forced Labels
4//!
5//! Traditional NER systems must output a label for every span they consider,
6//! even when uncertain. This leads to:
7//!
8//! - **Overconfident false positives**: System says "PERSON" with 0.51 confidence
9//! - **Hidden uncertainty**: Users can't distinguish confident vs. guessing
10//! - **No recourse**: Can't say "I don't know, ask a human"
11//!
12//! # Selective Prediction
13//!
14//! This module provides types for **selective prediction**, where the model can:
15//!
16//! 1. **Abstain**: Explicitly decline to label a span
17//! 2. **Distribute**: Output probabilities over multiple types
18//! 3. **Bound**: Provide confidence intervals, not point estimates
19//!
20//! # Research Background
21//!
22//! - Geifman & El-Yaniv (2017): "Selective Prediction via Deep Neural Networks"
23//! - El-Yaniv & Wiener (2010): "On the Foundations of Noise-Free Selective Classification"
24//! - Kamath et al. (2020): "Selective Question Answering under Domain Shift"
25//!
26//! # Trade-off: Coverage vs. Accuracy
27//!
28//! ```text
29//! ┌────────────────────────────────────────────────────────────────┐
30//! │                                                                │
31//! │  Accuracy                                                      │
32//! │     │                                                          │
33//! │ 100%├────────────────*.                                        │
34//! │     │               *  `.                                      │
35//! │     │              *    `.                                     │
36//! │ 80% ├─────────────*      `..                                   │
37//! │     │            *          `..                                │
38//! │     │           *              `..                             │
39//! │ 60% ├──────────*                 `...                          │
40//! │     │         *                      `...                      │
41//! │     │        *                           `...                  │
42//! │ 40% ├───────*─────────────────────────────────`...             │
43//! │     │      *                                      `...         │
44//! │     └──────┴────────────┴────────────┴────────────┴───Coverage │
45//! │           20%          50%          80%         100%           │
46//! │                                                                │
47//! │  • Low coverage (20%): Only answer when very confident → 95%+  │
48//! │  • High coverage (100%): Answer everything → baseline accuracy │
49//! │                                                                │
50//! └────────────────────────────────────────────────────────────────┘
51//! ```
52//!
53//! # Example
54//!
55//! ```rust
56//! use anno::types::uncertain::{UncertainPrediction, TypeDistribution, Abstention};
57//! use anno::EntityType;
58//!
59//! // A confident prediction
60//! let confident = UncertainPrediction::from_type(EntityType::Person, 0.95);
61//! assert!(confident.is_confident(0.8));
62//!
63//! // An uncertain prediction with distribution
64//! let uncertain = UncertainPrediction::distributed(TypeDistribution::new(vec![
65//!     (EntityType::Person, 0.45),
66//!     (EntityType::Organization, 0.40),
67//!     (EntityType::Location, 0.15),
68//! ]));
69//! assert!(!uncertain.is_confident(0.8));
70//!
71//! // Explicit abstention
72//! let abstain = UncertainPrediction::abstain(Abstention::LowConfidence { max_score: 0.35 });
73//! assert!(abstain.is_abstention());
74//! ```
75
76use crate::EntityType;
77use serde::{Deserialize, Serialize};
78use std::collections::HashMap;
79use std::fmt;
80
81// =============================================================================
82// Type Distribution
83// =============================================================================
84
85/// Distribution over entity types.
86///
87/// Unlike a single EntityType prediction, this captures uncertainty
88/// by assigning probability mass to multiple types.
89///
90/// # Invariants
91///
92/// - Probabilities are in [0, 1]
93/// - Probabilities may not sum to 1 (unnormalized is allowed)
94/// - Empty distributions are valid (no prediction)
95#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
96pub struct TypeDistribution {
97    /// Type → probability mapping
98    probs: Vec<(EntityType, f64)>,
99}
100
101impl TypeDistribution {
102    /// Create a new type distribution.
103    ///
104    /// Probabilities are clamped to [0, 1] but NOT normalized.
105    /// This allows representing "I'm 30% sure it's PERSON, 20% ORG, and 50% nothing".
106    #[must_use]
107    pub fn new(probs: Vec<(EntityType, f64)>) -> Self {
108        let probs = probs
109            .into_iter()
110            .map(|(t, p)| (t, p.clamp(0.0, 1.0)))
111            .filter(|(_, p)| *p > 0.0)
112            .collect();
113        Self { probs }
114    }
115
116    /// Create a uniform distribution over types.
117    #[must_use]
118    pub fn uniform(types: &[EntityType]) -> Self {
119        if types.is_empty() {
120            return Self { probs: vec![] };
121        }
122        let p = 1.0 / types.len() as f64;
123        Self::new(types.iter().map(|t| (t.clone(), p)).collect())
124    }
125
126    /// Create a distribution with all mass on one type.
127    #[must_use]
128    pub fn point_mass(entity_type: EntityType, confidence: f64) -> Self {
129        Self::new(vec![(entity_type, confidence)])
130    }
131
132    /// Get the most likely type and its probability.
133    #[must_use]
134    pub fn argmax(&self) -> Option<(&EntityType, f64)> {
135        self.probs
136            .iter()
137            .max_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal))
138            .map(|(t, p)| (t, *p))
139    }
140
141    /// Get probability for a specific type.
142    #[must_use]
143    pub fn prob(&self, entity_type: &EntityType) -> f64 {
144        self.probs
145            .iter()
146            .find(|(t, _)| t == entity_type)
147            .map(|(_, p)| *p)
148            .unwrap_or(0.0)
149    }
150
151    /// Get entropy of the distribution (higher = more uncertain).
152    ///
153    /// Returns 0 for point mass, log(n) for uniform over n types.
154    #[must_use]
155    pub fn entropy(&self) -> f64 {
156        let total: f64 = self.probs.iter().map(|(_, p)| p).sum();
157        if total <= 0.0 {
158            return 0.0;
159        }
160
161        let mut h = 0.0;
162        for (_, p) in &self.probs {
163            if *p > 0.0 {
164                let normalized = p / total;
165                h -= normalized * normalized.ln();
166            }
167        }
168        h
169    }
170
171    /// Get margin between top two predictions (higher = more confident).
172    ///
173    /// Returns 1.0 if only one type, 0.0 if tied.
174    #[must_use]
175    pub fn margin(&self) -> f64 {
176        if self.probs.len() < 2 {
177            return 1.0;
178        }
179
180        let mut sorted: Vec<f64> = self.probs.iter().map(|(_, p)| *p).collect();
181        sorted.sort_by(|a, b| b.partial_cmp(a).unwrap_or(std::cmp::Ordering::Equal));
182
183        sorted[0] - sorted.get(1).unwrap_or(&0.0)
184    }
185
186    /// Check if the top prediction exceeds a confidence threshold.
187    #[must_use]
188    pub fn is_confident(&self, threshold: f64) -> bool {
189        self.argmax().is_some_and(|(_, p)| p >= threshold)
190    }
191
192    /// Convert to a HashMap for easier iteration.
193    #[must_use]
194    pub fn to_map(&self) -> HashMap<EntityType, f64> {
195        self.probs.iter().cloned().collect()
196    }
197
198    /// Get the number of types with non-zero probability.
199    #[must_use]
200    pub fn num_types(&self) -> usize {
201        self.probs.len()
202    }
203
204    /// Check if distribution is empty (no predictions).
205    #[must_use]
206    pub fn is_empty(&self) -> bool {
207        self.probs.is_empty()
208    }
209
210    /// Iterate over (type, probability) pairs.
211    pub fn iter(&self) -> impl Iterator<Item = (&EntityType, f64)> {
212        self.probs.iter().map(|(t, p)| (t, *p))
213    }
214}
215
216impl fmt::Display for TypeDistribution {
217    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
218        if self.probs.is_empty() {
219            return write!(f, "∅");
220        }
221        let parts: Vec<String> = self
222            .probs
223            .iter()
224            .map(|(t, p)| format!("{}:{:.1}%", t.as_label(), p * 100.0))
225            .collect();
226        write!(f, "{{{}}}", parts.join(", "))
227    }
228}
229
230// =============================================================================
231// Abstention
232// =============================================================================
233
234/// Reason for abstaining from a prediction.
235///
236/// Different abstention reasons may warrant different downstream handling:
237/// - LowConfidence: Might benefit from more context or a human review
238/// - Ambiguous: Might benefit from entity linking to disambiguate
239/// - OutOfDomain: Should not be used for this domain
240/// - Conflict: Multiple signals disagree, needs resolution
241#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
242pub enum Abstention {
243    /// Maximum prediction score was below threshold.
244    LowConfidence {
245        /// The highest confidence score observed
246        max_score: f64,
247    },
248
249    /// Multiple types have similar probabilities (high entropy).
250    Ambiguous {
251        /// The top two competing types
252        top_types: Vec<EntityType>,
253        /// The margin between them (close to 0 = ambiguous)
254        margin: f64,
255    },
256
257    /// The text appears outside the model's training domain.
258    OutOfDomain {
259        /// Optional domain name if detected
260        detected_domain: Option<String>,
261    },
262
263    /// Multiple extraction methods produced conflicting results.
264    Conflict {
265        /// The conflicting predictions
266        predictions: Vec<(String, EntityType)>, // (source, type)
267    },
268
269    /// Span length or structure is invalid for entity extraction.
270    InvalidSpan {
271        /// Reason the span is invalid
272        reason: String,
273    },
274
275    /// Model explicitly declined (e.g., safety filter, policy).
276    Declined {
277        /// Reason for declining
278        reason: String,
279    },
280}
281
282impl Abstention {
283    /// Get a human-readable description of the abstention reason.
284    #[must_use]
285    pub fn description(&self) -> String {
286        match self {
287            Self::LowConfidence { max_score } => {
288                format!("Low confidence: max score {:.1}%", max_score * 100.0)
289            }
290            Self::Ambiguous { top_types, margin } => {
291                let types: Vec<_> = top_types.iter().map(|t| t.as_label()).collect();
292                format!(
293                    "Ambiguous between {} (margin: {:.1}%)",
294                    types.join(" vs "),
295                    margin * 100.0
296                )
297            }
298            Self::OutOfDomain { detected_domain } => match detected_domain {
299                Some(d) => format!("Out of domain: detected '{}'", d),
300                None => "Out of domain".to_string(),
301            },
302            Self::Conflict { predictions } => {
303                let conflicts: Vec<_> = predictions
304                    .iter()
305                    .map(|(src, t)| format!("{}→{}", src, t.as_label()))
306                    .collect();
307                format!("Conflict: {}", conflicts.join(", "))
308            }
309            Self::InvalidSpan { reason } => format!("Invalid span: {}", reason),
310            Self::Declined { reason } => format!("Declined: {}", reason),
311        }
312    }
313
314    /// Check if this abstention might be resolvable with more context.
315    #[must_use]
316    pub fn is_resolvable(&self) -> bool {
317        matches!(
318            self,
319            Self::LowConfidence { .. } | Self::Ambiguous { .. } | Self::Conflict { .. }
320        )
321    }
322}
323
324impl fmt::Display for Abstention {
325    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
326        write!(f, "{}", self.description())
327    }
328}
329
330// =============================================================================
331// Uncertain Prediction
332// =============================================================================
333
334/// A prediction that may include uncertainty or abstention.
335///
336/// This is the core type for selective prediction, replacing simple
337/// (EntityType, confidence) pairs with richer uncertainty information.
338///
339/// # Variants
340///
341/// - **Single**: A single type prediction with confidence (traditional)
342/// - **Distributed**: Probabilities over multiple types (soft prediction)
343/// - **Abstained**: Explicit refusal to predict with reason
344#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
345pub enum UncertainPrediction {
346    /// Single type prediction (traditional NER output).
347    Single {
348        /// The predicted type
349        entity_type: EntityType,
350        /// Confidence score [0, 1]
351        confidence: f64,
352    },
353
354    /// Distribution over types (soft prediction).
355    Distributed(TypeDistribution),
356
357    /// Explicit abstention with reason.
358    Abstained(Abstention),
359}
360
361impl UncertainPrediction {
362    /// Create a single-type prediction.
363    #[must_use]
364    pub fn from_type(entity_type: EntityType, confidence: f64) -> Self {
365        Self::Single {
366            entity_type,
367            confidence: confidence.clamp(0.0, 1.0),
368        }
369    }
370
371    /// Create a distributed prediction.
372    #[must_use]
373    pub fn distributed(dist: TypeDistribution) -> Self {
374        Self::Distributed(dist)
375    }
376
377    /// Create an abstention.
378    #[must_use]
379    pub fn abstain(reason: Abstention) -> Self {
380        Self::Abstained(reason)
381    }
382
383    /// Create abstention due to low confidence.
384    #[must_use]
385    pub fn abstain_low_confidence(max_score: f64) -> Self {
386        Self::Abstained(Abstention::LowConfidence { max_score })
387    }
388
389    /// Create abstention due to ambiguity.
390    #[must_use]
391    pub fn abstain_ambiguous(top_types: Vec<EntityType>, margin: f64) -> Self {
392        Self::Abstained(Abstention::Ambiguous { top_types, margin })
393    }
394
395    /// Check if this is an abstention.
396    #[must_use]
397    pub fn is_abstention(&self) -> bool {
398        matches!(self, Self::Abstained(_))
399    }
400
401    /// Check if this prediction is confident (above threshold).
402    #[must_use]
403    pub fn is_confident(&self, threshold: f64) -> bool {
404        match self {
405            Self::Single { confidence, .. } => *confidence >= threshold,
406            Self::Distributed(dist) => dist.is_confident(threshold),
407            Self::Abstained(_) => false,
408        }
409    }
410
411    /// Get the best prediction type and confidence.
412    ///
413    /// Returns `None` if abstained or distribution is empty.
414    #[must_use]
415    pub fn best(&self) -> Option<(&EntityType, f64)> {
416        match self {
417            Self::Single {
418                entity_type,
419                confidence,
420            } => Some((entity_type, *confidence)),
421            Self::Distributed(dist) => dist.argmax(),
422            Self::Abstained(_) => None,
423        }
424    }
425
426    /// Get the entity type if prediction is confident.
427    ///
428    /// Returns `None` if abstained, distributed, or below threshold.
429    #[must_use]
430    pub fn get_type(&self) -> Option<&EntityType> {
431        match self {
432            Self::Single { entity_type, .. } => Some(entity_type),
433            Self::Distributed(dist) => dist.argmax().map(|(t, _)| t),
434            Self::Abstained(_) => None,
435        }
436    }
437
438    /// Get confidence score.
439    ///
440    /// Returns 0.0 for abstentions, max prob for distributions.
441    #[must_use]
442    pub fn confidence(&self) -> f64 {
443        match self {
444            Self::Single { confidence, .. } => *confidence,
445            Self::Distributed(dist) => dist.argmax().map(|(_, p)| p).unwrap_or(0.0),
446            Self::Abstained(_) => 0.0,
447        }
448    }
449
450    /// Get the type distribution if available.
451    #[must_use]
452    pub fn distribution(&self) -> Option<&TypeDistribution> {
453        match self {
454            Self::Distributed(dist) => Some(dist),
455            _ => None,
456        }
457    }
458
459    /// Get the abstention reason if abstained.
460    #[must_use]
461    pub fn abstention_reason(&self) -> Option<&Abstention> {
462        match self {
463            Self::Abstained(reason) => Some(reason),
464            _ => None,
465        }
466    }
467
468    /// Convert to a single prediction, applying threshold for abstention.
469    ///
470    /// If the prediction is below threshold, converts to abstention.
471    #[must_use]
472    pub fn with_threshold(self, threshold: f64) -> Self {
473        match &self {
474            Self::Single { confidence, .. } if *confidence < threshold => {
475                Self::abstain_low_confidence(*confidence)
476            }
477            Self::Distributed(dist) => {
478                if let Some((_, p)) = dist.argmax() {
479                    if p < threshold {
480                        if dist.num_types() >= 2 {
481                            let top: Vec<_> = dist.iter().take(2).map(|(t, _)| t.clone()).collect();
482                            return Self::abstain_ambiguous(top, dist.margin());
483                        }
484                        return Self::abstain_low_confidence(p);
485                    }
486                }
487                self
488            }
489            _ => self,
490        }
491    }
492}
493
494impl fmt::Display for UncertainPrediction {
495    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
496        match self {
497            Self::Single {
498                entity_type,
499                confidence,
500            } => {
501                write!(f, "{} ({:.1}%)", entity_type.as_label(), confidence * 100.0)
502            }
503            Self::Distributed(dist) => write!(f, "{}", dist),
504            Self::Abstained(reason) => write!(f, "ABSTAIN: {}", reason),
505        }
506    }
507}
508
509// =============================================================================
510// Uncertain Entity
511// =============================================================================
512
513/// An entity extraction with uncertainty information.
514///
515/// Extends the standard Entity with richer uncertainty modeling.
516#[derive(Debug, Clone, Serialize, Deserialize)]
517pub struct UncertainEntity {
518    /// The text of the extracted entity
519    pub text: String,
520    /// Start offset (characters)
521    pub start: usize,
522    /// End offset (characters)
523    pub end: usize,
524    /// The uncertain prediction
525    pub prediction: UncertainPrediction,
526    /// Source backend name
527    pub source: Option<String>,
528}
529
530impl UncertainEntity {
531    /// Create a new uncertain entity.
532    #[must_use]
533    pub fn new(text: String, start: usize, end: usize, prediction: UncertainPrediction) -> Self {
534        Self {
535            text,
536            start,
537            end,
538            prediction,
539            source: None,
540        }
541    }
542
543    /// Set the source backend.
544    #[must_use]
545    pub fn with_source(mut self, source: impl Into<String>) -> Self {
546        self.source = Some(source.into());
547        self
548    }
549
550    /// Check if this entity should be included at a given threshold.
551    #[must_use]
552    pub fn should_include(&self, threshold: f64) -> bool {
553        self.prediction.is_confident(threshold)
554    }
555
556    /// Convert to standard Entity if confident enough.
557    ///
558    /// Returns `None` if abstained or below threshold.
559    #[must_use]
560    pub fn to_entity(&self, threshold: f64) -> Option<crate::Entity> {
561        if !self.prediction.is_confident(threshold) {
562            return None;
563        }
564
565        let (entity_type, confidence) = self.prediction.best()?;
566        Some(crate::Entity::new(
567            &self.text,
568            entity_type.clone(),
569            self.start,
570            self.end,
571            confidence,
572        ))
573    }
574}
575
576// =============================================================================
577// Selective Metrics
578// =============================================================================
579
580/// Metrics for evaluating selective prediction.
581///
582/// Captures the coverage-accuracy tradeoff.
583#[derive(Debug, Clone, Default, Serialize, Deserialize)]
584pub struct SelectiveMetrics {
585    /// Total predictions made (excluding abstentions)
586    pub predictions: usize,
587    /// Total abstentions
588    pub abstentions: usize,
589    /// Correct predictions
590    pub correct: usize,
591    /// Coverage = predictions / (predictions + abstentions)
592    pub coverage: f64,
593    /// Accuracy = correct / predictions (among non-abstained)
594    pub accuracy: f64,
595    /// Risk = incorrect / (predictions + abstentions) (including abstention penalty)
596    pub risk: f64,
597}
598
599impl SelectiveMetrics {
600    /// Compute selective metrics from predictions and gold labels.
601    ///
602    /// # Arguments
603    ///
604    /// * `predictions` - Vec of (predicted_type_or_none, gold_type)
605    ///   - `Some(t)` = prediction made
606    ///   - `None` = abstained
607    #[must_use]
608    pub fn compute(predictions: &[(Option<EntityType>, EntityType)]) -> Self {
609        let mut metrics = Self::default();
610        let total = predictions.len();
611        if total == 0 {
612            return metrics;
613        }
614
615        for (pred, gold) in predictions {
616            match pred {
617                Some(pred_type) => {
618                    metrics.predictions += 1;
619                    if pred_type == gold {
620                        metrics.correct += 1;
621                    }
622                }
623                None => {
624                    metrics.abstentions += 1;
625                }
626            }
627        }
628
629        metrics.coverage = metrics.predictions as f64 / total as f64;
630        metrics.accuracy = if metrics.predictions > 0 {
631            metrics.correct as f64 / metrics.predictions as f64
632        } else {
633            0.0
634        };
635
636        // Risk: fraction of incorrect predictions over all items
637        // (abstentions count as 0 error, predictions count their actual error)
638        let incorrect = metrics.predictions - metrics.correct;
639        metrics.risk = incorrect as f64 / total as f64;
640
641        metrics
642    }
643
644    /// Compute AUC for coverage-accuracy curve.
645    ///
646    /// Higher is better. Measures area under the coverage-accuracy curve
647    /// as threshold varies.
648    #[must_use]
649    pub fn coverage_accuracy_auc(
650        uncertain_predictions: &[(UncertainPrediction, EntityType)],
651    ) -> f64 {
652        if uncertain_predictions.is_empty() {
653            return 0.0;
654        }
655
656        // Sort by confidence descending
657        let mut sorted: Vec<_> = uncertain_predictions
658            .iter()
659            .map(|(pred, gold)| (pred.confidence(), pred.get_type(), gold))
660            .collect();
661        sorted.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
662
663        // Compute accuracy at each coverage level
664        let total = sorted.len() as f64;
665        let mut correct = 0.0;
666        let mut auc = 0.0;
667
668        for (i, (_, pred_type, gold)) in sorted.iter().enumerate() {
669            if pred_type.is_some_and(|pt| pt == *gold) {
670                correct += 1.0;
671            }
672            let coverage = (i + 1) as f64 / total;
673            let accuracy = correct / (i + 1) as f64;
674
675            // Trapezoidal rule for AUC
676            if i > 0 {
677                let prev_coverage = i as f64 / total;
678                auc += (coverage - prev_coverage) * accuracy;
679            }
680        }
681
682        auc
683    }
684}
685
686impl fmt::Display for SelectiveMetrics {
687    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
688        write!(
689            f,
690            "Coverage: {:.1}%, Accuracy: {:.1}%, Risk: {:.1}% ({}/{} predicted, {} abstained)",
691            self.coverage * 100.0,
692            self.accuracy * 100.0,
693            self.risk * 100.0,
694            self.predictions,
695            self.predictions + self.abstentions,
696            self.abstentions
697        )
698    }
699}
700
701// =============================================================================
702// Tests
703// =============================================================================
704
705#[cfg(test)]
706mod tests {
707    use super::*;
708
709    #[test]
710    fn test_type_distribution_argmax() {
711        let dist = TypeDistribution::new(vec![
712            (EntityType::Person, 0.7),
713            (EntityType::Organization, 0.2),
714            (EntityType::Location, 0.1),
715        ]);
716
717        let (best, prob) = dist.argmax().unwrap();
718        assert_eq!(*best, EntityType::Person);
719        assert!((prob - 0.7).abs() < 1e-10);
720    }
721
722    #[test]
723    fn test_type_distribution_entropy() {
724        // Point mass should have 0 entropy
725        let point = TypeDistribution::point_mass(EntityType::Person, 1.0);
726        assert!((point.entropy() - 0.0).abs() < 1e-10);
727
728        // Uniform should have higher entropy
729        let uniform = TypeDistribution::uniform(&[
730            EntityType::Person,
731            EntityType::Organization,
732            EntityType::Location,
733        ]);
734        assert!(uniform.entropy() > 0.0);
735    }
736
737    #[test]
738    fn test_type_distribution_margin() {
739        // Clear winner
740        let clear = TypeDistribution::new(vec![
741            (EntityType::Person, 0.9),
742            (EntityType::Organization, 0.1),
743        ]);
744        assert!((clear.margin() - 0.8).abs() < 1e-10);
745
746        // Tied
747        let tied = TypeDistribution::new(vec![
748            (EntityType::Person, 0.5),
749            (EntityType::Organization, 0.5),
750        ]);
751        assert!((tied.margin() - 0.0).abs() < 1e-10);
752    }
753
754    #[test]
755    fn test_uncertain_prediction_single() {
756        let pred = UncertainPrediction::from_type(EntityType::Person, 0.85);
757        assert!(pred.is_confident(0.8));
758        assert!(!pred.is_confident(0.9));
759        assert!(!pred.is_abstention());
760
761        let (t, c) = pred.best().unwrap();
762        assert_eq!(*t, EntityType::Person);
763        assert!((c - 0.85).abs() < 1e-10);
764    }
765
766    #[test]
767    fn test_uncertain_prediction_abstain() {
768        let pred = UncertainPrediction::abstain_low_confidence(0.35);
769        assert!(pred.is_abstention());
770        assert!(!pred.is_confident(0.1));
771        assert!(pred.best().is_none());
772
773        let reason = pred.abstention_reason().unwrap();
774        assert!(matches!(reason, Abstention::LowConfidence { .. }));
775    }
776
777    #[test]
778    fn test_with_threshold() {
779        let pred = UncertainPrediction::from_type(EntityType::Person, 0.6);
780
781        // Below threshold → abstain
782        let result = pred.clone().with_threshold(0.7);
783        assert!(result.is_abstention());
784
785        // Above threshold → keep
786        let result2 = pred.with_threshold(0.5);
787        assert!(!result2.is_abstention());
788    }
789
790    #[test]
791    fn test_selective_metrics() {
792        let predictions = vec![
793            (Some(EntityType::Person), EntityType::Person), // correct
794            (Some(EntityType::Organization), EntityType::Person), // incorrect
795            (None, EntityType::Location),                   // abstained
796            (Some(EntityType::Location), EntityType::Location), // correct
797        ];
798
799        let metrics = SelectiveMetrics::compute(&predictions);
800
801        assert_eq!(metrics.predictions, 3);
802        assert_eq!(metrics.abstentions, 1);
803        assert_eq!(metrics.correct, 2);
804        assert!((metrics.coverage - 0.75).abs() < 1e-10);
805        assert!((metrics.accuracy - 2.0 / 3.0).abs() < 1e-10);
806    }
807
808    #[test]
809    fn test_uncertain_entity_to_entity() {
810        let ue = UncertainEntity::new(
811            "John".to_string(),
812            0,
813            4,
814            UncertainPrediction::from_type(EntityType::Person, 0.9),
815        );
816
817        // Should convert when above threshold
818        let entity = ue.to_entity(0.8);
819        assert!(entity.is_some());
820        let e = entity.unwrap();
821        assert_eq!(e.text, "John");
822        assert_eq!(e.entity_type, EntityType::Person);
823
824        // Should not convert when below threshold
825        let entity_low = ue.to_entity(0.95);
826        assert!(entity_low.is_none());
827    }
828
829    #[test]
830    fn test_abstention_resolvable() {
831        assert!(Abstention::LowConfidence { max_score: 0.3 }.is_resolvable());
832        assert!(Abstention::Ambiguous {
833            top_types: vec![],
834            margin: 0.1
835        }
836        .is_resolvable());
837        assert!(!Abstention::OutOfDomain {
838            detected_domain: None
839        }
840        .is_resolvable());
841    }
842}