use std::cmp::Ordering;
use std::collections::{BinaryHeap, HashMap};
#[derive(Debug, Clone)]
pub struct Watershed3dConfig {
pub connectivity: u8,
pub min_size: usize,
pub remove_small_regions: bool,
pub topology_preserve: bool,
}
impl Default for Watershed3dConfig {
fn default() -> Self {
Watershed3dConfig {
connectivity: 6,
min_size: 50,
remove_small_regions: true,
topology_preserve: false,
}
}
}
#[derive(Debug, Clone)]
struct PqEntry {
neg_priority: OrderedF64,
pos: (usize, usize, usize),
label: usize,
}
#[derive(Debug, Clone, Copy, PartialEq)]
struct OrderedF64(f64);
impl Eq for OrderedF64 {}
impl PartialOrd for OrderedF64 {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for OrderedF64 {
fn cmp(&self, other: &Self) -> Ordering {
self.0.partial_cmp(&other.0).unwrap_or(Ordering::Less)
}
}
impl PartialEq for PqEntry {
fn eq(&self, other: &Self) -> bool {
self.neg_priority == other.neg_priority
}
}
impl Eq for PqEntry {}
impl PartialOrd for PqEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for PqEntry {
fn cmp(&self, other: &Self) -> Ordering {
self.neg_priority.cmp(&other.neg_priority)
}
}
fn neighbors_6(
d: usize,
h: usize,
w: usize,
nd: usize,
nh: usize,
nw: usize,
) -> Vec<(usize, usize, usize)> {
let mut result = Vec::with_capacity(6);
if d > 0 {
result.push((d - 1, h, w));
}
if d + 1 < nd {
result.push((d + 1, h, w));
}
if h > 0 {
result.push((d, h - 1, w));
}
if h + 1 < nh {
result.push((d, h + 1, w));
}
if w > 0 {
result.push((d, h, w - 1));
}
if w + 1 < nw {
result.push((d, h, w + 1));
}
result
}
fn neighbors_26(
d: usize,
h: usize,
w: usize,
nd: usize,
nh: usize,
nw: usize,
) -> Vec<(usize, usize, usize)> {
let mut result = Vec::new();
let d_min = if d == 0 { 0isize } else { -1isize };
let d_max = if d + 1 < nd { 1isize } else { 0isize };
let h_min = if h == 0 { 0isize } else { -1isize };
let h_max = if h + 1 < nh { 1isize } else { 0isize };
let w_min = if w == 0 { 0isize } else { -1isize };
let w_max = if w + 1 < nw { 1isize } else { 0isize };
for dd in d_min..=d_max {
for dh in h_min..=h_max {
for dw in w_min..=w_max {
if dd == 0 && dh == 0 && dw == 0 {
continue;
}
result.push((
(d as isize + dd) as usize,
(h as isize + dh) as usize,
(w as isize + dw) as usize,
));
}
}
}
result
}
#[inline]
fn get_neighbors(
d: usize,
h: usize,
w: usize,
nd: usize,
nh: usize,
nw: usize,
connectivity: u8,
) -> Vec<(usize, usize, usize)> {
if connectivity == 26 {
neighbors_26(d, h, w, nd, nh, nw)
} else {
neighbors_6(d, h, w, nd, nh, nw)
}
}
pub fn watershed_3d(
image: &[Vec<Vec<f64>>],
markers: &[Vec<Vec<i64>>],
config: &Watershed3dConfig,
) -> Vec<Vec<Vec<usize>>> {
let nd = image.len();
if nd == 0 {
return vec![];
}
let nh = image[0].len();
let nw = if nh > 0 { image[0][0].len() } else { 0 };
let flat = |d: usize, h: usize, w: usize| d * nh * nw + h * nw + w;
let mut labels = vec![0usize; nd * nh * nw];
let mut in_queue = vec![false; nd * nh * nw];
let mut heap: BinaryHeap<PqEntry> = BinaryHeap::new();
for d in 0..nd {
for h in 0..nh {
for w in 0..nw {
let m = markers[d][h][w];
if m > 0 {
let lbl = m as usize;
let fi = flat(d, h, w);
labels[fi] = lbl;
in_queue[fi] = true;
let priority = image[d][h][w];
heap.push(PqEntry {
neg_priority: OrderedF64(-priority),
pos: (d, h, w),
label: lbl,
});
}
}
}
}
while let Some(entry) = heap.pop() {
let (d, h, w) = entry.pos;
let lbl = entry.label;
let fi = flat(d, h, w);
if labels[fi] != lbl && labels[fi] != 0 {
continue;
}
let nbrs = get_neighbors(d, h, w, nd, nh, nw, config.connectivity);
for (nd2, nh2, nw2) in nbrs {
let nfi = flat(nd2, nh2, nw2);
if labels[nfi] != 0 {
continue; }
labels[nfi] = lbl;
if !in_queue[nfi] {
in_queue[nfi] = true;
let priority = image[nd2][nh2][nw2];
heap.push(PqEntry {
neg_priority: OrderedF64(-priority),
pos: (nd2, nh2, nw2),
label: lbl,
});
}
}
}
let mut result = vec![vec![vec![0usize; nw]; nh]; nd];
for d in 0..nd {
for h in 0..nh {
for w in 0..nw {
result[d][h][w] = labels[flat(d, h, w)];
}
}
}
if config.remove_small_regions && config.min_size > 0 {
merge_small_regions_3d(
&mut result,
config.min_size,
config.connectivity,
nd,
nh,
nw,
);
}
result
}
fn merge_small_regions_3d(
labels: &mut Vec<Vec<Vec<usize>>>,
min_size: usize,
connectivity: u8,
nd: usize,
nh: usize,
nw: usize,
) {
let mut size_map: HashMap<usize, usize> = HashMap::new();
for d in 0..nd {
for h in 0..nh {
for w in 0..nw {
*size_map.entry(labels[d][h][w]).or_insert(0) += 1;
}
}
}
let mut adjacency: HashMap<usize, HashMap<usize, usize>> = HashMap::new();
for d in 0..nd {
for h in 0..nh {
for w in 0..nw {
let lbl = labels[d][h][w];
if lbl == 0 {
continue;
}
let sz = *size_map.get(&lbl).unwrap_or(&0);
if sz < min_size {
let nbrs = get_neighbors(d, h, w, nd, nh, nw, connectivity);
for (nd2, nh2, nw2) in nbrs {
let nlbl = labels[nd2][nh2][nw2];
if nlbl != 0 && nlbl != lbl {
let entry = adjacency.entry(lbl).or_default();
*entry.entry(nlbl).or_insert(0) += 1;
}
}
}
}
}
}
let mut remap: HashMap<usize, usize> = HashMap::new();
for (small_lbl, neighbours) in &adjacency {
let sz = *size_map.get(small_lbl).unwrap_or(&0);
if sz >= min_size {
continue;
}
if let Some((&target, _)) = neighbours.iter().max_by_key(|(_, &cnt)| cnt) {
remap.insert(*small_lbl, target);
}
}
if !remap.is_empty() {
for d in 0..nd {
for h in 0..nh {
for w in 0..nw {
let lbl = labels[d][h][w];
if let Some(&new_lbl) = remap.get(&lbl) {
labels[d][h][w] = new_lbl;
}
}
}
}
}
}
pub fn find_watershed_lines(labels: &[Vec<Vec<usize>>]) -> Vec<Vec<Vec<bool>>> {
let nd = labels.len();
if nd == 0 {
return vec![];
}
let nh = labels[0].len();
let nw = if nh > 0 { labels[0][0].len() } else { 0 };
let mut lines = vec![vec![vec![false; nw]; nh]; nd];
for d in 0..nd {
for h in 0..nh {
for w in 0..nw {
let lbl = labels[d][h][w];
if lbl == 0 {
continue;
}
let nbrs = neighbors_6(d, h, w, nd, nh, nw);
for (nd2, nh2, nw2) in nbrs {
let nlbl = labels[nd2][nh2][nw2];
if nlbl != 0 && nlbl != lbl {
lines[d][h][w] = true;
break;
}
}
}
}
}
lines
}
pub fn topological_number_3d(
labels: &[Vec<Vec<usize>>],
pos: (usize, usize, usize),
label: usize,
) -> (u8, u8) {
let nd = labels.len();
if nd == 0 {
return (0, 0);
}
let nh = labels[0].len();
let nw = if nh > 0 { labels[0][0].len() } else { 0 };
let (d, h, w) = pos;
let d_min = d.saturating_sub(1);
let d_max = (d + 1).min(nd - 1);
let h_min = h.saturating_sub(1);
let h_max = (h + 1).min(nh - 1);
let w_min = w.saturating_sub(1);
let w_max = (w + 1).min(nw - 1);
let mut fg_voxels: Vec<(usize, usize, usize)> = Vec::new();
for dd in d_min..=d_max {
for dh in h_min..=h_max {
for dw in w_min..=w_max {
if dd == d && dh == h && dw == w {
continue;
}
if labels[dd][dh][dw] == label {
fg_voxels.push((dd, dh, dw));
}
}
}
}
let t_fore = count_components_in_ball(&fg_voxels);
let mut bg_voxels: Vec<(usize, usize, usize)> = Vec::new();
let face_offsets: &[(isize, isize, isize)] = &[
(-1, 0, 0),
(1, 0, 0),
(0, -1, 0),
(0, 1, 0),
(0, 0, -1),
(0, 0, 1),
];
for &(dd, dh, dw) in face_offsets {
let nd2 = d as isize + dd;
let nh2 = h as isize + dh;
let nw2 = w as isize + dw;
if nd2 >= 0
&& nd2 < nd as isize
&& nh2 >= 0
&& nh2 < nh as isize
&& nw2 >= 0
&& nw2 < nw as isize
{
let nd2 = nd2 as usize;
let nh2 = nh2 as usize;
let nw2 = nw2 as usize;
if labels[nd2][nh2][nw2] != label {
bg_voxels.push((nd2, nh2, nw2));
}
}
}
let t_back = count_components_in_ball(&bg_voxels);
(t_fore as u8, t_back as u8)
}
fn count_components_in_ball(voxels: &[(usize, usize, usize)]) -> usize {
if voxels.is_empty() {
return 0;
}
let n = voxels.len();
let mut visited = vec![false; n];
let mut count = 0usize;
for start in 0..n {
if visited[start] {
continue;
}
count += 1;
let mut queue = std::collections::VecDeque::new();
queue.push_back(start);
visited[start] = true;
while let Some(idx) = queue.pop_front() {
let (d, h, w) = voxels[idx];
for other in 0..n {
if visited[other] {
continue;
}
let (od, oh, ow) = voxels[other];
let max_diff = (d as isize - od as isize)
.abs()
.max((h as isize - oh as isize).abs())
.max((w as isize - ow as isize).abs());
if max_diff <= 1 {
visited[other] = true;
queue.push_back(other);
}
}
}
}
count
}
#[cfg(test)]
mod tests {
use super::*;
fn make_two_basin_image() -> (Vec<Vec<Vec<f64>>>, Vec<Vec<Vec<i64>>>) {
let nd = 5;
let nh = 5;
let nw = 5;
let mut image = vec![vec![vec![0.5f64; nw]; nh]; nd];
let mut markers = vec![vec![vec![0i64; nw]; nh]; nd];
for h in 0..nh {
for w in 0..nw {
image[0][h][w] = 0.1;
image[1][h][w] = 0.2;
}
}
for h in 0..nh {
for w in 0..nw {
image[2][h][w] = 1.0;
}
}
for h in 0..nh {
for w in 0..nw {
image[3][h][w] = 0.1;
image[4][h][w] = 0.2;
}
}
markers[0][2][2] = 1;
markers[4][2][2] = 2;
(image, markers)
}
#[test]
fn test_watershed_3d_two_basins() {
let (image, markers) = make_two_basin_image();
let config = Watershed3dConfig {
connectivity: 6,
min_size: 1,
remove_small_regions: false,
topology_preserve: false,
};
let labels = watershed_3d(&image, &markers, &config);
assert_eq!(labels.len(), 5);
assert_eq!(labels[0][2][2], 1);
assert_eq!(labels[4][2][2], 2);
let mut has_1 = false;
let mut has_2 = false;
for d in 0..5 {
for h in 0..5 {
for w in 0..5 {
match labels[d][h][w] {
1 => has_1 = true,
2 => has_2 = true,
_ => {}
}
}
}
}
assert!(has_1, "Expected label 1 in result");
assert!(has_2, "Expected label 2 in result");
}
#[test]
fn test_watershed_lines() {
let (image, markers) = make_two_basin_image();
let config = Watershed3dConfig {
connectivity: 6,
min_size: 1,
remove_small_regions: false,
topology_preserve: false,
};
let labels = watershed_3d(&image, &markers, &config);
let lines = find_watershed_lines(&labels);
let has_lines = lines
.iter()
.any(|plane| plane.iter().any(|row| row.iter().any(|&v| v)));
assert!(has_lines, "Expected watershed lines between basins");
}
#[test]
fn test_min_size_merging() {
let nd = 3;
let nh = 3;
let nw = 3;
let image = vec![vec![vec![0.5f64; nw]; nh]; nd];
let mut markers = vec![vec![vec![0i64; nw]; nh]; nd];
markers[0][0][0] = 1;
markers[2][2][2] = 2;
let config = Watershed3dConfig {
connectivity: 6,
min_size: 5, remove_small_regions: true,
topology_preserve: false,
};
let labels = watershed_3d(&image, &markers, &config);
let has_2 = labels
.iter()
.any(|plane| plane.iter().any(|row| row.iter().any(|&v| v == 2)));
assert!(!labels.is_empty(), "Labels should not be empty");
let _ = has_2; }
#[test]
fn test_topological_number_3d() {
let nd = 3;
let nh = 3;
let nw = 3;
let mut labels = vec![vec![vec![1usize; nw]; nh]; nd];
labels[1][1][1] = 0;
let (t_fore, t_back) = topological_number_3d(&labels, (1, 1, 1), 1);
assert_eq!(
t_back, 0,
"No background neighbours in 6-ball of solid cube"
);
assert!(t_fore >= 1, "Should have foreground component(s)");
}
#[test]
fn test_watershed_3d_single_marker() {
let nd = 2;
let nh = 2;
let nw = 2;
let image = vec![vec![vec![0.5f64; nw]; nh]; nd];
let mut markers = vec![vec![vec![0i64; nw]; nh]; nd];
markers[0][0][0] = 1;
let config = Watershed3dConfig::default();
let labels = watershed_3d(&image, &markers, &config);
let all_labelled = labels
.iter()
.all(|plane| plane.iter().all(|row| row.iter().all(|&v| v == 1)));
assert!(
all_labelled,
"All voxels should be label 1 with single marker"
);
}
}