1#[cfg(target_arch = "aarch64")]
5use crate::arg_max_i8;
6use crate::{
7 arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
8};
9use ndarray::{
10 parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
11 Array1, ArrayView1, ArrayView2, Zip,
12};
13use num_traits::{AsPrimitive, PrimInt};
14use rayon::slice::ParallelSliceMut;
15
16#[cfg(target_arch = "aarch64")]
27unsafe fn column_max_update_neon(
28 col_ptr: *const u8,
29 max_ptr: *mut u8,
30 class_ptr: *mut u8,
31 n: usize,
32 class_idx: u8,
33 signed: bool,
34) {
35 use std::arch::aarch64::*;
36
37 let class_vec = vdupq_n_u8(class_idx);
38 let chunks = n / 16;
39 let remainder = n % 16;
40
41 if signed {
42 for chunk in 0..chunks {
44 let offset = chunk * 16;
45 let col = vld1q_s8(col_ptr.add(offset) as *const i8);
46 let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
47 let mask = vcgeq_s8(col, cur_max);
49 let new_max = vmaxq_s8(col, cur_max);
51 vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
52 let cur_class = vld1q_u8(class_ptr.add(offset));
54 let new_class = vbslq_u8(mask, class_vec, cur_class);
55 vst1q_u8(class_ptr.add(offset), new_class);
56 }
57 for i in (chunks * 16)..n {
59 let val = *(col_ptr.add(i) as *const i8);
60 let cur = *(max_ptr.add(i) as *const i8);
61 if val >= cur {
62 *(max_ptr.add(i) as *mut i8) = val;
63 *class_ptr.add(i) = class_idx;
64 }
65 }
66 } else {
67 for chunk in 0..chunks {
69 let offset = chunk * 16;
70 let col = vld1q_u8(col_ptr.add(offset));
71 let cur_max = vld1q_u8(max_ptr.add(offset));
72 let mask = vcgeq_u8(col, cur_max);
73 let new_max = vmaxq_u8(col, cur_max);
74 vst1q_u8(max_ptr.add(offset), new_max);
75 let cur_class = vld1q_u8(class_ptr.add(offset));
76 let new_class = vbslq_u8(mask, class_vec, cur_class);
77 vst1q_u8(class_ptr.add(offset), new_class);
78 }
79 for i in (chunks * 16)..n {
81 let val = *col_ptr.add(i);
82 let cur = *max_ptr.add(i);
83 if val >= cur {
84 *max_ptr.add(i) = val;
85 *class_ptr.add(i) = class_idx;
86 }
87 }
88 }
89 let _ = remainder; }
91
92#[inline(always)]
94fn fast_arg_max<T: PrimInt + Copy>(score: ArrayView1<T>) -> (T, usize) {
95 #[cfg(target_arch = "aarch64")]
96 {
97 if std::mem::size_of::<T>() == 1 && score.as_slice().is_some() {
99 let slice = score.as_slice().unwrap();
100 let ptr = slice.as_ptr() as *const i8;
104 let i8_slice = unsafe { std::slice::from_raw_parts(ptr, slice.len()) };
105 if T::min_value() < T::zero() {
108 let (max_val, idx) = arg_max_i8(i8_slice);
109 let result: T = unsafe { std::mem::transmute_copy(&max_val) };
112 return (result, idx);
113 }
114 }
115 }
116 arg_max(score)
117}
118
119#[doc(hidden)]
124pub fn postprocess_boxes_quant<
125 B: BBoxTypeTrait,
126 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
127 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
128>(
129 threshold: Scores,
130 boxes: ArrayView2<Boxes>,
131 scores: ArrayView2<Scores>,
132 quant_boxes: Quantization,
133) -> Vec<DetectBoxQuantized<Scores>> {
134 assert_eq!(scores.dim().0, boxes.dim().0);
135 assert_eq!(boxes.dim().1, 4);
136
137 if scores.strides()[0] == 1 && scores.as_slice().is_none() {
139 return postprocess_boxes_quant_column_major::<B, _, _>(
140 threshold,
141 boxes,
142 scores,
143 quant_boxes,
144 );
145 }
146
147 Zip::from(scores.rows())
148 .and(boxes.rows())
149 .into_par_iter()
150 .filter_map(|(score, bbox)| {
151 let (score_, label) = fast_arg_max(score);
152 if score_ < threshold {
153 return None;
154 }
155
156 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
157 Some(DetectBoxQuantized {
158 label,
159 score: score_,
160 bbox: BoundingBox::from(bbox_quant),
161 })
162 })
163 .collect()
164}
165
166fn postprocess_boxes_quant_column_major<
168 B: BBoxTypeTrait,
169 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
170 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
171>(
172 threshold: Scores,
173 boxes: ArrayView2<Boxes>,
174 scores: ArrayView2<Scores>,
175 quant_boxes: Quantization,
176) -> Vec<DetectBoxQuantized<Scores>> {
177 let (n_candidates, n_classes) = scores.dim();
178
179 if n_classes > 255 {
181 return Zip::from(scores.rows())
182 .and(boxes.rows())
183 .into_par_iter()
184 .filter_map(|(score, bbox)| {
185 let (score_, label) = fast_arg_max(score);
186 if score_ < threshold {
187 return None;
188 }
189 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
190 Some(DetectBoxQuantized {
191 label,
192 score: score_,
193 bbox: BoundingBox::from(bbox_quant),
194 })
195 })
196 .collect();
197 }
198 let mut max_scores = vec![Scores::min_value(); n_candidates];
199 let mut max_classes = vec![0u8; n_candidates];
200
201 for class_idx in 0..n_classes {
202 let col = scores.column(class_idx);
203 if let Some(slice) = col.as_slice() {
204 #[cfg(target_arch = "aarch64")]
205 {
206 if std::mem::size_of::<Scores>() == 1 {
207 unsafe {
208 column_max_update_neon(
209 slice.as_ptr() as *const u8,
210 max_scores.as_mut_ptr() as *mut u8,
211 max_classes.as_mut_ptr(),
212 n_candidates,
213 class_idx as u8,
214 Scores::min_value() < Scores::zero(),
215 );
216 }
217 continue;
218 }
219 }
220 for (i, &val) in slice.iter().enumerate() {
221 if val >= max_scores[i] {
222 max_scores[i] = val;
223 max_classes[i] = class_idx as u8;
224 }
225 }
226 } else {
227 for (i, &val) in col.iter().enumerate() {
228 if val >= max_scores[i] {
229 max_scores[i] = val;
230 max_classes[i] = class_idx as u8;
231 }
232 }
233 }
234 }
235
236 let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
238 let mut cols: [Vec<Boxes>; 4] = [
239 vec![Boxes::zero(); n_candidates],
240 vec![Boxes::zero(); n_candidates],
241 vec![Boxes::zero(); n_candidates],
242 vec![Boxes::zero(); n_candidates],
243 ];
244 for (dim, col_buf) in cols.iter_mut().enumerate() {
245 let col = boxes.column(dim);
246 if let Some(slice) = col.as_slice() {
247 col_buf.copy_from_slice(slice);
248 } else {
249 for (i, &val) in col.iter().enumerate() {
250 col_buf[i] = val;
251 }
252 }
253 }
254 cols
255 } else {
256 [vec![], vec![], vec![], vec![]]
257 };
258 let boxes_copied = !boxes_buf[0].is_empty();
259
260 let mut result = Vec::new();
261 for i in 0..n_candidates {
262 if max_scores[i] >= threshold {
263 let bbox_quant = if boxes_copied {
264 let raw = [
265 boxes_buf[0][i],
266 boxes_buf[1][i],
267 boxes_buf[2][i],
268 boxes_buf[3][i],
269 ];
270 B::to_xyxy_dequant(&raw, quant_boxes)
271 } else {
272 B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
273 };
274 result.push(DetectBoxQuantized {
275 label: max_classes[i] as usize,
276 score: max_scores[i],
277 bbox: BoundingBox::from(bbox_quant),
278 });
279 }
280 }
281
282 result
283}
284
285#[doc(hidden)]
297pub fn postprocess_boxes_index_quant<
298 B: BBoxTypeTrait,
299 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
300 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
301>(
302 threshold: Scores,
303 boxes: ArrayView2<Boxes>,
304 scores: ArrayView2<Scores>,
305 quant_boxes: Quantization,
306) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
307 assert_eq!(scores.dim().0, boxes.dim().0);
308 assert_eq!(boxes.dim().1, 4);
309
310 if scores.strides()[0] == 1 && scores.as_slice().is_none() {
314 return postprocess_boxes_index_quant_column_major::<B, _, _>(
315 threshold,
316 boxes,
317 scores,
318 quant_boxes,
319 );
320 }
321
322 let indices: Array1<usize> = (0..boxes.dim().0).collect();
323 Zip::from(scores.rows())
324 .and(boxes.rows())
325 .and(&indices)
326 .into_par_iter()
327 .filter_map(|(score, bbox, index)| {
328 let (score_, label) = fast_arg_max(score);
329 if score_ < threshold {
330 return None;
331 }
332
333 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
334
335 Some((
336 DetectBoxQuantized {
337 label,
338 score: score_,
339 bbox: BoundingBox::from(bbox_quant),
340 },
341 *index,
342 ))
343 })
344 .collect()
345}
346
347fn postprocess_boxes_index_quant_column_major<
355 B: BBoxTypeTrait,
356 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
357 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
358>(
359 threshold: Scores,
360 boxes: ArrayView2<Boxes>,
361 scores: ArrayView2<Scores>,
362 quant_boxes: Quantization,
363) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
364 let (n_candidates, n_classes) = scores.dim();
365
366 if n_classes > 255 {
370 let indices: Array1<usize> = (0..n_candidates).collect();
371 return Zip::from(scores.rows())
372 .and(boxes.rows())
373 .and(&indices)
374 .into_par_iter()
375 .filter_map(|(score, bbox, index)| {
376 let (score_, label) = fast_arg_max(score);
377 if score_ < threshold {
378 return None;
379 }
380 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
381 Some((
382 DetectBoxQuantized {
383 label,
384 score: score_,
385 bbox: BoundingBox::from(bbox_quant),
386 },
387 *index,
388 ))
389 })
390 .collect();
391 }
392 let mut max_scores = vec![Scores::min_value(); n_candidates];
393 let mut max_classes = vec![0u8; n_candidates];
394 for class_idx in 0..n_classes {
395 let col = scores.column(class_idx);
396 if let Some(slice) = col.as_slice() {
397 #[cfg(target_arch = "aarch64")]
399 {
400 if std::mem::size_of::<Scores>() == 1 {
401 unsafe {
404 column_max_update_neon(
405 slice.as_ptr() as *const u8,
406 max_scores.as_mut_ptr() as *mut u8,
407 max_classes.as_mut_ptr(),
408 n_candidates,
409 class_idx as u8,
410 Scores::min_value() < Scores::zero(), );
412 }
413 continue;
414 }
415 }
416 for (i, &val) in slice.iter().enumerate() {
417 if val >= max_scores[i] {
418 max_scores[i] = val;
419 max_classes[i] = class_idx as u8;
420 }
421 }
422 } else {
423 for (i, &val) in col.iter().enumerate() {
424 if val >= max_scores[i] {
425 max_scores[i] = val;
426 max_classes[i] = class_idx as u8;
427 }
428 }
429 }
430 }
431
432 let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
436 let mut cols: [Vec<Boxes>; 4] = [
437 vec![Boxes::zero(); n_candidates],
438 vec![Boxes::zero(); n_candidates],
439 vec![Boxes::zero(); n_candidates],
440 vec![Boxes::zero(); n_candidates],
441 ];
442 for (dim, col_buf) in cols.iter_mut().enumerate() {
443 let col = boxes.column(dim);
444 if let Some(slice) = col.as_slice() {
445 col_buf.copy_from_slice(slice);
446 } else {
447 for (i, &val) in col.iter().enumerate() {
448 col_buf[i] = val;
449 }
450 }
451 }
452 cols
453 } else {
454 [vec![], vec![], vec![], vec![]]
456 };
457 let boxes_copied = !boxes_buf[0].is_empty();
458
459 let mut result = Vec::new();
461 for i in 0..n_candidates {
462 if max_scores[i] >= threshold {
463 let bbox_quant = if boxes_copied {
464 let raw = [
465 boxes_buf[0][i],
466 boxes_buf[1][i],
467 boxes_buf[2][i],
468 boxes_buf[3][i],
469 ];
470 B::to_xyxy_dequant(&raw, quant_boxes)
471 } else {
472 B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
473 };
474 result.push((
475 DetectBoxQuantized {
476 label: max_classes[i] as usize,
477 score: max_scores[i],
478 bbox: BoundingBox::from(bbox_quant),
479 },
480 i,
481 ));
482 }
483 }
484
485 result
486}
487
488#[doc(hidden)]
491#[must_use]
492pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
493 iou: f32,
494 max_det: Option<usize>,
495 mut boxes: Vec<DetectBoxQuantized<SCORE>>,
496) -> Vec<DetectBoxQuantized<SCORE>> {
497 boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
501
502 if iou >= 1.0 {
505 return match max_det {
506 Some(n) => {
507 boxes.truncate(n);
508 boxes
509 }
510 None => boxes,
511 };
512 }
513
514 let min_val = SCORE::min_value();
515 let cap = max_det.unwrap_or(usize::MAX);
516 let mut survivors: usize = 0;
517 for i in 0..boxes.len() {
519 if boxes[i].score <= min_val {
520 continue;
522 }
523 for j in (i + 1)..boxes.len() {
524 if boxes[j].score <= min_val {
527 continue;
529 }
530
531 if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
532 boxes[j].score = min_val;
534 }
535 }
536 survivors += 1;
537 if survivors >= cap {
538 break;
539 }
540 }
541 boxes
544 .into_iter()
545 .filter(|b| b.score > min_val)
546 .take(cap)
547 .collect()
548}
549
550#[doc(hidden)]
556#[must_use]
557pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
558 iou: f32,
559 max_det: Option<usize>,
560 mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
561) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
562 boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
565
566 if iou >= 1.0 {
569 return match max_det {
570 Some(n) => {
571 boxes.truncate(n);
572 boxes
573 }
574 None => boxes,
575 };
576 }
577
578 let min_val = SCORE::min_value();
579 let cap = max_det.unwrap_or(usize::MAX);
580 let mut survivors: usize = 0;
581 for i in 0..boxes.len() {
583 if boxes[i].0.score <= min_val {
584 continue;
586 }
587 for j in (i + 1)..boxes.len() {
588 if boxes[j].0.score <= min_val {
591 continue;
593 }
594 if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
595 boxes[j].0.score = min_val;
597 }
598 }
599 survivors += 1;
600 if survivors >= cap {
601 break;
602 }
603 }
604
605 boxes
607 .into_iter()
608 .filter(|b| b.0.score > min_val)
609 .take(cap)
610 .collect()
611}
612
613#[doc(hidden)]
620#[must_use]
621pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
622 iou: f32,
623 max_det: Option<usize>,
624 mut boxes: Vec<DetectBoxQuantized<SCORE>>,
625) -> Vec<DetectBoxQuantized<SCORE>> {
626 boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
627
628 if iou >= 1.0 {
631 return match max_det {
632 Some(n) => {
633 boxes.truncate(n);
634 boxes
635 }
636 None => boxes,
637 };
638 }
639
640 let min_val = SCORE::min_value();
641 let cap = max_det.unwrap_or(usize::MAX);
642 let mut survivors: usize = 0;
643 for i in 0..boxes.len() {
644 if boxes[i].score <= min_val {
645 continue;
646 }
647 for j in (i + 1)..boxes.len() {
648 if boxes[j].score <= min_val {
649 continue;
650 }
651 if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
653 boxes[j].score = min_val;
654 }
655 }
656 survivors += 1;
657 if survivors >= cap {
658 break;
659 }
660 }
661 boxes
662 .into_iter()
663 .filter(|b| b.score > min_val)
664 .take(cap)
665 .collect()
666}
667
668#[doc(hidden)]
674#[must_use]
675pub fn nms_extra_class_aware_int<
676 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
677 E: Send + Sync,
678>(
679 iou: f32,
680 max_det: Option<usize>,
681 mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
682) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
683 boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
684
685 if iou >= 1.0 {
688 return match max_det {
689 Some(n) => {
690 boxes.truncate(n);
691 boxes
692 }
693 None => boxes,
694 };
695 }
696
697 let min_val = SCORE::min_value();
698 let cap = max_det.unwrap_or(usize::MAX);
699 let mut survivors: usize = 0;
700 for i in 0..boxes.len() {
701 if boxes[i].0.score <= min_val {
702 continue;
703 }
704 for j in (i + 1)..boxes.len() {
705 if boxes[j].0.score <= min_val {
706 continue;
707 }
708 if boxes[j].0.label == boxes[i].0.label
710 && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
711 {
712 boxes[j].0.score = min_val;
713 }
714 }
715 survivors += 1;
716 if survivors >= cap {
717 break;
718 }
719 }
720 boxes
721 .into_iter()
722 .filter(|b| b.0.score > min_val)
723 .take(cap)
724 .collect()
725}
726
727#[doc(hidden)]
742pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
743where
744 f32: AsPrimitive<T>,
745{
746 if quant.scale == 0.0 {
747 return T::max_value();
748 }
749 let v = (score / quant.scale + quant.zero_point as f32).ceil();
750 let v = v.clamp(T::min_value().as_(), T::max_value().as_());
751 v.as_()
752}
753
754#[cfg(test)]
755mod tests {
756 use super::*;
757 use crate::XYWH;
758 use ndarray::Array2;
759
760 #[test]
763 fn column_major_matches_row_major() {
764 let n_classes = 80usize;
766 let n_candidates = 100usize;
767 let mut scores_physical = Array2::<u8>::zeros((n_classes, n_candidates));
768 for c in 0..n_classes {
770 for i in 0..n_candidates {
771 scores_physical[[c, i]] = ((c * 3 + i * 7) % 256) as u8;
772 }
773 }
774
775 let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
777 for i in 0..n_candidates {
778 boxes_physical[[0, i]] = (i * 10) as i16; boxes_physical[[1, i]] = (i * 20) as i16; boxes_physical[[2, i]] = (i * 10 + 50) as i16; boxes_physical[[3, i]] = (i * 20 + 100) as i16; }
783
784 let quant = Quantization {
785 scale: 0.00390625,
786 zero_point: 0,
787 };
788
789 let threshold: u8 = 10;
790
791 let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
793 let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
794 let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
795 threshold,
796 boxes_contiguous.view(),
797 scores_contiguous.view(),
798 quant,
799 );
800
801 let scores_view = scores_physical.view().reversed_axes();
803 let boxes_view = boxes_physical.view().reversed_axes();
804 assert!(scores_view.as_slice().is_none(), "should be non-contiguous");
805 assert_eq!(scores_view.strides()[0], 1);
806 let col_result =
807 postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
808
809 assert_eq!(
811 row_result.len(),
812 col_result.len(),
813 "different number of results: row={}, col={}",
814 row_result.len(),
815 col_result.len()
816 );
817 for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
818 assert_eq!(
819 row.0.label, col.0.label,
820 "candidate {i}: label mismatch row={} col={}",
821 row.0.label, col.0.label
822 );
823 assert_eq!(row.0.score, col.0.score, "candidate {i}: score mismatch");
824 assert_eq!(row.1, col.1, "candidate {i}: index mismatch");
825 assert_eq!(row.0.bbox, col.0.bbox, "candidate {i}: bbox mismatch");
826 }
827 }
828
829 #[test]
831 fn column_major_matches_row_major_i8() {
832 let n_classes = 80usize;
833 let n_candidates = 50usize;
834 let mut scores_physical = Array2::<i8>::zeros((n_classes, n_candidates));
835 for c in 0..n_classes {
836 for i in 0..n_candidates {
837 scores_physical[[c, i]] = ((c as i16 * 3 + i as i16 * 7) % 256 - 128) as i8;
838 }
839 }
840
841 let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
842 for i in 0..n_candidates {
843 boxes_physical[[0, i]] = (i * 10) as i16;
844 boxes_physical[[1, i]] = (i * 20) as i16;
845 boxes_physical[[2, i]] = (i * 10 + 50) as i16;
846 boxes_physical[[3, i]] = (i * 20 + 100) as i16;
847 }
848
849 let quant = Quantization {
850 scale: 0.0256,
851 zero_point: -116,
852 };
853 let threshold: i8 = -100;
854
855 let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
856 let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
857 let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
858 threshold,
859 boxes_contiguous.view(),
860 scores_contiguous.view(),
861 quant,
862 );
863
864 let scores_view = scores_physical.view().reversed_axes();
865 let boxes_view = boxes_physical.view().reversed_axes();
866 let col_result =
867 postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
868
869 assert_eq!(row_result.len(), col_result.len());
870 for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
871 assert_eq!(row.0.label, col.0.label, "i8 candidate {i}: label mismatch");
872 assert_eq!(row.0.score, col.0.score, "i8 candidate {i}: score mismatch");
873 assert_eq!(row.1, col.1, "i8 candidate {i}: index mismatch");
874 }
875 }
876
877 fn make_nms_boxes_int(n: usize) -> Vec<DetectBoxQuantized<u8>> {
879 (0..n)
880 .map(|i| DetectBoxQuantized {
881 bbox: BoundingBox {
882 xmin: i as f32 * 100.0,
883 ymin: 0.0,
884 xmax: i as f32 * 100.0 + 10.0,
885 ymax: 10.0,
886 },
887 label: 0,
888 score: (200 - i as u32).min(255) as u8,
889 })
890 .collect()
891 }
892
893 #[test]
894 fn nms_int_max_det_matches_full_truncated() {
895 let boxes = make_nms_boxes_int(20);
896 let n = 5;
897 let full = nms_int(0.5, None, boxes.clone());
898 let capped = nms_int(0.5, Some(n), boxes);
899 assert_eq!(capped.len(), n);
900 assert_eq!(&full[..n], &capped[..]);
901 }
902
903 #[test]
904 fn nms_int_max_det_zero_returns_empty() {
905 let boxes = make_nms_boxes_int(10);
906 let result = nms_int(0.5, Some(0), boxes);
907 assert!(result.is_empty());
908 }
909
910 #[test]
911 fn nms_int_max_det_iou_ge_1_returns_sorted_truncated() {
912 let boxes = make_nms_boxes_int(10);
913 let result = nms_int(1.0, Some(3), boxes);
914 assert_eq!(result.len(), 3);
915 assert!(result[0].score >= result[1].score);
917 assert!(result[1].score >= result[2].score);
918 }
919
920 #[test]
921 fn nms_int_max_det_larger_than_input() {
922 let boxes = make_nms_boxes_int(5);
923 let full = nms_int(0.5, None, boxes.clone());
924 let capped = nms_int(0.5, Some(100), boxes);
925 assert_eq!(full.len(), capped.len());
926 }
927}