1use ndarray::{Array2, ArrayView2, ArrayView3};
5use num_traits::{AsPrimitive, Float, PrimInt};
6
7use crate::{
8 byte::{nms_int, postprocess_boxes_quant, quantize_score_threshold},
9 configs::Detection,
10 dequant_detect_box,
11 float::{nms_float, postprocess_boxes_float},
12 BBoxTypeTrait, DecoderError, DetectBox, Quantization, XYWH, XYXY,
13};
14
15#[derive(Debug, Clone, PartialEq)]
18pub struct ModelPackDetectionConfig {
19 pub anchors: Vec<[f32; 2]>,
20 pub quantization: Option<Quantization>,
21}
22
23impl TryFrom<&Detection> for ModelPackDetectionConfig {
24 type Error = DecoderError;
25
26 fn try_from(value: &Detection) -> Result<Self, DecoderError> {
27 Ok(Self {
28 anchors: value.anchors.clone().ok_or_else(|| {
29 DecoderError::InvalidConfig("ModelPack Split Detection missing anchors".to_string())
30 })?,
31 quantization: value.quantization.map(Quantization::from),
32 })
33 }
34}
35
36pub fn decode_modelpack_det<
47 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
48 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
49>(
50 boxes_tensor: (ArrayView2<BOX>, Quantization),
51 scores_tensor: (ArrayView2<SCORE>, Quantization),
52 score_threshold: f32,
53 iou_threshold: f32,
54 output_boxes: &mut Vec<DetectBox>,
55) where
56 f32: AsPrimitive<SCORE>,
57{
58 impl_modelpack_quant::<XYXY, _, _>(
59 boxes_tensor,
60 scores_tensor,
61 score_threshold,
62 iou_threshold,
63 output_boxes,
64 )
65}
66
67pub fn decode_modelpack_float<
77 BOX: Float + AsPrimitive<f32> + Send + Sync,
78 SCORE: Float + AsPrimitive<f32> + Send + Sync,
79>(
80 boxes_tensor: ArrayView2<BOX>,
81 scores_tensor: ArrayView2<SCORE>,
82 score_threshold: f32,
83 iou_threshold: f32,
84 output_boxes: &mut Vec<DetectBox>,
85) where
86 f32: AsPrimitive<SCORE>,
87{
88 impl_modelpack_float::<XYXY, _, _>(
89 boxes_tensor,
90 scores_tensor,
91 score_threshold,
92 iou_threshold,
93 output_boxes,
94 )
95}
96
97pub fn decode_modelpack_split_quant<D: AsPrimitive<f32>>(
108 outputs: &[ArrayView3<D>],
109 configs: &[ModelPackDetectionConfig],
110 score_threshold: f32,
111 iou_threshold: f32,
112 output_boxes: &mut Vec<DetectBox>,
113) {
114 impl_modelpack_split_quant::<XYWH, D>(
115 outputs,
116 configs,
117 score_threshold,
118 iou_threshold,
119 output_boxes,
120 )
121}
122
123pub fn decode_modelpack_split_float<D: AsPrimitive<f32>>(
134 outputs: &[ArrayView3<D>],
135 configs: &[ModelPackDetectionConfig],
136 score_threshold: f32,
137 iou_threshold: f32,
138 output_boxes: &mut Vec<DetectBox>,
139) {
140 impl_modelpack_split_float::<XYWH, D>(
141 outputs,
142 configs,
143 score_threshold,
144 iou_threshold,
145 output_boxes,
146 );
147}
148pub(crate) fn impl_modelpack_quant<
157 B: BBoxTypeTrait,
158 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
159 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
160>(
161 boxes: (ArrayView2<BOX>, Quantization),
162 scores: (ArrayView2<SCORE>, Quantization),
163 score_threshold: f32,
164 iou_threshold: f32,
165 output_boxes: &mut Vec<DetectBox>,
166) where
167 f32: AsPrimitive<SCORE>,
168{
169 let (boxes_tensor, quant_boxes) = boxes;
170 let (scores_tensor, quant_scores) = scores;
171 let boxes = {
172 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
173 postprocess_boxes_quant::<B, _, _>(
174 score_threshold,
175 boxes_tensor,
176 scores_tensor,
177 quant_boxes,
178 )
179 };
180 let boxes = nms_int(iou_threshold, boxes);
181 let len = output_boxes.capacity().min(boxes.len());
182 output_boxes.clear();
183 for b in boxes.into_iter().take(len) {
184 output_boxes.push(dequant_detect_box(&b, quant_scores));
185 }
186}
187
188pub(crate) fn impl_modelpack_float<
197 B: BBoxTypeTrait,
198 BOX: Float + AsPrimitive<f32> + Send + Sync,
199 SCORE: Float + AsPrimitive<f32> + Send + Sync,
200>(
201 boxes_tensor: ArrayView2<BOX>,
202 scores_tensor: ArrayView2<SCORE>,
203 score_threshold: f32,
204 iou_threshold: f32,
205 output_boxes: &mut Vec<DetectBox>,
206) where
207 f32: AsPrimitive<SCORE>,
208{
209 let boxes =
210 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
211 let boxes = nms_float(iou_threshold, boxes);
212 let len = output_boxes.capacity().min(boxes.len());
213 output_boxes.clear();
214 for b in boxes.into_iter().take(len) {
215 output_boxes.push(b);
216 }
217}
218
219pub(crate) fn impl_modelpack_split_quant<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
228 outputs: &[ArrayView3<D>],
229 configs: &[ModelPackDetectionConfig],
230 score_threshold: f32,
231 iou_threshold: f32,
232 output_boxes: &mut Vec<DetectBox>,
233) {
234 let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_quant(outputs, configs);
235 let boxes = postprocess_boxes_float::<B, _, _>(
236 score_threshold,
237 boxes_tensor.view(),
238 scores_tensor.view(),
239 );
240 let boxes = nms_float(iou_threshold, boxes);
241 let len = output_boxes.capacity().min(boxes.len());
242 output_boxes.clear();
243 for b in boxes.into_iter().take(len) {
244 output_boxes.push(b);
245 }
246}
247
248pub(crate) fn impl_modelpack_split_float<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
258 outputs: &[ArrayView3<D>],
259 configs: &[ModelPackDetectionConfig],
260 score_threshold: f32,
261 iou_threshold: f32,
262 output_boxes: &mut Vec<DetectBox>,
263) {
264 let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_float(outputs, configs);
265 let boxes = postprocess_boxes_float::<B, _, _>(
266 score_threshold,
267 boxes_tensor.view(),
268 scores_tensor.view(),
269 );
270 let boxes = nms_float(iou_threshold, boxes);
271 let len = output_boxes.capacity().min(boxes.len());
272 output_boxes.clear();
273 for b in boxes.into_iter().take(len) {
274 output_boxes.push(b);
275 }
276}
277
278pub(crate) fn postprocess_modelpack_split_quant<T: AsPrimitive<f32>>(
282 outputs: &[ArrayView3<T>],
283 config: &[ModelPackDetectionConfig],
284) -> (Array2<f32>, Array2<f32>) {
285 let mut total_capacity = 0;
286 let mut nc = 0;
287 for (p, detail) in outputs.iter().zip(config) {
288 let shape = p.shape();
289 let na = detail.anchors.len();
290 nc = *shape
291 .last()
292 .expect("Shape must have at least one dimension")
293 / na
294 - 5;
295 total_capacity += shape[0] * shape[1] * na;
296 }
297 let mut bboxes = Vec::with_capacity(total_capacity * 4);
298 let mut bscores = Vec::with_capacity(total_capacity * nc);
299
300 for (p, detail) in outputs.iter().zip(config) {
301 let anchors = &detail.anchors;
302 let na = detail.anchors.len();
303 let shape = p.shape();
304 assert_eq!(
305 shape.iter().product::<usize>(),
306 p.len(),
307 "Shape product doesn't match tensor length"
308 );
309 let p_sigmoid = if let Some(quant) = &detail.quantization {
310 let scaled_zero = -quant.zero_point as f32 * quant.scale;
311 p.mapv(|x| fast_sigmoid_impl(x.as_() * quant.scale + scaled_zero))
312 } else {
313 p.mapv(|x| fast_sigmoid_impl(x.as_()))
314 };
315 let p_sigmoid = p_sigmoid.as_standard_layout();
316
317 let p = p_sigmoid
319 .as_slice()
320 .expect("Sigmoids are not in standard layout");
321 let height = shape[0];
322 let width = shape[1];
323
324 let div_width = 1.0 / width as f32;
325 let div_height = 1.0 / height as f32;
326
327 let mut grid = Vec::with_capacity(height * width * na * 2);
328 for y in 0..height {
329 for x in 0..width {
330 for _ in 0..na {
331 grid.push(x as f32 - 0.5);
332 grid.push(y as f32 - 0.5);
333 }
334 }
335 }
336 for ((p, g), anchor) in p
337 .chunks_exact(nc + 5)
338 .zip(grid.chunks_exact(2))
339 .zip(anchors.iter().cycle())
340 {
341 let (x, y) = (p[0], p[1]);
342 let x = (x * 2.0 + g[0]) * div_width;
343 let y = (y * 2.0 + g[1]) * div_height;
344 let (w, h) = (p[2], p[3]);
345 let w = w * w * 4.0 * anchor[0];
346 let h = h * h * 4.0 * anchor[1];
347
348 bboxes.push(x);
349 bboxes.push(y);
350 bboxes.push(w);
351 bboxes.push(h);
352
353 if nc == 1 {
354 bscores.push(p[4]);
355 } else {
356 let obj = p[4];
357 let probs = p[5..].iter().map(|x| *x * obj);
358 bscores.extend(probs);
359 }
360 }
361 }
362 debug_assert_eq!(bboxes.len() % 4, 0);
365 debug_assert_eq!(bscores.len() % nc, 0);
366
367 let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
368 .expect("Failed to create bboxes array");
369 let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
370 .expect("Failed to create bscores array");
371 (bboxes, bscores)
372}
373
374pub(crate) fn postprocess_modelpack_split_float<T: AsPrimitive<f32>>(
378 outputs: &[ArrayView3<T>],
379 config: &[ModelPackDetectionConfig],
380) -> (Array2<f32>, Array2<f32>) {
381 let mut total_capacity = 0;
382 let mut nc = 0;
383 for (p, detail) in outputs.iter().zip(config) {
384 let shape = p.shape();
385 let na = detail.anchors.len();
386 nc = *shape
387 .last()
388 .expect("Shape must have at least one dimension")
389 / na
390 - 5;
391 total_capacity += shape[0] * shape[1] * na;
392 }
393 let mut bboxes = Vec::with_capacity(total_capacity * 4);
394 let mut bscores = Vec::with_capacity(total_capacity * nc);
395
396 for (p, detail) in outputs.iter().zip(config) {
397 let anchors = &detail.anchors;
398 let na = detail.anchors.len();
399 let shape = p.shape();
400 assert_eq!(
401 shape.iter().product::<usize>(),
402 p.len(),
403 "Shape product doesn't match tensor length"
404 );
405 let p_sigmoid = p.mapv(|x| fast_sigmoid_impl(x.as_()));
406 let p_sigmoid = p_sigmoid.as_standard_layout();
407
408 let p = p_sigmoid
410 .as_slice()
411 .expect("Sigmoids are not in standard layout");
412 let height = shape[0];
413 let width = shape[1];
414
415 let div_width = 1.0 / width as f32;
416 let div_height = 1.0 / height as f32;
417
418 let mut grid = Vec::with_capacity(height * width * na * 2);
419 for y in 0..height {
420 for x in 0..width {
421 for _ in 0..na {
422 grid.push(x as f32 - 0.5);
423 grid.push(y as f32 - 0.5);
424 }
425 }
426 }
427 for ((p, g), anchor) in p
428 .chunks_exact(nc + 5)
429 .zip(grid.chunks_exact(2))
430 .zip(anchors.iter().cycle())
431 {
432 let (x, y) = (p[0], p[1]);
433 let x = (x * 2.0 + g[0]) * div_width;
434 let y = (y * 2.0 + g[1]) * div_height;
435 let (w, h) = (p[2], p[3]);
436 let w = w * w * 4.0 * anchor[0];
437 let h = h * h * 4.0 * anchor[1];
438
439 bboxes.push(x);
440 bboxes.push(y);
441 bboxes.push(w);
442 bboxes.push(h);
443
444 if nc == 1 {
445 bscores.push(p[4]);
446 } else {
447 let obj = p[4];
448 let probs = p[5..].iter().map(|x| *x * obj);
449 bscores.extend(probs);
450 }
451 }
452 }
453 debug_assert_eq!(bboxes.len() % 4, 0);
456 debug_assert_eq!(bscores.len() % nc, 0);
457
458 let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
459 .expect("Failed to create bboxes array");
460 let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
461 .expect("Failed to create bscores array");
462 (bboxes, bscores)
463}
464
465#[inline(always)]
466fn fast_sigmoid_impl(f: f32) -> f32 {
467 if f.abs() > 80.0 {
468 f.signum() * 0.5 + 0.5
469 } else {
470 1.0 / (1.0 + fast_math::exp_raw(-f))
472 }
473}
474
475pub fn modelpack_segmentation_to_mask(segmentation: ArrayView3<u8>) -> Array2<u8> {
484 use argminmax::ArgMinMax;
485 assert!(
486 segmentation.shape()[2] > 1,
487 "Model Instance Segmentation should have shape (H, W, x) where x > 1"
488 );
489 let height = segmentation.shape()[0];
490 let width = segmentation.shape()[1];
491 let channels = segmentation.shape()[2];
492 let segmentation = segmentation.as_standard_layout();
493 let seg = segmentation
495 .as_slice()
496 .expect("Segmentation is not in standard layout");
497 let argmax = seg
498 .chunks_exact(channels)
499 .map(|x| x.argmax() as u8)
500 .collect::<Vec<_>>();
501
502 Array2::from_shape_vec((height, width), argmax).expect("Failed to create mask array")
503}
504
505#[cfg(test)]
506#[cfg_attr(coverage_nightly, coverage(off))]
507mod modelpack_tests {
508 #![allow(clippy::excessive_precision)]
509 use ndarray::Array3;
510
511 use crate::configs::{DecoderType, DimName};
512
513 use super::*;
514 #[test]
515 fn test_detection_config() {
516 let det = Detection {
517 anchors: Some(vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]]),
518 quantization: Some((0.1, 128).into()),
519 decoder: DecoderType::ModelPack,
520 shape: vec![1, 9, 17, 18],
521 dshape: vec![
522 (DimName::Batch, 1),
523 (DimName::Height, 9),
524 (DimName::Width, 17),
525 (DimName::NumAnchorsXFeatures, 18),
526 ],
527 normalized: Some(true),
528 };
529 let config = ModelPackDetectionConfig::try_from(&det).unwrap();
530 assert_eq!(
531 config,
532 ModelPackDetectionConfig {
533 anchors: vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]],
534 quantization: Some(Quantization::new(0.1, 128)),
535 }
536 );
537
538 let det = Detection {
539 anchors: None,
540 quantization: Some((0.1, 128).into()),
541 decoder: DecoderType::ModelPack,
542 shape: vec![1, 9, 17, 18],
543 dshape: vec![
544 (DimName::Batch, 1),
545 (DimName::Height, 9),
546 (DimName::Width, 17),
547 (DimName::NumAnchorsXFeatures, 18),
548 ],
549 normalized: Some(true),
550 };
551 let result = ModelPackDetectionConfig::try_from(&det);
552 assert!(
553 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors")
554 );
555 }
556
557 #[test]
558 fn test_fast_sigmoid() {
559 fn full_sigmoid(x: f32) -> f32 {
560 1.0 / (1.0 + (-x).exp())
561 }
562 for i in -2550..=2550 {
563 let x = i as f32 * 0.1;
564 let fast = fast_sigmoid_impl(x);
565 let full = full_sigmoid(x);
566 let diff = (fast - full).abs();
567 assert!(
568 diff < 0.0005,
569 "Fast sigmoid differs from full sigmoid by {} at input {}",
570 diff,
571 x
572 );
573 }
574 }
575
576 #[test]
577 fn test_modelpack_segmentation_to_mask() {
578 let seg = Array3::from_shape_vec(
579 (2, 2, 3),
580 vec![
581 0u8, 10, 5, 20, 15, 25, 30, 5, 10, 0, 0, 0, ],
586 )
587 .unwrap();
588 let mask = modelpack_segmentation_to_mask(seg.view());
589 let expected_mask = Array2::from_shape_vec((2, 2), vec![1u8, 2, 0, 0]).unwrap();
590 assert_eq!(mask, expected_mask);
591 }
592
593 #[test]
594 #[should_panic(
595 expected = "Model Instance Segmentation should have shape (H, W, x) where x > 1"
596 )]
597 fn test_modelpack_segmentation_to_mask_invalid() {
598 let seg = Array3::from_shape_vec((2, 2, 1), vec![0u8, 10, 20, 30]).unwrap();
599 let _ = modelpack_segmentation_to_mask(seg.view());
600 }
601}