1use crate::error::{Result, VisionError};
10use scirs2_core::ndarray::Array2;
11use std::cmp::Reverse;
12use std::collections::{BinaryHeap, HashMap, HashSet};
13
14#[derive(Debug, Clone)]
20pub struct InstanceMask {
21 pub class_id: usize,
23 pub score: f64,
25 pub mask: Array2<bool>,
27 pub bbox: [usize; 4],
29}
30
31impl InstanceMask {
32 pub fn new(class_id: usize, score: f64, mask: Array2<bool>) -> Self {
36 let bbox = compute_bbox(&mask);
37 Self {
38 class_id,
39 score,
40 mask,
41 bbox,
42 }
43 }
44
45 pub fn area(&self) -> usize {
47 self.mask.iter().filter(|&&v| v).count()
48 }
49}
50
51fn compute_bbox(mask: &Array2<bool>) -> [usize; 4] {
55 let (height, width) = mask.dim();
56 let mut y_min = height;
57 let mut y_max = 0usize;
58 let mut x_min = width;
59 let mut x_max = 0usize;
60 let mut found = false;
61
62 for y in 0..height {
63 for x in 0..width {
64 if mask[[y, x]] {
65 found = true;
66 if y < y_min {
67 y_min = y;
68 }
69 if y > y_max {
70 y_max = y;
71 }
72 if x < x_min {
73 x_min = x;
74 }
75 if x > x_max {
76 x_max = x;
77 }
78 }
79 }
80 }
81
82 if found {
83 [y_min, x_min, y_max, x_max]
84 } else {
85 [0, 0, 0, 0]
86 }
87}
88
89#[derive(PartialEq)]
95struct WatershedEntry {
96 neg_gradient: ordered_float::NotNan<f64>,
98 y: usize,
99 x: usize,
100}
101
102impl Eq for WatershedEntry {}
103
104impl PartialOrd for WatershedEntry {
105 fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
106 Some(self.cmp(other))
107 }
108}
109
110impl Ord for WatershedEntry {
111 fn cmp(&self, other: &Self) -> std::cmp::Ordering {
112 self.neg_gradient.cmp(&other.neg_gradient)
114 }
115}
116
117pub fn watershed_instance(gradient: &Array2<f64>, markers: &Array2<i32>) -> Result<Array2<i32>> {
129 let (height, width) = gradient.dim();
130 let (mh, mw) = markers.dim();
131 if height != mh || width != mw {
132 return Err(VisionError::DimensionMismatch(format!(
133 "gradient ({height}×{width}) and markers ({mh}×{mw}) must have the same shape"
134 )));
135 }
136 if height == 0 || width == 0 {
137 return Err(VisionError::InvalidParameter(
138 "gradient must be non-empty".to_string(),
139 ));
140 }
141
142 let mut output = markers.to_owned();
143 let mut in_queue = Array2::<bool>::from_elem((height, width), false);
144
145 let mut heap: BinaryHeap<WatershedEntry> = BinaryHeap::new();
146
147 for y in 0..height {
149 for x in 0..width {
150 if markers[[y, x]] == 0 {
151 continue;
152 }
153 let neighbours: [(i64, i64); 4] = [(-1, 0), (1, 0), (0, -1), (0, 1)];
154 for (dy, dx) in neighbours {
155 let ny = y as i64 + dy;
156 let nx = x as i64 + dx;
157 if ny < 0 || ny >= height as i64 || nx < 0 || nx >= width as i64 {
158 continue;
159 }
160 let ny = ny as usize;
161 let nx = nx as usize;
162 if output[[ny, nx]] == 0 && !in_queue[[ny, nx]] {
163 in_queue[[ny, nx]] = true;
164 let neg = ordered_float::NotNan::new(-gradient[[ny, nx]])
165 .unwrap_or_else(|_| ordered_float::NotNan::default());
166 heap.push(WatershedEntry {
167 neg_gradient: neg,
168 y: ny,
169 x: nx,
170 });
171 }
172 }
173 }
174 }
175
176 while let Some(entry) = heap.pop() {
178 let y = entry.y;
179 let x = entry.x;
180
181 let mut best_label = 0i32;
183 let mut best_grad = f64::INFINITY;
184
185 let neighbours: [(i64, i64); 4] = [(-1, 0), (1, 0), (0, -1), (0, 1)];
186 for (dy, dx) in neighbours {
187 let ny = y as i64 + dy;
188 let nx = x as i64 + dx;
189 if ny < 0 || ny >= height as i64 || nx < 0 || nx >= width as i64 {
190 continue;
191 }
192 let ny = ny as usize;
193 let nx = nx as usize;
194 let nb_label = output[[ny, nx]];
195 if nb_label != 0 {
196 let cost = gradient[[ny, nx]];
198 if cost < best_grad {
199 best_grad = cost;
200 best_label = nb_label;
201 }
202 }
203 }
204
205 if best_label != 0 {
206 output[[y, x]] = best_label;
207
208 for (dy, dx) in neighbours {
210 let ny = y as i64 + dy;
211 let nx = x as i64 + dx;
212 if ny < 0 || ny >= height as i64 || nx < 0 || nx >= width as i64 {
213 continue;
214 }
215 let ny = ny as usize;
216 let nx = nx as usize;
217 if output[[ny, nx]] == 0 && !in_queue[[ny, nx]] {
218 in_queue[[ny, nx]] = true;
219 let neg = ordered_float::NotNan::new(-gradient[[ny, nx]])
220 .unwrap_or_else(|_| ordered_float::NotNan::default());
221 heap.push(WatershedEntry {
222 neg_gradient: neg,
223 y: ny,
224 x: nx,
225 });
226 }
227 }
228 }
229 }
230
231 Ok(output)
232}
233
234pub fn mask_iou(mask1: &Array2<bool>, mask2: &Array2<bool>) -> Result<f64> {
242 let (h1, w1) = mask1.dim();
243 let (h2, w2) = mask2.dim();
244 if h1 != h2 || w1 != w2 {
245 return Err(VisionError::DimensionMismatch(format!(
246 "mask1 ({h1}×{w1}) and mask2 ({h2}×{w2}) must have the same shape"
247 )));
248 }
249
250 let mut intersection = 0usize;
251 let mut union_ = 0usize;
252
253 for y in 0..h1 {
254 for x in 0..w1 {
255 let a = mask1[[y, x]];
256 let b = mask2[[y, x]];
257 if a && b {
258 intersection += 1;
259 }
260 if a || b {
261 union_ += 1;
262 }
263 }
264 }
265
266 if union_ == 0 {
267 Ok(0.0)
268 } else {
269 Ok(intersection as f64 / union_ as f64)
270 }
271}
272
273pub fn mask_nms(instances: &[InstanceMask], iou_threshold: f64) -> Result<Vec<InstanceMask>> {
286 if instances.is_empty() {
287 return Ok(Vec::new());
288 }
289
290 let mut indices: Vec<usize> = (0..instances.len()).collect();
292 indices.sort_by(|&a, &b| {
293 instances[b]
294 .score
295 .partial_cmp(&instances[a].score)
296 .unwrap_or(std::cmp::Ordering::Equal)
297 });
298
299 let mut kept: Vec<InstanceMask> = Vec::new();
300
301 'outer: for &idx in &indices {
302 let candidate = &instances[idx];
303 for already_kept in &kept {
304 if already_kept.class_id != candidate.class_id {
306 continue;
307 }
308 let iou = mask_iou(&candidate.mask, &already_kept.mask)?;
309 if iou > iou_threshold {
310 continue 'outer;
311 }
312 }
313 kept.push(candidate.clone());
314 }
315
316 Ok(kept)
317}
318
319pub fn panoptic_quality(
344 predicted: &[InstanceMask],
345 ground_truth: &[InstanceMask],
346) -> Result<(f64, f64, f64)> {
347 let iou_threshold = 0.5f64;
348
349 let n_pred = predicted.len();
351 let n_gt = ground_truth.len();
352
353 if n_pred == 0 && n_gt == 0 {
354 return Ok((1.0, 1.0, 1.0));
356 }
357
358 let mut iou_pairs: Vec<(f64, usize, usize)> = Vec::new();
360 for (pi, pred_inst) in predicted.iter().enumerate().take(n_pred) {
361 for (gi, gt_inst) in ground_truth.iter().enumerate().take(n_gt) {
362 let (ph, pw) = pred_inst.mask.dim();
364 let (gh, gw) = gt_inst.mask.dim();
365 if ph != gh || pw != gw {
366 continue;
367 }
368 let iou = mask_iou(&pred_inst.mask, >_inst.mask)?;
369 if iou > iou_threshold {
370 iou_pairs.push((iou, pi, gi));
371 }
372 }
373 }
374
375 iou_pairs.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
377
378 let mut matched_pred: HashSet<usize> = HashSet::new();
379 let mut matched_gt: HashSet<usize> = HashSet::new();
380 let mut tp_iou_sum = 0.0f64;
381 let mut tp_count = 0usize;
382
383 for (iou, pi, gi) in &iou_pairs {
384 if matched_pred.contains(pi) || matched_gt.contains(gi) {
385 continue;
386 }
387 matched_pred.insert(*pi);
388 matched_gt.insert(*gi);
389 tp_iou_sum += iou;
390 tp_count += 1;
391 }
392
393 let fp = n_pred - matched_pred.len();
394 let fn_ = n_gt - matched_gt.len();
395
396 let tp_f = tp_count as f64;
397 let fp_f = fp as f64;
398 let fn_f = fn_ as f64;
399
400 let sq = if tp_count > 0 { tp_iou_sum / tp_f } else { 0.0 };
401
402 let denom = tp_f + 0.5 * fp_f + 0.5 * fn_f;
403 let rq = if denom > 0.0 { tp_f / denom } else { 0.0 };
404 let pq = sq * rq;
405
406 Ok((pq, sq, rq))
407}
408
409pub fn instance_overlap(inst1: &InstanceMask, inst2: &InstanceMask) -> Result<bool> {
417 let (h1, w1) = inst1.mask.dim();
418 let (h2, w2) = inst2.mask.dim();
419 if h1 != h2 || w1 != w2 {
420 return Err(VisionError::DimensionMismatch(format!(
421 "inst1 mask ({h1}×{w1}) and inst2 mask ({h2}×{w2}) must have the same shape"
422 )));
423 }
424 for y in 0..h1 {
425 for x in 0..w1 {
426 if inst1.mask[[y, x]] && inst2.mask[[y, x]] {
427 return Ok(true);
428 }
429 }
430 }
431 Ok(false)
432}
433
434pub fn label_map_to_instances(label_map: &Array2<i32>) -> Result<Vec<InstanceMask>> {
443 let (height, width) = label_map.dim();
444 let mut label_set: HashMap<i32, Vec<(usize, usize)>> = HashMap::new();
445
446 for y in 0..height {
447 for x in 0..width {
448 let lbl = label_map[[y, x]];
449 if lbl == 0 {
450 continue;
451 }
452 label_set.entry(lbl).or_default().push((y, x));
453 }
454 }
455
456 let mut instances: Vec<InstanceMask> = Vec::new();
457 for (_, pixels) in label_set {
458 let mut mask = Array2::<bool>::from_elem((height, width), false);
459 for (y, x) in pixels {
460 mask[[y, x]] = true;
461 }
462 instances.push(InstanceMask::new(0, 1.0, mask));
463 }
464
465 instances.sort_by_key(|inst| Reverse(inst.area()));
467
468 Ok(instances)
469}
470
471#[cfg(test)]
476mod tests {
477 use super::*;
478 use scirs2_core::ndarray::{Array2, Array3};
479
480 fn make_mask(height: usize, width: usize, pixels: &[(usize, usize)]) -> Array2<bool> {
481 let mut m = Array2::<bool>::from_elem((height, width), false);
482 for &(y, x) in pixels {
483 m[[y, x]] = true;
484 }
485 m
486 }
487
488 #[test]
489 fn test_mask_iou_identical() {
490 let m = make_mask(4, 4, &[(0, 0), (0, 1), (1, 0)]);
491 let iou = mask_iou(&m, &m).expect("mask_iou should succeed");
492 assert!((iou - 1.0).abs() < 1e-10);
493 }
494
495 #[test]
496 fn test_mask_iou_disjoint() {
497 let m1 = make_mask(4, 4, &[(0, 0)]);
498 let m2 = make_mask(4, 4, &[(3, 3)]);
499 let iou = mask_iou(&m1, &m2).expect("mask_iou should succeed");
500 assert!((iou - 0.0).abs() < 1e-10);
501 }
502
503 #[test]
504 fn test_mask_nms_removes_overlap() {
505 let m1 = make_mask(4, 4, &[(0, 0), (0, 1), (1, 0), (1, 1)]);
506 let m2 = make_mask(4, 4, &[(0, 0), (0, 1), (1, 0)]);
507 let instances = vec![
508 InstanceMask::new(0, 0.9, m1.clone()),
509 InstanceMask::new(0, 0.7, m2.clone()),
510 ];
511 let kept = mask_nms(&instances, 0.5).expect("mask_nms should succeed");
512 assert_eq!(kept.len(), 1);
514 assert!((kept[0].score - 0.9).abs() < 1e-10);
515 }
516
517 #[test]
518 fn test_mask_nms_keeps_disjoint() {
519 let m1 = make_mask(4, 4, &[(0, 0)]);
520 let m2 = make_mask(4, 4, &[(3, 3)]);
521 let instances = vec![InstanceMask::new(0, 0.9, m1), InstanceMask::new(0, 0.8, m2)];
522 let kept = mask_nms(&instances, 0.5).expect("mask_nms should succeed");
523 assert_eq!(kept.len(), 2);
524 }
525
526 #[test]
527 fn test_panoptic_quality_perfect() {
528 let m = make_mask(4, 4, &[(0, 0), (0, 1)]);
529 let pred = vec![InstanceMask::new(0, 1.0, m.clone())];
530 let gt = vec![InstanceMask::new(0, 1.0, m)];
531 let (pq, sq, rq) = panoptic_quality(&pred, >).expect("panoptic_quality should succeed");
532 assert!((pq - 1.0).abs() < 1e-10);
533 assert!((sq - 1.0).abs() < 1e-10);
534 assert!((rq - 1.0).abs() < 1e-10);
535 }
536
537 #[test]
538 fn test_panoptic_quality_empty() {
539 let (pq, sq, rq) = panoptic_quality(&[], &[]).expect("panoptic_quality should succeed");
540 assert!((pq - 1.0).abs() < 1e-10);
542 let _ = (sq, rq);
543 }
544
545 #[test]
546 fn test_instance_overlap_true() {
547 let m1 = make_mask(4, 4, &[(1, 1), (2, 2)]);
548 let m2 = make_mask(4, 4, &[(2, 2), (3, 3)]);
549 let i1 = InstanceMask::new(0, 1.0, m1);
550 let i2 = InstanceMask::new(0, 1.0, m2);
551 assert!(instance_overlap(&i1, &i2).expect("should succeed"));
552 }
553
554 #[test]
555 fn test_instance_overlap_false() {
556 let m1 = make_mask(4, 4, &[(0, 0)]);
557 let m2 = make_mask(4, 4, &[(3, 3)]);
558 let i1 = InstanceMask::new(0, 1.0, m1);
559 let i2 = InstanceMask::new(0, 1.0, m2);
560 assert!(!instance_overlap(&i1, &i2).expect("should succeed"));
561 }
562
563 #[test]
564 fn test_watershed_instance_basic() {
565 let mut gradient = Array2::<f64>::zeros((5, 5));
566 for y in 0..5 {
568 gradient[[y, 2]] = 10.0;
569 }
570 let mut markers = Array2::<i32>::zeros((5, 5));
571 markers[[2, 0]] = 1;
572 markers[[2, 4]] = 2;
573 let labels = watershed_instance(&gradient, &markers).expect("watershed should succeed");
574 assert_eq!(labels.dim(), (5, 5));
575 assert_eq!(labels[[2, 0]], 1);
577 assert_eq!(labels[[2, 4]], 2);
579 }
580
581 #[test]
582 fn test_label_map_to_instances() {
583 let mut lmap = Array2::<i32>::zeros((4, 4));
584 lmap[[0, 0]] = 1;
585 lmap[[0, 1]] = 1;
586 lmap[[3, 3]] = 2;
587 let instances =
588 label_map_to_instances(&lmap).expect("label_map_to_instances should succeed");
589 assert_eq!(instances.len(), 2);
590 }
591}