Skip to main content

edgefirst_decoder/
modelpack.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use ndarray::{Array2, ArrayView2, ArrayView3};
5use num_traits::{AsPrimitive, Float, PrimInt};
6
7use crate::{
8    byte::{nms_int, postprocess_boxes_quant, quantize_score_threshold},
9    configs::Detection,
10    dequant_detect_box,
11    float::{nms_float, postprocess_boxes_float},
12    BBoxTypeTrait, DecoderError, DetectBox, Quantization, XYWH, XYXY,
13};
14
15/// Configuration for ModelPack split detection decoder. The quantization is
16/// ignored when decoding float models.
17#[derive(Debug, Clone, PartialEq)]
18pub struct ModelPackDetectionConfig {
19    pub anchors: Vec<[f32; 2]>,
20    pub quantization: Option<Quantization>,
21}
22
23impl TryFrom<&Detection> for ModelPackDetectionConfig {
24    type Error = DecoderError;
25
26    fn try_from(value: &Detection) -> Result<Self, DecoderError> {
27        Ok(Self {
28            anchors: value.anchors.clone().ok_or_else(|| {
29                DecoderError::InvalidConfig("ModelPack Split Detection missing anchors".to_string())
30            })?,
31            quantization: value.quantization.map(Quantization::from),
32        })
33    }
34}
35
36/// Decodes ModelPack detection outputs from quantized tensors.
37///
38/// The boxes are expected to be in XYXY format.
39///
40/// Expected shapes of inputs:
41/// - boxes: (num_boxes, 4)
42/// - scores: (num_boxes, num_classes)
43///
44/// # Panics
45/// Panics if shapes don't match the expected dimensions.
46pub fn decode_modelpack_det<
47    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
48    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
49>(
50    boxes_tensor: (ArrayView2<BOX>, Quantization),
51    scores_tensor: (ArrayView2<SCORE>, Quantization),
52    score_threshold: f32,
53    iou_threshold: f32,
54    output_boxes: &mut Vec<DetectBox>,
55) where
56    f32: AsPrimitive<SCORE>,
57{
58    impl_modelpack_quant::<XYXY, _, _>(
59        boxes_tensor,
60        scores_tensor,
61        score_threshold,
62        iou_threshold,
63        output_boxes,
64    )
65}
66
67/// Decodes ModelPack detection outputs from float tensors. The boxes
68/// are expected to be in XYXY format.
69///
70/// Expected shapes of inputs:
71/// - boxes: (num_boxes, 4)
72/// - scores: (num_boxes, num_classes)
73///
74/// # Panics
75/// Panics if shapes don't match the expected dimensions.
76pub fn decode_modelpack_float<
77    BOX: Float + AsPrimitive<f32> + Send + Sync,
78    SCORE: Float + AsPrimitive<f32> + Send + Sync,
79>(
80    boxes_tensor: ArrayView2<BOX>,
81    scores_tensor: ArrayView2<SCORE>,
82    score_threshold: f32,
83    iou_threshold: f32,
84    output_boxes: &mut Vec<DetectBox>,
85) where
86    f32: AsPrimitive<SCORE>,
87{
88    impl_modelpack_float::<XYXY, _, _>(
89        boxes_tensor,
90        scores_tensor,
91        score_threshold,
92        iou_threshold,
93        output_boxes,
94    )
95}
96
97/// Decodes ModelPack split detection outputs from quantized tensors. The boxes
98/// are expected to be in XYWH format.
99///
100/// The `configs` must correspond to the `outputs` in order.
101///
102/// Expected shapes of inputs:
103/// - outputs: (width, height, num_anchors * (5 + num_classes))
104///
105/// # Panics
106/// Panics if shapes don't match the expected dimensions.
107pub fn decode_modelpack_split_quant<D: AsPrimitive<f32>>(
108    outputs: &[ArrayView3<D>],
109    configs: &[ModelPackDetectionConfig],
110    score_threshold: f32,
111    iou_threshold: f32,
112    output_boxes: &mut Vec<DetectBox>,
113) {
114    impl_modelpack_split_quant::<XYWH, D>(
115        outputs,
116        configs,
117        score_threshold,
118        iou_threshold,
119        output_boxes,
120    )
121}
122
123/// Decodes ModelPack split detection outputs from float tensors. The boxes
124/// are expected to be in XYWH format.
125///
126/// The `configs` must correspond to the `outputs` in order.
127///
128/// Expected shapes of inputs:
129/// - outputs: (width, height, num_anchors * (5 + num_classes))
130///
131/// # Panics
132/// Panics if shapes don't match the expected dimensions.
133pub fn decode_modelpack_split_float<D: AsPrimitive<f32>>(
134    outputs: &[ArrayView3<D>],
135    configs: &[ModelPackDetectionConfig],
136    score_threshold: f32,
137    iou_threshold: f32,
138    output_boxes: &mut Vec<DetectBox>,
139) {
140    impl_modelpack_split_float::<XYWH, D>(
141        outputs,
142        configs,
143        score_threshold,
144        iou_threshold,
145        output_boxes,
146    );
147}
148/// Implementation of ModelPack detection decoding for quantized tensors.
149///
150/// Expected shapes of inputs:
151/// - boxes: (num_boxes, 4)
152/// - scores: (num_boxes, num_classes)
153///
154/// # Panics
155/// Panics if shapes don't match the expected dimensions.
156pub(crate) fn impl_modelpack_quant<
157    B: BBoxTypeTrait,
158    BOX: PrimInt + AsPrimitive<f32> + Send + Sync,
159    SCORE: PrimInt + AsPrimitive<f32> + Send + Sync,
160>(
161    boxes: (ArrayView2<BOX>, Quantization),
162    scores: (ArrayView2<SCORE>, Quantization),
163    score_threshold: f32,
164    iou_threshold: f32,
165    output_boxes: &mut Vec<DetectBox>,
166) where
167    f32: AsPrimitive<SCORE>,
168{
169    let (boxes_tensor, quant_boxes) = boxes;
170    let (scores_tensor, quant_scores) = scores;
171    let boxes = {
172        let score_threshold = quantize_score_threshold(score_threshold, quant_scores);
173        postprocess_boxes_quant::<B, _, _>(
174            score_threshold,
175            boxes_tensor,
176            scores_tensor,
177            quant_boxes,
178        )
179    };
180    let boxes = nms_int(iou_threshold, boxes);
181    let len = output_boxes.capacity().min(boxes.len());
182    output_boxes.clear();
183    for b in boxes.into_iter().take(len) {
184        output_boxes.push(dequant_detect_box(&b, quant_scores));
185    }
186}
187
188/// Implementation of ModelPack detection decoding for float tensors.
189///
190/// Expected shapes of inputs:
191/// - boxes: (num_boxes, 4)
192/// - scores: (num_boxes, num_classes)
193///
194/// # Panics
195/// Panics if shapes don't match the expected dimensions.
196pub(crate) fn impl_modelpack_float<
197    B: BBoxTypeTrait,
198    BOX: Float + AsPrimitive<f32> + Send + Sync,
199    SCORE: Float + AsPrimitive<f32> + Send + Sync,
200>(
201    boxes_tensor: ArrayView2<BOX>,
202    scores_tensor: ArrayView2<SCORE>,
203    score_threshold: f32,
204    iou_threshold: f32,
205    output_boxes: &mut Vec<DetectBox>,
206) where
207    f32: AsPrimitive<SCORE>,
208{
209    let boxes =
210        postprocess_boxes_float::<B, _, _>(score_threshold.as_(), boxes_tensor, scores_tensor);
211    let boxes = nms_float(iou_threshold, boxes);
212    let len = output_boxes.capacity().min(boxes.len());
213    output_boxes.clear();
214    for b in boxes.into_iter().take(len) {
215        output_boxes.push(b);
216    }
217}
218
219/// Implementation of ModelPack split detection decoding for quantized tensors.
220///
221/// Expected shapes of inputs:
222/// - boxes: (num_boxes, 4)
223/// - scores: (num_boxes, num_classes)
224///
225/// # Panics
226/// Panics if shapes don't match the expected dimensions.
227pub(crate) fn impl_modelpack_split_quant<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
228    outputs: &[ArrayView3<D>],
229    configs: &[ModelPackDetectionConfig],
230    score_threshold: f32,
231    iou_threshold: f32,
232    output_boxes: &mut Vec<DetectBox>,
233) {
234    let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_quant(outputs, configs);
235    let boxes = postprocess_boxes_float::<B, _, _>(
236        score_threshold,
237        boxes_tensor.view(),
238        scores_tensor.view(),
239    );
240    let boxes = nms_float(iou_threshold, boxes);
241    let len = output_boxes.capacity().min(boxes.len());
242    output_boxes.clear();
243    for b in boxes.into_iter().take(len) {
244        output_boxes.push(b);
245    }
246}
247
248/// Implementation of ModelPack split detection decoding for float tensors.
249///
250/// The `configs` must correspond to the `outputs` in order.
251///
252/// Expected shapes of inputs:
253/// - outputs: (width, height, num_anchors * (5 + num_classes))
254///
255/// # Panics
256/// Panics if shapes don't match the expected dimensions.
257pub(crate) fn impl_modelpack_split_float<B: BBoxTypeTrait, D: AsPrimitive<f32>>(
258    outputs: &[ArrayView3<D>],
259    configs: &[ModelPackDetectionConfig],
260    score_threshold: f32,
261    iou_threshold: f32,
262    output_boxes: &mut Vec<DetectBox>,
263) {
264    let (boxes_tensor, scores_tensor) = postprocess_modelpack_split_float(outputs, configs);
265    let boxes = postprocess_boxes_float::<B, _, _>(
266        score_threshold,
267        boxes_tensor.view(),
268        scores_tensor.view(),
269    );
270    let boxes = nms_float(iou_threshold, boxes);
271    let len = output_boxes.capacity().min(boxes.len());
272    output_boxes.clear();
273    for b in boxes.into_iter().take(len) {
274        output_boxes.push(b);
275    }
276}
277
278/// Post processes ModelPack split detection into detection boxes,
279/// filtering out any boxes below the score threshold. Returns the boxes and
280/// scores tensors. Boxes are in XYWH format.
281pub(crate) fn postprocess_modelpack_split_quant<T: AsPrimitive<f32>>(
282    outputs: &[ArrayView3<T>],
283    config: &[ModelPackDetectionConfig],
284) -> (Array2<f32>, Array2<f32>) {
285    let mut total_capacity = 0;
286    let mut nc = 0;
287    for (p, detail) in outputs.iter().zip(config) {
288        let shape = p.shape();
289        let na = detail.anchors.len();
290        nc = *shape
291            .last()
292            .expect("Shape must have at least one dimension")
293            / na
294            - 5;
295        total_capacity += shape[0] * shape[1] * na;
296    }
297    let mut bboxes = Vec::with_capacity(total_capacity * 4);
298    let mut bscores = Vec::with_capacity(total_capacity * nc);
299
300    for (p, detail) in outputs.iter().zip(config) {
301        let anchors = &detail.anchors;
302        let na = detail.anchors.len();
303        let shape = p.shape();
304        assert_eq!(
305            shape.iter().product::<usize>(),
306            p.len(),
307            "Shape product doesn't match tensor length"
308        );
309        let p_sigmoid = if let Some(quant) = &detail.quantization {
310            let scaled_zero = -quant.zero_point as f32 * quant.scale;
311            p.mapv(|x| fast_sigmoid_impl(x.as_() * quant.scale + scaled_zero))
312        } else {
313            p.mapv(|x| fast_sigmoid_impl(x.as_()))
314        };
315        let p_sigmoid = p_sigmoid.as_standard_layout();
316
317        // Safe to unwrap since we ensured standard layout above
318        let p = p_sigmoid
319            .as_slice()
320            .expect("Sigmoids are not in standard layout");
321        let height = shape[0];
322        let width = shape[1];
323
324        let div_width = 1.0 / width as f32;
325        let div_height = 1.0 / height as f32;
326
327        let mut grid = Vec::with_capacity(height * width * na * 2);
328        for y in 0..height {
329            for x in 0..width {
330                for _ in 0..na {
331                    grid.push(x as f32 - 0.5);
332                    grid.push(y as f32 - 0.5);
333                }
334            }
335        }
336        for ((p, g), anchor) in p
337            .chunks_exact(nc + 5)
338            .zip(grid.chunks_exact(2))
339            .zip(anchors.iter().cycle())
340        {
341            let (x, y) = (p[0], p[1]);
342            let x = (x * 2.0 + g[0]) * div_width;
343            let y = (y * 2.0 + g[1]) * div_height;
344            let (w, h) = (p[2], p[3]);
345            let w = w * w * 4.0 * anchor[0];
346            let h = h * h * 4.0 * anchor[1];
347
348            bboxes.push(x);
349            bboxes.push(y);
350            bboxes.push(w);
351            bboxes.push(h);
352
353            if nc == 1 {
354                bscores.push(p[4]);
355            } else {
356                let obj = p[4];
357                let probs = p[5..].iter().map(|x| *x * obj);
358                bscores.extend(probs);
359            }
360        }
361    }
362    // Safe to unwrap since we ensured lengths will match above
363
364    debug_assert_eq!(bboxes.len() % 4, 0);
365    debug_assert_eq!(bscores.len() % nc, 0);
366
367    let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
368        .expect("Failed to create bboxes array");
369    let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
370        .expect("Failed to create bscores array");
371    (bboxes, bscores)
372}
373
374/// Post processes ModelPack split detection into detection boxes,
375/// filtering out any boxes below the score threshold. Returns the boxes and
376/// scores tensors. Boxes are in XYWH format.
377pub(crate) fn postprocess_modelpack_split_float<T: AsPrimitive<f32>>(
378    outputs: &[ArrayView3<T>],
379    config: &[ModelPackDetectionConfig],
380) -> (Array2<f32>, Array2<f32>) {
381    let mut total_capacity = 0;
382    let mut nc = 0;
383    for (p, detail) in outputs.iter().zip(config) {
384        let shape = p.shape();
385        let na = detail.anchors.len();
386        nc = *shape
387            .last()
388            .expect("Shape must have at least one dimension")
389            / na
390            - 5;
391        total_capacity += shape[0] * shape[1] * na;
392    }
393    let mut bboxes = Vec::with_capacity(total_capacity * 4);
394    let mut bscores = Vec::with_capacity(total_capacity * nc);
395
396    for (p, detail) in outputs.iter().zip(config) {
397        let anchors = &detail.anchors;
398        let na = detail.anchors.len();
399        let shape = p.shape();
400        assert_eq!(
401            shape.iter().product::<usize>(),
402            p.len(),
403            "Shape product doesn't match tensor length"
404        );
405        let p_sigmoid = p.mapv(|x| fast_sigmoid_impl(x.as_()));
406        let p_sigmoid = p_sigmoid.as_standard_layout();
407
408        // Safe to unwrap since we ensured standard layout above
409        let p = p_sigmoid
410            .as_slice()
411            .expect("Sigmoids are not in standard layout");
412        let height = shape[0];
413        let width = shape[1];
414
415        let div_width = 1.0 / width as f32;
416        let div_height = 1.0 / height as f32;
417
418        let mut grid = Vec::with_capacity(height * width * na * 2);
419        for y in 0..height {
420            for x in 0..width {
421                for _ in 0..na {
422                    grid.push(x as f32 - 0.5);
423                    grid.push(y as f32 - 0.5);
424                }
425            }
426        }
427        for ((p, g), anchor) in p
428            .chunks_exact(nc + 5)
429            .zip(grid.chunks_exact(2))
430            .zip(anchors.iter().cycle())
431        {
432            let (x, y) = (p[0], p[1]);
433            let x = (x * 2.0 + g[0]) * div_width;
434            let y = (y * 2.0 + g[1]) * div_height;
435            let (w, h) = (p[2], p[3]);
436            let w = w * w * 4.0 * anchor[0];
437            let h = h * h * 4.0 * anchor[1];
438
439            bboxes.push(x);
440            bboxes.push(y);
441            bboxes.push(w);
442            bboxes.push(h);
443
444            if nc == 1 {
445                bscores.push(p[4]);
446            } else {
447                let obj = p[4];
448                let probs = p[5..].iter().map(|x| *x * obj);
449                bscores.extend(probs);
450            }
451        }
452    }
453    // Safe to unwrap since we ensured lengths will match above
454
455    debug_assert_eq!(bboxes.len() % 4, 0);
456    debug_assert_eq!(bscores.len() % nc, 0);
457
458    let bboxes = Array2::from_shape_vec((bboxes.len() / 4, 4), bboxes)
459        .expect("Failed to create bboxes array");
460    let bscores = Array2::from_shape_vec((bscores.len() / nc, nc), bscores)
461        .expect("Failed to create bscores array");
462    (bboxes, bscores)
463}
464
465#[inline(always)]
466fn fast_sigmoid_impl(f: f32) -> f32 {
467    if f.abs() > 80.0 {
468        f.signum() * 0.5 + 0.5
469    } else {
470        // these values are only valid for -88 < x < 88
471        1.0 / (1.0 + fast_math::exp_raw(-f))
472    }
473}
474
475/// Converts ModelPack segmentation into a 2D mask.
476/// The input segmentation is expected to have shape (H, W, num_classes).
477///
478/// The output mask will have shape (H, W), with values `0..num_classes` based
479/// on the argmax across the channels.
480///
481/// # Panics
482/// Panics if the input tensor does not have more than one channel.
483pub fn modelpack_segmentation_to_mask(segmentation: ArrayView3<u8>) -> Array2<u8> {
484    use argminmax::ArgMinMax;
485    assert!(
486        segmentation.shape()[2] > 1,
487        "Model Instance Segmentation should have shape (H, W, x) where x > 1"
488    );
489    let height = segmentation.shape()[0];
490    let width = segmentation.shape()[1];
491    let channels = segmentation.shape()[2];
492    let segmentation = segmentation.as_standard_layout();
493    // Safe to unwrap since we ensured standard layout above
494    let seg = segmentation
495        .as_slice()
496        .expect("Segmentation is not in standard layout");
497    let argmax = seg
498        .chunks_exact(channels)
499        .map(|x| x.argmax() as u8)
500        .collect::<Vec<_>>();
501
502    Array2::from_shape_vec((height, width), argmax).expect("Failed to create mask array")
503}
504
505#[cfg(test)]
506#[cfg_attr(coverage_nightly, coverage(off))]
507mod modelpack_tests {
508    #![allow(clippy::excessive_precision)]
509    use ndarray::Array3;
510
511    use crate::configs::{DecoderType, DimName};
512
513    use super::*;
514    #[test]
515    fn test_detection_config() {
516        let det = Detection {
517            anchors: Some(vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]]),
518            quantization: Some((0.1, 128).into()),
519            decoder: DecoderType::ModelPack,
520            shape: vec![1, 9, 17, 18],
521            dshape: vec![
522                (DimName::Batch, 1),
523                (DimName::Height, 9),
524                (DimName::Width, 17),
525                (DimName::NumAnchorsXFeatures, 18),
526            ],
527            normalized: Some(true),
528        };
529        let config = ModelPackDetectionConfig::try_from(&det).unwrap();
530        assert_eq!(
531            config,
532            ModelPackDetectionConfig {
533                anchors: vec![[0.1, 0.13], [0.16, 0.30], [0.33, 0.23]],
534                quantization: Some(Quantization::new(0.1, 128)),
535            }
536        );
537
538        let det = Detection {
539            anchors: None,
540            quantization: Some((0.1, 128).into()),
541            decoder: DecoderType::ModelPack,
542            shape: vec![1, 9, 17, 18],
543            dshape: vec![
544                (DimName::Batch, 1),
545                (DimName::Height, 9),
546                (DimName::Width, 17),
547                (DimName::NumAnchorsXFeatures, 18),
548            ],
549            normalized: Some(true),
550        };
551        let result = ModelPackDetectionConfig::try_from(&det);
552        assert!(
553            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors")
554        );
555    }
556
557    #[test]
558    fn test_fast_sigmoid() {
559        fn full_sigmoid(x: f32) -> f32 {
560            1.0 / (1.0 + (-x).exp())
561        }
562        for i in -2550..=2550 {
563            let x = i as f32 * 0.1;
564            let fast = fast_sigmoid_impl(x);
565            let full = full_sigmoid(x);
566            let diff = (fast - full).abs();
567            assert!(
568                diff < 0.0005,
569                "Fast sigmoid differs from full sigmoid by {} at input {}",
570                diff,
571                x
572            );
573        }
574    }
575
576    #[test]
577    fn test_modelpack_segmentation_to_mask() {
578        let seg = Array3::from_shape_vec(
579            (2, 2, 3),
580            vec![
581                0u8, 10, 5, // pixel (0,0)
582                20, 15, 25, // pixel (0,1)
583                30, 5, 10, // pixel (1,0)
584                0, 0, 0, // pixel (1,1)
585            ],
586        )
587        .unwrap();
588        let mask = modelpack_segmentation_to_mask(seg.view());
589        let expected_mask = Array2::from_shape_vec((2, 2), vec![1u8, 2, 0, 0]).unwrap();
590        assert_eq!(mask, expected_mask);
591    }
592
593    #[test]
594    #[should_panic(
595        expected = "Model Instance Segmentation should have shape (H, W, x) where x > 1"
596    )]
597    fn test_modelpack_segmentation_to_mask_invalid() {
598        let seg = Array3::from_shape_vec((2, 2, 1), vec![0u8, 10, 20, 30]).unwrap();
599        let _ = modelpack_segmentation_to_mask(seg.view());
600    }
601}