use ndarray::{Array2, ArrayView2};
use super::config::{Connectivity, LabelInput};
pub(super) struct Annuli {
pub inner: Array2<bool>,
pub outer: Array2<bool>,
}
pub(super) fn dilate(
mask: ArrayView2<bool>,
connectivity: Connectivity,
iterations: usize,
) -> Array2<bool> {
let rows = mask.shape()[0];
let cols = mask.shape()[1];
let mut current = mask.to_owned();
if iterations == 0 {
return current;
}
let offsets = connectivity.offsets();
let mut next = current.clone();
for _ in 0..iterations {
next.assign(¤t);
for row in 0..rows {
for col in 0..cols {
if !current[(row, col)] {
continue;
}
for &(d_row, d_col) in offsets {
let next_row_signed = row as isize + d_row;
let next_col_signed = col as isize + d_col;
if next_row_signed < 0 || next_col_signed < 0 {
continue;
}
let next_row = next_row_signed as usize;
let next_col = next_col_signed as usize;
if next_row >= rows || next_col >= cols {
continue;
}
next[(next_row, next_col)] = true;
}
}
}
std::mem::swap(&mut current, &mut next);
}
current
}
pub(super) fn extract_annuli(
mask: ArrayView2<bool>,
label: Option<&LabelInput>,
connectivity: Connectivity,
thickness: usize,
) -> Annuli {
let rows = mask.shape()[0];
let cols = mask.shape()[1];
let dilated_1 = dilate(mask, connectivity, thickness);
let dilated_2 = dilate(dilated_1.view(), connectivity, thickness);
let mut inner = Array2::<bool>::from_elem((rows, cols), false);
let mut outer = Array2::<bool>::from_elem((rows, cols), false);
for row in 0..rows {
for col in 0..cols {
let label_ok = match label {
Some(label) => label.allowed.contains(&label.map[(row, col)]),
None => true,
};
if !label_ok {
continue;
}
let in_mask = mask[(row, col)];
let in_d1 = dilated_1[(row, col)];
let in_d2 = dilated_2[(row, col)];
if in_d1 && !in_mask {
inner[(row, col)] = true;
}
if in_d2 && !in_d1 {
outer[(row, col)] = true;
}
}
}
Annuli { inner, outer }
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array2;
fn count_true(mask: &Array2<bool>) -> usize {
mask.iter().filter(|&&v| v).count()
}
#[test]
fn dilate_expands_correctly_per_connectivity() {
let mut mask = Array2::<bool>::from_elem((5, 5), false);
mask[(2, 2)] = true;
let four = dilate(mask.view(), Connectivity::Four, 1);
assert_eq!(count_true(&four), 5);
for &(row, col) in &[(2, 2), (1, 2), (3, 2), (2, 1), (2, 3)] {
assert!(four[(row, col)], "Four-dilated must include ({row}, {col})");
}
let eight = dilate(mask.view(), Connectivity::Eight, 1);
assert_eq!(count_true(&eight), 9);
for row in 1..=3 {
for col in 1..=3 {
assert!(
eight[(row, col)],
"Eight-dilated must include ({row}, {col})"
);
}
}
}
#[test]
fn extract_annuli_yields_inner_outer_with_label_gating() {
let mut mask = Array2::<bool>::from_elem((5, 5), false);
mask[(2, 2)] = true;
let mut label_map = Array2::<i32>::from_elem((5, 5), 1);
label_map[(2, 3)] = 2;
let label = LabelInput {
map: label_map.view(),
allowed: vec![1],
};
let annuli = extract_annuli(mask.view(), Some(&label), Connectivity::Four, 1);
let expected_inner = [(1, 2), (2, 1), (3, 2)];
assert_eq!(count_true(&annuli.inner), expected_inner.len());
for &(row, col) in &expected_inner {
assert!(
annuli.inner[(row, col)],
"inner must include ({row}, {col})"
);
}
assert_eq!(count_true(&annuli.outer), 8);
}
}