1use crate::{arg_max, BBoxTypeTrait, BoundingBox, DetectBox};
5use ndarray::{
6 parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
7 Array1, ArrayView2, Zip,
8};
9use num_traits::{AsPrimitive, Float};
10use rayon::slice::ParallelSliceMut;
11
12pub fn postprocess_boxes_float<
16 B: BBoxTypeTrait,
17 BOX: Float + AsPrimitive<f32> + Send + Sync,
18 SCORE: Float + AsPrimitive<f32> + Send + Sync,
19>(
20 threshold: SCORE,
21 boxes: ArrayView2<BOX>,
22 scores: ArrayView2<SCORE>,
23) -> Vec<DetectBox> {
24 assert_eq!(scores.dim().0, boxes.dim().0);
25 assert_eq!(boxes.dim().1, 4);
26 Zip::from(scores.rows())
27 .and(boxes.rows())
28 .into_par_iter()
29 .filter_map(|(score, bbox)| {
30 let (score_, label) = arg_max(score);
31 if score_ < threshold {
32 return None;
33 }
34
35 let bbox = B::ndarray_to_xyxy_float(bbox);
36 Some(DetectBox {
37 label,
38 score: score_.as_(),
39 bbox: bbox.into(),
40 })
41 })
42 .collect()
43}
44
45pub fn postprocess_boxes_index_float<
52 B: BBoxTypeTrait,
53 BOX: Float + AsPrimitive<f32> + Send + Sync,
54 SCORE: Float + AsPrimitive<f32> + Send + Sync,
55>(
56 threshold: SCORE,
57 boxes: ArrayView2<BOX>,
58 scores: ArrayView2<SCORE>,
59) -> Vec<(DetectBox, usize)> {
60 assert_eq!(scores.dim().0, boxes.dim().0);
61 assert_eq!(boxes.dim().1, 4);
62 let indices: Array1<usize> = (0..boxes.dim().0).collect();
63 Zip::from(scores.rows())
64 .and(boxes.rows())
65 .and(&indices)
66 .into_par_iter()
67 .filter_map(|(score, bbox, i)| {
68 let (score_, label) = arg_max(score);
69 if score_ < threshold {
70 return None;
71 }
72
73 let bbox = B::ndarray_to_xyxy_float(bbox);
74 Some((
75 DetectBox {
76 label,
77 score: score_.as_(),
78 bbox: bbox.into(),
79 },
80 *i,
81 ))
82 })
83 .collect()
84}
85
86#[must_use]
94pub fn nms_float(iou: f32, max_det: Option<usize>, mut boxes: Vec<DetectBox>) -> Vec<DetectBox> {
95 boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
98
99 if iou >= 1.0 {
102 return match max_det {
103 Some(n) => {
104 boxes.truncate(n);
105 boxes
106 }
107 None => boxes,
108 };
109 }
110
111 let cap = max_det.unwrap_or(usize::MAX);
112 let mut survivors: usize = 0;
113
114 for i in 0..boxes.len() {
116 if boxes[i].score < 0.0 {
117 continue;
119 }
120 for j in (i + 1)..boxes.len() {
121 if boxes[j].score < 0.0 {
124 continue;
126 }
127 if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
128 boxes[j].score = -1.0;
130 }
131 }
132
133 survivors += 1;
138 if survivors >= cap {
139 break;
140 }
141 }
142 boxes
145 .into_iter()
146 .filter(|b| b.score >= 0.0)
147 .take(cap)
148 .collect()
149}
150
151#[must_use]
157pub fn nms_extra_float<E: Send + Sync>(
158 iou: f32,
159 max_det: Option<usize>,
160 mut boxes: Vec<(DetectBox, E)>,
161) -> Vec<(DetectBox, E)> {
162 boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
165
166 if iou >= 1.0 {
169 return match max_det {
170 Some(n) => {
171 boxes.truncate(n);
172 boxes
173 }
174 None => boxes,
175 };
176 }
177
178 let cap = max_det.unwrap_or(usize::MAX);
179 let mut survivors: usize = 0;
180
181 for i in 0..boxes.len() {
183 if boxes[i].0.score < 0.0 {
184 continue;
186 }
187 for j in (i + 1)..boxes.len() {
188 if boxes[j].0.score < 0.0 {
191 continue;
193 }
194 if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
195 boxes[j].0.score = -1.0;
197 }
198 }
199 survivors += 1;
200 if survivors >= cap {
201 break;
202 }
203 }
204
205 boxes
208 .into_iter()
209 .filter(|b| b.0.score >= 0.0)
210 .take(cap)
211 .collect()
212}
213
214#[must_use]
240pub fn nms_class_aware_float(
241 iou: f32,
242 max_det: Option<usize>,
243 mut boxes: Vec<DetectBox>,
244) -> Vec<DetectBox> {
245 boxes.par_sort_by(|a, b| b.score.total_cmp(&a.score));
246
247 if iou >= 1.0 {
248 return match max_det {
249 Some(n) => {
250 boxes.truncate(n);
251 boxes
252 }
253 None => boxes,
254 };
255 }
256
257 let cap = max_det.unwrap_or(usize::MAX);
258 let mut survivors: usize = 0;
259
260 for i in 0..boxes.len() {
261 if boxes[i].score < 0.0 {
262 continue;
263 }
264 for j in (i + 1)..boxes.len() {
265 if boxes[j].score < 0.0 {
266 continue;
267 }
268 if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
270 boxes[j].score = -1.0;
271 }
272 }
273 survivors += 1;
274 if survivors >= cap {
275 break;
276 }
277 }
278 boxes
279 .into_iter()
280 .filter(|b| b.score >= 0.0)
281 .take(cap)
282 .collect()
283}
284
285#[must_use]
290pub fn nms_extra_class_aware_float<E: Send + Sync>(
291 iou: f32,
292 max_det: Option<usize>,
293 mut boxes: Vec<(DetectBox, E)>,
294) -> Vec<(DetectBox, E)> {
295 boxes.par_sort_by(|a, b| b.0.score.total_cmp(&a.0.score));
296
297 if iou >= 1.0 {
300 return match max_det {
301 Some(n) => {
302 boxes.truncate(n);
303 boxes
304 }
305 None => boxes,
306 };
307 }
308
309 let cap = max_det.unwrap_or(usize::MAX);
310 let mut survivors: usize = 0;
311
312 for i in 0..boxes.len() {
313 if boxes[i].0.score < 0.0 {
314 continue;
315 }
316 for j in (i + 1)..boxes.len() {
317 if boxes[j].0.score < 0.0 {
318 continue;
319 }
320 if boxes[j].0.label == boxes[i].0.label
322 && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
323 {
324 boxes[j].0.score = -1.0;
325 }
326 }
327 survivors += 1;
328 if survivors >= cap {
329 break;
330 }
331 }
332 boxes
333 .into_iter()
334 .filter(|b| b.0.score >= 0.0)
335 .take(cap)
336 .collect()
337}
338
339pub fn jaccard(a: &BoundingBox, b: &BoundingBox, iou: f32) -> bool {
352 let left = a.xmin.max(b.xmin);
353 let top = a.ymin.max(b.ymin);
354 let right = a.xmax.min(b.xmax);
355 let bottom = a.ymax.min(b.ymax);
356
357 let intersection = (right - left).max(0.0) * (bottom - top).max(0.0);
358 let area_a = (a.xmax - a.xmin) * (a.ymax - a.ymin);
359 let area_b = (b.xmax - b.xmin) * (b.ymax - b.ymin);
360
361 let union = area_a + area_b - intersection;
363
364 intersection > iou * union
365}
366
367#[inline]
376pub fn jaccard_batch4(a: &BoundingBox, boxes: &[BoundingBox; 4], iou: f32) -> [bool; 4] {
377 #[cfg(target_arch = "aarch64")]
378 {
379 unsafe { jaccard_batch4_neon(a, boxes, iou) }
381 }
382 #[cfg(not(target_arch = "aarch64"))]
383 {
384 [
385 jaccard(a, &boxes[0], iou),
386 jaccard(a, &boxes[1], iou),
387 jaccard(a, &boxes[2], iou),
388 jaccard(a, &boxes[3], iou),
389 ]
390 }
391}
392
393#[cfg(target_arch = "aarch64")]
399#[target_feature(enable = "neon")]
400unsafe fn jaccard_batch4_neon(a: &BoundingBox, boxes: &[BoundingBox; 4], iou: f32) -> [bool; 4] {
401 use std::arch::aarch64::*;
402
403 let zero = vdupq_n_f32(0.0);
404 let iou_v = vdupq_n_f32(iou);
405
406 let a_xmin = vdupq_n_f32(a.xmin);
408 let a_ymin = vdupq_n_f32(a.ymin);
409 let a_xmax = vdupq_n_f32(a.xmax);
410 let a_ymax = vdupq_n_f32(a.ymax);
411 let area_a = vmulq_f32(vsubq_f32(a_xmax, a_xmin), vsubq_f32(a_ymax, a_ymin));
412
413 let b0 = vld1q_f32(&boxes[0].xmin as *const f32);
415 let b1 = vld1q_f32(&boxes[1].xmin as *const f32);
416 let b2 = vld1q_f32(&boxes[2].xmin as *const f32);
417 let b3 = vld1q_f32(&boxes[3].xmin as *const f32);
418
419 let t01_lo = vtrn1q_f32(b0, b1); let t01_hi = vtrn2q_f32(b0, b1); let t23_lo = vtrn1q_f32(b2, b3);
423 let t23_hi = vtrn2q_f32(b2, b3);
424
425 let b_xmin = vreinterpretq_f32_f64(vtrn1q_f64(
426 vreinterpretq_f64_f32(t01_lo),
427 vreinterpretq_f64_f32(t23_lo),
428 ));
429 let b_ymin = vreinterpretq_f32_f64(vtrn1q_f64(
430 vreinterpretq_f64_f32(t01_hi),
431 vreinterpretq_f64_f32(t23_hi),
432 ));
433 let b_xmax = vreinterpretq_f32_f64(vtrn2q_f64(
434 vreinterpretq_f64_f32(t01_lo),
435 vreinterpretq_f64_f32(t23_lo),
436 ));
437 let b_ymax = vreinterpretq_f32_f64(vtrn2q_f64(
438 vreinterpretq_f64_f32(t01_hi),
439 vreinterpretq_f64_f32(t23_hi),
440 ));
441
442 let left = vmaxq_f32(a_xmin, b_xmin);
444 let top = vmaxq_f32(a_ymin, b_ymin);
445 let right = vminq_f32(a_xmax, b_xmax);
446 let bottom = vminq_f32(a_ymax, b_ymax);
447 let w = vmaxq_f32(vsubq_f32(right, left), zero);
448 let h = vmaxq_f32(vsubq_f32(bottom, top), zero);
449 let intersection = vmulq_f32(w, h);
450
451 let area_b = vmulq_f32(vsubq_f32(b_xmax, b_xmin), vsubq_f32(b_ymax, b_ymin));
453
454 let union = vsubq_f32(vaddq_f32(area_a, area_b), intersection);
456
457 let iou_union = vmulq_f32(iou_v, union);
459 let mask = vcgtq_f32(intersection, iou_union);
460
461 [
463 vgetq_lane_u32(mask, 0) != 0,
464 vgetq_lane_u32(mask, 1) != 0,
465 vgetq_lane_u32(mask, 2) != 0,
466 vgetq_lane_u32(mask, 3) != 0,
467 ]
468}
469
470#[cfg(test)]
471mod tests {
472 use super::*;
473 use crate::BoundingBox;
474
475 fn make_nms_boxes_float(n: usize) -> Vec<DetectBox> {
477 (0..n)
478 .map(|i| DetectBox {
479 bbox: BoundingBox {
480 xmin: i as f32 * 100.0,
481 ymin: 0.0,
482 xmax: i as f32 * 100.0 + 10.0,
483 ymax: 10.0,
484 },
485 label: 0,
486 score: 1.0 - i as f32 * 0.01,
487 })
488 .collect()
489 }
490
491 #[test]
492 fn nms_float_max_det_matches_full_truncated() {
493 let boxes = make_nms_boxes_float(20);
494 let n = 5;
495 let full = nms_float(0.5, None, boxes.clone());
496 let capped = nms_float(0.5, Some(n), boxes);
497 assert_eq!(capped.len(), n);
498 for (f, c) in full[..n].iter().zip(capped.iter()) {
499 assert_eq!(f.bbox, c.bbox);
500 assert_eq!(f.score, c.score);
501 }
502 }
503
504 #[test]
505 fn nms_float_max_det_zero_returns_empty() {
506 let boxes = make_nms_boxes_float(10);
507 let result = nms_float(0.5, Some(0), boxes);
508 assert!(result.is_empty());
509 }
510
511 #[test]
512 fn nms_float_max_det_iou_ge_1_returns_sorted_truncated() {
513 let boxes = make_nms_boxes_float(10);
514 let result = nms_float(1.0, Some(3), boxes);
515 assert_eq!(result.len(), 3);
516 assert!(result[0].score >= result[1].score);
517 assert!(result[1].score >= result[2].score);
518 }
519
520 #[test]
521 fn nms_float_max_det_larger_than_input() {
522 let boxes = make_nms_boxes_float(5);
523 let full = nms_float(0.5, None, boxes.clone());
524 let capped = nms_float(0.5, Some(100), boxes);
525 assert_eq!(full.len(), capped.len());
526 }
527
528 #[test]
529 fn jaccard_batch4_matches_scalar() {
530 let a = BoundingBox::new(0.0, 0.0, 10.0, 10.0);
531 let boxes = [
532 BoundingBox::new(5.0, 5.0, 15.0, 15.0), BoundingBox::new(20.0, 20.0, 30.0, 30.0), BoundingBox::new(0.0, 0.0, 10.0, 10.0), BoundingBox::new(8.0, 8.0, 18.0, 18.0), ];
537 let iou_threshold = 0.1;
538 let batch = jaccard_batch4(&a, &boxes, iou_threshold);
539 for (i, b) in boxes.iter().enumerate() {
540 let scalar = jaccard(&a, b, iou_threshold);
541 assert_eq!(
542 batch[i], scalar,
543 "batch4 mismatch at {i}: batch={} scalar={}",
544 batch[i], scalar
545 );
546 }
547 }
548}