use std::collections::HashMap;
use ndarray::Array2;
use crate::dataset::{
CategoryId, CocoAnnotation, CocoDataset, CocoDetection, CocoDetections, EvalDataset, ImageId,
};
use crate::error::EvalError;
use crate::evaluate::dt_top_indices_for_cell;
use crate::matching::{match_image, MatchResult};
use crate::parity::ParityMode;
use crate::tables::CrossClassIous;
use super::params::TideParams;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct DtBinLabel {
pub bin: DtBin,
pub target_gt_local_idx: i32,
pub iou_same: f64,
pub iou_cross: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum DtBin {
Tp,
Ignore,
Cls,
Loc,
Both,
Dupe,
Bkg,
}
#[derive(Debug, Default, Clone)]
pub struct BinAssignment {
pub dt_labels: HashMap<(i64, usize), DtBinLabel>,
pub missed_gts: Vec<(i64, usize)>,
}
pub fn assign_bins(
gt: &CocoDataset,
dt: &CocoDetections,
cross_class: &CrossClassIous,
params: &TideParams<'_>,
) -> Result<BinAssignment, EvalError> {
let mut images: Vec<&crate::dataset::ImageMeta> = gt.images().iter().collect();
images.sort_unstable_by_key(|im| im.id.0);
let mut out = BinAssignment::default();
let gt_anns = gt.annotations();
let dt_anns = dt.detections();
for (image_idx, image) in images.iter().enumerate() {
assign_bins_for_image(
image_idx,
image.id,
gt,
dt,
cross_class,
params,
gt_anns,
dt_anns,
&mut out,
)?;
}
Ok(out)
}
#[allow(clippy::too_many_arguments)]
fn assign_bins_for_image(
image_idx: usize,
image_id: ImageId,
gt: &CocoDataset,
dt: &CocoDetections,
cross_class: &CrossClassIous,
params: &TideParams<'_>,
gt_anns: &[CocoAnnotation],
dt_anns: &[CocoDetection],
out: &mut BinAssignment,
) -> Result<(), EvalError> {
let gt_local_indices: &[usize] = gt.ann_indices_for_image(image_id);
let dt_local_indices = dt_top_indices_for_cell(dt, image_id, None, params.max_dets_per_image);
if gt_local_indices.is_empty() && dt_local_indices.is_empty() {
return Ok(());
}
let cross = cross_class.get(image_idx);
let mut per_dt_matched: HashMap<usize, bool> = HashMap::new();
let mut per_dt_ignore: HashMap<usize, bool> = HashMap::new();
let mut gt_taken_by: HashMap<usize, usize> = HashMap::new();
let cats_in_image: Vec<CategoryId> = if params.use_cats {
let mut cats: Vec<CategoryId> = gt_local_indices
.iter()
.map(|&j| gt_anns[j].category_id)
.chain(dt_local_indices.iter().map(|&j| dt_anns[j].category_id))
.collect();
cats.sort_unstable_by_key(|c| c.0);
cats.dedup();
cats
} else {
vec![CategoryId(crate::evaluate::COLLAPSED_CATEGORY_SENTINEL)]
};
for cat in cats_in_image {
same_class_match_one_category(
>_local_indices_with_pos(gt_local_indices, gt_anns, cat, params.use_cats),
&dt_local_indices_with_pos(&dt_local_indices, dt_anns, cat, params.use_cats),
gt_anns,
dt_anns,
params,
&mut per_dt_matched,
&mut per_dt_ignore,
&mut gt_taken_by,
)?;
}
for (row_idx, &dt_input_idx) in dt_local_indices.iter().enumerate() {
let dt = &dt_anns[dt_input_idx];
let key = (image_id.0, dt_input_idx);
if per_dt_ignore.get(&dt_input_idx).copied().unwrap_or(false) {
out.dt_labels.insert(
key,
DtBinLabel {
bin: DtBin::Ignore,
target_gt_local_idx: -1,
iou_same: 0.0,
iou_cross: 0.0,
},
);
continue;
}
if per_dt_matched.get(&dt_input_idx).copied().unwrap_or(false) {
out.dt_labels.insert(
key,
DtBinLabel {
bin: DtBin::Tp,
target_gt_local_idx: -1,
iou_same: 0.0,
iou_cross: 0.0,
},
);
continue;
}
let (iou_same, best_same_col, iou_cross, best_cross_col) = best_same_and_cross(
row_idx,
dt.category_id,
cross,
gt_local_indices,
gt_anns,
params.use_cats,
);
let label = pick_bin(
iou_same,
best_same_col,
iou_cross,
best_cross_col,
params.t_f,
params.t_b,
);
out.dt_labels.insert(key, label);
}
for (col_idx, >_input_idx) in gt_local_indices.iter().enumerate() {
let g = >_anns[gt_input_idx];
if g.is_crowd || g.ignore_flag.unwrap_or(false) {
continue;
}
if gt_taken_by.contains_key(&col_idx) {
continue;
}
out.missed_gts.push((image_id.0, gt_input_idx));
}
Ok(())
}
fn gt_local_indices_with_pos(
gt_local_indices: &[usize],
gt_anns: &[CocoAnnotation],
cat: CategoryId,
use_cats: bool,
) -> Vec<(usize, usize)> {
gt_local_indices
.iter()
.enumerate()
.filter(|&(_, &gi)| !use_cats || gt_anns[gi].category_id == cat)
.map(|(col, &gi)| (col, gi))
.collect()
}
fn dt_local_indices_with_pos(
dt_local_indices: &[usize],
dt_anns: &[CocoDetection],
cat: CategoryId,
use_cats: bool,
) -> Vec<(usize, usize)> {
dt_local_indices
.iter()
.enumerate()
.filter(|&(_, &di)| !use_cats || dt_anns[di].category_id == cat)
.map(|(row, &di)| (row, di))
.collect()
}
#[allow(clippy::too_many_arguments)]
fn same_class_match_one_category(
gts_in_cat: &[(usize, usize)], dts_in_cat: &[(usize, usize)], gt_anns: &[CocoAnnotation],
dt_anns: &[CocoDetection],
params: &TideParams<'_>,
per_dt_matched: &mut HashMap<usize, bool>,
per_dt_ignore: &mut HashMap<usize, bool>,
gt_taken_by: &mut HashMap<usize, usize>,
) -> Result<(), EvalError> {
if dts_in_cat.is_empty() {
return Ok(());
}
let n_g = gts_in_cat.len();
let n_d = dts_in_cat.len();
let mut iou = Array2::<f64>::zeros((n_g, n_d));
if n_g > 0 {
for (gi_local, &(_, gi)) in gts_in_cat.iter().enumerate() {
let g_box = gt_anns[gi].bbox;
for (di_local, &(_, di)) in dts_in_cat.iter().enumerate() {
let d_box = dt_anns[di].bbox;
iou[(gi_local, di_local)] = bbox_iou_pair(g_box, d_box);
}
}
}
let gt_ignore: Vec<bool> = gts_in_cat
.iter()
.map(|&(_, gi)| {
let g = >_anns[gi];
g.is_crowd || g.ignore_flag.unwrap_or(false)
})
.collect();
let gt_iscrowd: Vec<bool> = gts_in_cat
.iter()
.map(|&(_, gi)| gt_anns[gi].is_crowd)
.collect();
let dt_scores: Vec<f64> = dts_in_cat
.iter()
.map(|&(_, di)| dt_anns[di].score)
.collect();
let single_threshold = [params.t_f];
let MatchResult {
dt_perm,
gt_perm,
dt_matches: dt_matches_pos,
gt_matches: gt_matches_pos,
dt_ignore,
} = match_image(
iou.view(),
>_ignore,
>_iscrowd,
&dt_scores,
&single_threshold,
ParityMode::Strict,
)?;
for (sorted_d, &orig_d) in dt_perm.iter().enumerate() {
let (_row_idx, dt_input_idx) = dts_in_cat[orig_d];
let matched = dt_matches_pos[(0, sorted_d)] >= 0;
let is_ignore = dt_ignore[(0, sorted_d)];
per_dt_matched.insert(dt_input_idx, matched);
per_dt_ignore.insert(dt_input_idx, is_ignore);
}
for (sorted_g, &orig_g) in gt_perm.iter().enumerate() {
let dt_pos = gt_matches_pos[(0, sorted_g)];
if dt_pos < 0 {
continue;
}
if gt_ignore[orig_g] {
continue;
}
let (col_idx, _gt_input_idx) = gts_in_cat[orig_g];
let dt_orig = dt_perm[dt_pos as usize];
let (_row_idx, dt_input_idx) = dts_in_cat[dt_orig];
gt_taken_by.insert(col_idx, dt_input_idx);
}
Ok(())
}
fn bbox_iou_pair(g: crate::dataset::Bbox, d: crate::dataset::Bbox) -> f64 {
let g_x2 = g.x + g.w;
let g_y2 = g.y + g.h;
let d_x2 = d.x + d.w;
let d_y2 = d.y + d.h;
let inter_w = (g_x2.min(d_x2) - g.x.max(d.x)).max(0.0);
let inter_h = (g_y2.min(d_y2) - g.y.max(d.y)).max(0.0);
let inter = inter_w * inter_h;
let union = g.w * g.h + d.w * d.h - inter;
if union <= 0.0 {
0.0
} else {
inter / union
}
}
fn best_same_and_cross(
row_idx: usize,
dt_cat: CategoryId,
cross: Option<ndarray::ArrayView2<'_, f64>>,
gt_local_indices: &[usize],
gt_anns: &[CocoAnnotation],
use_cats: bool,
) -> (f64, i32, f64, i32) {
let cross = match cross {
Some(m) => m,
None => return (0.0, -1, 0.0, -1),
};
if cross.ncols() == 0 {
return (0.0, -1, 0.0, -1);
}
let mut iou_same = 0.0_f64;
let mut best_same: i32 = -1;
let mut iou_cross = 0.0_f64;
let mut best_cross: i32 = -1;
for (col, >_input_idx) in gt_local_indices.iter().enumerate() {
let v = cross[(row_idx, col)];
let g_cat = gt_anns[gt_input_idx].category_id;
let same_class = !use_cats || g_cat == dt_cat;
if same_class {
if v > iou_same {
iou_same = v;
best_same = col as i32;
}
} else if v > iou_cross {
iou_cross = v;
best_cross = col as i32;
}
}
(iou_same, best_same, iou_cross, best_cross)
}
fn pick_bin(
iou_same: f64,
best_same_col: i32,
iou_cross: f64,
best_cross_col: i32,
t_f: f64,
t_b: f64,
) -> DtBinLabel {
let (bin, target) = if iou_same >= t_f {
(DtBin::Dupe, best_same_col)
} else if iou_cross >= t_f {
(DtBin::Cls, best_cross_col)
} else if iou_same >= t_b && iou_same >= iou_cross {
(DtBin::Loc, best_same_col)
} else if iou_cross >= t_b {
(DtBin::Both, best_cross_col)
} else {
(DtBin::Bkg, -1)
};
DtBinLabel {
bin,
target_gt_local_idx: target,
iou_same,
iou_cross,
}
}