use ndarray::{Array2, Array4, Array5, Axis};
use crate::error::EvalError;
use crate::parity::{argsort_score_desc, ParityMode, PARITY_EPS};
#[derive(Debug, Clone)]
pub struct PerImageEval {
pub dt_scores: Vec<f64>,
pub dt_matched: Array2<bool>,
pub dt_ignore: Array2<bool>,
pub gt_ignore: Vec<bool>,
}
#[derive(Debug, Clone, Copy)]
pub struct AccumulateParams<'p> {
pub iou_thresholds: &'p [f64],
pub recall_thresholds: &'p [f64],
pub max_dets: &'p [usize],
pub n_categories: usize,
pub n_area_ranges: usize,
pub n_images: usize,
}
pub fn sort_max_dets(max_dets: &mut [usize]) {
max_dets.sort();
}
#[derive(Debug, Clone)]
pub struct Accumulated {
pub precision: Array5<f64>,
pub recall: Array4<f64>,
pub scores: Array5<f64>,
}
pub fn accumulate(
eval_imgs: &[Option<Box<PerImageEval>>],
p: AccumulateParams<'_>,
_parity_mode: ParityMode,
) -> Result<Accumulated, EvalError> {
let n_t = p.iou_thresholds.len();
let n_r = p.recall_thresholds.len();
let n_k = p.n_categories;
let n_a = p.n_area_ranges;
let n_m = p.max_dets.len();
let n_i = p.n_images;
let expected = n_k * n_a * n_i;
if eval_imgs.len() != expected {
return Err(EvalError::DimensionMismatch {
detail: format!(
"eval_imgs len {} != n_categories({}) * n_area_ranges({}) * n_images({}) = {}",
eval_imgs.len(),
n_k,
n_a,
n_i,
expected
),
});
}
for cell in eval_imgs.iter().flatten() {
if cell.dt_matched.shape() != cell.dt_ignore.shape() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"PerImageEval.dt_matched {:?} != dt_ignore {:?}",
cell.dt_matched.shape(),
cell.dt_ignore.shape()
),
});
}
if cell.dt_matched.nrows() != n_t {
return Err(EvalError::DimensionMismatch {
detail: format!(
"PerImageEval row count {} != iou_thresholds len {}",
cell.dt_matched.nrows(),
n_t
),
});
}
if cell.dt_matched.ncols() != cell.dt_scores.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"PerImageEval.dt_matched cols {} != dt_scores len {}",
cell.dt_matched.ncols(),
cell.dt_scores.len()
),
});
}
}
let mut precision = Array5::<f64>::from_elem((n_t, n_r, n_k, n_a, n_m), -1.0);
let mut recall = Array4::<f64>::from_elem((n_t, n_k, n_a, n_m), -1.0);
let mut scores = Array5::<f64>::from_elem((n_t, n_r, n_k, n_a, n_m), -1.0);
for k in 0..n_k {
let nk = k * n_a * n_i;
for a in 0..n_a {
let na = a * n_i;
let cells: Vec<&PerImageEval> = (0..n_i)
.filter_map(|i| eval_imgs[nk + na + i].as_deref())
.collect();
if cells.is_empty() {
continue;
}
let npig: usize = cells
.iter()
.map(|e| e.gt_ignore.iter().filter(|&&ig| !ig).count())
.sum();
if npig == 0 {
continue;
}
for (m, &max_det) in p.max_dets.iter().enumerate() {
accumulate_cell(
&cells,
max_det,
npig,
n_t,
p.recall_thresholds,
k,
a,
m,
&mut precision,
&mut recall,
&mut scores,
);
}
}
}
Ok(Accumulated {
precision,
recall,
scores,
})
}
#[allow(clippy::too_many_arguments)]
fn accumulate_cell(
cells: &[&PerImageEval],
max_det: usize,
npig: usize,
n_t: usize,
recall_thresholds: &[f64],
k: usize,
a: usize,
m: usize,
precision: &mut Array5<f64>,
recall: &mut Array4<f64>,
scores: &mut Array5<f64>,
) {
let mut takes: Vec<usize> = Vec::with_capacity(cells.len());
let mut total = 0usize;
for cell in cells {
let take = cell.dt_scores.len().min(max_det);
takes.push(take);
total += take;
}
let mut all_scores: Vec<f64> = Vec::with_capacity(total);
for (cell, &take) in cells.iter().zip(&takes) {
all_scores.extend_from_slice(&cell.dt_scores[..take]);
}
let n_d = all_scores.len();
if n_d == 0 {
for t in 0..n_t {
recall[(t, k, a, m)] = 0.0;
}
return;
}
let perm = argsort_score_desc(&all_scores);
let npig_f = npig as f64;
let mut rc = vec![0.0_f64; n_d];
let mut pr = vec![0.0_f64; n_d];
let mut dtm = vec![false; n_d];
let mut dtg = vec![false; n_d];
for t in 0..n_t {
let mut cursor = 0;
for (cell, &take) in cells.iter().zip(&takes) {
let m_row = cell.dt_matched.row(t);
let g_row = cell.dt_ignore.row(t);
for d in 0..take {
dtm[cursor] = m_row[d];
dtg[cursor] = g_row[d];
cursor += 1;
}
}
let mut tp = 0.0_f64;
let mut fp = 0.0_f64;
for (out_idx, &src_idx) in perm.iter().enumerate() {
if !dtg[src_idx] {
if dtm[src_idx] {
tp += 1.0;
} else {
fp += 1.0;
}
}
rc[out_idx] = tp / npig_f;
pr[out_idx] = tp / (tp + fp + PARITY_EPS);
}
recall[(t, k, a, m)] = rc[n_d - 1];
for j in (1..n_d).rev() {
if pr[j] > pr[j - 1] {
pr[j - 1] = pr[j];
}
}
let mut p_lane = precision
.index_axis_mut(Axis(0), t)
.index_axis_move(Axis(1), k)
.index_axis_move(Axis(1), a)
.index_axis_move(Axis(1), m);
let mut s_lane = scores
.index_axis_mut(Axis(0), t)
.index_axis_move(Axis(1), k)
.index_axis_move(Axis(1), a)
.index_axis_move(Axis(1), m);
for (ri, &target) in recall_thresholds.iter().enumerate() {
let pi = rc.partition_point(|&v| v < target);
if pi < n_d {
p_lane[ri] = pr[pi];
s_lane[ri] = all_scores[perm[pi]];
} else {
p_lane[ri] = 0.0;
s_lane[ri] = 0.0;
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn one_threshold_eval(
scores: Vec<f64>,
matched: Vec<bool>,
ignore: Vec<bool>,
gt_ignore: Vec<bool>,
) -> PerImageEval {
let n = scores.len();
let dt_matched =
Array2::from_shape_vec((1, n), matched).expect("dt_matched shape mismatch");
let dt_ignore = Array2::from_shape_vec((1, n), ignore).expect("dt_ignore shape mismatch");
PerImageEval {
dt_scores: scores,
dt_matched,
dt_ignore,
gt_ignore,
}
}
fn params<'p>(
iou: &'p [f64],
rec: &'p [f64],
max_dets: &'p [usize],
n_images: usize,
) -> AccumulateParams<'p> {
AccumulateParams {
iou_thresholds: iou,
recall_thresholds: rec,
max_dets,
n_categories: 1,
n_area_ranges: 1,
n_images,
}
}
#[test]
fn empty_grid_returns_all_sentinel() {
let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 0);
let out = accumulate(&[], p, ParityMode::Strict).unwrap();
assert!(out.precision.iter().all(|&v| v == -1.0));
assert!(out.recall.iter().all(|&v| v == -1.0));
}
#[test]
fn no_dt_with_real_gt_yields_zero_recall_and_sentinel_precision() {
let cell = PerImageEval {
dt_scores: vec![],
dt_matched: Array2::<bool>::default((2, 0)),
dt_ignore: Array2::<bool>::default((2, 0)),
gt_ignore: vec![false],
};
let p = params(&[0.5, 0.75], &[0.0, 0.5, 1.0], &[100], 1);
let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
assert_eq!(out.recall[(1, 0, 0, 0)], 0.0);
for ri in 0..3 {
assert_eq!(out.precision[(0, ri, 0, 0, 0)], -1.0);
assert_eq!(out.precision[(1, ri, 0, 0, 0)], -1.0);
}
}
#[test]
fn cell_with_only_ignore_gts_skips_entirely() {
let cell = one_threshold_eval(vec![0.9], vec![true], vec![true], vec![true]);
let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
assert_eq!(out.recall[(0, 0, 0, 0)], -1.0);
assert_eq!(out.precision[(0, 0, 0, 0, 0)], -1.0);
}
#[test]
fn perfect_match_yields_ap_one_and_ar_one() {
let cell = one_threshold_eval(vec![0.9], vec![true], vec![false], vec![false]);
let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
assert_eq!(out.recall[(0, 0, 0, 0)], 1.0);
for ri in 0..3 {
let pr = out.precision[(0, ri, 0, 0, 0)];
assert!((pr - 1.0).abs() < 1e-12, "precision[{ri}] = {pr}");
assert_eq!(out.scores[(0, ri, 0, 0, 0)], 0.9);
}
}
#[test]
fn lone_fp_yields_zero_recall_zero_precision() {
let cell = one_threshold_eval(vec![0.9], vec![false], vec![false], vec![false]);
let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
for ri in 0..3 {
assert!(out.precision[(0, ri, 0, 0, 0)].abs() < 1e-12);
}
assert_eq!(out.scores[(0, 0, 0, 0, 0)], 0.9);
assert_eq!(out.scores[(0, 1, 0, 0, 0)], 0.0);
assert_eq!(out.scores[(0, 2, 0, 0, 0)], 0.0);
}
#[test]
fn ignored_dt_does_not_count_as_fp() {
let cell = one_threshold_eval(
vec![0.9, 0.8],
vec![true, false],
vec![false, true],
vec![false],
);
let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
for ri in 0..3 {
let pr = out.precision[(0, ri, 0, 0, 0)];
assert!((pr - 1.0).abs() < 1e-12, "precision[{ri}] = {pr}");
}
assert_eq!(out.recall[(0, 0, 0, 0)], 1.0);
}
#[test]
fn precision_envelope_runs_right_to_left() {
let cell = one_threshold_eval(
vec![0.9, 0.8, 0.7],
vec![true, false, true],
vec![false, false, false],
vec![false, false],
);
let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 1);
let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
assert!((out.precision[(0, 0, 0, 0, 0)] - 1.0).abs() < 1e-12);
assert!((out.precision[(0, 1, 0, 0, 0)] - 1.0).abs() < 1e-12);
assert!((out.precision[(0, 2, 0, 0, 0)] - 2.0 / 3.0).abs() < 1e-12);
}
#[test]
fn partition_point_matches_numpy_searchsorted_left() {
let haystack = [0.1, 0.3, 0.3, 0.7];
let lookup = |t: f64| haystack.partition_point(|&v| v < t);
assert_eq!(lookup(0.0), 0);
assert_eq!(lookup(0.3), 1); assert_eq!(lookup(0.5), 3);
assert_eq!(lookup(1.0), 4); }
#[test]
fn merged_sort_breaks_ties_by_input_order() {
let img0 = one_threshold_eval(vec![0.7], vec![true], vec![false], vec![false]);
let img1 = one_threshold_eval(vec![0.7], vec![false], vec![false], vec![false]);
let grid = vec![Some(Box::new(img0)), Some(Box::new(img1))];
let p = params(&[0.5], &[0.0, 0.5, 1.0], &[100], 2);
let out = accumulate(&grid, p, ParityMode::Strict).unwrap();
assert!((out.precision[(0, 0, 0, 0, 0)] - 1.0).abs() < 1e-12);
assert!((out.precision[(0, 1, 0, 0, 0)] - 1.0).abs() < 1e-12);
assert_eq!(out.precision[(0, 2, 0, 0, 0)], 0.0);
}
#[test]
fn max_det_truncation_drops_low_score_dts_per_image() {
let cell = one_threshold_eval(
vec![0.95, 0.9],
vec![false, true], vec![false, false],
vec![false],
);
let p = params(&[0.5], &[0.0, 0.5, 1.0], &[1], 1);
let out = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap();
for ri in 0..3 {
assert!(out.precision[(0, ri, 0, 0, 0)].abs() < 1e-12);
}
assert_eq!(out.recall[(0, 0, 0, 0)], 0.0);
}
#[test]
fn dimension_mismatch_on_grid_size_is_typed_error() {
let p = params(&[0.5], &[0.0], &[100], 5);
let err = accumulate(&[None, None], p, ParityMode::Strict).unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("eval_imgs"));
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn dimension_mismatch_on_per_image_t_is_typed_error() {
let cell = PerImageEval {
dt_scores: vec![0.9],
dt_matched: array![[true], [true]],
dt_ignore: array![[false], [false]],
gt_ignore: vec![false],
};
let p = params(&[0.5, 0.75, 0.9], &[0.0], &[100], 1);
let err = accumulate(&[Some(Box::new(cell))], p, ParityMode::Strict).unwrap_err();
assert!(matches!(err, EvalError::DimensionMismatch { .. }));
}
#[test]
fn reaccumulate_with_different_area_range_count_is_typed_error() {
let n_i = 1;
let n_a_built = 4;
let n_k = 1;
let cell = one_threshold_eval(vec![0.9], vec![true], vec![false], vec![false]);
let mut eval_imgs: Vec<Option<Box<PerImageEval>>> = vec![None; n_k * n_a_built * n_i];
eval_imgs[0] = Some(Box::new(cell));
let mut bad = params(&[0.5], &[0.0, 0.5, 1.0], &[100], n_i);
bad.n_area_ranges = 3;
let err = accumulate(&eval_imgs, bad, ParityMode::Strict).unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("eval_imgs"), "msg: {detail}");
assert!(detail.contains("n_area_ranges(3)"), "msg: {detail}");
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn vectorized_inner_sweep_matches_naive_reference() {
let recall_thresholds: Vec<f64> = (0..=10).map(|i| (i as f64) / 10.0).collect();
fn naive_sweep(rc: &[f64], pr: &[f64], rec_thr: &[f64]) -> Vec<f64> {
let n = pr.len();
let mut env = pr.to_vec();
for j in (1..n).rev() {
if env[j] > env[j - 1] {
env[j - 1] = env[j];
}
}
let mut q = vec![0.0_f64; rec_thr.len()];
for (ri, &target) in rec_thr.iter().enumerate() {
let mut pi = n;
for (j, &r) in rc.iter().enumerate() {
if r >= target {
pi = j;
break;
}
}
if pi < n {
q[ri] = env[pi];
}
}
q
}
fn vectorized_sweep(rc: &[f64], pr: &[f64], rec_thr: &[f64]) -> Vec<f64> {
let n = pr.len();
let mut env = pr.to_vec();
for j in (1..n).rev() {
if env[j] > env[j - 1] {
env[j - 1] = env[j];
}
}
let mut q = vec![0.0_f64; rec_thr.len()];
for (ri, &target) in rec_thr.iter().enumerate() {
let pi = rc.partition_point(|&v| v < target);
if pi < n {
q[ri] = env[pi];
}
}
q
}
let curves: &[(&[f64], &[f64])] = &[
(&[0.1, 0.3, 0.5, 0.7, 1.0], &[1.0, 0.9, 0.7, 0.5, 0.3]),
(&[0.2, 0.4, 0.6, 0.8, 1.0], &[1.0, 0.4, 0.6, 0.2, 0.5]),
(&[0.1, 0.2, 0.3, 0.4, 0.5], &[1.0, 1.0, 1.0, 1.0, 1.0]),
];
for (i, (rc, pr)) in curves.iter().enumerate() {
let q_naive = naive_sweep(rc, pr, &recall_thresholds);
let q_vec = vectorized_sweep(rc, pr, &recall_thresholds);
assert_eq!(q_naive.len(), q_vec.len(), "curve {i}");
for (ri, (a, b)) in q_naive.iter().zip(q_vec.iter()).enumerate() {
assert_eq!(
a.to_bits(),
b.to_bits(),
"curve {i}, recall threshold index {ri}: naive={a}, vec={b}"
);
}
}
}
#[test]
fn sort_max_dets_normalizes_ascending() {
let mut ladder = vec![100usize, 1, 10];
sort_max_dets(&mut ladder);
assert_eq!(ladder, vec![1, 10, 100]);
}
#[test]
fn sort_max_dets_is_idempotent_on_sorted_input() {
let mut ladder = vec![1usize, 10, 100];
sort_max_dets(&mut ladder);
assert_eq!(ladder, vec![1, 10, 100]);
}
#[test]
fn sort_max_dets_handles_duplicates_and_singletons() {
let mut singleton = vec![100usize];
sort_max_dets(&mut singleton);
assert_eq!(singleton, vec![100]);
let mut empty: Vec<usize> = Vec::new();
sort_max_dets(&mut empty);
assert!(empty.is_empty());
let mut dups = vec![10usize, 1, 10, 1, 100];
sort_max_dets(&mut dups);
assert_eq!(dups, vec![1, 1, 10, 10, 100]);
}
#[test]
fn permuted_ladder_after_sort_matches_canonical_order() {
let cell = one_threshold_eval(
vec![0.9, 0.8, 0.7],
vec![true, true, false],
vec![false, false, false],
vec![false, false, false],
);
let iou = [0.5];
let rec = [0.0, 0.5, 1.0];
let canonical = vec![1usize, 10, 100];
let canonical_acc = accumulate(
&[Some(Box::new(cell.clone()))],
params(&iou, &rec, &canonical, 1),
ParityMode::Strict,
)
.unwrap();
let mut permuted = vec![100usize, 1, 10];
sort_max_dets(&mut permuted);
assert_eq!(permuted, canonical);
let permuted_acc = accumulate(
&[Some(Box::new(cell))],
params(&iou, &rec, &permuted, 1),
ParityMode::Strict,
)
.unwrap();
assert_eq!(canonical_acc.precision, permuted_acc.precision);
assert_eq!(canonical_acc.recall, permuted_acc.recall);
assert_eq!(canonical_acc.scores, permuted_acc.scores);
}
}