use std::collections::HashMap;
use std::sync::{Mutex, MutexGuard};
use ndarray::ArrayViewMut2;
use vernier_mask::ops::{intersect_area_offsets, SegmentTable};
use vernier_mask::Rle;
use super::bbox::{BboxAnn, BboxIou};
use super::Similarity;
use crate::dataset::Bbox;
use crate::error::EvalError;
#[derive(Debug, Clone, PartialEq)]
pub struct SegmAnn {
pub rle: Rle,
pub is_crowd: bool,
pub ann_id: i64,
}
#[derive(Debug, Default, Clone, Copy)]
pub struct SegmIou;
impl Similarity for SegmIou {
type Annotation = SegmAnn;
fn compute(
&self,
gts: &[SegmAnn],
dts: &[SegmAnn],
out: &mut ArrayViewMut2<'_, f64>,
) -> Result<(), EvalError> {
let mut scratch = SegmComputeScratch::new();
segm_iou_compute(gts, dts, out, &mut scratch, None)
}
}
#[derive(Default)]
pub(crate) struct SegmComputeScratch {
g_bbox: Vec<BboxAnn>,
d_bbox: Vec<BboxAnn>,
g_area: Vec<u64>,
d_area: Vec<u64>,
g_segments: SegmentTable,
d_segments: SegmentTable,
}
impl SegmComputeScratch {
pub(crate) fn new() -> Self {
Self::default()
}
}
#[derive(Default)]
pub struct SegmGtCache {
inner: Mutex<HashMap<i64, SegmGtEntry>>,
}
#[derive(Clone)]
struct SegmGtEntry {
bbox: BboxAnn,
area: u64,
fg_offsets: Vec<u64>,
}
impl SegmGtCache {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.lock().len()
}
pub fn is_empty(&self) -> bool {
self.lock().is_empty()
}
pub fn clear(&self) {
self.lock().clear();
}
fn lock(&self) -> MutexGuard<'_, HashMap<i64, SegmGtEntry>> {
self.inner.lock().unwrap_or_else(|p| p.into_inner())
}
}
pub(crate) fn segm_iou_compute(
gts: &[SegmAnn],
dts: &[SegmAnn],
out: &mut ArrayViewMut2<'_, f64>,
scratch: &mut SegmComputeScratch,
gt_cache: Option<&SegmGtCache>,
) -> Result<(), EvalError> {
if out.nrows() != gts.len() || out.ncols() != dts.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"segm IoU output is {}x{}, expected {}x{}",
out.nrows(),
out.ncols(),
gts.len(),
dts.len()
),
});
}
if gts.is_empty() || dts.is_empty() {
return Ok(());
}
let (h, w) = (gts[0].rle.h, gts[0].rle.w);
for r in gts.iter().chain(dts.iter()).map(|a| &a.rle) {
if r.h != h || r.w != w {
return Err(EvalError::DimensionMismatch {
detail: format!(
"segm IoU expects all RLEs at [{h}, {w}]; got [{}, {}]",
r.h, r.w
),
});
}
}
scratch.g_bbox.clear();
scratch.g_area.clear();
scratch.g_segments.clear();
populate_gt(gts, scratch, gt_cache);
scratch.d_bbox.clear();
scratch.d_area.clear();
scratch.d_segments.clear();
for d in dts {
let ([x, y, w, h], area) = scratch.d_segments.push_with_bbox_and_area(&d.rle);
scratch.d_bbox.push(BboxAnn {
bbox: Bbox {
x: f64::from(x),
y: f64::from(y),
w: f64::from(w),
h: f64::from(h),
},
is_crowd: false,
});
scratch.d_area.push(area);
}
BboxIou.compute_overlap_mask(&scratch.g_bbox, &scratch.d_bbox, out)?;
for g in 0..gts.len() {
let crowd = gts[g].is_crowd;
let ga = scratch.g_area[g];
let g_seg = scratch.g_segments.row(g);
for d in 0..dts.len() {
if out[[g, d]] <= 0.0 {
continue;
}
let inter = intersect_area_offsets(g_seg, scratch.d_segments.row(d));
let denom = if crowd {
scratch.d_area[d]
} else {
ga + scratch.d_area[d] - inter
};
out[[g, d]] = if denom > 0 && inter > 0 {
(inter as f64) / (denom as f64)
} else {
0.0
};
}
}
Ok(())
}
fn populate_gt(gts: &[SegmAnn], scratch: &mut SegmComputeScratch, cache: Option<&SegmGtCache>) {
let Some(cache) = cache else {
for g in gts {
let ([x, y, w, h], area) = scratch.g_segments.push_with_bbox_and_area(&g.rle);
scratch.g_bbox.push(BboxAnn {
bbox: Bbox {
x: f64::from(x),
y: f64::from(y),
w: f64::from(w),
h: f64::from(h),
},
is_crowd: g.is_crowd,
});
scratch.g_area.push(area);
}
return;
};
let mut inner = cache.lock();
for g in gts {
let entry = inner.entry(g.ann_id).or_insert_with(|| {
let mut fg_offsets = Vec::new();
g.rle.decode_fg_offsets_into(&mut fg_offsets);
SegmGtEntry {
bbox: to_bbox_ann(&g.rle, g.is_crowd),
area: g.rle.area(),
fg_offsets,
}
});
scratch.g_bbox.push(entry.bbox);
scratch.g_area.push(entry.area);
scratch.g_segments.push_segments(&entry.fg_offsets);
}
}
pub(super) fn to_bbox_ann(rle: &Rle, is_crowd: bool) -> BboxAnn {
let [x, y, w, h] = rle.bbox();
BboxAnn {
bbox: Bbox {
x: f64::from(x),
y: f64::from(y),
w: f64::from(w),
h: f64::from(h),
},
is_crowd,
}
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn ann(rle: Rle, is_crowd: bool) -> SegmAnn {
SegmAnn {
rle,
is_crowd,
ann_id: 0,
}
}
fn compute(gts: &[SegmAnn], dts: &[SegmAnn]) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((gts.len(), dts.len()));
SegmIou.compute(gts, dts, &mut out.view_mut()).unwrap();
out
}
#[test]
fn perfect_overlap_is_one() {
let r = Rle::from_counts(2, 2, vec![0, 4]);
let m = compute(&[ann(r.clone(), false)], &[ann(r, false)]);
assert_eq!(m[[0, 0]].to_bits(), 1.0_f64.to_bits());
}
#[test]
fn disjoint_masks_are_zero_via_bbox_prefilter() {
let g = Rle::from_counts(2, 2, vec![0, 1, 3]);
let d = Rle::from_counts(2, 2, vec![3, 1]);
let m = compute(&[ann(g, false)], &[ann(d, false)]);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn partial_overlap_matches_hand_traced_ratio() {
let g = Rle::from_counts(2, 2, vec![0, 1, 3]);
let d = Rle::from_counts(2, 2, vec![0, 2, 2]);
let m = compute(&[ann(g, false)], &[ann(d, false)]);
assert_eq!(m[[0, 0]].to_bits(), (1.0_f64 / 2.0_f64).to_bits());
}
#[test]
fn e1_crowd_gt_uses_dt_area_denominator() {
let gt_full = Rle::from_counts(4, 4, vec![0, 16]);
let dt_pixel = Rle::from_counts(4, 4, vec![5, 1, 10]);
let crowd_m = compute(
&[ann(gt_full.clone(), true)],
&[ann(dt_pixel.clone(), false)],
);
let normal_m = compute(&[ann(gt_full, false)], &[ann(dt_pixel, false)]);
assert_eq!(crowd_m[[0, 0]].to_bits(), 1.0_f64.to_bits());
assert_eq!(normal_m[[0, 0]].to_bits(), (1.0_f64 / 16.0_f64).to_bits());
}
#[test]
fn dt_iscrowd_flag_is_ignored() {
let g = Rle::from_counts(2, 2, vec![0, 1, 3]);
let d = Rle::from_counts(2, 2, vec![0, 2, 2]);
let with_flag = compute(&[ann(g.clone(), false)], &[ann(d.clone(), true)]);
let without = compute(&[ann(g, false)], &[ann(d, false)]);
assert_eq!(with_flag[[0, 0]].to_bits(), without[[0, 0]].to_bits());
}
#[test]
fn empty_gt_or_dt_pair_is_zero_not_nan() {
let empty = Rle::from_counts(2, 2, vec![4]);
let dt_one = Rle::from_counts(2, 2, vec![0, 1, 3]);
let m = compute(&[ann(empty.clone(), false)], &[ann(dt_one, false)]);
assert!(m[[0, 0]].is_finite());
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
let m = compute(&[ann(empty.clone(), false)], &[ann(empty, false)]);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn rle_dimension_mismatch_returns_typed_error() {
let g = ann(Rle::from_counts(4, 4, vec![16]), false);
let d = ann(Rle::from_counts(8, 8, vec![64]), false);
let mut out = Array2::<f64>::zeros((1, 1));
let err = SegmIou
.compute(&[g], &[d], &mut out.view_mut())
.unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("[4, 4]"));
assert!(detail.contains("[8, 8]"));
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn output_shape_mismatch_returns_typed_error() {
let g = ann(Rle::from_counts(2, 2, vec![4]), false);
let d = ann(Rle::from_counts(2, 2, vec![4]), false);
let mut out = Array2::<f64>::zeros((2, 3));
let err = SegmIou
.compute(&[g], &[d], &mut out.view_mut())
.unwrap_err();
assert!(matches!(err, EvalError::DimensionMismatch { .. }));
}
#[test]
fn empty_inputs_return_unchanged_matrix() {
let dts: Vec<SegmAnn> = (0..3)
.map(|_| ann(Rle::from_counts(2, 2, vec![4]), false))
.collect();
let mut out = Array2::<f64>::from_elem((0, 3), 7.0);
SegmIou.compute(&[], &dts, &mut out.view_mut()).unwrap();
assert_eq!(out.shape(), &[0, 3]);
}
#[test]
fn three_by_three_matrix_exercises_prefilter_and_crowd() {
let g0 = Rle::from_counts(4, 4, vec![0, 1, 15]);
let g1 = Rle::from_counts(4, 4, vec![15, 1]);
let g2 = Rle::from_counts(4, 4, vec![0, 16]);
let d0 = Rle::from_counts(4, 4, vec![0, 1, 15]);
let d1 = Rle::from_counts(4, 4, vec![0, 1, 3, 1, 11]);
let d2 = Rle::from_counts(4, 4, vec![15, 1]);
let m = compute(
&[ann(g0, false), ann(g1, false), ann(g2, true)],
&[ann(d0, false), ann(d1, false), ann(d2, false)],
);
assert_eq!(m[[0, 0]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[0, 1]].to_bits(), (1.0_f64 / 2.0_f64).to_bits());
assert_eq!(m[[0, 2]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 0]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 1]].to_bits(), 0.0_f64.to_bits());
assert_eq!(m[[1, 2]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[2, 0]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[2, 1]].to_bits(), 1.0_f64.to_bits());
assert_eq!(m[[2, 2]].to_bits(), 1.0_f64.to_bits());
}
#[test]
fn impl_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<SegmIou>();
}
}