edgefirst_decoder/float.rs
1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{arg_max, BBoxTypeTrait, BoundingBox, DetectBox};
5use ndarray::{
6 parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
7 Array1, ArrayView2, Zip,
8};
9use num_traits::{AsPrimitive, Float};
10use rayon::slice::ParallelSliceMut;
11
12/// Post processes boxes and scores tensors into detection boxes, filtering out
13/// any boxes below the score threshold. The boxes tensor is converted to XYXY
14/// using the given BBoxTypeTrait. The order of the boxes is preserved.
15pub fn postprocess_boxes_float<
16 B: BBoxTypeTrait,
17 BOX: Float + AsPrimitive<f32> + Send + Sync,
18 SCORE: Float + AsPrimitive<f32> + Send + Sync,
19>(
20 threshold: SCORE,
21 boxes: ArrayView2<BOX>,
22 scores: ArrayView2<SCORE>,
23) -> Vec<DetectBox> {
24 assert_eq!(scores.dim().0, boxes.dim().0);
25 assert_eq!(boxes.dim().1, 4);
26 Zip::from(scores.rows())
27 .and(boxes.rows())
28 .into_par_iter()
29 .filter_map(|(score, bbox)| {
30 let (score_, label) = arg_max(score);
31 if score_ < threshold {
32 return None;
33 }
34
35 let bbox = B::ndarray_to_xyxy_float(bbox);
36 Some(DetectBox {
37 label,
38 score: score_.as_(),
39 bbox: bbox.into(),
40 })
41 })
42 .collect()
43}
44
45/// Post processes boxes and scores tensors into detection boxes, filtering out
46/// any boxes below the score threshold. The boxes tensor is converted to XYXY
47/// using the given BBoxTypeTrait. The order of the boxes is preserved.
48///
49/// This function is very similar to `postprocess_boxes_float` but will also
50/// return the index of the box. The boxes will be in ascending index order.
51pub fn postprocess_boxes_index_float<
52 B: BBoxTypeTrait,
53 BOX: Float + AsPrimitive<f32> + Send + Sync,
54 SCORE: Float + AsPrimitive<f32> + Send + Sync,
55>(
56 threshold: SCORE,
57 boxes: ArrayView2<BOX>,
58 scores: ArrayView2<SCORE>,
59) -> Vec<(DetectBox, usize)> {
60 assert_eq!(scores.dim().0, boxes.dim().0);
61 assert_eq!(boxes.dim().1, 4);
62 let indices: Array1<usize> = (0..boxes.dim().0).collect();
63 Zip::from(scores.rows())
64 .and(boxes.rows())
65 .and(&indices)
66 .into_par_iter()
67 .filter_map(|(score, bbox, i)| {
68 let (score_, label) = arg_max(score);
69 if score_ < threshold {
70 return None;
71 }
72
73 let bbox = B::ndarray_to_xyxy_float(bbox);
74 Some((
75 DetectBox {
76 label,
77 score: score_.as_(),
78 bbox: bbox.into(),
79 },
80 *i,
81 ))
82 })
83 .collect()
84}
85
86/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
87/// then greedily selects a subset of boxes in descending order of score.
88#[must_use]
89pub fn nms_float(iou: f32, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
90 // Boxes get sorted by score in descending order so we know based on the
91 // index the scoring of the boxes and can skip parts of the loop.
92 boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
93
94 // When the iou is 1.0 or larger, no boxes will be filtered so we just return
95 // immediately
96 if iou >= 1.0 {
97 return boxes;
98 }
99
100 // Outer loop over all boxes.
101 for i in 0..boxes.len() {
102 if boxes[i].score < 0.0 {
103 // this box was merged with a different box earlier
104 continue;
105 }
106 for j in (i + 1)..boxes.len() {
107 // Inner loop over boxes with lower score (later in the list).
108
109 if boxes[j].score < 0.0 {
110 // this box was suppressed by different box earlier
111 continue;
112 }
113 if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
114 // max_box(boxes[j].bbox, &mut boxes[i].bbox);
115 boxes[j].score = -1.0;
116 }
117 }
118 }
119 // Filter out suppressed boxes.
120 boxes.into_iter().filter(|b| b.score >= 0.0).collect()
121}
122
123/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
124/// then greedily selects a subset of boxes in descending order of score.
125///
126/// This is same as `nms_float` but will also include extra information along
127/// with each box, such as the index
128#[must_use]
129pub fn nms_extra_float<E: Send + Sync>(
130 iou: f32,
131 mut boxes: Vec<(DetectBox, E)>,
132) -> Vec<(DetectBox, E)> {
133 // Boxes get sorted by score in descending order so we know based on the
134 // index the scoring of the boxes and can skip parts of the loop.
135 boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
136
137 // When the iou is 1.0 or larger, no boxes will be filtered so we just return
138 // immediately
139 if iou >= 1.0 {
140 return boxes;
141 }
142
143 // Outer loop over all boxes.
144 for i in 0..boxes.len() {
145 if boxes[i].0.score <= 0.0 {
146 // this box was merged with a different box earlier
147 continue;
148 }
149 for j in (i + 1)..boxes.len() {
150 // Inner loop over boxes with lower score (later in the list).
151
152 if boxes[j].0.score <= 0.0 {
153 // this box was suppressed by different box earlier
154 continue;
155 }
156 if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
157 // max_box(boxes[j].bbox, &mut boxes[i].bbox);
158 boxes[j].0.score = 0.0;
159 }
160 }
161 }
162
163 // Filter out boxes with a score of 0.0.
164 boxes.into_iter().filter(|b| b.0.score > 0.0).collect()
165}
166
167/// Class-aware NMS: only suppress boxes with the same label.
168///
169/// Sorts boxes by score, then greedily selects a subset of boxes in descending
170/// order of score. Unlike class-agnostic NMS, boxes are only suppressed if they
171/// have the same class label AND overlap above the IoU threshold.
172///
173/// # Example
174/// ```
175/// # use edgefirst_decoder::{BoundingBox, DetectBox, float::nms_class_aware_float};
176/// let boxes = vec![
177/// DetectBox {
178/// bbox: BoundingBox::new(0.0, 0.0, 0.5, 0.5),
179/// score: 0.9,
180/// label: 0,
181/// },
182/// DetectBox {
183/// bbox: BoundingBox::new(0.1, 0.1, 0.6, 0.6),
184/// score: 0.8,
185/// label: 1,
186/// }, // different class
187/// ];
188/// // Both boxes survive because they have different labels
189/// let result = nms_class_aware_float(0.3, boxes);
190/// assert_eq!(result.len(), 2);
191/// ```
192#[must_use]
193pub fn nms_class_aware_float(iou: f32, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
194 boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
195
196 if iou >= 1.0 {
197 return boxes;
198 }
199
200 for i in 0..boxes.len() {
201 if boxes[i].score < 0.0 {
202 continue;
203 }
204 for j in (i + 1)..boxes.len() {
205 if boxes[j].score < 0.0 {
206 continue;
207 }
208 // Only suppress if same class AND overlapping
209 if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
210 boxes[j].score = -1.0;
211 }
212 }
213 }
214 boxes.into_iter().filter(|b| b.score >= 0.0).collect()
215}
216
217/// Class-aware NMS with extra data: only suppress boxes with the same label.
218///
219/// This is same as `nms_class_aware_float` but will also include extra
220/// information along with each box, such as the index.
221#[must_use]
222pub fn nms_extra_class_aware_float<E: Send + Sync>(
223 iou: f32,
224 mut boxes: Vec<(DetectBox, E)>,
225) -> Vec<(DetectBox, E)> {
226 boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
227
228 // When the iou is 1.0 or larger, no boxes will be filtered so we just return
229 // immediately
230 if iou >= 1.0 {
231 return boxes;
232 }
233
234 for i in 0..boxes.len() {
235 if boxes[i].0.score <= 0.0 {
236 continue;
237 }
238 for j in (i + 1)..boxes.len() {
239 if boxes[j].0.score <= 0.0 {
240 continue;
241 }
242 // Only suppress if same class AND overlapping
243 if boxes[j].0.label == boxes[i].0.label
244 && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
245 {
246 boxes[j].0.score = 0.0;
247 }
248 }
249 }
250 boxes.into_iter().filter(|b| b.0.score > 0.0).collect()
251}
252
253/// Returns true if the IOU of the given bounding boxes is greater than the iou
254/// threshold
255///
256/// # Example
257/// ```
258/// # use edgefirst_decoder::{BoundingBox, float::jaccard};
259/// let a = BoundingBox::new(0.0, 0.0, 0.2, 0.2);
260/// let b = BoundingBox::new(0.1, 0.1, 0.3, 0.3);
261/// let iou_threshold = 0.1;
262/// let result = jaccard(&a, &b, iou_threshold);
263/// assert!(result);
264/// ```
265pub fn jaccard(a: &BoundingBox, b: &BoundingBox, iou: f32) -> bool {
266 let left = a.xmin.max(b.xmin);
267 let top = a.ymin.max(b.ymin);
268 let right = a.xmax.min(b.xmax);
269 let bottom = a.ymax.min(b.ymax);
270
271 let intersection = (right - left).max(0.0) * (bottom - top).max(0.0);
272 let area_a = (a.xmax - a.xmin) * (a.ymax - a.ymin);
273 let area_b = (b.xmax - b.xmin) * (b.ymax - b.ymin);
274
275 // need to make sure we are not dividing by zero
276 let union = area_a + area_b - intersection;
277
278 intersection > iou * union
279}