use ndarray::ArrayViewMut2;
use super::Similarity;
use crate::dataset::Bbox;
use crate::error::EvalError;
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BboxAnn {
pub bbox: Bbox,
pub is_crowd: bool,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct BboxIou;
impl Similarity for BboxIou {
type Annotation = BboxAnn;
fn compute(
&self,
gts: &[BboxAnn],
dts: &[BboxAnn],
out: &mut ArrayViewMut2<'_, f64>,
) -> Result<(), EvalError> {
if out.nrows() != gts.len() || out.ncols() != dts.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"bbox IoU output is {}x{}, expected {}x{}",
out.nrows(),
out.ncols(),
gts.len(),
dts.len()
),
});
}
if gts.is_empty() || dts.is_empty() {
return Ok(());
}
let arch = pulp::Arch::new();
arch.dispatch(|| {
for (g, gt) in gts.iter().enumerate() {
let gxa = gt.bbox.x;
let gya = gt.bbox.y;
let gw = gt.bbox.w;
let gh = gt.bbox.h;
let gxb = gxa + gw;
let gyb = gya + gh;
let g_area = gw * gh;
let mut row = out.row_mut(g);
if gt.is_crowd {
for (d, dt) in dts.iter().enumerate() {
row[d] = iou_pair(gxa, gya, gxb, gyb, dt.bbox, CrowdDenom);
}
} else {
for (d, dt) in dts.iter().enumerate() {
row[d] = iou_pair(gxa, gya, gxb, gyb, dt.bbox, UnionDenom(g_area));
}
}
}
});
Ok(())
}
}
trait Denom: Copy {
fn denom(self, d_area: f64, inter: f64) -> f64;
}
#[derive(Clone, Copy)]
struct CrowdDenom;
impl Denom for CrowdDenom {
#[inline(always)]
fn denom(self, d_area: f64, _inter: f64) -> f64 {
d_area
}
}
#[derive(Clone, Copy)]
struct UnionDenom(f64);
impl Denom for UnionDenom {
#[inline(always)]
fn denom(self, d_area: f64, inter: f64) -> f64 {
self.0 + d_area - inter
}
}
#[inline(always)]
fn iou_pair<D: Denom>(gxa: f64, gya: f64, gxb: f64, gyb: f64, dt: Bbox, denom: D) -> f64 {
let dxa = dt.x;
let dya = dt.y;
let dw = dt.w;
let dh = dt.h;
let dxb = dxa + dw;
let dyb = dya + dh;
let d_area = dw * dh;
let iw = (gxb.min(dxb) - gxa.max(dxa)).max(0.0);
let ih = (gyb.min(dyb) - gya.max(dya)).max(0.0);
let inter = iw * ih;
let denom = denom.denom(d_area, inter);
if denom > 0.0 {
inter / denom
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn make_ann(x: f64, y: f64, w: f64, h: f64, is_crowd: bool) -> BboxAnn {
BboxAnn {
bbox: Bbox { x, y, w, h },
is_crowd,
}
}
fn compute(gts: &[BboxAnn], dts: &[BboxAnn]) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((gts.len(), dts.len()));
BboxIou.compute(gts, dts, &mut out.view_mut()).unwrap();
out
}
#[test]
fn perfect_overlap_is_one() {
let gts = [make_ann(0.0, 0.0, 10.0, 10.0, false)];
let dts = [make_ann(0.0, 0.0, 10.0, 10.0, false)];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 1.0_f64.to_bits());
}
#[test]
fn no_overlap_is_zero() {
let gts = [make_ann(0.0, 0.0, 1.0, 1.0, false)];
let dts = [make_ann(10.0, 10.0, 1.0, 1.0, false)];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn i4_edge_sharing_is_zero() {
let gts = [make_ann(0.0, 0.0, 1.0, 1.0, false)];
let dts = [make_ann(1.0, 0.0, 1.0, 1.0, false)];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn quarter_overlap_matches_hand_traced_value() {
let gts = [make_ann(0.0, 0.0, 2.0, 2.0, false)];
let dts = [make_ann(1.0, 1.0, 2.0, 2.0, false)];
let m = compute(>s, &dts);
let expected = 1.0_f64 / 7.0_f64;
assert_eq!(m[[0, 0]].to_bits(), expected.to_bits());
}
#[test]
fn e1_crowd_gt_uses_dt_area_denominator() {
let gts_crowd = [make_ann(0.0, 0.0, 10.0, 10.0, true)];
let gts_normal = [make_ann(0.0, 0.0, 10.0, 10.0, false)];
let dts = [make_ann(2.0, 2.0, 1.0, 1.0, false)];
let crowd_m = compute(>s_crowd, &dts);
let normal_m = compute(>s_normal, &dts);
assert_eq!(crowd_m[[0, 0]].to_bits(), 1.0_f64.to_bits());
let expected_normal = 1.0_f64 / 100.0_f64;
assert_eq!(normal_m[[0, 0]].to_bits(), expected_normal.to_bits());
}
#[test]
fn dt_iscrowd_flag_is_ignored() {
let gts = [make_ann(0.0, 0.0, 2.0, 2.0, false)];
let dts_marked = [make_ann(1.0, 1.0, 2.0, 2.0, true)];
let dts_clean = [make_ann(1.0, 1.0, 2.0, 2.0, false)];
let with_flag = compute(>s, &dts_marked);
let without = compute(>s, &dts_clean);
assert_eq!(with_flag[[0, 0]].to_bits(), without[[0, 0]].to_bits());
}
#[test]
fn zero_area_gt_with_zero_inter_yields_zero_not_nan() {
let gts = [make_ann(5.0, 5.0, 0.0, 5.0, false)];
let dts = [make_ann(0.0, 0.0, 10.0, 10.0, false)];
let m = compute(>s, &dts);
assert!(m[[0, 0]].is_finite());
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn zero_area_gt_and_dt_both_zero_yields_zero_via_denom_guard() {
let gts = [make_ann(5.0, 5.0, 0.0, 0.0, false)];
let dts = [make_ann(5.0, 5.0, 0.0, 0.0, false)];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn dimension_mismatch_returns_typed_error() {
let gts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 2];
let dts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 3];
let mut out = Array2::<f64>::zeros((1, 1));
let err = BboxIou
.compute(>s, &dts, &mut out.view_mut())
.unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("2"));
assert!(detail.contains("3"));
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn empty_inputs_return_unchanged_matrix() {
let dts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 3];
let mut out = Array2::<f64>::from_elem((0, 3), 7.0);
BboxIou.compute(&[], &dts, &mut out.view_mut()).unwrap();
assert_eq!(out.shape(), &[0, 3]);
}
#[test]
fn three_by_three_matrix_all_pairs_evaluated() {
let gts = [
make_ann(0.0, 0.0, 2.0, 2.0, false),
make_ann(5.0, 5.0, 2.0, 2.0, false),
make_ann(0.0, 0.0, 10.0, 10.0, true),
];
let dts = [
make_ann(0.0, 0.0, 2.0, 2.0, false),
make_ann(1.0, 1.0, 2.0, 2.0, false),
make_ann(20.0, 20.0, 1.0, 1.0, false),
];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[0, 1]].to_bits(), (1.0_f64 / 7.0_f64).to_bits());
assert_eq!(m[[0, 2]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 0]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 1]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 2]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[2, 0]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[2, 1]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[2, 2]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn impl_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<BboxIou>();
}
}