use crate::error::{NdimageError, NdimageResult};
use scirs2_core::ndarray::{Array, Array2, Ix2};
use scirs2_core::numeric::{Float, NumAssign};
use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WatershedConnectivity {
Four,
Eight,
}
impl Default for WatershedConnectivity {
fn default() -> Self {
WatershedConnectivity::Eight
}
}
#[derive(Debug, Clone)]
pub struct WatershedConfig {
pub connectivity: WatershedConnectivity,
pub watershed_line: bool,
pub compact_labels: bool,
}
impl Default for WatershedConfig {
fn default() -> Self {
WatershedConfig {
connectivity: WatershedConnectivity::Eight,
watershed_line: false,
compact_labels: false,
}
}
}
const WATERSHED_LABEL: i32 = -1;
const IN_QUEUE: i32 = -2;
#[derive(Clone, Debug)]
struct QueueEntry {
row: usize,
col: usize,
priority: f64,
order: u64,
}
impl PartialEq for QueueEntry {
fn eq(&self, other: &Self) -> bool {
self.row == other.row && self.col == other.col
}
}
impl Eq for QueueEntry {}
impl PartialOrd for QueueEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for QueueEntry {
fn cmp(&self, other: &Self) -> Ordering {
match other
.priority
.partial_cmp(&self.priority)
.unwrap_or(Ordering::Equal)
{
Ordering::Equal => {
other.order.cmp(&self.order)
}
ord => ord,
}
}
}
fn get_offsets(connectivity: WatershedConnectivity) -> &'static [(isize, isize)] {
match connectivity {
WatershedConnectivity::Four => &[(-1, 0), (0, -1), (0, 1), (1, 0)],
WatershedConnectivity::Eight => &[
(-1, -1),
(-1, 0),
(-1, 1),
(0, -1),
(0, 1),
(1, -1),
(1, 0),
(1, 1),
],
}
}
#[inline]
fn in_bounds(r: isize, c: isize, rows: usize, cols: usize) -> bool {
r >= 0 && r < rows as isize && c >= 0 && c < cols as isize
}
pub fn watershed<T>(
image: &Array<T, Ix2>,
markers: &Array<i32, Ix2>,
) -> NdimageResult<Array<i32, Ix2>>
where
T: Float + NumAssign + std::fmt::Debug + 'static,
{
watershed_with_config(image, markers, &WatershedConfig::default())
}
pub fn marker_watershed<T>(
image: &Array<T, Ix2>,
markers: &Array<i32, Ix2>,
connectivity: usize,
) -> NdimageResult<Array<i32, Ix2>>
where
T: Float + NumAssign + std::fmt::Debug + 'static,
{
if connectivity != 1 && connectivity != 2 {
return Err(NdimageError::InvalidInput(
"Connectivity must be 1 (4-connected) or 2 (8-connected)".to_string(),
));
}
let conn = if connectivity == 1 {
WatershedConnectivity::Four
} else {
WatershedConnectivity::Eight
};
let config = WatershedConfig {
connectivity: conn,
watershed_line: false,
compact_labels: false,
};
watershed_with_config(image, markers, &config)
}
pub fn watershed_with_config<T>(
image: &Array<T, Ix2>,
markers: &Array<i32, Ix2>,
config: &WatershedConfig,
) -> NdimageResult<Array<i32, Ix2>>
where
T: Float + NumAssign + std::fmt::Debug + 'static,
{
if image.shape() != markers.shape() {
return Err(NdimageError::DimensionError(
"Input image and markers must have the same shape".to_string(),
));
}
let rows = image.nrows();
let cols = image.ncols();
if rows == 0 || cols == 0 {
return Ok(markers.clone());
}
let offsets = get_offsets(config.connectivity);
let mut output = markers.clone();
let mut queue = BinaryHeap::new();
let mut insertion_order: u64 = 0;
for r in 0..rows {
for c in 0..cols {
let marker = markers[[r, c]];
if marker > 0 {
for &(dr, dc) in offsets {
let nr = r as isize + dr;
let nc = c as isize + dc;
if in_bounds(nr, nc, rows, cols) {
let nr = nr as usize;
let nc = nc as usize;
if output[[nr, nc]] == 0 {
output[[nr, nc]] = IN_QUEUE;
let priority = image[[nr, nc]].to_f64().unwrap_or(f64::INFINITY);
queue.push(QueueEntry {
row: nr,
col: nc,
priority,
order: insertion_order,
});
insertion_order += 1;
}
}
}
}
}
}
while let Some(entry) = queue.pop() {
let r = entry.row;
let c = entry.col;
let mut neighbor_labels: HashMap<i32, usize> = HashMap::new();
let mut _has_watershed_neighbor = false;
for &(dr, dc) in offsets {
let nr = r as isize + dr;
let nc = c as isize + dc;
if in_bounds(nr, nc, rows, cols) {
let nr = nr as usize;
let nc = nc as usize;
let label = output[[nr, nc]];
if label > 0 {
*neighbor_labels.entry(label).or_insert(0) += 1;
} else if label == WATERSHED_LABEL {
_has_watershed_neighbor = true;
}
}
}
if neighbor_labels.is_empty() {
output[[r, c]] = 0;
continue;
}
let distinct_labels: Vec<i32> = neighbor_labels.keys().copied().collect();
if config.watershed_line && distinct_labels.len() > 1 {
output[[r, c]] = WATERSHED_LABEL;
} else {
let best_label = neighbor_labels
.iter()
.max_by_key(|&(_, count)| count)
.map(|(&lbl, _)| lbl)
.unwrap_or(0);
if best_label > 0 {
output[[r, c]] = best_label;
} else {
output[[r, c]] = 0;
continue;
}
}
for &(dr, dc) in offsets {
let nr = r as isize + dr;
let nc = c as isize + dc;
if in_bounds(nr, nc, rows, cols) {
let nr = nr as usize;
let nc = nc as usize;
if output[[nr, nc]] == 0 {
output[[nr, nc]] = IN_QUEUE;
let priority = image[[nr, nc]].to_f64().unwrap_or(f64::INFINITY);
queue.push(QueueEntry {
row: nr,
col: nc,
priority,
order: insertion_order,
});
insertion_order += 1;
}
}
}
}
for val in output.iter_mut() {
if *val == IN_QUEUE {
*val = 0;
}
}
Ok(output)
}
pub fn watershed_from_distance<T>(
binary_image: &Array2<bool>,
connectivity: WatershedConnectivity,
min_distance: usize,
) -> NdimageResult<Array2<i32>>
where
T: Float + NumAssign + std::fmt::Debug + 'static,
{
let rows = binary_image.nrows();
let cols = binary_image.ncols();
if rows == 0 || cols == 0 {
return Ok(Array2::zeros((rows, cols)));
}
let mut distance = Array2::<f64>::zeros((rows, cols));
for r in 0..rows {
for c in 0..cols {
if binary_image[[r, c]] {
let mut d = f64::INFINITY;
if r > 0 {
let above = distance[[r - 1, c]];
if above + 1.0 < d {
d = above + 1.0;
}
}
if c > 0 {
let left = distance[[r, c - 1]];
if left + 1.0 < d {
d = left + 1.0;
}
}
if !d.is_finite() {
d = (rows + cols) as f64; }
distance[[r, c]] = d;
}
}
}
for r in (0..rows).rev() {
for c in (0..cols).rev() {
if binary_image[[r, c]] {
if r + 1 < rows {
let below = distance[[r + 1, c]] + 1.0;
if below < distance[[r, c]] {
distance[[r, c]] = below;
}
}
if c + 1 < cols {
let right = distance[[r, c + 1]] + 1.0;
if right < distance[[r, c]] {
distance[[r, c]] = right;
}
}
}
}
}
let offsets = get_offsets(connectivity);
let mut markers = Array2::<i32>::zeros((rows, cols));
let mut next_label = 1i32;
let min_dist_f = min_distance as f64;
for r in 0..rows {
for c in 0..cols {
if !binary_image[[r, c]] {
continue;
}
let val = distance[[r, c]];
if val < min_dist_f {
continue;
}
let mut is_max = true;
for &(dr, dc) in offsets {
let nr = r as isize + dr;
let nc = c as isize + dc;
if in_bounds(nr, nc, rows, cols) {
if distance[[nr as usize, nc as usize]] > val {
is_max = false;
break;
}
}
}
if is_max {
markers[[r, c]] = next_label;
next_label += 1;
}
}
}
let neg_distance: Array2<f64> = distance.mapv(|v| -v);
let config = WatershedConfig {
connectivity,
watershed_line: false,
compact_labels: false,
};
let mut result = watershed_with_config(&neg_distance, &markers, &config)?;
for r in 0..rows {
for c in 0..cols {
if !binary_image[[r, c]] {
result[[r, c]] = 0;
}
}
}
Ok(result)
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_watershed_basic() {
let image = array![[0.5, 0.6, 0.7], [0.4, 0.1, 0.2], [0.3, 0.4, 0.5],];
let markers = array![[0, 0, 0], [0, 1, 0], [0, 0, 2],];
let result = watershed(&image, &markers).expect("watershed should succeed");
assert_eq!(result[[1, 1]], 1);
assert_eq!(result[[2, 2]], 2);
}
#[test]
fn test_watershed_shape_mismatch() {
let image = array![[0.5, 0.6], [0.4, 0.1],];
let markers = array![[0, 0, 0], [0, 1, 0],];
let result = watershed(&image, &markers);
assert!(result.is_err());
}
#[test]
fn test_marker_watershed_4_connectivity() {
let image = array![
[5.0, 5.0, 9.0, 5.0, 5.0],
[5.0, 3.0, 9.0, 3.0, 5.0],
[5.0, 1.0, 9.0, 1.0, 5.0],
[5.0, 3.0, 9.0, 3.0, 5.0],
[5.0, 5.0, 9.0, 5.0, 5.0],
];
let markers = array![
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 1, 0, 2, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
];
let result =
marker_watershed(&image, &markers, 1).expect("marker_watershed should succeed");
assert_eq!(result[[2, 1]], 1);
assert_eq!(result[[2, 3]], 2);
}
#[test]
fn test_marker_watershed_8_connectivity() {
let image = array![
[5.0, 5.0, 9.0, 5.0, 5.0],
[5.0, 3.0, 9.0, 3.0, 5.0],
[5.0, 1.0, 9.0, 1.0, 5.0],
[5.0, 3.0, 9.0, 3.0, 5.0],
[5.0, 5.0, 9.0, 5.0, 5.0],
];
let markers = array![
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
[0, 1, 0, 2, 0],
[0, 0, 0, 0, 0],
[0, 0, 0, 0, 0],
];
let result =
marker_watershed(&image, &markers, 2).expect("marker_watershed should succeed");
assert_eq!(result[[2, 1]], 1);
assert_eq!(result[[2, 3]], 2);
}
#[test]
fn test_marker_watershed_invalid_connectivity() {
let image = array![[1.0, 2.0], [3.0, 4.0],];
let markers = array![[1, 0], [0, 2],];
let result = marker_watershed(&image, &markers, 3);
assert!(result.is_err());
}
#[test]
fn test_watershed_with_config_watershed_line() {
let image = array![
[1.0, 2.0, 9.0, 2.0, 1.0],
[1.0, 2.0, 9.0, 2.0, 1.0],
[1.0, 2.0, 9.0, 2.0, 1.0],
];
let markers = array![[1, 0, 0, 0, 2], [0, 0, 0, 0, 0], [1, 0, 0, 0, 2],];
let config = WatershedConfig {
connectivity: WatershedConnectivity::Four,
watershed_line: true,
compact_labels: false,
};
let result = watershed_with_config(&image, &markers, &config)
.expect("watershed with line should succeed");
assert_eq!(result[[0, 0]], 1);
assert_eq!(result[[0, 4]], 2);
}
#[test]
fn test_watershed_single_marker_floods_all() {
let image = array![[1.0, 2.0, 3.0], [4.0, 5.0, 6.0], [7.0, 8.0, 9.0],];
let markers = array![[1, 0, 0], [0, 0, 0], [0, 0, 0],];
let result = watershed(&image, &markers).expect("watershed should succeed");
for val in result.iter() {
assert_eq!(*val, 1);
}
}
#[test]
fn test_watershed_empty_image() {
let image: Array2<f64> = Array2::zeros((0, 0));
let markers: Array2<i32> = Array2::zeros((0, 0));
let result = watershed(&image, &markers).expect("empty should succeed");
assert_eq!(result.len(), 0);
}
#[test]
fn test_watershed_all_foreground_markers() {
let image = array![[1.0, 2.0], [3.0, 4.0],];
let markers = array![[1, 2], [3, 4],];
let result = watershed(&image, &markers).expect("watershed should succeed");
assert_eq!(result, markers);
}
#[test]
fn test_watershed_from_distance_basic() {
let binary = array![
[true, true, true, false, false, true, true, true],
[true, true, true, false, false, true, true, true],
[true, true, true, false, false, true, true, true],
[false, false, false, false, false, false, false, false],
[true, true, true, false, false, true, true, true],
[true, true, true, false, false, true, true, true],
[true, true, true, false, false, true, true, true],
];
let result = watershed_from_distance::<f64>(&binary, WatershedConnectivity::Four, 1)
.expect("watershed_from_distance should succeed");
assert_eq!(result[[0, 3]], 0);
assert_eq!(result[[3, 0]], 0);
}
}