Skip to main content

edgefirst_decoder/
byte.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(target_arch = "aarch64")]
5use crate::arg_max_i8;
6use crate::{
7    arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
8};
9use ndarray::{
10    parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
11    Array1, ArrayView1, ArrayView2, Zip,
12};
13use num_traits::{AsPrimitive, PrimInt};
14use rayon::slice::ParallelSliceMut;
15
16/// NEON-accelerated column max update for the column-major argmax path.
17///
18/// Processes 16 elements per iteration using SIMD max + bitwise select.
19/// Handles both unsigned (u8) and signed (i8) comparison semantics.
20///
21/// # Safety
22///
23/// - `col_ptr` must point to at least `n` valid bytes.
24/// - `max_ptr` must point to at least `n` valid mutable bytes.
25/// - `class_ptr` must point to at least `n` valid mutable bytes.
26#[cfg(target_arch = "aarch64")]
27unsafe fn column_max_update_neon(
28    col_ptr: *const u8,
29    max_ptr: *mut u8,
30    class_ptr: *mut u8,
31    n: usize,
32    class_idx: u8,
33    signed: bool,
34) {
35    use std::arch::aarch64::*;
36
37    let class_vec = vdupq_n_u8(class_idx);
38    let chunks = n / 16;
39    let remainder = n % 16;
40
41    if signed {
42        // Signed i8 comparison: interpret bytes as i8.
43        for chunk in 0..chunks {
44            let offset = chunk * 16;
45            let col = vld1q_s8(col_ptr.add(offset) as *const i8);
46            let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
47            // mask[i] = 0xFF where col[i] >= cur_max[i], else 0x00
48            let mask = vcgeq_s8(col, cur_max);
49            // new_max = max(col, cur_max)
50            let new_max = vmaxq_s8(col, cur_max);
51            vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
52            // Select class_idx where mask is set, keep old class otherwise.
53            let cur_class = vld1q_u8(class_ptr.add(offset));
54            let new_class = vbslq_u8(mask, class_vec, cur_class);
55            vst1q_u8(class_ptr.add(offset), new_class);
56        }
57        // Scalar tail.
58        for i in (chunks * 16)..n {
59            let val = *(col_ptr.add(i) as *const i8);
60            let cur = *(max_ptr.add(i) as *const i8);
61            if val >= cur {
62                *(max_ptr.add(i) as *mut i8) = val;
63                *class_ptr.add(i) = class_idx;
64            }
65        }
66    } else {
67        // Unsigned u8 comparison.
68        for chunk in 0..chunks {
69            let offset = chunk * 16;
70            let col = vld1q_u8(col_ptr.add(offset));
71            let cur_max = vld1q_u8(max_ptr.add(offset));
72            let mask = vcgeq_u8(col, cur_max);
73            let new_max = vmaxq_u8(col, cur_max);
74            vst1q_u8(max_ptr.add(offset), new_max);
75            let cur_class = vld1q_u8(class_ptr.add(offset));
76            let new_class = vbslq_u8(mask, class_vec, cur_class);
77            vst1q_u8(class_ptr.add(offset), new_class);
78        }
79        // Scalar tail.
80        for i in (chunks * 16)..n {
81            let val = *col_ptr.add(i);
82            let cur = *max_ptr.add(i);
83            if val >= cur {
84                *max_ptr.add(i) = val;
85                *class_ptr.add(i) = class_idx;
86            }
87        }
88    }
89    let _ = remainder; // suppress unused warning
90}
91
92/// NEON-accelerated column max update with software prefetch for DMA-BUF.
93///
94/// Adds PRFM (prefetch for load, L1 data cache, streaming) hints 2 cache
95/// lines ahead of the current read position to hide CMA memory latency.
96/// On Cortex-A55 (64B cache lines) this prefetches 128 bytes ahead; on
97/// Cortex-A76 (64B L1, 128B L2 lines) the same offset covers one L2 line.
98///
99/// # Safety
100///
101/// - `col_ptr` must point to at least `n` valid bytes.
102/// - `max_ptr` must point to at least `n` valid mutable bytes.
103/// - `class_ptr` must point to at least `n` valid mutable bytes.
104#[cfg(target_arch = "aarch64")]
105#[allow(dead_code)] // Reserved for DMA-BUF CMA paths where latency hiding helps.
106unsafe fn column_max_update_neon_prefetch(
107    col_ptr: *const u8,
108    max_ptr: *mut u8,
109    class_ptr: *mut u8,
110    n: usize,
111    class_idx: u8,
112    signed: bool,
113) {
114    use std::arch::aarch64::*;
115
116    const PREFETCH_AHEAD: usize = 128; // 2 cache lines on A55 (64B each)
117
118    let class_vec = vdupq_n_u8(class_idx);
119    let chunks = n / 16;
120
121    if signed {
122        for chunk in 0..chunks {
123            let offset = chunk * 16;
124            // Software prefetch: hint the next read 2 cache lines ahead.
125            if offset + PREFETCH_AHEAD < n {
126                core::arch::asm!(
127                    "prfm pldl1strm, [{ptr}]",
128                    ptr = in(reg) col_ptr.add(offset + PREFETCH_AHEAD),
129                    options(nostack, preserves_flags),
130                );
131            }
132            let col = vld1q_s8(col_ptr.add(offset) as *const i8);
133            let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
134            let mask = vcgeq_s8(col, cur_max);
135            let new_max = vmaxq_s8(col, cur_max);
136            vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
137            let cur_class = vld1q_u8(class_ptr.add(offset));
138            let new_class = vbslq_u8(mask, class_vec, cur_class);
139            vst1q_u8(class_ptr.add(offset), new_class);
140        }
141        for i in (chunks * 16)..n {
142            let val = *(col_ptr.add(i) as *const i8);
143            let cur = *(max_ptr.add(i) as *const i8);
144            if val >= cur {
145                *(max_ptr.add(i) as *mut i8) = val;
146                *class_ptr.add(i) = class_idx;
147            }
148        }
149    } else {
150        for chunk in 0..chunks {
151            let offset = chunk * 16;
152            if offset + PREFETCH_AHEAD < n {
153                core::arch::asm!(
154                    "prfm pldl1strm, [{ptr}]",
155                    ptr = in(reg) col_ptr.add(offset + PREFETCH_AHEAD),
156                    options(nostack, preserves_flags),
157                );
158            }
159            let col = vld1q_u8(col_ptr.add(offset));
160            let cur_max = vld1q_u8(max_ptr.add(offset));
161            let mask = vcgeq_u8(col, cur_max);
162            let new_max = vmaxq_u8(col, cur_max);
163            vst1q_u8(max_ptr.add(offset), new_max);
164            let cur_class = vld1q_u8(class_ptr.add(offset));
165            let new_class = vbslq_u8(mask, class_vec, cur_class);
166            vst1q_u8(class_ptr.add(offset), new_class);
167        }
168        for i in (chunks * 16)..n {
169            let val = *col_ptr.add(i);
170            let cur = *max_ptr.add(i);
171            if val >= cur {
172                *max_ptr.add(i) = val;
173                *class_ptr.add(i) = class_idx;
174            }
175        }
176    }
177}
178
179/// Fast argmax dispatching to NEON-optimized path for i8 on aarch64.
180#[inline(always)]
181fn fast_arg_max<T: PrimInt + Copy>(score: ArrayView1<T>) -> (T, usize) {
182    #[cfg(target_arch = "aarch64")]
183    {
184        // Check if this is an i8 slice and contiguous.
185        if std::mem::size_of::<T>() == 1 && score.as_slice().is_some() {
186            let slice = score.as_slice().unwrap();
187            // Safety: T is i8 when size_of::<T>() == 1 and PrimInt.
188            // PrimInt covers i8, u8, i16, etc. We only want to use the
189            // i8 NEON path for signed i8.
190            let ptr = slice.as_ptr() as *const i8;
191            let i8_slice = unsafe { std::slice::from_raw_parts(ptr, slice.len()) };
192            // Only valid for signed i8 (not u8). Check sign bit behavior:
193            // PrimInt for i8 means min_value() is negative.
194            if T::min_value() < T::zero() {
195                let (max_val, idx) = arg_max_i8(i8_slice);
196                // Safety: transmute i8 back to T (they have the same size and
197                // representation for i8).
198                let result: T = unsafe { std::mem::transmute_copy(&max_val) };
199                return (result, idx);
200            }
201        }
202    }
203    arg_max(score)
204}
205
206/// Post processes boxes and scores tensors into quantized detection boxes,
207/// filtering out any boxes below the score threshold. The boxes tensor
208/// is converted to XYXY using the given BBoxTypeTrait. The order of the boxes
209/// is preserved.
210#[doc(hidden)]
211pub fn postprocess_boxes_quant<
212    B: BBoxTypeTrait,
213    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
214    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
215>(
216    threshold: Scores,
217    boxes: ArrayView2<Boxes>,
218    scores: ArrayView2<Scores>,
219    quant_boxes: Quantization,
220) -> Vec<DetectBoxQuantized<Scores>> {
221    assert_eq!(scores.dim().0, boxes.dim().0);
222    assert_eq!(boxes.dim().1, 4);
223
224    // Use column-major path for transposed DMA-BUF views (see postprocess_boxes_index_quant).
225    if scores.strides()[0] == 1 && scores.as_slice().is_none() {
226        return postprocess_boxes_quant_column_major::<B, _, _>(
227            threshold,
228            boxes,
229            scores,
230            quant_boxes,
231        );
232    }
233
234    Zip::from(scores.rows())
235        .and(boxes.rows())
236        .into_par_iter()
237        .filter_map(|(score, bbox)| {
238            let (score_, label) = fast_arg_max(score);
239            if score_ < threshold {
240                return None;
241            }
242
243            let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
244            Some(DetectBoxQuantized {
245                label,
246                score: score_,
247                bbox: BoundingBox::from(bbox_quant),
248            })
249        })
250        .collect()
251}
252
253/// Column-major optimized path for `postprocess_boxes_quant`.
254fn postprocess_boxes_quant_column_major<
255    B: BBoxTypeTrait,
256    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
257    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
258>(
259    threshold: Scores,
260    boxes: ArrayView2<Boxes>,
261    scores: ArrayView2<Scores>,
262    quant_boxes: Quantization,
263) -> Vec<DetectBoxQuantized<Scores>> {
264    let (n_candidates, n_classes) = scores.dim();
265
266    // Column-major NEON path uses u8 class indices; fall back for >255 classes.
267    if n_classes > 255 {
268        return Zip::from(scores.rows())
269            .and(boxes.rows())
270            .into_par_iter()
271            .filter_map(|(score, bbox)| {
272                let (score_, label) = fast_arg_max(score);
273                if score_ < threshold {
274                    return None;
275                }
276                let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
277                Some(DetectBoxQuantized {
278                    label,
279                    score: score_,
280                    bbox: BoundingBox::from(bbox_quant),
281                })
282            })
283            .collect();
284    }
285    let mut max_scores = vec![Scores::min_value(); n_candidates];
286    let mut max_classes = vec![0u8; n_candidates];
287
288    for class_idx in 0..n_classes {
289        let col = scores.column(class_idx);
290        if let Some(slice) = col.as_slice() {
291            #[cfg(target_arch = "aarch64")]
292            {
293                if std::mem::size_of::<Scores>() == 1 {
294                    unsafe {
295                        // Non-prefetch variant for heap-backed tensors;
296                        // prefetch variant reserved for DMA-BUF CMA paths.
297                        column_max_update_neon(
298                            slice.as_ptr() as *const u8,
299                            max_scores.as_mut_ptr() as *mut u8,
300                            max_classes.as_mut_ptr(),
301                            n_candidates,
302                            class_idx as u8,
303                            Scores::min_value() < Scores::zero(),
304                        );
305                    }
306                    continue;
307                }
308            }
309            for (i, &val) in slice.iter().enumerate() {
310                if val >= max_scores[i] {
311                    max_scores[i] = val;
312                    max_classes[i] = class_idx as u8;
313                }
314            }
315        } else {
316            for (i, &val) in col.iter().enumerate() {
317                if val >= max_scores[i] {
318                    max_scores[i] = val;
319                    max_classes[i] = class_idx as u8;
320                }
321            }
322        }
323    }
324
325    // Copy boxes column-by-column if also transposed.
326    let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
327        let mut cols: [Vec<Boxes>; 4] = [
328            vec![Boxes::zero(); n_candidates],
329            vec![Boxes::zero(); n_candidates],
330            vec![Boxes::zero(); n_candidates],
331            vec![Boxes::zero(); n_candidates],
332        ];
333        for (dim, col_buf) in cols.iter_mut().enumerate() {
334            let col = boxes.column(dim);
335            if let Some(slice) = col.as_slice() {
336                col_buf.copy_from_slice(slice);
337            } else {
338                for (i, &val) in col.iter().enumerate() {
339                    col_buf[i] = val;
340                }
341            }
342        }
343        cols
344    } else {
345        [vec![], vec![], vec![], vec![]]
346    };
347    let boxes_copied = !boxes_buf[0].is_empty();
348
349    let mut result = Vec::new();
350    for i in 0..n_candidates {
351        if max_scores[i] >= threshold {
352            let bbox_quant = if boxes_copied {
353                let raw = [
354                    boxes_buf[0][i],
355                    boxes_buf[1][i],
356                    boxes_buf[2][i],
357                    boxes_buf[3][i],
358                ];
359                B::to_xyxy_dequant(&raw, quant_boxes)
360            } else {
361                B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
362            };
363            result.push(DetectBoxQuantized {
364                label: max_classes[i] as usize,
365                score: max_scores[i],
366                bbox: BoundingBox::from(bbox_quant),
367            });
368        }
369    }
370
371    result
372}
373
374/// Post processes boxes and scores tensors into quantized detection boxes,
375/// filtering out any boxes below the score threshold. The boxes tensor
376/// is converted to XYXY using the given BBoxTypeTrait. The order of the boxes
377/// is preserved.
378///
379/// This function is very similar to `postprocess_boxes_quant` but will also
380/// return the index of the box. The boxes will be in ascending index order.
381///
382/// When scores originate from a transposed DMA-BUF view (stride-1 along axis 0),
383/// an optimized column-major scan is used to avoid catastrophic strided reads on
384/// uncacheable memory.
385#[doc(hidden)]
386pub fn postprocess_boxes_index_quant<
387    B: BBoxTypeTrait,
388    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
389    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
390>(
391    threshold: Scores,
392    boxes: ArrayView2<Boxes>,
393    scores: ArrayView2<Scores>,
394    quant_boxes: Quantization,
395) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
396    assert_eq!(scores.dim().0, boxes.dim().0);
397    assert_eq!(boxes.dim().1, 4);
398
399    // Detect transposed C-contiguous layout (e.g., from reversed_axes() on DMA-BUF).
400    // In this layout columns are contiguous (stride-1 along axis 0) but rows are not.
401    // The column-major path reads memory sequentially, avoiding cache-hostile strides.
402    if scores.strides()[0] == 1 && scores.as_slice().is_none() {
403        return postprocess_boxes_index_quant_column_major::<B, _, _>(
404            threshold,
405            boxes,
406            scores,
407            quant_boxes,
408        );
409    }
410
411    let indices: Array1<usize> = (0..boxes.dim().0).collect();
412    Zip::from(scores.rows())
413        .and(boxes.rows())
414        .and(&indices)
415        .into_par_iter()
416        .filter_map(|(score, bbox, index)| {
417            let (score_, label) = fast_arg_max(score);
418            if score_ < threshold {
419                return None;
420            }
421
422            let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
423
424            Some((
425                DetectBoxQuantized {
426                    label,
427                    score: score_,
428                    bbox: BoundingBox::from(bbox_quant),
429                },
430                *index,
431            ))
432        })
433        .collect()
434}
435
436/// Column-major optimized path for `postprocess_boxes_index_quant`.
437///
438/// When scores come from a transposed DMA-BUF view ([N_candidates, N_classes]
439/// with strides [1, N_candidates]), row iteration causes N_classes reads each
440/// N_candidates bytes apart — catastrophic on uncacheable memory. Instead, this
441/// iterates over classes (columns, which are contiguous), maintaining a running
442/// argmax per candidate in cacheable heap buffers.
443fn postprocess_boxes_index_quant_column_major<
444    B: BBoxTypeTrait,
445    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
446    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
447>(
448    threshold: Scores,
449    boxes: ArrayView2<Boxes>,
450    scores: ArrayView2<Scores>,
451    quant_boxes: Quantization,
452) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
453    let (n_candidates, n_classes) = scores.dim();
454
455    // Phase 1: Column-based argmax — sequential reads over contiguous columns.
456    // Use u8 for class indices (max 255 classes) for optimal NEON vectorization.
457    // Fall back to row-major path for models with >255 classes.
458    if n_classes > 255 {
459        let indices: Array1<usize> = (0..n_candidates).collect();
460        return Zip::from(scores.rows())
461            .and(boxes.rows())
462            .and(&indices)
463            .into_par_iter()
464            .filter_map(|(score, bbox, index)| {
465                let (score_, label) = fast_arg_max(score);
466                if score_ < threshold {
467                    return None;
468                }
469                let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
470                Some((
471                    DetectBoxQuantized {
472                        label,
473                        score: score_,
474                        bbox: BoundingBox::from(bbox_quant),
475                    },
476                    *index,
477                ))
478            })
479            .collect();
480    }
481    let mut max_scores = vec![Scores::min_value(); n_candidates];
482    let mut max_classes = vec![0u8; n_candidates];
483    for class_idx in 0..n_classes {
484        let col = scores.column(class_idx);
485        if let Some(slice) = col.as_slice() {
486            // Use NEON-accelerated column max update on aarch64 for u8 scores.
487            #[cfg(target_arch = "aarch64")]
488            {
489                if std::mem::size_of::<Scores>() == 1 {
490                    // SAFETY: Scores is u8 or i8 (size == 1). We transmute the
491                    // slice pointers to the concrete byte type for NEON processing.
492                    unsafe {
493                        column_max_update_neon(
494                            slice.as_ptr() as *const u8,
495                            max_scores.as_mut_ptr() as *mut u8,
496                            max_classes.as_mut_ptr(),
497                            n_candidates,
498                            class_idx as u8,
499                            Scores::min_value() < Scores::zero(), // signed flag
500                        );
501                    }
502                    continue;
503                }
504            }
505            for (i, &val) in slice.iter().enumerate() {
506                if val >= max_scores[i] {
507                    max_scores[i] = val;
508                    max_classes[i] = class_idx as u8;
509                }
510            }
511        } else {
512            for (i, &val) in col.iter().enumerate() {
513                if val >= max_scores[i] {
514                    max_scores[i] = val;
515                    max_classes[i] = class_idx as u8;
516                }
517            }
518        }
519    }
520
521    // Phase 2: Copy boxes column-by-column into contiguous heap buffer.
522    // Boxes view is also transposed [N_candidates, 4] with strides [1, N_candidates],
523    // so column reads are sequential while row reads are strided.
524    let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
525        let mut cols: [Vec<Boxes>; 4] = [
526            vec![Boxes::zero(); n_candidates],
527            vec![Boxes::zero(); n_candidates],
528            vec![Boxes::zero(); n_candidates],
529            vec![Boxes::zero(); n_candidates],
530        ];
531        for (dim, col_buf) in cols.iter_mut().enumerate() {
532            let col = boxes.column(dim);
533            if let Some(slice) = col.as_slice() {
534                col_buf.copy_from_slice(slice);
535            } else {
536                for (i, &val) in col.iter().enumerate() {
537                    col_buf[i] = val;
538                }
539            }
540        }
541        cols
542    } else {
543        // Boxes are contiguous or differently strided — read per-candidate below.
544        [vec![], vec![], vec![], vec![]]
545    };
546    let boxes_copied = !boxes_buf[0].is_empty();
547
548    // Phase 3: Threshold filter — collect candidates that pass.
549    let mut result = Vec::new();
550    for i in 0..n_candidates {
551        if max_scores[i] >= threshold {
552            let bbox_quant = if boxes_copied {
553                let raw = [
554                    boxes_buf[0][i],
555                    boxes_buf[1][i],
556                    boxes_buf[2][i],
557                    boxes_buf[3][i],
558                ];
559                B::to_xyxy_dequant(&raw, quant_boxes)
560            } else {
561                B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
562            };
563            result.push((
564                DetectBoxQuantized {
565                    label: max_classes[i] as usize,
566                    score: max_scores[i],
567                    bbox: BoundingBox::from(bbox_quant),
568                },
569                i,
570            ));
571        }
572    }
573
574    result
575}
576
577/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
578/// then greedily selects a subset of boxes in descending order of score.
579#[doc(hidden)]
580#[must_use]
581pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
582    iou: f32,
583    max_det: Option<usize>,
584    mut boxes: Vec<DetectBoxQuantized<SCORE>>,
585) -> Vec<DetectBoxQuantized<SCORE>> {
586    // Boxes get sorted by score in descending order so we know based on the
587    // index the scoring of the boxes and can skip parts of the loop.
588
589    boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
590
591    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
592    // immediately
593    if iou >= 1.0 {
594        return match max_det {
595            Some(n) => {
596                boxes.truncate(n);
597                boxes
598            }
599            None => boxes,
600        };
601    }
602
603    let min_val = SCORE::min_value();
604    let cap = max_det.unwrap_or(usize::MAX);
605    let mut survivors: usize = 0;
606    // Outer loop over all boxes.
607    for i in 0..boxes.len() {
608        if boxes[i].score <= min_val {
609            // this box was merged with a different box earlier
610            continue;
611        }
612        for j in (i + 1)..boxes.len() {
613            // Inner loop over boxes with lower score (later in the list).
614
615            if boxes[j].score <= min_val {
616                // this box was suppressed by different box earlier
617                continue;
618            }
619
620            if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
621                // suppress this box
622                boxes[j].score = min_val;
623            }
624        }
625        survivors += 1;
626        if survivors >= cap {
627            break;
628        }
629    }
630    // Filter out boxes that were suppressed; cap because boxes after the
631    // break may still hold positive scores but are all lower than survivors.
632    boxes
633        .into_iter()
634        .filter(|b| b.score > min_val)
635        .take(cap)
636        .collect()
637}
638
639/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
640/// then greedily selects a subset of boxes in descending order of score.
641///
642/// This is same as `nms_int` but will also include extra information along
643/// with each box, such as the index
644#[doc(hidden)]
645#[must_use]
646pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
647    iou: f32,
648    max_det: Option<usize>,
649    mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
650) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
651    // Boxes get sorted by score in descending order so we know based on the
652    // index the scoring of the boxes and can skip parts of the loop.
653    boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
654
655    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
656    // immediately
657    if iou >= 1.0 {
658        return match max_det {
659            Some(n) => {
660                boxes.truncate(n);
661                boxes
662            }
663            None => boxes,
664        };
665    }
666
667    let min_val = SCORE::min_value();
668    let cap = max_det.unwrap_or(usize::MAX);
669    let mut survivors: usize = 0;
670    // Outer loop over all boxes.
671    for i in 0..boxes.len() {
672        if boxes[i].0.score <= min_val {
673            // this box was merged with a different box earlier
674            continue;
675        }
676        for j in (i + 1)..boxes.len() {
677            // Inner loop over boxes with lower score (later in the list).
678
679            if boxes[j].0.score <= min_val {
680                // this box was suppressed by different box earlier
681                continue;
682            }
683            if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
684                // suppress this box
685                boxes[j].0.score = min_val;
686            }
687        }
688        survivors += 1;
689        if survivors >= cap {
690            break;
691        }
692    }
693
694    // Filter out boxes that were suppressed; cap at `max_det`.
695    boxes
696        .into_iter()
697        .filter(|b| b.0.score > min_val)
698        .take(cap)
699        .collect()
700}
701
702/// Class-aware NMS for quantized boxes: only suppress boxes with the same
703/// label.
704///
705/// Sorts boxes by score, then greedily selects a subset of boxes in descending
706/// order of score. Unlike class-agnostic NMS, boxes are only suppressed if they
707/// have the same class label AND overlap above the IoU threshold.
708#[doc(hidden)]
709#[must_use]
710pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
711    iou: f32,
712    max_det: Option<usize>,
713    mut boxes: Vec<DetectBoxQuantized<SCORE>>,
714) -> Vec<DetectBoxQuantized<SCORE>> {
715    boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
716
717    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
718    // immediately
719    if iou >= 1.0 {
720        return match max_det {
721            Some(n) => {
722                boxes.truncate(n);
723                boxes
724            }
725            None => boxes,
726        };
727    }
728
729    let min_val = SCORE::min_value();
730    let cap = max_det.unwrap_or(usize::MAX);
731    let mut survivors: usize = 0;
732    for i in 0..boxes.len() {
733        if boxes[i].score <= min_val {
734            continue;
735        }
736        for j in (i + 1)..boxes.len() {
737            if boxes[j].score <= min_val {
738                continue;
739            }
740            // Only suppress if same class AND overlapping
741            if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
742                boxes[j].score = min_val;
743            }
744        }
745        survivors += 1;
746        if survivors >= cap {
747            break;
748        }
749    }
750    boxes
751        .into_iter()
752        .filter(|b| b.score > min_val)
753        .take(cap)
754        .collect()
755}
756
757/// Class-aware NMS for quantized boxes with extra data: only suppress boxes
758/// with the same label.
759///
760/// This is same as `nms_class_aware_int` but will also include extra
761/// information along with each box, such as the index.
762#[doc(hidden)]
763#[must_use]
764pub fn nms_extra_class_aware_int<
765    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
766    E: Send + Sync,
767>(
768    iou: f32,
769    max_det: Option<usize>,
770    mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
771) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
772    boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
773
774    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
775    // immediately
776    if iou >= 1.0 {
777        return match max_det {
778            Some(n) => {
779                boxes.truncate(n);
780                boxes
781            }
782            None => boxes,
783        };
784    }
785
786    let min_val = SCORE::min_value();
787    let cap = max_det.unwrap_or(usize::MAX);
788    let mut survivors: usize = 0;
789    for i in 0..boxes.len() {
790        if boxes[i].0.score <= min_val {
791            continue;
792        }
793        for j in (i + 1)..boxes.len() {
794            if boxes[j].0.score <= min_val {
795                continue;
796            }
797            // Only suppress if same class AND overlapping
798            if boxes[j].0.label == boxes[i].0.label
799                && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
800            {
801                boxes[j].0.score = min_val;
802            }
803        }
804        survivors += 1;
805        if survivors >= cap {
806            break;
807        }
808    }
809    boxes
810        .into_iter()
811        .filter(|b| b.0.score > min_val)
812        .take(cap)
813        .collect()
814}
815
816/// Quantizes a score from f32 to the given integer type, using the following
817/// formula `(score/quant.scale + quant.zero_point).ceil()`, then clamping to
818/// the min and max value of the given integer type
819///
820/// # Examples
821/// ```rust
822/// use edgefirst_decoder::{Quantization, byte::quantize_score_threshold};
823/// let quant = Quantization {
824///     scale: 0.1,
825///     zero_point: 128,
826/// };
827/// let q: u8 = quantize_score_threshold::<u8>(0.5, quant);
828/// assert_eq!(q, 128 + 5);
829/// ```
830#[doc(hidden)]
831pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
832where
833    f32: AsPrimitive<T>,
834{
835    if quant.scale == 0.0 {
836        return T::max_value();
837    }
838    let v = (score / quant.scale + quant.zero_point as f32).ceil();
839    let v = v.clamp(T::min_value().as_(), T::max_value().as_());
840    v.as_()
841}
842
843#[cfg(test)]
844mod tests {
845    use super::*;
846    use crate::XYWH;
847    use ndarray::Array2;
848
849    /// Verify that the column-major path produces identical results to the
850    /// row-major path for a transposed (non-contiguous) score array.
851    #[test]
852    fn column_major_matches_row_major() {
853        // Create scores in "model output" layout: [num_classes, num_candidates]
854        let n_classes = 80usize;
855        let n_candidates = 100usize;
856        let mut scores_physical = Array2::<u8>::zeros((n_classes, n_candidates));
857        // Fill with known pattern: class c, candidate i → (c * 3 + i * 7) % 256
858        for c in 0..n_classes {
859            for i in 0..n_candidates {
860                scores_physical[[c, i]] = ((c * 3 + i * 7) % 256) as u8;
861            }
862        }
863
864        // Create boxes: [4, num_candidates] i16
865        let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
866        for i in 0..n_candidates {
867            boxes_physical[[0, i]] = (i * 10) as i16; // x
868            boxes_physical[[1, i]] = (i * 20) as i16; // y
869            boxes_physical[[2, i]] = (i * 10 + 50) as i16; // w
870            boxes_physical[[3, i]] = (i * 20 + 100) as i16; // h
871        }
872
873        let quant = Quantization {
874            scale: 0.00390625,
875            zero_point: 0,
876        };
877
878        let threshold: u8 = 10;
879
880        // Row-major path: contiguous [n_candidates, n_classes] array
881        let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
882        let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
883        let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
884            threshold,
885            boxes_contiguous.view(),
886            scores_contiguous.view(),
887            quant,
888        );
889
890        // Column-major path: non-contiguous reversed view
891        let scores_view = scores_physical.view().reversed_axes();
892        let boxes_view = boxes_physical.view().reversed_axes();
893        assert!(scores_view.as_slice().is_none(), "should be non-contiguous");
894        assert_eq!(scores_view.strides()[0], 1);
895        let col_result =
896            postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
897
898        // Both paths should produce the same results
899        assert_eq!(
900            row_result.len(),
901            col_result.len(),
902            "different number of results: row={}, col={}",
903            row_result.len(),
904            col_result.len()
905        );
906        for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
907            assert_eq!(
908                row.0.label, col.0.label,
909                "candidate {i}: label mismatch row={} col={}",
910                row.0.label, col.0.label
911            );
912            assert_eq!(row.0.score, col.0.score, "candidate {i}: score mismatch");
913            assert_eq!(row.1, col.1, "candidate {i}: index mismatch");
914            assert_eq!(row.0.bbox, col.0.bbox, "candidate {i}: bbox mismatch");
915        }
916    }
917
918    /// Test column-major path with i8 scores (signed, matches NEON argmax path).
919    #[test]
920    fn column_major_matches_row_major_i8() {
921        let n_classes = 80usize;
922        let n_candidates = 50usize;
923        let mut scores_physical = Array2::<i8>::zeros((n_classes, n_candidates));
924        for c in 0..n_classes {
925            for i in 0..n_candidates {
926                scores_physical[[c, i]] = ((c as i16 * 3 + i as i16 * 7) % 256 - 128) as i8;
927            }
928        }
929
930        let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
931        for i in 0..n_candidates {
932            boxes_physical[[0, i]] = (i * 10) as i16;
933            boxes_physical[[1, i]] = (i * 20) as i16;
934            boxes_physical[[2, i]] = (i * 10 + 50) as i16;
935            boxes_physical[[3, i]] = (i * 20 + 100) as i16;
936        }
937
938        let quant = Quantization {
939            scale: 0.0256,
940            zero_point: -116,
941        };
942        let threshold: i8 = -100;
943
944        let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
945        let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
946        let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
947            threshold,
948            boxes_contiguous.view(),
949            scores_contiguous.view(),
950            quant,
951        );
952
953        let scores_view = scores_physical.view().reversed_axes();
954        let boxes_view = boxes_physical.view().reversed_axes();
955        let col_result =
956            postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
957
958        assert_eq!(row_result.len(), col_result.len());
959        for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
960            assert_eq!(row.0.label, col.0.label, "i8 candidate {i}: label mismatch");
961            assert_eq!(row.0.score, col.0.score, "i8 candidate {i}: score mismatch");
962            assert_eq!(row.1, col.1, "i8 candidate {i}: index mismatch");
963        }
964    }
965
966    /// Helper: create `n` non-overlapping boxes with descending u8 scores.
967    fn make_nms_boxes_int(n: usize) -> Vec<DetectBoxQuantized<u8>> {
968        (0..n)
969            .map(|i| DetectBoxQuantized {
970                bbox: BoundingBox {
971                    xmin: i as f32 * 100.0,
972                    ymin: 0.0,
973                    xmax: i as f32 * 100.0 + 10.0,
974                    ymax: 10.0,
975                },
976                label: 0,
977                score: (200 - i as u32).min(255) as u8,
978            })
979            .collect()
980    }
981
982    #[test]
983    fn nms_int_max_det_matches_full_truncated() {
984        let boxes = make_nms_boxes_int(20);
985        let n = 5;
986        let full = nms_int(0.5, None, boxes.clone());
987        let capped = nms_int(0.5, Some(n), boxes);
988        assert_eq!(capped.len(), n);
989        assert_eq!(&full[..n], &capped[..]);
990    }
991
992    #[test]
993    fn nms_int_max_det_zero_returns_empty() {
994        let boxes = make_nms_boxes_int(10);
995        let result = nms_int(0.5, Some(0), boxes);
996        assert!(result.is_empty());
997    }
998
999    #[test]
1000    fn nms_int_max_det_iou_ge_1_returns_sorted_truncated() {
1001        let boxes = make_nms_boxes_int(10);
1002        let result = nms_int(1.0, Some(3), boxes);
1003        assert_eq!(result.len(), 3);
1004        // Scores should be in descending order (sorted).
1005        assert!(result[0].score >= result[1].score);
1006        assert!(result[1].score >= result[2].score);
1007    }
1008
1009    #[test]
1010    fn nms_int_max_det_larger_than_input() {
1011        let boxes = make_nms_boxes_int(5);
1012        let full = nms_int(0.5, None, boxes.clone());
1013        let capped = nms_int(0.5, Some(100), boxes);
1014        assert_eq!(full.len(), capped.len());
1015    }
1016}