use ndarray::{Array2, ArrayView2};
use crate::error::EvalError;
use crate::parity::{argsort_score_desc, ParityMode, IOU_BOUNDARY_EPS};
#[derive(Debug, Clone)]
pub(crate) struct MatchResult {
pub dt_perm: Vec<usize>,
pub gt_perm: Vec<usize>,
pub dt_matches: Array2<i64>,
pub gt_matches: Array2<i64>,
pub dt_ignore: Array2<bool>,
}
#[allow(clippy::needless_pass_by_value)]
pub(crate) fn match_image(
iou_matrix: ArrayView2<'_, f64>,
gt_ignore: &[bool],
gt_iscrowd: &[bool],
dt_scores: &[f64],
iou_thresholds: &[f64],
_parity_mode: ParityMode,
) -> Result<MatchResult, EvalError> {
let n_g = gt_ignore.len();
let n_d = dt_scores.len();
let n_t = iou_thresholds.len();
if gt_iscrowd.len() != n_g {
return Err(EvalError::DimensionMismatch {
detail: format!(
"gt_iscrowd len {} does not match gt_ignore len {}",
gt_iscrowd.len(),
n_g
),
});
}
if iou_matrix.nrows() != n_g || iou_matrix.ncols() != n_d {
return Err(EvalError::DimensionMismatch {
detail: format!(
"iou_matrix is {}x{}, expected {}x{}",
iou_matrix.nrows(),
iou_matrix.ncols(),
n_g,
n_d
),
});
}
let dt_perm = argsort_score_desc(dt_scores);
let mut gt_perm: Vec<usize> = (0..n_g).collect();
gt_perm.sort_by_key(|&i| gt_ignore[i]);
let gt_ignore_sorted: Vec<bool> = gt_perm.iter().map(|&i| gt_ignore[i]).collect();
let gt_iscrowd_sorted: Vec<bool> = gt_perm.iter().map(|&i| gt_iscrowd[i]).collect();
let mut dt_matches = Array2::<i64>::from_elem((n_t, n_d), -1);
let mut gt_matches = Array2::<i64>::from_elem((n_t, n_g), -1);
let mut dt_ignore = Array2::<bool>::default((n_t, n_d));
if n_g == 0 || n_d == 0 || n_t == 0 {
return Ok(MatchResult {
dt_perm,
gt_perm,
dt_matches,
gt_matches,
dt_ignore,
});
}
for (tind, &t) in iou_thresholds.iter().enumerate() {
for (k_d, &d_orig) in dt_perm.iter().enumerate() {
let mut best = t.min(1.0 - IOU_BOUNDARY_EPS);
let mut m: i64 = -1;
for k_g in 0..n_g {
if gt_matches[(tind, k_g)] >= 0 && !gt_iscrowd_sorted[k_g] {
continue;
}
if m >= 0 && !gt_ignore_sorted[m as usize] && gt_ignore_sorted[k_g] {
break;
}
let g_orig = gt_perm[k_g];
let iou = iou_matrix[(g_orig, d_orig)];
if iou < best {
continue;
}
best = iou;
m = k_g as i64;
}
if m < 0 {
continue;
}
let m_idx = m as usize;
dt_ignore[(tind, k_d)] = gt_ignore_sorted[m_idx];
dt_matches[(tind, k_d)] = m;
gt_matches[(tind, m_idx)] = k_d as i64;
}
}
Ok(MatchResult {
dt_perm,
gt_perm,
dt_matches,
gt_matches,
dt_ignore,
})
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::array;
fn run(
iou: &Array2<f64>,
gt_ignore: &[bool],
gt_iscrowd: &[bool],
dt_scores: &[f64],
thresholds: &[f64],
) -> MatchResult {
match_image(
iou.view(),
gt_ignore,
gt_iscrowd,
dt_scores,
thresholds,
ParityMode::Strict,
)
.unwrap()
}
#[test]
fn perfect_match_all_dts_pair_with_distinct_gts() {
let iou = array![[1.0, 0.0], [0.0, 1.0]];
let r = run(&iou, &[false, false], &[false, false], &[0.9, 0.8], &[0.5]);
assert_eq!(r.dt_perm, vec![0, 1]);
assert_eq!(r.gt_perm, vec![0, 1]);
assert_eq!(r.dt_matches[(0, 0)], 0);
assert_eq!(r.dt_matches[(0, 1)], 1);
assert_eq!(r.gt_matches[(0, 0)], 0);
assert_eq!(r.gt_matches[(0, 1)], 1);
assert!(!r.dt_ignore[(0, 0)] && !r.dt_ignore[(0, 1)]);
}
#[test]
fn no_overlap_yields_no_matches() {
let iou = array![[0.0, 0.0], [0.0, 0.0]];
let r = run(&iou, &[false, false], &[false, false], &[0.9, 0.8], &[0.5]);
assert!(r.dt_matches.iter().all(|&v| v == -1));
assert!(r.gt_matches.iter().all(|&v| v == -1));
}
#[test]
fn b1_iou_exactly_at_threshold_matches() {
let iou = array![[0.5]];
let r = run(&iou, &[false], &[false], &[0.9], &[0.5]);
assert_eq!(r.dt_matches[(0, 0)], 0);
}
#[test]
fn b1_iou_at_one_still_matches() {
let iou = array![[1.0]];
let r = run(&iou, &[false], &[false], &[0.9], &[1.0]);
assert_eq!(r.dt_matches[(0, 0)], 0);
}
#[test]
fn b4_crowd_gt_matches_many_dts() {
let iou = array![[1.0, 1.0]];
let r = run(&iou, &[false], &[true], &[0.9, 0.8], &[0.5]);
assert_eq!(r.dt_matches[(0, 0)], 0);
assert_eq!(r.dt_matches[(0, 1)], 0);
assert_eq!(r.gt_matches[(0, 0)], 1);
}
#[test]
fn a1_score_ties_resolve_to_input_order() {
let iou = array![[0.9, 0.6], [0.6, 0.9]];
let r = run(&iou, &[false, false], &[false, false], &[0.7, 0.7], &[0.5]);
assert_eq!(r.dt_perm, vec![0, 1]);
assert_eq!(r.dt_matches[(0, 0)], 0);
assert_eq!(r.dt_matches[(0, 1)], 1);
}
#[test]
fn b3_ignore_gt_terminates_inner_loop_after_real_match() {
let iou = array![[0.8], [0.8]];
let r = run(&iou, &[false, true], &[false, false], &[0.9], &[0.5]);
assert_eq!(r.dt_matches[(0, 0)], 0); assert!(!r.dt_ignore[(0, 0)]);
}
#[test]
fn b6_dt_matched_to_ignore_inherits_flag() {
let iou = array![[0.8]];
let r = run(&iou, &[true], &[false], &[0.9], &[0.5]);
assert_eq!(r.dt_matches[(0, 0)], 0);
assert!(r.dt_ignore[(0, 0)]);
}
#[test]
fn a4_gt_sort_puts_ignore_at_tail_regardless_of_input_order() {
let iou = array![[0.0], [0.9]];
let r = run(&iou, &[true, false], &[false, false], &[0.9], &[0.5]);
assert_eq!(r.gt_perm, vec![1, 0]);
assert_eq!(r.dt_matches[(0, 0)], 0);
assert!(!r.dt_ignore[(0, 0)]);
}
#[test]
fn multiple_thresholds_accumulate_independently() {
let iou = array![[0.6]];
let r = run(&iou, &[false], &[false], &[0.9], &[0.5, 0.55, 0.6, 0.65]);
assert_eq!(r.dt_matches[(0, 0)], 0);
assert_eq!(r.dt_matches[(1, 0)], 0);
assert_eq!(r.dt_matches[(2, 0)], 0);
assert_eq!(r.dt_matches[(3, 0)], -1);
}
#[test]
fn empty_inputs_return_empty_arrays() {
let iou = Array2::<f64>::zeros((0, 0));
let r = run(&iou, &[], &[], &[], &[0.5]);
assert_eq!(r.dt_matches.shape(), &[1, 0]);
assert_eq!(r.gt_matches.shape(), &[1, 0]);
assert!(r.dt_perm.is_empty());
assert!(r.gt_perm.is_empty());
}
#[test]
fn empty_thresholds_yield_zero_row_matrices() {
let iou = array![[0.9]];
let r = run(&iou, &[false], &[false], &[0.9], &[]);
assert_eq!(r.dt_matches.shape(), &[0, 1]);
assert_eq!(r.gt_matches.shape(), &[0, 1]);
}
#[test]
fn iou_matrix_dimension_mismatch_is_typed_error() {
let iou = Array2::<f64>::zeros((1, 1));
let err = match_image(
iou.view(),
&[false, false],
&[false, false],
&[0.9],
&[0.5],
ParityMode::Strict,
)
.unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("iou_matrix"));
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn gt_iscrowd_length_mismatch_is_typed_error() {
let iou = Array2::<f64>::zeros((2, 1));
let err = match_image(
iou.view(),
&[false, false],
&[false],
&[0.9],
&[0.5],
ParityMode::Strict,
)
.unwrap_err();
assert!(matches!(err, EvalError::DimensionMismatch { .. }));
}
}