Skip to main content

edgefirst_decoder/
byte.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(target_arch = "aarch64")]
5use crate::arg_max_i8;
6use crate::{
7    arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
8};
9use ndarray::{
10    parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
11    Array1, ArrayView1, ArrayView2, Zip,
12};
13use num_traits::{AsPrimitive, PrimInt};
14use rayon::slice::ParallelSliceMut;
15
16/// NEON-accelerated column max update for the column-major argmax path.
17///
18/// Processes 16 elements per iteration using SIMD max + bitwise select.
19/// Handles both unsigned (u8) and signed (i8) comparison semantics.
20///
21/// # Safety
22///
23/// - `col_ptr` must point to at least `n` valid bytes.
24/// - `max_ptr` must point to at least `n` valid mutable bytes.
25/// - `class_ptr` must point to at least `n` valid mutable bytes.
26#[cfg(target_arch = "aarch64")]
27unsafe fn column_max_update_neon(
28    col_ptr: *const u8,
29    max_ptr: *mut u8,
30    class_ptr: *mut u8,
31    n: usize,
32    class_idx: u8,
33    signed: bool,
34) {
35    use std::arch::aarch64::*;
36
37    let class_vec = vdupq_n_u8(class_idx);
38    let chunks = n / 16;
39    let remainder = n % 16;
40
41    if signed {
42        // Signed i8 comparison: interpret bytes as i8.
43        for chunk in 0..chunks {
44            let offset = chunk * 16;
45            let col = vld1q_s8(col_ptr.add(offset) as *const i8);
46            let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
47            // mask[i] = 0xFF where col[i] >= cur_max[i], else 0x00
48            let mask = vcgeq_s8(col, cur_max);
49            // new_max = max(col, cur_max)
50            let new_max = vmaxq_s8(col, cur_max);
51            vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
52            // Select class_idx where mask is set, keep old class otherwise.
53            let cur_class = vld1q_u8(class_ptr.add(offset));
54            let new_class = vbslq_u8(mask, class_vec, cur_class);
55            vst1q_u8(class_ptr.add(offset), new_class);
56        }
57        // Scalar tail.
58        for i in (chunks * 16)..n {
59            let val = *(col_ptr.add(i) as *const i8);
60            let cur = *(max_ptr.add(i) as *const i8);
61            if val >= cur {
62                *(max_ptr.add(i) as *mut i8) = val;
63                *class_ptr.add(i) = class_idx;
64            }
65        }
66    } else {
67        // Unsigned u8 comparison.
68        for chunk in 0..chunks {
69            let offset = chunk * 16;
70            let col = vld1q_u8(col_ptr.add(offset));
71            let cur_max = vld1q_u8(max_ptr.add(offset));
72            let mask = vcgeq_u8(col, cur_max);
73            let new_max = vmaxq_u8(col, cur_max);
74            vst1q_u8(max_ptr.add(offset), new_max);
75            let cur_class = vld1q_u8(class_ptr.add(offset));
76            let new_class = vbslq_u8(mask, class_vec, cur_class);
77            vst1q_u8(class_ptr.add(offset), new_class);
78        }
79        // Scalar tail.
80        for i in (chunks * 16)..n {
81            let val = *col_ptr.add(i);
82            let cur = *max_ptr.add(i);
83            if val >= cur {
84                *max_ptr.add(i) = val;
85                *class_ptr.add(i) = class_idx;
86            }
87        }
88    }
89    let _ = remainder; // suppress unused warning
90}
91
92/// Fast argmax dispatching to NEON-optimized path for i8 on aarch64.
93#[inline(always)]
94fn fast_arg_max<T: PrimInt + Copy>(score: ArrayView1<T>) -> (T, usize) {
95    #[cfg(target_arch = "aarch64")]
96    {
97        // Check if this is an i8 slice and contiguous.
98        if std::mem::size_of::<T>() == 1 && score.as_slice().is_some() {
99            let slice = score.as_slice().unwrap();
100            // Safety: T is i8 when size_of::<T>() == 1 and PrimInt.
101            // PrimInt covers i8, u8, i16, etc. We only want to use the
102            // i8 NEON path for signed i8.
103            let ptr = slice.as_ptr() as *const i8;
104            let i8_slice = unsafe { std::slice::from_raw_parts(ptr, slice.len()) };
105            // Only valid for signed i8 (not u8). Check sign bit behavior:
106            // PrimInt for i8 means min_value() is negative.
107            if T::min_value() < T::zero() {
108                let (max_val, idx) = arg_max_i8(i8_slice);
109                // Safety: transmute i8 back to T (they have the same size and
110                // representation for i8).
111                let result: T = unsafe { std::mem::transmute_copy(&max_val) };
112                return (result, idx);
113            }
114        }
115    }
116    arg_max(score)
117}
118
119/// Post processes boxes and scores tensors into quantized detection boxes,
120/// filtering out any boxes below the score threshold. The boxes tensor
121/// is converted to XYXY using the given BBoxTypeTrait. The order of the boxes
122/// is preserved.
123#[doc(hidden)]
124pub fn postprocess_boxes_quant<
125    B: BBoxTypeTrait,
126    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
127    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
128>(
129    threshold: Scores,
130    boxes: ArrayView2<Boxes>,
131    scores: ArrayView2<Scores>,
132    quant_boxes: Quantization,
133) -> Vec<DetectBoxQuantized<Scores>> {
134    assert_eq!(scores.dim().0, boxes.dim().0);
135    assert_eq!(boxes.dim().1, 4);
136
137    // Use column-major path for transposed DMA-BUF views (see postprocess_boxes_index_quant).
138    if scores.strides()[0] == 1 && scores.as_slice().is_none() {
139        return postprocess_boxes_quant_column_major::<B, _, _>(
140            threshold,
141            boxes,
142            scores,
143            quant_boxes,
144        );
145    }
146
147    Zip::from(scores.rows())
148        .and(boxes.rows())
149        .into_par_iter()
150        .filter_map(|(score, bbox)| {
151            let (score_, label) = fast_arg_max(score);
152            if score_ < threshold {
153                return None;
154            }
155
156            let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
157            Some(DetectBoxQuantized {
158                label,
159                score: score_,
160                bbox: BoundingBox::from(bbox_quant),
161            })
162        })
163        .collect()
164}
165
166/// Column-major optimized path for `postprocess_boxes_quant`.
167fn postprocess_boxes_quant_column_major<
168    B: BBoxTypeTrait,
169    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
170    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
171>(
172    threshold: Scores,
173    boxes: ArrayView2<Boxes>,
174    scores: ArrayView2<Scores>,
175    quant_boxes: Quantization,
176) -> Vec<DetectBoxQuantized<Scores>> {
177    let (n_candidates, n_classes) = scores.dim();
178
179    // Column-major NEON path uses u8 class indices; fall back for >255 classes.
180    if n_classes > 255 {
181        return Zip::from(scores.rows())
182            .and(boxes.rows())
183            .into_par_iter()
184            .filter_map(|(score, bbox)| {
185                let (score_, label) = fast_arg_max(score);
186                if score_ < threshold {
187                    return None;
188                }
189                let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
190                Some(DetectBoxQuantized {
191                    label,
192                    score: score_,
193                    bbox: BoundingBox::from(bbox_quant),
194                })
195            })
196            .collect();
197    }
198    let mut max_scores = vec![Scores::min_value(); n_candidates];
199    let mut max_classes = vec![0u8; n_candidates];
200
201    for class_idx in 0..n_classes {
202        let col = scores.column(class_idx);
203        if let Some(slice) = col.as_slice() {
204            #[cfg(target_arch = "aarch64")]
205            {
206                if std::mem::size_of::<Scores>() == 1 {
207                    unsafe {
208                        column_max_update_neon(
209                            slice.as_ptr() as *const u8,
210                            max_scores.as_mut_ptr() as *mut u8,
211                            max_classes.as_mut_ptr(),
212                            n_candidates,
213                            class_idx as u8,
214                            Scores::min_value() < Scores::zero(),
215                        );
216                    }
217                    continue;
218                }
219            }
220            for (i, &val) in slice.iter().enumerate() {
221                if val >= max_scores[i] {
222                    max_scores[i] = val;
223                    max_classes[i] = class_idx as u8;
224                }
225            }
226        } else {
227            for (i, &val) in col.iter().enumerate() {
228                if val >= max_scores[i] {
229                    max_scores[i] = val;
230                    max_classes[i] = class_idx as u8;
231                }
232            }
233        }
234    }
235
236    // Copy boxes column-by-column if also transposed.
237    let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
238        let mut cols: [Vec<Boxes>; 4] = [
239            vec![Boxes::zero(); n_candidates],
240            vec![Boxes::zero(); n_candidates],
241            vec![Boxes::zero(); n_candidates],
242            vec![Boxes::zero(); n_candidates],
243        ];
244        for (dim, col_buf) in cols.iter_mut().enumerate() {
245            let col = boxes.column(dim);
246            if let Some(slice) = col.as_slice() {
247                col_buf.copy_from_slice(slice);
248            } else {
249                for (i, &val) in col.iter().enumerate() {
250                    col_buf[i] = val;
251                }
252            }
253        }
254        cols
255    } else {
256        [vec![], vec![], vec![], vec![]]
257    };
258    let boxes_copied = !boxes_buf[0].is_empty();
259
260    let mut result = Vec::new();
261    for i in 0..n_candidates {
262        if max_scores[i] >= threshold {
263            let bbox_quant = if boxes_copied {
264                let raw = [
265                    boxes_buf[0][i],
266                    boxes_buf[1][i],
267                    boxes_buf[2][i],
268                    boxes_buf[3][i],
269                ];
270                B::to_xyxy_dequant(&raw, quant_boxes)
271            } else {
272                B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
273            };
274            result.push(DetectBoxQuantized {
275                label: max_classes[i] as usize,
276                score: max_scores[i],
277                bbox: BoundingBox::from(bbox_quant),
278            });
279        }
280    }
281
282    result
283}
284
285/// Post processes boxes and scores tensors into quantized detection boxes,
286/// filtering out any boxes below the score threshold. The boxes tensor
287/// is converted to XYXY using the given BBoxTypeTrait. The order of the boxes
288/// is preserved.
289///
290/// This function is very similar to `postprocess_boxes_quant` but will also
291/// return the index of the box. The boxes will be in ascending index order.
292///
293/// When scores originate from a transposed DMA-BUF view (stride-1 along axis 0),
294/// an optimized column-major scan is used to avoid catastrophic strided reads on
295/// uncacheable memory.
296#[doc(hidden)]
297pub fn postprocess_boxes_index_quant<
298    B: BBoxTypeTrait,
299    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
300    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
301>(
302    threshold: Scores,
303    boxes: ArrayView2<Boxes>,
304    scores: ArrayView2<Scores>,
305    quant_boxes: Quantization,
306) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
307    assert_eq!(scores.dim().0, boxes.dim().0);
308    assert_eq!(boxes.dim().1, 4);
309
310    // Detect transposed C-contiguous layout (e.g., from reversed_axes() on DMA-BUF).
311    // In this layout columns are contiguous (stride-1 along axis 0) but rows are not.
312    // The column-major path reads memory sequentially, avoiding cache-hostile strides.
313    if scores.strides()[0] == 1 && scores.as_slice().is_none() {
314        return postprocess_boxes_index_quant_column_major::<B, _, _>(
315            threshold,
316            boxes,
317            scores,
318            quant_boxes,
319        );
320    }
321
322    let indices: Array1<usize> = (0..boxes.dim().0).collect();
323    Zip::from(scores.rows())
324        .and(boxes.rows())
325        .and(&indices)
326        .into_par_iter()
327        .filter_map(|(score, bbox, index)| {
328            let (score_, label) = fast_arg_max(score);
329            if score_ < threshold {
330                return None;
331            }
332
333            let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
334
335            Some((
336                DetectBoxQuantized {
337                    label,
338                    score: score_,
339                    bbox: BoundingBox::from(bbox_quant),
340                },
341                *index,
342            ))
343        })
344        .collect()
345}
346
347/// Column-major optimized path for `postprocess_boxes_index_quant`.
348///
349/// When scores come from a transposed DMA-BUF view ([N_candidates, N_classes]
350/// with strides [1, N_candidates]), row iteration causes N_classes reads each
351/// N_candidates bytes apart — catastrophic on uncacheable memory. Instead, this
352/// iterates over classes (columns, which are contiguous), maintaining a running
353/// argmax per candidate in cacheable heap buffers.
354fn postprocess_boxes_index_quant_column_major<
355    B: BBoxTypeTrait,
356    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
357    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
358>(
359    threshold: Scores,
360    boxes: ArrayView2<Boxes>,
361    scores: ArrayView2<Scores>,
362    quant_boxes: Quantization,
363) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
364    let (n_candidates, n_classes) = scores.dim();
365
366    // Phase 1: Column-based argmax — sequential reads over contiguous columns.
367    // Use u8 for class indices (max 255 classes) for optimal NEON vectorization.
368    // Fall back to row-major path for models with >255 classes.
369    if n_classes > 255 {
370        let indices: Array1<usize> = (0..n_candidates).collect();
371        return Zip::from(scores.rows())
372            .and(boxes.rows())
373            .and(&indices)
374            .into_par_iter()
375            .filter_map(|(score, bbox, index)| {
376                let (score_, label) = fast_arg_max(score);
377                if score_ < threshold {
378                    return None;
379                }
380                let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
381                Some((
382                    DetectBoxQuantized {
383                        label,
384                        score: score_,
385                        bbox: BoundingBox::from(bbox_quant),
386                    },
387                    *index,
388                ))
389            })
390            .collect();
391    }
392    let mut max_scores = vec![Scores::min_value(); n_candidates];
393    let mut max_classes = vec![0u8; n_candidates];
394    for class_idx in 0..n_classes {
395        let col = scores.column(class_idx);
396        if let Some(slice) = col.as_slice() {
397            // Use NEON-accelerated column max update on aarch64 for u8 scores.
398            #[cfg(target_arch = "aarch64")]
399            {
400                if std::mem::size_of::<Scores>() == 1 {
401                    // SAFETY: Scores is u8 or i8 (size == 1). We transmute the
402                    // slice pointers to the concrete byte type for NEON processing.
403                    unsafe {
404                        column_max_update_neon(
405                            slice.as_ptr() as *const u8,
406                            max_scores.as_mut_ptr() as *mut u8,
407                            max_classes.as_mut_ptr(),
408                            n_candidates,
409                            class_idx as u8,
410                            Scores::min_value() < Scores::zero(), // signed flag
411                        );
412                    }
413                    continue;
414                }
415            }
416            for (i, &val) in slice.iter().enumerate() {
417                if val >= max_scores[i] {
418                    max_scores[i] = val;
419                    max_classes[i] = class_idx as u8;
420                }
421            }
422        } else {
423            for (i, &val) in col.iter().enumerate() {
424                if val >= max_scores[i] {
425                    max_scores[i] = val;
426                    max_classes[i] = class_idx as u8;
427                }
428            }
429        }
430    }
431
432    // Phase 2: Copy boxes column-by-column into contiguous heap buffer.
433    // Boxes view is also transposed [N_candidates, 4] with strides [1, N_candidates],
434    // so column reads are sequential while row reads are strided.
435    let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
436        let mut cols: [Vec<Boxes>; 4] = [
437            vec![Boxes::zero(); n_candidates],
438            vec![Boxes::zero(); n_candidates],
439            vec![Boxes::zero(); n_candidates],
440            vec![Boxes::zero(); n_candidates],
441        ];
442        for (dim, col_buf) in cols.iter_mut().enumerate() {
443            let col = boxes.column(dim);
444            if let Some(slice) = col.as_slice() {
445                col_buf.copy_from_slice(slice);
446            } else {
447                for (i, &val) in col.iter().enumerate() {
448                    col_buf[i] = val;
449                }
450            }
451        }
452        cols
453    } else {
454        // Boxes are contiguous or differently strided — read per-candidate below.
455        [vec![], vec![], vec![], vec![]]
456    };
457    let boxes_copied = !boxes_buf[0].is_empty();
458
459    // Phase 3: Threshold filter — collect candidates that pass.
460    let mut result = Vec::new();
461    for i in 0..n_candidates {
462        if max_scores[i] >= threshold {
463            let bbox_quant = if boxes_copied {
464                let raw = [
465                    boxes_buf[0][i],
466                    boxes_buf[1][i],
467                    boxes_buf[2][i],
468                    boxes_buf[3][i],
469                ];
470                B::to_xyxy_dequant(&raw, quant_boxes)
471            } else {
472                B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
473            };
474            result.push((
475                DetectBoxQuantized {
476                    label: max_classes[i] as usize,
477                    score: max_scores[i],
478                    bbox: BoundingBox::from(bbox_quant),
479                },
480                i,
481            ));
482        }
483    }
484
485    result
486}
487
488/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
489/// then greedily selects a subset of boxes in descending order of score.
490#[doc(hidden)]
491#[must_use]
492pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
493    iou: f32,
494    max_det: Option<usize>,
495    mut boxes: Vec<DetectBoxQuantized<SCORE>>,
496) -> Vec<DetectBoxQuantized<SCORE>> {
497    // Boxes get sorted by score in descending order so we know based on the
498    // index the scoring of the boxes and can skip parts of the loop.
499
500    boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
501
502    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
503    // immediately
504    if iou >= 1.0 {
505        return match max_det {
506            Some(n) => {
507                boxes.truncate(n);
508                boxes
509            }
510            None => boxes,
511        };
512    }
513
514    let min_val = SCORE::min_value();
515    let cap = max_det.unwrap_or(usize::MAX);
516    let mut survivors: usize = 0;
517    // Outer loop over all boxes.
518    for i in 0..boxes.len() {
519        if boxes[i].score <= min_val {
520            // this box was merged with a different box earlier
521            continue;
522        }
523        for j in (i + 1)..boxes.len() {
524            // Inner loop over boxes with lower score (later in the list).
525
526            if boxes[j].score <= min_val {
527                // this box was suppressed by different box earlier
528                continue;
529            }
530
531            if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
532                // suppress this box
533                boxes[j].score = min_val;
534            }
535        }
536        survivors += 1;
537        if survivors >= cap {
538            break;
539        }
540    }
541    // Filter out boxes that were suppressed; cap because boxes after the
542    // break may still hold positive scores but are all lower than survivors.
543    boxes
544        .into_iter()
545        .filter(|b| b.score > min_val)
546        .take(cap)
547        .collect()
548}
549
550/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
551/// then greedily selects a subset of boxes in descending order of score.
552///
553/// This is same as `nms_int` but will also include extra information along
554/// with each box, such as the index
555#[doc(hidden)]
556#[must_use]
557pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
558    iou: f32,
559    max_det: Option<usize>,
560    mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
561) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
562    // Boxes get sorted by score in descending order so we know based on the
563    // index the scoring of the boxes and can skip parts of the loop.
564    boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
565
566    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
567    // immediately
568    if iou >= 1.0 {
569        return match max_det {
570            Some(n) => {
571                boxes.truncate(n);
572                boxes
573            }
574            None => boxes,
575        };
576    }
577
578    let min_val = SCORE::min_value();
579    let cap = max_det.unwrap_or(usize::MAX);
580    let mut survivors: usize = 0;
581    // Outer loop over all boxes.
582    for i in 0..boxes.len() {
583        if boxes[i].0.score <= min_val {
584            // this box was merged with a different box earlier
585            continue;
586        }
587        for j in (i + 1)..boxes.len() {
588            // Inner loop over boxes with lower score (later in the list).
589
590            if boxes[j].0.score <= min_val {
591                // this box was suppressed by different box earlier
592                continue;
593            }
594            if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
595                // suppress this box
596                boxes[j].0.score = min_val;
597            }
598        }
599        survivors += 1;
600        if survivors >= cap {
601            break;
602        }
603    }
604
605    // Filter out boxes that were suppressed; cap at `max_det`.
606    boxes
607        .into_iter()
608        .filter(|b| b.0.score > min_val)
609        .take(cap)
610        .collect()
611}
612
613/// Class-aware NMS for quantized boxes: only suppress boxes with the same
614/// label.
615///
616/// Sorts boxes by score, then greedily selects a subset of boxes in descending
617/// order of score. Unlike class-agnostic NMS, boxes are only suppressed if they
618/// have the same class label AND overlap above the IoU threshold.
619#[doc(hidden)]
620#[must_use]
621pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
622    iou: f32,
623    max_det: Option<usize>,
624    mut boxes: Vec<DetectBoxQuantized<SCORE>>,
625) -> Vec<DetectBoxQuantized<SCORE>> {
626    boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
627
628    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
629    // immediately
630    if iou >= 1.0 {
631        return match max_det {
632            Some(n) => {
633                boxes.truncate(n);
634                boxes
635            }
636            None => boxes,
637        };
638    }
639
640    let min_val = SCORE::min_value();
641    let cap = max_det.unwrap_or(usize::MAX);
642    let mut survivors: usize = 0;
643    for i in 0..boxes.len() {
644        if boxes[i].score <= min_val {
645            continue;
646        }
647        for j in (i + 1)..boxes.len() {
648            if boxes[j].score <= min_val {
649                continue;
650            }
651            // Only suppress if same class AND overlapping
652            if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
653                boxes[j].score = min_val;
654            }
655        }
656        survivors += 1;
657        if survivors >= cap {
658            break;
659        }
660    }
661    boxes
662        .into_iter()
663        .filter(|b| b.score > min_val)
664        .take(cap)
665        .collect()
666}
667
668/// Class-aware NMS for quantized boxes with extra data: only suppress boxes
669/// with the same label.
670///
671/// This is same as `nms_class_aware_int` but will also include extra
672/// information along with each box, such as the index.
673#[doc(hidden)]
674#[must_use]
675pub fn nms_extra_class_aware_int<
676    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
677    E: Send + Sync,
678>(
679    iou: f32,
680    max_det: Option<usize>,
681    mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
682) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
683    boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
684
685    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
686    // immediately
687    if iou >= 1.0 {
688        return match max_det {
689            Some(n) => {
690                boxes.truncate(n);
691                boxes
692            }
693            None => boxes,
694        };
695    }
696
697    let min_val = SCORE::min_value();
698    let cap = max_det.unwrap_or(usize::MAX);
699    let mut survivors: usize = 0;
700    for i in 0..boxes.len() {
701        if boxes[i].0.score <= min_val {
702            continue;
703        }
704        for j in (i + 1)..boxes.len() {
705            if boxes[j].0.score <= min_val {
706                continue;
707            }
708            // Only suppress if same class AND overlapping
709            if boxes[j].0.label == boxes[i].0.label
710                && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
711            {
712                boxes[j].0.score = min_val;
713            }
714        }
715        survivors += 1;
716        if survivors >= cap {
717            break;
718        }
719    }
720    boxes
721        .into_iter()
722        .filter(|b| b.0.score > min_val)
723        .take(cap)
724        .collect()
725}
726
727/// Quantizes a score from f32 to the given integer type, using the following
728/// formula `(score/quant.scale + quant.zero_point).ceil()`, then clamping to
729/// the min and max value of the given integer type
730///
731/// # Examples
732/// ```rust
733/// use edgefirst_decoder::{Quantization, byte::quantize_score_threshold};
734/// let quant = Quantization {
735///     scale: 0.1,
736///     zero_point: 128,
737/// };
738/// let q: u8 = quantize_score_threshold::<u8>(0.5, quant);
739/// assert_eq!(q, 128 + 5);
740/// ```
741#[doc(hidden)]
742pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
743where
744    f32: AsPrimitive<T>,
745{
746    if quant.scale == 0.0 {
747        return T::max_value();
748    }
749    let v = (score / quant.scale + quant.zero_point as f32).ceil();
750    let v = v.clamp(T::min_value().as_(), T::max_value().as_());
751    v.as_()
752}
753
754#[cfg(test)]
755mod tests {
756    use super::*;
757    use crate::XYWH;
758    use ndarray::Array2;
759
760    /// Verify that the column-major path produces identical results to the
761    /// row-major path for a transposed (non-contiguous) score array.
762    #[test]
763    fn column_major_matches_row_major() {
764        // Create scores in "model output" layout: [num_classes, num_candidates]
765        let n_classes = 80usize;
766        let n_candidates = 100usize;
767        let mut scores_physical = Array2::<u8>::zeros((n_classes, n_candidates));
768        // Fill with known pattern: class c, candidate i → (c * 3 + i * 7) % 256
769        for c in 0..n_classes {
770            for i in 0..n_candidates {
771                scores_physical[[c, i]] = ((c * 3 + i * 7) % 256) as u8;
772            }
773        }
774
775        // Create boxes: [4, num_candidates] i16
776        let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
777        for i in 0..n_candidates {
778            boxes_physical[[0, i]] = (i * 10) as i16; // x
779            boxes_physical[[1, i]] = (i * 20) as i16; // y
780            boxes_physical[[2, i]] = (i * 10 + 50) as i16; // w
781            boxes_physical[[3, i]] = (i * 20 + 100) as i16; // h
782        }
783
784        let quant = Quantization {
785            scale: 0.00390625,
786            zero_point: 0,
787        };
788
789        let threshold: u8 = 10;
790
791        // Row-major path: contiguous [n_candidates, n_classes] array
792        let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
793        let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
794        let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
795            threshold,
796            boxes_contiguous.view(),
797            scores_contiguous.view(),
798            quant,
799        );
800
801        // Column-major path: non-contiguous reversed view
802        let scores_view = scores_physical.view().reversed_axes();
803        let boxes_view = boxes_physical.view().reversed_axes();
804        assert!(scores_view.as_slice().is_none(), "should be non-contiguous");
805        assert_eq!(scores_view.strides()[0], 1);
806        let col_result =
807            postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
808
809        // Both paths should produce the same results
810        assert_eq!(
811            row_result.len(),
812            col_result.len(),
813            "different number of results: row={}, col={}",
814            row_result.len(),
815            col_result.len()
816        );
817        for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
818            assert_eq!(
819                row.0.label, col.0.label,
820                "candidate {i}: label mismatch row={} col={}",
821                row.0.label, col.0.label
822            );
823            assert_eq!(row.0.score, col.0.score, "candidate {i}: score mismatch");
824            assert_eq!(row.1, col.1, "candidate {i}: index mismatch");
825            assert_eq!(row.0.bbox, col.0.bbox, "candidate {i}: bbox mismatch");
826        }
827    }
828
829    /// Test column-major path with i8 scores (signed, matches NEON argmax path).
830    #[test]
831    fn column_major_matches_row_major_i8() {
832        let n_classes = 80usize;
833        let n_candidates = 50usize;
834        let mut scores_physical = Array2::<i8>::zeros((n_classes, n_candidates));
835        for c in 0..n_classes {
836            for i in 0..n_candidates {
837                scores_physical[[c, i]] = ((c as i16 * 3 + i as i16 * 7) % 256 - 128) as i8;
838            }
839        }
840
841        let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
842        for i in 0..n_candidates {
843            boxes_physical[[0, i]] = (i * 10) as i16;
844            boxes_physical[[1, i]] = (i * 20) as i16;
845            boxes_physical[[2, i]] = (i * 10 + 50) as i16;
846            boxes_physical[[3, i]] = (i * 20 + 100) as i16;
847        }
848
849        let quant = Quantization {
850            scale: 0.0256,
851            zero_point: -116,
852        };
853        let threshold: i8 = -100;
854
855        let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
856        let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
857        let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
858            threshold,
859            boxes_contiguous.view(),
860            scores_contiguous.view(),
861            quant,
862        );
863
864        let scores_view = scores_physical.view().reversed_axes();
865        let boxes_view = boxes_physical.view().reversed_axes();
866        let col_result =
867            postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
868
869        assert_eq!(row_result.len(), col_result.len());
870        for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
871            assert_eq!(row.0.label, col.0.label, "i8 candidate {i}: label mismatch");
872            assert_eq!(row.0.score, col.0.score, "i8 candidate {i}: score mismatch");
873            assert_eq!(row.1, col.1, "i8 candidate {i}: index mismatch");
874        }
875    }
876
877    /// Helper: create `n` non-overlapping boxes with descending u8 scores.
878    fn make_nms_boxes_int(n: usize) -> Vec<DetectBoxQuantized<u8>> {
879        (0..n)
880            .map(|i| DetectBoxQuantized {
881                bbox: BoundingBox {
882                    xmin: i as f32 * 100.0,
883                    ymin: 0.0,
884                    xmax: i as f32 * 100.0 + 10.0,
885                    ymax: 10.0,
886                },
887                label: 0,
888                score: (200 - i as u32).min(255) as u8,
889            })
890            .collect()
891    }
892
893    #[test]
894    fn nms_int_max_det_matches_full_truncated() {
895        let boxes = make_nms_boxes_int(20);
896        let n = 5;
897        let full = nms_int(0.5, None, boxes.clone());
898        let capped = nms_int(0.5, Some(n), boxes);
899        assert_eq!(capped.len(), n);
900        assert_eq!(&full[..n], &capped[..]);
901    }
902
903    #[test]
904    fn nms_int_max_det_zero_returns_empty() {
905        let boxes = make_nms_boxes_int(10);
906        let result = nms_int(0.5, Some(0), boxes);
907        assert!(result.is_empty());
908    }
909
910    #[test]
911    fn nms_int_max_det_iou_ge_1_returns_sorted_truncated() {
912        let boxes = make_nms_boxes_int(10);
913        let result = nms_int(1.0, Some(3), boxes);
914        assert_eq!(result.len(), 3);
915        // Scores should be in descending order (sorted).
916        assert!(result[0].score >= result[1].score);
917        assert!(result[1].score >= result[2].score);
918    }
919
920    #[test]
921    fn nms_int_max_det_larger_than_input() {
922        let boxes = make_nms_boxes_int(5);
923        let full = nms_int(0.5, None, boxes.clone());
924        let capped = nms_int(0.5, Some(100), boxes);
925        assert_eq!(full.len(), capped.len());
926    }
927}