1#[cfg(target_arch = "aarch64")]
5use crate::arg_max_i8;
6use crate::{
7 arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
8};
9use ndarray::{
10 parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
11 Array1, ArrayView1, ArrayView2, Zip,
12};
13use num_traits::{AsPrimitive, PrimInt};
14use rayon::slice::ParallelSliceMut;
15
16#[cfg(target_arch = "aarch64")]
27unsafe fn column_max_update_neon(
28 col_ptr: *const u8,
29 max_ptr: *mut u8,
30 class_ptr: *mut u8,
31 n: usize,
32 class_idx: u8,
33 signed: bool,
34) {
35 use std::arch::aarch64::*;
36
37 let class_vec = vdupq_n_u8(class_idx);
38 let chunks = n / 16;
39 let remainder = n % 16;
40
41 if signed {
42 for chunk in 0..chunks {
44 let offset = chunk * 16;
45 let col = vld1q_s8(col_ptr.add(offset) as *const i8);
46 let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
47 let mask = vcgeq_s8(col, cur_max);
49 let new_max = vmaxq_s8(col, cur_max);
51 vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
52 let cur_class = vld1q_u8(class_ptr.add(offset));
54 let new_class = vbslq_u8(mask, class_vec, cur_class);
55 vst1q_u8(class_ptr.add(offset), new_class);
56 }
57 for i in (chunks * 16)..n {
59 let val = *(col_ptr.add(i) as *const i8);
60 let cur = *(max_ptr.add(i) as *const i8);
61 if val >= cur {
62 *(max_ptr.add(i) as *mut i8) = val;
63 *class_ptr.add(i) = class_idx;
64 }
65 }
66 } else {
67 for chunk in 0..chunks {
69 let offset = chunk * 16;
70 let col = vld1q_u8(col_ptr.add(offset));
71 let cur_max = vld1q_u8(max_ptr.add(offset));
72 let mask = vcgeq_u8(col, cur_max);
73 let new_max = vmaxq_u8(col, cur_max);
74 vst1q_u8(max_ptr.add(offset), new_max);
75 let cur_class = vld1q_u8(class_ptr.add(offset));
76 let new_class = vbslq_u8(mask, class_vec, cur_class);
77 vst1q_u8(class_ptr.add(offset), new_class);
78 }
79 for i in (chunks * 16)..n {
81 let val = *col_ptr.add(i);
82 let cur = *max_ptr.add(i);
83 if val >= cur {
84 *max_ptr.add(i) = val;
85 *class_ptr.add(i) = class_idx;
86 }
87 }
88 }
89 let _ = remainder; }
91
92#[cfg(target_arch = "aarch64")]
105#[allow(dead_code)] unsafe fn column_max_update_neon_prefetch(
107 col_ptr: *const u8,
108 max_ptr: *mut u8,
109 class_ptr: *mut u8,
110 n: usize,
111 class_idx: u8,
112 signed: bool,
113) {
114 use std::arch::aarch64::*;
115
116 const PREFETCH_AHEAD: usize = 128; let class_vec = vdupq_n_u8(class_idx);
119 let chunks = n / 16;
120
121 if signed {
122 for chunk in 0..chunks {
123 let offset = chunk * 16;
124 if offset + PREFETCH_AHEAD < n {
126 core::arch::asm!(
127 "prfm pldl1strm, [{ptr}]",
128 ptr = in(reg) col_ptr.add(offset + PREFETCH_AHEAD),
129 options(nostack, preserves_flags),
130 );
131 }
132 let col = vld1q_s8(col_ptr.add(offset) as *const i8);
133 let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
134 let mask = vcgeq_s8(col, cur_max);
135 let new_max = vmaxq_s8(col, cur_max);
136 vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
137 let cur_class = vld1q_u8(class_ptr.add(offset));
138 let new_class = vbslq_u8(mask, class_vec, cur_class);
139 vst1q_u8(class_ptr.add(offset), new_class);
140 }
141 for i in (chunks * 16)..n {
142 let val = *(col_ptr.add(i) as *const i8);
143 let cur = *(max_ptr.add(i) as *const i8);
144 if val >= cur {
145 *(max_ptr.add(i) as *mut i8) = val;
146 *class_ptr.add(i) = class_idx;
147 }
148 }
149 } else {
150 for chunk in 0..chunks {
151 let offset = chunk * 16;
152 if offset + PREFETCH_AHEAD < n {
153 core::arch::asm!(
154 "prfm pldl1strm, [{ptr}]",
155 ptr = in(reg) col_ptr.add(offset + PREFETCH_AHEAD),
156 options(nostack, preserves_flags),
157 );
158 }
159 let col = vld1q_u8(col_ptr.add(offset));
160 let cur_max = vld1q_u8(max_ptr.add(offset));
161 let mask = vcgeq_u8(col, cur_max);
162 let new_max = vmaxq_u8(col, cur_max);
163 vst1q_u8(max_ptr.add(offset), new_max);
164 let cur_class = vld1q_u8(class_ptr.add(offset));
165 let new_class = vbslq_u8(mask, class_vec, cur_class);
166 vst1q_u8(class_ptr.add(offset), new_class);
167 }
168 for i in (chunks * 16)..n {
169 let val = *col_ptr.add(i);
170 let cur = *max_ptr.add(i);
171 if val >= cur {
172 *max_ptr.add(i) = val;
173 *class_ptr.add(i) = class_idx;
174 }
175 }
176 }
177}
178
179#[inline(always)]
181fn fast_arg_max<T: PrimInt + Copy>(score: ArrayView1<T>) -> (T, usize) {
182 #[cfg(target_arch = "aarch64")]
183 {
184 if std::mem::size_of::<T>() == 1 && score.as_slice().is_some() {
186 let slice = score.as_slice().unwrap();
187 let ptr = slice.as_ptr() as *const i8;
191 let i8_slice = unsafe { std::slice::from_raw_parts(ptr, slice.len()) };
192 if T::min_value() < T::zero() {
195 let (max_val, idx) = arg_max_i8(i8_slice);
196 let result: T = unsafe { std::mem::transmute_copy(&max_val) };
199 return (result, idx);
200 }
201 }
202 }
203 arg_max(score)
204}
205
206#[doc(hidden)]
211pub fn postprocess_boxes_quant<
212 B: BBoxTypeTrait,
213 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
214 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
215>(
216 threshold: Scores,
217 boxes: ArrayView2<Boxes>,
218 scores: ArrayView2<Scores>,
219 quant_boxes: Quantization,
220) -> Vec<DetectBoxQuantized<Scores>> {
221 assert_eq!(scores.dim().0, boxes.dim().0);
222 assert_eq!(boxes.dim().1, 4);
223
224 if scores.strides()[0] == 1 && scores.as_slice().is_none() {
226 return postprocess_boxes_quant_column_major::<B, _, _>(
227 threshold,
228 boxes,
229 scores,
230 quant_boxes,
231 );
232 }
233
234 Zip::from(scores.rows())
235 .and(boxes.rows())
236 .into_par_iter()
237 .filter_map(|(score, bbox)| {
238 let (score_, label) = fast_arg_max(score);
239 if score_ < threshold {
240 return None;
241 }
242
243 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
244 Some(DetectBoxQuantized {
245 label,
246 score: score_,
247 bbox: BoundingBox::from(bbox_quant),
248 })
249 })
250 .collect()
251}
252
253fn postprocess_boxes_quant_column_major<
255 B: BBoxTypeTrait,
256 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
257 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
258>(
259 threshold: Scores,
260 boxes: ArrayView2<Boxes>,
261 scores: ArrayView2<Scores>,
262 quant_boxes: Quantization,
263) -> Vec<DetectBoxQuantized<Scores>> {
264 let (n_candidates, n_classes) = scores.dim();
265
266 if n_classes > 255 {
268 return Zip::from(scores.rows())
269 .and(boxes.rows())
270 .into_par_iter()
271 .filter_map(|(score, bbox)| {
272 let (score_, label) = fast_arg_max(score);
273 if score_ < threshold {
274 return None;
275 }
276 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
277 Some(DetectBoxQuantized {
278 label,
279 score: score_,
280 bbox: BoundingBox::from(bbox_quant),
281 })
282 })
283 .collect();
284 }
285 let mut max_scores = vec![Scores::min_value(); n_candidates];
286 let mut max_classes = vec![0u8; n_candidates];
287
288 for class_idx in 0..n_classes {
289 let col = scores.column(class_idx);
290 if let Some(slice) = col.as_slice() {
291 #[cfg(target_arch = "aarch64")]
292 {
293 if std::mem::size_of::<Scores>() == 1 {
294 unsafe {
295 column_max_update_neon(
298 slice.as_ptr() as *const u8,
299 max_scores.as_mut_ptr() as *mut u8,
300 max_classes.as_mut_ptr(),
301 n_candidates,
302 class_idx as u8,
303 Scores::min_value() < Scores::zero(),
304 );
305 }
306 continue;
307 }
308 }
309 for (i, &val) in slice.iter().enumerate() {
310 if val >= max_scores[i] {
311 max_scores[i] = val;
312 max_classes[i] = class_idx as u8;
313 }
314 }
315 } else {
316 for (i, &val) in col.iter().enumerate() {
317 if val >= max_scores[i] {
318 max_scores[i] = val;
319 max_classes[i] = class_idx as u8;
320 }
321 }
322 }
323 }
324
325 let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
327 let mut cols: [Vec<Boxes>; 4] = [
328 vec![Boxes::zero(); n_candidates],
329 vec![Boxes::zero(); n_candidates],
330 vec![Boxes::zero(); n_candidates],
331 vec![Boxes::zero(); n_candidates],
332 ];
333 for (dim, col_buf) in cols.iter_mut().enumerate() {
334 let col = boxes.column(dim);
335 if let Some(slice) = col.as_slice() {
336 col_buf.copy_from_slice(slice);
337 } else {
338 for (i, &val) in col.iter().enumerate() {
339 col_buf[i] = val;
340 }
341 }
342 }
343 cols
344 } else {
345 [vec![], vec![], vec![], vec![]]
346 };
347 let boxes_copied = !boxes_buf[0].is_empty();
348
349 let mut result = Vec::new();
350 for i in 0..n_candidates {
351 if max_scores[i] >= threshold {
352 let bbox_quant = if boxes_copied {
353 let raw = [
354 boxes_buf[0][i],
355 boxes_buf[1][i],
356 boxes_buf[2][i],
357 boxes_buf[3][i],
358 ];
359 B::to_xyxy_dequant(&raw, quant_boxes)
360 } else {
361 B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
362 };
363 result.push(DetectBoxQuantized {
364 label: max_classes[i] as usize,
365 score: max_scores[i],
366 bbox: BoundingBox::from(bbox_quant),
367 });
368 }
369 }
370
371 result
372}
373
374#[doc(hidden)]
386pub fn postprocess_boxes_index_quant<
387 B: BBoxTypeTrait,
388 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
389 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
390>(
391 threshold: Scores,
392 boxes: ArrayView2<Boxes>,
393 scores: ArrayView2<Scores>,
394 quant_boxes: Quantization,
395) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
396 assert_eq!(scores.dim().0, boxes.dim().0);
397 assert_eq!(boxes.dim().1, 4);
398
399 if scores.strides()[0] == 1 && scores.as_slice().is_none() {
403 return postprocess_boxes_index_quant_column_major::<B, _, _>(
404 threshold,
405 boxes,
406 scores,
407 quant_boxes,
408 );
409 }
410
411 let indices: Array1<usize> = (0..boxes.dim().0).collect();
412 Zip::from(scores.rows())
413 .and(boxes.rows())
414 .and(&indices)
415 .into_par_iter()
416 .filter_map(|(score, bbox, index)| {
417 let (score_, label) = fast_arg_max(score);
418 if score_ < threshold {
419 return None;
420 }
421
422 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
423
424 Some((
425 DetectBoxQuantized {
426 label,
427 score: score_,
428 bbox: BoundingBox::from(bbox_quant),
429 },
430 *index,
431 ))
432 })
433 .collect()
434}
435
436fn postprocess_boxes_index_quant_column_major<
444 B: BBoxTypeTrait,
445 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
446 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
447>(
448 threshold: Scores,
449 boxes: ArrayView2<Boxes>,
450 scores: ArrayView2<Scores>,
451 quant_boxes: Quantization,
452) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
453 let (n_candidates, n_classes) = scores.dim();
454
455 if n_classes > 255 {
459 let indices: Array1<usize> = (0..n_candidates).collect();
460 return Zip::from(scores.rows())
461 .and(boxes.rows())
462 .and(&indices)
463 .into_par_iter()
464 .filter_map(|(score, bbox, index)| {
465 let (score_, label) = fast_arg_max(score);
466 if score_ < threshold {
467 return None;
468 }
469 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
470 Some((
471 DetectBoxQuantized {
472 label,
473 score: score_,
474 bbox: BoundingBox::from(bbox_quant),
475 },
476 *index,
477 ))
478 })
479 .collect();
480 }
481 let mut max_scores = vec![Scores::min_value(); n_candidates];
482 let mut max_classes = vec![0u8; n_candidates];
483 for class_idx in 0..n_classes {
484 let col = scores.column(class_idx);
485 if let Some(slice) = col.as_slice() {
486 #[cfg(target_arch = "aarch64")]
488 {
489 if std::mem::size_of::<Scores>() == 1 {
490 unsafe {
493 column_max_update_neon(
494 slice.as_ptr() as *const u8,
495 max_scores.as_mut_ptr() as *mut u8,
496 max_classes.as_mut_ptr(),
497 n_candidates,
498 class_idx as u8,
499 Scores::min_value() < Scores::zero(), );
501 }
502 continue;
503 }
504 }
505 for (i, &val) in slice.iter().enumerate() {
506 if val >= max_scores[i] {
507 max_scores[i] = val;
508 max_classes[i] = class_idx as u8;
509 }
510 }
511 } else {
512 for (i, &val) in col.iter().enumerate() {
513 if val >= max_scores[i] {
514 max_scores[i] = val;
515 max_classes[i] = class_idx as u8;
516 }
517 }
518 }
519 }
520
521 let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
525 let mut cols: [Vec<Boxes>; 4] = [
526 vec![Boxes::zero(); n_candidates],
527 vec![Boxes::zero(); n_candidates],
528 vec![Boxes::zero(); n_candidates],
529 vec![Boxes::zero(); n_candidates],
530 ];
531 for (dim, col_buf) in cols.iter_mut().enumerate() {
532 let col = boxes.column(dim);
533 if let Some(slice) = col.as_slice() {
534 col_buf.copy_from_slice(slice);
535 } else {
536 for (i, &val) in col.iter().enumerate() {
537 col_buf[i] = val;
538 }
539 }
540 }
541 cols
542 } else {
543 [vec![], vec![], vec![], vec![]]
545 };
546 let boxes_copied = !boxes_buf[0].is_empty();
547
548 let mut result = Vec::new();
550 for i in 0..n_candidates {
551 if max_scores[i] >= threshold {
552 let bbox_quant = if boxes_copied {
553 let raw = [
554 boxes_buf[0][i],
555 boxes_buf[1][i],
556 boxes_buf[2][i],
557 boxes_buf[3][i],
558 ];
559 B::to_xyxy_dequant(&raw, quant_boxes)
560 } else {
561 B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
562 };
563 result.push((
564 DetectBoxQuantized {
565 label: max_classes[i] as usize,
566 score: max_scores[i],
567 bbox: BoundingBox::from(bbox_quant),
568 },
569 i,
570 ));
571 }
572 }
573
574 result
575}
576
577#[doc(hidden)]
580#[must_use]
581pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
582 iou: f32,
583 max_det: Option<usize>,
584 mut boxes: Vec<DetectBoxQuantized<SCORE>>,
585) -> Vec<DetectBoxQuantized<SCORE>> {
586 boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
590
591 if iou >= 1.0 {
594 return match max_det {
595 Some(n) => {
596 boxes.truncate(n);
597 boxes
598 }
599 None => boxes,
600 };
601 }
602
603 let min_val = SCORE::min_value();
604 let cap = max_det.unwrap_or(usize::MAX);
605 let mut survivors: usize = 0;
606 for i in 0..boxes.len() {
608 if boxes[i].score <= min_val {
609 continue;
611 }
612 for j in (i + 1)..boxes.len() {
613 if boxes[j].score <= min_val {
616 continue;
618 }
619
620 if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
621 boxes[j].score = min_val;
623 }
624 }
625 survivors += 1;
626 if survivors >= cap {
627 break;
628 }
629 }
630 boxes
633 .into_iter()
634 .filter(|b| b.score > min_val)
635 .take(cap)
636 .collect()
637}
638
639#[doc(hidden)]
645#[must_use]
646pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
647 iou: f32,
648 max_det: Option<usize>,
649 mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
650) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
651 boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
654
655 if iou >= 1.0 {
658 return match max_det {
659 Some(n) => {
660 boxes.truncate(n);
661 boxes
662 }
663 None => boxes,
664 };
665 }
666
667 let min_val = SCORE::min_value();
668 let cap = max_det.unwrap_or(usize::MAX);
669 let mut survivors: usize = 0;
670 for i in 0..boxes.len() {
672 if boxes[i].0.score <= min_val {
673 continue;
675 }
676 for j in (i + 1)..boxes.len() {
677 if boxes[j].0.score <= min_val {
680 continue;
682 }
683 if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
684 boxes[j].0.score = min_val;
686 }
687 }
688 survivors += 1;
689 if survivors >= cap {
690 break;
691 }
692 }
693
694 boxes
696 .into_iter()
697 .filter(|b| b.0.score > min_val)
698 .take(cap)
699 .collect()
700}
701
702#[doc(hidden)]
709#[must_use]
710pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
711 iou: f32,
712 max_det: Option<usize>,
713 mut boxes: Vec<DetectBoxQuantized<SCORE>>,
714) -> Vec<DetectBoxQuantized<SCORE>> {
715 boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
716
717 if iou >= 1.0 {
720 return match max_det {
721 Some(n) => {
722 boxes.truncate(n);
723 boxes
724 }
725 None => boxes,
726 };
727 }
728
729 let min_val = SCORE::min_value();
730 let cap = max_det.unwrap_or(usize::MAX);
731 let mut survivors: usize = 0;
732 for i in 0..boxes.len() {
733 if boxes[i].score <= min_val {
734 continue;
735 }
736 for j in (i + 1)..boxes.len() {
737 if boxes[j].score <= min_val {
738 continue;
739 }
740 if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
742 boxes[j].score = min_val;
743 }
744 }
745 survivors += 1;
746 if survivors >= cap {
747 break;
748 }
749 }
750 boxes
751 .into_iter()
752 .filter(|b| b.score > min_val)
753 .take(cap)
754 .collect()
755}
756
757#[doc(hidden)]
763#[must_use]
764pub fn nms_extra_class_aware_int<
765 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
766 E: Send + Sync,
767>(
768 iou: f32,
769 max_det: Option<usize>,
770 mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
771) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
772 boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
773
774 if iou >= 1.0 {
777 return match max_det {
778 Some(n) => {
779 boxes.truncate(n);
780 boxes
781 }
782 None => boxes,
783 };
784 }
785
786 let min_val = SCORE::min_value();
787 let cap = max_det.unwrap_or(usize::MAX);
788 let mut survivors: usize = 0;
789 for i in 0..boxes.len() {
790 if boxes[i].0.score <= min_val {
791 continue;
792 }
793 for j in (i + 1)..boxes.len() {
794 if boxes[j].0.score <= min_val {
795 continue;
796 }
797 if boxes[j].0.label == boxes[i].0.label
799 && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
800 {
801 boxes[j].0.score = min_val;
802 }
803 }
804 survivors += 1;
805 if survivors >= cap {
806 break;
807 }
808 }
809 boxes
810 .into_iter()
811 .filter(|b| b.0.score > min_val)
812 .take(cap)
813 .collect()
814}
815
816#[doc(hidden)]
831pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
832where
833 f32: AsPrimitive<T>,
834{
835 if quant.scale == 0.0 {
836 return T::max_value();
837 }
838 let v = (score / quant.scale + quant.zero_point as f32).ceil();
839 let v = v.clamp(T::min_value().as_(), T::max_value().as_());
840 v.as_()
841}
842
843#[cfg(test)]
844mod tests {
845 use super::*;
846 use crate::XYWH;
847 use ndarray::Array2;
848
849 #[test]
852 fn column_major_matches_row_major() {
853 let n_classes = 80usize;
855 let n_candidates = 100usize;
856 let mut scores_physical = Array2::<u8>::zeros((n_classes, n_candidates));
857 for c in 0..n_classes {
859 for i in 0..n_candidates {
860 scores_physical[[c, i]] = ((c * 3 + i * 7) % 256) as u8;
861 }
862 }
863
864 let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
866 for i in 0..n_candidates {
867 boxes_physical[[0, i]] = (i * 10) as i16; boxes_physical[[1, i]] = (i * 20) as i16; boxes_physical[[2, i]] = (i * 10 + 50) as i16; boxes_physical[[3, i]] = (i * 20 + 100) as i16; }
872
873 let quant = Quantization {
874 scale: 0.00390625,
875 zero_point: 0,
876 };
877
878 let threshold: u8 = 10;
879
880 let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
882 let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
883 let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
884 threshold,
885 boxes_contiguous.view(),
886 scores_contiguous.view(),
887 quant,
888 );
889
890 let scores_view = scores_physical.view().reversed_axes();
892 let boxes_view = boxes_physical.view().reversed_axes();
893 assert!(scores_view.as_slice().is_none(), "should be non-contiguous");
894 assert_eq!(scores_view.strides()[0], 1);
895 let col_result =
896 postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
897
898 assert_eq!(
900 row_result.len(),
901 col_result.len(),
902 "different number of results: row={}, col={}",
903 row_result.len(),
904 col_result.len()
905 );
906 for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
907 assert_eq!(
908 row.0.label, col.0.label,
909 "candidate {i}: label mismatch row={} col={}",
910 row.0.label, col.0.label
911 );
912 assert_eq!(row.0.score, col.0.score, "candidate {i}: score mismatch");
913 assert_eq!(row.1, col.1, "candidate {i}: index mismatch");
914 assert_eq!(row.0.bbox, col.0.bbox, "candidate {i}: bbox mismatch");
915 }
916 }
917
918 #[test]
920 fn column_major_matches_row_major_i8() {
921 let n_classes = 80usize;
922 let n_candidates = 50usize;
923 let mut scores_physical = Array2::<i8>::zeros((n_classes, n_candidates));
924 for c in 0..n_classes {
925 for i in 0..n_candidates {
926 scores_physical[[c, i]] = ((c as i16 * 3 + i as i16 * 7) % 256 - 128) as i8;
927 }
928 }
929
930 let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
931 for i in 0..n_candidates {
932 boxes_physical[[0, i]] = (i * 10) as i16;
933 boxes_physical[[1, i]] = (i * 20) as i16;
934 boxes_physical[[2, i]] = (i * 10 + 50) as i16;
935 boxes_physical[[3, i]] = (i * 20 + 100) as i16;
936 }
937
938 let quant = Quantization {
939 scale: 0.0256,
940 zero_point: -116,
941 };
942 let threshold: i8 = -100;
943
944 let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
945 let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
946 let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
947 threshold,
948 boxes_contiguous.view(),
949 scores_contiguous.view(),
950 quant,
951 );
952
953 let scores_view = scores_physical.view().reversed_axes();
954 let boxes_view = boxes_physical.view().reversed_axes();
955 let col_result =
956 postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
957
958 assert_eq!(row_result.len(), col_result.len());
959 for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
960 assert_eq!(row.0.label, col.0.label, "i8 candidate {i}: label mismatch");
961 assert_eq!(row.0.score, col.0.score, "i8 candidate {i}: score mismatch");
962 assert_eq!(row.1, col.1, "i8 candidate {i}: index mismatch");
963 }
964 }
965
966 fn make_nms_boxes_int(n: usize) -> Vec<DetectBoxQuantized<u8>> {
968 (0..n)
969 .map(|i| DetectBoxQuantized {
970 bbox: BoundingBox {
971 xmin: i as f32 * 100.0,
972 ymin: 0.0,
973 xmax: i as f32 * 100.0 + 10.0,
974 ymax: 10.0,
975 },
976 label: 0,
977 score: (200 - i as u32).min(255) as u8,
978 })
979 .collect()
980 }
981
982 #[test]
983 fn nms_int_max_det_matches_full_truncated() {
984 let boxes = make_nms_boxes_int(20);
985 let n = 5;
986 let full = nms_int(0.5, None, boxes.clone());
987 let capped = nms_int(0.5, Some(n), boxes);
988 assert_eq!(capped.len(), n);
989 assert_eq!(&full[..n], &capped[..]);
990 }
991
992 #[test]
993 fn nms_int_max_det_zero_returns_empty() {
994 let boxes = make_nms_boxes_int(10);
995 let result = nms_int(0.5, Some(0), boxes);
996 assert!(result.is_empty());
997 }
998
999 #[test]
1000 fn nms_int_max_det_iou_ge_1_returns_sorted_truncated() {
1001 let boxes = make_nms_boxes_int(10);
1002 let result = nms_int(1.0, Some(3), boxes);
1003 assert_eq!(result.len(), 3);
1004 assert!(result[0].score >= result[1].score);
1006 assert!(result[1].score >= result[2].score);
1007 }
1008
1009 #[test]
1010 fn nms_int_max_det_larger_than_input() {
1011 let boxes = make_nms_boxes_int(5);
1012 let full = nms_int(0.5, None, boxes.clone());
1013 let capped = nms_int(0.5, Some(100), boxes);
1014 assert_eq!(full.len(), capped.len());
1015 }
1016}