Skip to main content

edgefirst_decoder/
float.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{arg_max, BBoxTypeTrait, BoundingBox, DetectBox};
5use ndarray::{
6    parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
7    Array1, ArrayView2, Zip,
8};
9use num_traits::{AsPrimitive, Float};
10use rayon::slice::ParallelSliceMut;
11
12/// Post processes boxes and scores tensors into detection boxes, filtering out
13/// any boxes below the score threshold. The boxes tensor is converted to XYXY
14/// using the given BBoxTypeTrait. The order of the boxes is preserved.
15pub fn postprocess_boxes_float<
16    B: BBoxTypeTrait,
17    BOX: Float + AsPrimitive<f32> + Send + Sync,
18    SCORE: Float + AsPrimitive<f32> + Send + Sync,
19>(
20    threshold: SCORE,
21    boxes: ArrayView2<BOX>,
22    scores: ArrayView2<SCORE>,
23) -> Vec<DetectBox> {
24    assert_eq!(scores.dim().0, boxes.dim().0);
25    assert_eq!(boxes.dim().1, 4);
26    Zip::from(scores.rows())
27        .and(boxes.rows())
28        .into_par_iter()
29        .filter_map(|(score, bbox)| {
30            let (score_, label) = arg_max(score);
31            if score_ < threshold {
32                return None;
33            }
34
35            let bbox = B::ndarray_to_xyxy_float(bbox);
36            Some(DetectBox {
37                label,
38                score: score_.as_(),
39                bbox: bbox.into(),
40            })
41        })
42        .collect()
43}
44
45/// Post processes boxes and scores tensors into detection boxes, filtering out
46/// any boxes below the score threshold. The boxes tensor is converted to XYXY
47/// using the given BBoxTypeTrait. The order of the boxes is preserved.
48///
49/// This function is very similar to `postprocess_boxes_float` but will also
50/// return the index of the box. The boxes will be in ascending index order.
51pub fn postprocess_boxes_index_float<
52    B: BBoxTypeTrait,
53    BOX: Float + AsPrimitive<f32> + Send + Sync,
54    SCORE: Float + AsPrimitive<f32> + Send + Sync,
55>(
56    threshold: SCORE,
57    boxes: ArrayView2<BOX>,
58    scores: ArrayView2<SCORE>,
59) -> Vec<(DetectBox, usize)> {
60    assert_eq!(scores.dim().0, boxes.dim().0);
61    assert_eq!(boxes.dim().1, 4);
62    let indices: Array1<usize> = (0..boxes.dim().0).collect();
63    Zip::from(scores.rows())
64        .and(boxes.rows())
65        .and(&indices)
66        .into_par_iter()
67        .filter_map(|(score, bbox, i)| {
68            let (score_, label) = arg_max(score);
69            if score_ < threshold {
70                return None;
71            }
72
73            let bbox = B::ndarray_to_xyxy_float(bbox);
74            Some((
75                DetectBox {
76                    label,
77                    score: score_.as_(),
78                    bbox: bbox.into(),
79                },
80                *i,
81            ))
82        })
83        .collect()
84}
85
86/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
87/// then greedily selects a subset of boxes in descending order of score.
88#[must_use]
89pub fn nms_float(iou: f32, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
90    // Boxes get sorted by score in descending order so we know based on the
91    // index the scoring of the boxes and can skip parts of the loop.
92    boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
93
94    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
95    // immediately
96    if iou >= 1.0 {
97        return boxes;
98    }
99
100    // Outer loop over all boxes.
101    for i in 0..boxes.len() {
102        if boxes[i].score < 0.0 {
103            // this box was merged with a different box earlier
104            continue;
105        }
106        for j in (i + 1)..boxes.len() {
107            // Inner loop over boxes with lower score (later in the list).
108
109            if boxes[j].score < 0.0 {
110                // this box was suppressed by different box earlier
111                continue;
112            }
113            if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
114                // max_box(boxes[j].bbox, &mut boxes[i].bbox);
115                boxes[j].score = -1.0;
116            }
117        }
118    }
119    // Filter out suppressed boxes.
120    boxes.into_iter().filter(|b| b.score >= 0.0).collect()
121}
122
123/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
124/// then greedily selects a subset of boxes in descending order of score.
125///
126/// This is same as `nms_float` but will also include extra information along
127/// with each box, such as the index
128#[must_use]
129pub fn nms_extra_float<E: Send + Sync>(
130    iou: f32,
131    mut boxes: Vec<(DetectBox, E)>,
132) -> Vec<(DetectBox, E)> {
133    // Boxes get sorted by score in descending order so we know based on the
134    // index the scoring of the boxes and can skip parts of the loop.
135    boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
136
137    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
138    // immediately
139    if iou >= 1.0 {
140        return boxes;
141    }
142
143    // Outer loop over all boxes.
144    for i in 0..boxes.len() {
145        if boxes[i].0.score <= 0.0 {
146            // this box was merged with a different box earlier
147            continue;
148        }
149        for j in (i + 1)..boxes.len() {
150            // Inner loop over boxes with lower score (later in the list).
151
152            if boxes[j].0.score <= 0.0 {
153                // this box was suppressed by different box earlier
154                continue;
155            }
156            if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
157                // max_box(boxes[j].bbox, &mut boxes[i].bbox);
158                boxes[j].0.score = 0.0;
159            }
160        }
161    }
162
163    // Filter out boxes with a score of 0.0.
164    boxes.into_iter().filter(|b| b.0.score > 0.0).collect()
165}
166
167/// Class-aware NMS: only suppress boxes with the same label.
168///
169/// Sorts boxes by score, then greedily selects a subset of boxes in descending
170/// order of score. Unlike class-agnostic NMS, boxes are only suppressed if they
171/// have the same class label AND overlap above the IoU threshold.
172///
173/// # Example
174/// ```
175/// # use edgefirst_decoder::{BoundingBox, DetectBox, float::nms_class_aware_float};
176/// let boxes = vec![
177///     DetectBox {
178///         bbox: BoundingBox::new(0.0, 0.0, 0.5, 0.5),
179///         score: 0.9,
180///         label: 0,
181///     },
182///     DetectBox {
183///         bbox: BoundingBox::new(0.1, 0.1, 0.6, 0.6),
184///         score: 0.8,
185///         label: 1,
186///     }, // different class
187/// ];
188/// // Both boxes survive because they have different labels
189/// let result = nms_class_aware_float(0.3, boxes);
190/// assert_eq!(result.len(), 2);
191/// ```
192#[must_use]
193pub fn nms_class_aware_float(iou: f32, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
194    boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
195
196    if iou >= 1.0 {
197        return boxes;
198    }
199
200    for i in 0..boxes.len() {
201        if boxes[i].score < 0.0 {
202            continue;
203        }
204        for j in (i + 1)..boxes.len() {
205            if boxes[j].score < 0.0 {
206                continue;
207            }
208            // Only suppress if same class AND overlapping
209            if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
210                boxes[j].score = -1.0;
211            }
212        }
213    }
214    boxes.into_iter().filter(|b| b.score >= 0.0).collect()
215}
216
217/// Class-aware NMS with extra data: only suppress boxes with the same label.
218///
219/// This is same as `nms_class_aware_float` but will also include extra
220/// information along with each box, such as the index.
221#[must_use]
222pub fn nms_extra_class_aware_float<E: Send + Sync>(
223    iou: f32,
224    mut boxes: Vec<(DetectBox, E)>,
225) -> Vec<(DetectBox, E)> {
226    boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
227
228    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
229    // immediately
230    if iou >= 1.0 {
231        return boxes;
232    }
233
234    for i in 0..boxes.len() {
235        if boxes[i].0.score <= 0.0 {
236            continue;
237        }
238        for j in (i + 1)..boxes.len() {
239            if boxes[j].0.score <= 0.0 {
240                continue;
241            }
242            // Only suppress if same class AND overlapping
243            if boxes[j].0.label == boxes[i].0.label
244                && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
245            {
246                boxes[j].0.score = 0.0;
247            }
248        }
249    }
250    boxes.into_iter().filter(|b| b.0.score > 0.0).collect()
251}
252
253/// Returns true if the IOU of the given bounding boxes is greater than the iou
254/// threshold
255///
256/// # Example
257/// ```
258/// # use edgefirst_decoder::{BoundingBox, float::jaccard};
259/// let a = BoundingBox::new(0.0, 0.0, 0.2, 0.2);
260/// let b = BoundingBox::new(0.1, 0.1, 0.3, 0.3);
261/// let iou_threshold = 0.1;
262/// let result = jaccard(&a, &b, iou_threshold);
263/// assert!(result);
264/// ```
265pub fn jaccard(a: &BoundingBox, b: &BoundingBox, iou: f32) -> bool {
266    let left = a.xmin.max(b.xmin);
267    let top = a.ymin.max(b.ymin);
268    let right = a.xmax.min(b.xmax);
269    let bottom = a.ymax.min(b.ymax);
270
271    let intersection = (right - left).max(0.0) * (bottom - top).max(0.0);
272    let area_a = (a.xmax - a.xmin) * (a.ymax - a.ymin);
273    let area_b = (b.xmax - b.xmin) * (b.ymax - b.ymin);
274
275    // need to make sure we are not dividing by zero
276    let union = area_a + area_b - intersection;
277
278    intersection > iou * union
279}