use crate::error::{Result, VisionError};
use image::DynamicImage;
use scirs2_core::ndarray::Array2;
#[derive(Debug, Clone)]
pub enum SegmentMethod {
Otsu,
Adaptive {
block_size: usize,
c: f32,
method: super::AdaptiveMethod,
},
KMeans {
k: usize,
max_iterations: usize,
},
GrabCut {
rect: (u32, u32, u32, u32),
n_components: usize,
},
Watershed {
n_markers: Option<usize>,
connectivity: u8,
},
}
#[derive(Debug, Clone)]
pub struct SegmentResult {
pub labels: Array2<u32>,
pub n_segments: usize,
pub method_name: String,
}
pub fn segment(img: &DynamicImage, method: SegmentMethod) -> Result<SegmentResult> {
match method {
SegmentMethod::Otsu => segment_otsu(img),
SegmentMethod::Adaptive {
block_size,
c,
method: adaptive_method,
} => segment_adaptive(img, block_size, c, adaptive_method),
SegmentMethod::KMeans { k, max_iterations } => segment_kmeans(img, k, max_iterations),
SegmentMethod::GrabCut { rect, n_components } => segment_grabcut(img, rect, n_components),
SegmentMethod::Watershed {
n_markers,
connectivity,
} => segment_watershed(img, n_markers, connectivity),
}
}
fn segment_otsu(img: &DynamicImage) -> Result<SegmentResult> {
let (binary, _threshold) = super::otsu_threshold(img)?;
let (width, height) = binary.dimensions();
let h = height as usize;
let w = width as usize;
let mut labels = Array2::zeros((h, w));
for y in 0..h {
for x in 0..w {
labels[[y, x]] = if binary.get_pixel(x as u32, y as u32)[0] > 0 {
1
} else {
0
};
}
}
Ok(SegmentResult {
labels,
n_segments: 2,
method_name: "Otsu".to_string(),
})
}
fn segment_adaptive(
img: &DynamicImage,
block_size: usize,
c: f32,
method: super::AdaptiveMethod,
) -> Result<SegmentResult> {
let binary = super::adaptive_threshold(img, block_size, c, method)?;
let (width, height) = binary.dimensions();
let h = height as usize;
let w = width as usize;
let mut labels = Array2::zeros((h, w));
for y in 0..h {
for x in 0..w {
labels[[y, x]] = if binary.get_pixel(x as u32, y as u32)[0] > 0 {
1
} else {
0
};
}
}
Ok(SegmentResult {
labels,
n_segments: 2,
method_name: "Adaptive".to_string(),
})
}
fn segment_kmeans(img: &DynamicImage, k: usize, max_iterations: usize) -> Result<SegmentResult> {
let params = super::kmeans_seg::KMeansSegParams {
k,
max_iterations,
epsilon: 1e-4,
n_init: 3,
use_color: true,
};
let result = super::kmeans_seg::kmeans_segment(img, ¶ms)?;
Ok(SegmentResult {
labels: result.labels,
n_segments: k,
method_name: format!("KMeans(k={})", k),
})
}
fn segment_grabcut(
img: &DynamicImage,
rect: (u32, u32, u32, u32),
n_components: usize,
) -> Result<SegmentResult> {
let params = super::grabcut::GrabCutParams {
n_components,
max_iterations: 10,
epsilon: 1e-3,
smoothness: 50.0,
};
let result = super::grabcut::grabcut_rect(img, rect, ¶ms)?;
let (h, w) = result.foreground_mask.dim();
let mut labels = Array2::zeros((h, w));
for y in 0..h {
for x in 0..w {
labels[[y, x]] = if result.foreground_mask[[y, x]] { 1 } else { 0 };
}
}
Ok(SegmentResult {
labels,
n_segments: 2,
method_name: "GrabCut".to_string(),
})
}
fn segment_watershed(
img: &DynamicImage,
_n_markers: Option<usize>,
connectivity: u8,
) -> Result<SegmentResult> {
let conn = if connectivity == 4 { 4 } else { 8 };
let labels = super::watershed::watershed(img, None, conn)?;
let mut unique_labels = std::collections::HashSet::new();
for &label in labels.iter() {
unique_labels.insert(label);
}
Ok(SegmentResult {
labels,
n_segments: unique_labels.len(),
method_name: "Watershed".to_string(),
})
}
#[cfg(test)]
mod tests {
use super::*;
use image::{GrayImage, Luma};
fn create_bimodal_image() -> DynamicImage {
let mut buf = GrayImage::new(40, 32);
for y in 0..32u32 {
for x in 0..40u32 {
let val = if x < 15 {
220u8
} else if x > 24 {
20u8
} else {
let t = (x - 15) as f32 / 10.0;
(220.0 * (1.0 - t) + 20.0 * t) as u8
};
buf.put_pixel(x, y, Luma([val]));
}
}
DynamicImage::ImageLuma8(buf)
}
#[test]
fn test_segment_otsu() {
let img = create_bimodal_image();
let result = segment(&img, SegmentMethod::Otsu).expect("Otsu failed");
assert_eq!(result.n_segments, 2);
assert_eq!(result.labels.dim(), (32, 40));
assert_eq!(result.method_name, "Otsu");
let bright = result.labels[[16, 5]]; let dark = result.labels[[16, 35]]; assert_ne!(bright, dark, "Otsu should separate bright and dark regions");
}
#[test]
fn test_segment_adaptive_mean() {
let img = create_bimodal_image();
let result = segment(
&img,
SegmentMethod::Adaptive {
block_size: 7,
c: 0.0,
method: super::super::AdaptiveMethod::Mean,
},
)
.expect("Adaptive mean failed");
assert_eq!(result.n_segments, 2);
assert_eq!(result.labels.dim(), (32, 40));
}
#[test]
fn test_segment_adaptive_gaussian() {
let img = create_bimodal_image();
let result = segment(
&img,
SegmentMethod::Adaptive {
block_size: 7,
c: 0.0,
method: super::super::AdaptiveMethod::Gaussian,
},
)
.expect("Adaptive gaussian failed");
assert_eq!(result.n_segments, 2);
}
#[test]
fn test_segment_kmeans() {
let mut buf = image::RgbImage::new(20, 20);
for y in 0..20u32 {
for x in 0..20u32 {
let color = if x < 10 {
[200u8, 50, 50]
} else {
[50u8, 50, 200]
};
buf.put_pixel(x, y, image::Rgb(color));
}
}
let img = DynamicImage::ImageRgb8(buf);
let result = segment(
&img,
SegmentMethod::KMeans {
k: 2,
max_iterations: 100,
},
)
.expect("KMeans failed");
assert_eq!(result.n_segments, 2);
assert_eq!(result.labels.dim(), (20, 20));
}
#[test]
fn test_segment_grabcut() {
let mut buf = image::RgbImage::new(20, 20);
for y in 0..20u32 {
for x in 0..20u32 {
let is_center = (5..15).contains(&x) && (5..15).contains(&y);
let color = if is_center {
[220u8, 220, 220]
} else {
[20u8, 20, 20]
};
buf.put_pixel(x, y, image::Rgb(color));
}
}
let img = DynamicImage::ImageRgb8(buf);
let result = segment(
&img,
SegmentMethod::GrabCut {
rect: (4, 4, 12, 12),
n_components: 3,
},
)
.expect("GrabCut failed");
assert_eq!(result.n_segments, 2);
assert_eq!(result.labels.dim(), (20, 20));
}
#[test]
fn test_segment_result_labels_range() {
let mut buf = GrayImage::new(32, 32);
for y in 0..32u32 {
for x in 0..32u32 {
buf.put_pixel(x, y, Luma([if x < 16 { 240u8 } else { 10u8 }]));
}
}
for y in 0..32u32 {
buf.put_pixel(15, y, Luma([125u8]));
buf.put_pixel(16, y, Luma([125u8]));
}
let img = DynamicImage::ImageLuma8(buf);
let result = segment(&img, SegmentMethod::Otsu).expect("Otsu failed");
for &label in result.labels.iter() {
assert!(label <= 1, "Label should be 0 or 1, got {}", label);
}
}
#[test]
fn test_segment_kmeans_three_regions() {
let mut buf = image::RgbImage::new(30, 10);
for y in 0..10u32 {
for x in 0..10u32 {
buf.put_pixel(x, y, image::Rgb([255, 0, 0]));
buf.put_pixel(x + 10, y, image::Rgb([0, 255, 0]));
buf.put_pixel(x + 20, y, image::Rgb([0, 0, 255]));
}
}
let img = DynamicImage::ImageRgb8(buf);
let result = segment(
&img,
SegmentMethod::KMeans {
k: 3,
max_iterations: 100,
},
)
.expect("KMeans 3-cluster failed");
assert_eq!(result.n_segments, 3);
let l0 = result.labels[[5, 5]];
let l1 = result.labels[[5, 15]];
let l2 = result.labels[[5, 25]];
assert_ne!(l0, l1);
assert_ne!(l1, l2);
assert_ne!(l0, l2);
}
}