edgefirst_decoder/
byte.rs1use crate::{
5 arg_max, float::jaccard, BBoxTypeTrait, BoundingBox, DetectBoxQuantized, Quantization,
6};
7use ndarray::{
8 parallel::prelude::{IntoParallelIterator, ParallelIterator as _},
9 Array1, ArrayView2, Zip,
10};
11use num_traits::{AsPrimitive, PrimInt};
12use rayon::slice::ParallelSliceMut;
13
14#[doc(hidden)]
19pub fn postprocess_boxes_quant<
20 B: BBoxTypeTrait,
21 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
22 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
23>(
24 threshold: Scores,
25 boxes: ArrayView2<Boxes>,
26 scores: ArrayView2<Scores>,
27 quant_boxes: Quantization,
28) -> Vec<DetectBoxQuantized<Scores>> {
29 assert_eq!(scores.dim().0, boxes.dim().0);
30 assert_eq!(boxes.dim().1, 4);
31 Zip::from(scores.rows())
32 .and(boxes.rows())
33 .into_par_iter()
34 .filter_map(|(score, bbox)| {
35 let (score_, label) = arg_max(score);
36 if score_ < threshold {
37 return None;
38 }
39
40 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
41 Some(DetectBoxQuantized {
42 label,
43 score: score_,
44 bbox: BoundingBox::from(bbox_quant),
45 })
46 })
47 .collect()
48}
49
50#[doc(hidden)]
58pub fn postprocess_boxes_index_quant<
59 B: BBoxTypeTrait,
60 Boxes: PrimInt + AsPrimitive<f32> + Send + Sync,
61 Scores: PrimInt + AsPrimitive<f32> + Send + Sync,
62>(
63 threshold: Scores,
64 boxes: ArrayView2<Boxes>,
65 scores: ArrayView2<Scores>,
66 quant_boxes: Quantization,
67) -> Vec<(DetectBoxQuantized<Scores>, usize)> {
68 assert_eq!(scores.dim().0, boxes.dim().0);
69 assert_eq!(boxes.dim().1, 4);
70 let indices: Array1<usize> = (0..boxes.dim().0).collect();
71 Zip::from(scores.rows())
72 .and(boxes.rows())
73 .and(&indices)
74 .into_par_iter()
75 .filter_map(|(score, bbox, index)| {
76 let (score_, label) = arg_max(score);
77 if score_ < threshold {
78 return None;
79 }
80
81 let bbox_quant = B::ndarray_to_xyxy_dequant(bbox.view(), quant_boxes);
82
83 Some((
84 DetectBoxQuantized {
85 label,
86 score: score_,
87 bbox: BoundingBox::from(bbox_quant),
88 },
89 *index,
90 ))
91 })
92 .collect()
93}
94
95#[doc(hidden)]
98#[must_use]
99pub fn nms_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
100 iou: f32,
101 mut boxes: Vec<DetectBoxQuantized<SCORE>>,
102) -> Vec<DetectBoxQuantized<SCORE>> {
103 boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
107
108 if iou >= 1.0 {
111 return boxes;
112 }
113
114 let min_val = SCORE::min_value();
115 for i in 0..boxes.len() {
117 if boxes[i].score <= min_val {
118 continue;
120 }
121 for j in (i + 1)..boxes.len() {
122 if boxes[j].score <= min_val {
125 continue;
127 }
128
129 if jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
130 boxes[j].score = min_val;
132 }
133 }
134 }
135 boxes.into_iter().filter(|b| b.score > min_val).collect()
137}
138
139#[doc(hidden)]
145#[must_use]
146pub fn nms_extra_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync, E: Send + Sync>(
147 iou: f32,
148 mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
149) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
150 boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
153
154 if iou >= 1.0 {
157 return boxes;
158 }
159
160 let min_val = SCORE::min_value();
161 for i in 0..boxes.len() {
163 if boxes[i].0.score <= min_val {
164 continue;
166 }
167 for j in (i + 1)..boxes.len() {
168 if boxes[j].0.score <= min_val {
171 continue;
173 }
174 if jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou) {
175 boxes[j].0.score = min_val;
177 }
178 }
179 }
180
181 boxes.into_iter().filter(|b| b.0.score > min_val).collect()
183}
184
185#[doc(hidden)]
192#[must_use]
193pub fn nms_class_aware_int<SCORE: PrimInt + AsPrimitive<f32> + Send + Sync>(
194 iou: f32,
195 mut boxes: Vec<DetectBoxQuantized<SCORE>>,
196) -> Vec<DetectBoxQuantized<SCORE>> {
197 boxes.par_sort_by(|a, b| b.score.cmp(&a.score));
198
199 if iou >= 1.0 {
202 return boxes;
203 }
204
205 let min_val = SCORE::min_value();
206 for i in 0..boxes.len() {
207 if boxes[i].score <= min_val {
208 continue;
209 }
210 for j in (i + 1)..boxes.len() {
211 if boxes[j].score <= min_val {
212 continue;
213 }
214 if boxes[j].label == boxes[i].label && jaccard(&boxes[j].bbox, &boxes[i].bbox, iou) {
216 boxes[j].score = min_val;
217 }
218 }
219 }
220 boxes.into_iter().filter(|b| b.score > min_val).collect()
221}
222
223#[doc(hidden)]
229#[must_use]
230pub fn nms_extra_class_aware_int<
231 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
232 E: Send + Sync,
233>(
234 iou: f32,
235 mut boxes: Vec<(DetectBoxQuantized<SCORE>, E)>,
236) -> Vec<(DetectBoxQuantized<SCORE>, E)> {
237 boxes.par_sort_by(|a, b| b.0.score.cmp(&a.0.score));
238
239 if iou >= 1.0 {
242 return boxes;
243 }
244
245 let min_val = SCORE::min_value();
246 for i in 0..boxes.len() {
247 if boxes[i].0.score <= min_val {
248 continue;
249 }
250 for j in (i + 1)..boxes.len() {
251 if boxes[j].0.score <= min_val {
252 continue;
253 }
254 if boxes[j].0.label == boxes[i].0.label
256 && jaccard(&boxes[j].0.bbox, &boxes[i].0.bbox, iou)
257 {
258 boxes[j].0.score = min_val;
259 }
260 }
261 }
262 boxes.into_iter().filter(|b| b.0.score > min_val).collect()
263}
264
265#[doc(hidden)]
280pub fn quantize_score_threshold<T: PrimInt + AsPrimitive<f32>>(score: f32, quant: Quantization) -> T
281where
282 f32: AsPrimitive<T>,
283{
284 if quant.scale == 0.0 {
285 return T::max_value();
286 }
287 let v = (score / quant.scale + quant.zero_point as f32).ceil();
288 let v = v.clamp(T::min_value().as_(), T::max_value().as_());
289 v.as_()
290}