use std::collections::HashMap;
use serde::Serialize;
use super::COCOeval;
#[derive(Debug, Clone, Serialize)]
pub struct CalibrationBin {
pub bin_lower: f64,
pub bin_upper: f64,
pub avg_confidence: f64,
pub avg_accuracy: f64,
pub count: usize,
}
#[derive(Debug, Clone, Serialize)]
pub struct CalibrationResult {
pub ece: f64,
pub mce: f64,
pub bins: Vec<CalibrationBin>,
pub per_category: HashMap<u64, f64>,
pub iou_threshold: f64,
pub n_bins: usize,
pub num_detections: usize,
}
#[derive(Clone, Copy)]
struct Detection {
confidence: f64,
correct: bool,
}
impl COCOeval {
pub fn calibration(
&self,
n_bins: usize,
iou_threshold: f64,
) -> crate::error::Result<CalibrationResult> {
if self.eval_imgs.is_empty() {
return Err("calibration() requires evaluate() to be called first".into());
}
let t_idx = self
.params
.iou_thrs
.iter()
.position(|&t| (t - iou_threshold).abs() < 1e-9)
.ok_or_else(|| {
format!(
"iou_threshold={iou_threshold} not found in params.iou_thrs={:?}",
self.params.iou_thrs
)
})?;
let target_area_rng = self.params.area_range_idx("all").unwrap_or(0);
let target_area = self.params.area_ranges[target_area_rng].range;
let mut all_dets: Vec<Detection> = Vec::new();
let mut per_cat_dets: HashMap<u64, Vec<Detection>> = HashMap::new();
for eval_img in self.eval_imgs.iter().flatten() {
if eval_img.area_rng != target_area {
continue;
}
let matched = &eval_img.dt_matched[t_idx];
let ignored = &eval_img.dt_ignore[t_idx];
debug_assert_eq!(matched.len(), eval_img.dt_scores.len());
debug_assert_eq!(ignored.len(), eval_img.dt_scores.len());
let n = matched
.len()
.min(ignored.len())
.min(eval_img.dt_scores.len());
for d in 0..n {
if ignored[d] {
continue;
}
let det = Detection {
confidence: eval_img.dt_scores[d],
correct: matched[d],
};
per_cat_dets
.entry(eval_img.category_id)
.or_default()
.push(det);
all_dets.push(det);
}
}
let bins = compute_bins(&all_dets, n_bins);
let (ece, mce) = compute_ece_mce(&bins, all_dets.len());
let num_detections = all_dets.len();
let per_category: HashMap<u64, f64> = per_cat_dets
.iter()
.map(|(&cat_id, dets)| {
let cat_bins = compute_bins(dets, n_bins);
let (cat_ece, _) = compute_ece_mce(&cat_bins, dets.len());
(cat_id, cat_ece)
})
.collect();
Ok(CalibrationResult {
ece,
mce,
bins,
per_category,
iou_threshold,
n_bins,
num_detections,
})
}
}
fn compute_bins(dets: &[Detection], n_bins: usize) -> Vec<CalibrationBin> {
let mut bins: Vec<CalibrationBin> = (0..n_bins)
.map(|i| {
let lower = i as f64 / n_bins as f64;
let upper = (i + 1) as f64 / n_bins as f64;
CalibrationBin {
bin_lower: lower,
bin_upper: upper,
avg_confidence: 0.0,
avg_accuracy: 0.0,
count: 0,
}
})
.collect();
for det in dets {
let idx = ((det.confidence * n_bins as f64) as usize).min(n_bins - 1);
bins[idx].avg_confidence += det.confidence;
bins[idx].avg_accuracy += if det.correct { 1.0 } else { 0.0 };
bins[idx].count += 1;
}
for bin in &mut bins {
if bin.count > 0 {
let n = bin.count as f64;
bin.avg_confidence /= n;
bin.avg_accuracy /= n;
}
}
bins
}
fn compute_ece_mce(bins: &[CalibrationBin], total: usize) -> (f64, f64) {
if total == 0 {
return (0.0, 0.0);
}
let mut ece = 0.0;
let mut mce = 0.0f64;
for bin in bins {
if bin.count > 0 {
let gap = (bin.avg_accuracy - bin.avg_confidence).abs();
ece += (bin.count as f64 / total as f64) * gap;
mce = mce.max(gap);
}
}
(ece, mce)
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_compute_bins_perfect_calibration() {
let dets: Vec<Detection> = (0..100)
.map(|i| {
let conf = (i as f64 + 0.5) / 100.0;
Detection {
confidence: conf,
correct: i % 2 == 0, }
})
.collect();
let bins = compute_bins(&dets, 10);
assert_eq!(bins.len(), 10);
for bin in &bins {
assert_eq!(bin.count, 10);
}
}
#[test]
fn test_ece_all_correct_high_confidence() {
let dets: Vec<Detection> = (0..100)
.map(|_| Detection {
confidence: 0.95,
correct: true,
})
.collect();
let bins = compute_bins(&dets, 10);
let (ece, mce) = compute_ece_mce(&bins, dets.len());
assert!((ece - 0.05).abs() < 1e-9);
assert!((mce - 0.05).abs() < 1e-9);
}
#[test]
fn test_ece_empty() {
let dets: Vec<Detection> = vec![];
let bins = compute_bins(&dets, 10);
let (ece, mce) = compute_ece_mce(&bins, 0);
assert_eq!(ece, 0.0);
assert_eq!(mce, 0.0);
}
#[test]
fn test_ece_single_bin_overconfident() {
let mut dets = Vec::new();
for i in 0..100 {
dets.push(Detection {
confidence: 0.9,
correct: i < 50,
});
}
let bins = compute_bins(&dets, 10);
let (ece, mce) = compute_ece_mce(&bins, dets.len());
assert!((ece - 0.4).abs() < 1e-9);
assert!((mce - 0.4).abs() < 1e-9);
}
}