Skip to main content

edgefirst_decoder/
modelpack.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{Array2, ArrayView2, ArrayView3};
5use num_traits::{AsPrimitive, Float, PrimInt};
6
7use crate::{
8    byte::{nms_int, postprocess_boxes_quant, quantize_score_threshold},
9    configs::Detection,
10    dequant_detect_box,
11    float::{nms_float, postprocess_boxes_float},
12    BBoxTypeTrait, DecoderError, DetectBox, Quantization, XYWH, XYXY,
13};
14
15/// Configuration for ModelPack split detection decoder. The quantization is
16/// ignored when decoding float models.
17#[derive(Debug, Clone, PartialEq)]
18pub struct ModelPackDetectionConfig {
19    pub anchors: Vec<[f32; 2]>,
20    pub quantization: Option<Quantization>,
21}
22
23impl TryFrom<&Detection> for ModelPackDetectionConfig {
24    type Error = DecoderError;
25
26    fn try_from(value: &Detection) -> Result<Self, DecoderError> {
27        Ok(Self {
28            anchors: value.anchors.clone().ok_or_else(|| {
29                DecoderError::InvalidConfig("ModelPack Split Detection missing anchors".to_string())
30            })?,
31            quantization: value.quantization.map(Quantization::from),
32        })
33    }
34}
35
36/// Decodes ModelPack detection outputs from quantized tensors.
37///
38/// The boxes are expected to be in XYXY format.
39///
40/// Expected shapes of inputs:
41/// - boxes: (num_boxes, 4)
42/// - scores: (num_boxes, num_classes)
43///
44/// # Panics
45/// Panics if shapes don't match the expected dimensions.
46pub fn decode_modelpack_det<
47    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
48    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
49>(
50    boxes_tensor: (ArrayView2<BOX>, Quantization),
51    scores_tensor: (ArrayView2<SCORE>, Quantization),
52    score_threshold: f32,
53    iou_threshold: f32,
54    output_boxes: &mut Vec<DetectBox>,
55) where
56    f32: AsPrimitive<SCORE>,
57{
58    impl_modelpack_quant::<XYXY, _, _>(
59        boxes_tensor,
60        scores_tensor,
61        score_threshold,
62        iou_threshold,
63        output_boxes,
64    )
65}
66
67/// Decodes ModelPack detection outputs from float tensors. The boxes
68/// are expected to be in XYXY format.
69///
70/// Expected shapes of inputs:
71/// - boxes: (num_boxes, 4)
72/// - scores: (num_boxes, num_classes)
73///
74/// # Panics
75/// Panics if shapes don't match the expected dimensions.
76pub fn decode_modelpack_float<
77    BOX: Float + AsPrimitive<f32> + Send + Sync,
78    SCORE: Float + AsPrimitive<f32> + Send + Sync,
79>(
80    boxes_tensor: ArrayView2<BOX>,
81    scores_tensor: ArrayView2<SCORE>,
82    score_threshold: f32,
83    iou_threshold: f32,
84    output_boxes: &mut Vec<DetectBox>,
85) where
86    f32: AsPrimitive<SCORE>,
87{
88    impl_modelpack_float::<XYXY, _, _>(
89        boxes_tensor,
90        scores_tensor,
91        score_threshold,
92        iou_threshold,
93        output_boxes,
94    )
95}
96
97/// Decodes ModelPack split detection outputs from quantized tensors. The boxes
98/// are expected to be in XYWH format.
99///
100/// The `configs` must correspond to the `outputs` in order.
101///
102/// Expected shapes of inputs:
103/// - outputs: (width, height, num_anchors * (5 + num_classes))
104///
105/// # Panics
106/// Panics if shapes don't match the expected dimensions.
107pub fn decode_modelpack_split_quant<D: AsPrimitive<f32>>(
108    outputs: &[ArrayView3<D>],
109    configs: &[ModelPackDetectionConfig],
110    score_threshold: f32,
111    iou_threshold: f32,
112    output_boxes: &mut Vec<DetectBox>,
113) {
114    impl_modelpack_split_quant::<XYWH, D>(
115        outputs,
116        configs,
117        score_threshold,
118        iou_threshold,
119        output_boxes,
120    )
121}
122
123/// Decodes ModelPack split detection outputs from float tensors. The boxes
124/// are expected to be in XYWH format.
125///
126/// The `configs` must correspond to the `outputs` in order.
127///
128/// Expected shapes of inputs:
129/// - outputs: (width, height, num_anchors * (5 + num_classes))
130///
131/// # Panics
132/// Panics if shapes don't match the expected dimensions.
133pub fn decode_modelpack_split_float<D: AsPrimitive<f32>>(
134    outputs: &[ArrayView3<D>],
135    configs: &[ModelPackDetectionConfig],
136    score_threshold: f32,
137    iou_threshold: f32,
138    output_boxes: &mut Vec<DetectBox>,
139) {
140    impl_modelpack_split_float::<XYWH, D>(
141        outputs,
142        configs,
143        score_threshold,
144        iou_threshold,
145        output_boxes,
146    );
147}
148/// Implementation of ModelPack detection decoding for quantized tensors.
149///
150/// Expected shapes of inputs:
151/// - boxes: (num_boxes, 4)
152/// - scores: (num_boxes, num_classes)
153///
154/// # Panics
155/// Panics if shapes don't match the expected dimensions.
156#[doc(hidden)]
157pub fn impl_modelpack_quant<
158    B: BBoxTypeTrait,
159    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
160    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
161>(
162    boxes: (ArrayView2<BOX>, Quantization),
163    scores: (ArrayView2<SCORE>, Quantization),
164    score_threshold: f32,
165    iou_threshold: f32,
166    output_boxes: &mut Vec<DetectBox>,
167) where
168    f32: AsPrimitive<SCORE>,
169{
170    let (boxes_tensor, quant_boxes) = boxes;
171    let (scores_tensor, quant_scores) = scores;
172    let boxes = {
173        let score_threshold = quantize_score_threshold(score_threshold, quant_boxes);
174        postprocess_boxes_quant::<B, _, _>(
175            score_threshold,
176            boxes_tensor,
177            scores_tensor,
178            quant_boxes,
179        )
180    };
181    let boxes = nms_int(iou_threshold, boxes);
182    let len = output_boxes.capacity().min(boxes.len());
183    output_boxes.clear();
184    for b in boxes.into_iter().take(len) {
185        output_boxes.push(dequant_detect_box(&b, quant_scores));
186    }
187}
188
189/// Implementation of ModelPack detection decoding for float tensors.
190///
191/// Expected shapes of inputs:
192/// - boxes: (num_boxes, 4)
193/// - scores: (num_boxes, num_classes)
194///
195/// # Panics
196/// Panics if shapes don't match the expected dimensions.
197#[doc(hidden)]
198pub fn impl_modelpack_float<
199    B: BBoxTypeTrait,
200    BOX: Float + AsPrimitive<f32> + Send + Sync,
201    SCORE: Float + AsPrimitive<f32> + Send + Sync,
202>(
203    boxes_tensor: ArrayView2<BOX>,
204    scores_tensor: ArrayView2<SCORE>,
205    score_threshold: f32,
206    iou_threshold: f32,
207    output_boxes: &mut Vec<DetectBox>,
208) where
209    f32: AsPrimitive<SCORE>,
210{
211    let boxes =
212        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
213    let boxes = nms_float(iou_threshold, boxes);
214    let len = output_boxes.capacity().min(boxes.len());
215    output_boxes.clear();
216    for b in boxes.into_iter().take(len) {
217        output_boxes.push(b);
218    }
219}
220
221/// Implementation of ModelPack split detection decoding for quantized tensors.
222///
223/// Expected shapes of inputs:
224/// - boxes: (num_boxes, 4)
225/// - scores: (num_boxes, num_classes)
226///
227/// # Panics
228/// Panics if shapes don't match the expected dimensions.
229#[doc(hidden)]
230pub fn impl_modelpack_split_quant<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
231    outputs: &[ArrayView3<D>],
232    configs: &[ModelPackDetectionConfig],
233    score_threshold: f32,
234    iou_threshold: f32,
235    output_boxes: &mut Vec<DetectBox>,
236) {
237    let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_quant(outputs, configs);
238    let boxes = postprocess_boxes_float::<B, _, _>(
239        score_threshold,
240        boxes_tensor.view(),
241        scores_tensor.view(),
242    );
243    let boxes = nms_float(iou_threshold, boxes);
244    let len = output_boxes.capacity().min(boxes.len());
245    output_boxes.clear();
246    for b in boxes.into_iter().take(len) {
247        output_boxes.push(b);
248    }
249}
250
251/// Implementation of ModelPack split detection decoding for float tensors.
252///
253/// The `configs` must correspond to the `outputs` in order.
254///
255/// Expected shapes of inputs:
256/// - outputs: (width, height, num_anchors * (5 + num_classes))
257///
258/// # Panics
259/// Panics if shapes don't match the expected dimensions.
260#[doc(hidden)]
261pub fn impl_modelpack_split_float<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
262    outputs: &[ArrayView3<D>],
263    configs: &[ModelPackDetectionConfig],
264    score_threshold: f32,
265    iou_threshold: f32,
266    output_boxes: &mut Vec<DetectBox>,
267) {
268    let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_float(outputs, configs);
269    let boxes = postprocess_boxes_float::<B, _, _>(
270        score_threshold,
271        boxes_tensor.view(),
272        scores_tensor.view(),
273    );
274    let boxes = nms_float(iou_threshold, boxes);
275    let len = output_boxes.capacity().min(boxes.len());
276    output_boxes.clear();
277    for b in boxes.into_iter().take(len) {
278        output_boxes.push(b);
279    }
280}
281
282/// Post processes ModelPack split detection into detection boxes,
283/// filtering out any boxes below the score threshold. Returns the boxes and
284/// scores tensors. Boxes are in XYWH format.
285#[doc(hidden)]
286pub fn postprocess_modelpack_split_quant<T: AsPrimitive<f32>>(
287    outputs: &[ArrayView3<T>],
288    config: &[ModelPackDetectionConfig],
289) -> (Array2<f32>, Array2<f32>) {
290    let mut total_capacity = 0;
291    let mut nc = 0;
292    for (p, detail) in outputs.iter().zip(config) {
293        let shape = p.shape();
294        let na = detail.anchors.len();
295        nc = *shape
296            .last()
297            .expect("Shape must have at least one dimension")
298            / na
299            - 5;
300        total_capacity += shape[0] * shape[1] * na;
301    }
302    let mut bboxes = Vec::with_capacity(total_capacity * 4);
303    let mut bscores = Vec::with_capacity(total_capacity * nc);
304
305    for (p, detail) in outputs.iter().zip(config) {
306        let anchors = &detail.anchors;
307        let na = detail.anchors.len();
308        let shape = p.shape();
309        assert_eq!(
310            shape.iter().product::<usize>(),
311            p.len(),
312            "Shape product doesn't match tensor length"
313        );
314        let p_sigmoid = if let Some(quant) = &detail.quantization {
315            let scaled_zero = -quant.zero_point as f32 * quant.scale;
316            p.mapv(|x| fast_sigmoid_impl(x.as_() * quant.scale + scaled_zero))
317        } else {
318            p.mapv(|x| fast_sigmoid_impl(x.as_()))
319        };
320        let p_sigmoid = p_sigmoid.as_standard_layout();
321
322        // Safe to unwrap since we ensured standard layout above
323        let p = p_sigmoid
324            .as_slice()
325            .expect("Sigmoids are not in standard layout");
326        let height = shape[0];
327        let width = shape[1];
328
329        let div_width = 1.0 / width as f32;
330        let div_height = 1.0 / height as f32;
331
332        let mut grid = Vec::with_capacity(height * width * na * 2);
333        for y in 0..height {
334            for x in 0..width {
335                for _ in 0..na {
336                    grid.push(x as f32 - 0.5);
337                    grid.push(y as f32 - 0.5);
338                }
339            }
340        }
341        for ((p, g), anchor) in p
342            .chunks_exact(nc + 5)
343            .zip(grid.chunks_exact(2))
344            .zip(anchors.iter().cycle())
345        {
346            let (x, y) = (p[0], p[1]);
347            let x = (x * 2.0 + g[0]) * div_width;
348            let y = (y * 2.0 + g[1]) * div_height;
349            let (w, h) = (p[2], p[3]);
350            let w = w * w * 4.0 * anchor[0];
351            let h = h * h * 4.0 * anchor[1];
352
353            bboxes.push(x);
354            bboxes.push(y);
355            bboxes.push(w);
356            bboxes.push(h);
357
358            if nc == 1 {
359                bscores.push(p[4]);
360            } else {
361                let obj = p[4];
362                let probs = p[5..].iter().map(|x| *x * obj);
363                bscores.extend(probs);
364            }
365        }
366    }
367    // Safe to unwrap since we ensured lengths will match above
368
369    debug_assert_eq!(bboxes.len() % 4, 0);
370    debug_assert_eq!(bscores.len() % nc, 0);
371
372    let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
373        .expect("Failed to create bboxes array");
374    let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
375        .expect("Failed to create bscores array");
376    (bboxes, bscores)
377}
378
379/// Post processes ModelPack split detection into detection boxes,
380/// filtering out any boxes below the score threshold. Returns the boxes and
381/// scores tensors. Boxes are in XYWH format.
382#[doc(hidden)]
383pub fn postprocess_modelpack_split_float<T: AsPrimitive<f32>>(
384    outputs: &[ArrayView3<T>],
385    config: &[ModelPackDetectionConfig],
386) -> (Array2<f32>, Array2<f32>) {
387    let mut total_capacity = 0;
388    let mut nc = 0;
389    for (p, detail) in outputs.iter().zip(config) {
390        let shape = p.shape();
391        let na = detail.anchors.len();
392        nc = *shape
393            .last()
394            .expect("Shape must have at least one dimension")
395            / na
396            - 5;
397        total_capacity += shape[0] * shape[1] * na;
398    }
399    let mut bboxes = Vec::with_capacity(total_capacity * 4);
400    let mut bscores = Vec::with_capacity(total_capacity * nc);
401
402    for (p, detail) in outputs.iter().zip(config) {
403        let anchors = &detail.anchors;
404        let na = detail.anchors.len();
405        let shape = p.shape();
406        assert_eq!(
407            shape.iter().product::<usize>(),
408            p.len(),
409            "Shape product doesn't match tensor length"
410        );
411        let p_sigmoid = p.mapv(|x| fast_sigmoid_impl(x.as_()));
412        let p_sigmoid = p_sigmoid.as_standard_layout();
413
414        // Safe to unwrap since we ensured standard layout above
415        let p = p_sigmoid
416            .as_slice()
417            .expect("Sigmoids are not in standard layout");
418        let height = shape[0];
419        let width = shape[1];
420
421        let div_width = 1.0 / width as f32;
422        let div_height = 1.0 / height as f32;
423
424        let mut grid = Vec::with_capacity(height * width * na * 2);
425        for y in 0..height {
426            for x in 0..width {
427                for _ in 0..na {
428                    grid.push(x as f32 - 0.5);
429                    grid.push(y as f32 - 0.5);
430                }
431            }
432        }
433        for ((p, g), anchor) in p
434            .chunks_exact(nc + 5)
435            .zip(grid.chunks_exact(2))
436            .zip(anchors.iter().cycle())
437        {
438            let (x, y) = (p[0], p[1]);
439            let x = (x * 2.0 + g[0]) * div_width;
440            let y = (y * 2.0 + g[1]) * div_height;
441            let (w, h) = (p[2], p[3]);
442            let w = w * w * 4.0 * anchor[0];
443            let h = h * h * 4.0 * anchor[1];
444
445            bboxes.push(x);
446            bboxes.push(y);
447            bboxes.push(w);
448            bboxes.push(h);
449
450            if nc == 1 {
451                bscores.push(p[4]);
452            } else {
453                let obj = p[4];
454                let probs = p[5..].iter().map(|x| *x * obj);
455                bscores.extend(probs);
456            }
457        }
458    }
459    // Safe to unwrap since we ensured lengths will match above
460
461    debug_assert_eq!(bboxes.len() % 4, 0);
462    debug_assert_eq!(bscores.len() % nc, 0);
463
464    let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
465        .expect("Failed to create bboxes array");
466    let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
467        .expect("Failed to create bscores array");
468    (bboxes, bscores)
469}
470
471#[inline(always)]
472fn fast_sigmoid_impl(f: f32) -> f32 {
473    if f.abs() > 80.0 {
474        f.signum() * 0.5 + 0.5
475    } else {
476        // these values are only valid for -88 < x < 88
477        1.0 / (1.0 + fast_math::exp_raw(-f))
478    }
479}
480
481/// Converts ModelPack segmentation into a 2D mask.
482/// The input segmentation is expected to have shape (H, W, num_classes).
483///
484/// The output mask will have shape (H, W), with values `0..num_classes` based
485/// on the argmax across the channels.
486///
487/// # Panics
488/// Panics if the input tensor does not have more than one channel.
489pub fn modelpack_segmentation_to_mask(segmentation: ArrayView3<u8>) -> Array2<u8> {
490    use argminmax::ArgMinMax;
491    assert!(
492        segmentation.shape()[2] > 1,
493        "Model Instance Segmentation should have shape (H, W, x) where x > 1"
494    );
495    let height = segmentation.shape()[0];
496    let width = segmentation.shape()[1];
497    let channels = segmentation.shape()[2];
498    let segmentation = segmentation.as_standard_layout();
499    // Safe to unwrap since we ensured standard layout above
500    let seg = segmentation
501        .as_slice()
502        .expect("Segmentation is not in standard layout");
503    let argmax = seg
504        .chunks_exact(channels)
505        .map(|x| x.argmax() as u8)
506        .collect::<Vec<_>>();
507
508    Array2::from_shape_vec((height, width), argmax).expect("Failed to create mask array")
509}
510
511#[cfg(test)]
512#[cfg_attr(coverage_nightly, coverage(off))]
513mod modelpack_tests {
514    #![allow(clippy::excessive_precision)]
515    use ndarray::Array3;
516
517    use crate::configs::{DecoderType, DimName};
518
519    use super::*;
520    #[test]
521    fn test_detection_config() {
522        let det = Detection {
523            anchors: Some(vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]]),
524            quantization: Some((0.1, 128).into()),
525            decoder: DecoderType::ModelPack,
526            shape: vec![1, 9, 17, 18],
527            dshape: vec![
528                (DimName::Batch, 1),
529                (DimName::Height, 9),
530                (DimName::Width, 17),
531                (DimName::NumAnchorsXFeatures, 18),
532            ],
533            normalized: Some(true),
534        };
535        let config = ModelPackDetectionConfig::try_from(&det).unwrap();
536        assert_eq!(
537            config,
538            ModelPackDetectionConfig {
539                anchors: vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]],
540                quantization: Some(Quantization::new(0.1, 128)),
541            }
542        );
543
544        let det = Detection {
545            anchors: None,
546            quantization: Some((0.1, 128).into()),
547            decoder: DecoderType::ModelPack,
548            shape: vec![1, 9, 17, 18],
549            dshape: vec![
550                (DimName::Batch, 1),
551                (DimName::Height, 9),
552                (DimName::Width, 17),
553                (DimName::NumAnchorsXFeatures, 18),
554            ],
555            normalized: Some(true),
556        };
557        let result = ModelPackDetectionConfig::try_from(&det);
558        assert!(
559            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors")
560        );
561    }
562
563    #[test]
564    fn test_fast_sigmoid() {
565        fn full_sigmoid(x: f32) -> f32 {
566            1.0 / (1.0 + (-x).exp())
567        }
568        for i in -2550..=2550 {
569            let x = i as f32 * 0.1;
570            let fast = fast_sigmoid_impl(x);
571            let full = full_sigmoid(x);
572            let diff = (fast - full).abs();
573            assert!(
574                diff < 0.0005,
575                "Fast sigmoid differs from full sigmoid by {} at input {}",
576                diff,
577                x
578            );
579        }
580    }
581
582    #[test]
583    fn test_modelpack_segmentation_to_mask() {
584        let seg = Array3::from_shape_vec(
585            (2, 2, 3),
586            vec![
587                0u8, 10, 5, // pixel (0,0)
588                20, 15, 25, // pixel (0,1)
589                30, 5, 10, // pixel (1,0)
590                0, 0, 0, // pixel (1,1)
591            ],
592        )
593        .unwrap();
594        let mask = modelpack_segmentation_to_mask(seg.view());
595        let expected_mask = Array2::from_shape_vec((2, 2), vec![1u8, 2, 0, 0]).unwrap();
596        assert_eq!(mask, expected_mask);
597    }
598
599    #[test]
600    #[should_panic(
601        expected = "Model Instance Segmentation should have shape (H, W, x) where x > 1"
602    )]
603    fn test_modelpack_segmentation_to_mask_invalid() {
604        let seg = Array3::from_shape_vec((2, 2, 1), vec![0u8, 10, 20, 30]).unwrap();
605        let _ = modelpack_segmentation_to_mask(seg.view());
606    }
607}