1use ndarray::{Array2, ArrayView2, ArrayView3};
5use num_traits::{AsPrimitive, Float, PrimInt};
6
7use crate::{
8 byte::{nms_int, postprocess_boxes_quant, quantize_score_threshold},
9 configs::Detection,
10 dequant_detect_box,
11 float::{nms_float, postprocess_boxes_float},
12 BBoxTypeTrait, DecoderError, DetectBox, Quantization, XYWH, XYXY,
13};
14
15#[derive(Debug, Clone, PartialEq)]
18pub(crate) struct ModelPackDetectionConfig {
19 pub(crate) anchors: Vec<[f32; 2]>,
20 pub(crate) quantization: Option<Quantization>,
21}
22
23impl TryFrom<&Detection> for ModelPackDetectionConfig {
24 type Error = DecoderError;
25
26 fn try_from(value: &Detection) -> Result<Self, DecoderError> {
27 Ok(Self {
28 anchors: value.anchors.clone().ok_or_else(|| {
29 DecoderError::InvalidConfig("ModelPack Split Detection missing anchors".to_string())
30 })?,
31 quantization: value.quantization.map(Quantization::from),
32 })
33 }
34}
35
36pub(crate) fn decode_modelpack_det<
47 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
48 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
49>(
50 boxes_tensor: (ArrayView2<BOX>, Quantization),
51 scores_tensor: (ArrayView2<SCORE>, Quantization),
52 score_threshold: f32,
53 iou_threshold: f32,
54 max_det: usize,
55 output_boxes: &mut Vec<DetectBox>,
56) where
57 f32: AsPrimitive<SCORE>,
58{
59 impl_modelpack_quant::<XYXY, _, _>(
60 boxes_tensor,
61 scores_tensor,
62 score_threshold,
63 iou_threshold,
64 max_det,
65 output_boxes,
66 )
67}
68
69pub(crate) fn decode_modelpack_float<
79 BOX: Float + AsPrimitive<f32> + Send + Sync,
80 SCORE: Float + AsPrimitive<f32> + Send + Sync,
81>(
82 boxes_tensor: ArrayView2<BOX>,
83 scores_tensor: ArrayView2<SCORE>,
84 score_threshold: f32,
85 iou_threshold: f32,
86 max_det: usize,
87 output_boxes: &mut Vec<DetectBox>,
88) where
89 f32: AsPrimitive<SCORE>,
90{
91 impl_modelpack_float::<XYXY, _, _>(
92 boxes_tensor,
93 scores_tensor,
94 score_threshold,
95 iou_threshold,
96 max_det,
97 output_boxes,
98 )
99}
100
101#[cfg(test)]
112pub(crate) fn decode_modelpack_split_quant<D: AsPrimitive<f32>>(
113 outputs: &[ArrayView3<D>],
114 configs: &[ModelPackDetectionConfig],
115 score_threshold: f32,
116 iou_threshold: f32,
117 max_det: usize,
118 output_boxes: &mut Vec<DetectBox>,
119) {
120 impl_modelpack_split_quant::<XYWH, D>(
121 outputs,
122 configs,
123 score_threshold,
124 iou_threshold,
125 max_det,
126 output_boxes,
127 )
128}
129
130pub(crate) fn decode_modelpack_split_float<D: AsPrimitive<f32>>(
141 outputs: &[ArrayView3<D>],
142 configs: &[ModelPackDetectionConfig],
143 score_threshold: f32,
144 iou_threshold: f32,
145 max_det: usize,
146 output_boxes: &mut Vec<DetectBox>,
147) {
148 impl_modelpack_split_float::<XYWH, D>(
149 outputs,
150 configs,
151 score_threshold,
152 iou_threshold,
153 max_det,
154 output_boxes,
155 );
156}
157pub(crate) fn impl_modelpack_quant<
166 B: BBoxTypeTrait,
167 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
168 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
169>(
170 boxes: (ArrayView2<BOX>, Quantization),
171 scores: (ArrayView2<SCORE>, Quantization),
172 score_threshold: f32,
173 iou_threshold: f32,
174 max_det: usize,
175 output_boxes: &mut Vec<DetectBox>,
176) where
177 f32: AsPrimitive<SCORE>,
178{
179 let (boxes_tensor, quant_boxes) = boxes;
180 let (scores_tensor, quant_scores) = scores;
181 let boxes = {
182 let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
183 postprocess_boxes_quant::<B, _, _>(
184 score_threshold,
185 boxes_tensor,
186 scores_tensor,
187 quant_boxes,
188 )
189 };
190 let boxes = nms_int(iou_threshold, Some(max_det), boxes);
191 output_boxes.clear();
192 for b in boxes.into_iter().take(max_det) {
193 output_boxes.push(dequant_detect_box(&b, quant_scores));
194 }
195}
196
197pub(crate) fn impl_modelpack_float<
206 B: BBoxTypeTrait,
207 BOX: Float + AsPrimitive<f32> + Send + Sync,
208 SCORE: Float + AsPrimitive<f32> + Send + Sync,
209>(
210 boxes_tensor: ArrayView2<BOX>,
211 scores_tensor: ArrayView2<SCORE>,
212 score_threshold: f32,
213 iou_threshold: f32,
214 max_det: usize,
215 output_boxes: &mut Vec<DetectBox>,
216) where
217 f32: AsPrimitive<SCORE>,
218{
219 let boxes =
220 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
221 let boxes = nms_float(iou_threshold, Some(max_det), boxes);
222 output_boxes.clear();
223 for b in boxes.into_iter().take(max_det) {
224 output_boxes.push(b);
225 }
226}
227
228#[cfg(test)]
237pub(crate) fn impl_modelpack_split_quant<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
238 outputs: &[ArrayView3<D>],
239 configs: &[ModelPackDetectionConfig],
240 score_threshold: f32,
241 iou_threshold: f32,
242 max_det: usize,
243 output_boxes: &mut Vec<DetectBox>,
244) {
245 let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_quant(outputs, configs);
246 let boxes = postprocess_boxes_float::<B, _, _>(
247 score_threshold,
248 boxes_tensor.view(),
249 scores_tensor.view(),
250 );
251 let boxes = nms_float(iou_threshold, Some(max_det), boxes);
252 output_boxes.clear();
253 for b in boxes.into_iter().take(max_det) {
254 output_boxes.push(b);
255 }
256}
257
258pub(crate) fn impl_modelpack_split_float<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
268 outputs: &[ArrayView3<D>],
269 configs: &[ModelPackDetectionConfig],
270 score_threshold: f32,
271 iou_threshold: f32,
272 max_det: usize,
273 output_boxes: &mut Vec<DetectBox>,
274) {
275 let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_float(outputs, configs);
276 let boxes = postprocess_boxes_float::<B, _, _>(
277 score_threshold,
278 boxes_tensor.view(),
279 scores_tensor.view(),
280 );
281 let boxes = nms_float(iou_threshold, Some(max_det), boxes);
282 output_boxes.clear();
283 for b in boxes.into_iter().take(max_det) {
284 output_boxes.push(b);
285 }
286}
287
288#[cfg(test)]
292pub(crate) fn postprocess_modelpack_split_quant<T: AsPrimitive<f32>>(
293 outputs: &[ArrayView3<T>],
294 config: &[ModelPackDetectionConfig],
295) -> (Array2<f32>, Array2<f32>) {
296 let mut total_capacity = 0;
297 let mut nc = 0;
298 for (p, detail) in outputs.iter().zip(config) {
299 let shape = p.shape();
300 let na = detail.anchors.len();
301 nc = *shape
302 .last()
303 .expect("Shape must have at least one dimension")
304 / na
305 - 5;
306 total_capacity += shape[0] * shape[1] * na;
307 }
308 let mut bboxes = Vec::with_capacity(total_capacity * 4);
309 let mut bscores = Vec::with_capacity(total_capacity * nc);
310
311 for (p, detail) in outputs.iter().zip(config) {
312 let anchors = &detail.anchors;
313 let na = detail.anchors.len();
314 let shape = p.shape();
315 assert_eq!(
316 shape.iter().product::<usize>(),
317 p.len(),
318 "Shape product doesn't match tensor length"
319 );
320 let p_sigmoid = if let Some(quant) = &detail.quantization {
321 let scaled_zero = -quant.zero_point as f32 * quant.scale;
322 p.mapv(|x| fast_sigmoid_impl(x.as_() * quant.scale + scaled_zero))
323 } else {
324 p.mapv(|x| fast_sigmoid_impl(x.as_()))
325 };
326 let p_sigmoid = p_sigmoid.as_standard_layout();
327
328 let p = p_sigmoid
330 .as_slice()
331 .expect("Sigmoids are not in standard layout");
332 let height = shape[0];
333 let width = shape[1];
334
335 let div_width = 1.0 / width as f32;
336 let div_height = 1.0 / height as f32;
337
338 let mut grid = Vec::with_capacity(height * width * na * 2);
339 for y in 0..height {
340 for x in 0..width {
341 for _ in 0..na {
342 grid.push(x as f32 - 0.5);
343 grid.push(y as f32 - 0.5);
344 }
345 }
346 }
347 for ((p, g), anchor) in p
348 .chunks_exact(nc + 5)
349 .zip(grid.chunks_exact(2))
350 .zip(anchors.iter().cycle())
351 {
352 let (x, y) = (p[0], p[1]);
353 let x = (x * 2.0 + g[0]) * div_width;
354 let y = (y * 2.0 + g[1]) * div_height;
355 let (w, h) = (p[2], p[3]);
356 let w = w * w * 4.0 * anchor[0];
357 let h = h * h * 4.0 * anchor[1];
358
359 bboxes.push(x);
360 bboxes.push(y);
361 bboxes.push(w);
362 bboxes.push(h);
363
364 if nc == 1 {
365 bscores.push(p[4]);
366 } else {
367 let obj = p[4];
368 let probs = p[5..].iter().map(|x| *x * obj);
369 bscores.extend(probs);
370 }
371 }
372 }
373 debug_assert_eq!(bboxes.len() % 4, 0);
376 debug_assert_eq!(bscores.len() % nc, 0);
377
378 let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
379 .expect("Failed to create bboxes array");
380 let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
381 .expect("Failed to create bscores array");
382 (bboxes, bscores)
383}
384
385pub(crate) fn postprocess_modelpack_split_float<T: AsPrimitive<f32>>(
389 outputs: &[ArrayView3<T>],
390 config: &[ModelPackDetectionConfig],
391) -> (Array2<f32>, Array2<f32>) {
392 let mut total_capacity = 0;
393 let mut nc = 0;
394 for (p, detail) in outputs.iter().zip(config) {
395 let shape = p.shape();
396 let na = detail.anchors.len();
397 nc = *shape
398 .last()
399 .expect("Shape must have at least one dimension")
400 / na
401 - 5;
402 total_capacity += shape[0] * shape[1] * na;
403 }
404 let mut bboxes = Vec::with_capacity(total_capacity * 4);
405 let mut bscores = Vec::with_capacity(total_capacity * nc);
406
407 for (p, detail) in outputs.iter().zip(config) {
408 let anchors = &detail.anchors;
409 let na = detail.anchors.len();
410 let shape = p.shape();
411 assert_eq!(
412 shape.iter().product::<usize>(),
413 p.len(),
414 "Shape product doesn't match tensor length"
415 );
416 let p_sigmoid = p.mapv(|x| fast_sigmoid_impl(x.as_()));
417 let p_sigmoid = p_sigmoid.as_standard_layout();
418
419 let p = p_sigmoid
421 .as_slice()
422 .expect("Sigmoids are not in standard layout");
423 let height = shape[0];
424 let width = shape[1];
425
426 let div_width = 1.0 / width as f32;
427 let div_height = 1.0 / height as f32;
428
429 let mut grid = Vec::with_capacity(height * width * na * 2);
430 for y in 0..height {
431 for x in 0..width {
432 for _ in 0..na {
433 grid.push(x as f32 - 0.5);
434 grid.push(y as f32 - 0.5);
435 }
436 }
437 }
438 for ((p, g), anchor) in p
439 .chunks_exact(nc + 5)
440 .zip(grid.chunks_exact(2))
441 .zip(anchors.iter().cycle())
442 {
443 let (x, y) = (p[0], p[1]);
444 let x = (x * 2.0 + g[0]) * div_width;
445 let y = (y * 2.0 + g[1]) * div_height;
446 let (w, h) = (p[2], p[3]);
447 let w = w * w * 4.0 * anchor[0];
448 let h = h * h * 4.0 * anchor[1];
449
450 bboxes.push(x);
451 bboxes.push(y);
452 bboxes.push(w);
453 bboxes.push(h);
454
455 if nc == 1 {
456 bscores.push(p[4]);
457 } else {
458 let obj = p[4];
459 let probs = p[5..].iter().map(|x| *x * obj);
460 bscores.extend(probs);
461 }
462 }
463 }
464 debug_assert_eq!(bboxes.len() % 4, 0);
467 debug_assert_eq!(bscores.len() % nc, 0);
468
469 let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
470 .expect("Failed to create bboxes array");
471 let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
472 .expect("Failed to create bscores array");
473 (bboxes, bscores)
474}
475
476#[inline(always)]
477fn fast_sigmoid_impl(f: f32) -> f32 {
478 if f.abs() > 80.0 {
479 f.signum() * 0.5 + 0.5
480 } else {
481 1.0 / (1.0 + fast_math::exp_raw(-f))
483 }
484}
485
486pub(crate) fn modelpack_segmentation_to_mask(segmentation: ArrayView3<u8>) -> Array2<u8> {
495 use argminmax::ArgMinMax;
496 assert!(
497 segmentation.shape()[2] > 1,
498 "Model Instance Segmentation should have shape (H, W, x) where x > 1"
499 );
500 let height = segmentation.shape()[0];
501 let width = segmentation.shape()[1];
502 let channels = segmentation.shape()[2];
503 let segmentation = segmentation.as_standard_layout();
504 let seg = segmentation
506 .as_slice()
507 .expect("Segmentation is not in standard layout");
508 let argmax = seg
509 .chunks_exact(channels)
510 .map(|x| x.argmax() as u8)
511 .collect::<Vec<_>>();
512
513 Array2::from_shape_vec((height, width), argmax).expect("Failed to create mask array")
514}
515
516#[cfg(test)]
517#[cfg_attr(coverage_nightly, coverage(off))]
518mod modelpack_tests {
519 #![allow(clippy::excessive_precision)]
520 use ndarray::Array3;
521
522 use crate::configs::{DecoderType, DimName};
523
524 use super::*;
525 #[test]
526 fn test_detection_config() {
527 let det = Detection {
528 anchors: Some(vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]]),
529 quantization: Some((0.1, 128).into()),
530 decoder: DecoderType::ModelPack,
531 shape: vec![1, 9, 17, 18],
532 dshape: vec![
533 (DimName::Batch, 1),
534 (DimName::Height, 9),
535 (DimName::Width, 17),
536 (DimName::NumAnchorsXFeatures, 18),
537 ],
538 normalized: Some(true),
539 };
540 let config = ModelPackDetectionConfig::try_from(&det).unwrap();
541 assert_eq!(
542 config,
543 ModelPackDetectionConfig {
544 anchors: vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]],
545 quantization: Some(Quantization::new(0.1, 128)),
546 }
547 );
548
549 let det = Detection {
550 anchors: None,
551 quantization: Some((0.1, 128).into()),
552 decoder: DecoderType::ModelPack,
553 shape: vec![1, 9, 17, 18],
554 dshape: vec![
555 (DimName::Batch, 1),
556 (DimName::Height, 9),
557 (DimName::Width, 17),
558 (DimName::NumAnchorsXFeatures, 18),
559 ],
560 normalized: Some(true),
561 };
562 let result = ModelPackDetectionConfig::try_from(&det);
563 assert!(
564 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors")
565 );
566 }
567
568 #[test]
569 fn test_fast_sigmoid() {
570 fn full_sigmoid(x: f32) -> f32 {
571 1.0 / (1.0 + (-x).exp())
572 }
573 for i in -2550..=2550 {
574 let x = i as f32 * 0.1;
575 let fast = fast_sigmoid_impl(x);
576 let full = full_sigmoid(x);
577 let diff = (fast - full).abs();
578 assert!(
579 diff < 0.0005,
580 "Fast sigmoid differs from full sigmoid by {} at input {}",
581 diff,
582 x
583 );
584 }
585 }
586
587 #[test]
588 fn test_modelpack_segmentation_to_mask() {
589 let seg = Array3::from_shape_vec(
590 (2, 2, 3),
591 vec![
592 0u8, 10, 5, 20, 15, 25, 30, 5, 10, 0, 0, 0, ],
597 )
598 .unwrap();
599 let mask = modelpack_segmentation_to_mask(seg.view());
600 let expected_mask = Array2::from_shape_vec((2, 2), vec![1u8, 2, 0, 0]).unwrap();
601 assert_eq!(mask, expected_mask);
602 }
603
604 #[test]
605 #[should_panic(
606 expected = "Model Instance Segmentation should have shape (H, W, x) where x > 1"
607 )]
608 fn test_modelpack_segmentation_to_mask_invalid() {
609 let seg = Array3::from_shape_vec((2, 2, 1), vec![0u8, 10, 20, 30]).unwrap();
610 let _ = modelpack_segmentation_to_mask(seg.view());
611 }
612}