1use crate::{arg_max, BBoxTypeTrait, BoundingBox, DetectBox};
5use ndarray::{
6 parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
7 Array1, ArrayView2, Zip,
8};
9use num_traits::{AsPrimitive, Float};
10use rayon::slice::ParallelSliceMut;
11
12pub fn postprocess_boxes_float<
16 B: BBoxTypeTrait,
17 BOX: Float + AsPrimitive<f32> + Send + Sync,
18 SCORE: Float + AsPrimitive<f32> + Send + Sync,
19>(
20 threshold: SCORE,
21 boxes: ArrayView2<BOX>,
22 scores: ArrayView2<SCORE>,
23) -> Vec<DetectBox> {
24 assert_eq!(scores.dim().0, boxes.dim().0);
25 assert_eq!(boxes.dim().1, 4);
26 Zip::from(scores.rows())
27 .and(boxes.rows())
28 .into_par_iter()
29 .filter_map(|(score, bbox)| {
30 let (score_, label) = arg_max(score);
31 if score_ < threshold {
32 return None;
33 }
34
35 let bbox = B::ndarray_to_xyxy_float(bbox);
36 Some(DetectBox {
37 label,
38 score: score_.as_(),
39 bbox: bbox.into(),
40 })
41 })
42 .collect()
43}
44
45pub fn postprocess_boxes_index_float<
52 B: BBoxTypeTrait,
53 BOX: Float + AsPrimitive<f32> + Send + Sync,
54 SCORE: Float + AsPrimitive<f32> + Send + Sync,
55>(
56 threshold: SCORE,
57 boxes: ArrayView2<BOX>,
58 scores: ArrayView2<SCORE>,
59) -> Vec<(DetectBox, usize)> {
60 assert_eq!(scores.dim().0, boxes.dim().0);
61 assert_eq!(boxes.dim().1, 4);
62 let indices: Array1<usize> = (0..boxes.dim().0).collect();
63 Zip::from(scores.rows())
64 .and(boxes.rows())
65 .and(&indices)
66 .into_par_iter()
67 .filter_map(|(score, bbox, i)| {
68 let (score_, label) = arg_max(score);
69 if score_ < threshold {
70 return None;
71 }
72
73 let bbox = B::ndarray_to_xyxy_float(bbox);
74 Some((
75 DetectBox {
76 label,
77 score: score_.as_(),
78 bbox: bbox.into(),
79 },
80 *i,
81 ))
82 })
83 .collect()
84}
85
86#[must_use]
94pub fn nms_float(iou: f32, max_det: Option<usize>, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
95 boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
98
99 if iou >= 1.0 {
102 return match max_det {
103 Some(n) => {
104 boxes.truncate(n);
105 boxes
106 }
107 None => boxes,
108 };
109 }
110
111 let cap = max_det.unwrap_or(usize::MAX);
112 let mut survivors: usize = 0;
113
114 for i in 0..boxes.len() {
116 if boxes[i].score < 0.0 {
117 continue;
119 }
120 for j in (i + 1)..boxes.len() {
121 if boxes[j].score < 0.0 {
124 continue;
126 }
127 if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
128 boxes[j].score = -1.0;
130 }
131 }
132 survivors += 1;
133 if survivors >= cap {
134 break;
135 }
136 }
137 boxes
140 .into_iter()
141 .filter(|b| b.score >= 0.0)
142 .take(cap)
143 .collect()
144}
145
146#[must_use]
152pub fn nms_extra_float<E: Send + Sync>(
153 iou: f32,
154 max_det: Option<usize>,
155 mut boxes: Vec<(DetectBox, E)>,
156) -> Vec<(DetectBox, E)> {
157 boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
160
161 if iou >= 1.0 {
164 return match max_det {
165 Some(n) => {
166 boxes.truncate(n);
167 boxes
168 }
169 None => boxes,
170 };
171 }
172
173 let cap = max_det.unwrap_or(usize::MAX);
174 let mut survivors: usize = 0;
175
176 for i in 0..boxes.len() {
178 if boxes[i].0.score < 0.0 {
179 continue;
181 }
182 for j in (i + 1)..boxes.len() {
183 if boxes[j].0.score < 0.0 {
186 continue;
188 }
189 if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
190 boxes[j].0.score = -1.0;
192 }
193 }
194 survivors += 1;
195 if survivors >= cap {
196 break;
197 }
198 }
199
200 boxes
203 .into_iter()
204 .filter(|b| b.0.score >= 0.0)
205 .take(cap)
206 .collect()
207}
208
209#[must_use]
235pub fn nms_class_aware_float(
236 iou: f32,
237 max_det: Option<usize>,
238 mut boxes: Vec<DetectBox>,
239) -> Vec<DetectBox> {
240 boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
241
242 if iou >= 1.0 {
243 return match max_det {
244 Some(n) => {
245 boxes.truncate(n);
246 boxes
247 }
248 None => boxes,
249 };
250 }
251
252 let cap = max_det.unwrap_or(usize::MAX);
253 let mut survivors: usize = 0;
254
255 for i in 0..boxes.len() {
256 if boxes[i].score < 0.0 {
257 continue;
258 }
259 for j in (i + 1)..boxes.len() {
260 if boxes[j].score < 0.0 {
261 continue;
262 }
263 if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
265 boxes[j].score = -1.0;
266 }
267 }
268 survivors += 1;
269 if survivors >= cap {
270 break;
271 }
272 }
273 boxes
274 .into_iter()
275 .filter(|b| b.score >= 0.0)
276 .take(cap)
277 .collect()
278}
279
280#[must_use]
285pub fn nms_extra_class_aware_float<E: Send + Sync>(
286 iou: f32,
287 max_det: Option<usize>,
288 mut boxes: Vec<(DetectBox, E)>,
289) -> Vec<(DetectBox, E)> {
290 boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
291
292 if iou >= 1.0 {
295 return match max_det {
296 Some(n) => {
297 boxes.truncate(n);
298 boxes
299 }
300 None => boxes,
301 };
302 }
303
304 let cap = max_det.unwrap_or(usize::MAX);
305 let mut survivors: usize = 0;
306
307 for i in 0..boxes.len() {
308 if boxes[i].0.score < 0.0 {
309 continue;
310 }
311 for j in (i + 1)..boxes.len() {
312 if boxes[j].0.score < 0.0 {
313 continue;
314 }
315 if boxes[j].0.label == boxes[i].0.label
317 && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
318 {
319 boxes[j].0.score = -1.0;
320 }
321 }
322 survivors += 1;
323 if survivors >= cap {
324 break;
325 }
326 }
327 boxes
328 .into_iter()
329 .filter(|b| b.0.score >= 0.0)
330 .take(cap)
331 .collect()
332}
333
334pub fn jaccard(a: &BoundingBox, b: &BoundingBox, iou: f32) -> bool {
347 let left = a.xmin.max(b.xmin);
348 let top = a.ymin.max(b.ymin);
349 let right = a.xmax.min(b.xmax);
350 let bottom = a.ymax.min(b.ymax);
351
352 let intersection = (right - left).max(0.0) * (bottom - top).max(0.0);
353 let area_a = (a.xmax - a.xmin) * (a.ymax - a.ymin);
354 let area_b = (b.xmax - b.xmin) * (b.ymax - b.ymin);
355
356 let union = area_a + area_b - intersection;
358
359 intersection > iou * union
360}
361
362#[cfg(test)]
363mod tests {
364 use super::*;
365 use crate::BoundingBox;
366
367 fn make_nms_boxes_float(n: usize) -> Vec<DetectBox> {
369 (0..n)
370 .map(|i| DetectBox {
371 bbox: BoundingBox {
372 xmin: i as f32 * 100.0,
373 ymin: 0.0,
374 xmax: i as f32 * 100.0 + 10.0,
375 ymax: 10.0,
376 },
377 label: 0,
378 score: 1.0 - i as f32 * 0.01,
379 })
380 .collect()
381 }
382
383 #[test]
384 fn nms_float_max_det_matches_full_truncated() {
385 let boxes = make_nms_boxes_float(20);
386 let n = 5;
387 let full = nms_float(0.5, None, boxes.clone());
388 let capped = nms_float(0.5, Some(n), boxes);
389 assert_eq!(capped.len(), n);
390 for (f, c) in full[..n].iter().zip(capped.iter()) {
391 assert_eq!(f.bbox, c.bbox);
392 assert_eq!(f.score, c.score);
393 }
394 }
395
396 #[test]
397 fn nms_float_max_det_zero_returns_empty() {
398 let boxes = make_nms_boxes_float(10);
399 let result = nms_float(0.5, Some(0), boxes);
400 assert!(result.is_empty());
401 }
402
403 #[test]
404 fn nms_float_max_det_iou_ge_1_returns_sorted_truncated() {
405 let boxes = make_nms_boxes_float(10);
406 let result = nms_float(1.0, Some(3), boxes);
407 assert_eq!(result.len(), 3);
408 assert!(result[0].score >= result[1].score);
409 assert!(result[1].score >= result[2].score);
410 }
411
412 #[test]
413 fn nms_float_max_det_larger_than_input() {
414 let boxes = make_nms_boxes_float(5);
415 let full = nms_float(0.5, None, boxes.clone());
416 let capped = nms_float(0.5, Some(100), boxes);
417 assert_eq!(full.len(), capped.len());
418 }
419}