1#[cfg(target_arch = "aarch64")]
5use crate::arg_max_i8;
6use crate::{
7 arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
8};
9use ndarray::{
10 parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
11 Array1, ArrayView1, ArrayView2, Zip,
12};
13use num_traits::{AsPrimitive, PrimInt};
14use rayon::slice::ParallelSliceMut;
15
16#[cfg(target_arch = "aarch64")]
27unsafe fn column_max_update_neon(
28 col_ptr: *const u8,
29 max_ptr: *mut u8,
30 class_ptr: *mut u8,
31 n: usize,
32 class_idx: u8,
33 signed: bool,
34) {
35 use std::arch::aarch64::*;
36
37 let class_vec = vdupq_n_u8(class_idx);
38 let chunks = n / 16;
39 let remainder = n % 16;
40
41 if signed {
42 for chunk in 0..chunks {
44 let offset = chunk * 16;
45 let col = vld1q_s8(col_ptr.add(offset) as *const i8);
46 let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
47 let mask = vcgeq_s8(col, cur_max);
49 let new_max = vmaxq_s8(col, cur_max);
51 vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
52 let cur_class = vld1q_u8(class_ptr.add(offset));
54 let new_class = vbslq_u8(mask, class_vec, cur_class);
55 vst1q_u8(class_ptr.add(offset), new_class);
56 }
57 for i in (chunks * 16)..n {
59 let val = *(col_ptr.add(i) as *const i8);
60 let cur = *(max_ptr.add(i) as *const i8);
61 if val >= cur {
62 *(max_ptr.add(i) as *mut i8) = val;
63 *class_ptr.add(i) = class_idx;
64 }
65 }
66 } else {
67 for chunk in 0..chunks {
69 let offset = chunk * 16;
70 let col = vld1q_u8(col_ptr.add(offset));
71 let cur_max = vld1q_u8(max_ptr.add(offset));
72 let mask = vcgeq_u8(col, cur_max);
73 let new_max = vmaxq_u8(col, cur_max);
74 vst1q_u8(max_ptr.add(offset), new_max);
75 let cur_class = vld1q_u8(class_ptr.add(offset));
76 let new_class = vbslq_u8(mask, class_vec, cur_class);
77 vst1q_u8(class_ptr.add(offset), new_class);
78 }
79 for i in (chunks * 16)..n {
81 let val = *col_ptr.add(i);
82 let cur = *max_ptr.add(i);
83 if val >= cur {
84 *max_ptr.add(i) = val;
85 *class_ptr.add(i) = class_idx;
86 }
87 }
88 }
89 let _ = remainder; }
91
92#[inline(always)]
94fn fast_arg_max<T: PrimInt + Copy>(score: ArrayView1<T>) -> (T, usize) {
95 #[cfg(target_arch = "aarch64")]
96 {
97 if std::mem::size_of::<T>() == 1 && score.as_slice().is_some() {
99 let slice = score.as_slice().unwrap();
100 let ptr = slice.as_ptr() as *const i8;
104 let i8_slice = unsafe { std::slice::from_raw_parts(ptr, slice.len()) };
105 if T::min_value() < T::zero() {
108 let (max_val, idx) = arg_max_i8(i8_slice);
109 let result: T = unsafe { std::mem::transmute_copy(&max_val) };
112 return (result, idx);
113 }
114 }
115 }
116 arg_max(score)
117}
118
119#[doc(hidden)]
124pub fn postprocess_boxes_quant<
125 B: BBoxTypeTrait,
126 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
127 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
128>(
129 threshold: Scores,
130 boxes: ArrayView2<Boxes>,
131 scores: ArrayView2<Scores>,
132 quant_boxes: Quantization,
133) -> Vec<DetectBoxQuantized<Scores>> {
134 assert_eq!(scores.dim().0, boxes.dim().0);
135 assert_eq!(boxes.dim().1, 4);
136
137 if scores.strides()[0] == 1 && scores.as_slice().is_none() {
139 return postprocess_boxes_quant_column_major::<B, _, _>(
140 threshold,
141 boxes,
142 scores,
143 quant_boxes,
144 );
145 }
146
147 Zip::from(scores.rows())
148 .and(boxes.rows())
149 .into_par_iter()
150 .filter_map(|(score, bbox)| {
151 let (score_, label) = fast_arg_max(score);
152 if score_ < threshold {
153 return None;
154 }
155
156 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
157 Some(DetectBoxQuantized {
158 label,
159 score: score_,
160 bbox: BoundingBox::from(bbox_quant),
161 })
162 })
163 .collect()
164}
165
166fn postprocess_boxes_quant_column_major<
168 B: BBoxTypeTrait,
169 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
170 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
171>(
172 threshold: Scores,
173 boxes: ArrayView2<Boxes>,
174 scores: ArrayView2<Scores>,
175 quant_boxes: Quantization,
176) -> Vec<DetectBoxQuantized<Scores>> {
177 let (n_candidates, n_classes) = scores.dim();
178
179 if n_classes > 255 {
181 return Zip::from(scores.rows())
182 .and(boxes.rows())
183 .into_par_iter()
184 .filter_map(|(score, bbox)| {
185 let (score_, label) = fast_arg_max(score);
186 if score_ < threshold {
187 return None;
188 }
189 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
190 Some(DetectBoxQuantized {
191 label,
192 score: score_,
193 bbox: BoundingBox::from(bbox_quant),
194 })
195 })
196 .collect();
197 }
198 let mut max_scores = vec![Scores::min_value(); n_candidates];
199 let mut max_classes = vec![0u8; n_candidates];
200
201 for class_idx in 0..n_classes {
202 let col = scores.column(class_idx);
203 if let Some(slice) = col.as_slice() {
204 #[cfg(target_arch = "aarch64")]
205 {
206 if std::mem::size_of::<Scores>() == 1 {
207 unsafe {
208 column_max_update_neon(
209 slice.as_ptr() as *const u8,
210 max_scores.as_mut_ptr() as *mut u8,
211 max_classes.as_mut_ptr(),
212 n_candidates,
213 class_idx as u8,
214 Scores::min_value() < Scores::zero(),
215 );
216 }
217 continue;
218 }
219 }
220 for (i, &val) in slice.iter().enumerate() {
221 if val >= max_scores[i] {
222 max_scores[i] = val;
223 max_classes[i] = class_idx as u8;
224 }
225 }
226 } else {
227 for (i, &val) in col.iter().enumerate() {
228 if val >= max_scores[i] {
229 max_scores[i] = val;
230 max_classes[i] = class_idx as u8;
231 }
232 }
233 }
234 }
235
236 let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
238 let mut cols: [Vec<Boxes>; 4] = [
239 vec![Boxes::zero(); n_candidates],
240 vec![Boxes::zero(); n_candidates],
241 vec![Boxes::zero(); n_candidates],
242 vec![Boxes::zero(); n_candidates],
243 ];
244 for (dim, col_buf) in cols.iter_mut().enumerate() {
245 let col = boxes.column(dim);
246 if let Some(slice) = col.as_slice() {
247 col_buf.copy_from_slice(slice);
248 } else {
249 for (i, &val) in col.iter().enumerate() {
250 col_buf[i] = val;
251 }
252 }
253 }
254 cols
255 } else {
256 [vec![], vec![], vec![], vec![]]
257 };
258 let boxes_copied = !boxes_buf[0].is_empty();
259
260 let mut result = Vec::new();
261 for i in 0..n_candidates {
262 if max_scores[i] >= threshold {
263 let bbox_quant = if boxes_copied {
264 let raw = [
265 boxes_buf[0][i],
266 boxes_buf[1][i],
267 boxes_buf[2][i],
268 boxes_buf[3][i],
269 ];
270 B::to_xyxy_dequant(&raw, quant_boxes)
271 } else {
272 B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
273 };
274 result.push(DetectBoxQuantized {
275 label: max_classes[i] as usize,
276 score: max_scores[i],
277 bbox: BoundingBox::from(bbox_quant),
278 });
279 }
280 }
281
282 result
283}
284
285#[doc(hidden)]
297pub fn postprocess_boxes_index_quant<
298 B: BBoxTypeTrait,
299 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
300 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
301>(
302 threshold: Scores,
303 boxes: ArrayView2<Boxes>,
304 scores: ArrayView2<Scores>,
305 quant_boxes: Quantization,
306) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
307 assert_eq!(scores.dim().0, boxes.dim().0);
308 assert_eq!(boxes.dim().1, 4);
309
310 if scores.strides()[0] == 1 && scores.as_slice().is_none() {
314 return postprocess_boxes_index_quant_column_major::<B, _, _>(
315 threshold,
316 boxes,
317 scores,
318 quant_boxes,
319 );
320 }
321
322 let indices: Array1<usize> = (0..boxes.dim().0).collect();
323 Zip::from(scores.rows())
324 .and(boxes.rows())
325 .and(&indices)
326 .into_par_iter()
327 .filter_map(|(score, bbox, index)| {
328 let (score_, label) = fast_arg_max(score);
329 if score_ < threshold {
330 return None;
331 }
332
333 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
334
335 Some((
336 DetectBoxQuantized {
337 label,
338 score: score_,
339 bbox: BoundingBox::from(bbox_quant),
340 },
341 *index,
342 ))
343 })
344 .collect()
345}
346
347fn postprocess_boxes_index_quant_column_major<
355 B: BBoxTypeTrait,
356 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
357 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
358>(
359 threshold: Scores,
360 boxes: ArrayView2<Boxes>,
361 scores: ArrayView2<Scores>,
362 quant_boxes: Quantization,
363) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
364 let (n_candidates, n_classes) = scores.dim();
365
366 if n_classes > 255 {
370 let indices: Array1<usize> = (0..n_candidates).collect();
371 return Zip::from(scores.rows())
372 .and(boxes.rows())
373 .and(&indices)
374 .into_par_iter()
375 .filter_map(|(score, bbox, index)| {
376 let (score_, label) = fast_arg_max(score);
377 if score_ < threshold {
378 return None;
379 }
380 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
381 Some((
382 DetectBoxQuantized {
383 label,
384 score: score_,
385 bbox: BoundingBox::from(bbox_quant),
386 },
387 *index,
388 ))
389 })
390 .collect();
391 }
392 let mut max_scores = vec![Scores::min_value(); n_candidates];
393 let mut max_classes = vec![0u8; n_candidates];
394 for class_idx in 0..n_classes {
395 let col = scores.column(class_idx);
396 if let Some(slice) = col.as_slice() {
397 #[cfg(target_arch = "aarch64")]
399 {
400 if std::mem::size_of::<Scores>() == 1 {
401 unsafe {
404 column_max_update_neon(
405 slice.as_ptr() as *const u8,
406 max_scores.as_mut_ptr() as *mut u8,
407 max_classes.as_mut_ptr(),
408 n_candidates,
409 class_idx as u8,
410 Scores::min_value() < Scores::zero(), );
412 }
413 continue;
414 }
415 }
416 for (i, &val) in slice.iter().enumerate() {
417 if val >= max_scores[i] {
418 max_scores[i] = val;
419 max_classes[i] = class_idx as u8;
420 }
421 }
422 } else {
423 for (i, &val) in col.iter().enumerate() {
424 if val >= max_scores[i] {
425 max_scores[i] = val;
426 max_classes[i] = class_idx as u8;
427 }
428 }
429 }
430 }
431
432 let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
436 let mut cols: [Vec<Boxes>; 4] = [
437 vec![Boxes::zero(); n_candidates],
438 vec![Boxes::zero(); n_candidates],
439 vec![Boxes::zero(); n_candidates],
440 vec![Boxes::zero(); n_candidates],
441 ];
442 for (dim, col_buf) in cols.iter_mut().enumerate() {
443 let col = boxes.column(dim);
444 if let Some(slice) = col.as_slice() {
445 col_buf.copy_from_slice(slice);
446 } else {
447 for (i, &val) in col.iter().enumerate() {
448 col_buf[i] = val;
449 }
450 }
451 }
452 cols
453 } else {
454 [vec![], vec![], vec![], vec![]]
456 };
457 let boxes_copied = !boxes_buf[0].is_empty();
458
459 let mut result = Vec::new();
461 for i in 0..n_candidates {
462 if max_scores[i] >= threshold {
463 let bbox_quant = if boxes_copied {
464 let raw = [
465 boxes_buf[0][i],
466 boxes_buf[1][i],
467 boxes_buf[2][i],
468 boxes_buf[3][i],
469 ];
470 B::to_xyxy_dequant(&raw, quant_boxes)
471 } else {
472 B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
473 };
474 result.push((
475 DetectBoxQuantized {
476 label: max_classes[i] as usize,
477 score: max_scores[i],
478 bbox: BoundingBox::from(bbox_quant),
479 },
480 i,
481 ));
482 }
483 }
484
485 result
486}
487
488#[doc(hidden)]
491#[must_use]
492pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
493 iou: f32,
494 mut boxes: Vec<DetectBoxQuantized<SCORE>>,
495) -> Vec<DetectBoxQuantized<SCORE>> {
496 boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
500
501 if iou >= 1.0 {
504 return boxes;
505 }
506
507 let min_val = SCORE::min_value();
508 for i in 0..boxes.len() {
510 if boxes[i].score <= min_val {
511 continue;
513 }
514 for j in (i + 1)..boxes.len() {
515 if boxes[j].score <= min_val {
518 continue;
520 }
521
522 if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
523 boxes[j].score = min_val;
525 }
526 }
527 }
528 boxes.into_iter().filter(|b| b.score > min_val).collect()
530}
531
532#[doc(hidden)]
538#[must_use]
539pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
540 iou: f32,
541 mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
542) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
543 boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
546
547 if iou >= 1.0 {
550 return boxes;
551 }
552
553 let min_val = SCORE::min_value();
554 for i in 0..boxes.len() {
556 if boxes[i].0.score <= min_val {
557 continue;
559 }
560 for j in (i + 1)..boxes.len() {
561 if boxes[j].0.score <= min_val {
564 continue;
566 }
567 if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
568 boxes[j].0.score = min_val;
570 }
571 }
572 }
573
574 boxes.into_iter().filter(|b| b.0.score > min_val).collect()
576}
577
578#[doc(hidden)]
585#[must_use]
586pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
587 iou: f32,
588 mut boxes: Vec<DetectBoxQuantized<SCORE>>,
589) -> Vec<DetectBoxQuantized<SCORE>> {
590 boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
591
592 if iou >= 1.0 {
595 return boxes;
596 }
597
598 let min_val = SCORE::min_value();
599 for i in 0..boxes.len() {
600 if boxes[i].score <= min_val {
601 continue;
602 }
603 for j in (i + 1)..boxes.len() {
604 if boxes[j].score <= min_val {
605 continue;
606 }
607 if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
609 boxes[j].score = min_val;
610 }
611 }
612 }
613 boxes.into_iter().filter(|b| b.score > min_val).collect()
614}
615
616#[doc(hidden)]
622#[must_use]
623pub fn nms_extra_class_aware_int<
624 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
625 E: Send + Sync,
626>(
627 iou: f32,
628 mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
629) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
630 boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
631
632 if iou >= 1.0 {
635 return boxes;
636 }
637
638 let min_val = SCORE::min_value();
639 for i in 0..boxes.len() {
640 if boxes[i].0.score <= min_val {
641 continue;
642 }
643 for j in (i + 1)..boxes.len() {
644 if boxes[j].0.score <= min_val {
645 continue;
646 }
647 if boxes[j].0.label == boxes[i].0.label
649 && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
650 {
651 boxes[j].0.score = min_val;
652 }
653 }
654 }
655 boxes.into_iter().filter(|b| b.0.score > min_val).collect()
656}
657
658#[doc(hidden)]
673pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
674where
675 f32: AsPrimitive<T>,
676{
677 if quant.scale == 0.0 {
678 return T::max_value();
679 }
680 let v = (score / quant.scale + quant.zero_point as f32).ceil();
681 let v = v.clamp(T::min_value().as_(), T::max_value().as_());
682 v.as_()
683}
684
685#[cfg(test)]
686mod tests {
687 use super::*;
688 use crate::XYWH;
689 use ndarray::Array2;
690
691 #[test]
694 fn column_major_matches_row_major() {
695 let n_classes = 80usize;
697 let n_candidates = 100usize;
698 let mut scores_physical = Array2::<u8>::zeros((n_classes, n_candidates));
699 for c in 0..n_classes {
701 for i in 0..n_candidates {
702 scores_physical[[c, i]] = ((c * 3 + i * 7) % 256) as u8;
703 }
704 }
705
706 let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
708 for i in 0..n_candidates {
709 boxes_physical[[0, i]] = (i * 10) as i16; boxes_physical[[1, i]] = (i * 20) as i16; boxes_physical[[2, i]] = (i * 10 + 50) as i16; boxes_physical[[3, i]] = (i * 20 + 100) as i16; }
714
715 let quant = Quantization {
716 scale: 0.00390625,
717 zero_point: 0,
718 };
719
720 let threshold: u8 = 10;
721
722 let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
724 let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
725 let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
726 threshold,
727 boxes_contiguous.view(),
728 scores_contiguous.view(),
729 quant,
730 );
731
732 let scores_view = scores_physical.view().reversed_axes();
734 let boxes_view = boxes_physical.view().reversed_axes();
735 assert!(scores_view.as_slice().is_none(), "should be non-contiguous");
736 assert_eq!(scores_view.strides()[0], 1);
737 let col_result =
738 postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
739
740 assert_eq!(
742 row_result.len(),
743 col_result.len(),
744 "different number of results: row={}, col={}",
745 row_result.len(),
746 col_result.len()
747 );
748 for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
749 assert_eq!(
750 row.0.label, col.0.label,
751 "candidate {i}: label mismatch row={} col={}",
752 row.0.label, col.0.label
753 );
754 assert_eq!(row.0.score, col.0.score, "candidate {i}: score mismatch");
755 assert_eq!(row.1, col.1, "candidate {i}: index mismatch");
756 assert_eq!(row.0.bbox, col.0.bbox, "candidate {i}: bbox mismatch");
757 }
758 }
759
760 #[test]
762 fn column_major_matches_row_major_i8() {
763 let n_classes = 80usize;
764 let n_candidates = 50usize;
765 let mut scores_physical = Array2::<i8>::zeros((n_classes, n_candidates));
766 for c in 0..n_classes {
767 for i in 0..n_candidates {
768 scores_physical[[c, i]] = ((c as i16 * 3 + i as i16 * 7) % 256 - 128) as i8;
769 }
770 }
771
772 let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
773 for i in 0..n_candidates {
774 boxes_physical[[0, i]] = (i * 10) as i16;
775 boxes_physical[[1, i]] = (i * 20) as i16;
776 boxes_physical[[2, i]] = (i * 10 + 50) as i16;
777 boxes_physical[[3, i]] = (i * 20 + 100) as i16;
778 }
779
780 let quant = Quantization {
781 scale: 0.0256,
782 zero_point: -116,
783 };
784 let threshold: i8 = -100;
785
786 let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
787 let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
788 let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
789 threshold,
790 boxes_contiguous.view(),
791 scores_contiguous.view(),
792 quant,
793 );
794
795 let scores_view = scores_physical.view().reversed_axes();
796 let boxes_view = boxes_physical.view().reversed_axes();
797 let col_result =
798 postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
799
800 assert_eq!(row_result.len(), col_result.len());
801 for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
802 assert_eq!(row.0.label, col.0.label, "i8 candidate {i}: label mismatch");
803 assert_eq!(row.0.score, col.0.score, "i8 candidate {i}: score mismatch");
804 assert_eq!(row.1, col.1, "i8 candidate {i}: index mismatch");
805 }
806 }
807}