use std::collections::{HashMap, HashSet};
use rand::Rng;
use rand::SeedableRng;
use rand::rngs::SmallRng;
use rayon::prelude::*;
use serde::Serialize;
use super::COCOeval;
use super::accumulate::accumulate_impl;
use super::summarize::{MetricDef, build_metric_defs, per_cat_ap_static, summarize_impl};
#[derive(Debug, Clone)]
pub struct CompareOpts {
pub n_bootstrap: usize,
pub seed: u64,
pub confidence: f64,
}
impl Default for CompareOpts {
fn default() -> Self {
Self {
n_bootstrap: 0,
seed: 42,
confidence: 0.95,
}
}
}
#[derive(Debug, Clone, Serialize)]
pub struct BootstrapCI {
pub lower: f64,
pub upper: f64,
pub confidence: f64,
pub prob_positive: f64,
pub std_err: f64,
}
#[derive(Debug, Clone, Serialize)]
pub struct CategoryDelta {
pub cat_id: u64,
pub cat_name: String,
pub ap_a: f64,
pub ap_b: f64,
pub delta: f64,
}
#[derive(Debug, Clone, Serialize)]
pub struct ComparisonResult {
pub metric_keys: Vec<String>,
pub metrics_a: HashMap<String, f64>,
pub metrics_b: HashMap<String, f64>,
pub deltas: HashMap<String, f64>,
pub ci: Option<HashMap<String, BootstrapCI>>,
pub per_category: Vec<CategoryDelta>,
pub n_bootstrap: usize,
pub num_images: usize,
}
pub fn compare(
eval_a: &COCOeval,
eval_b: &COCOeval,
opts: &CompareOpts,
) -> crate::error::Result<ComparisonResult> {
if eval_a.eval_imgs.is_empty() {
return Err("evaluate() must be called on eval_a before compare()".into());
}
if eval_b.eval_imgs.is_empty() {
return Err("evaluate() must be called on eval_b before compare()".into());
}
if eval_a.eval_mode != eval_b.eval_mode {
return Err(format!(
"eval_mode mismatch: {:?} vs {:?}",
eval_a.eval_mode, eval_b.eval_mode
)
.into());
}
if eval_a.params.iou_type != eval_b.params.iou_type {
return Err(format!(
"iou_type mismatch: {:?} vs {:?}",
eval_a.params.iou_type, eval_b.params.iou_type
)
.into());
}
if opts.confidence <= 0.0 || opts.confidence >= 1.0 {
return Err(format!("confidence must be in (0, 1), got {}", opts.confidence).into());
}
let imgs_a: HashSet<u64> = eval_a.params.img_ids.iter().copied().collect();
let imgs_b: HashSet<u64> = eval_b.params.img_ids.iter().copied().collect();
let shared_set: HashSet<u64> = imgs_a.intersection(&imgs_b).copied().collect();
if shared_set.is_empty() {
return Err("no shared images between eval_a and eval_b".into());
}
let num_images = shared_set.len();
let shared_sorted: Vec<u64> = {
let mut v: Vec<u64> = shared_set.iter().copied().collect();
v.sort_unstable();
v
};
let metrics = build_metric_defs(&eval_a.params, eval_a.eval_mode);
let metric_keys: Vec<&str> = metrics.iter().map(|m| m.name).collect();
let acc_a = accumulate_impl(&eval_a.eval_imgs, &eval_a.params, Some(&shared_set));
let stats_a = summarize_impl(
&acc_a,
&eval_a.params,
eval_a.eval_mode,
&eval_a.freq_groups,
&metrics,
);
let acc_b = accumulate_impl(&eval_b.eval_imgs, &eval_b.params, Some(&shared_set));
let stats_b = summarize_impl(
&acc_b,
&eval_b.params,
eval_b.eval_mode,
&eval_b.freq_groups,
&metrics,
);
let metrics_a = stats_to_map(&metric_keys, &stats_a);
let metrics_b = stats_to_map(&metric_keys, &stats_b);
let deltas: HashMap<String, f64> = metric_keys
.iter()
.zip(stats_a.iter().zip(stats_b.iter()))
.map(|(&k, (&a, &b))| {
let d = if a >= 0.0 && b >= 0.0 { b - a } else { 0.0 };
(k.to_string(), d)
})
.collect();
let per_cat_a = per_cat_ap_static(&acc_a, &eval_a.params);
let per_cat_b = per_cat_ap_static(&acc_b, &eval_b.params);
let mut per_category: Vec<CategoryDelta> = eval_a
.params
.cat_ids
.iter()
.enumerate()
.filter_map(|(i, &cat_id)| {
let ap_a = per_cat_a.get(i).copied().unwrap_or(-1.0);
let ap_b = per_cat_b.get(i).copied().unwrap_or(-1.0);
if ap_a < 0.0 && ap_b < 0.0 {
return None;
}
let cat_name = eval_a
.coco_gt
.get_cat(cat_id)
.map_or_else(|| cat_id.to_string(), |c| c.name.clone());
let delta = match (ap_a >= 0.0, ap_b >= 0.0) {
(true, true) => ap_b - ap_a,
(false, true) => ap_b,
(true, false) => -ap_a,
(false, false) => unreachable!("both < 0 filtered above"),
};
Some(CategoryDelta {
cat_id,
cat_name,
ap_a: if ap_a >= 0.0 { ap_a } else { -1.0 },
ap_b: if ap_b >= 0.0 { ap_b } else { -1.0 },
delta,
})
})
.collect();
per_category.sort_by(|a, b| {
a.delta
.partial_cmp(&b.delta)
.unwrap_or(std::cmp::Ordering::Equal)
});
let ci = if opts.n_bootstrap > 0 {
Some(bootstrap_compare(
eval_a,
eval_b,
&shared_sorted,
opts,
&metrics,
&metric_keys,
))
} else {
None
};
Ok(ComparisonResult {
metric_keys: metric_keys.iter().map(|&k| k.to_string()).collect(),
metrics_a,
metrics_b,
deltas,
ci,
per_category,
n_bootstrap: opts.n_bootstrap,
num_images,
})
}
fn bootstrap_compare(
eval_a: &COCOeval,
eval_b: &COCOeval,
shared_img_ids: &[u64],
opts: &CompareOpts,
metrics: &[MetricDef],
metric_keys: &[&str],
) -> HashMap<String, BootstrapCI> {
let n = shared_img_ids.len();
let all_deltas: Vec<Vec<f64>> = (0..opts.n_bootstrap)
.into_par_iter()
.map(|i| {
let mut rng = SmallRng::seed_from_u64(opts.seed.wrapping_add(i as u64));
let sample: HashSet<u64> = (0..n)
.map(|_| shared_img_ids[rng.random_range(0..n)])
.collect();
let acc_a = accumulate_impl(&eval_a.eval_imgs, &eval_a.params, Some(&sample));
let stats_a = summarize_impl(
&acc_a,
&eval_a.params,
eval_a.eval_mode,
&eval_a.freq_groups,
metrics,
);
let acc_b = accumulate_impl(&eval_b.eval_imgs, &eval_b.params, Some(&sample));
let stats_b = summarize_impl(
&acc_b,
&eval_b.params,
eval_b.eval_mode,
&eval_b.freq_groups,
metrics,
);
stats_a
.iter()
.zip(stats_b.iter())
.map(|(&a, &b)| if a >= 0.0 && b >= 0.0 { b - a } else { 0.0 })
.collect()
})
.collect();
let alpha = 1.0 - opts.confidence;
let nb = opts.n_bootstrap;
metric_keys
.iter()
.enumerate()
.map(|(m, &name)| {
let mut samples: Vec<f64> = all_deltas.iter().map(|d| d[m]).collect();
samples.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let lo_idx = ((alpha / 2.0) * nb as f64).floor() as usize;
let hi_idx = ((1.0 - alpha / 2.0) * nb as f64).ceil() as usize;
let lower = samples[lo_idx.min(nb - 1)];
let upper = samples[hi_idx.min(nb - 1)];
let pos_count = samples.iter().filter(|&&x| x > 0.0).count();
let prob_positive = pos_count as f64 / nb as f64;
let mean: f64 = samples.iter().sum::<f64>() / nb as f64;
let variance = if nb > 1 {
samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (nb - 1) as f64
} else {
0.0
};
(
name.to_string(),
BootstrapCI {
lower,
upper,
confidence: opts.confidence,
prob_positive,
std_err: variance.sqrt(),
},
)
})
.collect()
}
fn stats_to_map(metric_keys: &[&str], stats: &[f64]) -> HashMap<String, f64> {
metric_keys
.iter()
.zip(stats.iter())
.map(|(&k, &v)| (k.to_string(), v))
.collect()
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
use crate::COCO;
use crate::params::IouType;
use std::path::PathBuf;
fn fixtures_dir() -> PathBuf {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/fixtures")
}
fn make_eval() -> COCOeval {
let gt = COCO::new(&fixtures_dir().join("gt.json")).unwrap();
let dt = gt.load_res(&fixtures_dir().join("dt.json")).unwrap();
let mut ev = COCOeval::new(gt, dt, IouType::Bbox);
ev.evaluate();
ev
}
#[test]
fn test_compare_identical_models() {
let ev_a = make_eval();
let ev_b = make_eval();
let result = compare(&ev_a, &ev_b, &CompareOpts::default()).unwrap();
assert_eq!(result.num_images, ev_a.params.img_ids.len());
for (key, delta) in &result.deltas {
assert!(
delta.abs() < 1e-10,
"expected zero delta for {key}, got {delta}"
);
}
assert!(result.ci.is_none());
}
#[test]
fn test_compare_with_bootstrap() {
let ev_a = make_eval();
let ev_b = make_eval();
let opts = CompareOpts {
n_bootstrap: 50,
seed: 42,
confidence: 0.95,
};
let result = compare(&ev_a, &ev_b, &opts).unwrap();
assert!(result.ci.is_some());
let ci = result.ci.as_ref().unwrap();
assert!(!ci.is_empty());
for (key, boot_ci) in ci {
assert!(
boot_ci.lower.abs() < 0.1 && boot_ci.upper.abs() < 0.1,
"expected tight CI for {key}, got [{}, {}]",
boot_ci.lower,
boot_ci.upper
);
assert_eq!(boot_ci.confidence, 0.95);
}
}
#[test]
fn test_bootstrap_seed_reproducibility() {
let ev_a = make_eval();
let ev_b = make_eval();
let opts = CompareOpts {
n_bootstrap: 20,
seed: 123,
confidence: 0.95,
};
let r1 = compare(&ev_a, &ev_b, &opts).unwrap();
let r2 = compare(&ev_a, &ev_b, &opts).unwrap();
let ci1 = r1.ci.as_ref().unwrap();
let ci2 = r2.ci.as_ref().unwrap();
for key in ci1.keys() {
assert_eq!(
ci1[key].lower, ci2[key].lower,
"CI lower mismatch for {key}"
);
assert_eq!(
ci1[key].upper, ci2[key].upper,
"CI upper mismatch for {key}"
);
}
}
#[test]
fn test_compare_before_evaluate_errors() {
let gt = COCO::new(&fixtures_dir().join("gt.json")).unwrap();
let dt = gt.load_res(&fixtures_dir().join("dt.json")).unwrap();
let ev_a = COCOeval::new(gt, dt, IouType::Bbox);
let gt2 = COCO::new(&fixtures_dir().join("gt.json")).unwrap();
let dt2 = gt2.load_res(&fixtures_dir().join("dt.json")).unwrap();
let ev_b = COCOeval::new(gt2, dt2, IouType::Bbox);
let err = compare(&ev_a, &ev_b, &CompareOpts::default()).unwrap_err();
assert!(err.to_string().contains("evaluate()"));
}
#[test]
fn test_compare_per_category() {
let ev_a = make_eval();
let ev_b = make_eval();
let result = compare(&ev_a, &ev_b, &CompareOpts::default()).unwrap();
for cat in &result.per_category {
assert!(
cat.delta.abs() < 1e-10,
"expected zero delta for {}, got {}",
cat.cat_name,
cat.delta
);
assert!(cat.ap_a >= 0.0 || cat.ap_b >= 0.0);
}
}
#[test]
fn test_compare_invalid_confidence() {
let ev_a = make_eval();
let ev_b = make_eval();
let opts = CompareOpts {
confidence: 1.5,
..Default::default()
};
let err = compare(&ev_a, &ev_b, &opts).unwrap_err();
assert!(err.to_string().contains("confidence"));
}
}