use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CalibrationResults {
pub ece: f64,
pub mce: f64,
pub brier_score: f64,
pub avg_confidence_correct: f64,
pub avg_confidence_incorrect: f64,
pub confidence_gap: f64,
pub reliability_bins: Vec<ReliabilityBin>,
pub total_predictions: usize,
pub threshold_accuracy: HashMap<String, ThresholdMetrics>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReliabilityBin {
pub range: (f64, f64),
pub avg_confidence: f64,
pub accuracy: f64,
pub count: usize,
pub calibration_error: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ThresholdMetrics {
pub accuracy: f64,
pub coverage: f64,
pub count: usize,
}
#[derive(Debug, Clone)]
pub struct CalibrationEvaluator {
pub num_bins: usize,
pub thresholds: Vec<f64>,
}
impl Default for CalibrationEvaluator {
fn default() -> Self {
Self {
num_bins: 10,
thresholds: vec![0.5, 0.7, 0.8, 0.9, 0.95],
}
}
}
impl CalibrationEvaluator {
pub fn new(num_bins: usize) -> Self {
Self {
num_bins,
..Default::default()
}
}
pub fn compute(predictions: &[(f64, bool)]) -> CalibrationResults {
Self::default().evaluate(predictions)
}
pub fn evaluate(&self, predictions: &[(f64, bool)]) -> CalibrationResults {
if predictions.is_empty() {
return CalibrationResults {
ece: 0.0,
mce: 0.0,
brier_score: 0.0,
avg_confidence_correct: 0.0,
avg_confidence_incorrect: 0.0,
confidence_gap: 0.0,
reliability_bins: Vec::new(),
total_predictions: 0,
threshold_accuracy: HashMap::new(),
};
}
let bin_width = 1.0 / self.num_bins as f64;
let mut bins: Vec<Vec<(f64, bool)>> = vec![Vec::new(); self.num_bins];
for &(conf, correct) in predictions {
let bin_idx = ((conf * self.num_bins as f64) as usize).min(self.num_bins - 1);
bins[bin_idx].push((conf, correct));
}
let mut reliability_bins = Vec::new();
let mut ece_sum = 0.0;
let mut mce: f64 = 0.0;
for (i, bin) in bins.iter().enumerate() {
if bin.is_empty() {
continue;
}
let range_start = i as f64 * bin_width;
let range_end = (i + 1) as f64 * bin_width;
let avg_confidence = bin.iter().map(|(c, _)| c).sum::<f64>() / bin.len() as f64;
let accuracy =
bin.iter().filter(|(_, correct)| *correct).count() as f64 / bin.len() as f64;
let calibration_error = (accuracy - avg_confidence).abs();
let weight = bin.len() as f64 / predictions.len() as f64;
ece_sum += weight * calibration_error;
mce = mce.max(calibration_error);
reliability_bins.push(ReliabilityBin {
range: (range_start, range_end),
avg_confidence,
accuracy,
count: bin.len(),
calibration_error,
});
}
let brier_score = predictions
.iter()
.map(|(conf, correct)| {
let target = if *correct { 1.0 } else { 0.0 };
(conf - target).powi(2)
})
.sum::<f64>()
/ predictions.len() as f64;
let correct_confs: Vec<f64> = predictions
.iter()
.filter(|(_, c)| *c)
.map(|(conf, _)| *conf)
.collect();
let incorrect_confs: Vec<f64> = predictions
.iter()
.filter(|(_, c)| !*c)
.map(|(conf, _)| *conf)
.collect();
let avg_confidence_correct = if correct_confs.is_empty() {
0.0
} else {
correct_confs.iter().sum::<f64>() / correct_confs.len() as f64
};
let avg_confidence_incorrect = if incorrect_confs.is_empty() {
0.0
} else {
incorrect_confs.iter().sum::<f64>() / incorrect_confs.len() as f64
};
let mut threshold_accuracy = HashMap::new();
for &threshold in &self.thresholds {
let above: Vec<_> = predictions
.iter()
.filter(|(c, _)| *c >= threshold)
.collect();
if above.is_empty() {
threshold_accuracy.insert(
format!("{:.2}", threshold),
ThresholdMetrics {
accuracy: 0.0,
coverage: 0.0,
count: 0,
},
);
} else {
let acc = above.iter().filter(|(_, correct)| *correct).count() as f64
/ above.len() as f64;
let cov = above.len() as f64 / predictions.len() as f64;
threshold_accuracy.insert(
format!("{:.2}", threshold),
ThresholdMetrics {
accuracy: acc,
coverage: cov,
count: above.len(),
},
);
}
}
CalibrationResults {
ece: ece_sum,
mce,
brier_score,
avg_confidence_correct,
avg_confidence_incorrect,
confidence_gap: avg_confidence_correct - avg_confidence_incorrect,
reliability_bins,
total_predictions: predictions.len(),
threshold_accuracy,
}
}
}
pub fn calibration_grade(ece: f64) -> &'static str {
if ece < 0.05 {
"Well calibrated"
} else if ece < 0.10 {
"Moderately calibrated"
} else if ece < 0.15 {
"Poorly calibrated"
} else {
"Very poorly calibrated"
}
}
pub fn confidence_gap_grade(gap: f64) -> &'static str {
if gap > 0.3 {
"Excellent discrimination"
} else if gap > 0.2 {
"Good discrimination"
} else if gap > 0.1 {
"Moderate discrimination"
} else if gap > 0.0 {
"Weak discrimination"
} else {
"No discrimination (or reversed)"
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_perfect_calibration() {
let predictions = vec![
(0.9, true),
(0.9, true),
(0.9, true),
(0.9, true),
(0.9, true),
(0.9, true),
(0.9, true),
(0.9, true),
(0.9, true),
(0.9, false), ];
let results = CalibrationEvaluator::compute(&predictions);
assert!(
results.ece < 0.1,
"ECE should be low for well-calibrated predictions"
);
}
#[test]
fn test_overconfident_model() {
let predictions = vec![
(0.95, false),
(0.95, false),
(0.95, false),
(0.95, true),
(0.95, false), ];
let results = CalibrationEvaluator::compute(&predictions);
assert!(
results.ece > 0.5,
"ECE should be high for overconfident predictions"
);
}
#[test]
fn test_confidence_gap() {
let predictions = vec![
(0.95, true),
(0.90, true),
(0.85, true),
(0.30, false),
(0.25, false),
(0.20, false),
];
let results = CalibrationEvaluator::compute(&predictions);
assert!(
results.avg_confidence_correct > 0.8,
"Correct predictions should have high confidence"
);
assert!(
results.avg_confidence_incorrect < 0.4,
"Incorrect predictions should have low confidence"
);
assert!(
results.confidence_gap > 0.4,
"Should have large confidence gap"
);
}
#[test]
fn test_threshold_metrics() {
let predictions = vec![
(0.95, true),
(0.85, true),
(0.75, false),
(0.65, true),
(0.55, false),
];
let results = CalibrationEvaluator::compute(&predictions);
let t80 = results.threshold_accuracy.get("0.80").unwrap();
assert!((t80.accuracy - 1.0).abs() < 0.01, "Should be 100% at 0.80");
assert!((t80.coverage - 0.4).abs() < 0.01, "Coverage should be 40%");
}
#[test]
fn test_empty_predictions() {
let results = CalibrationEvaluator::compute(&[]);
assert_eq!(results.total_predictions, 0);
assert_eq!(results.ece, 0.0);
}
#[test]
fn test_calibration_grades() {
assert_eq!(calibration_grade(0.03), "Well calibrated");
assert_eq!(calibration_grade(0.07), "Moderately calibrated");
assert_eq!(calibration_grade(0.12), "Poorly calibrated");
assert_eq!(calibration_grade(0.25), "Very poorly calibrated");
}
#[test]
fn test_entropy_single_source() {
let scores = vec![0.9];
let entropy = confidence_entropy(&scores);
assert!(
(entropy - 0.0).abs() < 0.001,
"Single source should have 0 entropy"
);
}
#[test]
fn test_entropy_agreement() {
let scores = vec![0.9, 0.88, 0.92];
let entropy = confidence_entropy(&scores);
assert!(
entropy < 0.5,
"Agreeing sources should have low entropy: {}",
entropy
);
}
#[test]
fn test_entropy_conflict() {
let scores = vec![0.95, 0.05, 0.5, 0.8, 0.2];
let entropy = confidence_entropy(&scores);
assert!(
entropy > 0.5,
"Conflicting sources should have high entropy: {}",
entropy
);
}
#[test]
fn test_entropy_filter() {
let candidates = [
("Apple Inc.", vec![0.9, 0.88, 0.92]), ("Apple", vec![0.95, 0.05, 0.5]), ("Microsoft", vec![0.85, 0.87]), ];
let filter = EntropyFilter::new(0.6);
let filtered: Vec<_> = candidates
.iter()
.filter(|(_, scores)| filter.should_keep(scores))
.map(|(name, _)| *name)
.collect();
assert!(filtered.contains(&"Apple Inc."));
assert!(filtered.contains(&"Microsoft"));
assert!(
!filtered.contains(&"Apple"),
"Conflicting 'Apple' should be filtered"
);
}
}
#[must_use]
pub fn confidence_entropy(scores: &[f64]) -> f64 {
if scores.len() <= 1 {
return 0.0; }
let mean = scores.iter().sum::<f64>() / scores.len() as f64;
let n = scores.len() as f64;
let variance = if n > 1.0 {
scores.iter().map(|s| (s - mean).powi(2)).sum::<f64>() / (n - 1.0)
} else {
0.0
};
let std_dev = variance.sqrt();
(std_dev / 0.5).min(1.0)
}
#[must_use]
pub fn confidence_variance(scores: &[f64]) -> f64 {
if scores.len() <= 1 {
return 0.0;
}
let mean = scores.iter().sum::<f64>() / scores.len() as f64;
scores.iter().map(|s| (s - mean).powi(2)).sum::<f64>() / scores.len() as f64
}
#[derive(Debug, Clone)]
pub struct EntropyFilter {
pub max_entropy: f64,
pub min_sources: usize,
pub use_variance: bool,
pub max_variance: f64,
}
impl Default for EntropyFilter {
fn default() -> Self {
Self {
max_entropy: 0.7, min_sources: 2, use_variance: false,
max_variance: 0.1, }
}
}
impl EntropyFilter {
#[must_use]
pub fn new(max_entropy: f64) -> Self {
Self {
max_entropy,
..Default::default()
}
}
#[must_use]
pub fn strict() -> Self {
Self {
max_entropy: 0.4,
min_sources: 3,
..Default::default()
}
}
#[must_use]
pub fn permissive() -> Self {
Self {
max_entropy: 0.85,
min_sources: 2,
..Default::default()
}
}
#[must_use]
pub fn should_keep(&self, scores: &[f64]) -> bool {
if scores.len() < self.min_sources {
return true; }
if self.use_variance {
confidence_variance(scores) <= self.max_variance
} else {
confidence_entropy(scores) <= self.max_entropy
}
}
#[must_use]
pub fn compute_score(&self, scores: &[f64]) -> f64 {
if self.use_variance {
confidence_variance(scores)
} else {
confidence_entropy(scores)
}
}
#[must_use]
pub fn agreement_grade(&self, scores: &[f64]) -> &'static str {
let score = self.compute_score(scores);
if self.use_variance {
if score < 0.02 {
"Strong agreement"
} else if score < 0.05 {
"Good agreement"
} else if score < 0.1 {
"Moderate agreement"
} else {
"Disagreement"
}
} else if score < 0.3 {
"Strong agreement"
} else if score < 0.5 {
"Good agreement"
} else if score < 0.7 {
"Moderate agreement"
} else {
"Disagreement"
}
}
}