use std::sync::OnceLock;
use ndarray::ArrayViewMut2;
use super::Similarity;
use crate::dataset::Bbox;
use crate::error::EvalError;
fn arch() -> &'static pulp::Arch {
static ARCH: OnceLock<pulp::Arch> = OnceLock::new();
ARCH.get_or_init(pulp::Arch::new)
}
const SMALL_CELL_THRESHOLD: usize = 32;
#[cfg(feature = "bench-histogram")]
pub(crate) mod histogram {
use std::path::Path;
use std::sync::Mutex;
use std::time::Instant;
#[derive(Clone, Copy)]
pub(crate) enum KernelKind {
FullIou,
OverlapMask,
}
impl KernelKind {
fn label(self) -> &'static str {
match self {
Self::FullIou => "FullIou",
Self::OverlapMask => "OverlapMask",
}
}
}
#[derive(Clone, Copy)]
struct Record {
kind: KernelKind,
g: u32,
d: u32,
wall_ns: u64,
}
static RECORDS: Mutex<Vec<Record>> = Mutex::new(Vec::new());
pub(super) struct CallTimer {
kind: KernelKind,
g: u32,
d: u32,
start: Instant,
}
impl CallTimer {
pub(super) fn new(kind: KernelKind, g: usize, d: usize) -> Self {
Self {
kind,
g: u32::try_from(g).unwrap_or(u32::MAX),
d: u32::try_from(d).unwrap_or(u32::MAX),
start: Instant::now(),
}
}
}
impl Drop for CallTimer {
fn drop(&mut self) {
let elapsed = self.start.elapsed().as_nanos();
let wall_ns = u64::try_from(elapsed).unwrap_or(u64::MAX);
let record = Record {
kind: self.kind,
g: self.g,
d: self.d,
wall_ns,
};
if let Ok(mut records) = RECORDS.lock() {
records.push(record);
}
}
}
pub(crate) fn dump_csv(path: &Path) -> std::io::Result<usize> {
use std::io::Write;
let mut records = RECORDS.lock().unwrap_or_else(|p| p.into_inner());
let mut file = std::io::BufWriter::new(std::fs::File::create(path)?);
writeln!(file, "kind,g,d,wall_ns")?;
for r in records.iter() {
writeln!(file, "{},{},{},{}", r.kind.label(), r.g, r.d, r.wall_ns)?;
}
let n = records.len();
records.clear();
file.flush()?;
Ok(n)
}
#[cfg(test)]
pub(super) fn len() -> usize {
let records = RECORDS.lock().unwrap_or_else(|p| p.into_inner());
records.len()
}
}
#[derive(Debug, Clone, Copy, PartialEq)]
pub struct BboxAnn {
pub bbox: Bbox,
pub is_crowd: bool,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct BboxIou;
impl Similarity for BboxIou {
type Annotation = BboxAnn;
fn compute(
&self,
gts: &[BboxAnn],
dts: &[BboxAnn],
out: &mut ArrayViewMut2<'_, f64>,
) -> Result<(), EvalError> {
if out.nrows() != gts.len() || out.ncols() != dts.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"bbox IoU output is {}x{}, expected {}x{}",
out.nrows(),
out.ncols(),
gts.len(),
dts.len()
),
});
}
if gts.is_empty() || dts.is_empty() {
return Ok(());
}
#[cfg(feature = "bench-histogram")]
let _guard =
histogram::CallTimer::new(histogram::KernelKind::FullIou, gts.len(), dts.len());
if gts.len().saturating_mul(dts.len()) < SMALL_CELL_THRESHOLD {
full_iou_inner(gts, dts, out);
} else {
arch().dispatch(|| full_iou_inner(gts, dts, out));
}
Ok(())
}
}
#[inline(always)]
fn full_iou_inner(gts: &[BboxAnn], dts: &[BboxAnn], out: &mut ArrayViewMut2<'_, f64>) {
for (g, gt) in gts.iter().enumerate() {
let gxa = gt.bbox.x;
let gya = gt.bbox.y;
let gw = gt.bbox.w;
let gh = gt.bbox.h;
let gxb = gxa + gw;
let gyb = gya + gh;
let g_area = gw * gh;
let mut row = out.row_mut(g);
if gt.is_crowd {
for (d, dt) in dts.iter().enumerate() {
row[d] = iou_pair(gxa, gya, gxb, gyb, dt.bbox, CrowdDenom);
}
} else {
for (d, dt) in dts.iter().enumerate() {
row[d] = iou_pair(gxa, gya, gxb, gyb, dt.bbox, UnionDenom(g_area));
}
}
}
}
impl BboxIou {
pub(super) fn compute_overlap_mask(
&self,
gts: &[BboxAnn],
dts: &[BboxAnn],
out: &mut ArrayViewMut2<'_, f64>,
) -> Result<(), EvalError> {
if out.nrows() != gts.len() || out.ncols() != dts.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"bbox overlap-mask output is {}x{}, expected {}x{}",
out.nrows(),
out.ncols(),
gts.len(),
dts.len()
),
});
}
if gts.is_empty() || dts.is_empty() {
return Ok(());
}
#[cfg(feature = "bench-histogram")]
let _guard =
histogram::CallTimer::new(histogram::KernelKind::OverlapMask, gts.len(), dts.len());
if gts.len().saturating_mul(dts.len()) < SMALL_CELL_THRESHOLD {
overlap_mask_inner(gts, dts, out);
} else {
arch().dispatch(|| overlap_mask_inner(gts, dts, out));
}
Ok(())
}
}
#[inline(always)]
fn overlap_mask_inner(gts: &[BboxAnn], dts: &[BboxAnn], out: &mut ArrayViewMut2<'_, f64>) {
for (g, gt) in gts.iter().enumerate() {
let gxa = gt.bbox.x;
let gya = gt.bbox.y;
let gxb = gxa + gt.bbox.w;
let gyb = gya + gt.bbox.h;
let mut row = out.row_mut(g);
for (d, dt) in dts.iter().enumerate() {
let dxa = dt.bbox.x;
let dya = dt.bbox.y;
let dxb = dxa + dt.bbox.w;
let dyb = dya + dt.bbox.h;
let iw = gxb.min(dxb) - gxa.max(dxa);
let ih = gyb.min(dyb) - gya.max(dya);
row[d] = if iw > 0.0 && ih > 0.0 { 1.0 } else { 0.0 };
}
}
}
trait Denom: Copy {
fn denom(self, d_area: f64, inter: f64) -> f64;
}
#[derive(Clone, Copy)]
struct CrowdDenom;
impl Denom for CrowdDenom {
#[inline(always)]
fn denom(self, d_area: f64, _inter: f64) -> f64 {
d_area
}
}
#[derive(Clone, Copy)]
struct UnionDenom(f64);
impl Denom for UnionDenom {
#[inline(always)]
fn denom(self, d_area: f64, inter: f64) -> f64 {
self.0 + d_area - inter
}
}
#[inline(always)]
fn iou_pair<D: Denom>(gxa: f64, gya: f64, gxb: f64, gyb: f64, dt: Bbox, denom: D) -> f64 {
let dxa = dt.x;
let dya = dt.y;
let dw = dt.w;
let dh = dt.h;
let dxb = dxa + dw;
let dyb = dya + dh;
let d_area = dw * dh;
let iw = (gxb.min(dxb) - gxa.max(dxa)).max(0.0);
let ih = (gyb.min(dyb) - gya.max(dya)).max(0.0);
let inter = iw * ih;
let denom = denom.denom(d_area, inter);
if denom > 0.0 {
inter / denom
} else {
0.0
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn make_ann(x: f64, y: f64, w: f64, h: f64, is_crowd: bool) -> BboxAnn {
BboxAnn {
bbox: Bbox { x, y, w, h },
is_crowd,
}
}
fn compute(gts: &[BboxAnn], dts: &[BboxAnn]) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((gts.len(), dts.len()));
BboxIou.compute(gts, dts, &mut out.view_mut()).unwrap();
out
}
#[test]
fn perfect_overlap_is_one() {
let gts = [make_ann(0.0, 0.0, 10.0, 10.0, false)];
let dts = [make_ann(0.0, 0.0, 10.0, 10.0, false)];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 1.0_f64.to_bits());
}
#[test]
fn no_overlap_is_zero() {
let gts = [make_ann(0.0, 0.0, 1.0, 1.0, false)];
let dts = [make_ann(10.0, 10.0, 1.0, 1.0, false)];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn i4_edge_sharing_is_zero() {
let gts = [make_ann(0.0, 0.0, 1.0, 1.0, false)];
let dts = [make_ann(1.0, 0.0, 1.0, 1.0, false)];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn quarter_overlap_matches_hand_traced_value() {
let gts = [make_ann(0.0, 0.0, 2.0, 2.0, false)];
let dts = [make_ann(1.0, 1.0, 2.0, 2.0, false)];
let m = compute(>s, &dts);
let expected = 1.0_f64 / 7.0_f64;
assert_eq!(m[[0, 0]].to_bits(), expected.to_bits());
}
#[test]
fn e1_crowd_gt_uses_dt_area_denominator() {
let gts_crowd = [make_ann(0.0, 0.0, 10.0, 10.0, true)];
let gts_normal = [make_ann(0.0, 0.0, 10.0, 10.0, false)];
let dts = [make_ann(2.0, 2.0, 1.0, 1.0, false)];
let crowd_m = compute(>s_crowd, &dts);
let normal_m = compute(>s_normal, &dts);
assert_eq!(crowd_m[[0, 0]].to_bits(), 1.0_f64.to_bits());
let expected_normal = 1.0_f64 / 100.0_f64;
assert_eq!(normal_m[[0, 0]].to_bits(), expected_normal.to_bits());
}
#[test]
fn dt_iscrowd_flag_is_ignored() {
let gts = [make_ann(0.0, 0.0, 2.0, 2.0, false)];
let dts_marked = [make_ann(1.0, 1.0, 2.0, 2.0, true)];
let dts_clean = [make_ann(1.0, 1.0, 2.0, 2.0, false)];
let with_flag = compute(>s, &dts_marked);
let without = compute(>s, &dts_clean);
assert_eq!(with_flag[[0, 0]].to_bits(), without[[0, 0]].to_bits());
}
#[test]
fn zero_area_gt_with_zero_inter_yields_zero_not_nan() {
let gts = [make_ann(5.0, 5.0, 0.0, 5.0, false)];
let dts = [make_ann(0.0, 0.0, 10.0, 10.0, false)];
let m = compute(>s, &dts);
assert!(m[[0, 0]].is_finite());
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn zero_area_gt_and_dt_both_zero_yields_zero_via_denom_guard() {
let gts = [make_ann(5.0, 5.0, 0.0, 0.0, false)];
let dts = [make_ann(5.0, 5.0, 0.0, 0.0, false)];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn dimension_mismatch_returns_typed_error() {
let gts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 2];
let dts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 3];
let mut out = Array2::<f64>::zeros((1, 1));
let err = BboxIou
.compute(>s, &dts, &mut out.view_mut())
.unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("2"));
assert!(detail.contains("3"));
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn empty_inputs_return_unchanged_matrix() {
let dts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 3];
let mut out = Array2::<f64>::from_elem((0, 3), 7.0);
BboxIou.compute(&[], &dts, &mut out.view_mut()).unwrap();
assert_eq!(out.shape(), &[0, 3]);
}
#[test]
fn three_by_three_matrix_all_pairs_evaluated() {
let gts = [
make_ann(0.0, 0.0, 2.0, 2.0, false),
make_ann(5.0, 5.0, 2.0, 2.0, false),
make_ann(0.0, 0.0, 10.0, 10.0, true),
];
let dts = [
make_ann(0.0, 0.0, 2.0, 2.0, false),
make_ann(1.0, 1.0, 2.0, 2.0, false),
make_ann(20.0, 20.0, 1.0, 1.0, false),
];
let m = compute(>s, &dts);
assert_eq!(m[[0, 0]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[0, 1]].to_bits(), (1.0_f64 / 7.0_f64).to_bits());
assert_eq!(m[[0, 2]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 0]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 1]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 2]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[2, 0]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[2, 1]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[2, 2]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn impl_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<BboxIou>();
}
fn overlap_mask(gts: &[BboxAnn], dts: &[BboxAnn]) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((gts.len(), dts.len()));
BboxIou
.compute_overlap_mask(gts, dts, &mut out.view_mut())
.unwrap();
out
}
#[test]
fn overlap_mask_writes_only_zero_or_one_sentinels() {
let gts = [
make_ann(0.0, 0.0, 2.0, 2.0, false),
make_ann(5.0, 5.0, 2.0, 2.0, false),
make_ann(0.0, 0.0, 10.0, 10.0, true),
make_ann(5.0, 5.0, 0.0, 5.0, false), ];
let dts = [
make_ann(0.0, 0.0, 2.0, 2.0, false),
make_ann(1.0, 1.0, 2.0, 2.0, false),
make_ann(2.0, 0.0, 1.0, 1.0, false), make_ann(20.0, 20.0, 1.0, 1.0, false),
];
let m = overlap_mask(>s, &dts);
let zero = 0.0_f64.to_bits();
let one = 1.0_f64.to_bits();
for g in 0..gts.len() {
for d in 0..dts.len() {
let bits = m[[g, d]].to_bits();
assert!(
bits == zero || bits == one,
"overlap_mask[{g},{d}] = {} ({:#x}); expected 0.0 or 1.0",
m[[g, d]],
bits,
);
}
}
}
#[test]
fn overlap_mask_survivor_bit_matches_full_iou() {
let gts = [
make_ann(0.0, 0.0, 2.0, 2.0, false),
make_ann(5.0, 5.0, 2.0, 2.0, false),
make_ann(0.0, 0.0, 10.0, 10.0, true),
make_ann(5.0, 5.0, 0.0, 5.0, false),
make_ann(5.0, 5.0, 0.0, 0.0, false), ];
let dts = [
make_ann(0.0, 0.0, 2.0, 2.0, false),
make_ann(1.0, 1.0, 2.0, 2.0, false),
make_ann(2.0, 0.0, 1.0, 1.0, false),
make_ann(20.0, 20.0, 1.0, 1.0, false),
make_ann(5.0, 5.0, 0.0, 0.0, false),
];
let iou = compute(>s, &dts);
let mask = overlap_mask(>s, &dts);
for g in 0..gts.len() {
for d in 0..dts.len() {
let iou_pos = iou[[g, d]] > 0.0;
let mask_pos = mask[[g, d]] > 0.0;
assert_eq!(
iou_pos,
mask_pos,
"survivor-bit mismatch at ({g},{d}): iou={}, mask={}",
iou[[g, d]],
mask[[g, d]],
);
}
}
}
#[test]
fn overlap_mask_dimension_mismatch_returns_typed_error() {
let gts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 2];
let dts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 3];
let mut out = Array2::<f64>::zeros((1, 1));
let err = BboxIou
.compute_overlap_mask(>s, &dts, &mut out.view_mut())
.unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("2"));
assert!(detail.contains("3"));
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn overlap_mask_empty_inputs_return_unchanged_matrix() {
let dts = [make_ann(0.0, 0.0, 1.0, 1.0, false); 3];
let mut out = Array2::<f64>::from_elem((0, 3), 7.0);
BboxIou
.compute_overlap_mask(&[], &dts, &mut out.view_mut())
.unwrap();
assert_eq!(out.shape(), &[0, 3]);
}
#[cfg(feature = "bench-histogram")]
#[test]
fn histogram_records_kernel_calls_when_feature_on() {
use super::histogram;
let gts = [make_ann(0.0, 0.0, 2.0, 2.0, false)];
let dts = [make_ann(1.0, 1.0, 2.0, 2.0, false)];
let _ = compute(>s, &dts);
let _ = overlap_mask(>s, &dts);
assert!(histogram::len() >= 2);
let tmp = std::env::temp_dir().join("vernier-bench-histogram-smoke.csv");
let n = histogram::dump_csv(&tmp).expect("dump_csv should succeed");
assert!(n >= 2);
let csv = std::fs::read_to_string(&tmp).expect("dumped file should be readable");
assert!(csv.starts_with("kind,g,d,wall_ns\n"));
assert!(csv.lines().count() > n);
assert!(
csv.contains("FullIou,"),
"expected FullIou rows in CSV: {csv}"
);
assert!(
csv.contains("OverlapMask,"),
"expected OverlapMask rows in CSV: {csv}"
);
std::fs::remove_file(&tmp).ok();
}
}