Skip to main content

scirs2_vision/
instance_segmentation.rs

1//! Instance segmentation algorithms
2//!
3//! This module provides instance-level segmentation functionality including:
4//! - Watershed-based instance separation
5//! - Mask IoU and Non-Maximum Suppression
6//! - Panoptic Quality metric
7//! - Instance overlap utilities
8
9use crate::error::{Result, VisionError};
10use scirs2_core::ndarray::Array2;
11use std::cmp::Reverse;
12use std::collections::{BinaryHeap, HashMap, HashSet};
13
14// ---------------------------------------------------------------------------
15// InstanceMask
16// ---------------------------------------------------------------------------
17
18/// A single instance produced by an instance segmentation model.
19#[derive(Debug, Clone)]
20pub struct InstanceMask {
21    /// Class identifier (0-indexed)
22    pub class_id: usize,
23    /// Detection confidence / quality score in [0, 1]
24    pub score: f64,
25    /// Binary pixel mask (`true` = object foreground)
26    pub mask: Array2<bool>,
27    /// Tight axis-aligned bounding box `[y_min, x_min, y_max, x_max]`
28    pub bbox: [usize; 4],
29}
30
31impl InstanceMask {
32    /// Construct a new [`InstanceMask`] with an automatically computed bounding box.
33    ///
34    /// If the mask is all-false the bbox is set to `[0, 0, 0, 0]`.
35    pub fn new(class_id: usize, score: f64, mask: Array2<bool>) -> Self {
36        let bbox = compute_bbox(&mask);
37        Self {
38            class_id,
39            score,
40            mask,
41            bbox,
42        }
43    }
44
45    /// Return the number of foreground pixels.
46    pub fn area(&self) -> usize {
47        self.mask.iter().filter(|&&v| v).count()
48    }
49}
50
51/// Compute a tight bounding box from a binary mask.
52///
53/// Returns `[y_min, x_min, y_max, x_max]`.  Returns `[0,0,0,0]` for empty masks.
54fn compute_bbox(mask: &Array2<bool>) -> [usize; 4] {
55    let (height, width) = mask.dim();
56    let mut y_min = height;
57    let mut y_max = 0usize;
58    let mut x_min = width;
59    let mut x_max = 0usize;
60    let mut found = false;
61
62    for y in 0..height {
63        for x in 0..width {
64            if mask[[y, x]] {
65                found = true;
66                if y < y_min {
67                    y_min = y;
68                }
69                if y > y_max {
70                    y_max = y;
71                }
72                if x < x_min {
73                    x_min = x;
74                }
75                if x > x_max {
76                    x_max = x;
77                }
78            }
79        }
80    }
81
82    if found {
83        [y_min, x_min, y_max, x_max]
84    } else {
85        [0, 0, 0, 0]
86    }
87}
88
89// ---------------------------------------------------------------------------
90// Watershed instance segmentation
91// ---------------------------------------------------------------------------
92
93/// Priority queue entry for the watershed flooding (min-heap by gradient value).
94#[derive(PartialEq)]
95struct WatershedEntry {
96    /// Negated gradient so that `BinaryHeap` (max-heap) acts as a min-heap.
97    neg_gradient: ordered_float::NotNan<f64>,
98    y: usize,
99    x: usize,
100}
101
102impl Eq for WatershedEntry {}
103
104impl PartialOrd for WatershedEntry {
105    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
106        Some(self.cmp(other))
107    }
108}
109
110impl Ord for WatershedEntry {
111    fn cmp(&self, other: &Self) -> std::cmp::Ordering {
112        // max-heap on neg_gradient → min-heap on actual gradient
113        self.neg_gradient.cmp(&other.neg_gradient)
114    }
115}
116
117/// Marker-controlled watershed segmentation.
118///
119/// Floods a gradient image starting from pre-placed markers.  Each marker
120/// expands into its catchment basin.  Pixels labelled `0` in `markers` are
121/// uninitialized and will be flooded; non-zero pixels are seeds.
122/// The returned label map uses the same non-zero label values as `markers`;
123/// pixels where no basin reached them retain label `0`.
124///
125/// # Arguments
126/// * `gradient` - Gradient magnitude image `[height, width]` (higher = boundary)
127/// * `markers`  - Initial marker map; `0` = unlabelled, positive = seed label
128pub fn watershed_instance(gradient: &Array2<f64>, markers: &Array2<i32>) -> Result<Array2<i32>> {
129    let (height, width) = gradient.dim();
130    let (mh, mw) = markers.dim();
131    if height != mh || width != mw {
132        return Err(VisionError::DimensionMismatch(format!(
133            "gradient ({height}×{width}) and markers ({mh}×{mw}) must have the same shape"
134        )));
135    }
136    if height == 0 || width == 0 {
137        return Err(VisionError::InvalidParameter(
138            "gradient must be non-empty".to_string(),
139        ));
140    }
141
142    let mut output = markers.to_owned();
143    let mut in_queue = Array2::<bool>::from_elem((height, width), false);
144
145    let mut heap: BinaryHeap<WatershedEntry> = BinaryHeap::new();
146
147    // Seed the heap with all pixels adjacent to markers
148    for y in 0..height {
149        for x in 0..width {
150            if markers[[y, x]] == 0 {
151                continue;
152            }
153            let neighbours: [(i64, i64); 4] = [(-1, 0), (1, 0), (0, -1), (0, 1)];
154            for (dy, dx) in neighbours {
155                let ny = y as i64 + dy;
156                let nx = x as i64 + dx;
157                if ny < 0 || ny >= height as i64 || nx < 0 || nx >= width as i64 {
158                    continue;
159                }
160                let ny = ny as usize;
161                let nx = nx as usize;
162                if output[[ny, nx]] == 0 && !in_queue[[ny, nx]] {
163                    in_queue[[ny, nx]] = true;
164                    let neg = ordered_float::NotNan::new(-gradient[[ny, nx]])
165                        .unwrap_or_else(|_| ordered_float::NotNan::default());
166                    heap.push(WatershedEntry {
167                        neg_gradient: neg,
168                        y: ny,
169                        x: nx,
170                    });
171                }
172            }
173        }
174    }
175
176    // Flood filling
177    while let Some(entry) = heap.pop() {
178        let y = entry.y;
179        let x = entry.x;
180
181        // Assign label from the neighbouring marker with the lowest boundary cost
182        let mut best_label = 0i32;
183        let mut best_grad = f64::INFINITY;
184
185        let neighbours: [(i64, i64); 4] = [(-1, 0), (1, 0), (0, -1), (0, 1)];
186        for (dy, dx) in neighbours {
187            let ny = y as i64 + dy;
188            let nx = x as i64 + dx;
189            if ny < 0 || ny >= height as i64 || nx < 0 || nx >= width as i64 {
190                continue;
191            }
192            let ny = ny as usize;
193            let nx = nx as usize;
194            let nb_label = output[[ny, nx]];
195            if nb_label != 0 {
196                // Use the neighbour's gradient as the "cost" of propagating through it
197                let cost = gradient[[ny, nx]];
198                if cost < best_grad {
199                    best_grad = cost;
200                    best_label = nb_label;
201                }
202            }
203        }
204
205        if best_label != 0 {
206            output[[y, x]] = best_label;
207
208            // Expand into unlabelled neighbours
209            for (dy, dx) in neighbours {
210                let ny = y as i64 + dy;
211                let nx = x as i64 + dx;
212                if ny < 0 || ny >= height as i64 || nx < 0 || nx >= width as i64 {
213                    continue;
214                }
215                let ny = ny as usize;
216                let nx = nx as usize;
217                if output[[ny, nx]] == 0 && !in_queue[[ny, nx]] {
218                    in_queue[[ny, nx]] = true;
219                    let neg = ordered_float::NotNan::new(-gradient[[ny, nx]])
220                        .unwrap_or_else(|_| ordered_float::NotNan::default());
221                    heap.push(WatershedEntry {
222                        neg_gradient: neg,
223                        y: ny,
224                        x: nx,
225                    });
226                }
227            }
228        }
229    }
230
231    Ok(output)
232}
233
234// ---------------------------------------------------------------------------
235// Mask IoU
236// ---------------------------------------------------------------------------
237
238/// Compute intersection-over-union between two binary masks.
239///
240/// Returns 0.0 when both masks are empty (zero union).
241pub fn mask_iou(mask1: &Array2<bool>, mask2: &Array2<bool>) -> Result<f64> {
242    let (h1, w1) = mask1.dim();
243    let (h2, w2) = mask2.dim();
244    if h1 != h2 || w1 != w2 {
245        return Err(VisionError::DimensionMismatch(format!(
246            "mask1 ({h1}×{w1}) and mask2 ({h2}×{w2}) must have the same shape"
247        )));
248    }
249
250    let mut intersection = 0usize;
251    let mut union_ = 0usize;
252
253    for y in 0..h1 {
254        for x in 0..w1 {
255            let a = mask1[[y, x]];
256            let b = mask2[[y, x]];
257            if a && b {
258                intersection += 1;
259            }
260            if a || b {
261                union_ += 1;
262            }
263        }
264    }
265
266    if union_ == 0 {
267        Ok(0.0)
268    } else {
269        Ok(intersection as f64 / union_ as f64)
270    }
271}
272
273// ---------------------------------------------------------------------------
274// Mask NMS
275// ---------------------------------------------------------------------------
276
277/// Non-maximum suppression on instance masks using mask IoU.
278///
279/// Instances are sorted by descending score; any instance whose mask IoU with
280/// an already-selected instance exceeds `iou_threshold` is suppressed.
281///
282/// # Arguments
283/// * `instances`     - Candidate instance masks
284/// * `iou_threshold` - IoU threshold above which the lower-scored instance is suppressed
285pub fn mask_nms(instances: &[InstanceMask], iou_threshold: f64) -> Result<Vec<InstanceMask>> {
286    if instances.is_empty() {
287        return Ok(Vec::new());
288    }
289
290    // Sort indices by descending score
291    let mut indices: Vec<usize> = (0..instances.len()).collect();
292    indices.sort_by(|&a, &b| {
293        instances[b]
294            .score
295            .partial_cmp(&instances[a].score)
296            .unwrap_or(std::cmp::Ordering::Equal)
297    });
298
299    let mut kept: Vec<InstanceMask> = Vec::new();
300
301    'outer: for &idx in &indices {
302        let candidate = &instances[idx];
303        for already_kept in &kept {
304            // Only compare within the same class (optional convention, matches most frameworks)
305            if already_kept.class_id != candidate.class_id {
306                continue;
307            }
308            let iou = mask_iou(&candidate.mask, &already_kept.mask)?;
309            if iou > iou_threshold {
310                continue 'outer;
311            }
312        }
313        kept.push(candidate.clone());
314    }
315
316    Ok(kept)
317}
318
319// ---------------------------------------------------------------------------
320// Panoptic Quality
321// ---------------------------------------------------------------------------
322
323/// Compute Panoptic Quality (PQ), Segmentation Quality (SQ), and Recognition
324/// Quality (RQ) for a single semantic class.
325///
326/// The formulae are from Kirillov et al., "Panoptic Segmentation", CVPR 2019:
327///
328/// ```text
329/// PQ = SQ × RQ
330/// SQ = Σ_{(p,g)∈TP} IoU(p,g) / |TP|
331/// RQ = |TP| / (|TP| + ½|FP| + ½|FN|)
332/// ```
333///
334/// A predicted instance `p` is matched to ground-truth `g` if their mask IoU
335/// exceeds 0.5 (the standard threshold).
336///
337/// # Arguments
338/// * `predicted`    - Predicted instance masks (all same class)
339/// * `ground_truth` - Ground-truth instance masks (all same class)
340///
341/// # Returns
342/// `(pq, sq, rq)` tuple.
343pub fn panoptic_quality(
344    predicted: &[InstanceMask],
345    ground_truth: &[InstanceMask],
346) -> Result<(f64, f64, f64)> {
347    let iou_threshold = 0.5f64;
348
349    // Build a cost matrix: IoU between every predicted × gt pair
350    let n_pred = predicted.len();
351    let n_gt = ground_truth.len();
352
353    if n_pred == 0 && n_gt == 0 {
354        // Nothing to evaluate; PQ = 1 by convention (perfect vacuous case)
355        return Ok((1.0, 1.0, 1.0));
356    }
357
358    // Greedy matching: sort pairs by descending IoU, greedily assign
359    let mut iou_pairs: Vec<(f64, usize, usize)> = Vec::new();
360    for (pi, pred_inst) in predicted.iter().enumerate().take(n_pred) {
361        for (gi, gt_inst) in ground_truth.iter().enumerate().take(n_gt) {
362            // Only compare spatially compatible pairs (same dimensions)
363            let (ph, pw) = pred_inst.mask.dim();
364            let (gh, gw) = gt_inst.mask.dim();
365            if ph != gh || pw != gw {
366                continue;
367            }
368            let iou = mask_iou(&pred_inst.mask, &gt_inst.mask)?;
369            if iou > iou_threshold {
370                iou_pairs.push((iou, pi, gi));
371            }
372        }
373    }
374
375    // Sort descending by IoU
376    iou_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
377
378    let mut matched_pred: HashSet<usize> = HashSet::new();
379    let mut matched_gt: HashSet<usize> = HashSet::new();
380    let mut tp_iou_sum = 0.0f64;
381    let mut tp_count = 0usize;
382
383    for (iou, pi, gi) in &iou_pairs {
384        if matched_pred.contains(pi) || matched_gt.contains(gi) {
385            continue;
386        }
387        matched_pred.insert(*pi);
388        matched_gt.insert(*gi);
389        tp_iou_sum += iou;
390        tp_count += 1;
391    }
392
393    let fp = n_pred - matched_pred.len();
394    let fn_ = n_gt - matched_gt.len();
395
396    let tp_f = tp_count as f64;
397    let fp_f = fp as f64;
398    let fn_f = fn_ as f64;
399
400    let sq = if tp_count > 0 { tp_iou_sum / tp_f } else { 0.0 };
401
402    let denom = tp_f + 0.5 * fp_f + 0.5 * fn_f;
403    let rq = if denom > 0.0 { tp_f / denom } else { 0.0 };
404    let pq = sq * rq;
405
406    Ok((pq, sq, rq))
407}
408
409// ---------------------------------------------------------------------------
410// Instance overlap
411// ---------------------------------------------------------------------------
412
413/// Check whether two instance masks overlap (share at least one foreground pixel).
414///
415/// Returns an error if the masks have different spatial dimensions.
416pub fn instance_overlap(inst1: &InstanceMask, inst2: &InstanceMask) -> Result<bool> {
417    let (h1, w1) = inst1.mask.dim();
418    let (h2, w2) = inst2.mask.dim();
419    if h1 != h2 || w1 != w2 {
420        return Err(VisionError::DimensionMismatch(format!(
421            "inst1 mask ({h1}×{w1}) and inst2 mask ({h2}×{w2}) must have the same shape"
422        )));
423    }
424    for y in 0..h1 {
425        for x in 0..w1 {
426            if inst1.mask[[y, x]] && inst2.mask[[y, x]] {
427                return Ok(true);
428            }
429        }
430    }
431    Ok(false)
432}
433
434// ---------------------------------------------------------------------------
435// Utility: build InstanceMask from a label map
436// ---------------------------------------------------------------------------
437
438/// Convert a dense label map (e.g. watershed output) to a vector of [`InstanceMask`]s.
439///
440/// Label `0` is treated as background and ignored.
441/// All instances are assigned `class_id = 0` and `score = 1.0` (modify as needed).
442pub fn label_map_to_instances(label_map: &Array2<i32>) -> Result<Vec<InstanceMask>> {
443    let (height, width) = label_map.dim();
444    let mut label_set: HashMap<i32, Vec<(usize, usize)>> = HashMap::new();
445
446    for y in 0..height {
447        for x in 0..width {
448            let lbl = label_map[[y, x]];
449            if lbl == 0 {
450                continue;
451            }
452            label_set.entry(lbl).or_default().push((y, x));
453        }
454    }
455
456    let mut instances: Vec<InstanceMask> = Vec::new();
457    for (_, pixels) in label_set {
458        let mut mask = Array2::<bool>::from_elem((height, width), false);
459        for (y, x) in pixels {
460            mask[[y, x]] = true;
461        }
462        instances.push(InstanceMask::new(0, 1.0, mask));
463    }
464
465    // Sort by area descending for deterministic ordering
466    instances.sort_by_key(|inst| Reverse(inst.area()));
467
468    Ok(instances)
469}
470
471// ---------------------------------------------------------------------------
472// Tests
473// ---------------------------------------------------------------------------
474
475#[cfg(test)]
476mod tests {
477    use super::*;
478    use scirs2_core::ndarray::{Array2, Array3};
479
480    fn make_mask(height: usize, width: usize, pixels: &[(usize, usize)]) -> Array2<bool> {
481        let mut m = Array2::<bool>::from_elem((height, width), false);
482        for &(y, x) in pixels {
483            m[[y, x]] = true;
484        }
485        m
486    }
487
488    #[test]
489    fn test_mask_iou_identical() {
490        let m = make_mask(4, 4, &[(0, 0), (0, 1), (1, 0)]);
491        let iou = mask_iou(&m, &m).expect("mask_iou should succeed");
492        assert!((iou - 1.0).abs() < 1e-10);
493    }
494
495    #[test]
496    fn test_mask_iou_disjoint() {
497        let m1 = make_mask(4, 4, &[(0, 0)]);
498        let m2 = make_mask(4, 4, &[(3, 3)]);
499        let iou = mask_iou(&m1, &m2).expect("mask_iou should succeed");
500        assert!((iou - 0.0).abs() < 1e-10);
501    }
502
503    #[test]
504    fn test_mask_nms_removes_overlap() {
505        let m1 = make_mask(4, 4, &[(0, 0), (0, 1), (1, 0), (1, 1)]);
506        let m2 = make_mask(4, 4, &[(0, 0), (0, 1), (1, 0)]);
507        let instances = vec![
508            InstanceMask::new(0, 0.9, m1.clone()),
509            InstanceMask::new(0, 0.7, m2.clone()),
510        ];
511        let kept = mask_nms(&instances, 0.5).expect("mask_nms should succeed");
512        // IoU(m1, m2) = 3/4 = 0.75 > 0.5, so lower-score m2 should be suppressed
513        assert_eq!(kept.len(), 1);
514        assert!((kept[0].score - 0.9).abs() < 1e-10);
515    }
516
517    #[test]
518    fn test_mask_nms_keeps_disjoint() {
519        let m1 = make_mask(4, 4, &[(0, 0)]);
520        let m2 = make_mask(4, 4, &[(3, 3)]);
521        let instances = vec![InstanceMask::new(0, 0.9, m1), InstanceMask::new(0, 0.8, m2)];
522        let kept = mask_nms(&instances, 0.5).expect("mask_nms should succeed");
523        assert_eq!(kept.len(), 2);
524    }
525
526    #[test]
527    fn test_panoptic_quality_perfect() {
528        let m = make_mask(4, 4, &[(0, 0), (0, 1)]);
529        let pred = vec![InstanceMask::new(0, 1.0, m.clone())];
530        let gt = vec![InstanceMask::new(0, 1.0, m)];
531        let (pq, sq, rq) = panoptic_quality(&pred, &gt).expect("panoptic_quality should succeed");
532        assert!((pq - 1.0).abs() < 1e-10);
533        assert!((sq - 1.0).abs() < 1e-10);
534        assert!((rq - 1.0).abs() < 1e-10);
535    }
536
537    #[test]
538    fn test_panoptic_quality_empty() {
539        let (pq, sq, rq) = panoptic_quality(&[], &[]).expect("panoptic_quality should succeed");
540        // Vacuous perfect case
541        assert!((pq - 1.0).abs() < 1e-10);
542        let _ = (sq, rq);
543    }
544
545    #[test]
546    fn test_instance_overlap_true() {
547        let m1 = make_mask(4, 4, &[(1, 1), (2, 2)]);
548        let m2 = make_mask(4, 4, &[(2, 2), (3, 3)]);
549        let i1 = InstanceMask::new(0, 1.0, m1);
550        let i2 = InstanceMask::new(0, 1.0, m2);
551        assert!(instance_overlap(&i1, &i2).expect("should succeed"));
552    }
553
554    #[test]
555    fn test_instance_overlap_false() {
556        let m1 = make_mask(4, 4, &[(0, 0)]);
557        let m2 = make_mask(4, 4, &[(3, 3)]);
558        let i1 = InstanceMask::new(0, 1.0, m1);
559        let i2 = InstanceMask::new(0, 1.0, m2);
560        assert!(!instance_overlap(&i1, &i2).expect("should succeed"));
561    }
562
563    #[test]
564    fn test_watershed_instance_basic() {
565        let mut gradient = Array2::<f64>::zeros((5, 5));
566        // High-gradient boundary down the middle
567        for y in 0..5 {
568            gradient[[y, 2]] = 10.0;
569        }
570        let mut markers = Array2::<i32>::zeros((5, 5));
571        markers[[2, 0]] = 1;
572        markers[[2, 4]] = 2;
573        let labels = watershed_instance(&gradient, &markers).expect("watershed should succeed");
574        assert_eq!(labels.dim(), (5, 5));
575        // Left seed should dominate left side
576        assert_eq!(labels[[2, 0]], 1);
577        // Right seed should dominate right side
578        assert_eq!(labels[[2, 4]], 2);
579    }
580
581    #[test]
582    fn test_label_map_to_instances() {
583        let mut lmap = Array2::<i32>::zeros((4, 4));
584        lmap[[0, 0]] = 1;
585        lmap[[0, 1]] = 1;
586        lmap[[3, 3]] = 2;
587        let instances =
588            label_map_to_instances(&lmap).expect("label_map_to_instances should succeed");
589        assert_eq!(instances.len(), 2);
590    }
591}