use std::collections::HashMap;
use std::sync::{Mutex, MutexGuard};
use ndarray::ArrayViewMut2;
use vernier_mask::ops::{
boundary_band_segments_into, intersect_area_offsets, ErodeScratch, SegmentTable,
};
use super::bbox::{BboxAnn, BboxIou};
use super::segm::{to_bbox_ann, SegmAnn};
use super::Similarity;
use crate::boundary_parity::BOUNDARY_DILATION_RATIO_DEFAULT;
use crate::error::EvalError;
#[derive(Default)]
pub(crate) struct BoundaryComputeScratch {
erode: ErodeScratch,
g_bbox: Vec<BboxAnn>,
d_bbox: Vec<BboxAnn>,
g_mask_area: Vec<u64>,
d_mask_area: Vec<u64>,
g_band_area: Vec<u64>,
d_band_area: Vec<u64>,
g_mask_segments: SegmentTable,
g_band_segments: SegmentTable,
d_mask_segments: SegmentTable,
d_band_segments: SegmentTable,
g_active: Vec<bool>,
d_active: Vec<bool>,
}
impl BoundaryComputeScratch {
pub(crate) fn new() -> Self {
Self::default()
}
}
#[derive(Default)]
pub struct BoundaryGtCache {
inner: Mutex<CacheInner>,
}
#[derive(Default)]
struct CacheInner {
bands: HashMap<i64, BoundaryGtEntry>,
ratio: Option<f64>,
}
#[derive(Clone)]
struct BoundaryGtEntry {
band_area: u64,
mask_offsets: Vec<u64>,
band_offsets: Vec<u64>,
}
impl BoundaryGtCache {
pub fn new() -> Self {
Self::default()
}
pub fn len(&self) -> usize {
self.lock().bands.len()
}
pub fn is_empty(&self) -> bool {
self.lock().bands.is_empty()
}
pub fn clear(&self) {
let mut inner = self.lock();
inner.bands.clear();
inner.ratio = None;
}
pub(crate) fn align_ratio(&self, ratio: f64) {
let mut inner = self.lock();
if inner.ratio != Some(ratio) {
inner.bands.clear();
inner.ratio = Some(ratio);
}
}
fn lock(&self) -> MutexGuard<'_, CacheInner> {
self.inner.lock().unwrap_or_else(|p| p.into_inner())
}
}
#[derive(Debug, Clone, Copy)]
pub struct BoundaryIou {
pub dilation_ratio: f64,
}
impl Default for BoundaryIou {
fn default() -> Self {
Self {
dilation_ratio: BOUNDARY_DILATION_RATIO_DEFAULT,
}
}
}
impl Similarity for BoundaryIou {
type Annotation = SegmAnn;
fn compute(
&self,
gts: &[SegmAnn],
dts: &[SegmAnn],
out: &mut ArrayViewMut2<'_, f64>,
) -> Result<(), EvalError> {
let mut scratch = BoundaryComputeScratch::new();
boundary_iou_compute(self.dilation_ratio, gts, dts, out, &mut scratch, None)
}
}
pub(crate) fn boundary_iou_compute(
dilation_ratio: f64,
gts: &[SegmAnn],
dts: &[SegmAnn],
out: &mut ArrayViewMut2<'_, f64>,
scratch: &mut BoundaryComputeScratch,
gt_cache: Option<&BoundaryGtCache>,
) -> Result<(), EvalError> {
if out.nrows() != gts.len() || out.ncols() != dts.len() {
return Err(EvalError::DimensionMismatch {
detail: format!(
"boundary IoU output is {}x{}, expected {}x{}",
out.nrows(),
out.ncols(),
gts.len(),
dts.len()
),
});
}
if gts.is_empty() || dts.is_empty() {
return Ok(());
}
let (h, w) = (gts[0].rle.h, gts[0].rle.w);
for r in gts.iter().chain(dts.iter()).map(|a| &a.rle) {
if r.h != h || r.w != w {
return Err(EvalError::DimensionMismatch {
detail: format!(
"boundary IoU expects all RLEs at [{h}, {w}]; got [{}, {}]",
r.h, r.w
),
});
}
}
scratch.g_bbox.clear();
scratch
.g_bbox
.extend(gts.iter().map(|g| to_bbox_ann(&g.rle, g.is_crowd)));
scratch.d_bbox.clear();
scratch
.d_bbox
.extend(dts.iter().map(|d| to_bbox_ann(&d.rle, false)));
BboxIou.compute_overlap_mask(&scratch.g_bbox, &scratch.d_bbox, out)?;
scratch.g_active.clear();
scratch.g_active.resize(gts.len(), false);
scratch.d_active.clear();
scratch.d_active.resize(dts.len(), false);
for g in 0..gts.len() {
for d in 0..dts.len() {
if out[[g, d]] > 0.0 {
scratch.g_active[g] = true;
scratch.d_active[d] = true;
}
}
}
scratch.g_mask_area.clear();
scratch.g_band_area.clear();
scratch.g_mask_segments.clear();
scratch.g_band_segments.clear();
for (g_idx, g) in gts.iter().enumerate() {
if !scratch.g_active[g_idx] {
scratch.g_mask_area.push(0);
scratch.g_mask_segments.push_segments(&[]);
scratch.g_band_area.push(0);
scratch.g_band_segments.push_segments(&[]);
continue;
}
scratch.g_mask_area.push(g.rle.area());
if g.is_crowd {
scratch.g_mask_segments.push_from_rle(&g.rle);
scratch.g_band_area.push(0);
scratch.g_band_segments.push_segments(&[]);
} else {
populate_gt_entry(g, dilation_ratio, scratch, gt_cache)?;
}
}
scratch.d_mask_area.clear();
scratch.d_band_area.clear();
scratch.d_mask_segments.clear();
scratch.d_band_segments.clear();
for (d_idx, d) in dts.iter().enumerate() {
if !scratch.d_active[d_idx] {
scratch.d_mask_area.push(0);
scratch.d_mask_segments.push_segments(&[]);
scratch.d_band_area.push(0);
scratch.d_band_segments.push_segments(&[]);
continue;
}
scratch.d_mask_area.push(d.rle.area());
scratch.d_mask_segments.push_from_rle(&d.rle);
let band_area = boundary_band_segments_into(
&d.rle,
dilation_ratio,
&mut scratch.erode,
&mut scratch.d_band_segments,
)?;
scratch.d_band_area.push(band_area);
}
for g in 0..gts.len() {
let crowd = gts[g].is_crowd;
let g_mask_seg = scratch.g_mask_segments.row(g);
let g_band_seg = scratch.g_band_segments.row(g);
for d in 0..dts.len() {
if out[[g, d]] <= 0.0 {
continue;
}
let inter_mask = intersect_area_offsets(g_mask_seg, scratch.d_mask_segments.row(d));
let mask_denom = if crowd {
scratch.d_mask_area[d]
} else {
scratch.g_mask_area[g] + scratch.d_mask_area[d] - inter_mask
};
let mask_iou = if mask_denom > 0 && inter_mask > 0 {
(inter_mask as f64) / (mask_denom as f64)
} else {
0.0
};
if crowd {
out[[g, d]] = mask_iou;
continue;
}
let inter_bound = intersect_area_offsets(g_band_seg, scratch.d_band_segments.row(d));
let bound_denom = scratch.g_band_area[g] + scratch.d_band_area[d] - inter_bound;
let bound_iou = if bound_denom > 0 && inter_bound > 0 {
(inter_bound as f64) / (bound_denom as f64)
} else {
0.0
};
out[[g, d]] = mask_iou.min(bound_iou);
}
}
Ok(())
}
fn populate_gt_entry(
ann: &SegmAnn,
ratio: f64,
scratch: &mut BoundaryComputeScratch,
cache: Option<&BoundaryGtCache>,
) -> Result<(), EvalError> {
if let Some(cache) = cache {
let mut inner = cache.lock();
if let Some(entry) = inner.bands.get(&ann.ann_id) {
scratch.g_band_area.push(entry.band_area);
scratch.g_mask_segments.push_segments(&entry.mask_offsets);
scratch.g_band_segments.push_segments(&entry.band_offsets);
return Ok(());
}
scratch.g_mask_segments.push_from_rle(&ann.rle);
let band_area = boundary_band_segments_into(
&ann.rle,
ratio,
&mut scratch.erode,
&mut scratch.g_band_segments,
)?;
scratch.g_band_area.push(band_area);
let mask_offsets = scratch.g_mask_segments.last_row().to_vec();
let band_offsets = scratch.g_band_segments.last_row().to_vec();
inner.bands.insert(
ann.ann_id,
BoundaryGtEntry {
band_area,
mask_offsets,
band_offsets,
},
);
return Ok(());
}
scratch.g_mask_segments.push_from_rle(&ann.rle);
let band_area = boundary_band_segments_into(
&ann.rle,
ratio,
&mut scratch.erode,
&mut scratch.g_band_segments,
)?;
scratch.g_band_area.push(band_area);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
use vernier_mask::ops::boundary_band;
use vernier_mask::Rle;
fn ann(rle: Rle, is_crowd: bool) -> SegmAnn {
SegmAnn {
rle,
is_crowd,
ann_id: 0,
}
}
fn compute(gts: &[SegmAnn], dts: &[SegmAnn]) -> Array2<f64> {
let mut out = Array2::<f64>::zeros((gts.len(), dts.len()));
BoundaryIou::default()
.compute(gts, dts, &mut out.view_mut())
.unwrap();
out
}
fn filled_rect(h: u32, w: u32, x0: u32, y0: u32, rw: u32, rh: u32) -> Rle {
let mut raster = vec![0u8; (h as usize) * (w as usize)];
for x in x0..x0 + rw {
for y in y0..y0 + rh {
raster[(x as usize) * (h as usize) + (y as usize)] = 1;
}
}
Rle::from_raster_bytes(&raster, h, w).unwrap()
}
#[test]
fn perfect_overlap_is_one() {
let r = Rle::from_counts(2, 2, vec![0, 4]);
let m = compute(&[ann(r.clone(), false)], &[ann(r, false)]);
assert_eq!(m[[0, 0]].to_bits(), 1.0_f64.to_bits());
}
#[test]
fn disjoint_masks_are_zero_via_bbox_prefilter() {
let g = Rle::from_counts(2, 2, vec![0, 1, 3]);
let d = Rle::from_counts(2, 2, vec![3, 1]);
let m = compute(&[ann(g, false)], &[ann(d, false)]);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn small_mask_band_clamps_to_full_mask() {
let g = Rle::from_counts(4, 4, vec![0, 1, 15]);
let d = Rle::from_counts(4, 4, vec![0, 2, 14]);
let m = compute(&[ann(g, false)], &[ann(d, false)]);
assert_eq!(m[[0, 0]].to_bits(), (1.0_f64 / 2.0_f64).to_bits());
}
#[test]
fn partial_overlap_equals_min_of_mask_and_bound_iou() {
let h = 20;
let w = 20;
let gt = filled_rect(h, w, 0, 5, 10, 10);
let dt = filled_rect(h, w, 5, 5, 10, 10);
let kernel = BoundaryIou {
dilation_ratio: 0.04,
};
let mut out = Array2::<f64>::zeros((1, 1));
kernel
.compute(
&[ann(gt.clone(), false)],
&[ann(dt.clone(), false)],
&mut out.view_mut(),
)
.unwrap();
let g_band = boundary_band(>, 0.04).unwrap();
let d_band = boundary_band(&dt, 0.04).unwrap();
let inter_mask = gt.intersect_area(&dt).unwrap();
let mask_iou = (inter_mask as f64) / ((gt.area() + dt.area() - inter_mask) as f64);
let inter_bound = g_band.intersect_area(&d_band).unwrap();
let bound_iou =
(inter_bound as f64) / ((g_band.area() + d_band.area() - inter_bound) as f64);
let expected = mask_iou.min(bound_iou);
assert!(bound_iou < mask_iou);
assert_eq!(out[[0, 0]].to_bits(), expected.to_bits());
}
#[test]
fn e1_o1_crowd_gt_uses_mask_iou_alone() {
let gt_full = Rle::from_counts(4, 4, vec![0, 16]);
let dt_pixel = Rle::from_counts(4, 4, vec![5, 1, 10]);
let m = compute(&[ann(gt_full, true)], &[ann(dt_pixel, false)]);
assert_eq!(m[[0, 0]].to_bits(), 1.0_f64.to_bits());
}
#[test]
fn dt_iscrowd_flag_is_ignored() {
let g = Rle::from_counts(2, 2, vec![0, 1, 3]);
let d = Rle::from_counts(2, 2, vec![0, 2, 2]);
let with_flag = compute(&[ann(g.clone(), false)], &[ann(d.clone(), true)]);
let without = compute(&[ann(g, false)], &[ann(d, false)]);
assert_eq!(with_flag[[0, 0]].to_bits(), without[[0, 0]].to_bits());
}
#[test]
fn empty_masks_pair_is_zero_not_nan() {
let empty = Rle::from_counts(2, 2, vec![4]);
let dt_one = Rle::from_counts(2, 2, vec![0, 1, 3]);
let m = compute(&[ann(empty.clone(), false)], &[ann(dt_one, false)]);
assert!(m[[0, 0]].is_finite());
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
let m = compute(&[ann(empty.clone(), false)], &[ann(empty, false)]);
assert_eq!(m[[0, 0]].to_bits(), 0.0_f64.to_bits());
}
#[test]
fn empty_inputs_return_unchanged_matrix() {
let dts: Vec<SegmAnn> = (0..3)
.map(|_| ann(Rle::from_counts(2, 2, vec![4]), false))
.collect();
let mut out = Array2::<f64>::from_elem((0, 3), 7.0);
BoundaryIou::default()
.compute(&[], &dts, &mut out.view_mut())
.unwrap();
assert_eq!(out.shape(), &[0, 3]);
}
#[test]
fn output_shape_mismatch_returns_typed_error() {
let g = ann(Rle::from_counts(2, 2, vec![4]), false);
let d = ann(Rle::from_counts(2, 2, vec![4]), false);
let mut out = Array2::<f64>::zeros((2, 3));
let err = BoundaryIou::default()
.compute(&[g], &[d], &mut out.view_mut())
.unwrap_err();
assert!(matches!(err, EvalError::DimensionMismatch { .. }));
}
#[test]
fn rle_dimension_mismatch_returns_typed_error() {
let g = ann(Rle::from_counts(4, 4, vec![16]), false);
let d = ann(Rle::from_counts(8, 8, vec![64]), false);
let mut out = Array2::<f64>::zeros((1, 1));
let err = BoundaryIou::default()
.compute(&[g], &[d], &mut out.view_mut())
.unwrap_err();
match err {
EvalError::DimensionMismatch { detail } => {
assert!(detail.contains("[4, 4]"));
assert!(detail.contains("[8, 8]"));
}
other => panic!("expected DimensionMismatch, got {other:?}"),
}
}
#[test]
fn default_dilation_ratio_is_pinned_constant() {
assert_eq!(
BoundaryIou::default().dilation_ratio,
BOUNDARY_DILATION_RATIO_DEFAULT
);
}
#[test]
fn custom_dilation_ratio_flows_through_to_bands() {
let h = 20;
let w = 20;
let gt = filled_rect(h, w, 0, 5, 10, 10);
let dt = filled_rect(h, w, 5, 5, 10, 10);
let run = |ratio: f64| -> f64 {
let mut out = Array2::<f64>::zeros((1, 1));
BoundaryIou {
dilation_ratio: ratio,
}
.compute(
&[ann(gt.clone(), false)],
&[ann(dt.clone(), false)],
&mut out.view_mut(),
)
.unwrap();
out[[0, 0]]
};
let large_ratio = 0.10;
let g_band = boundary_band(>, large_ratio).unwrap();
let d_band = boundary_band(&dt, large_ratio).unwrap();
let inter_mask = gt.intersect_area(&dt).unwrap();
let mask_iou = (inter_mask as f64) / ((gt.area() + dt.area() - inter_mask) as f64);
let inter_bound = g_band.intersect_area(&d_band).unwrap();
let bound_iou =
(inter_bound as f64) / ((g_band.area() + d_band.area() - inter_bound) as f64);
let expected_large = mask_iou.min(bound_iou);
let actual_small = run(0.04);
let actual_large = run(large_ratio);
assert_eq!(actual_large.to_bits(), expected_large.to_bits());
assert_ne!(actual_small.to_bits(), actual_large.to_bits());
}
#[test]
fn impl_is_send_and_sync() {
fn assert_send_sync<T: Send + Sync>() {}
assert_send_sync::<BoundaryIou>();
}
}