Skip to main content

edgefirst_decoder/
float.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{arg_max, BBoxTypeTrait, BoundingBox, DetectBox};
5use ndarray::{
6    parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
7    Array1, ArrayView2, Zip,
8};
9use num_traits::{AsPrimitive, Float};
10use rayon::slice::ParallelSliceMut;
11
12/// Post processes boxes and scores tensors into detection boxes, filtering out
13/// any boxes below the score threshold. The boxes tensor is converted to XYXY
14/// using the given BBoxTypeTrait. The order of the boxes is preserved.
15pub fn postprocess_boxes_float<
16    B: BBoxTypeTrait,
17    BOX: Float + AsPrimitive<f32> + Send + Sync,
18    SCORE: Float + AsPrimitive<f32> + Send + Sync,
19>(
20    threshold: SCORE,
21    boxes: ArrayView2<BOX>,
22    scores: ArrayView2<SCORE>,
23) -> Vec<DetectBox> {
24    assert_eq!(scores.dim().0, boxes.dim().0);
25    assert_eq!(boxes.dim().1, 4);
26    Zip::from(scores.rows())
27        .and(boxes.rows())
28        .into_par_iter()
29        .filter_map(|(score, bbox)| {
30            let (score_, label) = arg_max(score);
31            if score_ < threshold {
32                return None;
33            }
34
35            let bbox = B::ndarray_to_xyxy_float(bbox);
36            Some(DetectBox {
37                label,
38                score: score_.as_(),
39                bbox: bbox.into(),
40            })
41        })
42        .collect()
43}
44
45/// Post processes boxes and scores tensors into detection boxes, filtering out
46/// any boxes below the score threshold. The boxes tensor is converted to XYXY
47/// using the given BBoxTypeTrait. The order of the boxes is preserved.
48///
49/// This function is very similar to `postprocess_boxes_float` but will also
50/// return the index of the box. The boxes will be in ascending index order.
51pub fn postprocess_boxes_index_float<
52    B: BBoxTypeTrait,
53    BOX: Float + AsPrimitive<f32> + Send + Sync,
54    SCORE: Float + AsPrimitive<f32> + Send + Sync,
55>(
56    threshold: SCORE,
57    boxes: ArrayView2<BOX>,
58    scores: ArrayView2<SCORE>,
59) -> Vec<(DetectBox, usize)> {
60    assert_eq!(scores.dim().0, boxes.dim().0);
61    assert_eq!(boxes.dim().1, 4);
62    let indices: Array1<usize> = (0..boxes.dim().0).collect();
63    Zip::from(scores.rows())
64        .and(boxes.rows())
65        .and(&indices)
66        .into_par_iter()
67        .filter_map(|(score, bbox, i)| {
68            let (score_, label) = arg_max(score);
69            if score_ < threshold {
70                return None;
71            }
72
73            let bbox = B::ndarray_to_xyxy_float(bbox);
74            Some((
75                DetectBox {
76                    label,
77                    score: score_.as_(),
78                    bbox: bbox.into(),
79                },
80                *i,
81            ))
82        })
83        .collect()
84}
85
86/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
87/// then greedily selects a subset of boxes in descending order of score.
88///
89/// If `max_det` is `Some(n)`, the greedy loop stops as soon as `n` survivors
90/// have been confirmed. Because the input is sorted descending, the first `n`
91/// survivors are the highest-scoring `n`, so the post-NMS top-`n` is preserved
92/// without iterating the full O(N²) suppression loop.
93#[must_use]
94pub fn nms_float(iou: f32, max_det: Option<usize>, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
95    // Boxes get sorted by score in descending order so we know based on the
96    // index the scoring of the boxes and can skip parts of the loop.
97    boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
98
99    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
100    // immediately
101    if iou >= 1.0 {
102        return match max_det {
103            Some(n) => {
104                boxes.truncate(n);
105                boxes
106            }
107            None => boxes,
108        };
109    }
110
111    let cap = max_det.unwrap_or(usize::MAX);
112    let mut survivors: usize = 0;
113
114    // Outer loop over all boxes.
115    for i in 0..boxes.len() {
116        if boxes[i].score < 0.0 {
117            // this box was merged with a different box earlier
118            continue;
119        }
120        for j in (i + 1)..boxes.len() {
121            // Inner loop over boxes with lower score (later in the list).
122
123            if boxes[j].score < 0.0 {
124                // this box was suppressed by different box earlier
125                continue;
126            }
127            if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
128                // max_box(boxes[j].bbox, &mut boxes[i].bbox);
129                boxes[j].score = -1.0;
130            }
131        }
132
133        // NOTE: jaccard_batch4_neon is available for callers that can
134        // batch unsuppressed candidates externally. It is not used
135        // inline here because score=-1 marking creates sparse gaps
136        // that prevent contiguous 4-box batching.
137        survivors += 1;
138        if survivors >= cap {
139            break;
140        }
141    }
142    // Filter out suppressed boxes; cap at `max_det` because boxes after the
143    // break may still hold positive scores but score lower than every survivor.
144    boxes
145        .into_iter()
146        .filter(|b| b.score >= 0.0)
147        .take(cap)
148        .collect()
149}
150
151/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
152/// then greedily selects a subset of boxes in descending order of score.
153///
154/// This is same as `nms_float` but will also include extra information along
155/// with each box, such as the index
156#[must_use]
157pub fn nms_extra_float<E: Send + Sync>(
158    iou: f32,
159    max_det: Option<usize>,
160    mut boxes: Vec<(DetectBox, E)>,
161) -> Vec<(DetectBox, E)> {
162    // Boxes get sorted by score in descending order so we know based on the
163    // index the scoring of the boxes and can skip parts of the loop.
164    boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
165
166    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
167    // immediately
168    if iou >= 1.0 {
169        return match max_det {
170            Some(n) => {
171                boxes.truncate(n);
172                boxes
173            }
174            None => boxes,
175        };
176    }
177
178    let cap = max_det.unwrap_or(usize::MAX);
179    let mut survivors: usize = 0;
180
181    // Outer loop over all boxes.
182    for i in 0..boxes.len() {
183        if boxes[i].0.score < 0.0 {
184            // this box was merged with a different box earlier
185            continue;
186        }
187        for j in (i + 1)..boxes.len() {
188            // Inner loop over boxes with lower score (later in the list).
189
190            if boxes[j].0.score < 0.0 {
191                // this box was suppressed by different box earlier
192                continue;
193            }
194            if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
195                // max_box(boxes[j].bbox, &mut boxes[i].bbox);
196                boxes[j].0.score = -1.0;
197            }
198        }
199        survivors += 1;
200        if survivors >= cap {
201            break;
202        }
203    }
204
205    // Filter out suppressed boxes; cap at `max_det` for the same reason as
206    // `nms_float`.
207    boxes
208        .into_iter()
209        .filter(|b| b.0.score >= 0.0)
210        .take(cap)
211        .collect()
212}
213
214/// Class-aware NMS: only suppress boxes with the same label.
215///
216/// Sorts boxes by score, then greedily selects a subset of boxes in descending
217/// order of score. Unlike class-agnostic NMS, boxes are only suppressed if they
218/// have the same class label AND overlap above the IoU threshold.
219///
220/// # Example
221/// ```
222/// # use edgefirst_decoder::{BoundingBox, DetectBox, float::nms_class_aware_float};
223/// let boxes = vec![
224///     DetectBox {
225///         bbox: BoundingBox::new(0.0, 0.0, 0.5, 0.5),
226///         score: 0.9,
227///         label: 0,
228///     },
229///     DetectBox {
230///         bbox: BoundingBox::new(0.1, 0.1, 0.6, 0.6),
231///         score: 0.8,
232///         label: 1,
233///     }, // different class
234/// ];
235/// // Both boxes survive because they have different labels
236/// let result = nms_class_aware_float(0.3, None, boxes);
237/// assert_eq!(result.len(), 2);
238/// ```
239#[must_use]
240pub fn nms_class_aware_float(
241    iou: f32,
242    max_det: Option<usize>,
243    mut boxes: Vec<DetectBox>,
244) -> Vec<DetectBox> {
245    boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
246
247    if iou >= 1.0 {
248        return match max_det {
249            Some(n) => {
250                boxes.truncate(n);
251                boxes
252            }
253            None => boxes,
254        };
255    }
256
257    let cap = max_det.unwrap_or(usize::MAX);
258    let mut survivors: usize = 0;
259
260    for i in 0..boxes.len() {
261        if boxes[i].score < 0.0 {
262            continue;
263        }
264        for j in (i + 1)..boxes.len() {
265            if boxes[j].score < 0.0 {
266                continue;
267            }
268            // Only suppress if same class AND overlapping
269            if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
270                boxes[j].score = -1.0;
271            }
272        }
273        survivors += 1;
274        if survivors >= cap {
275            break;
276        }
277    }
278    boxes
279        .into_iter()
280        .filter(|b| b.score >= 0.0)
281        .take(cap)
282        .collect()
283}
284
285/// Class-aware NMS with extra data: only suppress boxes with the same label.
286///
287/// This is same as `nms_class_aware_float` but will also include extra
288/// information along with each box, such as the index.
289#[must_use]
290pub fn nms_extra_class_aware_float<E: Send + Sync>(
291    iou: f32,
292    max_det: Option<usize>,
293    mut boxes: Vec<(DetectBox, E)>,
294) -> Vec<(DetectBox, E)> {
295    boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
296
297    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
298    // immediately
299    if iou >= 1.0 {
300        return match max_det {
301            Some(n) => {
302                boxes.truncate(n);
303                boxes
304            }
305            None => boxes,
306        };
307    }
308
309    let cap = max_det.unwrap_or(usize::MAX);
310    let mut survivors: usize = 0;
311
312    for i in 0..boxes.len() {
313        if boxes[i].0.score < 0.0 {
314            continue;
315        }
316        for j in (i + 1)..boxes.len() {
317            if boxes[j].0.score < 0.0 {
318                continue;
319            }
320            // Only suppress if same class AND overlapping
321            if boxes[j].0.label == boxes[i].0.label
322                && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
323            {
324                boxes[j].0.score = -1.0;
325            }
326        }
327        survivors += 1;
328        if survivors >= cap {
329            break;
330        }
331    }
332    boxes
333        .into_iter()
334        .filter(|b| b.0.score >= 0.0)
335        .take(cap)
336        .collect()
337}
338
339/// Returns true if the IOU of the given bounding boxes is greater than the iou
340/// threshold
341///
342/// # Example
343/// ```
344/// # use edgefirst_decoder::{BoundingBox, float::jaccard};
345/// let a = BoundingBox::new(0.0, 0.0, 0.2, 0.2);
346/// let b = BoundingBox::new(0.1, 0.1, 0.3, 0.3);
347/// let iou_threshold = 0.1;
348/// let result = jaccard(&a, &b, iou_threshold);
349/// assert!(result);
350/// ```
351pub fn jaccard(a: &BoundingBox, b: &BoundingBox, iou: f32) -> bool {
352    let left = a.xmin.max(b.xmin);
353    let top = a.ymin.max(b.ymin);
354    let right = a.xmax.min(b.xmax);
355    let bottom = a.ymax.min(b.ymax);
356
357    let intersection = (right - left).max(0.0) * (bottom - top).max(0.0);
358    let area_a = (a.xmax - a.xmin) * (a.ymax - a.ymin);
359    let area_b = (b.xmax - b.xmin) * (b.ymax - b.ymin);
360
361    // need to make sure we are not dividing by zero
362    let union = area_a + area_b - intersection;
363
364    intersection > iou * union
365}
366
367/// Batch IoU check: test one reference box `a` against 4 candidate boxes.
368///
369/// Returns a 4-element array of booleans: `result[i]` is true if
370/// `jaccard(a, boxes[i], iou)` would return true.
371///
372/// On aarch64, uses NEON `vmaxq_f32`/`vminq_f32` for vectorized
373/// intersection computation. On other architectures falls back to
374/// 4 scalar `jaccard` calls.
375#[inline]
376pub fn jaccard_batch4(a: &BoundingBox, boxes: &[BoundingBox; 4], iou: f32) -> [bool; 4] {
377    #[cfg(target_arch = "aarch64")]
378    {
379        // SAFETY: NEON is mandatory on aarch64.
380        unsafe { jaccard_batch4_neon(a, boxes, iou) }
381    }
382    #[cfg(not(target_arch = "aarch64"))]
383    {
384        [
385            jaccard(a, &boxes[0], iou),
386            jaccard(a, &boxes[1], iou),
387            jaccard(a, &boxes[2], iou),
388            jaccard(a, &boxes[3], iou),
389        ]
390    }
391}
392
393/// NEON-vectorized batch IoU for 4 candidate boxes against one reference.
394///
395/// Loads xmin/ymin/xmax/ymax of the 4 candidates into separate NEON
396/// registers (AoS→SoA transpose), then computes intersection, union,
397/// and the `intersection > iou * union` test in 4-wide SIMD.
398#[cfg(target_arch = "aarch64")]
399#[target_feature(enable = "neon")]
400unsafe fn jaccard_batch4_neon(a: &BoundingBox, boxes: &[BoundingBox; 4], iou: f32) -> [bool; 4] {
401    use std::arch::aarch64::*;
402
403    let zero = vdupq_n_f32(0.0);
404    let iou_v = vdupq_n_f32(iou);
405
406    // Reference box broadcast.
407    let a_xmin = vdupq_n_f32(a.xmin);
408    let a_ymin = vdupq_n_f32(a.ymin);
409    let a_xmax = vdupq_n_f32(a.xmax);
410    let a_ymax = vdupq_n_f32(a.ymax);
411    let area_a = vmulq_f32(vsubq_f32(a_xmax, a_xmin), vsubq_f32(a_ymax, a_ymin));
412
413    // Load 4 boxes (each BoundingBox is [xmin, ymin, xmax, ymax]).
414    let b0 = vld1q_f32(&boxes[0].xmin as *const f32);
415    let b1 = vld1q_f32(&boxes[1].xmin as *const f32);
416    let b2 = vld1q_f32(&boxes[2].xmin as *const f32);
417    let b3 = vld1q_f32(&boxes[3].xmin as *const f32);
418
419    // AoS → SoA transpose (4×4).
420    let t01_lo = vtrn1q_f32(b0, b1); // xmin0,xmin1,xmax0,xmax1
421    let t01_hi = vtrn2q_f32(b0, b1); // ymin0,ymin1,ymax0,ymax1
422    let t23_lo = vtrn1q_f32(b2, b3);
423    let t23_hi = vtrn2q_f32(b2, b3);
424
425    let b_xmin = vreinterpretq_f32_f64(vtrn1q_f64(
426        vreinterpretq_f64_f32(t01_lo),
427        vreinterpretq_f64_f32(t23_lo),
428    ));
429    let b_ymin = vreinterpretq_f32_f64(vtrn1q_f64(
430        vreinterpretq_f64_f32(t01_hi),
431        vreinterpretq_f64_f32(t23_hi),
432    ));
433    let b_xmax = vreinterpretq_f32_f64(vtrn2q_f64(
434        vreinterpretq_f64_f32(t01_lo),
435        vreinterpretq_f64_f32(t23_lo),
436    ));
437    let b_ymax = vreinterpretq_f32_f64(vtrn2q_f64(
438        vreinterpretq_f64_f32(t01_hi),
439        vreinterpretq_f64_f32(t23_hi),
440    ));
441
442    // Intersection.
443    let left = vmaxq_f32(a_xmin, b_xmin);
444    let top = vmaxq_f32(a_ymin, b_ymin);
445    let right = vminq_f32(a_xmax, b_xmax);
446    let bottom = vminq_f32(a_ymax, b_ymax);
447    let w = vmaxq_f32(vsubq_f32(right, left), zero);
448    let h = vmaxq_f32(vsubq_f32(bottom, top), zero);
449    let intersection = vmulq_f32(w, h);
450
451    // Area B.
452    let area_b = vmulq_f32(vsubq_f32(b_xmax, b_xmin), vsubq_f32(b_ymax, b_ymin));
453
454    // Union = area_a + area_b - intersection.
455    let union = vsubq_f32(vaddq_f32(area_a, area_b), intersection);
456
457    // Test: intersection > iou * union (equivalent to IoU > threshold).
458    let iou_union = vmulq_f32(iou_v, union);
459    let mask = vcgtq_f32(intersection, iou_union);
460
461    // Extract per-lane results.
462    [
463        vgetq_lane_u32(mask, 0) != 0,
464        vgetq_lane_u32(mask, 1) != 0,
465        vgetq_lane_u32(mask, 2) != 0,
466        vgetq_lane_u32(mask, 3) != 0,
467    ]
468}
469
470#[cfg(test)]
471mod tests {
472    use super::*;
473    use crate::BoundingBox;
474
475    /// Helper: create `n` non-overlapping boxes with descending f32 scores.
476    fn make_nms_boxes_float(n: usize) -> Vec<DetectBox> {
477        (0..n)
478            .map(|i| DetectBox {
479                bbox: BoundingBox {
480                    xmin: i as f32 * 100.0,
481                    ymin: 0.0,
482                    xmax: i as f32 * 100.0 + 10.0,
483                    ymax: 10.0,
484                },
485                label: 0,
486                score: 1.0 - i as f32 * 0.01,
487            })
488            .collect()
489    }
490
491    #[test]
492    fn nms_float_max_det_matches_full_truncated() {
493        let boxes = make_nms_boxes_float(20);
494        let n = 5;
495        let full = nms_float(0.5, None, boxes.clone());
496        let capped = nms_float(0.5, Some(n), boxes);
497        assert_eq!(capped.len(), n);
498        for (f, c) in full[..n].iter().zip(capped.iter()) {
499            assert_eq!(f.bbox, c.bbox);
500            assert_eq!(f.score, c.score);
501        }
502    }
503
504    #[test]
505    fn nms_float_max_det_zero_returns_empty() {
506        let boxes = make_nms_boxes_float(10);
507        let result = nms_float(0.5, Some(0), boxes);
508        assert!(result.is_empty());
509    }
510
511    #[test]
512    fn nms_float_max_det_iou_ge_1_returns_sorted_truncated() {
513        let boxes = make_nms_boxes_float(10);
514        let result = nms_float(1.0, Some(3), boxes);
515        assert_eq!(result.len(), 3);
516        assert!(result[0].score >= result[1].score);
517        assert!(result[1].score >= result[2].score);
518    }
519
520    #[test]
521    fn nms_float_max_det_larger_than_input() {
522        let boxes = make_nms_boxes_float(5);
523        let full = nms_float(0.5, None, boxes.clone());
524        let capped = nms_float(0.5, Some(100), boxes);
525        assert_eq!(full.len(), capped.len());
526    }
527
528    #[test]
529    fn jaccard_batch4_matches_scalar() {
530        let a = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
531        let boxes = [
532            BoundingBox::new(5.0, 5.0, 15.0, 15.0),   // overlap
533            BoundingBox::new(20.0, 20.0, 30.0, 30.0), // no overlap
534            BoundingBox::new(0.0, 0.0, 10.0, 10.0),   // identical
535            BoundingBox::new(8.0, 8.0, 18.0, 18.0),   // small overlap
536        ];
537        let iou_threshold = 0.1;
538        let batch = jaccard_batch4(&a, &boxes, iou_threshold);
539        for (i, b) in boxes.iter().enumerate() {
540            let scalar = jaccard(&a, b, iou_threshold);
541            assert_eq!(
542                batch[i], scalar,
543                "batch4 mismatch at {i}: batch={} scalar={}",
544                batch[i], scalar
545            );
546        }
547    }
548}