use crate::{Result, VisionError};
use scirs2_core::ndarray::{Array2, Array3, ArrayView2, ArrayView3};
use scirs2_core::numeric::Float; use scirs2_core::random::thread_rng;
use std::collections::{HashMap, HashSet, VecDeque};
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Connectivity {
Four,
Eight,
}
#[derive(Debug, Clone)]
pub enum WatershedMarkers {
Automatic { min_distance: usize },
Manual(Array2<i32>),
}
#[derive(Debug, Clone)]
pub struct WatershedConfig {
pub connectivity: Connectivity,
pub compactness: f32,
pub return_watershed_lines: bool,
}
impl Default for WatershedConfig {
fn default() -> Self {
Self {
connectivity: Connectivity::Eight,
compactness: 0.0,
return_watershed_lines: false,
}
}
}
pub fn watershed(
image: &Tensor,
markers: &WatershedMarkers,
config: &WatershedConfig,
) -> Result<Tensor> {
let array = tensor_to_array2(image)?;
let marker_array = match markers {
WatershedMarkers::Automatic { min_distance } => {
generate_watershed_markers(&array, *min_distance)?
}
WatershedMarkers::Manual(m) => m.clone(),
};
let result = watershed_segmentation(&array, &marker_array, config)?;
let result_f32 = result.mapv(|x| x as f32);
array2_to_tensor(&result_f32)
}
fn generate_watershed_markers(image: &ArrayView2<f32>, min_distance: usize) -> Result<Array2<i32>> {
let (height, width) = image.dim();
let mut markers = Array2::zeros((height, width));
let mut label = 1i32;
for i in min_distance..(height - min_distance) {
for j in min_distance..(width - min_distance) {
let center_val = image[[i, j]];
let mut is_minimum = true;
for di in -(min_distance as isize)..=(min_distance as isize) {
for dj in -(min_distance as isize)..=(min_distance as isize) {
if di == 0 && dj == 0 {
continue;
}
let ni = (i as isize + di) as usize;
let nj = (j as isize + dj) as usize;
if image[[ni, nj]] < center_val {
is_minimum = false;
break;
}
}
if !is_minimum {
break;
}
}
if is_minimum {
markers[[i, j]] = label;
label += 1;
}
}
}
Ok(markers)
}
fn watershed_segmentation(
image: &ArrayView2<f32>,
markers: &Array2<i32>,
config: &WatershedConfig,
) -> Result<Array2<i32>> {
let (height, width) = image.dim();
let mut labels = markers.clone();
let mut queue = std::collections::BinaryHeap::new();
for i in 1..(height - 1) {
for j in 1..(width - 1) {
if markers[[i, j]] > 0 {
let neighbors = get_neighbors(i, j, height, width, &config.connectivity);
for (ni, nj) in neighbors {
if markers[[ni, nj]] == 0 {
queue.push(std::cmp::Reverse((
(image[[ni, nj]] * 1000.0) as i32,
ni,
nj,
)));
}
}
}
}
}
let mut visited = HashSet::new();
while let Some(std::cmp::Reverse((_, i, j))) = queue.pop() {
if visited.contains(&(i, j)) {
continue;
}
visited.insert((i, j));
let neighbors = get_neighbors(i, j, height, width, &config.connectivity);
let mut neighbor_labels: Vec<i32> = neighbors
.iter()
.filter_map(|&(ni, nj)| {
let label = labels[[ni, nj]];
if label > 0 {
Some(label)
} else {
None
}
})
.collect();
neighbor_labels.sort_unstable();
neighbor_labels.dedup();
if neighbor_labels.len() == 1 {
labels[[i, j]] = neighbor_labels[0];
for (ni, nj) in neighbors {
if labels[[ni, nj]] == 0 && !visited.contains(&(ni, nj)) {
queue.push(std::cmp::Reverse((
(image[[ni, nj]] * 1000.0) as i32,
ni,
nj,
)));
}
}
} else if neighbor_labels.len() > 1 && config.return_watershed_lines {
labels[[i, j]] = -1; } else if neighbor_labels.len() > 1 {
labels[[i, j]] = neighbor_labels[0];
}
}
Ok(labels)
}
fn get_neighbors(
i: usize,
j: usize,
height: usize,
width: usize,
connectivity: &Connectivity,
) -> Vec<(usize, usize)> {
let mut neighbors = Vec::new();
if i > 0 {
neighbors.push((i - 1, j));
}
if i < height - 1 {
neighbors.push((i + 1, j));
}
if j > 0 {
neighbors.push((i, j - 1));
}
if j < width - 1 {
neighbors.push((i, j + 1));
}
if *connectivity == Connectivity::Eight {
if i > 0 && j > 0 {
neighbors.push((i - 1, j - 1));
}
if i > 0 && j < width - 1 {
neighbors.push((i - 1, j + 1));
}
if i < height - 1 && j > 0 {
neighbors.push((i + 1, j - 1));
}
if i < height - 1 && j < width - 1 {
neighbors.push((i + 1, j + 1));
}
}
neighbors
}
#[derive(Debug, Clone)]
pub struct GraphCutsConfig {
pub spatial_weight: f32,
pub max_iterations: usize,
pub convergence_threshold: f32,
}
impl Default for GraphCutsConfig {
fn default() -> Self {
Self {
spatial_weight: 1.0,
max_iterations: 100,
convergence_threshold: 1e-4,
}
}
}
pub fn graph_cuts(
image: &Tensor,
foreground_seeds: &[(usize, usize)],
background_seeds: &[(usize, usize)],
config: &GraphCutsConfig,
) -> Result<Tensor> {
let array = tensor_to_array3(image)?;
let (height, width, _) = array.dim();
let mut mask = Array2::zeros((height, width));
for &(i, j) in foreground_seeds {
if i < height && j < width {
mask[[i, j]] = 1.0;
}
}
for &(i, j) in background_seeds {
if i < height && j < width {
mask[[i, j]] = 0.0;
}
}
let fg_model = build_color_model(&array, foreground_seeds)?;
let bg_model = build_color_model(&array, background_seeds)?;
for _iteration in 0..config.max_iterations {
let old_mask = mask.clone();
for i in 0..height {
for j in 0..width {
if foreground_seeds.contains(&(i, j)) || background_seeds.contains(&(i, j)) {
continue;
}
let pixel = array.slice(scirs2_core::ndarray::s![i, j, ..]);
let fg_likelihood = compute_likelihood(&pixel, &fg_model);
let bg_likelihood = compute_likelihood(&pixel, &bg_model);
let neighbors = get_neighbors(i, j, height, width, &Connectivity::Eight);
let num_neighbors = neighbors.len() as f32;
let mut spatial_term = 0.0;
for (ni, nj) in &neighbors {
if mask[[*ni, *nj]] == 1.0 {
spatial_term += 1.0;
}
}
spatial_term *= config.spatial_weight;
let fg_energy = -fg_likelihood.ln()
+ spatial_term * (num_neighbors - spatial_term / config.spatial_weight);
let bg_energy = -bg_likelihood.ln() + spatial_term;
mask[[i, j]] = if fg_energy < bg_energy { 1.0 } else { 0.0 };
}
}
let change = (&mask - &old_mask).map(|x| x.abs()).sum();
if change < config.convergence_threshold {
break;
}
}
array2_to_tensor(&mask)
}
type ColorModel = (Array3<f32>, Array3<f32>);
fn build_color_model(image: &Array3<f32>, seeds: &[(usize, usize)]) -> Result<ColorModel> {
if seeds.is_empty() {
return Err(VisionError::InvalidParameter(
"Seeds cannot be empty".to_string(),
));
}
let channels = image.shape()[2];
let mut mean = Array3::zeros((1, 1, channels));
let mut std = Array3::zeros((1, 1, channels));
for &(i, j) in seeds {
for c in 0..channels {
mean[[0, 0, c]] += image[[i, j, c]];
}
}
mean /= seeds.len() as f32;
for &(i, j) in seeds {
for c in 0..channels {
let diff = image[[i, j, c]] - mean[[0, 0, c]];
std[[0, 0, c]] += diff * diff;
}
}
std /= seeds.len() as f32;
std.mapv_inplace(|x| x.sqrt().max(0.01));
Ok((mean, std))
}
fn compute_likelihood(pixel: &scirs2_core::ndarray::ArrayView1<f32>, model: &ColorModel) -> f32 {
let (mean, std) = model;
let mut likelihood = 1.0;
for c in 0..pixel.len() {
let diff = pixel[c] - mean[[0, 0, c]];
let variance = std[[0, 0, c]] * std[[0, 0, c]];
let exp_term = (-0.5 * diff * diff / variance).exp();
likelihood *= exp_term / (std[[0, 0, c]] * (2.0 * std::f32::consts::PI).sqrt());
}
likelihood
}
#[derive(Debug, Clone)]
pub struct RegionGrowingConfig {
pub connectivity: Connectivity,
pub intensity_threshold: f32,
pub adaptive: bool,
pub max_region_size: usize,
}
impl Default for RegionGrowingConfig {
fn default() -> Self {
Self {
connectivity: Connectivity::Eight,
intensity_threshold: 10.0,
adaptive: true,
max_region_size: 0,
}
}
}
pub fn region_growing(
image: &Tensor,
seeds: &[(usize, usize)],
config: &RegionGrowingConfig,
) -> Result<Tensor> {
let array = tensor_to_array2(image)?;
let (height, width) = array.dim();
let mut labels = Array2::zeros((height, width));
for (label, &(seed_i, seed_j)) in seeds.iter().enumerate() {
if seed_i >= height || seed_j >= width {
continue;
}
if labels[[seed_i, seed_j]] != 0.0 {
continue; }
let seed_value = array[[seed_i, seed_j]];
let mut queue = VecDeque::new();
queue.push_back((seed_i, seed_j));
labels[[seed_i, seed_j]] = (label + 1) as f32;
let mut region_size = 1;
let max_size = if config.max_region_size == 0 {
usize::MAX
} else {
config.max_region_size
};
while let Some((i, j)) = queue.pop_front() {
if region_size >= max_size {
break;
}
let _current_value = array[[i, j]];
let threshold = if config.adaptive {
let local_mean = compute_local_mean(&array.to_owned(), i, j, 3);
config.intensity_threshold * (local_mean / 128.0).max(0.5)
} else {
config.intensity_threshold
};
let neighbors = get_neighbors(i, j, height, width, &config.connectivity);
for (ni, nj) in neighbors {
if labels[[ni, nj]] == 0.0 {
let neighbor_value = array[[ni, nj]];
if (neighbor_value - seed_value).abs() < threshold {
labels[[ni, nj]] = (label + 1) as f32;
queue.push_back((ni, nj));
region_size += 1;
if region_size >= max_size {
break;
}
}
}
}
}
}
array2_to_tensor(&labels)
}
fn compute_local_mean(array: &Array2<f32>, i: usize, j: usize, radius: usize) -> f32 {
let (height, width) = array.dim();
let mut sum = 0.0;
let mut count = 0;
let i_start = i.saturating_sub(radius);
let i_end = (i + radius + 1).min(height);
let j_start = j.saturating_sub(radius);
let j_end = (j + radius + 1).min(width);
for ni in i_start..i_end {
for nj in j_start..j_end {
sum += array[[ni, nj]];
count += 1;
}
}
if count > 0 {
sum / count as f32
} else {
0.0
}
}
fn tensor_to_array2(tensor: &Tensor) -> Result<ArrayView2<'_, f32>> {
let shape = tensor.shape();
if shape.len() != 2 {
return Err(VisionError::InvalidParameter(format!(
"Expected 2D tensor, got {}D",
shape.len()
)));
}
Err(VisionError::InvalidParameter(
"Tensor-to-array conversion not yet implemented".to_string(),
))
}
fn tensor_to_array3(tensor: &Tensor) -> Result<Array3<f32>> {
let shape = tensor.shape();
if shape.len() != 3 {
return Err(VisionError::InvalidParameter(format!(
"Expected 3D tensor, got {}D",
shape.len()
)));
}
Err(VisionError::InvalidParameter(
"Tensor-to-array conversion not yet implemented".to_string(),
))
}
fn array2_to_tensor(_array: &Array2<f32>) -> Result<Tensor> {
Err(VisionError::InvalidParameter(
"Array-to-tensor conversion not yet implemented".to_string(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::{array, Array2};
#[test]
fn test_get_neighbors_4connectivity() {
let neighbors = get_neighbors(5, 5, 10, 10, &Connectivity::Four);
assert_eq!(neighbors.len(), 4);
assert!(neighbors.contains(&(4, 5)));
assert!(neighbors.contains(&(6, 5)));
assert!(neighbors.contains(&(5, 4)));
assert!(neighbors.contains(&(5, 6)));
}
#[test]
fn test_get_neighbors_8connectivity() {
let neighbors = get_neighbors(5, 5, 10, 10, &Connectivity::Eight);
assert_eq!(neighbors.len(), 8);
}
#[test]
fn test_get_neighbors_boundary() {
let neighbors = get_neighbors(0, 0, 10, 10, &Connectivity::Four);
assert_eq!(neighbors.len(), 2); }
#[test]
fn test_watershed_config_default() {
let config = WatershedConfig::default();
assert_eq!(config.connectivity, Connectivity::Eight);
assert_eq!(config.compactness, 0.0);
assert!(!config.return_watershed_lines);
}
#[test]
fn test_graph_cuts_config_default() {
let config = GraphCutsConfig::default();
assert_eq!(config.spatial_weight, 1.0);
assert_eq!(config.max_iterations, 100);
}
#[test]
fn test_region_growing_config_default() {
let config = RegionGrowingConfig::default();
assert_eq!(config.connectivity, Connectivity::Eight);
assert_eq!(config.intensity_threshold, 10.0);
assert!(config.adaptive);
}
#[test]
fn test_generate_watershed_markers() {
let image = Array2::from_shape_fn((10, 10), |(i, j)| {
((i as f32 - 5.0).powi(2) + (j as f32 - 5.0).powi(2)).sqrt()
});
let markers =
generate_watershed_markers(&image.view(), 2).expect("Failed to generate markers");
let num_markers = markers.iter().filter(|&&x| x > 0).count();
assert!(num_markers >= 1);
}
#[test]
fn test_compute_local_mean() {
let array = Array2::from_elem((10, 10), 5.0);
let mean = compute_local_mean(&array, 5, 5, 2);
assert!((mean - 5.0).abs() < 1e-5);
}
}