Skip to main content

edgefirst_decoder/
byte.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4#[cfg(target_arch = "aarch64")]
5use crate::arg_max_i8;
6use crate::{
7    arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
8};
9use ndarray::{
10    parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
11    Array1, ArrayView1, ArrayView2, Zip,
12};
13use num_traits::{AsPrimitive, PrimInt};
14use rayon::slice::ParallelSliceMut;
15
16/// NEON-accelerated column max update for the column-major argmax path.
17///
18/// Processes 16 elements per iteration using SIMD max + bitwise select.
19/// Handles both unsigned (u8) and signed (i8) comparison semantics.
20///
21/// # Safety
22///
23/// - `col_ptr` must point to at least `n` valid bytes.
24/// - `max_ptr` must point to at least `n` valid mutable bytes.
25/// - `class_ptr` must point to at least `n` valid mutable bytes.
26#[cfg(target_arch = "aarch64")]
27unsafe fn column_max_update_neon(
28    col_ptr: *const u8,
29    max_ptr: *mut u8,
30    class_ptr: *mut u8,
31    n: usize,
32    class_idx: u8,
33    signed: bool,
34) {
35    use std::arch::aarch64::*;
36
37    let class_vec = vdupq_n_u8(class_idx);
38    let chunks = n / 16;
39    let remainder = n % 16;
40
41    if signed {
42        // Signed i8 comparison: interpret bytes as i8.
43        for chunk in 0..chunks {
44            let offset = chunk * 16;
45            let col = vld1q_s8(col_ptr.add(offset) as *const i8);
46            let cur_max = vld1q_s8(max_ptr.add(offset) as *const i8);
47            // mask[i] = 0xFF where col[i] >= cur_max[i], else 0x00
48            let mask = vcgeq_s8(col, cur_max);
49            // new_max = max(col, cur_max)
50            let new_max = vmaxq_s8(col, cur_max);
51            vst1q_s8(max_ptr.add(offset) as *mut i8, new_max);
52            // Select class_idx where mask is set, keep old class otherwise.
53            let cur_class = vld1q_u8(class_ptr.add(offset));
54            let new_class = vbslq_u8(mask, class_vec, cur_class);
55            vst1q_u8(class_ptr.add(offset), new_class);
56        }
57        // Scalar tail.
58        for i in (chunks * 16)..n {
59            let val = *(col_ptr.add(i) as *const i8);
60            let cur = *(max_ptr.add(i) as *const i8);
61            if val >= cur {
62                *(max_ptr.add(i) as *mut i8) = val;
63                *class_ptr.add(i) = class_idx;
64            }
65        }
66    } else {
67        // Unsigned u8 comparison.
68        for chunk in 0..chunks {
69            let offset = chunk * 16;
70            let col = vld1q_u8(col_ptr.add(offset));
71            let cur_max = vld1q_u8(max_ptr.add(offset));
72            let mask = vcgeq_u8(col, cur_max);
73            let new_max = vmaxq_u8(col, cur_max);
74            vst1q_u8(max_ptr.add(offset), new_max);
75            let cur_class = vld1q_u8(class_ptr.add(offset));
76            let new_class = vbslq_u8(mask, class_vec, cur_class);
77            vst1q_u8(class_ptr.add(offset), new_class);
78        }
79        // Scalar tail.
80        for i in (chunks * 16)..n {
81            let val = *col_ptr.add(i);
82            let cur = *max_ptr.add(i);
83            if val >= cur {
84                *max_ptr.add(i) = val;
85                *class_ptr.add(i) = class_idx;
86            }
87        }
88    }
89    let _ = remainder; // suppress unused warning
90}
91
92/// Fast argmax dispatching to NEON-optimized path for i8 on aarch64.
93#[inline(always)]
94fn fast_arg_max<T: PrimInt + Copy>(score: ArrayView1<T>) -> (T, usize) {
95    #[cfg(target_arch = "aarch64")]
96    {
97        // Check if this is an i8 slice and contiguous.
98        if std::mem::size_of::<T>() == 1 && score.as_slice().is_some() {
99            let slice = score.as_slice().unwrap();
100            // Safety: T is i8 when size_of::<T>() == 1 and PrimInt.
101            // PrimInt covers i8, u8, i16, etc. We only want to use the
102            // i8 NEON path for signed i8.
103            let ptr = slice.as_ptr() as *const i8;
104            let i8_slice = unsafe { std::slice::from_raw_parts(ptr, slice.len()) };
105            // Only valid for signed i8 (not u8). Check sign bit behavior:
106            // PrimInt for i8 means min_value() is negative.
107            if T::min_value() < T::zero() {
108                let (max_val, idx) = arg_max_i8(i8_slice);
109                // Safety: transmute i8 back to T (they have the same size and
110                // representation for i8).
111                let result: T = unsafe { std::mem::transmute_copy(&max_val) };
112                return (result, idx);
113            }
114        }
115    }
116    arg_max(score)
117}
118
119/// Post processes boxes and scores tensors into quantized detection boxes,
120/// filtering out any boxes below the score threshold. The boxes tensor
121/// is converted to XYXY using the given BBoxTypeTrait. The order of the boxes
122/// is preserved.
123#[doc(hidden)]
124pub fn postprocess_boxes_quant<
125    B: BBoxTypeTrait,
126    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
127    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
128>(
129    threshold: Scores,
130    boxes: ArrayView2<Boxes>,
131    scores: ArrayView2<Scores>,
132    quant_boxes: Quantization,
133) -> Vec<DetectBoxQuantized<Scores>> {
134    assert_eq!(scores.dim().0, boxes.dim().0);
135    assert_eq!(boxes.dim().1, 4);
136
137    // Use column-major path for transposed DMA-BUF views (see postprocess_boxes_index_quant).
138    if scores.strides()[0] == 1 && scores.as_slice().is_none() {
139        return postprocess_boxes_quant_column_major::<B, _, _>(
140            threshold,
141            boxes,
142            scores,
143            quant_boxes,
144        );
145    }
146
147    Zip::from(scores.rows())
148        .and(boxes.rows())
149        .into_par_iter()
150        .filter_map(|(score, bbox)| {
151            let (score_, label) = fast_arg_max(score);
152            if score_ < threshold {
153                return None;
154            }
155
156            let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
157            Some(DetectBoxQuantized {
158                label,
159                score: score_,
160                bbox: BoundingBox::from(bbox_quant),
161            })
162        })
163        .collect()
164}
165
166/// Column-major optimized path for `postprocess_boxes_quant`.
167fn postprocess_boxes_quant_column_major<
168    B: BBoxTypeTrait,
169    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
170    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
171>(
172    threshold: Scores,
173    boxes: ArrayView2<Boxes>,
174    scores: ArrayView2<Scores>,
175    quant_boxes: Quantization,
176) -> Vec<DetectBoxQuantized<Scores>> {
177    let (n_candidates, n_classes) = scores.dim();
178
179    // Column-major NEON path uses u8 class indices; fall back for >255 classes.
180    if n_classes > 255 {
181        return Zip::from(scores.rows())
182            .and(boxes.rows())
183            .into_par_iter()
184            .filter_map(|(score, bbox)| {
185                let (score_, label) = fast_arg_max(score);
186                if score_ < threshold {
187                    return None;
188                }
189                let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
190                Some(DetectBoxQuantized {
191                    label,
192                    score: score_,
193                    bbox: BoundingBox::from(bbox_quant),
194                })
195            })
196            .collect();
197    }
198    let mut max_scores = vec![Scores::min_value(); n_candidates];
199    let mut max_classes = vec![0u8; n_candidates];
200
201    for class_idx in 0..n_classes {
202        let col = scores.column(class_idx);
203        if let Some(slice) = col.as_slice() {
204            #[cfg(target_arch = "aarch64")]
205            {
206                if std::mem::size_of::<Scores>() == 1 {
207                    unsafe {
208                        column_max_update_neon(
209                            slice.as_ptr() as *const u8,
210                            max_scores.as_mut_ptr() as *mut u8,
211                            max_classes.as_mut_ptr(),
212                            n_candidates,
213                            class_idx as u8,
214                            Scores::min_value() < Scores::zero(),
215                        );
216                    }
217                    continue;
218                }
219            }
220            for (i, &val) in slice.iter().enumerate() {
221                if val >= max_scores[i] {
222                    max_scores[i] = val;
223                    max_classes[i] = class_idx as u8;
224                }
225            }
226        } else {
227            for (i, &val) in col.iter().enumerate() {
228                if val >= max_scores[i] {
229                    max_scores[i] = val;
230                    max_classes[i] = class_idx as u8;
231                }
232            }
233        }
234    }
235
236    // Copy boxes column-by-column if also transposed.
237    let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
238        let mut cols: [Vec<Boxes>; 4] = [
239            vec![Boxes::zero(); n_candidates],
240            vec![Boxes::zero(); n_candidates],
241            vec![Boxes::zero(); n_candidates],
242            vec![Boxes::zero(); n_candidates],
243        ];
244        for (dim, col_buf) in cols.iter_mut().enumerate() {
245            let col = boxes.column(dim);
246            if let Some(slice) = col.as_slice() {
247                col_buf.copy_from_slice(slice);
248            } else {
249                for (i, &val) in col.iter().enumerate() {
250                    col_buf[i] = val;
251                }
252            }
253        }
254        cols
255    } else {
256        [vec![], vec![], vec![], vec![]]
257    };
258    let boxes_copied = !boxes_buf[0].is_empty();
259
260    let mut result = Vec::new();
261    for i in 0..n_candidates {
262        if max_scores[i] >= threshold {
263            let bbox_quant = if boxes_copied {
264                let raw = [
265                    boxes_buf[0][i],
266                    boxes_buf[1][i],
267                    boxes_buf[2][i],
268                    boxes_buf[3][i],
269                ];
270                B::to_xyxy_dequant(&raw, quant_boxes)
271            } else {
272                B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
273            };
274            result.push(DetectBoxQuantized {
275                label: max_classes[i] as usize,
276                score: max_scores[i],
277                bbox: BoundingBox::from(bbox_quant),
278            });
279        }
280    }
281
282    result
283}
284
285/// Post processes boxes and scores tensors into quantized detection boxes,
286/// filtering out any boxes below the score threshold. The boxes tensor
287/// is converted to XYXY using the given BBoxTypeTrait. The order of the boxes
288/// is preserved.
289///
290/// This function is very similar to `postprocess_boxes_quant` but will also
291/// return the index of the box. The boxes will be in ascending index order.
292///
293/// When scores originate from a transposed DMA-BUF view (stride-1 along axis 0),
294/// an optimized column-major scan is used to avoid catastrophic strided reads on
295/// uncacheable memory.
296#[doc(hidden)]
297pub fn postprocess_boxes_index_quant<
298    B: BBoxTypeTrait,
299    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
300    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
301>(
302    threshold: Scores,
303    boxes: ArrayView2<Boxes>,
304    scores: ArrayView2<Scores>,
305    quant_boxes: Quantization,
306) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
307    assert_eq!(scores.dim().0, boxes.dim().0);
308    assert_eq!(boxes.dim().1, 4);
309
310    // Detect transposed C-contiguous layout (e.g., from reversed_axes() on DMA-BUF).
311    // In this layout columns are contiguous (stride-1 along axis 0) but rows are not.
312    // The column-major path reads memory sequentially, avoiding cache-hostile strides.
313    if scores.strides()[0] == 1 && scores.as_slice().is_none() {
314        return postprocess_boxes_index_quant_column_major::<B, _, _>(
315            threshold,
316            boxes,
317            scores,
318            quant_boxes,
319        );
320    }
321
322    let indices: Array1<usize> = (0..boxes.dim().0).collect();
323    Zip::from(scores.rows())
324        .and(boxes.rows())
325        .and(&indices)
326        .into_par_iter()
327        .filter_map(|(score, bbox, index)| {
328            let (score_, label) = fast_arg_max(score);
329            if score_ < threshold {
330                return None;
331            }
332
333            let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
334
335            Some((
336                DetectBoxQuantized {
337                    label,
338                    score: score_,
339                    bbox: BoundingBox::from(bbox_quant),
340                },
341                *index,
342            ))
343        })
344        .collect()
345}
346
347/// Column-major optimized path for `postprocess_boxes_index_quant`.
348///
349/// When scores come from a transposed DMA-BUF view ([N_candidates, N_classes]
350/// with strides [1, N_candidates]), row iteration causes N_classes reads each
351/// N_candidates bytes apart — catastrophic on uncacheable memory. Instead, this
352/// iterates over classes (columns, which are contiguous), maintaining a running
353/// argmax per candidate in cacheable heap buffers.
354fn postprocess_boxes_index_quant_column_major<
355    B: BBoxTypeTrait,
356    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
357    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
358>(
359    threshold: Scores,
360    boxes: ArrayView2<Boxes>,
361    scores: ArrayView2<Scores>,
362    quant_boxes: Quantization,
363) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
364    let (n_candidates, n_classes) = scores.dim();
365
366    // Phase 1: Column-based argmax — sequential reads over contiguous columns.
367    // Use u8 for class indices (max 255 classes) for optimal NEON vectorization.
368    // Fall back to row-major path for models with >255 classes.
369    if n_classes > 255 {
370        let indices: Array1<usize> = (0..n_candidates).collect();
371        return Zip::from(scores.rows())
372            .and(boxes.rows())
373            .and(&indices)
374            .into_par_iter()
375            .filter_map(|(score, bbox, index)| {
376                let (score_, label) = fast_arg_max(score);
377                if score_ < threshold {
378                    return None;
379                }
380                let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
381                Some((
382                    DetectBoxQuantized {
383                        label,
384                        score: score_,
385                        bbox: BoundingBox::from(bbox_quant),
386                    },
387                    *index,
388                ))
389            })
390            .collect();
391    }
392    let mut max_scores = vec![Scores::min_value(); n_candidates];
393    let mut max_classes = vec![0u8; n_candidates];
394    for class_idx in 0..n_classes {
395        let col = scores.column(class_idx);
396        if let Some(slice) = col.as_slice() {
397            // Use NEON-accelerated column max update on aarch64 for u8 scores.
398            #[cfg(target_arch = "aarch64")]
399            {
400                if std::mem::size_of::<Scores>() == 1 {
401                    // SAFETY: Scores is u8 or i8 (size == 1). We transmute the
402                    // slice pointers to the concrete byte type for NEON processing.
403                    unsafe {
404                        column_max_update_neon(
405                            slice.as_ptr() as *const u8,
406                            max_scores.as_mut_ptr() as *mut u8,
407                            max_classes.as_mut_ptr(),
408                            n_candidates,
409                            class_idx as u8,
410                            Scores::min_value() < Scores::zero(), // signed flag
411                        );
412                    }
413                    continue;
414                }
415            }
416            for (i, &val) in slice.iter().enumerate() {
417                if val >= max_scores[i] {
418                    max_scores[i] = val;
419                    max_classes[i] = class_idx as u8;
420                }
421            }
422        } else {
423            for (i, &val) in col.iter().enumerate() {
424                if val >= max_scores[i] {
425                    max_scores[i] = val;
426                    max_classes[i] = class_idx as u8;
427                }
428            }
429        }
430    }
431
432    // Phase 2: Copy boxes column-by-column into contiguous heap buffer.
433    // Boxes view is also transposed [N_candidates, 4] with strides [1, N_candidates],
434    // so column reads are sequential while row reads are strided.
435    let boxes_buf: [Vec<Boxes>; 4] = if boxes.strides()[0] == 1 && boxes.as_slice().is_none() {
436        let mut cols: [Vec<Boxes>; 4] = [
437            vec![Boxes::zero(); n_candidates],
438            vec![Boxes::zero(); n_candidates],
439            vec![Boxes::zero(); n_candidates],
440            vec![Boxes::zero(); n_candidates],
441        ];
442        for (dim, col_buf) in cols.iter_mut().enumerate() {
443            let col = boxes.column(dim);
444            if let Some(slice) = col.as_slice() {
445                col_buf.copy_from_slice(slice);
446            } else {
447                for (i, &val) in col.iter().enumerate() {
448                    col_buf[i] = val;
449                }
450            }
451        }
452        cols
453    } else {
454        // Boxes are contiguous or differently strided — read per-candidate below.
455        [vec![], vec![], vec![], vec![]]
456    };
457    let boxes_copied = !boxes_buf[0].is_empty();
458
459    // Phase 3: Threshold filter — collect candidates that pass.
460    let mut result = Vec::new();
461    for i in 0..n_candidates {
462        if max_scores[i] >= threshold {
463            let bbox_quant = if boxes_copied {
464                let raw = [
465                    boxes_buf[0][i],
466                    boxes_buf[1][i],
467                    boxes_buf[2][i],
468                    boxes_buf[3][i],
469                ];
470                B::to_xyxy_dequant(&raw, quant_boxes)
471            } else {
472                B::ndarray_to_xyxy_dequant(boxes.row(i), quant_boxes)
473            };
474            result.push((
475                DetectBoxQuantized {
476                    label: max_classes[i] as usize,
477                    score: max_scores[i],
478                    bbox: BoundingBox::from(bbox_quant),
479                },
480                i,
481            ));
482        }
483    }
484
485    result
486}
487
488/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
489/// then greedily selects a subset of boxes in descending order of score.
490#[doc(hidden)]
491#[must_use]
492pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
493    iou: f32,
494    mut boxes: Vec<DetectBoxQuantized<SCORE>>,
495) -> Vec<DetectBoxQuantized<SCORE>> {
496    // Boxes get sorted by score in descending order so we know based on the
497    // index the scoring of the boxes and can skip parts of the loop.
498
499    boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
500
501    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
502    // immediately
503    if iou >= 1.0 {
504        return boxes;
505    }
506
507    let min_val = SCORE::min_value();
508    // Outer loop over all boxes.
509    for i in 0..boxes.len() {
510        if boxes[i].score <= min_val {
511            // this box was merged with a different box earlier
512            continue;
513        }
514        for j in (i + 1)..boxes.len() {
515            // Inner loop over boxes with lower score (later in the list).
516
517            if boxes[j].score <= min_val {
518                // this box was suppressed by different box earlier
519                continue;
520            }
521
522            if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
523                // suppress this box
524                boxes[j].score = min_val;
525            }
526        }
527    }
528    // Filter out boxes that were suppressed.
529    boxes.into_iter().filter(|b| b.score > min_val).collect()
530}
531
532/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
533/// then greedily selects a subset of boxes in descending order of score.
534///
535/// This is same as `nms_int` but will also include extra information along
536/// with each box, such as the index
537#[doc(hidden)]
538#[must_use]
539pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
540    iou: f32,
541    mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
542) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
543    // Boxes get sorted by score in descending order so we know based on the
544    // index the scoring of the boxes and can skip parts of the loop.
545    boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
546
547    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
548    // immediately
549    if iou >= 1.0 {
550        return boxes;
551    }
552
553    let min_val = SCORE::min_value();
554    // Outer loop over all boxes.
555    for i in 0..boxes.len() {
556        if boxes[i].0.score <= min_val {
557            // this box was merged with a different box earlier
558            continue;
559        }
560        for j in (i + 1)..boxes.len() {
561            // Inner loop over boxes with lower score (later in the list).
562
563            if boxes[j].0.score <= min_val {
564                // this box was suppressed by different box earlier
565                continue;
566            }
567            if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
568                // suppress this box
569                boxes[j].0.score = min_val;
570            }
571        }
572    }
573
574    // Filter out boxes that were suppressed.
575    boxes.into_iter().filter(|b| b.0.score > min_val).collect()
576}
577
578/// Class-aware NMS for quantized boxes: only suppress boxes with the same
579/// label.
580///
581/// Sorts boxes by score, then greedily selects a subset of boxes in descending
582/// order of score. Unlike class-agnostic NMS, boxes are only suppressed if they
583/// have the same class label AND overlap above the IoU threshold.
584#[doc(hidden)]
585#[must_use]
586pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
587    iou: f32,
588    mut boxes: Vec<DetectBoxQuantized<SCORE>>,
589) -> Vec<DetectBoxQuantized<SCORE>> {
590    boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
591
592    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
593    // immediately
594    if iou >= 1.0 {
595        return boxes;
596    }
597
598    let min_val = SCORE::min_value();
599    for i in 0..boxes.len() {
600        if boxes[i].score <= min_val {
601            continue;
602        }
603        for j in (i + 1)..boxes.len() {
604            if boxes[j].score <= min_val {
605                continue;
606            }
607            // Only suppress if same class AND overlapping
608            if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
609                boxes[j].score = min_val;
610            }
611        }
612    }
613    boxes.into_iter().filter(|b| b.score > min_val).collect()
614}
615
616/// Class-aware NMS for quantized boxes with extra data: only suppress boxes
617/// with the same label.
618///
619/// This is same as `nms_class_aware_int` but will also include extra
620/// information along with each box, such as the index.
621#[doc(hidden)]
622#[must_use]
623pub fn nms_extra_class_aware_int<
624    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
625    E: Send + Sync,
626>(
627    iou: f32,
628    mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
629) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
630    boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
631
632    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
633    // immediately
634    if iou >= 1.0 {
635        return boxes;
636    }
637
638    let min_val = SCORE::min_value();
639    for i in 0..boxes.len() {
640        if boxes[i].0.score <= min_val {
641            continue;
642        }
643        for j in (i + 1)..boxes.len() {
644            if boxes[j].0.score <= min_val {
645                continue;
646            }
647            // Only suppress if same class AND overlapping
648            if boxes[j].0.label == boxes[i].0.label
649                && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
650            {
651                boxes[j].0.score = min_val;
652            }
653        }
654    }
655    boxes.into_iter().filter(|b| b.0.score > min_val).collect()
656}
657
658/// Quantizes a score from f32 to the given integer type, using the following
659/// formula `(score/quant.scale + quant.zero_point).ceil()`, then clamping to
660/// the min and max value of the given integer type
661///
662/// # Examples
663/// ```rust
664/// use edgefirst_decoder::{Quantization, byte::quantize_score_threshold};
665/// let quant = Quantization {
666///     scale: 0.1,
667///     zero_point: 128,
668/// };
669/// let q: u8 = quantize_score_threshold::<u8>(0.5, quant);
670/// assert_eq!(q, 128 + 5);
671/// ```
672#[doc(hidden)]
673pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
674where
675    f32: AsPrimitive<T>,
676{
677    if quant.scale == 0.0 {
678        return T::max_value();
679    }
680    let v = (score / quant.scale + quant.zero_point as f32).ceil();
681    let v = v.clamp(T::min_value().as_(), T::max_value().as_());
682    v.as_()
683}
684
685#[cfg(test)]
686mod tests {
687    use super::*;
688    use crate::XYWH;
689    use ndarray::Array2;
690
691    /// Verify that the column-major path produces identical results to the
692    /// row-major path for a transposed (non-contiguous) score array.
693    #[test]
694    fn column_major_matches_row_major() {
695        // Create scores in "model output" layout: [num_classes, num_candidates]
696        let n_classes = 80usize;
697        let n_candidates = 100usize;
698        let mut scores_physical = Array2::<u8>::zeros((n_classes, n_candidates));
699        // Fill with known pattern: class c, candidate i → (c * 3 + i * 7) % 256
700        for c in 0..n_classes {
701            for i in 0..n_candidates {
702                scores_physical[[c, i]] = ((c * 3 + i * 7) % 256) as u8;
703            }
704        }
705
706        // Create boxes: [4, num_candidates] i16
707        let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
708        for i in 0..n_candidates {
709            boxes_physical[[0, i]] = (i * 10) as i16; // x
710            boxes_physical[[1, i]] = (i * 20) as i16; // y
711            boxes_physical[[2, i]] = (i * 10 + 50) as i16; // w
712            boxes_physical[[3, i]] = (i * 20 + 100) as i16; // h
713        }
714
715        let quant = Quantization {
716            scale: 0.00390625,
717            zero_point: 0,
718        };
719
720        let threshold: u8 = 10;
721
722        // Row-major path: contiguous [n_candidates, n_classes] array
723        let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
724        let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
725        let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
726            threshold,
727            boxes_contiguous.view(),
728            scores_contiguous.view(),
729            quant,
730        );
731
732        // Column-major path: non-contiguous reversed view
733        let scores_view = scores_physical.view().reversed_axes();
734        let boxes_view = boxes_physical.view().reversed_axes();
735        assert!(scores_view.as_slice().is_none(), "should be non-contiguous");
736        assert_eq!(scores_view.strides()[0], 1);
737        let col_result =
738            postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
739
740        // Both paths should produce the same results
741        assert_eq!(
742            row_result.len(),
743            col_result.len(),
744            "different number of results: row={}, col={}",
745            row_result.len(),
746            col_result.len()
747        );
748        for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
749            assert_eq!(
750                row.0.label, col.0.label,
751                "candidate {i}: label mismatch row={} col={}",
752                row.0.label, col.0.label
753            );
754            assert_eq!(row.0.score, col.0.score, "candidate {i}: score mismatch");
755            assert_eq!(row.1, col.1, "candidate {i}: index mismatch");
756            assert_eq!(row.0.bbox, col.0.bbox, "candidate {i}: bbox mismatch");
757        }
758    }
759
760    /// Test column-major path with i8 scores (signed, matches NEON argmax path).
761    #[test]
762    fn column_major_matches_row_major_i8() {
763        let n_classes = 80usize;
764        let n_candidates = 50usize;
765        let mut scores_physical = Array2::<i8>::zeros((n_classes, n_candidates));
766        for c in 0..n_classes {
767            for i in 0..n_candidates {
768                scores_physical[[c, i]] = ((c as i16 * 3 + i as i16 * 7) % 256 - 128) as i8;
769            }
770        }
771
772        let mut boxes_physical = Array2::<i16>::zeros((4, n_candidates));
773        for i in 0..n_candidates {
774            boxes_physical[[0, i]] = (i * 10) as i16;
775            boxes_physical[[1, i]] = (i * 20) as i16;
776            boxes_physical[[2, i]] = (i * 10 + 50) as i16;
777            boxes_physical[[3, i]] = (i * 20 + 100) as i16;
778        }
779
780        let quant = Quantization {
781            scale: 0.0256,
782            zero_point: -116,
783        };
784        let threshold: i8 = -100;
785
786        let scores_contiguous = scores_physical.clone().reversed_axes().to_owned();
787        let boxes_contiguous = boxes_physical.clone().reversed_axes().to_owned();
788        let row_result = postprocess_boxes_index_quant::<XYWH, _, _>(
789            threshold,
790            boxes_contiguous.view(),
791            scores_contiguous.view(),
792            quant,
793        );
794
795        let scores_view = scores_physical.view().reversed_axes();
796        let boxes_view = boxes_physical.view().reversed_axes();
797        let col_result =
798            postprocess_boxes_index_quant::<XYWH, _, _>(threshold, boxes_view, scores_view, quant);
799
800        assert_eq!(row_result.len(), col_result.len());
801        for (i, (row, col)) in row_result.iter().zip(col_result.iter()).enumerate() {
802            assert_eq!(row.0.label, col.0.label, "i8 candidate {i}: label mismatch");
803            assert_eq!(row.0.score, col.0.score, "i8 candidate {i}: score mismatch");
804            assert_eq!(row.1, col.1, "i8 candidate {i}: index mismatch");
805        }
806    }
807}