use crate::core::{Pix, PixelDepth};
use crate::region::conncomp::ConnectivityType;
use crate::region::error::{RegionError, RegionResult};
use std::collections::BinaryHeap;
#[derive(Debug, Clone)]
pub struct WatershedOptions {
pub min_depth: u32,
pub connectivity: ConnectivityType,
}
impl Default for WatershedOptions {
fn default() -> Self {
Self {
min_depth: 1,
connectivity: ConnectivityType::EightWay,
}
}
}
impl WatershedOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_min_depth(mut self, depth: u32) -> Self {
self.min_depth = depth;
self
}
pub fn with_connectivity(mut self, connectivity: ConnectivityType) -> Self {
self.connectivity = connectivity;
self
}
}
#[derive(Clone, Eq, PartialEq)]
struct PixelEntry {
x: u32,
y: u32,
value: u32,
}
impl Ord for PixelEntry {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
other.value.cmp(&self.value)
}
}
impl PartialOrd for PixelEntry {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
Some(self.cmp(other))
}
}
const WATERSHED: u32 = 0;
const UNLABELED: u32 = u32::MAX;
const IN_QUEUE: u32 = u32::MAX - 1;
pub fn watershed_segmentation(pix: &Pix, options: &WatershedOptions) -> RegionResult<Pix> {
if pix.depth() != PixelDepth::Bit8 {
return Err(RegionError::UnsupportedDepth {
expected: "8-bit",
actual: pix.depth().bits(),
});
}
let width = pix.width();
let height = pix.height();
if width == 0 || height == 0 {
return Err(RegionError::EmptyImage);
}
let mut labels = vec![UNLABELED; (width * height) as usize];
let minima = find_local_minima(pix, options.connectivity)?;
let mut current_label = 1u32;
let mut queue = BinaryHeap::new();
for (mx, my) in &minima {
let idx = (my * width + mx) as usize;
if labels[idx] == UNLABELED {
labels[idx] = current_label;
let value = pix.get_pixel(*mx, *my).unwrap_or(0);
queue.push(PixelEntry {
x: *mx,
y: *my,
value,
});
current_label += 1;
}
}
while let Some(entry) = queue.pop() {
let x = entry.x;
let y = entry.y;
let idx = (y * width + x) as usize;
let current_label = labels[idx];
if current_label == WATERSHED || current_label == IN_QUEUE {
continue;
}
let neighbors = get_neighbors(x, y, width, height, options.connectivity);
for (nx, ny) in neighbors {
let nidx = (ny * width + nx) as usize;
let neighbor_label = labels[nidx];
if neighbor_label == UNLABELED {
labels[nidx] = current_label;
let value = pix.get_pixel(nx, ny).unwrap_or(0);
queue.push(PixelEntry {
x: nx,
y: ny,
value,
});
} else if neighbor_label != current_label
&& neighbor_label != WATERSHED
&& neighbor_label != IN_QUEUE
{
if labels[idx] != WATERSHED {
let neighbor_value = pix.get_pixel(nx, ny).unwrap_or(0);
let current_value = pix.get_pixel(x, y).unwrap_or(0);
if current_value.abs_diff(neighbor_value) >= options.min_depth {
}
}
}
}
}
let mut output = Pix::new(width, height, PixelDepth::Bit32)
.map_err(RegionError::Core)?
.try_into_mut()
.unwrap_or_else(|p| p.to_mut());
for y in 0..height {
for x in 0..width {
let idx = (y * width + x) as usize;
let label = labels[idx];
if label == UNLABELED || label == IN_QUEUE {
let _ = output.set_pixel(x, y, 0);
continue;
}
let neighbors = get_neighbors(x, y, width, height, options.connectivity);
let mut is_boundary = false;
for (nx, ny) in &neighbors {
let nidx = (ny * width + nx) as usize;
let neighbor_label = labels[nidx];
if neighbor_label != label
&& neighbor_label != UNLABELED
&& neighbor_label != IN_QUEUE
&& neighbor_label != WATERSHED
{
is_boundary = true;
break;
}
}
if is_boundary {
let _ = output.set_pixel(x, y, WATERSHED);
} else {
let _ = output.set_pixel(x, y, label);
}
}
}
Ok(output.into())
}
pub fn find_local_minima(
pix: &Pix,
connectivity: ConnectivityType,
) -> RegionResult<Vec<(u32, u32)>> {
if pix.depth() != PixelDepth::Bit8 {
return Err(RegionError::UnsupportedDepth {
expected: "8-bit",
actual: pix.depth().bits(),
});
}
let width = pix.width();
let height = pix.height();
let mut minima = Vec::new();
for y in 0..height {
for x in 0..width {
let value = pix.get_pixel(x, y).unwrap_or(255);
let neighbors = get_neighbors(x, y, width, height, connectivity);
let is_minimum = neighbors.iter().all(|&(nx, ny)| {
let neighbor_value = pix.get_pixel(nx, ny).unwrap_or(255);
value <= neighbor_value
});
let has_lower_neighbor = neighbors.iter().any(|&(nx, ny)| {
let neighbor_value = pix.get_pixel(nx, ny).unwrap_or(255);
value < neighbor_value
});
if is_minimum && (has_lower_neighbor || neighbors.is_empty()) {
minima.push((x, y));
}
}
}
Ok(minima)
}
pub fn find_local_maxima(
pix: &Pix,
connectivity: ConnectivityType,
) -> RegionResult<Vec<(u32, u32)>> {
if pix.depth() != PixelDepth::Bit8 {
return Err(RegionError::UnsupportedDepth {
expected: "8-bit",
actual: pix.depth().bits(),
});
}
let width = pix.width();
let height = pix.height();
let mut maxima = Vec::new();
for y in 0..height {
for x in 0..width {
let value = pix.get_pixel(x, y).unwrap_or(0);
let neighbors = get_neighbors(x, y, width, height, connectivity);
let is_maximum = neighbors.iter().all(|&(nx, ny)| {
let neighbor_value = pix.get_pixel(nx, ny).unwrap_or(0);
value >= neighbor_value
});
let has_higher_neighbor = neighbors.iter().any(|&(nx, ny)| {
let neighbor_value = pix.get_pixel(nx, ny).unwrap_or(0);
value > neighbor_value
});
if is_maximum && (has_higher_neighbor || neighbors.is_empty()) {
maxima.push((x, y));
}
}
}
Ok(maxima)
}
pub fn compute_gradient(pix: &Pix) -> RegionResult<Pix> {
if pix.depth() != PixelDepth::Bit8 {
return Err(RegionError::UnsupportedDepth {
expected: "8-bit",
actual: pix.depth().bits(),
});
}
let width = pix.width();
let height = pix.height();
let mut output = Pix::new(width, height, PixelDepth::Bit8)
.map_err(RegionError::Core)?
.try_into_mut()
.unwrap_or_else(|p| p.to_mut());
for y in 0..height {
for x in 0..width {
let center = pix.get_pixel(x, y).unwrap_or(0) as i32;
let left = if x > 0 {
pix.get_pixel(x - 1, y).unwrap_or(0) as i32
} else {
center
};
let right = if x + 1 < width {
pix.get_pixel(x + 1, y).unwrap_or(0) as i32
} else {
center
};
let gx = right - left;
let top = if y > 0 {
pix.get_pixel(x, y - 1).unwrap_or(0) as i32
} else {
center
};
let bottom = if y + 1 < height {
pix.get_pixel(x, y + 1).unwrap_or(0) as i32
} else {
center
};
let gy = bottom - top;
let magnitude = ((gx.abs() + gy.abs()) / 2).min(255) as u32;
let _ = output.set_pixel(x, y, magnitude);
}
}
Ok(output.into())
}
fn get_neighbors(
x: u32,
y: u32,
width: u32,
height: u32,
connectivity: ConnectivityType,
) -> Vec<(u32, u32)> {
let mut neighbors = Vec::with_capacity(8);
if x > 0 {
neighbors.push((x - 1, y));
}
if x + 1 < width {
neighbors.push((x + 1, y));
}
if y > 0 {
neighbors.push((x, y - 1));
}
if y + 1 < height {
neighbors.push((x, y + 1));
}
if connectivity == ConnectivityType::EightWay {
if x > 0 && y > 0 {
neighbors.push((x - 1, y - 1));
}
if x + 1 < width && y > 0 {
neighbors.push((x + 1, y - 1));
}
if x > 0 && y + 1 < height {
neighbors.push((x - 1, y + 1));
}
if x + 1 < width && y + 1 < height {
neighbors.push((x + 1, y + 1));
}
}
neighbors
}
pub struct WatershedResult {
basins: Vec<Pix>,
labeled: Pix,
width: u32,
height: u32,
}
impl WatershedResult {
pub fn num_basins(&self) -> u32 {
self.basins.len() as u32
}
pub fn basins(&self) -> &[Pix] {
&self.basins
}
}
pub fn watershed_with_basins(
pix: &Pix,
options: &WatershedOptions,
) -> RegionResult<WatershedResult> {
let labeled = watershed_segmentation(pix, options)?;
let width = pix.width();
let height = pix.height();
let mut max_label = 0u32;
for y in 0..height {
for x in 0..width {
let label = labeled.get_pixel(x, y).unwrap_or(0);
if label > max_label {
max_label = label;
}
}
}
if max_label == 0 {
return Ok(WatershedResult {
basins: Vec::new(),
labeled,
width,
height,
});
}
let mut basin_images: Vec<_> = (0..max_label)
.map(|_| {
Pix::new(width, height, PixelDepth::Bit8)
.map_err(RegionError::Core)
.map(|p| p.try_into_mut().unwrap_or_else(|p| p.to_mut()))
})
.collect::<RegionResult<Vec<_>>>()?;
for y in 0..height {
for x in 0..width {
let label = labeled.get_pixel(x, y).unwrap_or(0);
if label > 0 {
let val = pix.get_pixel(x, y).unwrap_or(0);
let _ = basin_images[(label - 1) as usize].set_pixel(x, y, val);
}
}
}
Ok(WatershedResult {
basins: basin_images.into_iter().map(|m| m.into()).collect(),
labeled,
width,
height,
})
}
pub fn watershed_render_fill(result: &WatershedResult) -> RegionResult<Pix> {
let width = result.width;
let height = result.height;
let mut output = Pix::new(width, height, PixelDepth::Bit8)
.map_err(RegionError::Core)?
.try_into_mut()
.unwrap_or_else(|p| p.to_mut());
for (idx, basin) in result.basins.iter().enumerate() {
let label = (idx + 1) as u32;
let mut min_val = u32::MAX;
for y in 0..height {
for x in 0..width {
if result.labeled.get_pixel(x, y).unwrap_or(0) == label {
let v = basin.get_pixel(x, y).unwrap_or(u32::MAX);
if v < min_val {
min_val = v;
}
}
}
}
if min_val == u32::MAX {
continue;
}
for y in 0..height {
for x in 0..width {
if result.labeled.get_pixel(x, y).unwrap_or(0) == label {
let _ = output.set_pixel(x, y, min_val);
}
}
}
}
Ok(output.into())
}
pub fn watershed_render_colors(result: &WatershedResult) -> RegionResult<Pix> {
let width = result.width;
let height = result.height;
let mut output = Pix::new(width, height, PixelDepth::Bit32)
.map_err(RegionError::Core)?
.try_into_mut()
.unwrap_or_else(|p| p.to_mut());
let colors: Vec<u32> = (0..result.basins.len())
.map(|idx| {
let h = idx.wrapping_add(1).wrapping_mul(2654435761);
let r = ((h >> 16) & 0xFF) as u32 | 0x40; let g = ((h >> 8) & 0xFF) as u32 | 0x40;
let b = (h & 0xFF) as u32 | 0x40;
(r << 24) | (g << 16) | (b << 8) | 0xFF
})
.collect();
for y in 0..height {
for x in 0..width {
let label = result.labeled.get_pixel(x, y).unwrap_or(0);
if label > 0 && (label as usize) <= colors.len() {
let _ = output.set_pixel(x, y, colors[(label - 1) as usize]);
} else {
let _ = output.set_pixel(x, y, 0x000000FF);
}
}
}
Ok(output.into())
}
pub fn find_basins(pix: &Pix, connectivity: ConnectivityType) -> RegionResult<Pix> {
let options = WatershedOptions::new()
.with_min_depth(0)
.with_connectivity(connectivity);
watershed_segmentation(pix, &options)
}
#[cfg(test)]
mod tests {
use super::*;
fn create_gray_image(width: u32, height: u32, values: &[Vec<u32>]) -> Pix {
let pix = Pix::new(width, height, PixelDepth::Bit8).unwrap();
let mut pix_mut = pix.try_into_mut().unwrap();
for (y, row) in values.iter().enumerate() {
for (x, &value) in row.iter().enumerate() {
let _ = pix_mut.set_pixel(x as u32, y as u32, value);
}
}
pix_mut.into()
}
#[test]
fn test_find_local_minima() {
let values = vec![vec![5, 5, 5], vec![5, 1, 5], vec![5, 5, 5]];
let pix = create_gray_image(3, 3, &values);
let minima = find_local_minima(&pix, ConnectivityType::FourWay).unwrap();
assert_eq!(minima.len(), 1);
assert_eq!(minima[0], (1, 1));
}
#[test]
fn test_find_local_maxima() {
let values = vec![vec![1, 1, 1], vec![1, 5, 1], vec![1, 1, 1]];
let pix = create_gray_image(3, 3, &values);
let maxima = find_local_maxima(&pix, ConnectivityType::FourWay).unwrap();
assert_eq!(maxima.len(), 1);
assert_eq!(maxima[0], (1, 1));
}
#[test]
fn test_compute_gradient() {
let values = vec![
vec![0, 0, 100, 100],
vec![0, 0, 100, 100],
vec![0, 0, 100, 100],
];
let pix = create_gray_image(4, 3, &values);
let gradient = compute_gradient(&pix).unwrap();
let grad_0 = gradient.get_pixel(0, 1).unwrap();
let grad_1 = gradient.get_pixel(1, 1).unwrap();
assert!(grad_1 > grad_0);
}
#[test]
fn test_watershed_two_basins() {
let values = vec![
vec![0, 5, 10, 5, 0],
vec![5, 10, 15, 10, 5],
vec![10, 15, 20, 15, 10],
vec![5, 10, 15, 10, 5],
vec![0, 5, 10, 5, 0],
];
let pix = create_gray_image(5, 5, &values);
let options = WatershedOptions::new().with_min_depth(1);
let result = watershed_segmentation(&pix, &options).unwrap();
let mut labels = std::collections::HashSet::new();
for y in 0..5 {
for x in 0..5 {
if let Some(label) = result.get_pixel(x, y)
&& label > 0
{
labels.insert(label);
}
}
}
assert!(!labels.is_empty());
}
#[test]
fn test_find_basins() {
let values = vec![vec![0, 10, 0], vec![10, 20, 10], vec![0, 10, 0]];
let pix = create_gray_image(3, 3, &values);
let basins = find_basins(&pix, ConnectivityType::FourWay).unwrap();
let label_0_0 = basins.get_pixel(0, 0).unwrap();
let label_2_0 = basins.get_pixel(2, 0).unwrap();
assert!(label_0_0 > 0 || label_2_0 > 0);
}
#[test]
fn test_unsupported_depth() {
let pix = Pix::new(5, 5, PixelDepth::Bit1).unwrap();
let options = WatershedOptions::default();
let result = watershed_segmentation(&pix, &options);
assert!(result.is_err());
}
fn create_two_basin_image() -> Pix {
create_gray_image(
5,
3,
&[
vec![5, 3, 20, 3, 5],
vec![0, 3, 20, 3, 0],
vec![5, 3, 20, 3, 5],
],
)
}
#[test]
fn test_watershed_with_basins_num_basins() {
let pix = create_two_basin_image();
let options = WatershedOptions::new().with_min_depth(1);
let result = watershed_with_basins(&pix, &options).unwrap();
assert_eq!(result.num_basins(), 2);
}
#[test]
fn test_watershed_with_basins_basin_images() {
let pix = create_two_basin_image();
let options = WatershedOptions::new().with_min_depth(1);
let result = watershed_with_basins(&pix, &options).unwrap();
assert_eq!(result.basins().len(), result.num_basins() as usize);
for basin in result.basins() {
assert_eq!(basin.width(), pix.width());
assert_eq!(basin.height(), pix.height());
assert_eq!(basin.depth(), PixelDepth::Bit8);
}
}
#[test]
fn test_watershed_render_fill_min_value() {
let pix = create_two_basin_image();
let options = WatershedOptions::new().with_min_depth(1);
let result = watershed_with_basins(&pix, &options).unwrap();
let filled = watershed_render_fill(&result).unwrap();
assert_eq!(filled.width(), pix.width());
assert_eq!(filled.height(), pix.height());
assert_eq!(filled.depth(), PixelDepth::Bit8);
assert_eq!(filled.get_pixel(0, 1).unwrap(), 0);
assert_eq!(filled.get_pixel(4, 1).unwrap(), 0);
}
#[test]
fn test_watershed_render_colors_32bpp() {
let pix = create_two_basin_image();
let options = WatershedOptions::new().with_min_depth(1);
let result = watershed_with_basins(&pix, &options).unwrap();
let colored = watershed_render_colors(&result).unwrap();
assert_eq!(colored.width(), pix.width());
assert_eq!(colored.height(), pix.height());
assert_eq!(colored.depth(), PixelDepth::Bit32);
}
#[test]
fn test_watershed_render_colors_different_basins() {
let pix = create_two_basin_image();
let options = WatershedOptions::new().with_min_depth(1);
let result = watershed_with_basins(&pix, &options).unwrap();
let colored = watershed_render_colors(&result).unwrap();
let color_left = colored.get_pixel(0, 1).unwrap();
let color_right = colored.get_pixel(4, 1).unwrap();
assert_ne!(color_left, color_right);
assert_ne!(color_left, 0x000000FF);
assert_ne!(color_right, 0x000000FF);
}
}