use crate::error::{CvError, CvResult};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SegmentationError {
ImageTooSmall,
InvalidDimensions,
}
impl std::fmt::Display for SegmentationError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::ImageTooSmall => write!(f, "image is too small for segmentation"),
Self::InvalidDimensions => write!(f, "pixel buffer size does not match width × height"),
}
}
}
impl std::error::Error for SegmentationError {}
#[derive(Debug, Clone)]
pub struct SegmentationConfig {
pub min_area_pixels: u32,
pub max_instances: u32,
pub threshold: u8,
}
impl Default for SegmentationConfig {
fn default() -> Self {
Self {
min_area_pixels: 50,
max_instances: 32,
threshold: 127,
}
}
}
impl SegmentationConfig {
pub fn validate(&self) -> CvResult<()> {
if self.max_instances == 0 {
return Err(CvError::invalid_parameter(
"max_instances",
"must be at least 1",
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SegmentMask {
pub object_id: u32,
pub class_id: u32,
pub confidence: f32,
pub mask: Vec<u8>,
pub bounding_box: (u32, u32, u32, u32),
}
impl SegmentMask {
#[must_use]
pub fn bbox_area(&self) -> u64 {
let (x0, y0, x1, y1) = self.bounding_box;
let w = x1.saturating_sub(x0) as u64 + 1;
let h = y1.saturating_sub(y0) as u64 + 1;
w * h
}
#[must_use]
pub fn pixel_count(&self) -> usize {
self.mask.iter().filter(|&&b| b > 0).count()
}
}
struct UnionFind {
parent: Vec<usize>,
rank: Vec<u8>,
}
impl UnionFind {
fn new(n: usize) -> Self {
Self {
parent: (0..n).collect(),
rank: vec![0; n],
}
}
fn find(&mut self, mut x: usize) -> usize {
while self.parent[x] != x {
let grandparent = self.parent[self.parent[x]];
self.parent[x] = grandparent;
x = grandparent;
}
x
}
fn union(&mut self, a: usize, b: usize) {
let ra = self.find(a);
let rb = self.find(b);
if ra == rb {
return;
}
match self.rank[ra].cmp(&self.rank[rb]) {
std::cmp::Ordering::Less => self.parent[ra] = rb,
std::cmp::Ordering::Greater => self.parent[rb] = ra,
std::cmp::Ordering::Equal => {
self.parent[rb] = ra;
self.rank[ra] = self.rank[ra].saturating_add(1);
}
}
}
}
#[derive(Debug, Clone)]
pub struct InstanceSegmenter {
config: SegmentationConfig,
}
impl InstanceSegmenter {
#[must_use]
pub fn new(config: SegmentationConfig) -> Self {
Self { config }
}
#[must_use]
pub fn default() -> Self {
Self::new(SegmentationConfig::default())
}
fn to_gray(image: &[u8], width: u32, height: u32) -> Option<Vec<u8>> {
let n = (width as usize) * (height as usize);
if image.is_empty() || n == 0 {
return Some(vec![]);
}
let bpp = image.len() / n;
match bpp {
1 => Some(image.to_vec()),
3 => Some(
image
.chunks_exact(3)
.map(|c| {
let r = u32::from(c[0]);
let g = u32::from(c[1]);
let b = u32::from(c[2]);
((r * 299 + g * 587 + b * 114) / 1000) as u8
})
.collect(),
),
4 => Some(
image
.chunks_exact(4)
.map(|c| {
let r = u32::from(c[0]);
let g = u32::from(c[1]);
let b = u32::from(c[2]);
((r * 299 + g * 587 + b * 114) / 1000) as u8
})
.collect(),
),
_ => None,
}
}
fn validate(image: &[u8], width: u32, height: u32) -> Result<(), SegmentationError> {
if width == 0 || height == 0 {
return Err(SegmentationError::ImageTooSmall);
}
let n = (width as usize) * (height as usize);
if image.len() < n {
return Err(SegmentationError::InvalidDimensions);
}
Ok(())
}
#[allow(clippy::cast_precision_loss)]
pub fn segment(&self, image: &[u8], width: u32, height: u32) -> Vec<SegmentMask> {
if Self::validate(image, width, height).is_err() {
return Vec::new();
}
let gray = match Self::to_gray(image, width, height) {
Some(g) => g,
None => return Vec::new(),
};
let w = width as usize;
let h = height as usize;
let n = w * h;
let foreground: Vec<bool> = gray.iter().map(|&p| p > self.config.threshold).collect();
let mut uf = UnionFind::new(n);
for y in 0..h {
for x in 0..w {
let idx = y * w + x;
if !foreground[idx] {
continue;
}
let neighbours: &[(i64, i64)] = &[(-1, -1), (0, -1), (1, -1), (-1, 0)];
for &(dx, dy) in neighbours {
let nx = x as i64 + dx;
let ny = y as i64 + dy;
if nx < 0 || ny < 0 || nx >= w as i64 || ny >= h as i64 {
continue;
}
let nidx = ny as usize * w + nx as usize;
if foreground[nidx] {
uf.union(idx, nidx);
}
}
}
}
use std::collections::HashMap;
let mut components: HashMap<usize, (u32, u64, u32, u32, u32, u32)> = HashMap::new();
for y in 0..h {
for x in 0..w {
let idx = y * w + x;
if !foreground[idx] {
continue;
}
let root = uf.find(idx);
let intensity = u64::from(gray[idx]);
let entry = components
.entry(root)
.or_insert((0, 0, x as u32, y as u32, x as u32, y as u32));
entry.0 += 1;
entry.1 += intensity;
if (x as u32) < entry.2 {
entry.2 = x as u32;
}
if (y as u32) < entry.3 {
entry.3 = y as u32;
}
if (x as u32) > entry.4 {
entry.4 = x as u32;
}
if (y as u32) > entry.5 {
entry.5 = y as u32;
}
}
}
let min_area = self.config.min_area_pixels;
let mut surviving: Vec<(usize, u32, u64, u32, u32, u32, u32)> = components
.into_iter()
.filter(|(_, (count, ..))| *count >= min_area)
.map(|(root, (count, sum, x0, y0, x1, y1))| (root, count, sum, x0, y0, x1, y1))
.collect();
surviving.sort_by(|a, b| b.1.cmp(&a.1));
surviving.truncate(self.config.max_instances as usize);
if surviving.is_empty() {
return Vec::new();
}
let accepted: std::collections::HashSet<usize> =
surviving.iter().map(|(root, ..)| *root).collect();
let root_to_idx: HashMap<usize, usize> = surviving
.iter()
.enumerate()
.map(|(i, (root, ..))| (*root, i))
.collect();
let num_instances = surviving.len();
let mut masks_data: Vec<Vec<u8>> = vec![vec![0u8; n]; num_instances];
for y in 0..h {
for x in 0..w {
let idx = y * w + x;
if !foreground[idx] {
continue;
}
let root = uf.find(idx);
if !accepted.contains(&root) {
continue;
}
if let Some(&inst_idx) = root_to_idx.get(&root) {
masks_data[inst_idx][idx] = 255;
}
}
}
surviving
.into_iter()
.zip(masks_data.into_iter())
.enumerate()
.map(|(i, ((_, count, sum, x0, y0, x1, y1), mask))| {
let confidence = if count > 0 {
(sum as f32 / (count as f32 * 255.0)).clamp(0.0, 1.0)
} else {
0.0
};
SegmentMask {
object_id: (i as u32) + 1,
class_id: 1,
confidence,
mask,
bounding_box: (x0, y0, x1, y1),
}
})
.collect()
}
pub fn segment_checked(
&self,
image: &[u8],
width: u32,
height: u32,
) -> Result<Vec<SegmentMask>, SegmentationError> {
Self::validate(image, width, height)?;
Ok(self.segment(image, width, height))
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_image_with_rect(
w: u32,
h: u32,
bg: u8,
rx0: u32,
ry0: u32,
rx1: u32,
ry1: u32,
fg: u8,
) -> Vec<u8> {
let mut img = vec![bg; (w * h) as usize];
for y in ry0..ry1 {
for x in rx0..rx1 {
img[(y * w + x) as usize] = fg;
}
}
img
}
fn default_segmenter() -> InstanceSegmenter {
InstanceSegmenter::new(SegmentationConfig {
min_area_pixels: 4,
max_instances: 32,
threshold: 127,
})
}
#[test]
fn test_empty_image_returns_empty() {
let seg = default_segmenter();
let result = seg.segment(&[], 0, 0);
assert!(result.is_empty(), "expected empty for zero-size image");
}
#[test]
fn test_all_black_returns_empty() {
let seg = default_segmenter();
let img = vec![0u8; 20 * 20];
let result = seg.segment(&img, 20, 20);
assert!(result.is_empty(), "no foreground pixels → no masks");
}
#[test]
fn test_single_object_detected() {
let seg = default_segmenter();
let img = make_image_with_rect(20, 20, 0, 6, 6, 14, 14, 255);
let masks = seg.segment(&img, 20, 20);
assert_eq!(masks.len(), 1, "exactly one object expected");
let m = &masks[0];
assert_eq!(m.object_id, 1);
assert_eq!(m.class_id, 1);
assert!(m.confidence > 0.0);
assert_eq!(m.mask.len(), 20 * 20);
assert_eq!(m.pixel_count(), 64);
}
#[test]
fn test_single_object_bounding_box() {
let seg = default_segmenter();
let img = make_image_with_rect(20, 20, 0, 5, 3, 10, 8, 255);
let masks = seg.segment(&img, 20, 20);
assert_eq!(masks.len(), 1);
let (x0, y0, x1, y1) = masks[0].bounding_box;
assert_eq!(x0, 5);
assert_eq!(y0, 3);
assert_eq!(x1, 9); assert_eq!(y1, 7);
}
#[test]
fn test_two_separated_objects() {
let seg = default_segmenter();
let mut img = vec![0u8; 30 * 30];
for y in 2..6usize {
for x in 2..6usize {
img[y * 30 + x] = 255;
}
}
for y in 20..24usize {
for x in 20..24usize {
img[y * 30 + x] = 255;
}
}
let masks = seg.segment(&img, 30, 30);
assert_eq!(masks.len(), 2, "two objects expected");
let ids: Vec<u32> = masks.iter().map(|m| m.object_id).collect();
assert!(ids.contains(&1) && ids.contains(&2));
}
#[test]
fn test_max_instances_cap() {
let seg = InstanceSegmenter::new(SegmentationConfig {
min_area_pixels: 1,
max_instances: 1,
threshold: 127,
});
let mut img = vec![0u8; 30 * 30];
for y in 0..4usize {
for x in 0..4usize {
img[y * 30 + x] = 255;
}
}
for y in 20..24usize {
for x in 20..24usize {
img[y * 30 + x] = 255;
}
}
let masks = seg.segment(&img, 30, 30);
assert!(masks.len() <= 1, "max_instances cap not respected");
}
#[test]
fn test_min_area_filter_removes_small_blobs() {
let seg = InstanceSegmenter::new(SegmentationConfig {
min_area_pixels: 100,
max_instances: 32,
threshold: 127,
});
let img = make_image_with_rect(20, 20, 0, 2, 2, 6, 6, 255);
let masks = seg.segment(&img, 20, 20);
assert!(masks.is_empty(), "blob should be filtered by min_area");
}
#[test]
fn test_rgb_image_supported() {
let seg = default_segmenter();
let img = vec![255u8; 10 * 10 * 3];
let masks = seg.segment(&img, 10, 10);
assert_eq!(masks.len(), 1);
}
#[test]
fn test_rgba_image_supported() {
let seg = default_segmenter();
let img = vec![255u8; 10 * 10 * 4];
let masks = seg.segment(&img, 10, 10);
assert_eq!(masks.len(), 1);
}
#[test]
fn test_checked_invalid_dimensions() {
let seg = default_segmenter();
let result = seg.segment_checked(&[], 0, 0);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SegmentationError::ImageTooSmall);
}
#[test]
fn test_checked_buffer_too_small() {
let seg = default_segmenter();
let result = seg.segment_checked(&[0u8; 10], 20, 20);
assert!(result.is_err());
assert_eq!(result.unwrap_err(), SegmentationError::InvalidDimensions);
}
#[test]
fn test_config_validate_max_instances_zero() {
let cfg = SegmentationConfig {
min_area_pixels: 10,
max_instances: 0,
threshold: 127,
};
assert!(cfg.validate().is_err());
}
#[test]
fn test_pixel_count_matches_segment_count() {
let seg = default_segmenter();
let img = make_image_with_rect(20, 20, 0, 2, 2, 8, 8, 255);
let masks = seg.segment(&img, 20, 20);
assert!(!masks.is_empty());
assert_eq!(masks[0].pixel_count(), 36);
}
#[test]
fn test_mask_length_equals_image_size() {
let seg = default_segmenter();
let img = make_image_with_rect(15, 15, 0, 3, 3, 10, 10, 200);
let masks = seg.segment(&img, 15, 15);
if let Some(m) = masks.first() {
assert_eq!(m.mask.len(), 15 * 15);
}
}
}