1use ndarray::{Array2, ArrayView2, ArrayView3};
5use num_traits::{AsPrimitive, Float, PrimInt};
6
7use crate::{
8 byte::{nms_int, postprocess_boxes_quant, quantize_score_threshold},
9 configs::Detection,
10 dequant_detect_box,
11 float::{nms_float, postprocess_boxes_float},
12 BBoxTypeTrait, DecoderError, DetectBox, Quantization, XYWH, XYXY,
13};
14
15#[derive(Debug, Clone, PartialEq)]
18pub struct ModelPackDetectionConfig {
19 pub anchors: Vec<[f32; 2]>,
20 pub quantization: Option<Quantization>,
21}
22
23impl TryFrom<&Detection> for ModelPackDetectionConfig {
24 type Error = DecoderError;
25
26 fn try_from(value: &Detection) -> Result<Self, DecoderError> {
27 Ok(Self {
28 anchors: value.anchors.clone().ok_or_else(|| {
29 DecoderError::InvalidConfig("ModelPack Split Detection missing anchors".to_string())
30 })?,
31 quantization: value.quantization.map(Quantization::from),
32 })
33 }
34}
35
36pub fn decode_modelpack_det<
47 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
48 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
49>(
50 boxes_tensor: (ArrayView2<BOX>, Quantization),
51 scores_tensor: (ArrayView2<SCORE>, Quantization),
52 score_threshold: f32,
53 iou_threshold: f32,
54 output_boxes: &mut Vec<DetectBox>,
55) where
56 f32: AsPrimitive<SCORE>,
57{
58 impl_modelpack_quant::<XYXY, _, _>(
59 boxes_tensor,
60 scores_tensor,
61 score_threshold,
62 iou_threshold,
63 output_boxes,
64 )
65}
66
67pub fn decode_modelpack_float<
77 BOX: Float + AsPrimitive<f32> + Send + Sync,
78 SCORE: Float + AsPrimitive<f32> + Send + Sync,
79>(
80 boxes_tensor: ArrayView2<BOX>,
81 scores_tensor: ArrayView2<SCORE>,
82 score_threshold: f32,
83 iou_threshold: f32,
84 output_boxes: &mut Vec<DetectBox>,
85) where
86 f32: AsPrimitive<SCORE>,
87{
88 impl_modelpack_float::<XYXY, _, _>(
89 boxes_tensor,
90 scores_tensor,
91 score_threshold,
92 iou_threshold,
93 output_boxes,
94 )
95}
96
97pub fn decode_modelpack_split_quant<D: AsPrimitive<f32>>(
108 outputs: &[ArrayView3<D>],
109 configs: &[ModelPackDetectionConfig],
110 score_threshold: f32,
111 iou_threshold: f32,
112 output_boxes: &mut Vec<DetectBox>,
113) {
114 impl_modelpack_split_quant::<XYWH, D>(
115 outputs,
116 configs,
117 score_threshold,
118 iou_threshold,
119 output_boxes,
120 )
121}
122
123pub fn decode_modelpack_split_float<D: AsPrimitive<f32>>(
134 outputs: &[ArrayView3<D>],
135 configs: &[ModelPackDetectionConfig],
136 score_threshold: f32,
137 iou_threshold: f32,
138 output_boxes: &mut Vec<DetectBox>,
139) {
140 impl_modelpack_split_float::<XYWH, D>(
141 outputs,
142 configs,
143 score_threshold,
144 iou_threshold,
145 output_boxes,
146 );
147}
148#[doc(hidden)]
157pub fn impl_modelpack_quant<
158 B: BBoxTypeTrait,
159 BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
160 SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
161>(
162 boxes: (ArrayView2<BOX>, Quantization),
163 scores: (ArrayView2<SCORE>, Quantization),
164 score_threshold: f32,
165 iou_threshold: f32,
166 output_boxes: &mut Vec<DetectBox>,
167) where
168 f32: AsPrimitive<SCORE>,
169{
170 let (boxes_tensor, quant_boxes) = boxes;
171 let (scores_tensor, quant_scores) = scores;
172 let boxes = {
173 let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
174 postprocess_boxes_quant::<B, _, _>(
175 score_threshold,
176 boxes_tensor,
177 scores_tensor,
178 quant_boxes,
179 )
180 };
181 let boxes = nms_int(iou_threshold, boxes);
182 let len = output_boxes.capacity().min(boxes.len());
183 output_boxes.clear();
184 for b in boxes.into_iter().take(len) {
185 output_boxes.push(dequant_detect_box(&b, quant_scores));
186 }
187}
188
189#[doc(hidden)]
198pub fn impl_modelpack_float<
199 B: BBoxTypeTrait,
200 BOX: Float + AsPrimitive<f32> + Send + Sync,
201 SCORE: Float + AsPrimitive<f32> + Send + Sync,
202>(
203 boxes_tensor: ArrayView2<BOX>,
204 scores_tensor: ArrayView2<SCORE>,
205 score_threshold: f32,
206 iou_threshold: f32,
207 output_boxes: &mut Vec<DetectBox>,
208) where
209 f32: AsPrimitive<SCORE>,
210{
211 let boxes =
212 postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
213 let boxes = nms_float(iou_threshold, boxes);
214 let len = output_boxes.capacity().min(boxes.len());
215 output_boxes.clear();
216 for b in boxes.into_iter().take(len) {
217 output_boxes.push(b);
218 }
219}
220
221#[doc(hidden)]
230pub fn impl_modelpack_split_quant<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
231 outputs: &[ArrayView3<D>],
232 configs: &[ModelPackDetectionConfig],
233 score_threshold: f32,
234 iou_threshold: f32,
235 output_boxes: &mut Vec<DetectBox>,
236) {
237 let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_quant(outputs, configs);
238 let boxes = postprocess_boxes_float::<B, _, _>(
239 score_threshold,
240 boxes_tensor.view(),
241 scores_tensor.view(),
242 );
243 let boxes = nms_float(iou_threshold, boxes);
244 let len = output_boxes.capacity().min(boxes.len());
245 output_boxes.clear();
246 for b in boxes.into_iter().take(len) {
247 output_boxes.push(b);
248 }
249}
250
251#[doc(hidden)]
261pub fn impl_modelpack_split_float<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
262 outputs: &[ArrayView3<D>],
263 configs: &[ModelPackDetectionConfig],
264 score_threshold: f32,
265 iou_threshold: f32,
266 output_boxes: &mut Vec<DetectBox>,
267) {
268 let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_float(outputs, configs);
269 let boxes = postprocess_boxes_float::<B, _, _>(
270 score_threshold,
271 boxes_tensor.view(),
272 scores_tensor.view(),
273 );
274 let boxes = nms_float(iou_threshold, boxes);
275 let len = output_boxes.capacity().min(boxes.len());
276 output_boxes.clear();
277 for b in boxes.into_iter().take(len) {
278 output_boxes.push(b);
279 }
280}
281
282#[doc(hidden)]
286pub fn postprocess_modelpack_split_quant<T: AsPrimitive<f32>>(
287 outputs: &[ArrayView3<T>],
288 config: &[ModelPackDetectionConfig],
289) -> (Array2<f32>, Array2<f32>) {
290 let mut total_capacity = 0;
291 let mut nc = 0;
292 for (p, detail) in outputs.iter().zip(config) {
293 let shape = p.shape();
294 let na = detail.anchors.len();
295 nc = *shape
296 .last()
297 .expect("Shape must have at least one dimension")
298 / na
299 - 5;
300 total_capacity += shape[0] * shape[1] * na;
301 }
302 let mut bboxes = Vec::with_capacity(total_capacity * 4);
303 let mut bscores = Vec::with_capacity(total_capacity * nc);
304
305 for (p, detail) in outputs.iter().zip(config) {
306 let anchors = &detail.anchors;
307 let na = detail.anchors.len();
308 let shape = p.shape();
309 assert_eq!(
310 shape.iter().product::<usize>(),
311 p.len(),
312 "Shape product doesn't match tensor length"
313 );
314 let p_sigmoid = if let Some(quant) = &detail.quantization {
315 let scaled_zero = -quant.zero_point as f32 * quant.scale;
316 p.mapv(|x| fast_sigmoid_impl(x.as_() * quant.scale + scaled_zero))
317 } else {
318 p.mapv(|x| fast_sigmoid_impl(x.as_()))
319 };
320 let p_sigmoid = p_sigmoid.as_standard_layout();
321
322 let p = p_sigmoid
324 .as_slice()
325 .expect("Sigmoids are not in standard layout");
326 let height = shape[0];
327 let width = shape[1];
328
329 let div_width = 1.0 / width as f32;
330 let div_height = 1.0 / height as f32;
331
332 let mut grid = Vec::with_capacity(height * width * na * 2);
333 for y in 0..height {
334 for x in 0..width {
335 for _ in 0..na {
336 grid.push(x as f32 - 0.5);
337 grid.push(y as f32 - 0.5);
338 }
339 }
340 }
341 for ((p, g), anchor) in p
342 .chunks_exact(nc + 5)
343 .zip(grid.chunks_exact(2))
344 .zip(anchors.iter().cycle())
345 {
346 let (x, y) = (p[0], p[1]);
347 let x = (x * 2.0 + g[0]) * div_width;
348 let y = (y * 2.0 + g[1]) * div_height;
349 let (w, h) = (p[2], p[3]);
350 let w = w * w * 4.0 * anchor[0];
351 let h = h * h * 4.0 * anchor[1];
352
353 bboxes.push(x);
354 bboxes.push(y);
355 bboxes.push(w);
356 bboxes.push(h);
357
358 if nc == 1 {
359 bscores.push(p[4]);
360 } else {
361 let obj = p[4];
362 let probs = p[5..].iter().map(|x| *x * obj);
363 bscores.extend(probs);
364 }
365 }
366 }
367 debug_assert_eq!(bboxes.len() % 4, 0);
370 debug_assert_eq!(bscores.len() % nc, 0);
371
372 let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
373 .expect("Failed to create bboxes array");
374 let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
375 .expect("Failed to create bscores array");
376 (bboxes, bscores)
377}
378
379#[doc(hidden)]
383pub fn postprocess_modelpack_split_float<T: AsPrimitive<f32>>(
384 outputs: &[ArrayView3<T>],
385 config: &[ModelPackDetectionConfig],
386) -> (Array2<f32>, Array2<f32>) {
387 let mut total_capacity = 0;
388 let mut nc = 0;
389 for (p, detail) in outputs.iter().zip(config) {
390 let shape = p.shape();
391 let na = detail.anchors.len();
392 nc = *shape
393 .last()
394 .expect("Shape must have at least one dimension")
395 / na
396 - 5;
397 total_capacity += shape[0] * shape[1] * na;
398 }
399 let mut bboxes = Vec::with_capacity(total_capacity * 4);
400 let mut bscores = Vec::with_capacity(total_capacity * nc);
401
402 for (p, detail) in outputs.iter().zip(config) {
403 let anchors = &detail.anchors;
404 let na = detail.anchors.len();
405 let shape = p.shape();
406 assert_eq!(
407 shape.iter().product::<usize>(),
408 p.len(),
409 "Shape product doesn't match tensor length"
410 );
411 let p_sigmoid = p.mapv(|x| fast_sigmoid_impl(x.as_()));
412 let p_sigmoid = p_sigmoid.as_standard_layout();
413
414 let p = p_sigmoid
416 .as_slice()
417 .expect("Sigmoids are not in standard layout");
418 let height = shape[0];
419 let width = shape[1];
420
421 let div_width = 1.0 / width as f32;
422 let div_height = 1.0 / height as f32;
423
424 let mut grid = Vec::with_capacity(height * width * na * 2);
425 for y in 0..height {
426 for x in 0..width {
427 for _ in 0..na {
428 grid.push(x as f32 - 0.5);
429 grid.push(y as f32 - 0.5);
430 }
431 }
432 }
433 for ((p, g), anchor) in p
434 .chunks_exact(nc + 5)
435 .zip(grid.chunks_exact(2))
436 .zip(anchors.iter().cycle())
437 {
438 let (x, y) = (p[0], p[1]);
439 let x = (x * 2.0 + g[0]) * div_width;
440 let y = (y * 2.0 + g[1]) * div_height;
441 let (w, h) = (p[2], p[3]);
442 let w = w * w * 4.0 * anchor[0];
443 let h = h * h * 4.0 * anchor[1];
444
445 bboxes.push(x);
446 bboxes.push(y);
447 bboxes.push(w);
448 bboxes.push(h);
449
450 if nc == 1 {
451 bscores.push(p[4]);
452 } else {
453 let obj = p[4];
454 let probs = p[5..].iter().map(|x| *x * obj);
455 bscores.extend(probs);
456 }
457 }
458 }
459 debug_assert_eq!(bboxes.len() % 4, 0);
462 debug_assert_eq!(bscores.len() % nc, 0);
463
464 let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
465 .expect("Failed to create bboxes array");
466 let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
467 .expect("Failed to create bscores array");
468 (bboxes, bscores)
469}
470
471#[inline(always)]
472fn fast_sigmoid_impl(f: f32) -> f32 {
473 if f.abs() > 80.0 {
474 f.signum() * 0.5 + 0.5
475 } else {
476 1.0 / (1.0 + fast_math::exp_raw(-f))
478 }
479}
480
481pub fn modelpack_segmentation_to_mask(segmentation: ArrayView3<u8>) -> Array2<u8> {
490 use argminmax::ArgMinMax;
491 assert!(
492 segmentation.shape()[2] > 1,
493 "Model Instance Segmentation should have shape (H, W, x) where x > 1"
494 );
495 let height = segmentation.shape()[0];
496 let width = segmentation.shape()[1];
497 let channels = segmentation.shape()[2];
498 let segmentation = segmentation.as_standard_layout();
499 let seg = segmentation
501 .as_slice()
502 .expect("Segmentation is not in standard layout");
503 let argmax = seg
504 .chunks_exact(channels)
505 .map(|x| x.argmax() as u8)
506 .collect::<Vec<_>>();
507
508 Array2::from_shape_vec((height, width), argmax).expect("Failed to create mask array")
509}
510
511#[cfg(test)]
512#[cfg_attr(coverage_nightly, coverage(off))]
513mod modelpack_tests {
514 #![allow(clippy::excessive_precision)]
515 use ndarray::Array3;
516
517 use crate::configs::{DecoderType, DimName};
518
519 use super::*;
520 #[test]
521 fn test_detection_config() {
522 let det = Detection {
523 anchors: Some(vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]]),
524 quantization: Some((0.1, 128).into()),
525 decoder: DecoderType::ModelPack,
526 shape: vec![1, 9, 17, 18],
527 dshape: vec![
528 (DimName::Batch, 1),
529 (DimName::Height, 9),
530 (DimName::Width, 17),
531 (DimName::NumAnchorsXFeatures, 18),
532 ],
533 normalized: Some(true),
534 };
535 let config = ModelPackDetectionConfig::try_from(&det).unwrap();
536 assert_eq!(
537 config,
538 ModelPackDetectionConfig {
539 anchors: vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]],
540 quantization: Some(Quantization::new(0.1, 128)),
541 }
542 );
543
544 let det = Detection {
545 anchors: None,
546 quantization: Some((0.1, 128).into()),
547 decoder: DecoderType::ModelPack,
548 shape: vec![1, 9, 17, 18],
549 dshape: vec![
550 (DimName::Batch, 1),
551 (DimName::Height, 9),
552 (DimName::Width, 17),
553 (DimName::NumAnchorsXFeatures, 18),
554 ],
555 normalized: Some(true),
556 };
557 let result = ModelPackDetectionConfig::try_from(&det);
558 assert!(
559 matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors")
560 );
561 }
562
563 #[test]
564 fn test_fast_sigmoid() {
565 fn full_sigmoid(x: f32) -> f32 {
566 1.0 / (1.0 + (-x).exp())
567 }
568 for i in -2550..=2550 {
569 let x = i as f32 * 0.1;
570 let fast = fast_sigmoid_impl(x);
571 let full = full_sigmoid(x);
572 let diff = (fast - full).abs();
573 assert!(
574 diff < 0.0005,
575 "Fast sigmoid differs from full sigmoid by {} at input {}",
576 diff,
577 x
578 );
579 }
580 }
581
582 #[test]
583 fn test_modelpack_segmentation_to_mask() {
584 let seg = Array3::from_shape_vec(
585 (2, 2, 3),
586 vec![
587 0u8, 10, 5, 20, 15, 25, 30, 5, 10, 0, 0, 0, ],
592 )
593 .unwrap();
594 let mask = modelpack_segmentation_to_mask(seg.view());
595 let expected_mask = Array2::from_shape_vec((2, 2), vec![1u8, 2, 0, 0]).unwrap();
596 assert_eq!(mask, expected_mask);
597 }
598
599 #[test]
600 #[should_panic(
601 expected = "Model Instance Segmentation should have shape (H, W, x) where x > 1"
602 )]
603 fn test_modelpack_segmentation_to_mask_invalid() {
604 let seg = Array3::from_shape_vec((2, 2, 1), vec![0u8, 10, 20, 30]).unwrap();
605 let _ = modelpack_segmentation_to_mask(seg.view());
606 }
607}