Skip to main content

oxicuda_vision/detection/
nms.rs

1//! Non-Maximum Suppression over scored bounding boxes.
2//!
3//! Object detectors emit many overlapping candidate boxes for each ground-truth
4//! object. **Non-Maximum Suppression** (NMS) collapses these into a sparse set
5//! by greedily keeping the highest-scoring box and discarding any later box
6//! whose Intersection-over-Union (IoU) with an already-kept box exceeds a
7//! threshold.
8//!
9//! **Soft-NMS** (Bodla et al., 2017) replaces the hard discard with a Gaussian
10//! score *decay*: an overlapping box keeps its identity but has its score
11//! multiplied by `exp(−IoU² / sigma)`, so heavily-overlapping low-confidence
12//! boxes fade out gradually rather than being deleted outright. Boxes whose
13//! decayed score falls at or below `score_threshold` are dropped.
14//!
15//! This module operates on the ergonomic struct-of-fields [`BBox`] type
16//! (`x1, y1, x2, y2, score`). The flat `[n × 4]` array-based variants used by
17//! the anchor / RPN pipeline live in [`crate::detection::anchor_nms`].
18
19use crate::error::{VisionError, VisionResult};
20
21// ─── BBox ────────────────────────────────────────────────────────────────────
22
23/// An axis-aligned bounding box with a detection confidence score.
24///
25/// Coordinates follow the half-open convention `[x1, x2) × [y1, y2)`; a box is
26/// **degenerate** (zero area) when `x2 ≤ x1` or `y2 ≤ y1`.
27#[derive(Debug, Clone, Copy, PartialEq)]
28pub struct BBox {
29    /// Left edge (inclusive).
30    pub x1: f32,
31    /// Top edge (inclusive).
32    pub y1: f32,
33    /// Right edge (exclusive).
34    pub x2: f32,
35    /// Bottom edge (exclusive).
36    pub y2: f32,
37    /// Detection confidence score.
38    pub score: f32,
39}
40
41impl BBox {
42    /// Construct a new box.
43    #[must_use]
44    #[inline]
45    pub fn new(x1: f32, y1: f32, x2: f32, y2: f32, score: f32) -> Self {
46        Self {
47            x1,
48            y1,
49            x2,
50            y2,
51            score,
52        }
53    }
54
55    /// Positive area of the box, or `0.0` if degenerate.
56    #[must_use]
57    #[inline]
58    pub fn area(&self) -> f32 {
59        let w = self.x2 - self.x1;
60        let h = self.y2 - self.y1;
61        if w <= 0.0 || h <= 0.0 { 0.0 } else { w * h }
62    }
63}
64
65// ─── IoU ─────────────────────────────────────────────────────────────────────
66
67/// Jaccard Intersection-over-Union of two boxes.
68///
69/// Degenerate boxes (zero area) yield `0.0`. The result is clamped to `[0, 1]`.
70#[must_use]
71pub fn iou(a: &BBox, b: &BBox) -> f32 {
72    let area_a = a.area();
73    let area_b = b.area();
74    if area_a <= 0.0 || area_b <= 0.0 {
75        return 0.0;
76    }
77    let ix1 = a.x1.max(b.x1);
78    let iy1 = a.y1.max(b.y1);
79    let ix2 = a.x2.min(b.x2);
80    let iy2 = a.y2.min(b.y2);
81    let iw = (ix2 - ix1).max(0.0);
82    let ih = (iy2 - iy1).max(0.0);
83    let inter = iw * ih;
84    let union = area_a + area_b - inter;
85    if union <= 0.0 {
86        return 0.0;
87    }
88    (inter / union).clamp(0.0, 1.0)
89}
90
91// ─── NMS ─────────────────────────────────────────────────────────────────────
92
93/// Greedy hard Non-Maximum Suppression.
94///
95/// Sorts the input boxes by descending score (ties broken by ascending input
96/// index for determinism), then sweeps through them keeping a box iff its IoU
97/// with every previously-kept box is **≤** `iou_threshold`. Returns the kept
98/// **indices into the original `boxes` slice**, in descending-score order.
99///
100/// An empty input returns an empty `Vec` (no error). With
101/// `iou_threshold == 1.0`, only IoU strictly greater than 1 suppresses — i.e.
102/// nothing — so every box is kept.
103///
104/// # Errors
105/// - [`VisionError::NonFinite`] if `iou_threshold` is NaN/∞ or outside `[0, 1]`.
106pub fn nms(boxes: &[BBox], iou_threshold: f32) -> VisionResult<Vec<usize>> {
107    if !(0.0..=1.0).contains(&iou_threshold) || !iou_threshold.is_finite() {
108        return Err(VisionError::NonFinite("nms iou_threshold"));
109    }
110    if boxes.is_empty() {
111        return Ok(Vec::new());
112    }
113
114    let mut order: Vec<usize> = (0..boxes.len()).collect();
115    order.sort_by(|&a, &b| {
116        boxes[b]
117            .score
118            .partial_cmp(&boxes[a].score)
119            .unwrap_or(std::cmp::Ordering::Equal)
120            .then(a.cmp(&b))
121    });
122
123    let mut kept: Vec<usize> = Vec::new();
124    for &idx in &order {
125        let candidate = &boxes[idx];
126        let mut suppressed = false;
127        for &k in &kept {
128            if iou(candidate, &boxes[k]) > iou_threshold {
129                suppressed = true;
130                break;
131            }
132        }
133        if !suppressed {
134            kept.push(idx);
135        }
136    }
137    Ok(kept)
138}
139
140// ─── Soft-NMS ────────────────────────────────────────────────────────────────
141
142/// Gaussian Soft-NMS (Bodla et al., 2017).
143///
144/// Instead of hard-suppressing overlapping boxes, repeatedly select the
145/// highest-scoring remaining box as the *pivot*, emit it, then multiply every
146/// other remaining box's score by the Gaussian decay `exp(−IoU(pivot, ·)² /
147/// sigma)`. Boxes whose decayed score drops **at or below** `score_threshold`
148/// are removed and never emitted.
149///
150/// Returns `(original_index, decayed_score)` pairs in descending decayed-score
151/// (emission) order. An empty input returns an empty `Vec`.
152///
153/// # Errors
154/// - [`VisionError::NonFinite`] if `sigma ≤ 0`, or if `sigma` /
155///   `score_threshold` is non-finite.
156pub fn soft_nms(
157    boxes: &[BBox],
158    sigma: f32,
159    score_threshold: f32,
160) -> VisionResult<Vec<(usize, f32)>> {
161    if sigma <= 0.0 || !sigma.is_finite() {
162        return Err(VisionError::NonFinite("soft_nms sigma"));
163    }
164    if !score_threshold.is_finite() {
165        return Err(VisionError::NonFinite("soft_nms score_threshold"));
166    }
167    if boxes.is_empty() {
168        return Ok(Vec::new());
169    }
170
171    let inv_sigma = 1.0_f32 / sigma;
172    // Pool of (original index, current score) — boxes themselves are read by
173    // index from `boxes`.
174    let mut pool: Vec<(usize, f32)> = boxes.iter().map(|b| (b.score, b)).enumerate().fold(
175        Vec::with_capacity(boxes.len()),
176        |mut acc, (i, (s, _))| {
177            acc.push((i, s));
178            acc
179        },
180    );
181
182    let mut out: Vec<(usize, f32)> = Vec::new();
183
184    while !pool.is_empty() {
185        // Highest-scoring remaining entry becomes the pivot.
186        let (max_pos, max_score) = pool.iter().enumerate().fold(
187            (0usize, f32::NEG_INFINITY),
188            |(best_i, best_s), (i, &(_, s))| {
189                if s > best_s { (i, s) } else { (best_i, best_s) }
190            },
191        );
192
193        if max_score <= score_threshold {
194            // Every remaining score ≤ max_score ≤ threshold: stop.
195            break;
196        }
197
198        let pivot = pool.swap_remove(max_pos);
199        out.push(pivot);
200
201        let pivot_box = &boxes[pivot.0];
202        for entry in pool.iter_mut() {
203            let ov = iou(pivot_box, &boxes[entry.0]);
204            let decay = (-(ov * ov) * inv_sigma).exp();
205            entry.1 *= decay;
206        }
207    }
208
209    Ok(out)
210}
211
212// ─── Tests ───────────────────────────────────────────────────────────────────
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    fn b(x1: f32, y1: f32, x2: f32, y2: f32, s: f32) -> BBox {
219        BBox::new(x1, y1, x2, y2, s)
220    }
221
222    #[test]
223    fn iou_identical_is_1() {
224        let a = b(0.0, 0.0, 10.0, 10.0, 0.9);
225        assert!((iou(&a, &a) - 1.0).abs() < 1e-6);
226    }
227
228    #[test]
229    fn iou_disjoint_is_0() {
230        let a = b(0.0, 0.0, 1.0, 1.0, 0.9);
231        let c = b(2.0, 2.0, 3.0, 3.0, 0.8);
232        assert!(iou(&a, &c).abs() < 1e-7);
233    }
234
235    #[test]
236    fn iou_half_overlap() {
237        // a = [0,0,10,10] (area 100), c = [5,0,15,10] (area 100)
238        // inter = 5·10 = 50, union = 150 ⇒ iou = 1/3.
239        let a = b(0.0, 0.0, 10.0, 10.0, 0.9);
240        let c = b(5.0, 0.0, 15.0, 10.0, 0.8);
241        assert!((iou(&a, &c) - 1.0 / 3.0).abs() < 1e-5);
242    }
243
244    #[test]
245    fn nms_keeps_highest() {
246        // Two heavily overlapping boxes; index 1 has the higher score.
247        let boxes = vec![b(0.0, 0.0, 10.0, 10.0, 0.4), b(1.0, 1.0, 10.0, 10.0, 0.9)];
248        let kept = nms(&boxes, 0.3).expect("ok");
249        assert_eq!(kept, vec![1]);
250    }
251
252    #[test]
253    fn nms_suppresses_overlap() {
254        // Three boxes: 0 and 1 overlap heavily, 2 is disjoint.
255        let boxes = vec![
256            b(0.0, 0.0, 10.0, 10.0, 0.9),
257            b(0.5, 0.5, 10.5, 10.5, 0.8),
258            b(50.0, 50.0, 60.0, 60.0, 0.7),
259        ];
260        let kept = nms(&boxes, 0.3).expect("ok");
261        // Box 1 suppressed by 0; 0 and 2 survive.
262        assert_eq!(kept, vec![0, 2]);
263    }
264
265    #[test]
266    fn nms_keeps_disjoint() {
267        let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.9), b(5.0, 5.0, 6.0, 6.0, 0.8)];
268        let kept = nms(&boxes, 0.5).expect("ok");
269        assert_eq!(kept, vec![0, 1]);
270    }
271
272    #[test]
273    fn soft_nms_decays_scores() {
274        // Two overlapping boxes; the lower-scored survivor must be decayed.
275        let boxes = vec![b(0.0, 0.0, 10.0, 10.0, 0.9), b(1.0, 1.0, 10.0, 10.0, 0.8)];
276        let out = soft_nms(&boxes, 0.5, 0.0).expect("ok");
277        assert_eq!(out.len(), 2);
278        assert_eq!(out[0].0, 0);
279        assert_eq!(out[1].0, 1);
280        assert!(out[1].1 < 0.8, "expected decay, got {}", out[1].1);
281    }
282
283    #[test]
284    fn empty_boxes() {
285        let boxes: Vec<BBox> = Vec::new();
286        assert_eq!(nms(&boxes, 0.5).expect("ok"), Vec::<usize>::new());
287        assert_eq!(
288            soft_nms(&boxes, 0.5, 0.0).expect("ok"),
289            Vec::<(usize, f32)>::new()
290        );
291    }
292
293    #[test]
294    fn threshold_1_keeps_all() {
295        // Even identical boxes are kept when threshold == 1.0.
296        let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.9), b(0.0, 0.0, 1.0, 1.0, 0.8)];
297        let kept = nms(&boxes, 1.0).expect("ok");
298        assert_eq!(kept.len(), 2);
299    }
300
301    #[test]
302    fn soft_nms_threshold_filters() {
303        // Two identical boxes ⇒ iou = 1, decay = exp(-1/0.5) = exp(-2) ≈ 0.135.
304        // The second box's decayed score 0.5·0.135 ≈ 0.068 < 0.4 ⇒ dropped.
305        let boxes = vec![b(0.0, 0.0, 10.0, 10.0, 1.0), b(0.0, 0.0, 10.0, 10.0, 0.5)];
306        let out = soft_nms(&boxes, 0.5, 0.4).expect("ok");
307        assert_eq!(out.len(), 1);
308        assert_eq!(out[0].0, 0);
309    }
310
311    // ── Additional coverage ──────────────────────────────────────────────────
312
313    #[test]
314    fn nms_invalid_threshold_errors() {
315        let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.9)];
316        assert!(matches!(nms(&boxes, 1.5), Err(VisionError::NonFinite(_))));
317        assert!(matches!(nms(&boxes, -0.1), Err(VisionError::NonFinite(_))));
318        assert!(matches!(
319            nms(&boxes, f32::NAN),
320            Err(VisionError::NonFinite(_))
321        ));
322    }
323
324    #[test]
325    fn soft_nms_invalid_sigma_errors() {
326        let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.9)];
327        assert!(matches!(
328            soft_nms(&boxes, 0.0, 0.0),
329            Err(VisionError::NonFinite(_))
330        ));
331        assert!(matches!(
332            soft_nms(&boxes, -1.0, 0.0),
333            Err(VisionError::NonFinite(_))
334        ));
335    }
336
337    #[test]
338    fn nms_returns_descending_score_order() {
339        let boxes = vec![
340            b(0.0, 0.0, 1.0, 1.0, 0.3),
341            b(5.0, 0.0, 6.0, 1.0, 0.9),
342            b(10.0, 0.0, 11.0, 1.0, 0.5),
343        ];
344        let kept = nms(&boxes, 0.5).expect("ok");
345        assert_eq!(kept, vec![1, 2, 0]);
346    }
347
348    #[test]
349    fn soft_nms_disjoint_no_decay() {
350        let boxes = vec![b(0.0, 0.0, 1.0, 1.0, 0.7), b(5.0, 5.0, 6.0, 6.0, 0.6)];
351        let out = soft_nms(&boxes, 0.5, 0.0).expect("ok");
352        assert_eq!(out.len(), 2);
353        assert!((out[0].1 - 0.7).abs() < 1e-5);
354        assert!((out[1].1 - 0.6).abs() < 1e-5);
355    }
356
357    #[test]
358    fn iou_degenerate_is_0() {
359        let a = b(0.0, 0.0, 10.0, 10.0, 0.9);
360        let degenerate = b(5.0, 5.0, 5.0, 5.0, 0.8);
361        assert!(iou(&a, &degenerate).abs() < 1e-7);
362    }
363}