use std::borrow::Cow;
use std::collections::BTreeSet;
use crate::evaluate::{AreaRange, AREA_UNBOUNDED};
use crate::summarize::{AreaRng, MaxDetSelector, Metric, StatRequest};
#[derive(Debug, Clone, PartialEq)]
pub struct Bucket {
pub index: usize,
pub label: Cow<'static, str>,
pub lo: f64,
pub hi: f64,
}
impl Bucket {
pub const fn from_static(index: usize, label: &'static str, lo: f64, hi: f64) -> Self {
Self {
index,
label: Cow::Borrowed(label),
lo,
hi,
}
}
pub fn new(index: usize, label: impl Into<Cow<'static, str>>, lo: f64, hi: f64) -> Self {
Self {
index,
label: label.into(),
lo,
hi,
}
}
pub fn contains(&self, key: f64) -> bool {
key >= self.lo && key <= self.hi
}
pub fn to_area_range(&self) -> AreaRange {
AreaRange {
index: self.index,
lo: self.lo,
hi: self.hi,
}
}
pub fn to_area_rng(&self) -> AreaRng {
AreaRng::new(self.index, self.label.clone())
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct Breakdown {
axis: Cow<'static, str>,
buckets: Vec<Bucket>,
}
impl Breakdown {
pub fn new(axis: impl Into<Cow<'static, str>>, buckets: Vec<Bucket>) -> Self {
let out = Self {
axis: axis.into(),
buckets,
};
debug_assert!(!out.buckets.is_empty(), "Breakdown must have >= 1 bucket");
let n = out.buckets.len();
debug_assert!(
out.buckets.iter().all(|b| b.index < n),
"Breakdown bucket index out of range",
);
let mut seen = vec![false; n];
for b in &out.buckets {
if b.index < n {
debug_assert!(!seen[b.index], "Breakdown has duplicate bucket index");
seen[b.index] = true;
}
}
out
}
pub fn coco_area_det() -> Self {
Self::new(
"area",
vec![
Bucket::from_static(0, "all", 0.0, AREA_UNBOUNDED),
Bucket::from_static(1, "small", 0.0, 32.0 * 32.0),
Bucket::from_static(2, "medium", 32.0 * 32.0, 96.0 * 96.0),
Bucket::from_static(3, "large", 96.0 * 96.0, AREA_UNBOUNDED),
],
)
}
pub fn coco_area_keypoints() -> Self {
Self::new(
"area",
vec![
Bucket::from_static(0, "all", 0.0, AREA_UNBOUNDED),
Bucket::from_static(1, "medium", 32.0 * 32.0, 96.0 * 96.0),
Bucket::from_static(2, "large", 96.0 * 96.0, AREA_UNBOUNDED),
],
)
}
pub fn axis(&self) -> &str {
&self.axis
}
pub fn buckets(&self) -> &[Bucket] {
&self.buckets
}
pub fn len(&self) -> usize {
self.buckets.len()
}
pub fn is_empty(&self) -> bool {
self.buckets.is_empty()
}
pub fn bucket_at(&self, index: usize) -> Option<&Bucket> {
self.buckets.iter().find(|b| b.index == index)
}
pub fn area_ranges(&self) -> Vec<AreaRange> {
self.buckets.iter().map(Bucket::to_area_range).collect()
}
pub fn summary_areas(&self) -> Vec<AreaRng> {
self.buckets.iter().map(Bucket::to_area_rng).collect()
}
pub fn detection_plan(&self) -> Option<[StatRequest; 12]> {
if self.len() != 4 {
return None;
}
let all = self.bucket_at(0)?.to_area_rng();
let small = self.bucket_at(1)?.to_area_rng();
let medium = self.bucket_at(2)?.to_area_rng();
let large = self.bucket_at(3)?.to_area_rng();
use MaxDetSelector::{Largest, Value};
use Metric::{AveragePrecision, AverageRecall};
Some([
StatRequest::new(AveragePrecision, None, all.clone(), Largest),
StatRequest::new(AveragePrecision, Some(0.5), all.clone(), Largest),
StatRequest::new(AveragePrecision, Some(0.75), all.clone(), Largest),
StatRequest::new(AveragePrecision, None, small.clone(), Largest),
StatRequest::new(AveragePrecision, None, medium.clone(), Largest),
StatRequest::new(AveragePrecision, None, large.clone(), Largest),
StatRequest::new(AverageRecall, None, all.clone(), Value(1)),
StatRequest::new(AverageRecall, None, all.clone(), Value(10)),
StatRequest::new(AverageRecall, None, all, Value(100)),
StatRequest::new(AverageRecall, None, small, Largest),
StatRequest::new(AverageRecall, None, medium, Largest),
StatRequest::new(AverageRecall, None, large, Largest),
])
}
pub fn keypoints_plan(&self) -> Option<[StatRequest; 10]> {
if self.len() != 3 {
return None;
}
let all = self.bucket_at(0)?.to_area_rng();
let medium = self.bucket_at(1)?.to_area_rng();
let large = self.bucket_at(2)?.to_area_rng();
use MaxDetSelector::Largest;
use Metric::{AveragePrecision, AverageRecall};
Some([
StatRequest::new(AveragePrecision, None, all.clone(), Largest),
StatRequest::new(AveragePrecision, Some(0.5), all.clone(), Largest),
StatRequest::new(AveragePrecision, Some(0.75), all.clone(), Largest),
StatRequest::new(AveragePrecision, None, medium.clone(), Largest),
StatRequest::new(AveragePrecision, None, large.clone(), Largest),
StatRequest::new(AverageRecall, None, all.clone(), Largest),
StatRequest::new(AverageRecall, Some(0.5), all.clone(), Largest),
StatRequest::new(AverageRecall, Some(0.75), all, Largest),
StatRequest::new(AverageRecall, None, medium, Largest),
StatRequest::new(AverageRecall, None, large, Largest),
])
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClassGroup {
pub index: usize,
pub label: Cow<'static, str>,
class_ids: Vec<u32>,
}
impl ClassGroup {
pub fn new(
index: usize,
label: impl Into<Cow<'static, str>>,
class_ids: impl IntoIterator<Item = u32>,
) -> Self {
let mut ids: Vec<u32> = class_ids.into_iter().collect();
ids.sort_unstable();
ids.dedup();
Self {
index,
label: label.into(),
class_ids: ids,
}
}
pub fn class_ids(&self) -> &[u32] {
&self.class_ids
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ClassGroupBreakdown {
axis: Cow<'static, str>,
groups: Vec<ClassGroup>,
}
impl ClassGroupBreakdown {
pub fn new(axis: impl Into<Cow<'static, str>>, groups: Vec<ClassGroup>) -> Self {
let out = Self {
axis: axis.into(),
groups,
};
debug_assert!(
!out.groups.is_empty(),
"ClassGroupBreakdown must have >= 1 group",
);
let n = out.groups.len();
debug_assert!(
out.groups.iter().all(|g| g.index < n),
"ClassGroupBreakdown group index out of range",
);
let mut seen_idx = vec![false; n];
let mut seen_labels: BTreeSet<&str> = BTreeSet::new();
let mut seen_ids: BTreeSet<u32> = BTreeSet::new();
for g in &out.groups {
if g.index < n {
debug_assert!(
!seen_idx[g.index],
"ClassGroupBreakdown duplicate group index",
);
seen_idx[g.index] = true;
}
debug_assert!(
seen_labels.insert(g.label.as_ref()),
"ClassGroupBreakdown duplicate group label",
);
for &cid in g.class_ids() {
debug_assert!(
seen_ids.insert(cid),
"ClassGroupBreakdown class id {cid} appears in multiple groups",
);
}
}
out
}
pub fn axis(&self) -> &str {
&self.axis
}
pub fn groups(&self) -> &[ClassGroup] {
&self.groups
}
pub fn len(&self) -> usize {
self.groups.len()
}
pub fn is_empty(&self) -> bool {
self.groups.is_empty()
}
pub fn group_at(&self, index: usize) -> Option<&ClassGroup> {
self.groups.iter().find(|g| g.index == index)
}
pub fn group_of(&self, class_id: u32) -> Option<&ClassGroup> {
self.groups
.iter()
.find(|g| g.class_ids().binary_search(&class_id).is_ok())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::accumulate::{accumulate, AccumulateParams};
use crate::evaluate::AreaRange;
use crate::parity::{iou_thresholds, recall_thresholds, ParityMode};
use crate::summarize::{summarize_with, Metric};
use ndarray::{Array4, Array5};
#[test]
fn coco_area_det_matches_legacy_area_range_coco_default_bitwise() {
let bd = Breakdown::coco_area_det();
let legacy = AreaRange::coco_default();
let ranges = bd.area_ranges();
assert_eq!(ranges.len(), legacy.len());
for (got, want) in ranges.iter().zip(legacy.iter()) {
assert_eq!(got.index, want.index, "index drift on bucket {}", got.index);
assert_eq!(
got.lo.to_bits(),
want.lo.to_bits(),
"lo bound drift on bucket {}",
got.index,
);
assert_eq!(
got.hi.to_bits(),
want.hi.to_bits(),
"hi bound drift on bucket {}",
got.index,
);
}
}
#[test]
fn coco_area_keypoints_matches_legacy_keypoints_default_bitwise() {
let bd = Breakdown::coco_area_keypoints();
let legacy = AreaRange::keypoints_default();
let ranges = bd.area_ranges();
assert_eq!(ranges.len(), legacy.len());
for (got, want) in ranges.iter().zip(legacy.iter()) {
assert_eq!(got.index, want.index);
assert_eq!(got.lo.to_bits(), want.lo.to_bits());
assert_eq!(got.hi.to_bits(), want.hi.to_bits());
}
}
#[test]
fn coco_area_det_summary_labels_pin_canonical_strings() {
let bd = Breakdown::coco_area_det();
let labels: Vec<&str> = bd.buckets().iter().map(|b| b.label.as_ref()).collect();
assert_eq!(labels, ["all", "small", "medium", "large"]);
assert_eq!(bd.axis(), "area");
}
#[test]
fn coco_area_keypoints_drops_small_bucket_per_d5() {
let bd = Breakdown::coco_area_keypoints();
assert_eq!(bd.len(), 3);
let labels: Vec<&str> = bd.buckets().iter().map(|b| b.label.as_ref()).collect();
assert_eq!(labels, ["all", "medium", "large"]);
assert!(!labels.contains(&"small"));
let indices: Vec<usize> = bd.buckets().iter().map(|b| b.index).collect();
assert_eq!(indices, [0, 1, 2]);
}
#[test]
fn bucket_contains_is_inclusive_on_both_ends() {
let small = Bucket::from_static(1, "small", 0.0, 32.0 * 32.0);
let medium = Bucket::from_static(2, "medium", 32.0 * 32.0, 96.0 * 96.0);
assert!(small.contains(1024.0));
assert!(medium.contains(1024.0));
assert!(!small.contains(-1.0));
assert!(!medium.contains(1023.999));
assert!(small.contains(0.0));
assert!(medium.contains(96.0 * 96.0));
}
#[test]
fn detection_plan_matches_canonical_default_bitwise() {
let iou = iou_thresholds();
let max_dets = [1usize, 10, 100];
let accum = crate::Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 0.5),
recall: Array4::<f64>::from_elem((iou.len(), 1, 4, 3), 0.7),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 4, 3), 1.0),
};
let static_plan = StatRequest::coco_detection_default();
let bd = Breakdown::coco_area_det();
let bd_plan = bd.detection_plan().expect("4-bucket layout");
let from_static = summarize_with(&accum, &static_plan, iou, &max_dets).unwrap();
let from_bd = summarize_with(&accum, &bd_plan, iou, &max_dets).unwrap();
assert_eq!(from_static.stats(), from_bd.stats());
for (s, b) in from_static.lines.iter().zip(from_bd.lines.iter()) {
assert_eq!(s.metric, b.metric);
assert_eq!(s.iou_threshold, b.iou_threshold);
assert_eq!(s.area.label, b.area.label);
assert_eq!(s.area.index, b.area.index);
assert_eq!(s.max_dets, b.max_dets);
assert_eq!(s.value.to_bits(), b.value.to_bits());
}
}
#[test]
fn keypoints_plan_matches_canonical_default_bitwise() {
let iou = iou_thresholds();
let max_dets = [20usize];
let accum = crate::Accumulated {
precision: Array5::<f64>::from_elem((iou.len(), 101, 1, 3, 1), 0.5),
recall: Array4::<f64>::from_elem((iou.len(), 1, 3, 1), 0.7),
scores: Array5::<f64>::from_elem((iou.len(), 101, 1, 3, 1), 1.0),
};
let static_plan = StatRequest::coco_keypoints_default();
let bd = Breakdown::coco_area_keypoints();
let bd_plan = bd.keypoints_plan().expect("3-bucket kp layout");
let from_static = summarize_with(&accum, &static_plan, iou, &max_dets).unwrap();
let from_bd = summarize_with(&accum, &bd_plan, iou, &max_dets).unwrap();
assert_eq!(from_static.stats(), from_bd.stats());
for (s, b) in from_static.lines.iter().zip(from_bd.lines.iter()) {
assert_eq!(s.area.index, b.area.index);
assert_eq!(s.area.label, b.area.label);
assert_eq!(s.value.to_bits(), b.value.to_bits());
}
}
#[test]
fn detection_plan_returns_none_for_non_canonical_size() {
let bd = Breakdown::coco_area_keypoints(); assert!(bd.detection_plan().is_none());
}
#[test]
fn keypoints_plan_returns_none_for_non_canonical_size() {
let bd = Breakdown::coco_area_det(); assert!(bd.keypoints_plan().is_none());
}
#[test]
fn fine_grained_five_bucket_breakdown_extends_a_axis() {
let bd = Breakdown::new(
"area",
vec![
Bucket::from_static(0, "tiny", 0.0, 16.0 * 16.0),
Bucket::from_static(1, "small", 16.0 * 16.0, 32.0 * 32.0),
Bucket::from_static(2, "medium", 32.0 * 32.0, 96.0 * 96.0),
Bucket::from_static(3, "large", 96.0 * 96.0, 192.0 * 192.0),
Bucket::from_static(4, "huge", 192.0 * 192.0, AREA_UNBOUNDED),
],
);
assert_eq!(bd.len(), 5);
let iou = iou_thresholds();
let max_dets = [100usize];
let mut precision = Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), -1.0);
let mut recall = Array4::<f64>::from_elem((iou.len(), 1, 5, 1), -1.0);
let scores = Array5::<f64>::from_elem((iou.len(), 101, 1, 5, 1), 1.0);
for a in 0..5 {
let pr_val = 0.1 * (a as f64 + 1.0);
let rc_val = 0.2 * (a as f64 + 1.0);
for t in 0..iou.len() {
for r in 0..101 {
precision[(t, r, 0, a, 0)] = pr_val;
}
recall[(t, 0, a, 0)] = rc_val;
}
}
let accum = crate::Accumulated {
precision,
recall,
scores,
};
let plan: Vec<StatRequest> = bd
.buckets()
.iter()
.map(|b| {
StatRequest::new(
Metric::AveragePrecision,
None,
b.to_area_rng(),
MaxDetSelector::Largest,
)
})
.collect();
let summary = summarize_with(&accum, &plan, iou, &max_dets).unwrap();
assert_eq!(summary.lines.len(), 5, "one line per bucket");
let expected_labels = ["tiny", "small", "medium", "large", "huge"];
for (a, line) in summary.lines.iter().enumerate() {
let expected = 0.1 * (a as f64 + 1.0);
assert_eq!(
line.area.label.as_ref(),
expected_labels[a],
"bucket {a} label drift",
);
assert!(
(line.value - expected).abs() < 1e-12,
"bucket {a} value: got {got}, expected {expected}",
got = line.value,
);
}
let area_ranges = bd.area_ranges();
assert_eq!(area_ranges.len(), 5);
let acc_params = AccumulateParams {
iou_thresholds: iou,
recall_thresholds: recall_thresholds(),
max_dets: &max_dets,
n_categories: 1,
n_area_ranges: area_ranges.len(),
n_images: 0,
};
let empty = accumulate(&[], acc_params, ParityMode::Strict).unwrap();
assert_eq!(empty.precision.shape()[3], 5);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "Breakdown must have >= 1 bucket")]
fn empty_breakdown_panics_in_debug() {
let _ = Breakdown::new("axis", vec![]);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "Breakdown bucket index out of range")]
fn out_of_range_index_panics_in_debug() {
let _ = Breakdown::new("axis", vec![Bucket::from_static(5, "x", 0.0, 1.0)]);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "Breakdown has duplicate bucket index")]
fn duplicate_index_panics_in_debug() {
let _ = Breakdown::new(
"axis",
vec![
Bucket::from_static(0, "a", 0.0, 1.0),
Bucket::from_static(0, "b", 1.0, 2.0),
],
);
}
#[test]
fn class_group_new_sorts_and_dedups_class_ids() {
let g = ClassGroup::new(0, "vehicles", vec![8, 3, 6, 3, 8]);
assert_eq!(g.class_ids(), &[3, 6, 8]);
assert_eq!(g.index, 0);
assert_eq!(g.label, "vehicles");
}
#[test]
fn class_group_breakdown_basic_shape() {
let bd = ClassGroupBreakdown::new(
"vehicle_taxonomy",
vec![
ClassGroup::new(0, "small", [3, 4]),
ClassGroup::new(1, "large", [6, 8]),
],
);
assert_eq!(bd.axis(), "vehicle_taxonomy");
assert_eq!(bd.len(), 2);
assert!(!bd.is_empty());
}
#[test]
fn class_group_breakdown_lookup_by_index_and_class_id() {
let bd = ClassGroupBreakdown::new(
"g",
vec![
ClassGroup::new(0, "a", [1, 2, 3]),
ClassGroup::new(1, "b", [10, 20]),
],
);
assert_eq!(bd.group_at(0).unwrap().label, "a");
assert_eq!(bd.group_at(1).unwrap().label, "b");
assert!(bd.group_at(2).is_none());
assert_eq!(bd.group_of(2).unwrap().label, "a");
assert_eq!(bd.group_of(20).unwrap().label, "b");
assert!(bd.group_of(99).is_none());
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "ClassGroupBreakdown must have >= 1 group")]
fn class_group_breakdown_empty_panics_in_debug() {
let _ = ClassGroupBreakdown::new("g", vec![]);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "ClassGroupBreakdown duplicate group label")]
fn class_group_breakdown_duplicate_label_panics_in_debug() {
let _ = ClassGroupBreakdown::new(
"g",
vec![
ClassGroup::new(0, "vehicles", [1]),
ClassGroup::new(1, "vehicles", [2]),
],
);
}
#[test]
#[cfg(debug_assertions)]
#[should_panic(expected = "appears in multiple groups")]
fn class_group_breakdown_partition_violation_panics_in_debug() {
let _ = ClassGroupBreakdown::new(
"g",
vec![
ClassGroup::new(0, "a", [1, 2, 3]),
ClassGroup::new(1, "b", [3, 4]),
],
);
}
}