Skip to main content

edgefirst_decoder/
modelpack.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{Array2, ArrayView2, ArrayView3};
5use num_traits::{AsPrimitive, Float, PrimInt};
6
7use crate::{
8    byte::{nms_int, postprocess_boxes_quant, quantize_score_threshold},
9    configs::Detection,
10    dequant_detect_box,
11    float::{nms_float, postprocess_boxes_float},
12    BBoxTypeTrait, DecoderError, DetectBox, Quantization, XYWH, XYXY,
13};
14
15/// Configuration for ModelPack split detection decoder. The quantization is
16/// ignored when decoding float models.
17#[derive(Debug, Clone, PartialEq)]
18pub(crate) struct ModelPackDetectionConfig {
19    pub(crate) anchors: Vec<[f32; 2]>,
20    pub(crate) quantization: Option<Quantization>,
21}
22
23impl TryFrom<&Detection> for ModelPackDetectionConfig {
24    type Error = DecoderError;
25
26    fn try_from(value: &Detection) -> Result<Self, DecoderError> {
27        Ok(Self {
28            anchors: value.anchors.clone().ok_or_else(|| {
29                DecoderError::InvalidConfig("ModelPack Split Detection missing anchors".to_string())
30            })?,
31            quantization: value.quantization.map(Quantization::from),
32        })
33    }
34}
35
36/// Decodes ModelPack detection outputs from quantized tensors.
37///
38/// The boxes are expected to be in XYXY format.
39///
40/// Expected shapes of inputs:
41/// - boxes: (num_boxes, 4)
42/// - scores: (num_boxes, num_classes)
43///
44/// # Panics
45/// Panics if shapes don't match the expected dimensions.
46pub(crate) fn decode_modelpack_det<
47    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
48    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
49>(
50    boxes_tensor: (ArrayView2<BOX>, Quantization),
51    scores_tensor: (ArrayView2<SCORE>, Quantization),
52    score_threshold: f32,
53    iou_threshold: f32,
54    max_det: usize,
55    output_boxes: &mut Vec<DetectBox>,
56) where
57    f32: AsPrimitive<SCORE>,
58{
59    impl_modelpack_quant::<XYXY, _, _>(
60        boxes_tensor,
61        scores_tensor,
62        score_threshold,
63        iou_threshold,
64        max_det,
65        output_boxes,
66    )
67}
68
69/// Decodes ModelPack detection outputs from float tensors. The boxes
70/// are expected to be in XYXY format.
71///
72/// Expected shapes of inputs:
73/// - boxes: (num_boxes, 4)
74/// - scores: (num_boxes, num_classes)
75///
76/// # Panics
77/// Panics if shapes don't match the expected dimensions.
78pub(crate) fn decode_modelpack_float<
79    BOX: Float + AsPrimitive<f32> + Send + Sync,
80    SCORE: Float + AsPrimitive<f32> + Send + Sync,
81>(
82    boxes_tensor: ArrayView2<BOX>,
83    scores_tensor: ArrayView2<SCORE>,
84    score_threshold: f32,
85    iou_threshold: f32,
86    max_det: usize,
87    output_boxes: &mut Vec<DetectBox>,
88) where
89    f32: AsPrimitive<SCORE>,
90{
91    impl_modelpack_float::<XYXY, _, _>(
92        boxes_tensor,
93        scores_tensor,
94        score_threshold,
95        iou_threshold,
96        max_det,
97        output_boxes,
98    )
99}
100
101/// Decodes ModelPack split detection outputs from quantized tensors. The boxes
102/// are expected to be in XYWH format.
103///
104/// The `configs` must correspond to the `outputs` in order.
105///
106/// Expected shapes of inputs:
107/// - outputs: (width, height, num_anchors * (5 + num_classes))
108///
109/// # Panics
110/// Panics if shapes don't match the expected dimensions.
111#[cfg(test)]
112pub(crate) fn decode_modelpack_split_quant<D: AsPrimitive<f32>>(
113    outputs: &[ArrayView3<D>],
114    configs: &[ModelPackDetectionConfig],
115    score_threshold: f32,
116    iou_threshold: f32,
117    max_det: usize,
118    output_boxes: &mut Vec<DetectBox>,
119) {
120    impl_modelpack_split_quant::<XYWH, D>(
121        outputs,
122        configs,
123        score_threshold,
124        iou_threshold,
125        max_det,
126        output_boxes,
127    )
128}
129
130/// Decodes ModelPack split detection outputs from float tensors. The boxes
131/// are expected to be in XYWH format.
132///
133/// The `configs` must correspond to the `outputs` in order.
134///
135/// Expected shapes of inputs:
136/// - outputs: (width, height, num_anchors * (5 + num_classes))
137///
138/// # Panics
139/// Panics if shapes don't match the expected dimensions.
140pub(crate) fn decode_modelpack_split_float<D: AsPrimitive<f32>>(
141    outputs: &[ArrayView3<D>],
142    configs: &[ModelPackDetectionConfig],
143    score_threshold: f32,
144    iou_threshold: f32,
145    max_det: usize,
146    output_boxes: &mut Vec<DetectBox>,
147) {
148    impl_modelpack_split_float::<XYWH, D>(
149        outputs,
150        configs,
151        score_threshold,
152        iou_threshold,
153        max_det,
154        output_boxes,
155    );
156}
157/// Implementation of ModelPack detection decoding for quantized tensors.
158///
159/// Expected shapes of inputs:
160/// - boxes: (num_boxes, 4)
161/// - scores: (num_boxes, num_classes)
162///
163/// # Panics
164/// Panics if shapes don't match the expected dimensions.
165pub(crate) fn impl_modelpack_quant<
166    B: BBoxTypeTrait,
167    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
168    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
169>(
170    boxes: (ArrayView2<BOX>, Quantization),
171    scores: (ArrayView2<SCORE>, Quantization),
172    score_threshold: f32,
173    iou_threshold: f32,
174    max_det: usize,
175    output_boxes: &mut Vec<DetectBox>,
176) where
177    f32: AsPrimitive<SCORE>,
178{
179    let (boxes_tensor, quant_boxes) = boxes;
180    let (scores_tensor, quant_scores) = scores;
181    let boxes = {
182        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
183        postprocess_boxes_quant::<B, _, _>(
184            score_threshold,
185            boxes_tensor,
186            scores_tensor,
187            quant_boxes,
188        )
189    };
190    let boxes = nms_int(iou_threshold, Some(max_det), boxes);
191    output_boxes.clear();
192    for b in boxes.into_iter().take(max_det) {
193        output_boxes.push(dequant_detect_box(&b, quant_scores));
194    }
195}
196
197/// Implementation of ModelPack detection decoding for float tensors.
198///
199/// Expected shapes of inputs:
200/// - boxes: (num_boxes, 4)
201/// - scores: (num_boxes, num_classes)
202///
203/// # Panics
204/// Panics if shapes don't match the expected dimensions.
205pub(crate) fn impl_modelpack_float<
206    B: BBoxTypeTrait,
207    BOX: Float + AsPrimitive<f32> + Send + Sync,
208    SCORE: Float + AsPrimitive<f32> + Send + Sync,
209>(
210    boxes_tensor: ArrayView2<BOX>,
211    scores_tensor: ArrayView2<SCORE>,
212    score_threshold: f32,
213    iou_threshold: f32,
214    max_det: usize,
215    output_boxes: &mut Vec<DetectBox>,
216) where
217    f32: AsPrimitive<SCORE>,
218{
219    let boxes =
220        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
221    let boxes = nms_float(iou_threshold, Some(max_det), boxes);
222    output_boxes.clear();
223    for b in boxes.into_iter().take(max_det) {
224        output_boxes.push(b);
225    }
226}
227
228/// Implementation of ModelPack split detection decoding for quantized tensors.
229///
230/// Expected shapes of inputs:
231/// - boxes: (num_boxes, 4)
232/// - scores: (num_boxes, num_classes)
233///
234/// # Panics
235/// Panics if shapes don't match the expected dimensions.
236#[cfg(test)]
237pub(crate) fn impl_modelpack_split_quant<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
238    outputs: &[ArrayView3<D>],
239    configs: &[ModelPackDetectionConfig],
240    score_threshold: f32,
241    iou_threshold: f32,
242    max_det: usize,
243    output_boxes: &mut Vec<DetectBox>,
244) {
245    let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_quant(outputs, configs);
246    let boxes = postprocess_boxes_float::<B, _, _>(
247        score_threshold,
248        boxes_tensor.view(),
249        scores_tensor.view(),
250    );
251    let boxes = nms_float(iou_threshold, Some(max_det), boxes);
252    output_boxes.clear();
253    for b in boxes.into_iter().take(max_det) {
254        output_boxes.push(b);
255    }
256}
257
258/// Implementation of ModelPack split detection decoding for float tensors.
259///
260/// The `configs` must correspond to the `outputs` in order.
261///
262/// Expected shapes of inputs:
263/// - outputs: (width, height, num_anchors * (5 + num_classes))
264///
265/// # Panics
266/// Panics if shapes don't match the expected dimensions.
267pub(crate) fn impl_modelpack_split_float<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
268    outputs: &[ArrayView3<D>],
269    configs: &[ModelPackDetectionConfig],
270    score_threshold: f32,
271    iou_threshold: f32,
272    max_det: usize,
273    output_boxes: &mut Vec<DetectBox>,
274) {
275    let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_float(outputs, configs);
276    let boxes = postprocess_boxes_float::<B, _, _>(
277        score_threshold,
278        boxes_tensor.view(),
279        scores_tensor.view(),
280    );
281    let boxes = nms_float(iou_threshold, Some(max_det), boxes);
282    output_boxes.clear();
283    for b in boxes.into_iter().take(max_det) {
284        output_boxes.push(b);
285    }
286}
287
288/// Post processes ModelPack split detection into detection boxes,
289/// filtering out any boxes below the score threshold. Returns the boxes and
290/// scores tensors. Boxes are in XYWH format.
291#[cfg(test)]
292pub(crate) fn postprocess_modelpack_split_quant<T: AsPrimitive<f32>>(
293    outputs: &[ArrayView3<T>],
294    config: &[ModelPackDetectionConfig],
295) -> (Array2<f32>, Array2<f32>) {
296    let mut total_capacity = 0;
297    let mut nc = 0;
298    for (p, detail) in outputs.iter().zip(config) {
299        let shape = p.shape();
300        let na = detail.anchors.len();
301        nc = *shape
302            .last()
303            .expect("Shape must have at least one dimension")
304            / na
305            - 5;
306        total_capacity += shape[0] * shape[1] * na;
307    }
308    let mut bboxes = Vec::with_capacity(total_capacity * 4);
309    let mut bscores = Vec::with_capacity(total_capacity * nc);
310
311    for (p, detail) in outputs.iter().zip(config) {
312        let anchors = &detail.anchors;
313        let na = detail.anchors.len();
314        let shape = p.shape();
315        assert_eq!(
316            shape.iter().product::<usize>(),
317            p.len(),
318            "Shape product doesn't match tensor length"
319        );
320        let p_sigmoid = if let Some(quant) = &detail.quantization {
321            let scaled_zero = -quant.zero_point as f32 * quant.scale;
322            p.mapv(|x| fast_sigmoid_impl(x.as_() * quant.scale + scaled_zero))
323        } else {
324            p.mapv(|x| fast_sigmoid_impl(x.as_()))
325        };
326        let p_sigmoid = p_sigmoid.as_standard_layout();
327
328        // Safe to unwrap since we ensured standard layout above
329        let p = p_sigmoid
330            .as_slice()
331            .expect("Sigmoids are not in standard layout");
332        let height = shape[0];
333        let width = shape[1];
334
335        let div_width = 1.0 / width as f32;
336        let div_height = 1.0 / height as f32;
337
338        let mut grid = Vec::with_capacity(height * width * na * 2);
339        for y in 0..height {
340            for x in 0..width {
341                for _ in 0..na {
342                    grid.push(x as f32 - 0.5);
343                    grid.push(y as f32 - 0.5);
344                }
345            }
346        }
347        for ((p, g), anchor) in p
348            .chunks_exact(nc + 5)
349            .zip(grid.chunks_exact(2))
350            .zip(anchors.iter().cycle())
351        {
352            let (x, y) = (p[0], p[1]);
353            let x = (x * 2.0 + g[0]) * div_width;
354            let y = (y * 2.0 + g[1]) * div_height;
355            let (w, h) = (p[2], p[3]);
356            let w = w * w * 4.0 * anchor[0];
357            let h = h * h * 4.0 * anchor[1];
358
359            bboxes.push(x);
360            bboxes.push(y);
361            bboxes.push(w);
362            bboxes.push(h);
363
364            if nc == 1 {
365                bscores.push(p[4]);
366            } else {
367                let obj = p[4];
368                let probs = p[5..].iter().map(|x| *x * obj);
369                bscores.extend(probs);
370            }
371        }
372    }
373    // Safe to unwrap since we ensured lengths will match above
374
375    debug_assert_eq!(bboxes.len() % 4, 0);
376    debug_assert_eq!(bscores.len() % nc, 0);
377
378    let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
379        .expect("Failed to create bboxes array");
380    let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
381        .expect("Failed to create bscores array");
382    (bboxes, bscores)
383}
384
385/// Post processes ModelPack split detection into detection boxes,
386/// filtering out any boxes below the score threshold. Returns the boxes and
387/// scores tensors. Boxes are in XYWH format.
388pub(crate) fn postprocess_modelpack_split_float<T: AsPrimitive<f32>>(
389    outputs: &[ArrayView3<T>],
390    config: &[ModelPackDetectionConfig],
391) -> (Array2<f32>, Array2<f32>) {
392    let mut total_capacity = 0;
393    let mut nc = 0;
394    for (p, detail) in outputs.iter().zip(config) {
395        let shape = p.shape();
396        let na = detail.anchors.len();
397        nc = *shape
398            .last()
399            .expect("Shape must have at least one dimension")
400            / na
401            - 5;
402        total_capacity += shape[0] * shape[1] * na;
403    }
404    let mut bboxes = Vec::with_capacity(total_capacity * 4);
405    let mut bscores = Vec::with_capacity(total_capacity * nc);
406
407    for (p, detail) in outputs.iter().zip(config) {
408        let anchors = &detail.anchors;
409        let na = detail.anchors.len();
410        let shape = p.shape();
411        assert_eq!(
412            shape.iter().product::<usize>(),
413            p.len(),
414            "Shape product doesn't match tensor length"
415        );
416        let p_sigmoid = p.mapv(|x| fast_sigmoid_impl(x.as_()));
417        let p_sigmoid = p_sigmoid.as_standard_layout();
418
419        // Safe to unwrap since we ensured standard layout above
420        let p = p_sigmoid
421            .as_slice()
422            .expect("Sigmoids are not in standard layout");
423        let height = shape[0];
424        let width = shape[1];
425
426        let div_width = 1.0 / width as f32;
427        let div_height = 1.0 / height as f32;
428
429        let mut grid = Vec::with_capacity(height * width * na * 2);
430        for y in 0..height {
431            for x in 0..width {
432                for _ in 0..na {
433                    grid.push(x as f32 - 0.5);
434                    grid.push(y as f32 - 0.5);
435                }
436            }
437        }
438        for ((p, g), anchor) in p
439            .chunks_exact(nc + 5)
440            .zip(grid.chunks_exact(2))
441            .zip(anchors.iter().cycle())
442        {
443            let (x, y) = (p[0], p[1]);
444            let x = (x * 2.0 + g[0]) * div_width;
445            let y = (y * 2.0 + g[1]) * div_height;
446            let (w, h) = (p[2], p[3]);
447            let w = w * w * 4.0 * anchor[0];
448            let h = h * h * 4.0 * anchor[1];
449
450            bboxes.push(x);
451            bboxes.push(y);
452            bboxes.push(w);
453            bboxes.push(h);
454
455            if nc == 1 {
456                bscores.push(p[4]);
457            } else {
458                let obj = p[4];
459                let probs = p[5..].iter().map(|x| *x * obj);
460                bscores.extend(probs);
461            }
462        }
463    }
464    // Safe to unwrap since we ensured lengths will match above
465
466    debug_assert_eq!(bboxes.len() % 4, 0);
467    debug_assert_eq!(bscores.len() % nc, 0);
468
469    let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
470        .expect("Failed to create bboxes array");
471    let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
472        .expect("Failed to create bscores array");
473    (bboxes, bscores)
474}
475
476#[inline(always)]
477fn fast_sigmoid_impl(f: f32) -> f32 {
478    if f.abs() > 80.0 {
479        f.signum() * 0.5 + 0.5
480    } else {
481        // these values are only valid for -88 < x < 88
482        1.0 / (1.0 + fast_math::exp_raw(-f))
483    }
484}
485
486/// Converts ModelPack segmentation into a 2D mask.
487/// The input segmentation is expected to have shape (H, W, num_classes).
488///
489/// The output mask will have shape (H, W), with values `0..num_classes` based
490/// on the argmax across the channels.
491///
492/// # Panics
493/// Panics if the input tensor does not have more than one channel.
494pub(crate) fn modelpack_segmentation_to_mask(segmentation: ArrayView3<u8>) -> Array2<u8> {
495    use argminmax::ArgMinMax;
496    assert!(
497        segmentation.shape()[2] > 1,
498        "Model Instance Segmentation should have shape (H, W, x) where x > 1"
499    );
500    let height = segmentation.shape()[0];
501    let width = segmentation.shape()[1];
502    let channels = segmentation.shape()[2];
503    let segmentation = segmentation.as_standard_layout();
504    // Safe to unwrap since we ensured standard layout above
505    let seg = segmentation
506        .as_slice()
507        .expect("Segmentation is not in standard layout");
508    let argmax = seg
509        .chunks_exact(channels)
510        .map(|x| x.argmax() as u8)
511        .collect::<Vec<_>>();
512
513    Array2::from_shape_vec((height, width), argmax).expect("Failed to create mask array")
514}
515
516#[cfg(test)]
517#[cfg_attr(coverage_nightly, coverage(off))]
518mod modelpack_tests {
519    #![allow(clippy::excessive_precision)]
520    use ndarray::Array3;
521
522    use crate::configs::{DecoderType, DimName};
523
524    use super::*;
525    #[test]
526    fn test_detection_config() {
527        let det = Detection {
528            anchors: Some(vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]]),
529            quantization: Some((0.1, 128).into()),
530            decoder: DecoderType::ModelPack,
531            shape: vec![1, 9, 17, 18],
532            dshape: vec![
533                (DimName::Batch, 1),
534                (DimName::Height, 9),
535                (DimName::Width, 17),
536                (DimName::NumAnchorsXFeatures, 18),
537            ],
538            normalized: Some(true),
539        };
540        let config = ModelPackDetectionConfig::try_from(&det).unwrap();
541        assert_eq!(
542            config,
543            ModelPackDetectionConfig {
544                anchors: vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]],
545                quantization: Some(Quantization::new(0.1, 128)),
546            }
547        );
548
549        let det = Detection {
550            anchors: None,
551            quantization: Some((0.1, 128).into()),
552            decoder: DecoderType::ModelPack,
553            shape: vec![1, 9, 17, 18],
554            dshape: vec![
555                (DimName::Batch, 1),
556                (DimName::Height, 9),
557                (DimName::Width, 17),
558                (DimName::NumAnchorsXFeatures, 18),
559            ],
560            normalized: Some(true),
561        };
562        let result = ModelPackDetectionConfig::try_from(&det);
563        assert!(
564            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors")
565        );
566    }
567
568    #[test]
569    fn test_fast_sigmoid() {
570        fn full_sigmoid(x: f32) -> f32 {
571            1.0 / (1.0 + (-x).exp())
572        }
573        for i in -2550..=2550 {
574            let x = i as f32 * 0.1;
575            let fast = fast_sigmoid_impl(x);
576            let full = full_sigmoid(x);
577            let diff = (fast - full).abs();
578            assert!(
579                diff < 0.0005,
580                "Fast sigmoid differs from full sigmoid by {} at input {}",
581                diff,
582                x
583            );
584        }
585    }
586
587    #[test]
588    fn test_modelpack_segmentation_to_mask() {
589        let seg = Array3::from_shape_vec(
590            (2, 2, 3),
591            vec![
592                0u8, 10, 5, // pixel (0,0)
593                20, 15, 25, // pixel (0,1)
594                30, 5, 10, // pixel (1,0)
595                0, 0, 0, // pixel (1,1)
596            ],
597        )
598        .unwrap();
599        let mask = modelpack_segmentation_to_mask(seg.view());
600        let expected_mask = Array2::from_shape_vec((2, 2), vec![1u8, 2, 0, 0]).unwrap();
601        assert_eq!(mask, expected_mask);
602    }
603
604    #[test]
605    #[should_panic(
606        expected = "Model Instance Segmentation should have shape (H, W, x) where x > 1"
607    )]
608    fn test_modelpack_segmentation_to_mask_invalid() {
609        let seg = Array3::from_shape_vec((2, 2, 1), vec![0u8, 10, 20, 30]).unwrap();
610        let _ = modelpack_segmentation_to_mask(seg.view());
611    }
612}