use std::borrow::Cow;
use std::ops::Range;
use ndarray::Axis;
use crate::accumulate::Accumulated;
use crate::error::EvalError;
const IOU_LOOKUP_TOL: f64 = 1e-12;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AreaRng {
pub index: usize,
pub label: Cow<'static, str>,
}
impl AreaRng {
pub fn new(index: usize, label: impl Into<Cow<'static, str>>) -> Self {
Self {
index,
label: label.into(),
}
}
pub const fn from_static(index: usize, label: &'static str) -> Self {
Self {
index,
label: Cow::Borrowed(label),
}
}
pub const ALL: Self = Self::from_static(0, "all");
pub const SMALL: Self = Self::from_static(1, "small");
pub const MEDIUM: Self = Self::from_static(2, "medium");
pub const LARGE: Self = Self::from_static(3, "large");
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Metric {
AveragePrecision,
AverageRecall,
}
#[derive(Debug, Clone)]
pub struct StatLine {
pub metric: Metric,
pub iou_threshold: Option<f64>,
pub area: AreaRng,
pub max_dets: usize,
pub value: f64,
}
#[derive(Debug, Clone)]
pub struct Summary {
pub lines: Vec<StatLine>,
}
impl Summary {
pub fn stats(&self) -> Vec<f64> {
self.lines.iter().map(|l| l.value).collect()
}
pub fn pretty_lines(&self) -> Vec<String> {
self.lines
.iter()
.map(|line| {
let (title, kind) = match line.metric {
Metric::AveragePrecision => ("Average Precision", "(AP)"),
Metric::AverageRecall => ("Average Recall", "(AR)"),
};
let iou = match line.iou_threshold {
Some(t) => format!("{t:0.2}"),
None => "0.50:0.95".to_string(),
};
format!(
" {title:<18} {kind} @[ IoU={iou:<9} | area={:>6} | maxDets={:>3} ] = {:0.3}",
line.area.label, line.max_dets, line.value
)
})
.collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MaxDetSelector {
Largest,
Value(usize),
}
#[derive(Debug, Clone)]
pub struct StatRequest {
pub metric: Metric,
pub iou_threshold: Option<f64>,
pub area: AreaRng,
pub max_dets: MaxDetSelector,
}
impl StatRequest {
pub const fn new(
metric: Metric,
iou_threshold: Option<f64>,
area: AreaRng,
max_dets: MaxDetSelector,
) -> Self {
Self {
metric,
iou_threshold,
area,
max_dets,
}
}
pub const fn coco_detection_default() -> [Self; 12] {
use MaxDetSelector::{Largest, Value};
use Metric::{AveragePrecision, AverageRecall};
[
Self::new(AveragePrecision, None, AreaRng::ALL, Largest),
Self::new(AveragePrecision, Some(0.5), AreaRng::ALL, Largest),
Self::new(AveragePrecision, Some(0.75), AreaRng::ALL, Largest),
Self::new(AveragePrecision, None, AreaRng::SMALL, Largest),
Self::new(AveragePrecision, None, AreaRng::MEDIUM, Largest),
Self::new(AveragePrecision, None, AreaRng::LARGE, Largest),
Self::new(AverageRecall, None, AreaRng::ALL, Value(1)),
Self::new(AverageRecall, None, AreaRng::ALL, Value(10)),
Self::new(AverageRecall, None, AreaRng::ALL, Value(100)),
Self::new(AverageRecall, None, AreaRng::SMALL, Largest),
Self::new(AverageRecall, None, AreaRng::MEDIUM, Largest),
Self::new(AverageRecall, None, AreaRng::LARGE, Largest),
]
}
pub const fn coco_keypoints_default() -> [Self; 10] {
use MaxDetSelector::Largest;
use Metric::{AveragePrecision, AverageRecall};
const ALL: AreaRng = AreaRng::from_static(0, "all");
const MEDIUM: AreaRng = AreaRng::from_static(1, "medium");
const LARGE: AreaRng = AreaRng::from_static(2, "large");
[
Self::new(AveragePrecision, None, ALL, Largest),
Self::new(AveragePrecision, Some(0.5), ALL, Largest),
Self::new(AveragePrecision, Some(0.75), ALL, Largest),
Self::new(AveragePrecision, None, MEDIUM, Largest),
Self::new(AveragePrecision, None, LARGE, Largest),
Self::new(AverageRecall, None, ALL, Largest),
Self::new(AverageRecall, Some(0.5), ALL, Largest),
Self::new(AverageRecall, Some(0.75), ALL, Largest),
Self::new(AverageRecall, None, MEDIUM, Largest),
Self::new(AverageRecall, None, LARGE, Largest),
]
}
}
pub fn summarize_detection(
accum: &Accumulated,
iou_thresholds: &[f64],
max_dets: &[usize],
) -> Result<Summary, EvalError> {
summarize_with(
accum,
&StatRequest::coco_detection_default(),
iou_thresholds,
max_dets,
)
}
pub fn summarize_with(
accum: &Accumulated,
plan: &[StatRequest],
iou_thresholds: &[f64],
max_dets: &[usize],
) -> Result<Summary, EvalError> {
let p_shape = accum.precision.shape();
let r_shape = accum.recall.shape();
let n_t = p_shape[0];
let n_m = p_shape[4];
if n_t != iou_thresholds.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"precision T-axis {} != iou_thresholds len {}",
n_t,
iou_thresholds.len()
),
});
}
if n_m != max_dets.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"precision M-axis {} != max_dets len {}",
n_m,
max_dets.len()
),
});
}
if r_shape[0] != n_t || r_shape[3] != n_m {
return Err(EvalError::DimensionMismatch {
detail: format!("recall {r_shape:?} disagrees with precision {p_shape:?}"),
});
}
let n_a = p_shape[3];
let m_max = max_dets.len() - 1;
let resolved: Vec<(usize, Range<usize>)> = plan
.iter()
.map(|req| {
if req.area.index >= n_a {
return Err(EvalError::InvalidConfig {
detail: format!(
"AreaRng index {} is out of range for A-axis (size {})",
req.area.index, n_a
),
});
}
let m_idx = match req.max_dets {
MaxDetSelector::Largest => m_max,
MaxDetSelector::Value(v) => {
max_dets.iter().position(|&d| d == v).ok_or_else(|| {
EvalError::InvalidConfig {
detail: format!("max_dets does not contain {v}"),
}
})?
}
};
let t_range = match req.iou_threshold {
None => 0..n_t,
Some(target) => {
let t = iou_thresholds
.iter()
.position(|&v| (v - target).abs() < IOU_LOOKUP_TOL)
.ok_or_else(|| EvalError::InvalidConfig {
detail: format!("iou_threshold {target} not in ladder"),
})?;
t..(t + 1)
}
};
Ok((m_idx, t_range))
})
.collect::<Result<Vec<_>, EvalError>>()?;
let lines = plan
.iter()
.zip(resolved)
.map(|(req, (m_idx, t_range))| {
let value = mean_slice(accum, req.metric, t_range, req.area.index, m_idx);
StatLine {
metric: req.metric,
iou_threshold: req.iou_threshold,
area: req.area.clone(),
max_dets: max_dets[m_idx],
value,
}
})
.collect();
Ok(Summary { lines })
}
fn mean_slice(
accum: &Accumulated,
metric: Metric,
t_range: Range<usize>,
area_idx: usize,
m_idx: usize,
) -> f64 {
let t_count = t_range.len();
let cap = match metric {
Metric::AveragePrecision => {
t_count * accum.precision.shape()[1] * accum.precision.shape()[2]
}
Metric::AverageRecall => t_count * accum.recall.shape()[1],
};
let mut filtered: Vec<f64> = Vec::with_capacity(cap);
let mut push = |v: f64| {
if v > -1.0 {
filtered.push(v);
}
};
for t in t_range {
match metric {
Metric::AveragePrecision => accum
.precision
.index_axis(Axis(0), t)
.index_axis(Axis(2), area_idx)
.index_axis(Axis(2), m_idx)
.iter()
.copied()
.for_each(&mut push),
Metric::AverageRecall => accum
.recall
.index_axis(Axis(0), t)
.index_axis(Axis(1), area_idx)
.index_axis(Axis(1), m_idx)
.iter()
.copied()
.for_each(&mut push),
}
}
if filtered.is_empty() {
-1.0
} else {
pairwise_sum(&filtered) / filtered.len() as f64
}
}
fn pairwise_sum(values: &[f64]) -> f64 {
const PW_BLOCKSIZE: usize = 128;
let n = values.len();
if n < 8 {
let mut s = 0.0_f64;
for &v in values {
s += v;
}
return s;
}
if n <= PW_BLOCKSIZE {
let mut r = [
values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7],
];
let trunc = n - (n % 8);
let mut i = 8;
while i < trunc {
r[0] += values[i];
r[1] += values[i + 1];
r[2] += values[i + 2];
r[3] += values[i + 3];
r[4] += values[i + 4];
r[5] += values[i + 5];
r[6] += values[i + 6];
r[7] += values[i + 7];
i += 8;
}
let mut res = ((r[0] + r[1]) + (r[2] + r[3])) + ((r[4] + r[5]) + (r[6] + r[7]));
while i < n {
res += values[i];
i += 1;
}
return res;
}
let mut n2 = n / 2;
n2 -= n2 % 8;
pairwise_sum(&values[..n2]) + pairwise_sum(&values[n2..])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::accumulate::{accumulate, AccumulateParams, PerImageEval};
use crate::parity::{iou_thresholds, recall_thresholds, ParityMode};
use ndarray::{Array2, Array4, Array5};
fn perfect_match_eval(t: usize) -> PerImageEval {
PerImageEval {
dt_scores: vec![0.9],
dt_matched: Array2::from_elem((t, 1), true),
dt_ignore: Array2::from_elem((t, 1), false),
gt_ignore: vec![false],
}
}
#[test]
fn perfect_match_summarizes_to_ones() {
let iou = iou_thresholds();
let rec = recall_thresholds();
let max_dets = [1usize, 10, 100];
let cell = perfect_match_eval(iou.len());
let mut grid: Vec<Option<PerImageEval>> = vec![None; 4];
grid[0] = Some(cell);
let p = AccumulateParams {
iou_thresholds: iou,
recall_thresholds: rec,
max_dets: &max_dets,
n_categories: 1,
n_area_ranges: 4,
n_images: 1,
};
let accum = accumulate(&grid, p, ParityMode::Strict).unwrap();
let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
let stats = summary.stats();
assert_eq!(stats.len(), 12);
for &i in &[0usize, 1, 2, 6, 7, 8] {
let v = stats[i];
assert!((v - 1.0).abs() < 1e-9, "stat[{i}] = {v}");
}
for &i in &[3usize, 4, 5, 9, 10, 11] {
assert_eq!(stats[i], -1.0, "stat[{i}] should be -1 sentinel");
}
}
#[test]
fn empty_grid_yields_all_neg_one_stats() {
let iou = iou_thresholds();
let rec = recall_thresholds();
let max_dets = [1usize, 10, 100];
let p = AccumulateParams {
iou_thresholds: iou,
recall_thresholds: rec,
max_dets: &max_dets,
n_categories: 1,
n_area_ranges: 4,
n_images: 0,
};
let accum = accumulate(&[], p, ParityMode::Strict).unwrap();
let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
assert!(summary.stats().iter().all(|&v| v == -1.0));
}
#[test]
fn missing_max_det_value_is_typed_error() {
let iou = iou_thresholds();
let max_dets = [10usize, 100];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 2), -1.0),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 2), -1.0),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 2), -1.0),
};
let err = summarize_detection(&accum, iou, &max_dets).unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
#[test]
fn iou_threshold_dimension_mismatch_is_typed_error() {
let max_dets = [100usize];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((10, 101, 1, 4, 1), -1.0),
recall: Array4::<f64>::from_elem((10, 1, 4, 1), -1.0),
scores: Array5::<f64>::from_elem((10, 101, 1, 4, 1), -1.0),
};
let err = summarize_detection(&accum, &[0.5, 0.6, 0.7, 0.8, 0.9], &max_dets).unwrap_err();
assert!(matches!(err, EvalError::DimensionMismatch { .. }));
}
#[test]
fn summarize_with_custom_plan_evaluates_only_requested_lines() {
let iou = iou_thresholds();
let max_dets = [100usize];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 0.5),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 1), 0.7),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
};
let plan = [
StatRequest::new(
Metric::AveragePrecision,
Some(0.5),
AreaRng::ALL,
MaxDetSelector::Largest,
),
StatRequest::new(
Metric::AverageRecall,
Some(0.75),
AreaRng::ALL,
MaxDetSelector::Largest,
),
];
let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
assert_eq!(summary.lines.len(), 2);
assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
assert_eq!(summary.lines[0].iou_threshold, Some(0.5));
assert!((summary.lines[1].value - 0.7).abs() < 1e-12);
assert_eq!(summary.lines[1].metric, Metric::AverageRecall);
}
#[test]
fn summarize_detection_matches_canonical_plan_via_summarize_with() {
let iou = iou_thresholds();
let max_dets = [1usize, 10, 100];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 0.5),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 0.7),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
};
let direct = summarize_detection(&accum, iou, &max_dets).unwrap();
let via_plan = summarize_with(
&accum,
&StatRequest::coco_detection_default(),
iou,
&max_dets,
)
.unwrap();
assert_eq!(direct.stats(), via_plan.stats());
}
#[test]
fn custom_area_bucket_with_owned_label_renders_in_pretty_lines() {
let iou = iou_thresholds();
let max_dets = [100usize];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0),
recall: Array4::<f64>::from_elem((iou.len(), 1, 5, 1), 1.0),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0),
};
let plan = [StatRequest::new(
Metric::AveragePrecision,
None,
AreaRng::new(4, "tiny"),
MaxDetSelector::Largest,
)];
let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
let lines = summary.pretty_lines();
assert_eq!(lines.len(), 1);
assert!(lines[0].contains("tiny"), "unexpected line: {}", lines[0]);
}
#[test]
fn out_of_range_area_index_is_typed_error() {
let iou = iou_thresholds();
let max_dets = [100usize];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 1), 1.0),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
};
let plan = [StatRequest::new(
Metric::AveragePrecision,
None,
AreaRng::new(4, "tiny"),
MaxDetSelector::Largest,
)];
let err = summarize_with(&accum, &plan, iou, &max_dets).unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
#[test]
fn pretty_lines_match_pycocotools_shape() {
let iou = iou_thresholds();
let max_dets = [1usize, 10, 100];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 1.0),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
};
let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
let lines = summary.pretty_lines();
assert_eq!(lines.len(), 12);
assert!(lines[0].contains("Average Precision"));
assert!(lines[0].contains("(AP)"));
assert!(lines[0].contains("0.50:0.95"));
assert!(lines[0].contains("maxDets=100"));
assert!(lines[6].contains("Average Recall"));
assert!(lines[6].contains("maxDets= 1"));
}
#[test]
fn pairwise_sum_matches_numpy_add_reduce_bitwise() {
let v: Vec<f64> = (0..1010)
.map(|i| if i % 2 == 0 { 1.0 } else { 1e-12 })
.collect();
let got = pairwise_sum(&v);
let expected = f64::from_bits(0x407f_9000_0000_22b4);
assert_eq!(
got.to_bits(),
expected.to_bits(),
"pairwise_sum drifts from numpy: got {got:e}, expected {expected:e}",
);
}
#[test]
fn coco_keypoints_default_plan_pins_canonical_order() {
let plan = StatRequest::coco_keypoints_default();
assert_eq!(plan.len(), 10);
let expected: [(Metric, Option<f64>, usize, MaxDetSelector); 10] = [
(Metric::AveragePrecision, None, 0, MaxDetSelector::Largest), (
Metric::AveragePrecision,
Some(0.5),
0,
MaxDetSelector::Largest,
), (
Metric::AveragePrecision,
Some(0.75),
0,
MaxDetSelector::Largest,
), (Metric::AveragePrecision, None, 1, MaxDetSelector::Largest), (Metric::AveragePrecision, None, 2, MaxDetSelector::Largest), (Metric::AverageRecall, None, 0, MaxDetSelector::Largest), (Metric::AverageRecall, Some(0.5), 0, MaxDetSelector::Largest), (
Metric::AverageRecall,
Some(0.75),
0,
MaxDetSelector::Largest,
), (Metric::AverageRecall, None, 1, MaxDetSelector::Largest), (Metric::AverageRecall, None, 2, MaxDetSelector::Largest), ];
for (i, (metric, iou, idx, sel)) in expected.into_iter().enumerate() {
assert_eq!(plan[i].metric, metric, "row {i} metric");
assert_eq!(plan[i].iou_threshold, iou, "row {i} iou_threshold");
assert_eq!(plan[i].area.index, idx, "row {i} area index");
assert_eq!(plan[i].max_dets, sel, "row {i} selector");
}
assert!(plan.iter().all(|r| r.area.index <= 2));
}
#[test]
fn pairwise_sum_handles_short_inputs_with_naive_fallback() {
let v = [1.0_f64, 2.0, 3.0, 4.0];
assert_eq!(pairwise_sum(&v), 10.0);
assert_eq!(pairwise_sum(&[]), 0.0);
assert_eq!(pairwise_sum(&[42.0]), 42.0);
}
}