Skip to main content

edgefirst_decoder/
byte.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use crate::{
5    arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
6};
7use ndarray::{
8    parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
9    Array1, ArrayView2, Zip,
10};
11use num_traits::{AsPrimitive, PrimInt};
12use rayon::slice::ParallelSliceMut;
13
14/// Post processes boxes and scores tensors into quantized detection boxes,
15/// filtering out any boxes below the score threshold. The boxes tensor
16/// is converted to XYXY using the given BBoxTypeTrait. The order of the boxes
17/// is preserved.
18#[doc(hidden)]
19pub fn postprocess_boxes_quant<
20    B: BBoxTypeTrait,
21    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
22    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
23>(
24    threshold: Scores,
25    boxes: ArrayView2<Boxes>,
26    scores: ArrayView2<Scores>,
27    quant_boxes: Quantization,
28) -> Vec<DetectBoxQuantized<Scores>> {
29    assert_eq!(scores.dim().0, boxes.dim().0);
30    assert_eq!(boxes.dim().1, 4);
31    Zip::from(scores.rows())
32        .and(boxes.rows())
33        .into_par_iter()
34        .filter_map(|(score, bbox)| {
35            let (score_, label) = arg_max(score);
36            if score_ < threshold {
37                return None;
38            }
39
40            let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
41            Some(DetectBoxQuantized {
42                label,
43                score: score_,
44                bbox: BoundingBox::from(bbox_quant),
45            })
46        })
47        .collect()
48}
49
50/// Post processes boxes and scores tensors into quantized detection boxes,
51/// filtering out any boxes below the score threshold. The boxes tensor
52/// is converted to XYXY using the given BBoxTypeTrait. The order of the boxes
53/// is preserved.
54///
55/// This function is very similar to `postprocess_boxes_quant` but will also
56/// return the index of the box. The boxes will be in ascending index order.
57#[doc(hidden)]
58pub fn postprocess_boxes_index_quant<
59    B: BBoxTypeTrait,
60    Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
61    Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
62>(
63    threshold: Scores,
64    boxes: ArrayView2<Boxes>,
65    scores: ArrayView2<Scores>,
66    quant_boxes: Quantization,
67) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
68    assert_eq!(scores.dim().0, boxes.dim().0);
69    assert_eq!(boxes.dim().1, 4);
70    let indices: Array1<usize> = (0..boxes.dim().0).collect();
71    Zip::from(scores.rows())
72        .and(boxes.rows())
73        .and(&indices)
74        .into_par_iter()
75        .filter_map(|(score, bbox, index)| {
76            let (score_, label) = arg_max(score);
77            if score_ < threshold {
78                return None;
79            }
80
81            let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
82
83            Some((
84                DetectBoxQuantized {
85                    label,
86                    score: score_,
87                    bbox: BoundingBox::from(bbox_quant),
88                },
89                *index,
90            ))
91        })
92        .collect()
93}
94
95/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
96/// then greedily selects a subset of boxes in descending order of score.
97#[doc(hidden)]
98#[must_use]
99pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
100    iou: f32,
101    mut boxes: Vec<DetectBoxQuantized<SCORE>>,
102) -> Vec<DetectBoxQuantized<SCORE>> {
103    // Boxes get sorted by score in descending order so we know based on the
104    // index the scoring of the boxes and can skip parts of the loop.
105
106    boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
107
108    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
109    // immediately
110    if iou >= 1.0 {
111        return boxes;
112    }
113
114    let min_val = SCORE::min_value();
115    // Outer loop over all boxes.
116    for i in 0..boxes.len() {
117        if boxes[i].score <= min_val {
118            // this box was merged with a different box earlier
119            continue;
120        }
121        for j in (i + 1)..boxes.len() {
122            // Inner loop over boxes with lower score (later in the list).
123
124            if boxes[j].score <= min_val {
125                // this box was suppressed by different box earlier
126                continue;
127            }
128
129            if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
130                // suppress this box
131                boxes[j].score = min_val;
132            }
133        }
134    }
135    // Filter out boxes that were suppressed.
136    boxes.into_iter().filter(|b| b.score > min_val).collect()
137}
138
139/// Uses NMS to filter boxes based on the score and iou. Sorts boxes by score,
140/// then greedily selects a subset of boxes in descending order of score.
141///
142/// This is same as `nms_int` but will also include extra information along
143/// with each box, such as the index
144#[doc(hidden)]
145#[must_use]
146pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
147    iou: f32,
148    mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
149) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
150    // Boxes get sorted by score in descending order so we know based on the
151    // index the scoring of the boxes and can skip parts of the loop.
152    boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
153
154    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
155    // immediately
156    if iou >= 1.0 {
157        return boxes;
158    }
159
160    let min_val = SCORE::min_value();
161    // Outer loop over all boxes.
162    for i in 0..boxes.len() {
163        if boxes[i].0.score <= min_val {
164            // this box was merged with a different box earlier
165            continue;
166        }
167        for j in (i + 1)..boxes.len() {
168            // Inner loop over boxes with lower score (later in the list).
169
170            if boxes[j].0.score <= min_val {
171                // this box was suppressed by different box earlier
172                continue;
173            }
174            if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
175                // suppress this box
176                boxes[j].0.score = min_val;
177            }
178        }
179    }
180
181    // Filter out boxes that were suppressed.
182    boxes.into_iter().filter(|b| b.0.score > min_val).collect()
183}
184
185/// Class-aware NMS for quantized boxes: only suppress boxes with the same
186/// label.
187///
188/// Sorts boxes by score, then greedily selects a subset of boxes in descending
189/// order of score. Unlike class-agnostic NMS, boxes are only suppressed if they
190/// have the same class label AND overlap above the IoU threshold.
191#[doc(hidden)]
192#[must_use]
193pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
194    iou: f32,
195    mut boxes: Vec<DetectBoxQuantized<SCORE>>,
196) -> Vec<DetectBoxQuantized<SCORE>> {
197    boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
198
199    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
200    // immediately
201    if iou >= 1.0 {
202        return boxes;
203    }
204
205    let min_val = SCORE::min_value();
206    for i in 0..boxes.len() {
207        if boxes[i].score <= min_val {
208            continue;
209        }
210        for j in (i + 1)..boxes.len() {
211            if boxes[j].score <= min_val {
212                continue;
213            }
214            // Only suppress if same class AND overlapping
215            if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
216                boxes[j].score = min_val;
217            }
218        }
219    }
220    boxes.into_iter().filter(|b| b.score > min_val).collect()
221}
222
223/// Class-aware NMS for quantized boxes with extra data: only suppress boxes
224/// with the same label.
225///
226/// This is same as `nms_class_aware_int` but will also include extra
227/// information along with each box, such as the index.
228#[doc(hidden)]
229#[must_use]
230pub fn nms_extra_class_aware_int<
231    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
232    E: Send + Sync,
233>(
234    iou: f32,
235    mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
236) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
237    boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
238
239    // When the iou is 1.0 or larger, no boxes will be filtered so we just return
240    // immediately
241    if iou >= 1.0 {
242        return boxes;
243    }
244
245    let min_val = SCORE::min_value();
246    for i in 0..boxes.len() {
247        if boxes[i].0.score <= min_val {
248            continue;
249        }
250        for j in (i + 1)..boxes.len() {
251            if boxes[j].0.score <= min_val {
252                continue;
253            }
254            // Only suppress if same class AND overlapping
255            if boxes[j].0.label == boxes[i].0.label
256                && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
257            {
258                boxes[j].0.score = min_val;
259            }
260        }
261    }
262    boxes.into_iter().filter(|b| b.0.score > min_val).collect()
263}
264
265/// Quantizes a score from f32 to the given integer type, using the following
266/// formula `(score/quant.scale + quant.zero_point).ceil()`, then clamping to
267/// the min and max value of the given integer type
268///
269/// # Examples
270/// ```rust
271/// use edgefirst_decoder::{Quantization, byte::quantize_score_threshold};
272/// let quant = Quantization {
273///     scale: 0.1,
274///     zero_point: 128,
275/// };
276/// let q: u8 = quantize_score_threshold::<u8>(0.5, quant);
277/// assert_eq!(q, 128 + 5);
278/// ```
279#[doc(hidden)]
280pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
281where
282    f32: AsPrimitive<T>,
283{
284    if quant.scale == 0.0 {
285        return T::max_value();
286    }
287    let v = (score / quant.scale + quant.zero_point as f32).ceil();
288    let v = v.clamp(T::min_value().as_(), T::max_value().as_());
289    v.as_()
290}