Skip to main content

yscv_detect/
nms.rs

1use crate::{BoundingBox, DetectError, Detection};
2
3/// Computes IoU for two axis-aligned boxes.
4pub fn iou(a: BoundingBox, b: BoundingBox) -> f32 {
5    let inter_x1 = a.x1.max(b.x1);
6    let inter_y1 = a.y1.max(b.y1);
7    let inter_x2 = a.x2.min(b.x2);
8    let inter_y2 = a.y2.min(b.y2);
9
10    let inter_w = (inter_x2 - inter_x1).max(0.0);
11    let inter_h = (inter_y2 - inter_y1).max(0.0);
12    let inter = inter_w * inter_h;
13    let union = a.area() + b.area() - inter;
14    if union <= 0.0 { 0.0 } else { inter / union }
15}
16
17/// Standard score-sorted NMS.
18pub fn non_max_suppression(
19    detections: &[Detection],
20    iou_threshold: f32,
21    max_detections: usize,
22) -> Vec<Detection> {
23    let mut sorted = detections
24        .iter()
25        .copied()
26        .filter(is_finite_detection)
27        .collect::<Vec<_>>();
28    sorted.sort_by(|a, b| b.score.total_cmp(&a.score));
29
30    let mut selected: Vec<Detection> = Vec::new();
31    for candidate in sorted {
32        if selected.len() >= max_detections {
33            break;
34        }
35        let mut suppressed = false;
36        for chosen in &selected {
37            if chosen.class_id == candidate.class_id
38                && iou(chosen.bbox, candidate.bbox) > iou_threshold
39            {
40                suppressed = true;
41                break;
42            }
43        }
44        if !suppressed {
45            selected.push(candidate);
46        }
47    }
48    selected
49}
50
51pub(crate) fn validate_nms_args(
52    iou_threshold: f32,
53    max_detections: usize,
54) -> Result<(), DetectError> {
55    if !iou_threshold.is_finite() || !(0.0..=1.0).contains(&iou_threshold) {
56        return Err(DetectError::InvalidIouThreshold { iou_threshold });
57    }
58    if max_detections == 0 {
59        return Err(DetectError::InvalidMaxDetections { max_detections });
60    }
61    Ok(())
62}
63
64fn is_finite_detection(detection: &Detection) -> bool {
65    detection.score.is_finite()
66        && detection.bbox.x1.is_finite()
67        && detection.bbox.y1.is_finite()
68        && detection.bbox.x2.is_finite()
69        && detection.bbox.y2.is_finite()
70}
71
72/// Soft-NMS with Gaussian decay.
73///
74/// Instead of hard suppression, overlapping detections have their scores
75/// decayed by `score *= exp(-(iou² / sigma))`. Detections whose score
76/// falls below `score_threshold` are removed. The vector is modified in
77/// place.
78pub fn soft_nms(detections: &mut Vec<Detection>, sigma: f32, score_threshold: f32) {
79    // Filter out non-finite detections first.
80    detections.retain(is_finite_detection);
81
82    // Process each position: pick the highest-scoring remaining detection,
83    // swap it to the current position, then decay all subsequent detections.
84    let mut i = 0;
85    while i < detections.len() {
86        // Find the index of the max-score detection in [i..].
87        let mut max_idx = i;
88        for j in (i + 1)..detections.len() {
89            if detections[j].score > detections[max_idx].score {
90                max_idx = j;
91            }
92        }
93        detections.swap(i, max_idx);
94
95        // Decay scores of all subsequent detections based on IoU with detections[i].
96        let current = detections[i];
97        let mut j = i + 1;
98        while j < detections.len() {
99            let overlap = iou(current.bbox, detections[j].bbox);
100            detections[j].score *= (-overlap * overlap / sigma).exp();
101            if detections[j].score < score_threshold {
102                detections.swap_remove(j);
103                // Don't increment j; the swapped element needs checking too.
104            } else {
105                j += 1;
106            }
107        }
108        i += 1;
109    }
110}
111
112/// Per-class (batched) NMS.
113///
114/// Groups detections by `class_id`, runs standard `non_max_suppression` on
115/// each group independently, then merges and returns results sorted by score
116/// descending.
117pub fn batched_nms(detections: &[Detection], iou_threshold: f32) -> Vec<Detection> {
118    use std::collections::HashMap;
119
120    let mut by_class: HashMap<usize, Vec<Detection>> = HashMap::new();
121    for det in detections {
122        by_class.entry(det.class_id).or_default().push(*det);
123    }
124
125    let mut results: Vec<Detection> = Vec::new();
126    for class_dets in by_class.values() {
127        let kept = non_max_suppression(class_dets, iou_threshold, class_dets.len());
128        results.extend(kept);
129    }
130
131    results.sort_by(|a, b| b.score.total_cmp(&a.score));
132    results
133}