use crate::accumulate::PerImageEval;
use crate::error::EvalError;
use crate::parity::{quantile_linear, ParityMode};
use crate::summarize::pairwise_sum;
const Z_95: f64 = 1.959_963_984_540_054;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Binning {
Quantile,
EqualWidth,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConfidenceKind {
Wilson,
ClopperPearson,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Aggregation {
Macro,
Micro,
}
#[derive(Debug, Clone)]
pub struct CalibrationParams {
pub iou_index: usize,
pub n_bins: usize,
pub binning: Binning,
pub min_score: f64,
pub confidence: ConfidenceKind,
pub per_class: bool,
pub per_class_aggregation: Aggregation,
}
impl Default for CalibrationParams {
fn default() -> Self {
Self {
iou_index: 0,
n_bins: 15,
binning: Binning::Quantile,
min_score: 0.05,
confidence: ConfidenceKind::Wilson,
per_class: false,
per_class_aggregation: Aggregation::Macro,
}
}
}
#[derive(Debug, Clone)]
pub struct ReliabilityTable {
pub bin_id: Vec<u32>,
pub score_lo: Vec<f64>,
pub score_hi: Vec<f64>,
pub mean_score: Vec<f64>,
pub accuracy: Vec<f64>,
pub count: Vec<u64>,
pub gap: Vec<f64>,
pub ci_lo: Vec<f64>,
pub ci_hi: Vec<f64>,
}
#[derive(Debug, Clone)]
pub struct PerClassTable {
pub class_id: Vec<u32>,
pub ece: Vec<f64>,
pub mce: Vec<f64>,
pub n: Vec<u64>,
}
#[derive(Debug, Clone)]
pub struct CalibrationSummary {
pub ece: f64,
pub mce: f64,
pub n_detections: u64,
pub effective_n_bins: usize,
pub reliability: ReliabilityTable,
pub per_class: Option<PerClassTable>,
}
#[derive(Debug, Clone, Copy)]
struct Detection {
score: f64,
correct: f64,
class: u32,
}
fn collect_detections(
eval_imgs: &[Option<Box<PerImageEval>>],
n_categories: usize,
n_area_ranges: usize,
iou_index: usize,
min_score: f64,
) -> Result<Vec<Detection>, EvalError> {
let n_i = if n_categories == 0 || n_area_ranges == 0 {
0
} else {
eval_imgs.len() / (n_categories * n_area_ranges)
};
let expected = n_categories * n_area_ranges * n_i;
if eval_imgs.len() != expected {
return Err(EvalError::DimensionMismatch {
detail: format!(
"eval_imgs len {} != n_categories({}) * n_area_ranges({}) * n_images({}) = {}",
eval_imgs.len(),
n_categories,
n_area_ranges,
n_i,
expected
),
});
}
let mut out: Vec<Detection> = Vec::new();
let area_idx: usize = 0;
if n_area_ranges == 0 || area_idx >= n_area_ranges {
return Ok(out);
}
for k in 0..n_categories {
let nk = k * n_area_ranges * n_i;
let na = area_idx * n_i;
for i in 0..n_i {
let cell = match eval_imgs[nk + na + i].as_deref() {
Some(c) => c,
None => continue,
};
validate_cell(cell, iou_index)?;
let n_d = cell.dt_scores.len();
for d in 0..n_d {
if cell.dt_ignore[(iou_index, d)] {
continue;
}
let s = cell.dt_scores[d];
if !s.is_finite() {
return Err(EvalError::NonFinite {
context: "calibration::dt_scores",
});
}
if s < min_score {
continue;
}
let correct = if cell.dt_matched[(iou_index, d)] {
1.0_f64
} else {
0.0_f64
};
let class: u32 = u32::try_from(k).map_err(|_| EvalError::InvalidConfig {
detail: format!("class index {k} does not fit in u32"),
})?;
out.push(Detection {
score: s,
correct,
class,
});
}
}
}
Ok(out)
}
fn validate_cell(cell: &PerImageEval, iou_index: usize) -> Result<(), EvalError> {
if cell.dt_matched.shape() != cell.dt_ignore.shape() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"PerImageEval.dt_matched {:?} != dt_ignore {:?}",
cell.dt_matched.shape(),
cell.dt_ignore.shape()
),
});
}
if cell.dt_matched.ncols() != cell.dt_scores.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"PerImageEval.dt_matched cols {} != dt_scores len {}",
cell.dt_matched.ncols(),
cell.dt_scores.len()
),
});
}
if iou_index >= cell.dt_matched.nrows() {
return Err(EvalError::InvalidConfig {
detail: format!(
"iou_index {iou_index} out of range for PerImageEval with T={}",
cell.dt_matched.nrows()
),
});
}
Ok(())
}
fn build_edges(
scores_sorted: &[f64],
n_bins: usize,
binning: Binning,
min_score: f64,
) -> (Vec<f64>, usize) {
if scores_sorted.is_empty() || n_bins == 0 {
return (Vec::new(), 0);
}
let edges = match binning {
Binning::Quantile => {
let n_bins_f = n_bins as f64;
let qs: Vec<f64> = (0..=n_bins).map(|i| (i as f64) / n_bins_f).collect();
quantile_linear(scores_sorted, &qs)
}
Binning::EqualWidth => {
let observed_min = scores_sorted[0];
let lo = if min_score > observed_min {
min_score
} else {
observed_min
};
crate::parity::linspace(lo, 1.0_f64, n_bins + 1)
}
};
let mut deduped: Vec<f64> = Vec::with_capacity(edges.len());
for e in edges {
if let Some(&last) = deduped.last() {
if last == e {
continue;
}
}
deduped.push(e);
}
let effective = if deduped.len() < 2 {
0
} else {
deduped.len() - 1
};
(deduped, effective)
}
fn assign_bin(score: f64, edges: &[f64]) -> Option<usize> {
if edges.len() < 2 {
return None;
}
let last_edge = edges[edges.len() - 1];
let first_edge = edges[0];
if score < first_edge || score > last_edge {
return None;
}
for b in 0..(edges.len() - 1) {
let lo = edges[b];
let hi = edges[b + 1];
let in_bin = if b + 1 == edges.len() - 1 {
score >= lo && score <= hi
} else {
score >= lo && score < hi
};
if in_bin {
return Some(b);
}
}
None
}
fn wilson_ci(correct: f64, count: u64) -> (f64, f64) {
let n = count as f64;
let phat = correct / n;
let z = Z_95;
let zz = z * z;
let denom = 1.0 + zz / n;
let center = (phat + zz / (2.0 * n)) / denom;
let margin = (z / denom)
* (phat * (1.0 - phat) / n + zz / (4.0 * n * n))
.max(0.0)
.sqrt();
(center - margin, center + margin)
}
fn build_reliability(
detections: &[Detection],
edges: &[f64],
effective_n_bins: usize,
confidence: ConfidenceKind,
) -> Result<(ReliabilityTable, f64, f64), EvalError> {
let mut bin_id: Vec<u32> = Vec::with_capacity(effective_n_bins);
let mut score_lo: Vec<f64> = Vec::with_capacity(effective_n_bins);
let mut score_hi: Vec<f64> = Vec::with_capacity(effective_n_bins);
let mut mean_score: Vec<f64> = Vec::with_capacity(effective_n_bins);
let mut accuracy: Vec<f64> = Vec::with_capacity(effective_n_bins);
let mut count: Vec<u64> = Vec::with_capacity(effective_n_bins);
let mut gap: Vec<f64> = Vec::with_capacity(effective_n_bins);
let mut ci_lo: Vec<f64> = Vec::with_capacity(effective_n_bins);
let mut ci_hi: Vec<f64> = Vec::with_capacity(effective_n_bins);
if effective_n_bins == 0 || edges.len() < 2 {
return Ok((
ReliabilityTable {
bin_id,
score_lo,
score_hi,
mean_score,
accuracy,
count,
gap,
ci_lo,
ci_hi,
},
f64::NAN,
f64::NAN,
));
}
let mut per_bin_scores: Vec<Vec<f64>> = vec![Vec::new(); effective_n_bins];
let mut per_bin_correct: Vec<Vec<f64>> = vec![Vec::new(); effective_n_bins];
for det in detections {
if let Some(b) = assign_bin(det.score, edges) {
per_bin_scores[b].push(det.score);
per_bin_correct[b].push(det.correct);
}
}
let total_n: u64 = per_bin_scores.iter().map(|v| v.len() as u64).sum();
let total_n_f = total_n as f64;
let mut ece_acc = 0.0_f64;
let mut mce_acc: f64 = 0.0_f64;
let mut any_nonempty = false;
for b in 0..effective_n_bins {
let scores_b = &per_bin_scores[b];
let correct_b = &per_bin_correct[b];
let n_b = scores_b.len() as u64;
bin_id.push(u32::try_from(b).map_err(|_| EvalError::InvalidConfig {
detail: format!("bin id {b} does not fit in u32"),
})?);
score_lo.push(edges[b]);
score_hi.push(edges[b + 1]);
count.push(n_b);
if n_b == 0 {
mean_score.push(f64::NAN);
accuracy.push(f64::NAN);
gap.push(f64::NAN);
ci_lo.push(f64::NAN);
ci_hi.push(f64::NAN);
continue;
}
any_nonempty = true;
let sum_s = pairwise_sum(scores_b);
let sum_c = pairwise_sum(correct_b);
let n_b_f = n_b as f64;
let mean_s = sum_s / n_b_f;
let acc = sum_c / n_b_f;
let g = acc - mean_s;
mean_score.push(mean_s);
accuracy.push(acc);
gap.push(g);
match confidence {
ConfidenceKind::Wilson => {
let (lo_ci, hi_ci) = wilson_ci(sum_c, n_b);
ci_lo.push(lo_ci);
ci_hi.push(hi_ci);
}
ConfidenceKind::ClopperPearson => {
return Err(EvalError::InvalidConfig {
detail: "Clopper-Pearson CI not yet implemented; use Wilson".to_string(),
});
}
}
let abs_gap = g.abs();
if total_n > 0 {
ece_acc += (n_b_f / total_n_f) * abs_gap;
}
if abs_gap > mce_acc {
mce_acc = abs_gap;
}
}
let (ece, mce) = if any_nonempty {
(ece_acc, mce_acc)
} else {
(f64::NAN, f64::NAN)
};
Ok((
ReliabilityTable {
bin_id,
score_lo,
score_hi,
mean_score,
accuracy,
count,
gap,
ci_lo,
ci_hi,
},
ece,
mce,
))
}
fn empty_summary() -> CalibrationSummary {
CalibrationSummary {
ece: f64::NAN,
mce: f64::NAN,
n_detections: 0,
effective_n_bins: 0,
reliability: ReliabilityTable {
bin_id: Vec::new(),
score_lo: Vec::new(),
score_hi: Vec::new(),
mean_score: Vec::new(),
accuracy: Vec::new(),
count: Vec::new(),
gap: Vec::new(),
ci_lo: Vec::new(),
ci_hi: Vec::new(),
},
per_class: None,
}
}
pub fn summarize_calibration(
eval_imgs: &[Option<Box<PerImageEval>>],
n_categories: usize,
n_area_ranges: usize,
params: &CalibrationParams,
_parity_mode: ParityMode,
) -> Result<CalibrationSummary, EvalError> {
if params.n_bins == 0 {
return Err(EvalError::InvalidConfig {
detail: "calibration n_bins must be > 0".to_string(),
});
}
for cell in eval_imgs.iter().flatten() {
let t = cell.dt_matched.nrows();
if params.iou_index >= t {
return Err(EvalError::InvalidConfig {
detail: format!(
"calibration iou_index {} out of range for T={t}",
params.iou_index
),
});
}
}
let detections = collect_detections(
eval_imgs,
n_categories,
n_area_ranges,
params.iou_index,
params.min_score,
)?;
if detections.is_empty() {
return Ok(empty_summary());
}
let mut scores_sorted: Vec<f64> = detections.iter().map(|d| d.score).collect();
scores_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let (edges, effective_n_bins) = build_edges(
&scores_sorted,
params.n_bins,
params.binning,
params.min_score,
);
let (reliability, ece, mce) =
build_reliability(&detections, &edges, effective_n_bins, params.confidence)?;
let n_detections: u64 = reliability.count.iter().sum();
let per_class = if params.per_class {
Some(build_per_class(
&detections,
&edges,
effective_n_bins,
params.confidence,
)?)
} else {
None
};
Ok(CalibrationSummary {
ece,
mce,
n_detections,
effective_n_bins,
reliability,
per_class,
})
}
fn build_per_class(
detections: &[Detection],
edges: &[f64],
effective_n_bins: usize,
confidence: ConfidenceKind,
) -> Result<PerClassTable, EvalError> {
let mut by_class: std::collections::BTreeMap<u32, Vec<Detection>> =
std::collections::BTreeMap::new();
for det in detections {
by_class.entry(det.class).or_default().push(*det);
}
let mut class_id: Vec<u32> = Vec::with_capacity(by_class.len());
let mut ece_col: Vec<f64> = Vec::with_capacity(by_class.len());
let mut mce_col: Vec<f64> = Vec::with_capacity(by_class.len());
let mut n_col: Vec<u64> = Vec::with_capacity(by_class.len());
for (k, dets) in by_class {
let (table, ece_k, mce_k) = build_reliability(&dets, edges, effective_n_bins, confidence)?;
let n_k: u64 = table.count.iter().sum();
class_id.push(k);
ece_col.push(ece_k);
mce_col.push(mce_k);
n_col.push(n_k);
}
Ok(PerClassTable {
class_id,
ece: ece_col,
mce: mce_col,
n: n_col,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn build_grid(
per_class_data: &[(Vec<f64>, Vec<bool>, Vec<bool>)],
) -> Vec<Option<Box<PerImageEval>>> {
let mut grid: Vec<Option<Box<PerImageEval>>> = Vec::new();
for (scores, matched, ignored) in per_class_data {
let d = scores.len();
assert_eq!(matched.len(), d);
assert_eq!(ignored.len(), d);
if d == 0 {
grid.push(None);
continue;
}
let mut dt_matched: Array2<bool> = Array2::from_elem((1, d), false);
let mut dt_ignore: Array2<bool> = Array2::from_elem((1, d), false);
for (j, &m) in matched.iter().enumerate() {
dt_matched[(0, j)] = m;
}
for (j, &ig) in ignored.iter().enumerate() {
dt_ignore[(0, j)] = ig;
}
grid.push(Some(Box::new(PerImageEval {
dt_scores: scores.clone(),
dt_matched,
dt_ignore,
gt_ignore: vec![],
})));
}
grid
}
fn default_params() -> CalibrationParams {
CalibrationParams {
min_score: 0.0,
..CalibrationParams::default()
}
}
#[test]
fn n_bins_zero_returns_error() {
let grid = build_grid(&[(vec![0.5], vec![true], vec![false])]);
let params = CalibrationParams {
n_bins: 0,
..default_params()
};
let err = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap_err();
match err {
EvalError::InvalidConfig { detail } => {
assert!(detail.contains("n_bins"), "got: {detail}");
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn iou_index_out_of_range_returns_error() {
let grid = build_grid(&[(vec![0.5], vec![true], vec![false])]);
let params = CalibrationParams {
iou_index: 5,
..default_params()
};
let err = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap_err();
match err {
EvalError::InvalidConfig { detail } => {
assert!(detail.contains("iou_index"), "got: {detail}");
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn clopper_pearson_is_phase2_error() {
let grid = build_grid(&[(vec![0.5, 0.6], vec![true, false], vec![false, false])]);
let params = CalibrationParams {
confidence: ConfidenceKind::ClopperPearson,
n_bins: 2,
..default_params()
};
let err = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap_err();
match err {
EvalError::InvalidConfig { detail } => {
assert!(detail.contains("Clopper-Pearson"), "got: {detail}");
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn empty_input_returns_empty_summary_not_error() {
let grid: Vec<Option<Box<PerImageEval>>> = vec![None];
let params = default_params();
let out = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap();
assert_eq!(out.n_detections, 0);
assert_eq!(out.effective_n_bins, 0);
assert!(out.ece.is_nan());
assert!(out.mce.is_nan());
assert!(out.reliability.bin_id.is_empty());
assert!(out.per_class.is_none());
}
#[test]
fn min_score_cutoff_excludes_low_score_detections() {
let grid = build_grid(&[(
vec![0.01, 0.04, 0.6, 0.9],
vec![true, false, true, true],
vec![false, false, false, false],
)]);
let params = CalibrationParams {
min_score: 0.05,
n_bins: 2,
..CalibrationParams::default()
};
let out = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap();
assert_eq!(out.n_detections, 2);
}
#[test]
fn ignore_region_detections_drop_from_histogram() {
let grid = build_grid(&[(
vec![0.4, 0.5, 0.6],
vec![true, false, true],
vec![false, true, false],
)]);
let params = CalibrationParams {
n_bins: 2,
..default_params()
};
let out = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap();
assert_eq!(out.n_detections, 2);
}
#[test]
fn bin_edge_degeneracy_merges_to_fewer_bins() {
let grid = build_grid(&[(
vec![0.5; 5],
vec![true, true, false, true, false],
vec![false; 5],
)]);
let params = CalibrationParams {
n_bins: 10,
binning: Binning::Quantile,
..default_params()
};
let out = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap();
assert!(out.effective_n_bins < 10);
assert_eq!(out.reliability.bin_id.len(), out.effective_n_bins);
assert_eq!(out.reliability.score_lo.len(), out.effective_n_bins);
}
#[test]
fn zero_count_bin_emits_nan_under_equal_width() {
let grid = build_grid(&[(
vec![0.0, 0.1, 0.95],
vec![true, false, true],
vec![false, false, false],
)]);
let params = CalibrationParams {
n_bins: 4,
binning: Binning::EqualWidth,
min_score: 0.0,
..CalibrationParams::default()
};
let out = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap();
let mut found_empty = false;
for (b, &c) in out.reliability.count.iter().enumerate() {
if c == 0 {
found_empty = true;
assert!(out.reliability.accuracy[b].is_nan());
assert!(out.reliability.mean_score[b].is_nan());
assert!(out.reliability.gap[b].is_nan());
assert!(out.reliability.ci_lo[b].is_nan());
assert!(out.reliability.ci_hi[b].is_nan());
}
}
assert!(found_empty, "expected at least one empty bin");
}
#[test]
fn perfect_all_correct_high_score_ece_matches_gap() {
let grid = build_grid(&[(
vec![0.9, 0.95, 0.99],
vec![true, true, true],
vec![false, false, false],
)]);
let params = CalibrationParams {
n_bins: 1,
binning: Binning::Quantile,
min_score: 0.0,
..CalibrationParams::default()
};
let out = summarize_calibration(&grid, 1, 1, ¶ms, ParityMode::Strict).unwrap();
assert_eq!(out.n_detections, 3);
let expected_mean = (0.9 + 0.95 + 0.99) / 3.0;
let expected_ece = (1.0_f64 - expected_mean).abs();
assert!(
(out.ece - expected_ece).abs() < 1e-12,
"ece={} expected~{}",
out.ece,
expected_ece
);
assert!((out.mce - expected_ece).abs() < 1e-12);
}
#[test]
fn per_class_breakdown_sums_to_total_detections() {
let grid = build_grid(&[
(
vec![0.2, 0.4, 0.6],
vec![false, false, true],
vec![false; 3],
),
(vec![0.8, 0.9], vec![true, true], vec![false; 2]),
]);
let params = CalibrationParams {
n_bins: 2,
per_class: true,
min_score: 0.0,
..CalibrationParams::default()
};
let out = summarize_calibration(&grid, 2, 1, ¶ms, ParityMode::Strict).unwrap();
let pc = out.per_class.expect("per_class table");
assert_eq!(pc.class_id, vec![0, 1]);
let pc_sum: u64 = pc.n.iter().sum();
assert_eq!(pc_sum, out.n_detections);
assert_eq!(pc.n.len(), 2);
assert_eq!(pc.ece.len(), 2);
assert_eq!(pc.mce.len(), 2);
}
#[test]
fn identical_cells_produce_identical_summaries_iou_type_genericity() {
let grid_a = build_grid(&[(vec![0.3, 0.7], vec![false, true], vec![false, false])]);
let grid_b = build_grid(&[(vec![0.3, 0.7], vec![false, true], vec![false, false])]);
let params = CalibrationParams {
n_bins: 2,
min_score: 0.0,
..CalibrationParams::default()
};
let out_a = summarize_calibration(&grid_a, 1, 1, ¶ms, ParityMode::Strict).unwrap();
let out_b = summarize_calibration(&grid_b, 1, 1, ¶ms, ParityMode::Strict).unwrap();
assert_eq!(out_a.n_detections, out_b.n_detections);
assert_eq!(out_a.effective_n_bins, out_b.effective_n_bins);
assert_eq!(out_a.reliability.count, out_b.reliability.count);
assert_eq!(out_a.ece.to_bits(), out_b.ece.to_bits());
assert_eq!(out_a.mce.to_bits(), out_b.mce.to_bits());
}
#[test]
fn wilson_ci_known_values() {
let (lo, hi) = wilson_ci(8.0, 10);
assert!((lo - 0.490_162_471_5).abs() < 1e-9, "lo={lo}");
assert!((hi - 0.943_317_848_5).abs() < 1e-9, "hi={hi}");
}
}