use std::borrow::Cow;
use std::collections::HashMap;
use std::ops::Range;
use ndarray::Axis;
use crate::accumulate::Accumulated;
use crate::dataset::{CategoryId, Frequency};
use crate::error::EvalError;
pub(crate) const IOU_LOOKUP_TOL: f64 = 1e-12;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AreaRng {
pub index: usize,
pub label: Cow<'static, str>,
}
impl AreaRng {
pub fn new(index: usize, label: impl Into<Cow<'static, str>>) -> Self {
Self {
index,
label: label.into(),
}
}
pub const fn from_static(index: usize, label: &'static str) -> Self {
Self {
index,
label: Cow::Borrowed(label),
}
}
pub const ALL: Self = Self::from_static(0, "all");
pub const SMALL: Self = Self::from_static(1, "small");
pub const MEDIUM: Self = Self::from_static(2, "medium");
pub const LARGE: Self = Self::from_static(3, "large");
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Metric {
AveragePrecision,
AverageRecall,
}
#[derive(Debug, Clone)]
pub struct StatLine {
pub metric: Metric,
pub iou_threshold: Option<f64>,
pub area: AreaRng,
pub max_dets: usize,
pub value: f64,
}
#[derive(Debug, Clone)]
pub struct Summary {
pub lines: Vec<StatLine>,
}
impl Summary {
pub fn stats(&self) -> Vec<f64> {
self.lines.iter().map(|l| l.value).collect()
}
pub fn pretty_lines(&self) -> Vec<String> {
self.lines
.iter()
.map(|line| {
let (title, kind) = match line.metric {
Metric::AveragePrecision => ("Average Precision", "(AP)"),
Metric::AverageRecall => ("Average Recall", "(AR)"),
};
let iou = match line.iou_threshold {
Some(t) => format!("{t:0.2}"),
None => "0.50:0.95".to_string(),
};
format!(
" {title:<18} {kind} @[ IoU={iou:<9} | area={:>6} | maxDets={:>3} ] = {:0.3}",
line.area.label, line.max_dets, line.value
)
})
.collect()
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum MaxDetSelector {
Largest,
Value(usize),
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum CategoryFilter {
All,
Frequency(Frequency),
ByIds(Vec<CategoryId>),
ByGrouping(Cow<'static, str>),
}
impl CategoryFilter {
pub fn needs_lvis_context(&self) -> bool {
matches!(self, Self::Frequency(_) | Self::ByIds(_))
}
}
#[derive(Debug, Clone)]
pub struct StatRequest {
pub metric: Metric,
pub iou_threshold: Option<f64>,
pub area: AreaRng,
pub max_dets: MaxDetSelector,
pub category_filter: CategoryFilter,
}
impl StatRequest {
pub const fn new(
metric: Metric,
iou_threshold: Option<f64>,
area: AreaRng,
max_dets: MaxDetSelector,
) -> Self {
Self {
metric,
iou_threshold,
area,
max_dets,
category_filter: CategoryFilter::All,
}
}
pub const fn new_with_filter(
metric: Metric,
iou_threshold: Option<f64>,
area: AreaRng,
max_dets: MaxDetSelector,
category_filter: CategoryFilter,
) -> Self {
Self {
metric,
iou_threshold,
area,
max_dets,
category_filter,
}
}
pub const fn coco_detection_default() -> [Self; 12] {
use MaxDetSelector::{Largest, Value};
use Metric::{AveragePrecision, AverageRecall};
[
Self::new(AveragePrecision, None, AreaRng::ALL, Largest),
Self::new(AveragePrecision, Some(0.5), AreaRng::ALL, Largest),
Self::new(AveragePrecision, Some(0.75), AreaRng::ALL, Largest),
Self::new(AveragePrecision, None, AreaRng::SMALL, Largest),
Self::new(AveragePrecision, None, AreaRng::MEDIUM, Largest),
Self::new(AveragePrecision, None, AreaRng::LARGE, Largest),
Self::new(AverageRecall, None, AreaRng::ALL, Value(1)),
Self::new(AverageRecall, None, AreaRng::ALL, Value(10)),
Self::new(AverageRecall, None, AreaRng::ALL, Value(100)),
Self::new(AverageRecall, None, AreaRng::SMALL, Largest),
Self::new(AverageRecall, None, AreaRng::MEDIUM, Largest),
Self::new(AverageRecall, None, AreaRng::LARGE, Largest),
]
}
pub const fn lvis_default() -> [Self; 13] {
use CategoryFilter::{All as AllK, Frequency as FreqK};
use MaxDetSelector::Largest;
use Metric::{AveragePrecision, AverageRecall};
[
Self::new_with_filter(AveragePrecision, None, AreaRng::ALL, Largest, AllK),
Self::new_with_filter(AveragePrecision, Some(0.5), AreaRng::ALL, Largest, AllK),
Self::new_with_filter(AveragePrecision, Some(0.75), AreaRng::ALL, Largest, AllK),
Self::new_with_filter(AveragePrecision, None, AreaRng::SMALL, Largest, AllK),
Self::new_with_filter(AveragePrecision, None, AreaRng::MEDIUM, Largest, AllK),
Self::new_with_filter(AveragePrecision, None, AreaRng::LARGE, Largest, AllK),
Self::new_with_filter(
AveragePrecision,
None,
AreaRng::ALL,
Largest,
FreqK(Frequency::Rare),
),
Self::new_with_filter(
AveragePrecision,
None,
AreaRng::ALL,
Largest,
FreqK(Frequency::Common),
),
Self::new_with_filter(
AveragePrecision,
None,
AreaRng::ALL,
Largest,
FreqK(Frequency::Frequent),
),
Self::new_with_filter(AverageRecall, None, AreaRng::ALL, Largest, AllK),
Self::new_with_filter(AverageRecall, None, AreaRng::SMALL, Largest, AllK),
Self::new_with_filter(AverageRecall, None, AreaRng::MEDIUM, Largest, AllK),
Self::new_with_filter(AverageRecall, None, AreaRng::LARGE, Largest, AllK),
]
}
pub const fn coco_keypoints_default() -> [Self; 10] {
use MaxDetSelector::Largest;
use Metric::{AveragePrecision, AverageRecall};
const ALL: AreaRng = AreaRng::from_static(0, "all");
const MEDIUM: AreaRng = AreaRng::from_static(1, "medium");
const LARGE: AreaRng = AreaRng::from_static(2, "large");
[
Self::new(AveragePrecision, None, ALL, Largest),
Self::new(AveragePrecision, Some(0.5), ALL, Largest),
Self::new(AveragePrecision, Some(0.75), ALL, Largest),
Self::new(AveragePrecision, None, MEDIUM, Largest),
Self::new(AveragePrecision, None, LARGE, Largest),
Self::new(AverageRecall, None, ALL, Largest),
Self::new(AverageRecall, Some(0.5), ALL, Largest),
Self::new(AverageRecall, Some(0.75), ALL, Largest),
Self::new(AverageRecall, None, MEDIUM, Largest),
Self::new(AverageRecall, None, LARGE, Largest),
]
}
}
pub fn summarize_detection(
accum: &Accumulated,
iou_thresholds: &[f64],
max_dets: &[usize],
) -> Result<Summary, EvalError> {
summarize_with(
accum,
&StatRequest::coco_detection_default(),
iou_thresholds,
max_dets,
)
}
pub fn summarize_with(
accum: &Accumulated,
plan: &[StatRequest],
iou_thresholds: &[f64],
max_dets: &[usize],
) -> Result<Summary, EvalError> {
summarize_dispatch(accum, plan, iou_thresholds, max_dets, None)
}
pub fn summarize_with_lvis(
accum: &Accumulated,
plan: &[StatRequest],
iou_thresholds: &[f64],
max_dets: &[usize],
category_ids: &[CategoryId],
category_frequency: Option<&HashMap<CategoryId, Frequency>>,
) -> Result<Summary, EvalError> {
let n_k = accum.precision.shape()[2];
if category_ids.len() != n_k {
return Err(EvalError::InvalidConfig {
detail: format!(
"category_ids len {} != precision K-axis {n_k}",
category_ids.len()
),
});
}
let ctx = LvisCtx {
category_ids,
category_frequency,
};
summarize_dispatch(accum, plan, iou_thresholds, max_dets, Some(&ctx))
}
struct LvisCtx<'a> {
category_ids: &'a [CategoryId],
category_frequency: Option<&'a HashMap<CategoryId, Frequency>>,
}
fn summarize_dispatch(
accum: &Accumulated,
plan: &[StatRequest],
iou_thresholds: &[f64],
max_dets: &[usize],
lvis: Option<&LvisCtx<'_>>,
) -> Result<Summary, EvalError> {
let p_shape = accum.precision.shape();
let r_shape = accum.recall.shape();
let n_t = p_shape[0];
let n_m = p_shape[4];
if n_t != iou_thresholds.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"precision T-axis {} != iou_thresholds len {}",
n_t,
iou_thresholds.len()
),
});
}
if n_m != max_dets.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"precision M-axis {} != max_dets len {}",
n_m,
max_dets.len()
),
});
}
if r_shape[0] != n_t || r_shape[3] != n_m {
return Err(EvalError::DimensionMismatch {
detail: format!("recall {r_shape:?} disagrees with precision {p_shape:?}"),
});
}
let n_a = p_shape[3];
let n_k = p_shape[2];
let m_max = max_dets.len() - 1;
let resolved: Vec<(usize, Range<usize>, Option<Vec<bool>>)> = plan
.iter()
.map(|req| {
if req.area.index >= n_a {
return Err(EvalError::InvalidConfig {
detail: format!(
"AreaRng index {} is out of range for A-axis (size {})",
req.area.index, n_a
),
});
}
let m_idx = match req.max_dets {
MaxDetSelector::Largest => m_max,
MaxDetSelector::Value(v) => {
max_dets.iter().position(|&d| d == v).ok_or_else(|| {
EvalError::InvalidConfig {
detail: format!("max_dets does not contain {v}"),
}
})?
}
};
let t_range = match req.iou_threshold {
None => 0..n_t,
Some(target) => {
let t = iou_thresholds
.iter()
.position(|&v| (v - target).abs() < IOU_LOOKUP_TOL)
.ok_or_else(|| EvalError::InvalidConfig {
detail: format!("iou_threshold {target} not in ladder"),
})?;
t..(t + 1)
}
};
let k_mask = resolve_category_filter(&req.category_filter, n_k, lvis)?;
Ok((m_idx, t_range, k_mask))
})
.collect::<Result<Vec<_>, EvalError>>()?;
let lines = plan
.iter()
.zip(resolved)
.map(|(req, (m_idx, t_range, k_mask))| {
let value = mean_slice(
accum,
req.metric,
t_range,
req.area.index,
m_idx,
k_mask.as_deref(),
);
StatLine {
metric: req.metric,
iou_threshold: req.iou_threshold,
area: req.area.clone(),
max_dets: max_dets[m_idx],
value,
}
})
.collect();
Ok(Summary { lines })
}
fn resolve_category_filter(
filter: &CategoryFilter,
n_k: usize,
lvis: Option<&LvisCtx<'_>>,
) -> Result<Option<Vec<bool>>, EvalError> {
match filter {
CategoryFilter::All => Ok(None),
CategoryFilter::Frequency(target) => {
let Some(ctx) = lvis else {
return Err(EvalError::InvalidConfig {
detail: "CategoryFilter::Frequency requires summarize_with_lvis".to_string(),
});
};
let Some(freq_map) = ctx.category_frequency else {
return Ok(Some(vec![false; n_k]));
};
Ok(Some(
ctx.category_ids
.iter()
.map(|cid| freq_map.get(cid).is_some_and(|f| f == target))
.collect(),
))
}
CategoryFilter::ByIds(ids) => {
let Some(ctx) = lvis else {
return Err(EvalError::InvalidConfig {
detail: "CategoryFilter::ByIds requires summarize_with_lvis".to_string(),
});
};
let allow: std::collections::HashSet<&CategoryId> = ids.iter().collect();
Ok(Some(
ctx.category_ids
.iter()
.map(|cid| allow.contains(cid))
.collect(),
))
}
CategoryFilter::ByGrouping(label) => Err(EvalError::InvalidConfig {
detail: format!(
"CategoryFilter::ByGrouping({label:?}) must be resolved to ByIds at the \
evaluator boundary before reaching the kernel summarizer (ADR-0041 / 0042). \
Resolution maps the group label against the active ClassGroupBreakdown."
),
}),
}
}
fn mean_slice(
accum: &Accumulated,
metric: Metric,
t_range: Range<usize>,
area_idx: usize,
m_idx: usize,
k_mask: Option<&[bool]>,
) -> f64 {
let t_count = t_range.len();
let n_k = accum.precision.shape()[2];
let cap = match metric {
Metric::AveragePrecision => t_count * accum.precision.shape()[1] * n_k,
Metric::AverageRecall => t_count * n_k,
};
let mut filtered: Vec<f64> = Vec::with_capacity(cap);
let push_if = |filtered: &mut Vec<f64>, v: f64| {
if v > -1.0 {
filtered.push(v);
}
};
for t in t_range {
match metric {
Metric::AveragePrecision => {
let p_t = accum.precision.index_axis(Axis(0), t);
let p_ta = p_t.index_axis(Axis(2), area_idx);
let plane = p_ta.index_axis(Axis(2), m_idx);
let n_r = plane.shape()[0];
for r in 0..n_r {
for k in 0..n_k {
if k_mask.is_some_and(|m| !m[k]) {
continue;
}
push_if(&mut filtered, plane[(r, k)]);
}
}
}
Metric::AverageRecall => {
let r_t = accum.recall.index_axis(Axis(0), t);
let r_ta = r_t.index_axis(Axis(1), area_idx);
let plane = r_ta.index_axis(Axis(1), m_idx);
for k in 0..n_k {
if k_mask.is_some_and(|m| !m[k]) {
continue;
}
push_if(&mut filtered, plane[k]);
}
}
}
}
if filtered.is_empty() {
-1.0
} else {
pairwise_sum(&filtered) / filtered.len() as f64
}
}
pub(crate) fn pairwise_sum(values: &[f64]) -> f64 {
const PW_BLOCKSIZE: usize = 128;
let n = values.len();
if n < 8 {
let mut s = 0.0_f64;
for &v in values {
s += v;
}
return s;
}
if n <= PW_BLOCKSIZE {
let mut r = [
values[0], values[1], values[2], values[3], values[4], values[5], values[6], values[7],
];
let trunc = n - (n % 8);
let mut i = 8;
while i < trunc {
r[0] += values[i];
r[1] += values[i + 1];
r[2] += values[i + 2];
r[3] += values[i + 3];
r[4] += values[i + 4];
r[5] += values[i + 5];
r[6] += values[i + 6];
r[7] += values[i + 7];
i += 8;
}
let mut res = ((r[0] + r[1]) + (r[2] + r[3])) + ((r[4] + r[5]) + (r[6] + r[7]));
while i < n {
res += values[i];
i += 1;
}
return res;
}
let mut n2 = n / 2;
n2 -= n2 % 8;
pairwise_sum(&values[..n2]) + pairwise_sum(&values[n2..])
}
#[cfg(test)]
mod tests {
use super::*;
use crate::accumulate::{accumulate, AccumulateParams, PerImageEval};
use crate::parity::{iou_thresholds, recall_thresholds, ParityMode};
use ndarray::{Array2, Array4, Array5};
fn perfect_match_eval(t: usize) -> PerImageEval {
PerImageEval {
dt_scores: vec![0.9],
dt_matched: Array2::from_elem((t, 1), true),
dt_ignore: Array2::from_elem((t, 1), false),
gt_ignore: vec![false],
}
}
#[test]
fn perfect_match_summarizes_to_ones() {
let iou = iou_thresholds();
let rec = recall_thresholds();
let max_dets = [1usize, 10, 100];
let cell = perfect_match_eval(iou.len());
let mut grid: Vec<Option<Box<PerImageEval>>> = vec![None; 4];
grid[0] = Some(Box::new(cell));
let p = AccumulateParams {
iou_thresholds: iou,
recall_thresholds: rec,
max_dets: &max_dets,
n_categories: 1,
n_area_ranges: 4,
n_images: 1,
};
let accum = accumulate(&grid, p, ParityMode::Strict).unwrap();
let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
let stats = summary.stats();
assert_eq!(stats.len(), 12);
for &i in &[0usize, 1, 2, 6, 7, 8] {
let v = stats[i];
assert!((v - 1.0).abs() < 1e-9, "stat[{i}] = {v}");
}
for &i in &[3usize, 4, 5, 9, 10, 11] {
assert_eq!(stats[i], -1.0, "stat[{i}] should be -1 sentinel");
}
}
#[test]
fn empty_grid_yields_all_neg_one_stats() {
let iou = iou_thresholds();
let rec = recall_thresholds();
let max_dets = [1usize, 10, 100];
let p = AccumulateParams {
iou_thresholds: iou,
recall_thresholds: rec,
max_dets: &max_dets,
n_categories: 1,
n_area_ranges: 4,
n_images: 0,
};
let accum = accumulate(&[], p, ParityMode::Strict).unwrap();
let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
assert!(summary.stats().iter().all(|&v| v == -1.0));
}
#[test]
fn missing_max_det_value_is_typed_error() {
let iou = iou_thresholds();
let max_dets = [10usize, 100];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 2), -1.0),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 2), -1.0),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 2), -1.0),
};
let err = summarize_detection(&accum, iou, &max_dets).unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
#[test]
fn iou_threshold_dimension_mismatch_is_typed_error() {
let max_dets = [100usize];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((10, 101, 1, 4, 1), -1.0),
recall: Array4::<f64>::from_elem((10, 1, 4, 1), -1.0),
scores: Array5::<f64>::from_elem((10, 101, 1, 4, 1), -1.0),
};
let err = summarize_detection(&accum, &[0.5, 0.6, 0.7, 0.8, 0.9], &max_dets).unwrap_err();
assert!(matches!(err, EvalError::DimensionMismatch { .. }));
}
#[test]
fn summarize_with_custom_plan_evaluates_only_requested_lines() {
let iou = iou_thresholds();
let max_dets = [100usize];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 0.5),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 1), 0.7),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
};
let plan = [
StatRequest::new(
Metric::AveragePrecision,
Some(0.5),
AreaRng::ALL,
MaxDetSelector::Largest,
),
StatRequest::new(
Metric::AverageRecall,
Some(0.75),
AreaRng::ALL,
MaxDetSelector::Largest,
),
];
let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
assert_eq!(summary.lines.len(), 2);
assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
assert_eq!(summary.lines[0].iou_threshold, Some(0.5));
assert!((summary.lines[1].value - 0.7).abs() < 1e-12);
assert_eq!(summary.lines[1].metric, Metric::AverageRecall);
}
#[test]
fn summarize_detection_matches_canonical_plan_via_summarize_with() {
let iou = iou_thresholds();
let max_dets = [1usize, 10, 100];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 0.5),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 0.7),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
};
let direct = summarize_detection(&accum, iou, &max_dets).unwrap();
let via_plan = summarize_with(
&accum,
&StatRequest::coco_detection_default(),
iou,
&max_dets,
)
.unwrap();
assert_eq!(direct.stats(), via_plan.stats());
}
#[test]
fn custom_area_bucket_with_owned_label_renders_in_pretty_lines() {
let iou = iou_thresholds();
let max_dets = [100usize];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0),
recall: Array4::<f64>::from_elem((iou.len(), 1, 5, 1), 1.0),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0),
};
let plan = [StatRequest::new(
Metric::AveragePrecision,
None,
AreaRng::new(4, "tiny"),
MaxDetSelector::Largest,
)];
let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
let lines = summary.pretty_lines();
assert_eq!(lines.len(), 1);
assert!(lines[0].contains("tiny"), "unexpected line: {}", lines[0]);
}
#[test]
fn out_of_range_area_index_is_typed_error() {
let iou = iou_thresholds();
let max_dets = [100usize];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 1), 1.0),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 1), 1.0),
};
let plan = [StatRequest::new(
Metric::AveragePrecision,
None,
AreaRng::new(4, "tiny"),
MaxDetSelector::Largest,
)];
let err = summarize_with(&accum, &plan, iou, &max_dets).unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
#[test]
fn pretty_lines_match_pycocotools_shape() {
let iou = iou_thresholds();
let max_dets = [1usize, 10, 100];
let accum = Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 1.0),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
};
let summary = summarize_detection(&accum, iou, &max_dets).unwrap();
let lines = summary.pretty_lines();
assert_eq!(lines.len(), 12);
assert!(lines[0].contains("Average Precision"));
assert!(lines[0].contains("(AP)"));
assert!(lines[0].contains("0.50:0.95"));
assert!(lines[0].contains("maxDets=100"));
assert!(lines[6].contains("Average Recall"));
assert!(lines[6].contains("maxDets= 1"));
}
#[test]
fn pairwise_sum_matches_numpy_add_reduce_bitwise() {
let v: Vec<f64> = (0..1010)
.map(|i| if i % 2 == 0 { 1.0 } else { 1e-12 })
.collect();
let got = pairwise_sum(&v);
let expected = f64::from_bits(0x407f_9000_0000_22b4);
assert_eq!(
got.to_bits(),
expected.to_bits(),
"pairwise_sum drifts from numpy: got {got:e}, expected {expected:e}",
);
}
#[test]
fn coco_keypoints_default_plan_pins_canonical_order() {
let plan = StatRequest::coco_keypoints_default();
assert_eq!(plan.len(), 10);
let expected: [(Metric, Option<f64>, usize, MaxDetSelector); 10] = [
(Metric::AveragePrecision, None, 0, MaxDetSelector::Largest), (
Metric::AveragePrecision,
Some(0.5),
0,
MaxDetSelector::Largest,
), (
Metric::AveragePrecision,
Some(0.75),
0,
MaxDetSelector::Largest,
), (Metric::AveragePrecision, None, 1, MaxDetSelector::Largest), (Metric::AveragePrecision, None, 2, MaxDetSelector::Largest), (Metric::AverageRecall, None, 0, MaxDetSelector::Largest), (Metric::AverageRecall, Some(0.5), 0, MaxDetSelector::Largest), (
Metric::AverageRecall,
Some(0.75),
0,
MaxDetSelector::Largest,
), (Metric::AverageRecall, None, 1, MaxDetSelector::Largest), (Metric::AverageRecall, None, 2, MaxDetSelector::Largest), ];
for (i, (metric, iou, idx, sel)) in expected.into_iter().enumerate() {
assert_eq!(plan[i].metric, metric, "row {i} metric");
assert_eq!(plan[i].iou_threshold, iou, "row {i} iou_threshold");
assert_eq!(plan[i].area.index, idx, "row {i} area index");
assert_eq!(plan[i].max_dets, sel, "row {i} selector");
}
assert!(plan.iter().all(|r| r.area.index <= 2));
}
#[test]
fn pairwise_sum_handles_short_inputs_with_naive_fallback() {
let v = [1.0_f64, 2.0, 3.0, 4.0];
assert_eq!(pairwise_sum(&v), 10.0);
assert_eq!(pairwise_sum(&[]), 0.0);
assert_eq!(pairwise_sum(&[42.0]), 42.0);
}
fn fake_accumulated(n_k: usize, precision_per_k: &[f64], recall_per_k: &[f64]) -> Accumulated {
const N_T: usize = 10;
const N_R: usize = 101;
const N_A: usize = 4;
const N_M: usize = 1;
assert_eq!(precision_per_k.len(), n_k);
assert_eq!(recall_per_k.len(), n_k);
let mut precision = Array5::<f64>::from_elem((N_T, N_R, n_k, N_A, N_M), 0.0);
let mut recall = Array4::<f64>::from_elem((N_T, n_k, N_A, N_M), 0.0);
for k in 0..n_k {
for t in 0..N_T {
for r in 0..N_R {
for a in 0..N_A {
for m in 0..N_M {
precision[(t, r, k, a, m)] = precision_per_k[k];
}
}
}
for a in 0..N_A {
for m in 0..N_M {
recall[(t, k, a, m)] = recall_per_k[k];
}
}
}
}
Accumulated {
precision,
recall,
scores: Array5::<f64>::from_elem((N_T, N_R, n_k, N_A, N_M), 0.0),
}
}
#[test]
fn lvis_default_has_13_entries_in_canonical_order() {
let plan = StatRequest::lvis_default();
assert_eq!(plan.len(), 13, "AF1: 9 AP + 4 AR");
for (i, req) in plan.iter().take(6).enumerate() {
assert_eq!(req.metric, Metric::AveragePrecision, "row {i}");
assert_eq!(req.category_filter, CategoryFilter::All, "row {i}");
}
for (i, expected) in [Frequency::Rare, Frequency::Common, Frequency::Frequent]
.iter()
.enumerate()
{
let req = &plan[6 + i];
assert_eq!(req.metric, Metric::AveragePrecision);
assert_eq!(req.area.index, AreaRng::ALL.index);
assert_eq!(
req.category_filter,
CategoryFilter::Frequency(*expected),
"row {}: AP{}",
6 + i,
expected_letter(*expected),
);
}
for (i, area_idx) in [
AreaRng::ALL.index,
AreaRng::SMALL.index,
AreaRng::MEDIUM.index,
AreaRng::LARGE.index,
]
.iter()
.enumerate()
{
let req = &plan[9 + i];
assert_eq!(req.metric, Metric::AverageRecall);
assert_eq!(req.area.index, *area_idx);
assert_eq!(req.category_filter, CategoryFilter::All);
assert_eq!(req.max_dets, MaxDetSelector::Largest);
}
}
fn expected_letter(f: Frequency) -> char {
match f {
Frequency::Rare => 'r',
Frequency::Common => 'c',
Frequency::Frequent => 'f',
}
}
#[test]
fn summarize_with_rejects_frequency_filter_on_coco_path() {
let accum = fake_accumulated(3, &[1.0, 1.0, 1.0], &[1.0, 1.0, 1.0]);
let plan = StatRequest::lvis_default();
let err = summarize_with(&accum, &plan, iou_thresholds(), &[300]).unwrap_err();
match err {
EvalError::InvalidConfig { detail } => {
assert!(detail.contains("summarize_with_lvis"), "msg: {detail}");
}
other => panic!("expected InvalidConfig, got {other:?}"),
}
}
#[test]
fn summarize_with_lvis_routes_frequency_buckets_correctly() {
let accum = fake_accumulated(3, &[0.6, 0.4, 0.2], &[0.6, 0.4, 0.2]);
let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
let mut freq_map = HashMap::new();
freq_map.insert(CategoryId(1), Frequency::Frequent);
freq_map.insert(CategoryId(2), Frequency::Common);
freq_map.insert(CategoryId(3), Frequency::Rare);
let plan = StatRequest::lvis_default();
let summary = summarize_with_lvis(
&accum,
&plan,
iou_thresholds(),
&[300],
&cat_ids,
Some(&freq_map),
)
.unwrap();
let apr = summary.lines[6].value;
let apc = summary.lines[7].value;
let apf = summary.lines[8].value;
assert!((apr - 0.2).abs() < 1e-12, "APr expected 0.2, got {apr}");
assert!((apc - 0.4).abs() < 1e-12, "APc expected 0.4, got {apc}");
assert!((apf - 0.6).abs() < 1e-12, "APf expected 0.6, got {apf}");
let ap = summary.lines[0].value;
let expected = (0.6 + 0.4 + 0.2) / 3.0;
assert!((ap - expected).abs() < 1e-12, "AP overall: {ap}");
}
#[test]
fn ab3_filters_minus_one_sentinels_before_mean() {
let accum = fake_accumulated(2, &[-1.0, 0.5], &[-1.0, 0.5]);
let cat_ids = [CategoryId(1), CategoryId(2)];
let mut freq_map = HashMap::new();
freq_map.insert(CategoryId(1), Frequency::Rare);
freq_map.insert(CategoryId(2), Frequency::Frequent);
let summary = summarize_with_lvis(
&accum,
&StatRequest::lvis_default(),
iou_thresholds(),
&[300],
&cat_ids,
Some(&freq_map),
)
.unwrap();
assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
assert_eq!(summary.lines[6].value, -1.0, "APr empty bucket → -1");
assert!((summary.lines[8].value - 0.5).abs() < 1e-12);
}
#[test]
fn af6_empty_frequency_bucket_returns_minus_one_not_zero_or_nan() {
let accum = fake_accumulated(3, &[0.7, 0.8, 0.9], &[0.7, 0.8, 0.9]);
let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
let mut freq_map = HashMap::new();
freq_map.insert(CategoryId(1), Frequency::Frequent);
freq_map.insert(CategoryId(2), Frequency::Frequent);
freq_map.insert(CategoryId(3), Frequency::Frequent);
let summary = summarize_with_lvis(
&accum,
&StatRequest::lvis_default(),
iou_thresholds(),
&[300],
&cat_ids,
Some(&freq_map),
)
.unwrap();
assert!((summary.lines[0].value - 0.8).abs() < 1e-12);
assert_eq!(summary.lines[6].value, -1.0, "APr");
assert_eq!(summary.lines[7].value, -1.0, "APc");
assert!(!summary.lines[6].value.is_nan(), "AF6: never nan");
assert!(summary.lines[6].value != 0.0, "AF6: never 0.0");
}
#[test]
fn ab6_no_frequency_map_yields_minus_one_for_frequency_filtered_lines() {
let accum = fake_accumulated(2, &[0.5, 0.5], &[0.5, 0.5]);
let cat_ids = [CategoryId(1), CategoryId(2)];
let summary = summarize_with_lvis(
&accum,
&StatRequest::lvis_default(),
iou_thresholds(),
&[300],
&cat_ids,
None,
)
.unwrap();
assert!((summary.lines[0].value - 0.5).abs() < 1e-12, "AP overall");
assert_eq!(summary.lines[6].value, -1.0, "APr without freq map");
assert_eq!(summary.lines[7].value, -1.0, "APc without freq map");
assert_eq!(summary.lines[8].value, -1.0, "APf without freq map");
}
#[test]
fn category_filter_by_ids_subsets_correctly() {
let accum = fake_accumulated(3, &[0.1, 0.5, 0.9], &[0.1, 0.5, 0.9]);
let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)];
let plan = vec![StatRequest::new_with_filter(
Metric::AveragePrecision,
None,
AreaRng::ALL,
MaxDetSelector::Largest,
CategoryFilter::ByIds(vec![CategoryId(2)]),
)];
let summary =
summarize_with_lvis(&accum, &plan, iou_thresholds(), &[300], &cat_ids, None).unwrap();
assert!((summary.lines[0].value - 0.5).abs() < 1e-12);
}
#[test]
fn category_axis_size_mismatch_is_typed_error() {
let accum = fake_accumulated(2, &[0.5, 0.5], &[0.5, 0.5]);
let cat_ids = [CategoryId(1), CategoryId(2), CategoryId(3)]; let err = summarize_with_lvis(
&accum,
&StatRequest::lvis_default(),
iou_thresholds(),
&[300],
&cat_ids,
None,
)
.unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
}