use scirs2_core::ndarray::{Array, Dimension, IxDyn};
use std::collections::HashMap;
use super::Connectivity;
use crate::error::{NdimageError, NdimageResult};
struct UnionFind {
parent: Vec<usize>,
rank: Vec<usize>,
}
impl UnionFind {
fn new(size: usize) -> Self {
UnionFind {
parent: (0..size).collect(),
rank: vec![0; size],
}
}
fn find(&mut self, x: usize) -> usize {
if self.parent[x] != x {
self.parent[x] = self.find(self.parent[x]); }
self.parent[x]
}
fn union(&mut self, x: usize, y: usize) {
let root_x = self.find(x);
let root_y = self.find(y);
if root_x != root_y {
if self.rank[root_x] < self.rank[root_y] {
self.parent[root_x] = root_y;
} else if self.rank[root_x] > self.rank[root_y] {
self.parent[root_y] = root_x;
} else {
self.parent[root_y] = root_x;
self.rank[root_x] += 1;
}
}
}
fn get_component_mapping(&mut self) -> HashMap<usize, usize> {
let mut mapping = HashMap::new();
let mut next_label = 1;
for i in 0..self.parent.len() {
let root = self.find(i);
if !mapping.contains_key(&root) {
mapping.insert(root, next_label);
next_label += 1;
}
}
mapping
}
}
#[allow(dead_code)]
fn get_neighbors(
position: &[usize],
shape: &[usize],
connectivity: Connectivity,
) -> Vec<Vec<usize>> {
let ndim = position.len();
let mut neighbors = Vec::new();
match connectivity {
Connectivity::Face => {
for dim in 0..ndim {
if position[dim] > 0 {
let mut neighbor = position.to_vec();
neighbor[dim] -= 1;
neighbors.push(neighbor);
}
if position[dim] + 1 < shape[dim] {
let mut neighbor = position.to_vec();
neighbor[dim] += 1;
neighbors.push(neighbor);
}
}
}
Connectivity::FaceEdge => {
let offsets = generate_face_edge_offsets(ndim);
for offset in offsets {
let mut neighbor = Vec::with_capacity(ndim);
let mut valid = true;
for (i, &pos) in position.iter().enumerate() {
let new_pos = (pos as isize) + offset[i];
if new_pos < 0 || new_pos >= shape[i] as isize {
valid = false;
break;
}
neighbor.push(new_pos as usize);
}
if valid && neighbor != position {
neighbors.push(neighbor);
}
}
}
Connectivity::Full => {
let offsets = generate_all_offsets(ndim);
for offset in offsets {
let mut neighbor = Vec::with_capacity(ndim);
let mut valid = true;
for (i, &pos) in position.iter().enumerate() {
let new_pos = (pos as isize) + offset[i];
if new_pos < 0 || new_pos >= shape[i] as isize {
valid = false;
break;
}
neighbor.push(new_pos as usize);
}
if valid && neighbor != position {
neighbors.push(neighbor);
}
}
}
}
neighbors
}
#[allow(dead_code)]
fn generate_all_offsets(ndim: usize) -> Vec<Vec<isize>> {
let mut offsets = Vec::new();
let total_combinations = 3_usize.pow(ndim as u32);
for i in 0..total_combinations {
let mut offset = Vec::with_capacity(ndim);
let mut temp = i;
for _ in 0..ndim {
let val = (temp % 3) as isize - 1; offset.push(val);
temp /= 3;
}
if !offset.iter().all(|&x| x == 0) {
offsets.push(offset);
}
}
offsets
}
#[allow(dead_code)]
fn generate_face_edge_offsets(ndim: usize) -> Vec<Vec<isize>> {
let mut offsets = Vec::new();
let total_combinations = 3_usize.pow(ndim as u32);
for i in 0..total_combinations {
let mut offset = Vec::with_capacity(ndim);
let mut temp = i;
for _ in 0..ndim {
let val = (temp % 3) as isize - 1; offset.push(val);
temp /= 3;
}
if offset.iter().all(|&x| x == 0) {
continue;
}
let non_zero_count = offset.iter().filter(|&&x| x != 0).count();
if non_zero_count <= 2 {
offsets.push(offset);
}
}
offsets
}
#[allow(dead_code)]
fn ravel_index(indices: &[usize], shape: &[usize]) -> usize {
let mut flat_index = 0;
let mut stride = 1;
for i in (0..indices.len()).rev() {
flat_index += indices[i] * stride;
stride *= shape[i];
}
flat_index
}
#[allow(dead_code)]
fn unravel_index(_flatindex: usize, shape: &[usize]) -> Vec<usize> {
let mut indices = vec![0; shape.len()];
let mut remaining = _flatindex;
for i in 0..shape.len() {
let stride: usize = shape[(i + 1)..].iter().product();
indices[i] = remaining / stride;
remaining %= stride;
}
indices
}
#[allow(dead_code)]
pub fn label<D>(
input: &Array<bool, D>,
structure: Option<&Array<bool, D>>,
connectivity: Option<Connectivity>,
background: Option<bool>,
) -> NdimageResult<(Array<usize, D>, usize)>
where
D: Dimension,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
let conn = connectivity.unwrap_or(Connectivity::Face);
let bg = background.unwrap_or(false);
if let Some(struct_elem) = structure {
if struct_elem.ndim() != input.ndim() {
return Err(NdimageError::DimensionError(format!(
"Structure must have same rank as input (got {} expected {})",
struct_elem.ndim(),
input.ndim()
)));
}
}
let shape = input.shape();
let total_elements: usize = shape.iter().product();
if total_elements == 0 {
let output = Array::<usize, D>::zeros(input.raw_dim());
return Ok((output, 0));
}
let mut uf = UnionFind::new(total_elements);
let input_dyn = input.clone().into_dyn();
for flat_idx in 0..total_elements {
let indices = unravel_index(flat_idx, shape);
let current_pixel = input_dyn[IxDyn(&indices)];
if current_pixel == !bg {
let neighbors = get_neighbors(&indices, shape, conn);
for neighbor_indices in neighbors {
let neighbor_pixel = input_dyn[IxDyn(&neighbor_indices)];
if neighbor_pixel == current_pixel {
let neighbor_flat_idx = ravel_index(&neighbor_indices, shape);
uf.union(flat_idx, neighbor_flat_idx);
}
}
}
}
let component_mapping = uf.get_component_mapping();
let mut output = Array::<usize, D>::zeros(input.raw_dim());
let mut num_labels = 0;
let mut output_dyn = output.clone().into_dyn();
for flat_idx in 0..total_elements {
let indices = unravel_index(flat_idx, shape);
let pixel = input_dyn[IxDyn(&indices)];
if pixel == !bg {
let root = uf.find(flat_idx);
if let Some(&label) = component_mapping.get(&root) {
output_dyn[IxDyn(&indices)] = label;
num_labels = num_labels.max(label);
}
}
}
output = output_dyn.into_dimensionality::<D>().map_err(|_| {
NdimageError::DimensionError("Failed to convert back to original dimension type".into())
})?;
Ok((output, num_labels))
}
#[allow(dead_code)]
pub fn find_boundaries<D>(
input: &Array<usize, D>,
connectivity: Option<Connectivity>,
mode: Option<&str>,
) -> NdimageResult<Array<bool, D>>
where
D: Dimension,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
let conn = connectivity.unwrap_or(Connectivity::Face);
let mode_str = mode.unwrap_or("outer");
if mode_str != "inner" && mode_str != "outer" && mode_str != "thick" {
return Err(NdimageError::InvalidInput(format!(
"Mode must be 'inner', 'outer', or 'thick', got '{}'",
mode_str
)));
}
let shape = input.shape();
let total_elements: usize = shape.iter().product();
let mut output = Array::<bool, D>::from_elem(input.raw_dim(), false);
if total_elements == 0 {
return Ok(output);
}
let input_dyn = input.clone().into_dyn();
let mut output_dyn = output.clone().into_dyn();
for flat_idx in 0..total_elements {
let indices = unravel_index(flat_idx, shape);
let current_label = input_dyn[IxDyn(&indices)];
if mode_str == "inner" && current_label == 0 {
continue;
}
let neighbors = get_neighbors(&indices, shape, conn);
let mut is_boundary = false;
for neighbor_indices in neighbors {
let neighbor_label = input_dyn[IxDyn(&neighbor_indices)];
match mode_str {
"inner"
if current_label != 0
&& (neighbor_label == 0 || neighbor_label != current_label)
=> {
is_boundary = true;
break;
}
"outer"
if current_label == 0 && neighbor_label != 0 => {
is_boundary = true;
break;
}
"thick"
if current_label != neighbor_label => {
is_boundary = true;
break;
}
_ => {} }
}
if is_boundary {
output_dyn[IxDyn(&indices)] = true;
}
}
output = output_dyn.into_dimensionality::<D>().map_err(|_| {
NdimageError::DimensionError("Failed to convert back to original dimension type".into())
})?;
Ok(output)
}
#[allow(dead_code)]
pub fn remove_small_objects<D>(
input: &Array<bool, D>,
min_size: usize,
connectivity: Option<Connectivity>,
) -> NdimageResult<Array<bool, D>>
where
D: Dimension,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
if min_size == 0 {
return Err(NdimageError::InvalidInput(
"min_size must be greater than 0".into(),
));
}
let conn = connectivity.unwrap_or(Connectivity::Face);
let (labeled, num_labels) = label(input, None, Some(conn), None)?;
if num_labels == 0 {
return Ok(Array::<bool, D>::from_elem(input.raw_dim(), false));
}
let mut component_sizes = vec![0; num_labels + 1];
for &label_val in labeled.iter() {
if label_val > 0 {
component_sizes[label_val] += 1;
}
}
let mut output = Array::<bool, D>::from_elem(input.raw_dim(), false);
let shape = input.shape();
let total_elements: usize = shape.iter().product();
let labeled_dyn = labeled.clone().into_dyn();
let mut output_dyn = output.clone().into_dyn();
for flat_idx in 0..total_elements {
let indices = unravel_index(flat_idx, shape);
let label_val = labeled_dyn[IxDyn(&indices)];
if label_val > 0 && component_sizes[label_val] >= min_size {
output_dyn[IxDyn(&indices)] = true;
}
}
output = output_dyn.into_dimensionality::<D>().map_err(|_| {
NdimageError::DimensionError("Failed to convert back to original dimension type".into())
})?;
Ok(output)
}
#[allow(dead_code)]
pub fn remove_small_holes<D>(
input: &Array<bool, D>,
min_size: usize,
connectivity: Option<Connectivity>,
) -> NdimageResult<Array<bool, D>>
where
D: Dimension,
{
if input.ndim() == 0 {
return Err(NdimageError::InvalidInput(
"Input array cannot be 0-dimensional".into(),
));
}
if min_size == 0 {
return Err(NdimageError::InvalidInput(
"min_size must be greater than 0".into(),
));
}
let conn = connectivity.unwrap_or(Connectivity::Face);
let mut inverted = input.clone();
for pixel in inverted.iter_mut() {
*pixel = !*pixel;
}
let filtered_inverted = remove_small_objects(&inverted, min_size, Some(conn))?;
let mut output = filtered_inverted;
for pixel in output.iter_mut() {
*pixel = !*pixel;
}
Ok(output)
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct BoundingBox2D {
pub label: usize,
pub min_row: usize,
pub max_row: usize,
pub min_col: usize,
pub max_col: usize,
}
impl BoundingBox2D {
pub fn width(&self) -> usize {
self.max_col - self.min_col
}
pub fn height(&self) -> usize {
self.max_row - self.min_row
}
pub fn area(&self) -> usize {
self.width() * self.height()
}
}
pub fn label_2d(
input: &scirs2_core::ndarray::Array2<bool>,
connectivity: Option<Connectivity>,
) -> NdimageResult<(scirs2_core::ndarray::Array2<usize>, usize)> {
use scirs2_core::ndarray::Array2;
let conn = connectivity.unwrap_or(Connectivity::Face);
let rows = input.nrows();
let cols = input.ncols();
if rows == 0 || cols == 0 {
return Ok((Array2::zeros((rows, cols)), 0));
}
let total = rows * cols;
let mut uf = UnionFind::new(total);
let use_diag = matches!(conn, Connectivity::Full | Connectivity::FaceEdge);
for r in 0..rows {
for c in 0..cols {
if !input[[r, c]] {
continue;
}
let idx = r * cols + c;
if r > 0 && input[[r - 1, c]] {
uf.union(idx, (r - 1) * cols + c);
}
if c > 0 && input[[r, c - 1]] {
uf.union(idx, r * cols + (c - 1));
}
if use_diag {
if r > 0 && c > 0 && input[[r - 1, c - 1]] {
uf.union(idx, (r - 1) * cols + (c - 1));
}
if r > 0 && c + 1 < cols && input[[r - 1, c + 1]] {
uf.union(idx, (r - 1) * cols + (c + 1));
}
}
}
}
let mut root_to_label: HashMap<usize, usize> = HashMap::new();
let mut next_label = 1usize;
let mut output = Array2::zeros((rows, cols));
for r in 0..rows {
for c in 0..cols {
if !input[[r, c]] {
continue;
}
let idx = r * cols + c;
let root = uf.find(idx);
let lbl = match root_to_label.get(&root) {
Some(&l) => l,
None => {
let l = next_label;
root_to_label.insert(root, l);
next_label += 1;
l
}
};
output[[r, c]] = lbl;
}
}
let num_labels = next_label - 1;
Ok((output, num_labels))
}
pub fn find_objects_2d(
labeled: &scirs2_core::ndarray::Array2<usize>,
) -> NdimageResult<Vec<BoundingBox2D>> {
let rows = labeled.nrows();
let cols = labeled.ncols();
if rows == 0 || cols == 0 {
return Ok(Vec::new());
}
let mut bbox_map: HashMap<usize, (usize, usize, usize, usize)> = HashMap::new();
for r in 0..rows {
for c in 0..cols {
let lbl = labeled[[r, c]];
if lbl == 0 {
continue;
}
let entry = bbox_map.entry(lbl).or_insert((r, r, c, c));
if r < entry.0 {
entry.0 = r;
}
if r > entry.1 {
entry.1 = r;
}
if c < entry.2 {
entry.2 = c;
}
if c > entry.3 {
entry.3 = c;
}
}
}
let mut result: Vec<BoundingBox2D> = bbox_map
.into_iter()
.map(|(lbl, (min_r, max_r, min_c, max_c))| BoundingBox2D {
label: lbl,
min_row: min_r,
max_row: max_r + 1, min_col: min_c,
max_col: max_c + 1, })
.collect();
result.sort_by_key(|b| b.label);
Ok(result)
}
pub fn count_labels_2d(labeled: &scirs2_core::ndarray::Array2<usize>) -> HashMap<usize, usize> {
let mut counts: HashMap<usize, usize> = HashMap::new();
for &lbl in labeled.iter() {
if lbl > 0 {
*counts.entry(lbl).or_insert(0) += 1;
}
}
counts
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_label() {
let input = Array2::from_elem((3, 3), true);
let (result, _num_labels) = label(&input, None, None, None).expect("Operation failed");
assert_eq!(result.shape(), input.shape());
}
#[test]
fn test_find_boundaries() {
let input = Array2::from_elem((3, 3), 1);
let result = find_boundaries(&input, None, None).expect("Operation failed");
assert_eq!(result.shape(), input.shape());
}
#[test]
fn test_remove_small_objects() {
let input = Array2::from_elem((3, 3), true);
let result = remove_small_objects(&input, 1, None).expect("Operation failed");
assert_eq!(result.shape(), input.shape());
}
#[test]
fn test_label_2d_two_components() {
let input = array![
[true, true, false, false],
[true, true, false, false],
[false, false, true, true],
[false, false, true, true],
];
let (labeled, num) = label_2d(&input, None).expect("label_2d should succeed");
assert_eq!(num, 2);
let l1 = labeled[[0, 0]];
let l2 = labeled[[2, 2]];
assert_ne!(l1, 0);
assert_ne!(l2, 0);
assert_ne!(l1, l2);
assert_eq!(labeled[[0, 0]], labeled[[0, 1]]);
assert_eq!(labeled[[0, 0]], labeled[[1, 0]]);
assert_eq!(labeled[[0, 0]], labeled[[1, 1]]);
}
#[test]
fn test_label_2d_single_component_8conn() {
let input = array![
[true, false, false],
[false, true, false],
[false, false, true],
];
let (labeled, num) =
label_2d(&input, Some(Connectivity::Full)).expect("label_2d 8-conn should succeed");
assert_eq!(num, 1);
assert_eq!(labeled[[0, 0]], labeled[[1, 1]]);
assert_eq!(labeled[[1, 1]], labeled[[2, 2]]);
}
#[test]
fn test_label_2d_multiple_components_4conn() {
let input = array![
[true, false, false],
[false, true, false],
[false, false, true],
];
let (labeled, num) =
label_2d(&input, Some(Connectivity::Face)).expect("label_2d 4-conn should succeed");
assert_eq!(num, 3);
let l0 = labeled[[0, 0]];
let l1 = labeled[[1, 1]];
let l2 = labeled[[2, 2]];
assert_ne!(l0, l1);
assert_ne!(l1, l2);
assert_ne!(l0, l2);
}
#[test]
fn test_label_2d_empty() {
let input = Array2::from_elem((3, 3), false);
let (labeled, num) = label_2d(&input, None).expect("empty should succeed");
assert_eq!(num, 0);
for &v in labeled.iter() {
assert_eq!(v, 0);
}
}
#[test]
fn test_label_2d_all_foreground() {
let input = Array2::from_elem((4, 4), true);
let (labeled, num) = label_2d(&input, None).expect("all-foreground should succeed");
assert_eq!(num, 1);
let expected_label = labeled[[0, 0]];
for &v in labeled.iter() {
assert_eq!(v, expected_label);
}
}
#[test]
fn test_find_objects_2d_basic() {
let input = array![
[true, true, false, false],
[true, true, false, false],
[false, false, true, true],
[false, false, true, true],
];
let (labeled, _) = label_2d(&input, None).expect("label_2d should succeed");
let objects = find_objects_2d(&labeled).expect("find_objects_2d should succeed");
assert_eq!(objects.len(), 2);
let obj1 = &objects[0];
assert_eq!(obj1.min_row, 0);
assert_eq!(obj1.max_row, 2);
assert_eq!(obj1.min_col, 0);
assert_eq!(obj1.max_col, 2);
assert_eq!(obj1.width(), 2);
assert_eq!(obj1.height(), 2);
let obj2 = &objects[1];
assert_eq!(obj2.min_row, 2);
assert_eq!(obj2.max_row, 4);
assert_eq!(obj2.min_col, 2);
assert_eq!(obj2.max_col, 4);
}
#[test]
fn test_find_objects_2d_no_objects() {
let labeled = Array2::<usize>::zeros((5, 5));
let objects = find_objects_2d(&labeled).expect("no objects should succeed");
assert!(objects.is_empty());
}
#[test]
fn test_find_objects_2d_single_pixel() {
let mut labeled = Array2::<usize>::zeros((5, 5));
labeled[[2, 3]] = 1;
let objects = find_objects_2d(&labeled).expect("single pixel should succeed");
assert_eq!(objects.len(), 1);
assert_eq!(objects[0].min_row, 2);
assert_eq!(objects[0].max_row, 3);
assert_eq!(objects[0].min_col, 3);
assert_eq!(objects[0].max_col, 4);
assert_eq!(objects[0].area(), 1);
}
#[test]
fn test_count_labels_2d() {
let labeled = array![[0, 1, 1, 0], [0, 1, 0, 2], [3, 0, 0, 2], [3, 3, 0, 0],];
let counts = count_labels_2d(&labeled);
assert_eq!(counts.get(&1), Some(&3));
assert_eq!(counts.get(&2), Some(&2));
assert_eq!(counts.get(&3), Some(&3));
assert_eq!(counts.get(&0), None); }
#[test]
fn test_label_2d_l_shape() {
let input = array![
[true, false, false],
[true, false, false],
[true, true, true],
];
let (labeled, num) = label_2d(&input, None).expect("L-shape should succeed");
assert_eq!(num, 1);
let expected = labeled[[0, 0]];
assert_eq!(labeled[[1, 0]], expected);
assert_eq!(labeled[[2, 0]], expected);
assert_eq!(labeled[[2, 1]], expected);
assert_eq!(labeled[[2, 2]], expected);
}
}