Skip to main content

edgefirst_decoder/
float.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{arg_max, BBoxTypeTrait, BoundingBox, DetectBox};
5use ndarray::{
6    parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
7    Array1, ArrayView2, Zip,
8};
9use num_traits::{AsPrimitive, Float};
10use rayon::slice::ParallelSliceMut;
11
12/// Post processes boxes and scores tensors into detection boxes, filtering out
13/// any boxes below the score threshold. The boxes tensor is converted to XYXY
14/// using the given BBoxTypeTrait. The order of the boxes is preserved.
15pub fn postprocess_boxes_float<
16    B: BBoxTypeTrait,
17    BOX: Float + AsPrimitive<f32> + Send + Sync,
18    SCORE: Float + AsPrimitive<f32> + Send + Sync,
19>(
20    threshold: SCORE,
21    boxes: ArrayView2<BOX>,
22    scores: ArrayView2<SCORE>,
23) -> Vec<DetectBox> {
24    assert_eq!(scores.dim().0, boxes.dim().0);
25    assert_eq!(boxes.dim().1, 4);
26    Zip::from(scores.rows())
27        .and(boxes.rows())
28        .into_par_iter()
29        .filter_map(|(score, bbox)| {
30            let (score_, label) = arg_max(score);
31            if score_ < threshold {
32                return None;
33            }
34
35            let bbox = B::ndarray_to_xyxy_float(bbox);
36            Some(DetectBox {
37                label,
38                score: score_.as_(),
39                bbox: bbox.into(),
40            })
41        })
42        .collect()
43}
44
45/// Post processes boxes and scores tensors into detection boxes, filtering out
46/// any boxes below the score threshold. The boxes tensor is converted to XYXY
47/// using the given BBoxTypeTrait. The order of the boxes is preserved.
48///
49/// This function is very similar to `postprocess_boxes_float` but will also
50/// return the index of the box. The boxes will be in ascending index order.
51pub fn postprocess_boxes_index_float<
52    B: BBoxTypeTrait,
53    BOX: Float + AsPrimitive<f32> + Send + Sync,
54    SCORE: Float + AsPrimitive<f32> + Send + Sync,
55>(
56    threshold: SCORE,
57    boxes: ArrayView2<BOX>,
58    scores: ArrayView2<SCORE>,
59) -> Vec<(DetectBox, usize)> {
60    assert_eq!(scores.dim().0, boxes.dim().0);
61    assert_eq!(boxes.dim().1, 4);
62    let indices: Array1<usize> = (0..boxes.dim().0).collect();
63    Zip::from(scores.rows())
64        .and(boxes.rows())
65        .and(&indices)
66        .into_par_iter()
67        .filter_map(|(score, bbox, i)| {
68            let (score_, label) = arg_max(score);
69            if score_ < threshold {
70                return None;
71            }
72
73            let bbox = B::ndarray_to_xyxy_float(bbox);
74            Some((
75                DetectBox {
76                    label,
77                    score: score_.as_(),
78                    bbox: bbox.into(),
79                },
80                *i,
81            ))
82        })
83        .collect()
84}
85
86/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
87/// then greedily selects a subset of boxes in descending order of score.
88///
89/// If `max_det` is `Some(n)`, the greedy loop stops as soon as `n` survivors
90/// have been confirmed. Because the input is sorted descending, the first `n`
91/// survivors are the highest-scoring `n`, so the post-NMS top-`n` is preserved
92/// without iterating the full O(N²) suppression loop.
93#[must_use]
94pub fn nms_float(iou: f32, max_det: Option<usize>, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
95    // Boxes get sorted by score in descending order so we know based on the
96    // index the scoring of the boxes and can skip parts of the loop.
97    boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
98
99    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
100    // immediately
101    if iou >= 1.0 {
102        return match max_det {
103            Some(n) => {
104                boxes.truncate(n);
105                boxes
106            }
107            None => boxes,
108        };
109    }
110
111    let cap = max_det.unwrap_or(usize::MAX);
112    let mut survivors: usize = 0;
113
114    // Outer loop over all boxes.
115    for i in 0..boxes.len() {
116        if boxes[i].score < 0.0 {
117            // this box was merged with a different box earlier
118            continue;
119        }
120        for j in (i + 1)..boxes.len() {
121            // Inner loop over boxes with lower score (later in the list).
122
123            if boxes[j].score < 0.0 {
124                // this box was suppressed by different box earlier
125                continue;
126            }
127            if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
128                // max_box(boxes[j].bbox, &mut boxes[i].bbox);
129                boxes[j].score = -1.0;
130            }
131        }
132        survivors += 1;
133        if survivors >= cap {
134            break;
135        }
136    }
137    // Filter out suppressed boxes; cap at `max_det` because boxes after the
138    // break may still hold positive scores but score lower than every survivor.
139    boxes
140        .into_iter()
141        .filter(|b| b.score >= 0.0)
142        .take(cap)
143        .collect()
144}
145
146/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
147/// then greedily selects a subset of boxes in descending order of score.
148///
149/// This is same as `nms_float` but will also include extra information along
150/// with each box, such as the index
151#[must_use]
152pub fn nms_extra_float<E: Send + Sync>(
153    iou: f32,
154    max_det: Option<usize>,
155    mut boxes: Vec<(DetectBox, E)>,
156) -> Vec<(DetectBox, E)> {
157    // Boxes get sorted by score in descending order so we know based on the
158    // index the scoring of the boxes and can skip parts of the loop.
159    boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
160
161    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
162    // immediately
163    if iou >= 1.0 {
164        return match max_det {
165            Some(n) => {
166                boxes.truncate(n);
167                boxes
168            }
169            None => boxes,
170        };
171    }
172
173    let cap = max_det.unwrap_or(usize::MAX);
174    let mut survivors: usize = 0;
175
176    // Outer loop over all boxes.
177    for i in 0..boxes.len() {
178        if boxes[i].0.score < 0.0 {
179            // this box was merged with a different box earlier
180            continue;
181        }
182        for j in (i + 1)..boxes.len() {
183            // Inner loop over boxes with lower score (later in the list).
184
185            if boxes[j].0.score < 0.0 {
186                // this box was suppressed by different box earlier
187                continue;
188            }
189            if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
190                // max_box(boxes[j].bbox, &mut boxes[i].bbox);
191                boxes[j].0.score = -1.0;
192            }
193        }
194        survivors += 1;
195        if survivors >= cap {
196            break;
197        }
198    }
199
200    // Filter out suppressed boxes; cap at `max_det` for the same reason as
201    // `nms_float`.
202    boxes
203        .into_iter()
204        .filter(|b| b.0.score >= 0.0)
205        .take(cap)
206        .collect()
207}
208
209/// Class-aware NMS: only suppress boxes with the same label.
210///
211/// Sorts boxes by score, then greedily selects a subset of boxes in descending
212/// order of score. Unlike class-agnostic NMS, boxes are only suppressed if they
213/// have the same class label AND overlap above the IoU threshold.
214///
215/// # Example
216/// ```
217/// # use edgefirst_decoder::{BoundingBox, DetectBox, float::nms_class_aware_float};
218/// let boxes = vec![
219///     DetectBox {
220///         bbox: BoundingBox::new(0.0, 0.0, 0.5, 0.5),
221///         score: 0.9,
222///         label: 0,
223///     },
224///     DetectBox {
225///         bbox: BoundingBox::new(0.1, 0.1, 0.6, 0.6),
226///         score: 0.8,
227///         label: 1,
228///     }, // different class
229/// ];
230/// // Both boxes survive because they have different labels
231/// let result = nms_class_aware_float(0.3, None, boxes);
232/// assert_eq!(result.len(), 2);
233/// ```
234#[must_use]
235pub fn nms_class_aware_float(
236    iou: f32,
237    max_det: Option<usize>,
238    mut boxes: Vec<DetectBox>,
239) -> Vec<DetectBox> {
240    boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
241
242    if iou >= 1.0 {
243        return match max_det {
244            Some(n) => {
245                boxes.truncate(n);
246                boxes
247            }
248            None => boxes,
249        };
250    }
251
252    let cap = max_det.unwrap_or(usize::MAX);
253    let mut survivors: usize = 0;
254
255    for i in 0..boxes.len() {
256        if boxes[i].score < 0.0 {
257            continue;
258        }
259        for j in (i + 1)..boxes.len() {
260            if boxes[j].score < 0.0 {
261                continue;
262            }
263            // Only suppress if same class AND overlapping
264            if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
265                boxes[j].score = -1.0;
266            }
267        }
268        survivors += 1;
269        if survivors >= cap {
270            break;
271        }
272    }
273    boxes
274        .into_iter()
275        .filter(|b| b.score >= 0.0)
276        .take(cap)
277        .collect()
278}
279
280/// Class-aware NMS with extra data: only suppress boxes with the same label.
281///
282/// This is same as `nms_class_aware_float` but will also include extra
283/// information along with each box, such as the index.
284#[must_use]
285pub fn nms_extra_class_aware_float<E: Send + Sync>(
286    iou: f32,
287    max_det: Option<usize>,
288    mut boxes: Vec<(DetectBox, E)>,
289) -> Vec<(DetectBox, E)> {
290    boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
291
292    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
293    // immediately
294    if iou >= 1.0 {
295        return match max_det {
296            Some(n) => {
297                boxes.truncate(n);
298                boxes
299            }
300            None => boxes,
301        };
302    }
303
304    let cap = max_det.unwrap_or(usize::MAX);
305    let mut survivors: usize = 0;
306
307    for i in 0..boxes.len() {
308        if boxes[i].0.score < 0.0 {
309            continue;
310        }
311        for j in (i + 1)..boxes.len() {
312            if boxes[j].0.score < 0.0 {
313                continue;
314            }
315            // Only suppress if same class AND overlapping
316            if boxes[j].0.label == boxes[i].0.label
317                && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
318            {
319                boxes[j].0.score = -1.0;
320            }
321        }
322        survivors += 1;
323        if survivors >= cap {
324            break;
325        }
326    }
327    boxes
328        .into_iter()
329        .filter(|b| b.0.score >= 0.0)
330        .take(cap)
331        .collect()
332}
333
334/// Returns true if the IOU of the given bounding boxes is greater than the iou
335/// threshold
336///
337/// # Example
338/// ```
339/// # use edgefirst_decoder::{BoundingBox, float::jaccard};
340/// let a = BoundingBox::new(0.0, 0.0, 0.2, 0.2);
341/// let b = BoundingBox::new(0.1, 0.1, 0.3, 0.3);
342/// let iou_threshold = 0.1;
343/// let result = jaccard(&a, &b, iou_threshold);
344/// assert!(result);
345/// ```
346pub fn jaccard(a: &BoundingBox, b: &BoundingBox, iou: f32) -> bool {
347    let left = a.xmin.max(b.xmin);
348    let top = a.ymin.max(b.ymin);
349    let right = a.xmax.min(b.xmax);
350    let bottom = a.ymax.min(b.ymax);
351
352    let intersection = (right - left).max(0.0) * (bottom - top).max(0.0);
353    let area_a = (a.xmax - a.xmin) * (a.ymax - a.ymin);
354    let area_b = (b.xmax - b.xmin) * (b.ymax - b.ymin);
355
356    // need to make sure we are not dividing by zero
357    let union = area_a + area_b - intersection;
358
359    intersection > iou * union
360}
361
362#[cfg(test)]
363mod tests {
364    use super::*;
365    use crate::BoundingBox;
366
367    /// Helper: create `n` non-overlapping boxes with descending f32 scores.
368    fn make_nms_boxes_float(n: usize) -> Vec<DetectBox> {
369        (0..n)
370            .map(|i| DetectBox {
371                bbox: BoundingBox {
372                    xmin: i as f32 * 100.0,
373                    ymin: 0.0,
374                    xmax: i as f32 * 100.0 + 10.0,
375                    ymax: 10.0,
376                },
377                label: 0,
378                score: 1.0 - i as f32 * 0.01,
379            })
380            .collect()
381    }
382
383    #[test]
384    fn nms_float_max_det_matches_full_truncated() {
385        let boxes = make_nms_boxes_float(20);
386        let n = 5;
387        let full = nms_float(0.5, None, boxes.clone());
388        let capped = nms_float(0.5, Some(n), boxes);
389        assert_eq!(capped.len(), n);
390        for (f, c) in full[..n].iter().zip(capped.iter()) {
391            assert_eq!(f.bbox, c.bbox);
392            assert_eq!(f.score, c.score);
393        }
394    }
395
396    #[test]
397    fn nms_float_max_det_zero_returns_empty() {
398        let boxes = make_nms_boxes_float(10);
399        let result = nms_float(0.5, Some(0), boxes);
400        assert!(result.is_empty());
401    }
402
403    #[test]
404    fn nms_float_max_det_iou_ge_1_returns_sorted_truncated() {
405        let boxes = make_nms_boxes_float(10);
406        let result = nms_float(1.0, Some(3), boxes);
407        assert_eq!(result.len(), 3);
408        assert!(result[0].score >= result[1].score);
409        assert!(result[1].score >= result[2].score);
410    }
411
412    #[test]
413    fn nms_float_max_det_larger_than_input() {
414        let boxes = make_nms_boxes_float(5);
415        let full = nms_float(0.5, None, boxes.clone());
416        let capped = nms_float(0.5, Some(100), boxes);
417        assert_eq!(full.len(), capped.len());
418    }
419}