use std::collections::{HashMap, HashSet};
use crate::dataset::{CocoDataset, CocoDetections, EvalDataset, ImageMeta};
use crate::error::EvalError;
use crate::evaluate::EvalKernel;
use crate::parity::ParityMode;
use crate::tide::cross_class::compute_cross_class_ious;
#[derive(Debug, Clone, Default)]
pub struct ConfusionMatrixCounts {
pub counts: HashMap<(Option<usize>, Option<usize>), u64>,
pub category_ids: Vec<i64>,
}
pub fn compute_confusion_matrix<K: EvalKernel>(
gt: &CocoDataset,
dt: &CocoDetections,
kernel: &K,
iou_threshold: f64,
max_dets_per_image: usize,
parity_mode: ParityMode,
) -> Result<ConfusionMatrixCounts, EvalError> {
let cross = compute_cross_class_ious(gt, dt, kernel, parity_mode, max_dets_per_image)?;
let mut category_ids: Vec<i64> = gt.categories().iter().map(|c| c.id.0).collect();
category_ids.sort_unstable();
let mut images: Vec<&ImageMeta> = gt.images().iter().collect();
images.sort_unstable_by_key(|im| im.id.0);
let gt_anns = gt.annotations();
let mut counts: HashMap<(Option<usize>, Option<usize>), u64> = HashMap::new();
for (image_idx, image) in images.iter().enumerate() {
let gt_indices = gt.ann_indices_for_image(image.id);
let (Some(iou), Some(dt_classes), Some(gt_classes)) = (
cross.get(image_idx),
cross.dt_classes(image_idx),
cross.gt_classes(image_idx),
) else {
continue;
};
let n_d = iou.shape()[0];
let n_g = iou.shape()[1];
let mut gt_taken: HashSet<usize> = HashSet::new();
for d in 0..n_d {
let mut best_g: Option<usize> = None;
let mut best_iou = f64::NEG_INFINITY;
for g in 0..n_g {
if gt_taken.contains(&g) {
continue;
}
let v = iou[(d, g)];
if v > best_iou {
best_iou = v;
best_g = Some(g);
}
}
let dt_class_idx = dt_classes[d];
if let Some(g) = best_g {
if best_iou >= iou_threshold {
if is_ignore_gt(>_anns[gt_indices[g]]) {
continue;
}
gt_taken.insert(g);
*counts
.entry((Some(gt_classes[g]), Some(dt_class_idx)))
.or_insert(0) += 1;
continue;
}
}
*counts.entry((None, Some(dt_class_idx))).or_insert(0) += 1;
}
for (g, >_class_idx) in gt_classes.iter().enumerate() {
if gt_taken.contains(&g) || is_ignore_gt(>_anns[gt_indices[g]]) {
continue;
}
*counts.entry((Some(gt_class_idx), None)).or_insert(0) += 1;
}
}
Ok(ConfusionMatrixCounts {
counts,
category_ids,
})
}
fn is_ignore_gt(ann: &crate::dataset::CocoAnnotation) -> bool {
ann.is_crowd || ann.ignore_flag.unwrap_or(false)
}
#[cfg(test)]
fn category_index_for_id(counts: &ConfusionMatrixCounts, category_id: i64) -> Option<usize> {
counts.category_ids.iter().position(|&id| id == category_id)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dataset::{
AnnId, Bbox, CategoryId, CategoryMeta, CocoAnnotation, DetectionInput, ImageId, ImageMeta,
};
use crate::similarity::BboxIou;
fn img(id: i64, w: u32, h: u32) -> ImageMeta {
ImageMeta {
id: ImageId(id),
width: w,
height: h,
file_name: None,
}
}
fn cat(id: i64, name: &str) -> CategoryMeta {
CategoryMeta {
id: CategoryId(id),
name: name.into(),
supercategory: None,
}
}
fn ann(
id: i64,
image: i64,
cat: i64,
bbox: (f64, f64, f64, f64),
iscrowd: bool,
) -> CocoAnnotation {
CocoAnnotation {
id: AnnId(id),
image_id: ImageId(image),
category_id: CategoryId(cat),
area: bbox.2 * bbox.3,
is_crowd: iscrowd,
ignore_flag: None,
bbox: Bbox {
x: bbox.0,
y: bbox.1,
w: bbox.2,
h: bbox.3,
},
segmentation: None,
keypoints: None,
num_keypoints: None,
}
}
fn dt_input(image: i64, cat: i64, score: f64, bbox: (f64, f64, f64, f64)) -> DetectionInput {
DetectionInput {
id: None,
image_id: ImageId(image),
category_id: CategoryId(cat),
score,
bbox: Bbox {
x: bbox.0,
y: bbox.1,
w: bbox.2,
h: bbox.3,
},
segmentation: None,
keypoints: None,
num_keypoints: None,
}
}
#[test]
fn diagonal_only_when_every_dt_matches_same_class_gt() {
let images = vec![img(1, 200, 200)];
let cats = vec![cat(1, "a"), cat(2, "b")];
let anns = vec![
ann(1, 1, 1, (10.0, 10.0, 40.0, 40.0), false),
ann(2, 1, 2, (100.0, 100.0, 40.0, 40.0), false),
];
let gt = CocoDataset::from_parts(images, anns, cats).expect("dataset builds");
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (10.0, 10.0, 40.0, 40.0)),
dt_input(1, 2, 0.8, (100.0, 100.0, 40.0, 40.0)),
])
.expect("detections build");
let cm = compute_confusion_matrix(>, &dts, &BboxIou, 0.5, 100, ParityMode::Strict)
.expect("confusion matrix runs");
let idx_a = category_index_for_id(&cm, 1).expect("class 1 in matrix");
let idx_b = category_index_for_id(&cm, 2).expect("class 2 in matrix");
assert_eq!(cm.counts.get(&(Some(idx_a), Some(idx_a))), Some(&1));
assert_eq!(cm.counts.get(&(Some(idx_b), Some(idx_b))), Some(&1));
assert_eq!(cm.counts.len(), 2);
}
#[test]
fn off_diagonal_when_every_dt_is_wrong_class() {
let images = vec![img(1, 200, 200)];
let cats = vec![cat(1, "a"), cat(2, "b")];
let anns = vec![
ann(1, 1, 1, (10.0, 10.0, 40.0, 40.0), false),
ann(2, 1, 2, (100.0, 100.0, 40.0, 40.0), false),
];
let gt = CocoDataset::from_parts(images, anns, cats).expect("dataset builds");
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 2, 0.9, (10.0, 10.0, 40.0, 40.0)),
dt_input(1, 1, 0.9, (100.0, 100.0, 40.0, 40.0)),
])
.expect("detections build");
let cm = compute_confusion_matrix(>, &dts, &BboxIou, 0.5, 100, ParityMode::Strict)
.expect("confusion matrix runs");
let idx_a = category_index_for_id(&cm, 1).expect("class 1");
let idx_b = category_index_for_id(&cm, 2).expect("class 2");
assert_eq!(cm.counts.get(&(Some(idx_a), Some(idx_b))), Some(&1));
assert_eq!(cm.counts.get(&(Some(idx_b), Some(idx_a))), Some(&1));
assert_eq!(cm.counts.len(), 2);
}
#[test]
fn fp_row_for_background_dts_and_no_missed_for_covered_gts() {
let images = vec![img(1, 1000, 1000)];
let cats = vec![cat(1, "a"), cat(2, "b")];
let anns = vec![
ann(1, 1, 1, (10.0, 10.0, 40.0, 40.0), false),
ann(2, 1, 2, (100.0, 100.0, 40.0, 40.0), false),
];
let gt = CocoDataset::from_parts(images, anns, cats).expect("dataset builds");
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (500.0, 500.0, 30.0, 30.0)),
dt_input(1, 2, 0.9, (600.0, 500.0, 30.0, 30.0)),
dt_input(1, 1, 0.5, (10.0, 10.0, 40.0, 40.0)),
dt_input(1, 2, 0.5, (100.0, 100.0, 40.0, 40.0)),
])
.expect("detections build");
let cm = compute_confusion_matrix(>, &dts, &BboxIou, 0.5, 100, ParityMode::Strict)
.expect("confusion matrix runs");
let idx_a = category_index_for_id(&cm, 1).expect("class 1");
let idx_b = category_index_for_id(&cm, 2).expect("class 2");
assert_eq!(cm.counts.get(&(None, Some(idx_a))), Some(&1));
assert_eq!(cm.counts.get(&(None, Some(idx_b))), Some(&1));
assert_eq!(cm.counts.get(&(Some(idx_a), Some(idx_a))), Some(&1));
assert_eq!(cm.counts.get(&(Some(idx_b), Some(idx_b))), Some(&1));
assert!(!cm.counts.contains_key(&(Some(idx_a), None)));
assert!(!cm.counts.contains_key(&(Some(idx_b), None)));
}
#[test]
fn fp_and_missed_when_dts_and_gts_dont_overlap_at_all() {
let images = vec![img(1, 1000, 1000)];
let cats = vec![cat(1, "a"), cat(2, "b")];
let anns = vec![
ann(1, 1, 1, (10.0, 10.0, 40.0, 40.0), false),
ann(2, 1, 2, (100.0, 100.0, 40.0, 40.0), false),
];
let gt = CocoDataset::from_parts(images, anns, cats).expect("dataset builds");
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (500.0, 500.0, 30.0, 30.0)),
dt_input(1, 2, 0.9, (600.0, 500.0, 30.0, 30.0)),
])
.expect("detections build");
let cm = compute_confusion_matrix(>, &dts, &BboxIou, 0.5, 100, ParityMode::Strict)
.expect("confusion matrix runs");
let idx_a = category_index_for_id(&cm, 1).expect("class 1");
let idx_b = category_index_for_id(&cm, 2).expect("class 2");
assert_eq!(cm.counts.get(&(None, Some(idx_a))), Some(&1));
assert_eq!(cm.counts.get(&(None, Some(idx_b))), Some(&1));
assert_eq!(cm.counts.get(&(Some(idx_a), None)), Some(&1));
assert_eq!(cm.counts.get(&(Some(idx_b), None)), Some(&1));
}
#[test]
fn ignore_gt_neither_matched_nor_missed() {
let images = vec![img(1, 1000, 1000), img(2, 1000, 1000)];
let cats = vec![cat(1, "a")];
let anns = vec![
ann(1, 1, 1, (10.0, 10.0, 40.0, 40.0), true), ann(2, 2, 1, (10.0, 10.0, 40.0, 40.0), false),
];
let gt = CocoDataset::from_parts(images, anns, cats).expect("dataset builds");
let dts = CocoDetections::from_inputs(vec![
dt_input(1, 1, 0.9, (10.0, 10.0, 40.0, 40.0)),
dt_input(1, 1, 0.5, (500.0, 500.0, 30.0, 30.0)),
dt_input(2, 1, 0.9, (10.0, 10.0, 40.0, 40.0)),
])
.expect("detections build");
let cm = compute_confusion_matrix(>, &dts, &BboxIou, 0.5, 100, ParityMode::Strict)
.expect("confusion matrix runs");
let idx_a = category_index_for_id(&cm, 1).expect("class 1");
assert_eq!(cm.counts.get(&(Some(idx_a), Some(idx_a))), Some(&1));
assert_eq!(cm.counts.get(&(None, Some(idx_a))), Some(&1));
assert!(!cm.counts.contains_key(&(Some(idx_a), None)));
assert_eq!(cm.counts.values().sum::<u64>(), 2);
}
}