use std::collections::{HashMap, HashSet};
use crate::accumulate::{accumulate, AccumulateParams, PerImageEval};
use crate::dataset::{CocoDataset, CocoDetections, EvalDataset, ImageId};
use crate::error::EvalError;
use crate::evaluate::EvalKernel;
use crate::lrp::{optimal_lrp_with_partitioned, LrpKernelMarker, LrpParams, LrpReport};
use crate::parity::{recall_thresholds, ParityMode};
use crate::summarize::{summarize_detection, summarize_with, StatRequest, Summary};
pub fn image_id_to_idx<D: EvalDataset>(dataset: &D) -> HashMap<ImageId, usize> {
let mut ids: Vec<ImageId> = dataset.images().iter().map(|im| im.id).collect();
ids.sort_unstable_by_key(|id| id.0);
ids.into_iter().enumerate().map(|(i, id)| (id, i)).collect()
}
pub const UNASSIGNED: &str = "__unassigned__";
pub const CROSS_SEPARATOR: &str = "::";
pub const SLICES_CAP: usize = 256;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum KeyKind {
Image,
Result,
}
#[derive(Debug, Clone)]
pub struct Slice {
pub axis: String,
pub value: String,
pub image_ids: HashSet<ImageId>,
pub image_indices: HashSet<usize>,
}
#[derive(Debug, Clone)]
pub struct PartitionSpec {
pub key_kind: KeyKind,
pub slices: Vec<Slice>,
}
impl PartitionSpec {
pub fn build(
key_kind: KeyKind,
per_axis: &HashMap<String, HashMap<String, HashSet<ImageId>>>,
all_image_ids: &HashSet<ImageId>,
image_id_to_idx: &HashMap<ImageId, usize>,
cross_axes: &[Vec<String>],
) -> Result<Self, EvalError> {
validate_cross_axes(per_axis, cross_axes)?;
let mut marginal_slices: Vec<Slice> = Vec::new();
let mut axes_sorted: Vec<&String> = per_axis.keys().collect();
axes_sorted.sort();
for axis in axes_sorted {
let values = match per_axis.get(axis) {
Some(v) => v,
None => continue,
};
let mut value_keys: Vec<&String> = values.keys().collect();
value_keys.sort();
let mut covered: HashSet<ImageId> = HashSet::new();
for value in &value_keys {
let ids = values.get(*value).cloned().unwrap_or_default();
covered.extend(ids.iter().copied());
marginal_slices.push(make_slice(axis, value, ids, image_id_to_idx));
}
let missing: HashSet<ImageId> = all_image_ids
.iter()
.copied()
.filter(|id| !covered.contains(id))
.collect();
marginal_slices.push(make_slice(axis, UNASSIGNED, missing, image_id_to_idx));
}
let mut joint_slices: Vec<Slice> = Vec::new();
for axes in cross_axes {
joint_slices.extend(expand_cross_axes(
axes,
per_axis,
all_image_ids,
image_id_to_idx,
)?);
}
joint_slices.sort_by(|a, b| a.axis.cmp(&b.axis).then_with(|| a.value.cmp(&b.value)));
let total = marginal_slices.len() + joint_slices.len();
if total > SLICES_CAP {
return Err(EvalError::InvalidConfig {
detail: format!(
"partition would produce {total} slices but the cap is {SLICES_CAP}; \
reduce --cross axes or narrow the manifest"
),
});
}
let mut slices = marginal_slices;
slices.extend(joint_slices);
Ok(Self { key_kind, slices })
}
pub fn len(&self) -> usize {
self.slices.len()
}
pub fn is_empty(&self) -> bool {
self.slices.is_empty()
}
}
#[derive(Debug, Clone)]
pub struct SliceResult {
pub slice: Slice,
pub summary: Summary,
pub n_images: u64,
pub n_detections: u64,
}
#[derive(Debug, Clone)]
pub struct PartitionedSummary {
pub overall: Summary,
pub overall_n_images: u64,
pub overall_n_detections: u64,
pub slices: Vec<SliceResult>,
}
#[derive(Debug, Clone, Copy)]
pub enum SummaryPlan<'a> {
DetectionDefault,
KeypointsDefault,
Custom {
plan: &'a [StatRequest],
max_dets: &'a [usize],
},
}
impl<'a> SummaryPlan<'a> {
fn max_dets(&self) -> &'a [usize] {
match self {
Self::DetectionDefault => &DETECTION_MAX_DETS,
Self::KeypointsDefault => &KEYPOINTS_MAX_DETS,
Self::Custom { max_dets, .. } => max_dets,
}
}
fn summarize(
&self,
accum: &crate::accumulate::Accumulated,
iou_thresholds: &[f64],
) -> Result<Summary, EvalError> {
match self {
Self::DetectionDefault => {
summarize_detection(accum, iou_thresholds, &DETECTION_MAX_DETS)
}
Self::KeypointsDefault => {
let plan = StatRequest::coco_keypoints_default();
summarize_with(accum, &plan, iou_thresholds, &KEYPOINTS_MAX_DETS)
}
Self::Custom { plan, max_dets } => {
summarize_with(accum, plan, iou_thresholds, max_dets)
}
}
}
}
const DETECTION_MAX_DETS: [usize; 3] = [1, 10, 100];
const KEYPOINTS_MAX_DETS: [usize; 1] = [20];
#[derive(Debug, Clone, Copy)]
pub struct GridDims {
pub n_categories: usize,
pub n_area_ranges: usize,
pub n_images: usize,
}
pub fn evaluate_partitioned(
eval_imgs: &[Option<Box<PerImageEval>>],
grid: GridDims,
spec: &PartitionSpec,
iou_thresholds: &[f64],
parity_mode: ParityMode,
summary_plan: SummaryPlan<'_>,
) -> Result<PartitionedSummary, EvalError> {
let expected = grid.n_categories * grid.n_area_ranges * grid.n_images;
if eval_imgs.len() != expected {
return Err(EvalError::DimensionMismatch {
detail: format!(
"eval_imgs len {} != n_categories({}) * n_area_ranges({}) * n_images({}) = {}",
eval_imgs.len(),
grid.n_categories,
grid.n_area_ranges,
grid.n_images,
expected,
),
});
}
let accum_params = AccumulateParams {
iou_thresholds,
recall_thresholds: recall_thresholds(),
max_dets: summary_plan.max_dets(),
n_categories: grid.n_categories,
n_area_ranges: grid.n_area_ranges,
n_images: grid.n_images,
};
let accum_overall = accumulate(eval_imgs, accum_params, parity_mode)?;
let overall = summary_plan.summarize(&accum_overall, iou_thresholds)?;
let overall_n_detections = count_detections(eval_imgs, grid, None);
let mut slices_out: Vec<SliceResult> = Vec::with_capacity(spec.slices.len());
for slice in &spec.slices {
let (filtered, n_detections) =
filtered_flatten_and_count(eval_imgs, grid, &slice.image_indices);
let accum = accumulate(&filtered, accum_params, parity_mode)?;
let summary = summary_plan.summarize(&accum, iou_thresholds)?;
slices_out.push(SliceResult {
n_images: slice.image_ids.len() as u64,
n_detections,
slice: slice.clone(),
summary,
});
}
Ok(PartitionedSummary {
overall,
overall_n_images: grid.n_images as u64,
overall_n_detections,
slices: slices_out,
})
}
#[derive(Debug, Clone)]
pub struct LrpSliceResult {
pub slice: Slice,
pub report: LrpReport,
pub n_images: u64,
pub n_detections: u64,
}
#[derive(Debug, Clone)]
pub struct PartitionedLrpReport {
pub overall: LrpReport,
pub overall_n_images: u64,
pub overall_n_detections: u64,
pub slices: Vec<LrpSliceResult>,
}
pub fn evaluate_partitioned_lrp<K: EvalKernel>(
gt: &CocoDataset,
dt: &CocoDetections,
kernel: &K,
kernel_marker: LrpKernelMarker,
params: LrpParams<'_>,
parity_mode: ParityMode,
spec: &PartitionSpec,
) -> Result<PartitionedLrpReport, EvalError> {
let filters: Vec<HashSet<usize>> = spec
.slices
.iter()
.map(|s| s.image_indices.clone())
.collect();
let mut reports =
optimal_lrp_with_partitioned(gt, dt, kernel, kernel_marker, params, parity_mode, &filters)?;
if reports.len() != spec.slices.len() + 1 {
return Err(EvalError::DimensionMismatch {
detail: format!(
"lrp partition: expected {} reports (1 overall + {} slices), got {}",
spec.slices.len() + 1,
spec.slices.len(),
reports.len()
),
});
}
let overall = reports.remove(0);
let n_images_total = gt.images().len() as u64;
let image_id_to_idx_map = image_id_to_idx(gt);
let mut slice_n_detections: Vec<u64> = vec![0; spec.slices.len()];
let mut overall_n_detections: u64 = 0;
for d in dt.detections() {
overall_n_detections = overall_n_detections.saturating_add(1);
let Some(&i) = image_id_to_idx_map.get(&d.image_id) else {
continue;
};
for (slice_idx, slice) in spec.slices.iter().enumerate() {
if slice.image_indices.contains(&i) {
slice_n_detections[slice_idx] = slice_n_detections[slice_idx].saturating_add(1);
}
}
}
let slices_out: Vec<LrpSliceResult> = spec
.slices
.iter()
.zip(reports)
.enumerate()
.map(|(idx, (slice, report))| LrpSliceResult {
n_images: slice.image_ids.len() as u64,
n_detections: slice_n_detections[idx],
slice: slice.clone(),
report,
})
.collect();
Ok(PartitionedLrpReport {
overall,
overall_n_images: n_images_total,
overall_n_detections,
slices: slices_out,
})
}
pub fn evaluate_partitioned_with(
eval_imgs: &[Option<Box<PerImageEval>>],
grid: GridDims,
spec: &PartitionSpec,
iou_thresholds: &[f64],
max_dets: &[usize],
parity_mode: ParityMode,
plan: &[StatRequest],
) -> Result<PartitionedSummary, EvalError> {
evaluate_partitioned(
eval_imgs,
grid,
spec,
iou_thresholds,
parity_mode,
SummaryPlan::Custom { plan, max_dets },
)
}
fn filtered_flatten_and_count(
eval_imgs: &[Option<Box<PerImageEval>>],
grid: GridDims,
slice_indices: &HashSet<usize>,
) -> (Vec<Option<Box<PerImageEval>>>, u64) {
let total = grid.n_categories * grid.n_area_ranges * grid.n_images;
let mut out: Vec<Option<Box<PerImageEval>>> = Vec::with_capacity(total);
let mut n_detections: u64 = 0;
for k in 0..grid.n_categories {
for a in 0..grid.n_area_ranges {
for i in 0..grid.n_images {
let flat = k * grid.n_area_ranges * grid.n_images + a * grid.n_images + i;
let cell = if slice_indices.contains(&i) {
eval_imgs.get(flat).and_then(|c| c.clone())
} else {
None
};
if a == 0 {
if let Some(ref c) = cell {
n_detections = n_detections.saturating_add(c.dt_scores.len() as u64);
}
}
out.push(cell);
}
}
}
(out, n_detections)
}
fn count_detections(
eval_imgs: &[Option<Box<PerImageEval>>],
grid: GridDims,
slice_indices: Option<&HashSet<usize>>,
) -> u64 {
let mut total: u64 = 0;
for k in 0..grid.n_categories {
for i in 0..grid.n_images {
if let Some(set) = slice_indices {
if !set.contains(&i) {
continue;
}
}
let flat = k * grid.n_area_ranges * grid.n_images + i;
if let Some(cell) = eval_imgs.get(flat).and_then(|c| c.as_deref()) {
total = total.saturating_add(cell.dt_scores.len() as u64);
}
}
}
total
}
fn make_slice(
axis: &str,
value: &str,
image_ids: HashSet<ImageId>,
image_id_to_idx: &HashMap<ImageId, usize>,
) -> Slice {
let image_indices: HashSet<usize> = image_ids
.iter()
.filter_map(|id| image_id_to_idx.get(id).copied())
.collect();
Slice {
axis: axis.to_owned(),
value: value.to_owned(),
image_ids,
image_indices,
}
}
fn validate_cross_axes(
per_axis: &HashMap<String, HashMap<String, HashSet<ImageId>>>,
cross_axes: &[Vec<String>],
) -> Result<(), EvalError> {
for axis in per_axis.keys() {
if axis.contains(CROSS_SEPARATOR) {
return Err(EvalError::InvalidConfig {
detail: format!(
"manifest axis name {axis:?} contains the reserved separator \
{CROSS_SEPARATOR:?}; rename the axis"
),
});
}
}
for axes in cross_axes {
if axes.len() < 2 {
return Err(EvalError::InvalidConfig {
detail: format!(
"--cross requires at least two axes per tuple; got {} ({:?})",
axes.len(),
axes
),
});
}
let mut seen: HashSet<&String> = HashSet::with_capacity(axes.len());
for ax in axes {
if !per_axis.contains_key(ax) {
return Err(EvalError::InvalidConfig {
detail: format!(
"--cross references axis {ax:?} which is not present in the manifest"
),
});
}
if !seen.insert(ax) {
return Err(EvalError::InvalidConfig {
detail: format!("--cross tuple {axes:?} repeats axis {ax:?}"),
});
}
}
}
Ok(())
}
type AxisValueEntry<'a> = (&'a str, &'a HashSet<ImageId>);
fn expand_cross_axes(
axes: &[String],
per_axis: &HashMap<String, HashMap<String, HashSet<ImageId>>>,
all_image_ids: &HashSet<ImageId>,
image_id_to_idx: &HashMap<ImageId, usize>,
) -> Result<Vec<Slice>, EvalError> {
let mut value_sets: Vec<(&str, Vec<AxisValueEntry<'_>>)> = Vec::with_capacity(axes.len());
for axis in axes {
let by_value = per_axis.get(axis).ok_or_else(|| EvalError::InvalidConfig {
detail: format!("--cross axis {axis:?} missing during expansion"),
})?;
let mut entries: Vec<(&str, &HashSet<ImageId>)> =
by_value.iter().map(|(v, ids)| (v.as_str(), ids)).collect();
entries.sort_by_key(|(v, _)| *v);
value_sets.push((axis.as_str(), entries));
}
let mut combos: Vec<Vec<(&str, &str, &HashSet<ImageId>)>> = vec![Vec::new()];
for (axis_name, values) in &value_sets {
let mut next: Vec<Vec<(&str, &str, &HashSet<ImageId>)>> = Vec::new();
for combo in &combos {
for (value, ids) in values {
let mut extended = combo.clone();
extended.push((axis_name, value, ids));
next.push(extended);
}
}
combos = next;
}
let joined_axis: String = axes.join(CROSS_SEPARATOR);
let mut out: Vec<Slice> = Vec::with_capacity(combos.len() + 1);
let mut covered: HashSet<ImageId> = HashSet::new();
for combo in combos {
let mut iter = combo.iter().map(|(_, _, ids)| *ids);
let mut joint: HashSet<ImageId> = match iter.next() {
Some(first) => first.clone(),
None => HashSet::new(),
};
for ids in iter {
joint = joint.intersection(ids).copied().collect();
}
covered.extend(joint.iter().copied());
let joined_value: String = combo
.iter()
.map(|(_, v, _)| (*v).to_owned())
.collect::<Vec<_>>()
.join(CROSS_SEPARATOR);
out.push(make_slice(
&joined_axis,
&joined_value,
joint,
image_id_to_idx,
));
}
let missing: HashSet<ImageId> = all_image_ids
.iter()
.copied()
.filter(|id| !covered.contains(id))
.collect();
out.push(make_slice(
&joined_axis,
UNASSIGNED,
missing,
image_id_to_idx,
));
Ok(out)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::accumulate::PerImageEval;
use ndarray::Array2;
fn id(n: i64) -> ImageId {
ImageId(n)
}
fn build_image_grid(n: i64) -> (HashSet<ImageId>, HashMap<ImageId, usize>) {
let all: HashSet<ImageId> = (1..=n).map(id).collect();
let map: HashMap<ImageId, usize> = (1..=n).map(|i| (id(i), (i - 1) as usize)).collect();
(all, map)
}
#[test]
fn marginal_order_is_axis_then_value_then_unassigned_last() {
let mut per_axis: HashMap<String, HashMap<String, HashSet<ImageId>>> = HashMap::new();
per_axis.insert(
"weather".into(),
HashMap::from([
("fog".into(), HashSet::from([id(1)])),
("clear".into(), HashSet::from([id(2)])),
]),
);
per_axis.insert(
"time".into(),
HashMap::from([("day".into(), HashSet::from([id(1), id(2)]))]),
);
let (all, map) = build_image_grid(3);
let spec = PartitionSpec::build(KeyKind::Image, &per_axis, &all, &map, &[]).unwrap();
let labels: Vec<(&str, &str)> = spec
.slices
.iter()
.map(|s| (s.axis.as_str(), s.value.as_str()))
.collect();
assert_eq!(
labels,
vec![
("time", "day"),
("time", UNASSIGNED),
("weather", "clear"),
("weather", "fog"),
("weather", UNASSIGNED),
]
);
}
#[test]
fn unassigned_collects_dataset_images_not_in_any_value() {
let mut per_axis: HashMap<String, HashMap<String, HashSet<ImageId>>> = HashMap::new();
per_axis.insert(
"weather".into(),
HashMap::from([("fog".into(), HashSet::from([id(1)]))]),
);
let (all, map) = build_image_grid(3);
let spec = PartitionSpec::build(KeyKind::Image, &per_axis, &all, &map, &[]).unwrap();
let unassigned = spec
.slices
.iter()
.find(|s| s.axis == "weather" && s.value == UNASSIGNED)
.expect("expected an unassigned slice on `weather`");
let mut ids: Vec<i64> = unassigned.image_ids.iter().map(|i| i.0).collect();
ids.sort();
assert_eq!(ids, vec![2, 3]);
}
#[test]
fn cross_axes_emits_intersection_joint_cells() {
let mut per_axis: HashMap<String, HashMap<String, HashSet<ImageId>>> = HashMap::new();
per_axis.insert(
"weather".into(),
HashMap::from([
("fog".into(), HashSet::from([id(1), id(2)])),
("clear".into(), HashSet::from([id(3), id(4)])),
]),
);
per_axis.insert(
"time".into(),
HashMap::from([
("day".into(), HashSet::from([id(1), id(3)])),
("night".into(), HashSet::from([id(2), id(4)])),
]),
);
let (all, map) = build_image_grid(4);
let cross = vec![vec!["weather".to_string(), "time".to_string()]];
let spec = PartitionSpec::build(KeyKind::Image, &per_axis, &all, &map, &cross).unwrap();
let joint: Vec<&Slice> = spec
.slices
.iter()
.filter(|s| s.axis.contains(CROSS_SEPARATOR))
.collect();
assert_eq!(joint.len(), 5);
let fog_day = joint
.iter()
.find(|s| s.value == "fog::day")
.expect("fog::day must exist");
let mut ids: Vec<i64> = fog_day.image_ids.iter().map(|i| i.0).collect();
ids.sort();
assert_eq!(ids, vec![1]);
}
#[test]
fn cross_axes_with_unknown_axis_is_rejected() {
let mut per_axis: HashMap<String, HashMap<String, HashSet<ImageId>>> = HashMap::new();
per_axis.insert(
"weather".into(),
HashMap::from([("fog".into(), HashSet::from([id(1)]))]),
);
let (all, map) = build_image_grid(2);
let cross = vec![vec!["weather".into(), "missing".into()]];
let err = PartitionSpec::build(KeyKind::Image, &per_axis, &all, &map, &cross).unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
#[test]
fn cross_axes_singleton_tuple_is_rejected() {
let mut per_axis: HashMap<String, HashMap<String, HashSet<ImageId>>> = HashMap::new();
per_axis.insert(
"weather".into(),
HashMap::from([("fog".into(), HashSet::from([id(1)]))]),
);
let (all, map) = build_image_grid(2);
let cross = vec![vec!["weather".into()]];
let err = PartitionSpec::build(KeyKind::Image, &per_axis, &all, &map, &cross).unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
#[test]
fn slice_cap_is_enforced() {
let mut by_value: HashMap<String, HashSet<ImageId>> = HashMap::new();
let n = SLICES_CAP + 1;
for i in 1..=n as i64 {
by_value.insert(format!("v{i}"), HashSet::from([id(i)]));
}
let mut per_axis: HashMap<String, HashMap<String, HashSet<ImageId>>> = HashMap::new();
per_axis.insert("axis".into(), by_value);
let (all, map) = build_image_grid(n as i64);
let err = PartitionSpec::build(KeyKind::Image, &per_axis, &all, &map, &[]).unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
#[test]
fn axis_name_with_cross_separator_is_rejected() {
let mut per_axis: HashMap<String, HashMap<String, HashSet<ImageId>>> = HashMap::new();
per_axis.insert(
"weather::extra".into(),
HashMap::from([("fog".into(), HashSet::from([id(1)]))]),
);
let (all, map) = build_image_grid(2);
let err = PartitionSpec::build(KeyKind::Image, &per_axis, &all, &map, &[]).unwrap_err();
assert!(matches!(err, EvalError::InvalidConfig { .. }));
}
fn fake_cell(n_dts: usize) -> Box<PerImageEval> {
Box::new(PerImageEval {
dt_scores: vec![0.5; n_dts],
dt_matched: Array2::default((1, n_dts)),
dt_ignore: Array2::default((1, n_dts)),
gt_ignore: vec![false],
})
}
#[test]
fn filtered_flatten_keeps_only_in_slice_cells() {
let grid = GridDims {
n_categories: 1,
n_area_ranges: 1,
n_images: 3,
};
let eval_imgs: Vec<Option<Box<PerImageEval>>> =
vec![Some(fake_cell(2)), Some(fake_cell(3)), Some(fake_cell(4))];
let slice_indices: HashSet<usize> = HashSet::from([0, 2]);
let (filtered, n_detections) = filtered_flatten_and_count(&eval_imgs, grid, &slice_indices);
assert!(filtered[0].is_some());
assert!(filtered[1].is_none());
assert!(filtered[2].is_some());
assert_eq!(n_detections, 2 + 4);
}
#[test]
fn count_detections_skips_out_of_slice_images() {
let grid = GridDims {
n_categories: 2,
n_area_ranges: 1,
n_images: 2,
};
let eval_imgs: Vec<Option<Box<PerImageEval>>> = vec![
Some(fake_cell(1)),
Some(fake_cell(2)),
Some(fake_cell(3)),
Some(fake_cell(4)),
];
let total = count_detections(&eval_imgs, grid, None);
assert_eq!(total, 1 + 2 + 3 + 4);
let only_first = count_detections(&eval_imgs, grid, Some(&HashSet::from([0])));
assert_eq!(only_first, 4);
}
}