Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust
16# use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs::{self, DecoderVersion} };
17# fn main() -> DecoderResult<()> {
18// Create a decoder for a YOLOv8 model with quantized int8 output with 0.25 score threshold and 0.7 IOU threshold
19let decoder = DecoderBuilder::new()
20    .with_config_yolo_det(configs::Detection {
21        anchors: None,
22        decoder: configs::DecoderType::Ultralytics,
23        quantization: Some(configs::QuantTuple(0.012345, 26)),
24        shape: vec![1, 84, 8400],
25        dshape: Vec::new(),
26        normalized: Some(true),
27    },
28    Some(DecoderVersion::Yolov8))
29    .with_score_threshold(0.25)
30    .with_iou_threshold(0.7)
31    .build()?;
32
33// Get the model output from the model. Here we load it from a test data file for demonstration purposes.
34let model_output: Vec<i8> = include_bytes!("../../../testdata/yolov8s_80_classes.bin")
35    .iter()
36    .map(|b| *b as i8)
37    .collect();
38let model_output_array = ndarray::Array3::from_shape_vec((1, 84, 8400), model_output)?;
39
40// THe capacity is used to determine the maximum number of detections to decode.
41let mut output_boxes: Vec<_> = Vec::with_capacity(10);
42let mut output_masks: Vec<_> = Vec::with_capacity(10);
43
44// Decode the quantized model output into detection boxes and segmentation masks
45// Because this model is a detection-only model, the `output_masks` vector will remain empty.
46decoder.decode_quantized(&[model_output_array.view().into()], &mut output_boxes, &mut output_masks)?;
47# Ok(())
48# }
49```
50
51# Overview
52
53The primary components of this crate are:
54- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
55- `yolo` module: Contains functions specific to decoding YOLO model outputs.
56- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
57
58The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
59It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
60When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
61If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
62the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
63to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
64
65The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
66which can be used if the model type and output formats are known in advance.
67
68
69*/
70#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89    yolo::yolo_segmentation_to_mask,
90};
91
92/// Trait to convert bounding box formats to XYXY float format
93pub trait BBoxTypeTrait {
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97    /// Converts the bbox into XYXY float format.
98    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99        input: &[B; 4],
100        quant: Quantization,
101    ) -> [A; 4]
102    where
103        f32: AsPrimitive<A>,
104        i32: AsPrimitive<A>;
105
106    /// Converts the bbox into XYXY float format.
107    ///
108    /// # Examples
109    /// ```rust
110    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
111    /// # use ndarray::array;
112    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
113    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
114    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
115    /// ```
116    #[inline(always)]
117    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118        input: ArrayView1<B>,
119    ) -> [A; 4] {
120        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121    }
122
123    #[inline(always)]
124    /// Converts the bbox into XYXY float format.
125    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126        input: ArrayView1<B>,
127        quant: Quantization,
128    ) -> [A; 4]
129    where
130        f32: AsPrimitive<A>,
131        i32: AsPrimitive<A>,
132    {
133        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134    }
135}
136
137/// Converts XYXY bounding boxes to XYXY
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143        input.map(|b| b.as_())
144    }
145
146    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147        input: &[B; 4],
148        quant: Quantization,
149    ) -> [A; 4]
150    where
151        f32: AsPrimitive<A>,
152        i32: AsPrimitive<A>,
153    {
154        let scale = quant.scale.as_();
155        let zp = quant.zero_point.as_();
156        input.map(|b| (b.as_() - zp) * scale)
157    }
158
159    #[inline(always)]
160    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161        input: ArrayView1<B>,
162    ) -> [A; 4] {
163        [
164            input[0].as_(),
165            input[1].as_(),
166            input[2].as_(),
167            input[3].as_(),
168        ]
169    }
170}
171
172/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
173/// box
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178    #[inline(always)]
179    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180        let half = A::one() / (A::one() + A::one());
181        [
182            (input[0].as_()) - (input[2].as_() * half),
183            (input[1].as_()) - (input[3].as_() * half),
184            (input[0].as_()) + (input[2].as_() * half),
185            (input[1].as_()) + (input[3].as_() * half),
186        ]
187    }
188
189    #[inline(always)]
190    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191        input: &[B; 4],
192        quant: Quantization,
193    ) -> [A; 4]
194    where
195        f32: AsPrimitive<A>,
196        i32: AsPrimitive<A>,
197    {
198        let scale = quant.scale.as_();
199        let half_scale = (quant.scale * 0.5).as_();
200        let zp = quant.zero_point.as_();
201        let [x, y, w, h] = [
202            (input[0].as_() - zp) * scale,
203            (input[1].as_() - zp) * scale,
204            (input[2].as_() - zp) * half_scale,
205            (input[3].as_() - zp) * half_scale,
206        ];
207
208        [x - w, y - h, x + w, y + h]
209    }
210
211    #[inline(always)]
212    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213        input: ArrayView1<B>,
214    ) -> [A; 4] {
215        let half = A::one() / (A::one() + A::one());
216        [
217            (input[0].as_()) - (input[2].as_() * half),
218            (input[1].as_()) - (input[3].as_() * half),
219            (input[0].as_()) + (input[2].as_() * half),
220            (input[1].as_()) + (input[3].as_() * half),
221        ]
222    }
223}
224
225/// Describes the quantization parameters for a tensor
226#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228    pub scale: f32,
229    pub zero_point: i32,
230}
231
232impl Quantization {
233    /// Creates a new Quantization struct
234    /// # Examples
235    /// ```
236    /// # use edgefirst_decoder::Quantization;
237    /// let quant = Quantization::new(0.1, -128);
238    /// assert_eq!(quant.scale, 0.1);
239    /// assert_eq!(quant.zero_point, -128);
240    /// ```
241    pub fn new(scale: f32, zero_point: i32) -> Self {
242        Self { scale, zero_point }
243    }
244}
245
246impl From<QuantTuple> for Quantization {
247    /// Creates a new Quantization struct from a QuantTuple
248    /// # Examples
249    /// ```
250    /// # use edgefirst_decoder::Quantization;
251    /// # use edgefirst_decoder::configs::QuantTuple;
252    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
253    /// let quant = Quantization::from(quant_tuple);
254    /// assert_eq!(quant.scale, 0.1);
255    /// assert_eq!(quant.zero_point, -128);
256    /// ```
257    fn from(quant_tuple: QuantTuple) -> Quantization {
258        Quantization {
259            scale: quant_tuple.0,
260            zero_point: quant_tuple.1,
261        }
262    }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267    S: AsPrimitive<f32>,
268    Z: AsPrimitive<i32>,
269{
270    /// Creates a new Quantization struct from a tuple
271    /// # Examples
272    /// ```
273    /// # use edgefirst_decoder::Quantization;
274    /// let quant = Quantization::from((0.1_f64, -128_i64));
275    /// assert_eq!(quant.scale, 0.1);
276    /// assert_eq!(quant.zero_point, -128);
277    /// ```
278    fn from((scale, zp): (S, Z)) -> Quantization {
279        Self {
280            scale: scale.as_(),
281            zero_point: zp.as_(),
282        }
283    }
284}
285
286impl Default for Quantization {
287    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
288    /// # Examples
289    /// ```rust
290    /// # use edgefirst_decoder::Quantization;
291    /// let quant = Quantization::default();
292    /// assert_eq!(quant.scale, 1.0);
293    /// assert_eq!(quant.zero_point, 0);
294    /// ```
295    fn default() -> Self {
296        Self {
297            scale: 1.0,
298            zero_point: 0,
299        }
300    }
301}
302
303/// A detection box with f32 bbox and score
304#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306    pub bbox: BoundingBox,
307    /// model-specific score for this detection, higher implies more confidence
308    pub score: f32,
309    /// label index for this detection
310    pub label: usize,
311}
312
313/// A bounding box with f32 coordinates in XYXY format
314#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316    /// left-most normalized coordinate of the bounding box
317    pub xmin: f32,
318    /// top-most normalized coordinate of the bounding box
319    pub ymin: f32,
320    /// right-most normalized coordinate of the bounding box
321    pub xmax: f32,
322    /// bottom-most normalized coordinate of the bounding box
323    pub ymax: f32,
324}
325
326impl BoundingBox {
327    /// Creates a new BoundingBox from the given coordinates
328    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329        Self {
330            xmin,
331            ymin,
332            xmax,
333            ymax,
334        }
335    }
336
337    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
338    ///
339    /// ```
340    /// # use edgefirst_decoder::BoundingBox;
341    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
342    /// let canonical_bbox = bbox.to_canonical();
343    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
344    /// ```
345    pub fn to_canonical(&self) -> Self {
346        let xmin = self.xmin.min(self.xmax);
347        let xmax = self.xmin.max(self.xmax);
348        let ymin = self.ymin.min(self.ymax);
349        let ymax = self.ymin.max(self.ymax);
350        BoundingBox {
351            xmin,
352            ymin,
353            xmax,
354            ymax,
355        }
356    }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
361    /// xmax, ymax order
362    /// # Examples
363    /// ```
364    /// # use edgefirst_decoder::BoundingBox;
365    /// let bbox = BoundingBox {
366    ///     xmin: 0.1,
367    ///     ymin: 0.2,
368    ///     xmax: 0.3,
369    ///     ymax: 0.4,
370    /// };
371    /// let arr: [f32; 4] = bbox.into();
372    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
373    /// ```
374    fn from(b: BoundingBox) -> Self {
375        [b.xmin, b.ymin, b.xmax, b.ymax]
376    }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
381    // BoundingBox
382    fn from(arr: [f32; 4]) -> Self {
383        BoundingBox {
384            xmin: arr[0],
385            ymin: arr[1],
386            xmax: arr[2],
387            ymax: arr[3],
388        }
389    }
390}
391
392impl DetectBox {
393    /// Returns true if one detect box is equal to another detect box, within
394    /// the given `eps`
395    ///
396    /// # Examples
397    /// ```
398    /// # use edgefirst_decoder::DetectBox;
399    /// let box1 = DetectBox {
400    ///     bbox: edgefirst_decoder::BoundingBox {
401    ///         xmin: 0.1,
402    ///         ymin: 0.2,
403    ///         xmax: 0.3,
404    ///         ymax: 0.4,
405    ///     },
406    ///     score: 0.5,
407    ///     label: 1,
408    /// };
409    /// let box2 = DetectBox {
410    ///     bbox: edgefirst_decoder::BoundingBox {
411    ///         xmin: 0.101,
412    ///         ymin: 0.199,
413    ///         xmax: 0.301,
414    ///         ymax: 0.399,
415    ///     },
416    ///     score: 0.510,
417    ///     label: 1,
418    /// };
419    /// assert!(box1.equal_within_delta(&box2, 0.011));
420    /// ```
421    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423        self.label == rhs.label
424            && eq_delta(self.score, rhs.score)
425            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429    }
430}
431
432/// A segmentation result with a segmentation mask, and a normalized bounding
433/// box representing the area that the segmentation mask covers
434#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436    /// left-most normalized coordinate of the segmentation box
437    pub xmin: f32,
438    /// top-most normalized coordinate of the segmentation box
439    pub ymin: f32,
440    /// right-most normalized coordinate of the segmentation box
441    pub xmax: f32,
442    /// bottom-most normalized coordinate of the segmentation box
443    pub ymax: f32,
444    /// 3D segmentation array. If the last dimension is 1, values equal or above
445    /// 128 are considered objects. Otherwise the object is the argmax index
446    pub segmentation: Array3<u8>,
447}
448
449/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
450///
451///  # Examples
452/// ```
453/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
454/// let quant = Quantization::new(0.1, -128);
455/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
456/// let detect_quant = DetectBoxQuantized {
457///     bbox,
458///     score: 100_i8,
459///     label: 1,
460/// };
461/// let detect = dequant_detect_box(&detect_quant, quant);
462/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
463/// assert_eq!(detect.label, 1);
464/// assert_eq!(detect.bbox, bbox);
465/// ```
466pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
467    detect: &DetectBoxQuantized<SCORE>,
468    quant_scores: Quantization,
469) -> DetectBox {
470    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
471    DetectBox {
472        bbox: detect.bbox,
473        score: quant_scores.scale * detect.score.as_() + scaled_zp,
474        label: detect.label,
475    }
476}
477/// A detection box with a f32 bbox and quantized score
478#[derive(Debug, Clone, Copy, PartialEq)]
479pub struct DetectBoxQuantized<
480    // BOX: Signed + PrimInt + AsPrimitive<f32>,
481    SCORE: PrimInt + AsPrimitive<f32>,
482> {
483    // pub bbox: BoundingBoxQuantized<BOX>,
484    pub bbox: BoundingBox,
485    /// model-specific score for this detection, higher implies more
486    /// confidence.
487    pub score: SCORE,
488    /// label index for this detect
489    pub label: usize,
490}
491
492/// Dequantizes an ndarray from quantized values to f32 values using the given
493/// quantization parameters
494///
495/// # Examples
496/// ```
497/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
498/// let quant = Quantization::new(0.1, -128);
499/// let input: Vec<i8> = vec![0, 127, -128, 64];
500/// let input_array = ndarray::Array1::from(input);
501/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
502/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
503/// ```
504pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
505    input: ArrayView<T, D>,
506    quant: Quantization,
507) -> Array<F, D>
508where
509    i32: num_traits::AsPrimitive<F>,
510    f32: num_traits::AsPrimitive<F>,
511{
512    let zero_point = quant.zero_point.as_();
513    let scale = quant.scale.as_();
514    if zero_point != F::zero() {
515        let scaled_zero = -zero_point * scale;
516        input.mapv(|d| d.as_() * scale + scaled_zero)
517    } else {
518        input.mapv(|d| d.as_() * scale)
519    }
520}
521
522/// Dequantizes a slice from quantized values to float values using the given
523/// quantization parameters
524///
525/// # Examples
526/// ```
527/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
528/// let quant = Quantization::new(0.1, -128);
529/// let input: Vec<i8> = vec![0, 127, -128, 64];
530/// let mut output: Vec<f32> = vec![0.0; input.len()];
531/// dequantize_cpu(&input, quant, &mut output);
532/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
533/// ```
534pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
535    input: &[T],
536    quant: Quantization,
537    output: &mut [F],
538) where
539    f32: num_traits::AsPrimitive<F>,
540    i32: num_traits::AsPrimitive<F>,
541{
542    assert!(input.len() == output.len());
543    let zero_point = quant.zero_point.as_();
544    let scale = quant.scale.as_();
545    if zero_point != F::zero() {
546        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
547        input
548            .iter()
549            .zip(output)
550            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
551    } else {
552        input
553            .iter()
554            .zip(output)
555            .for_each(|(d, deq)| *deq = d.as_() * scale);
556    }
557}
558
559/// Dequantizes a slice from quantized values to float values using the given
560/// quantization parameters, using chunked processing. This is around 5% faster
561/// than `dequantize_cpu` for large slices.
562///
563/// # Examples
564/// ```
565/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
566/// let quant = Quantization::new(0.1, -128);
567/// let input: Vec<i8> = vec![0, 127, -128, 64];
568/// let mut output: Vec<f32> = vec![0.0; input.len()];
569/// dequantize_cpu_chunked(&input, quant, &mut output);
570/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
571/// ```
572pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
573    input: &[T],
574    quant: Quantization,
575    output: &mut [F],
576) where
577    f32: num_traits::AsPrimitive<F>,
578    i32: num_traits::AsPrimitive<F>,
579{
580    assert!(input.len() == output.len());
581    let zero_point = quant.zero_point.as_();
582    let scale = quant.scale.as_();
583
584    let input = input.as_chunks::<4>();
585    let output = output.as_chunks_mut::<4>();
586
587    if zero_point != F::zero() {
588        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
589
590        input
591            .0
592            .iter()
593            .zip(output.0)
594            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
595        input
596            .1
597            .iter()
598            .zip(output.1)
599            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
600    } else {
601        input
602            .0
603            .iter()
604            .zip(output.0)
605            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
606        input
607            .1
608            .iter()
609            .zip(output.1)
610            .for_each(|(d, deq)| *deq = d.as_() * scale);
611    }
612}
613
614/// Converts a segmentation tensor into a 2D mask
615/// If the last dimension of the segmentation tensor is 1, values equal or
616/// above 128 are considered objects. Otherwise the object is the argmax index
617///
618/// # Errors
619///
620/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
621/// invalid shape.
622///
623/// # Examples
624/// ```
625/// # use edgefirst_decoder::segmentation_to_mask;
626/// let segmentation =
627///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
628/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
629/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
630/// ```
631pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
632    if segmentation.shape()[2] == 0 {
633        return Err(DecoderError::InvalidShape(
634            "Segmentation tensor must have non-zero depth".to_string(),
635        ));
636    }
637    if segmentation.shape()[2] == 1 {
638        yolo_segmentation_to_mask(segmentation, 128)
639    } else {
640        Ok(modelpack_segmentation_to_mask(segmentation))
641    }
642}
643
644/// Returns the maximum value and its index from a 1D array
645fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
646    score
647        .iter()
648        .enumerate()
649        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
650            if max > *s {
651                (max, arg_max)
652            } else {
653                (*s, ind)
654            }
655        })
656}
657#[cfg(test)]
658#[cfg_attr(coverage_nightly, coverage(off))]
659mod decoder_tests {
660    #![allow(clippy::excessive_precision)]
661    use crate::{
662        configs::{DecoderType, DimName, Protos},
663        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
664        yolo::{
665            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
666            decode_yolo_segdet_quant,
667        },
668        *,
669    };
670    use ndarray::{array, s, Array4};
671    use ndarray_stats::DeviationExt;
672
673    fn compare_outputs(
674        boxes: (&[DetectBox], &[DetectBox]),
675        masks: (&[Segmentation], &[Segmentation]),
676    ) {
677        let (boxes0, boxes1) = boxes;
678        let (masks0, masks1) = masks;
679
680        assert_eq!(boxes0.len(), boxes1.len());
681        assert_eq!(masks0.len(), masks1.len());
682
683        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
684            assert!(
685                b_i8.equal_within_delta(b_f32, 1e-6),
686                "{b_i8:?} is not equal to {b_f32:?}"
687            );
688        }
689
690        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
691            assert_eq!(
692                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
693                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
694            );
695            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
696            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
697            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
698            let diff = &mask_i8 - &mask_f32;
699            for x in 0..diff.shape()[0] {
700                for y in 0..diff.shape()[1] {
701                    for z in 0..diff.shape()[2] {
702                        let val = diff[[x, y, z]];
703                        assert!(
704                            val.abs() <= 1,
705                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
706                            x,
707                            y,
708                            z,
709                            val
710                        );
711                    }
712                }
713            }
714            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
715            assert!(
716                mean_sq_err < 1e-2,
717                "Mean Square Error between masks was greater than 1%: {:.2}%",
718                mean_sq_err * 100.0
719            );
720        }
721    }
722
723    #[test]
724    fn test_decoder_modelpack() {
725        let score_threshold = 0.45;
726        let iou_threshold = 0.45;
727        let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
728        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
729
730        let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
731        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
732
733        let quant_boxes = (0.004656755365431309, 21).into();
734        let quant_scores = (0.0019603664986789227, 0).into();
735
736        let decoder = DecoderBuilder::default()
737            .with_config_modelpack_det(
738                configs::Boxes {
739                    decoder: DecoderType::ModelPack,
740                    quantization: Some(quant_boxes),
741                    shape: vec![1, 1935, 1, 4],
742                    dshape: vec![
743                        (DimName::Batch, 1),
744                        (DimName::NumBoxes, 1935),
745                        (DimName::Padding, 1),
746                        (DimName::BoxCoords, 4),
747                    ],
748                    normalized: Some(true),
749                },
750                configs::Scores {
751                    decoder: DecoderType::ModelPack,
752                    quantization: Some(quant_scores),
753                    shape: vec![1, 1935, 1],
754                    dshape: vec![
755                        (DimName::Batch, 1),
756                        (DimName::NumBoxes, 1935),
757                        (DimName::NumClasses, 1),
758                    ],
759                },
760            )
761            .with_score_threshold(score_threshold)
762            .with_iou_threshold(iou_threshold)
763            .build()
764            .unwrap();
765
766        let quant_boxes = quant_boxes.into();
767        let quant_scores = quant_scores.into();
768
769        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
770        decode_modelpack_det(
771            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
772            (scores.slice(s![0, .., ..]), quant_scores),
773            score_threshold,
774            iou_threshold,
775            &mut output_boxes,
776        );
777        assert!(output_boxes[0].equal_within_delta(
778            &DetectBox {
779                bbox: BoundingBox {
780                    xmin: 0.40513772,
781                    ymin: 0.6379755,
782                    xmax: 0.5122431,
783                    ymax: 0.7730214,
784                },
785                score: 0.4861709,
786                label: 0
787            },
788            1e-6
789        ));
790
791        let mut output_boxes1 = Vec::with_capacity(50);
792        let mut output_masks1 = Vec::with_capacity(50);
793
794        decoder
795            .decode_quantized(
796                &[boxes.view().into(), scores.view().into()],
797                &mut output_boxes1,
798                &mut output_masks1,
799            )
800            .unwrap();
801
802        let mut output_boxes_float = Vec::with_capacity(50);
803        let mut output_masks_float = Vec::with_capacity(50);
804
805        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
806        let scores = dequantize_ndarray(scores.view(), quant_scores);
807
808        decoder
809            .decode_float::<f32>(
810                &[boxes.view().into_dyn(), scores.view().into_dyn()],
811                &mut output_boxes_float,
812                &mut output_masks_float,
813            )
814            .unwrap();
815
816        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
817        compare_outputs(
818            (&output_boxes, &output_boxes_float),
819            (&[], &output_masks_float),
820        );
821    }
822
823    #[test]
824    fn test_decoder_modelpack_split_u8() {
825        let score_threshold = 0.45;
826        let iou_threshold = 0.45;
827        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
828        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
829
830        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
831        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
832
833        let quant0 = (0.08547406643629074, 174).into();
834        let quant1 = (0.09929127991199493, 183).into();
835        let anchors0 = vec![
836            [0.36666667461395264, 0.31481480598449707],
837            [0.38749998807907104, 0.4740740656852722],
838            [0.5333333611488342, 0.644444465637207],
839        ];
840        let anchors1 = vec![
841            [0.13750000298023224, 0.2074074000120163],
842            [0.2541666626930237, 0.21481481194496155],
843            [0.23125000298023224, 0.35185185074806213],
844        ];
845
846        let detect_config0 = configs::Detection {
847            decoder: DecoderType::ModelPack,
848            shape: vec![1, 9, 15, 18],
849            anchors: Some(anchors0.clone()),
850            quantization: Some(quant0),
851            dshape: vec![
852                (DimName::Batch, 1),
853                (DimName::Height, 9),
854                (DimName::Width, 15),
855                (DimName::NumAnchorsXFeatures, 18),
856            ],
857            normalized: Some(true),
858        };
859
860        let detect_config1 = configs::Detection {
861            decoder: DecoderType::ModelPack,
862            shape: vec![1, 17, 30, 18],
863            anchors: Some(anchors1.clone()),
864            quantization: Some(quant1),
865            dshape: vec![
866                (DimName::Batch, 1),
867                (DimName::Height, 17),
868                (DimName::Width, 30),
869                (DimName::NumAnchorsXFeatures, 18),
870            ],
871            normalized: Some(true),
872        };
873
874        let config0 = (&detect_config0).try_into().unwrap();
875        let config1 = (&detect_config1).try_into().unwrap();
876
877        let decoder = DecoderBuilder::default()
878            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
879            .with_score_threshold(score_threshold)
880            .with_iou_threshold(iou_threshold)
881            .build()
882            .unwrap();
883
884        let quant0 = quant0.into();
885        let quant1 = quant1.into();
886
887        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
888        decode_modelpack_split_quant(
889            &[
890                detect0.slice(s![0, .., .., ..]),
891                detect1.slice(s![0, .., .., ..]),
892            ],
893            &[config0, config1],
894            score_threshold,
895            iou_threshold,
896            &mut output_boxes,
897        );
898        assert!(output_boxes[0].equal_within_delta(
899            &DetectBox {
900                bbox: BoundingBox {
901                    xmin: 0.43171933,
902                    ymin: 0.68243736,
903                    xmax: 0.5626645,
904                    ymax: 0.808863,
905                },
906                score: 0.99240804,
907                label: 0
908            },
909            1e-6
910        ));
911
912        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
913        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
914        decoder
915            .decode_quantized(
916                &[detect0.view().into(), detect1.view().into()],
917                &mut output_boxes1,
918                &mut output_masks1,
919            )
920            .unwrap();
921
922        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
923        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
924
925        let detect0 = dequantize_ndarray(detect0.view(), quant0);
926        let detect1 = dequantize_ndarray(detect1.view(), quant1);
927        decoder
928            .decode_float::<f32>(
929                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
930                &mut output_boxes1_f32,
931                &mut output_masks1_f32,
932            )
933            .unwrap();
934
935        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
936        compare_outputs(
937            (&output_boxes, &output_boxes1_f32),
938            (&[], &output_masks1_f32),
939        );
940    }
941
942    #[test]
943    fn test_decoder_parse_config_modelpack_split_u8() {
944        let score_threshold = 0.45;
945        let iou_threshold = 0.45;
946        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
947        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
948
949        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
950        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
951
952        let decoder = DecoderBuilder::default()
953            .with_config_yaml_str(
954                include_str!("../../../testdata/modelpack_split.yaml").to_string(),
955            )
956            .with_score_threshold(score_threshold)
957            .with_iou_threshold(iou_threshold)
958            .build()
959            .unwrap();
960
961        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
962        let mut output_masks: Vec<_> = Vec::with_capacity(10);
963        decoder
964            .decode_quantized(
965                &[
966                    ArrayViewDQuantized::from(detect1.view()),
967                    ArrayViewDQuantized::from(detect0.view()),
968                ],
969                &mut output_boxes,
970                &mut output_masks,
971            )
972            .unwrap();
973        assert!(output_boxes[0].equal_within_delta(
974            &DetectBox {
975                bbox: BoundingBox {
976                    xmin: 0.43171933,
977                    ymin: 0.68243736,
978                    xmax: 0.5626645,
979                    ymax: 0.808863,
980                },
981                score: 0.99240804,
982                label: 0
983            },
984            1e-6
985        ));
986    }
987
988    #[test]
989    fn test_modelpack_seg() {
990        let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
991        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
992        let quant = (1.0 / 255.0, 0).into();
993
994        let decoder = DecoderBuilder::default()
995            .with_config_modelpack_seg(configs::Segmentation {
996                decoder: DecoderType::ModelPack,
997                quantization: Some(quant),
998                shape: vec![1, 2, 160, 160],
999                dshape: vec![
1000                    (DimName::Batch, 1),
1001                    (DimName::NumClasses, 2),
1002                    (DimName::Height, 160),
1003                    (DimName::Width, 160),
1004                ],
1005            })
1006            .build()
1007            .unwrap();
1008        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1009        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1010        decoder
1011            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1012            .unwrap();
1013
1014        let mut mask = out.slice(s![0, .., .., ..]);
1015        mask.swap_axes(0, 1);
1016        mask.swap_axes(1, 2);
1017        let mask = [Segmentation {
1018            xmin: 0.0,
1019            ymin: 0.0,
1020            xmax: 1.0,
1021            ymax: 1.0,
1022            segmentation: mask.into_owned(),
1023        }];
1024        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1025
1026        decoder
1027            .decode_float::<f32>(
1028                &[dequantize_ndarray(out.view(), quant.into())
1029                    .view()
1030                    .into_dyn()],
1031                &mut output_boxes,
1032                &mut output_masks,
1033            )
1034            .unwrap();
1035
1036        // not expected for float decoder to have same values as quantized decoder, as
1037        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1038        // the model output. Thus the float output is the same as the quantized output
1039        // but scaled differently. However, it is expected that the mask after argmax
1040        // will be the same.
1041        compare_outputs((&[], &output_boxes), (&[], &[]));
1042        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1043        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1044
1045        assert_eq!(mask0, mask1);
1046    }
1047    #[test]
1048    fn test_modelpack_seg_quant() {
1049        let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1050        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1051        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1052        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1053        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1054        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1055        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1056
1057        let quant = (1.0 / 255.0, 0).into();
1058
1059        let decoder = DecoderBuilder::default()
1060            .with_config_modelpack_seg(configs::Segmentation {
1061                decoder: DecoderType::ModelPack,
1062                quantization: Some(quant),
1063                shape: vec![1, 2, 160, 160],
1064                dshape: vec![
1065                    (DimName::Batch, 1),
1066                    (DimName::NumClasses, 2),
1067                    (DimName::Height, 160),
1068                    (DimName::Width, 160),
1069                ],
1070            })
1071            .build()
1072            .unwrap();
1073        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1074        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1075        decoder
1076            .decode_quantized(
1077                &[out_u8.view().into()],
1078                &mut output_boxes,
1079                &mut output_masks_u8,
1080            )
1081            .unwrap();
1082
1083        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1084        decoder
1085            .decode_quantized(
1086                &[out_i8.view().into()],
1087                &mut output_boxes,
1088                &mut output_masks_i8,
1089            )
1090            .unwrap();
1091
1092        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1093        decoder
1094            .decode_quantized(
1095                &[out_u16.view().into()],
1096                &mut output_boxes,
1097                &mut output_masks_u16,
1098            )
1099            .unwrap();
1100
1101        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1102        decoder
1103            .decode_quantized(
1104                &[out_i16.view().into()],
1105                &mut output_boxes,
1106                &mut output_masks_i16,
1107            )
1108            .unwrap();
1109
1110        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1111        decoder
1112            .decode_quantized(
1113                &[out_u32.view().into()],
1114                &mut output_boxes,
1115                &mut output_masks_u32,
1116            )
1117            .unwrap();
1118
1119        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1120        decoder
1121            .decode_quantized(
1122                &[out_i32.view().into()],
1123                &mut output_boxes,
1124                &mut output_masks_i32,
1125            )
1126            .unwrap();
1127
1128        compare_outputs((&[], &output_boxes), (&[], &[]));
1129        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1130        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1131        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1132        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1133        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1134        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1135        assert_eq!(mask_u8, mask_i8);
1136        assert_eq!(mask_u8, mask_u16);
1137        assert_eq!(mask_u8, mask_i16);
1138        assert_eq!(mask_u8, mask_u32);
1139        assert_eq!(mask_u8, mask_i32);
1140    }
1141
1142    #[test]
1143    fn test_modelpack_segdet() {
1144        let score_threshold = 0.45;
1145        let iou_threshold = 0.45;
1146
1147        let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
1148        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1149
1150        let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
1151        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1152
1153        let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1154        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1155
1156        let quant_boxes = (0.004656755365431309, 21).into();
1157        let quant_scores = (0.0019603664986789227, 0).into();
1158        let quant_seg = (1.0 / 255.0, 0).into();
1159
1160        let decoder = DecoderBuilder::default()
1161            .with_config_modelpack_segdet(
1162                configs::Boxes {
1163                    decoder: DecoderType::ModelPack,
1164                    quantization: Some(quant_boxes),
1165                    shape: vec![1, 1935, 1, 4],
1166                    dshape: vec![
1167                        (DimName::Batch, 1),
1168                        (DimName::NumBoxes, 1935),
1169                        (DimName::Padding, 1),
1170                        (DimName::BoxCoords, 4),
1171                    ],
1172                    normalized: Some(true),
1173                },
1174                configs::Scores {
1175                    decoder: DecoderType::ModelPack,
1176                    quantization: Some(quant_scores),
1177                    shape: vec![1, 1935, 1],
1178                    dshape: vec![
1179                        (DimName::Batch, 1),
1180                        (DimName::NumBoxes, 1935),
1181                        (DimName::NumClasses, 1),
1182                    ],
1183                },
1184                configs::Segmentation {
1185                    decoder: DecoderType::ModelPack,
1186                    quantization: Some(quant_seg),
1187                    shape: vec![1, 2, 160, 160],
1188                    dshape: vec![
1189                        (DimName::Batch, 1),
1190                        (DimName::NumClasses, 2),
1191                        (DimName::Height, 160),
1192                        (DimName::Width, 160),
1193                    ],
1194                },
1195            )
1196            .with_iou_threshold(iou_threshold)
1197            .with_score_threshold(score_threshold)
1198            .build()
1199            .unwrap();
1200        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1201        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1202        decoder
1203            .decode_quantized(
1204                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1205                &mut output_boxes,
1206                &mut output_masks,
1207            )
1208            .unwrap();
1209
1210        let mut mask = seg.slice(s![0, .., .., ..]);
1211        mask.swap_axes(0, 1);
1212        mask.swap_axes(1, 2);
1213        let mask = [Segmentation {
1214            xmin: 0.0,
1215            ymin: 0.0,
1216            xmax: 1.0,
1217            ymax: 1.0,
1218            segmentation: mask.into_owned(),
1219        }];
1220        let correct_boxes = [DetectBox {
1221            bbox: BoundingBox {
1222                xmin: 0.40513772,
1223                ymin: 0.6379755,
1224                xmax: 0.5122431,
1225                ymax: 0.7730214,
1226            },
1227            score: 0.4861709,
1228            label: 0,
1229        }];
1230        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1231
1232        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1233        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1234        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1235        decoder
1236            .decode_float::<f32>(
1237                &[
1238                    scores.view().into_dyn(),
1239                    boxes.view().into_dyn(),
1240                    seg.view().into_dyn(),
1241                ],
1242                &mut output_boxes,
1243                &mut output_masks,
1244            )
1245            .unwrap();
1246
1247        // not expected for float segmentation decoder to have same values as quantized
1248        // segmentation decoder, as float decoder ensures the data fills 0-255,
1249        // quantized decoder uses whatever the model output. Thus the float
1250        // output is the same as the quantized output but scaled differently.
1251        // However, it is expected that the mask after argmax will be the same.
1252        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1253        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1254        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1255
1256        assert_eq!(mask0, mask1);
1257    }
1258
1259    #[test]
1260    fn test_modelpack_segdet_split() {
1261        let score_threshold = 0.8;
1262        let iou_threshold = 0.5;
1263
1264        let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1265        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1266
1267        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1268        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1269
1270        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1271        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1272
1273        let quant0 = (0.08547406643629074, 174).into();
1274        let quant1 = (0.09929127991199493, 183).into();
1275        let quant_seg = (1.0 / 255.0, 0).into();
1276
1277        let anchors0 = vec![
1278            [0.36666667461395264, 0.31481480598449707],
1279            [0.38749998807907104, 0.4740740656852722],
1280            [0.5333333611488342, 0.644444465637207],
1281        ];
1282        let anchors1 = vec![
1283            [0.13750000298023224, 0.2074074000120163],
1284            [0.2541666626930237, 0.21481481194496155],
1285            [0.23125000298023224, 0.35185185074806213],
1286        ];
1287
1288        let decoder = DecoderBuilder::default()
1289            .with_config_modelpack_segdet_split(
1290                vec![
1291                    configs::Detection {
1292                        decoder: DecoderType::ModelPack,
1293                        shape: vec![1, 17, 30, 18],
1294                        anchors: Some(anchors1),
1295                        quantization: Some(quant1),
1296                        dshape: vec![
1297                            (DimName::Batch, 1),
1298                            (DimName::Height, 17),
1299                            (DimName::Width, 30),
1300                            (DimName::NumAnchorsXFeatures, 18),
1301                        ],
1302                        normalized: Some(true),
1303                    },
1304                    configs::Detection {
1305                        decoder: DecoderType::ModelPack,
1306                        shape: vec![1, 9, 15, 18],
1307                        anchors: Some(anchors0),
1308                        quantization: Some(quant0),
1309                        dshape: vec![
1310                            (DimName::Batch, 1),
1311                            (DimName::Height, 9),
1312                            (DimName::Width, 15),
1313                            (DimName::NumAnchorsXFeatures, 18),
1314                        ],
1315                        normalized: Some(true),
1316                    },
1317                ],
1318                configs::Segmentation {
1319                    decoder: DecoderType::ModelPack,
1320                    quantization: Some(quant_seg),
1321                    shape: vec![1, 2, 160, 160],
1322                    dshape: vec![
1323                        (DimName::Batch, 1),
1324                        (DimName::NumClasses, 2),
1325                        (DimName::Height, 160),
1326                        (DimName::Width, 160),
1327                    ],
1328                },
1329            )
1330            .with_score_threshold(score_threshold)
1331            .with_iou_threshold(iou_threshold)
1332            .build()
1333            .unwrap();
1334        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1335        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1336        decoder
1337            .decode_quantized(
1338                &[
1339                    detect0.view().into(),
1340                    detect1.view().into(),
1341                    seg.view().into(),
1342                ],
1343                &mut output_boxes,
1344                &mut output_masks,
1345            )
1346            .unwrap();
1347
1348        let mut mask = seg.slice(s![0, .., .., ..]);
1349        mask.swap_axes(0, 1);
1350        mask.swap_axes(1, 2);
1351        let mask = [Segmentation {
1352            xmin: 0.0,
1353            ymin: 0.0,
1354            xmax: 1.0,
1355            ymax: 1.0,
1356            segmentation: mask.into_owned(),
1357        }];
1358        let correct_boxes = [DetectBox {
1359            bbox: BoundingBox {
1360                xmin: 0.43171933,
1361                ymin: 0.68243736,
1362                xmax: 0.5626645,
1363                ymax: 0.808863,
1364            },
1365            score: 0.99240804,
1366            label: 0,
1367        }];
1368        println!("Output Boxes: {:?}", output_boxes);
1369        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1370
1371        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1372        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1373        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1374        decoder
1375            .decode_float::<f32>(
1376                &[
1377                    detect0.view().into_dyn(),
1378                    detect1.view().into_dyn(),
1379                    seg.view().into_dyn(),
1380                ],
1381                &mut output_boxes,
1382                &mut output_masks,
1383            )
1384            .unwrap();
1385
1386        // not expected for float segmentation decoder to have same values as quantized
1387        // segmentation decoder, as float decoder ensures the data fills 0-255,
1388        // quantized decoder uses whatever the model output. Thus the float
1389        // output is the same as the quantized output but scaled differently.
1390        // However, it is expected that the mask after argmax will be the same.
1391        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1392        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1393        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1394
1395        assert_eq!(mask0, mask1);
1396    }
1397
1398    #[test]
1399    fn test_dequant_chunked() {
1400        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1401        let mut out =
1402            unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1403        out.push(123); // make sure to test non multiple of 16 length
1404
1405        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1406        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1407        let quant = Quantization::new(0.0040811873, -123);
1408        dequantize_cpu(&out, quant, &mut out_dequant);
1409
1410        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1411        assert_eq!(out_dequant, out_dequant_simd);
1412
1413        let quant = Quantization::new(0.0040811873, 0);
1414        dequantize_cpu(&out, quant, &mut out_dequant);
1415
1416        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1417        assert_eq!(out_dequant, out_dequant_simd);
1418    }
1419
1420    #[test]
1421    fn test_decoder_yolo_det() {
1422        let score_threshold = 0.25;
1423        let iou_threshold = 0.7;
1424        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1425        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1426        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1427        let quant = (0.0040811873, -123).into();
1428
1429        let decoder = DecoderBuilder::default()
1430            .with_config_yolo_det(
1431                configs::Detection {
1432                    decoder: DecoderType::Ultralytics,
1433                    shape: vec![1, 84, 8400],
1434                    anchors: None,
1435                    quantization: Some(quant),
1436                    dshape: vec![
1437                        (DimName::Batch, 1),
1438                        (DimName::NumFeatures, 84),
1439                        (DimName::NumBoxes, 8400),
1440                    ],
1441                    normalized: Some(true),
1442                },
1443                Some(DecoderVersion::Yolo11),
1444            )
1445            .with_score_threshold(score_threshold)
1446            .with_iou_threshold(iou_threshold)
1447            .build()
1448            .unwrap();
1449
1450        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1451        decode_yolo_det(
1452            (out.slice(s![0, .., ..]), quant.into()),
1453            score_threshold,
1454            iou_threshold,
1455            Some(configs::Nms::ClassAgnostic),
1456            &mut output_boxes,
1457        );
1458        assert!(output_boxes[0].equal_within_delta(
1459            &DetectBox {
1460                bbox: BoundingBox {
1461                    xmin: 0.5285137,
1462                    ymin: 0.05305544,
1463                    xmax: 0.87541467,
1464                    ymax: 0.9998909,
1465                },
1466                score: 0.5591227,
1467                label: 0
1468            },
1469            1e-6
1470        ));
1471
1472        assert!(output_boxes[1].equal_within_delta(
1473            &DetectBox {
1474                bbox: BoundingBox {
1475                    xmin: 0.130598,
1476                    ymin: 0.43260583,
1477                    xmax: 0.35098213,
1478                    ymax: 0.9958097,
1479                },
1480                score: 0.33057618,
1481                label: 75
1482            },
1483            1e-6
1484        ));
1485
1486        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1487        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1488        decoder
1489            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1490            .unwrap();
1491
1492        let out = dequantize_ndarray(out.view(), quant.into());
1493        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1494        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1495        decoder
1496            .decode_float::<f32>(
1497                &[out.view().into_dyn()],
1498                &mut output_boxes_f32,
1499                &mut output_masks_f32,
1500            )
1501            .unwrap();
1502
1503        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1504        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1505    }
1506
1507    #[test]
1508    fn test_decoder_masks() {
1509        let score_threshold = 0.45;
1510        let iou_threshold = 0.45;
1511        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1512        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1513        let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1514        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1515
1516        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1517        let protos =
1518            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1519        let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1520        let quant_protos = Quantization::new(0.02491161972284317, -117);
1521        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1522        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1523        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1524        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1525        decode_yolo_segdet_float(
1526            seg.view(),
1527            protos.view(),
1528            score_threshold,
1529            iou_threshold,
1530            Some(configs::Nms::ClassAgnostic),
1531            &mut output_boxes,
1532            &mut output_masks,
1533        );
1534        assert_eq!(output_boxes.len(), 2);
1535        assert_eq!(output_boxes.len(), output_masks.len());
1536
1537        for (b, m) in output_boxes.iter().zip(&output_masks) {
1538            assert!(b.bbox.xmin >= m.xmin);
1539            assert!(b.bbox.ymin >= m.ymin);
1540            assert!(b.bbox.xmax >= m.xmax);
1541            assert!(b.bbox.ymax >= m.ymax);
1542        }
1543        assert!(output_boxes[0].equal_within_delta(
1544            &DetectBox {
1545                bbox: BoundingBox {
1546                    xmin: 0.08515105,
1547                    ymin: 0.7131401,
1548                    xmax: 0.29802868,
1549                    ymax: 0.8195788,
1550                },
1551                score: 0.91537374,
1552                label: 23
1553            },
1554            1.0 / 160.0, // wider range because mask will expand the box
1555        ));
1556
1557        assert!(output_boxes[1].equal_within_delta(
1558            &DetectBox {
1559                bbox: BoundingBox {
1560                    xmin: 0.59605736,
1561                    ymin: 0.25545314,
1562                    xmax: 0.93666154,
1563                    ymax: 0.72378385,
1564                },
1565                score: 0.91537374,
1566                label: 23
1567            },
1568            1.0 / 160.0, // wider range because mask will expand the box
1569        ));
1570
1571        let full_mask = include_bytes!("../../../testdata/yolov8_mask_results.bin");
1572        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1573
1574        let cropped_mask = full_mask.slice(ndarray::s![
1575            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1576            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1577        ]);
1578
1579        assert_eq!(
1580            cropped_mask,
1581            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1582        );
1583    }
1584
1585    #[test]
1586    fn test_decoder_masks_i8() {
1587        let score_threshold = 0.45;
1588        let iou_threshold = 0.45;
1589        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1590        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1591        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1592        let quant_boxes = (0.021287761628627777, 31).into();
1593
1594        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1595        let protos =
1596            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1597        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1598        let quant_protos = (0.02491161972284317, -117).into();
1599        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1600        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1601
1602        let decoder = DecoderBuilder::default()
1603            .with_config_yolo_segdet(
1604                configs::Detection {
1605                    decoder: configs::DecoderType::Ultralytics,
1606                    quantization: Some(quant_boxes),
1607                    shape: vec![1, 116, 8400],
1608                    anchors: None,
1609                    dshape: vec![
1610                        (DimName::Batch, 1),
1611                        (DimName::NumFeatures, 116),
1612                        (DimName::NumBoxes, 8400),
1613                    ],
1614                    normalized: Some(true),
1615                },
1616                Protos {
1617                    decoder: configs::DecoderType::Ultralytics,
1618                    quantization: Some(quant_protos),
1619                    shape: vec![1, 160, 160, 32],
1620                    dshape: vec![
1621                        (DimName::Batch, 1),
1622                        (DimName::Height, 160),
1623                        (DimName::Width, 160),
1624                        (DimName::NumProtos, 32),
1625                    ],
1626                },
1627                Some(DecoderVersion::Yolo11),
1628            )
1629            .with_score_threshold(score_threshold)
1630            .with_iou_threshold(iou_threshold)
1631            .build()
1632            .unwrap();
1633
1634        let quant_boxes = quant_boxes.into();
1635        let quant_protos = quant_protos.into();
1636
1637        decode_yolo_segdet_quant(
1638            (boxes.slice(s![0, .., ..]), quant_boxes),
1639            (protos.slice(s![0, .., .., ..]), quant_protos),
1640            score_threshold,
1641            iou_threshold,
1642            Some(configs::Nms::ClassAgnostic),
1643            &mut output_boxes,
1644            &mut output_masks,
1645        );
1646
1647        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1648        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1649
1650        decoder
1651            .decode_quantized(
1652                &[boxes.view().into(), protos.view().into()],
1653                &mut output_boxes1,
1654                &mut output_masks1,
1655            )
1656            .unwrap();
1657
1658        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1659        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1660
1661        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1662        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1663        decode_yolo_segdet_float(
1664            seg.slice(s![0, .., ..]),
1665            protos.slice(s![0, .., .., ..]),
1666            score_threshold,
1667            iou_threshold,
1668            Some(configs::Nms::ClassAgnostic),
1669            &mut output_boxes_f32,
1670            &mut output_masks_f32,
1671        );
1672
1673        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1674        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1675
1676        decoder
1677            .decode_float(
1678                &[seg.view().into_dyn(), protos.view().into_dyn()],
1679                &mut output_boxes1_f32,
1680                &mut output_masks1_f32,
1681            )
1682            .unwrap();
1683
1684        compare_outputs(
1685            (&output_boxes, &output_boxes1),
1686            (&output_masks, &output_masks1),
1687        );
1688
1689        compare_outputs(
1690            (&output_boxes, &output_boxes_f32),
1691            (&output_masks, &output_masks_f32),
1692        );
1693
1694        compare_outputs(
1695            (&output_boxes_f32, &output_boxes1_f32),
1696            (&output_masks_f32, &output_masks1_f32),
1697        );
1698    }
1699
1700    #[test]
1701    fn test_decoder_yolo_split() {
1702        let score_threshold = 0.45;
1703        let iou_threshold = 0.45;
1704        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1705        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1706        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1707        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1708
1709        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1710
1711        let decoder = DecoderBuilder::default()
1712            .with_config_yolo_split_det(
1713                configs::Boxes {
1714                    decoder: configs::DecoderType::Ultralytics,
1715                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1716                    shape: vec![1, 4, 8400],
1717                    dshape: vec![
1718                        (DimName::Batch, 1),
1719                        (DimName::BoxCoords, 4),
1720                        (DimName::NumBoxes, 8400),
1721                    ],
1722                    normalized: Some(true),
1723                },
1724                configs::Scores {
1725                    decoder: configs::DecoderType::Ultralytics,
1726                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1727                    shape: vec![1, 80, 8400],
1728                    dshape: vec![
1729                        (DimName::Batch, 1),
1730                        (DimName::NumClasses, 80),
1731                        (DimName::NumBoxes, 8400),
1732                    ],
1733                },
1734            )
1735            .with_score_threshold(score_threshold)
1736            .with_iou_threshold(iou_threshold)
1737            .build()
1738            .unwrap();
1739
1740        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1741        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1742
1743        decoder
1744            .decode_quantized(
1745                &[
1746                    boxes.slice(s![.., ..4, ..]).into(),
1747                    boxes.slice(s![.., 4..84, ..]).into(),
1748                ],
1749                &mut output_boxes,
1750                &mut output_masks,
1751            )
1752            .unwrap();
1753
1754        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1755        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1756        decode_yolo_det_float(
1757            seg.slice(s![0, ..84, ..]),
1758            score_threshold,
1759            iou_threshold,
1760            Some(configs::Nms::ClassAgnostic),
1761            &mut output_boxes_f32,
1762        );
1763
1764        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1765        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1766
1767        decoder
1768            .decode_float(
1769                &[
1770                    seg.slice(s![.., ..4, ..]).into_dyn(),
1771                    seg.slice(s![.., 4..84, ..]).into_dyn(),
1772                ],
1773                &mut output_boxes1,
1774                &mut output_masks1,
1775            )
1776            .unwrap();
1777        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
1778        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
1779    }
1780
1781    #[test]
1782    fn test_decoder_masks_config_mixed() {
1783        let score_threshold = 0.45;
1784        let iou_threshold = 0.45;
1785        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1786        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1787        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1788        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1789
1790        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1791
1792        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1793        let protos =
1794            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1795        let protos: Vec<_> = protos.to_vec();
1796        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1797        let quant_protos = Quantization::new(0.02491161972284317, -117);
1798
1799        let decoder = DecoderBuilder::default()
1800            .with_config_yolo_split_segdet(
1801                configs::Boxes {
1802                    decoder: configs::DecoderType::Ultralytics,
1803                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1804                    shape: vec![1, 4, 8400],
1805                    dshape: vec![
1806                        (DimName::Batch, 1),
1807                        (DimName::BoxCoords, 4),
1808                        (DimName::NumBoxes, 8400),
1809                    ],
1810                    normalized: Some(true),
1811                },
1812                configs::Scores {
1813                    decoder: configs::DecoderType::Ultralytics,
1814                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1815                    shape: vec![1, 80, 8400],
1816                    dshape: vec![
1817                        (DimName::Batch, 1),
1818                        (DimName::NumClasses, 80),
1819                        (DimName::NumBoxes, 8400),
1820                    ],
1821                },
1822                configs::MaskCoefficients {
1823                    decoder: configs::DecoderType::Ultralytics,
1824                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1825                    shape: vec![1, 32, 8400],
1826                    dshape: vec![
1827                        (DimName::Batch, 1),
1828                        (DimName::NumProtos, 32),
1829                        (DimName::NumBoxes, 8400),
1830                    ],
1831                },
1832                configs::Protos {
1833                    decoder: configs::DecoderType::Ultralytics,
1834                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
1835                    shape: vec![1, 160, 160, 32],
1836                    dshape: vec![
1837                        (DimName::Batch, 1),
1838                        (DimName::Height, 160),
1839                        (DimName::Width, 160),
1840                        (DimName::NumProtos, 32),
1841                    ],
1842                },
1843            )
1844            .with_score_threshold(score_threshold)
1845            .with_iou_threshold(iou_threshold)
1846            .build()
1847            .unwrap();
1848
1849        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1850        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1851
1852        decoder
1853            .decode_quantized(
1854                &[
1855                    boxes.slice(s![.., ..4, ..]).into(),
1856                    boxes.slice(s![.., 4..84, ..]).into(),
1857                    boxes.slice(s![.., 84.., ..]).into(),
1858                    protos.view().into(),
1859                ],
1860                &mut output_boxes,
1861                &mut output_masks,
1862            )
1863            .unwrap();
1864
1865        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1866        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1867        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1868        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1869        decode_yolo_segdet_float(
1870            seg.slice(s![0, .., ..]),
1871            protos.slice(s![0, .., .., ..]),
1872            score_threshold,
1873            iou_threshold,
1874            Some(configs::Nms::ClassAgnostic),
1875            &mut output_boxes_f32,
1876            &mut output_masks_f32,
1877        );
1878
1879        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1880        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1881
1882        decoder
1883            .decode_float(
1884                &[
1885                    seg.slice(s![.., ..4, ..]).into_dyn(),
1886                    seg.slice(s![.., 4..84, ..]).into_dyn(),
1887                    seg.slice(s![.., 84.., ..]).into_dyn(),
1888                    protos.view().into_dyn(),
1889                ],
1890                &mut output_boxes1,
1891                &mut output_masks1,
1892            )
1893            .unwrap();
1894        compare_outputs(
1895            (&output_boxes, &output_boxes_f32),
1896            (&output_masks, &output_masks_f32),
1897        );
1898        compare_outputs(
1899            (&output_boxes_f32, &output_boxes1),
1900            (&output_masks_f32, &output_masks1),
1901        );
1902    }
1903
1904    #[test]
1905    fn test_decoder_masks_config_i32() {
1906        let score_threshold = 0.45;
1907        let iou_threshold = 0.45;
1908        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1909        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1910        let scale = 1 << 23;
1911        let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
1912        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1913
1914        let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
1915
1916        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1917        let protos =
1918            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1919        let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
1920        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1921        let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
1922
1923        let decoder = DecoderBuilder::default()
1924            .with_config_yolo_split_segdet(
1925                configs::Boxes {
1926                    decoder: configs::DecoderType::Ultralytics,
1927                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1928                    shape: vec![1, 4, 8400],
1929                    dshape: vec![
1930                        (DimName::Batch, 1),
1931                        (DimName::BoxCoords, 4),
1932                        (DimName::NumBoxes, 8400),
1933                    ],
1934                    normalized: Some(true),
1935                },
1936                configs::Scores {
1937                    decoder: configs::DecoderType::Ultralytics,
1938                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1939                    shape: vec![1, 80, 8400],
1940                    dshape: vec![
1941                        (DimName::Batch, 1),
1942                        (DimName::NumClasses, 80),
1943                        (DimName::NumBoxes, 8400),
1944                    ],
1945                },
1946                configs::MaskCoefficients {
1947                    decoder: configs::DecoderType::Ultralytics,
1948                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1949                    shape: vec![1, 32, 8400],
1950                    dshape: vec![
1951                        (DimName::Batch, 1),
1952                        (DimName::NumProtos, 32),
1953                        (DimName::NumBoxes, 8400),
1954                    ],
1955                },
1956                configs::Protos {
1957                    decoder: configs::DecoderType::Ultralytics,
1958                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
1959                    shape: vec![1, 160, 160, 32],
1960                    dshape: vec![
1961                        (DimName::Batch, 1),
1962                        (DimName::Height, 160),
1963                        (DimName::Width, 160),
1964                        (DimName::NumProtos, 32),
1965                    ],
1966                },
1967            )
1968            .with_score_threshold(score_threshold)
1969            .with_iou_threshold(iou_threshold)
1970            .build()
1971            .unwrap();
1972
1973        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1974        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1975
1976        decoder
1977            .decode_quantized(
1978                &[
1979                    boxes.slice(s![.., ..4, ..]).into(),
1980                    boxes.slice(s![.., 4..84, ..]).into(),
1981                    boxes.slice(s![.., 84.., ..]).into(),
1982                    protos.view().into(),
1983                ],
1984                &mut output_boxes,
1985                &mut output_masks,
1986            )
1987            .unwrap();
1988
1989        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1990        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1991        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1992        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
1993        decode_yolo_segdet_float(
1994            seg.slice(s![0, .., ..]),
1995            protos.slice(s![0, .., .., ..]),
1996            score_threshold,
1997            iou_threshold,
1998            Some(configs::Nms::ClassAgnostic),
1999            &mut output_boxes_f32,
2000            &mut output_masks_f32,
2001        );
2002
2003        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2004        assert_eq!(output_masks.len(), output_masks_f32.len());
2005
2006        compare_outputs(
2007            (&output_boxes, &output_boxes_f32),
2008            (&output_masks, &output_masks_f32),
2009        );
2010    }
2011
2012    /// test running multiple decoders concurrently
2013    #[test]
2014    fn test_context_switch() {
2015        let yolo_det = || {
2016            let score_threshold = 0.25;
2017            let iou_threshold = 0.7;
2018            let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2019            let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2020            let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2021            let quant = (0.0040811873, -123).into();
2022
2023            let decoder = DecoderBuilder::default()
2024                .with_config_yolo_det(
2025                    configs::Detection {
2026                        decoder: DecoderType::Ultralytics,
2027                        shape: vec![1, 84, 8400],
2028                        anchors: None,
2029                        quantization: Some(quant),
2030                        dshape: vec![
2031                            (DimName::Batch, 1),
2032                            (DimName::NumFeatures, 84),
2033                            (DimName::NumBoxes, 8400),
2034                        ],
2035                        normalized: None,
2036                    },
2037                    None,
2038                )
2039                .with_score_threshold(score_threshold)
2040                .with_iou_threshold(iou_threshold)
2041                .build()
2042                .unwrap();
2043
2044            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2045            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2046
2047            for _ in 0..100 {
2048                decoder
2049                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2050                    .unwrap();
2051
2052                assert!(output_boxes[0].equal_within_delta(
2053                    &DetectBox {
2054                        bbox: BoundingBox {
2055                            xmin: 0.5285137,
2056                            ymin: 0.05305544,
2057                            xmax: 0.87541467,
2058                            ymax: 0.9998909,
2059                        },
2060                        score: 0.5591227,
2061                        label: 0
2062                    },
2063                    1e-6
2064                ));
2065
2066                assert!(output_boxes[1].equal_within_delta(
2067                    &DetectBox {
2068                        bbox: BoundingBox {
2069                            xmin: 0.130598,
2070                            ymin: 0.43260583,
2071                            xmax: 0.35098213,
2072                            ymax: 0.9958097,
2073                        },
2074                        score: 0.33057618,
2075                        label: 75
2076                    },
2077                    1e-6
2078                ));
2079                assert!(output_masks.is_empty());
2080            }
2081        };
2082
2083        let modelpack_det_split = || {
2084            let score_threshold = 0.8;
2085            let iou_threshold = 0.5;
2086
2087            let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
2088            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2089
2090            let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2091            let detect0 =
2092                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2093
2094            let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2095            let detect1 =
2096                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2097
2098            let mut mask = seg.slice(s![0, .., .., ..]);
2099            mask.swap_axes(0, 1);
2100            mask.swap_axes(1, 2);
2101            let mask = [Segmentation {
2102                xmin: 0.0,
2103                ymin: 0.0,
2104                xmax: 1.0,
2105                ymax: 1.0,
2106                segmentation: mask.into_owned(),
2107            }];
2108            let correct_boxes = [DetectBox {
2109                bbox: BoundingBox {
2110                    xmin: 0.43171933,
2111                    ymin: 0.68243736,
2112                    xmax: 0.5626645,
2113                    ymax: 0.808863,
2114                },
2115                score: 0.99240804,
2116                label: 0,
2117            }];
2118
2119            let quant0 = (0.08547406643629074, 174).into();
2120            let quant1 = (0.09929127991199493, 183).into();
2121            let quant_seg = (1.0 / 255.0, 0).into();
2122
2123            let anchors0 = vec![
2124                [0.36666667461395264, 0.31481480598449707],
2125                [0.38749998807907104, 0.4740740656852722],
2126                [0.5333333611488342, 0.644444465637207],
2127            ];
2128            let anchors1 = vec![
2129                [0.13750000298023224, 0.2074074000120163],
2130                [0.2541666626930237, 0.21481481194496155],
2131                [0.23125000298023224, 0.35185185074806213],
2132            ];
2133
2134            let decoder = DecoderBuilder::default()
2135                .with_config_modelpack_segdet_split(
2136                    vec![
2137                        configs::Detection {
2138                            decoder: DecoderType::ModelPack,
2139                            shape: vec![1, 17, 30, 18],
2140                            anchors: Some(anchors1),
2141                            quantization: Some(quant1),
2142                            dshape: vec![
2143                                (DimName::Batch, 1),
2144                                (DimName::Height, 17),
2145                                (DimName::Width, 30),
2146                                (DimName::NumAnchorsXFeatures, 18),
2147                            ],
2148                            normalized: None,
2149                        },
2150                        configs::Detection {
2151                            decoder: DecoderType::ModelPack,
2152                            shape: vec![1, 9, 15, 18],
2153                            anchors: Some(anchors0),
2154                            quantization: Some(quant0),
2155                            dshape: vec![
2156                                (DimName::Batch, 1),
2157                                (DimName::Height, 9),
2158                                (DimName::Width, 15),
2159                                (DimName::NumAnchorsXFeatures, 18),
2160                            ],
2161                            normalized: None,
2162                        },
2163                    ],
2164                    configs::Segmentation {
2165                        decoder: DecoderType::ModelPack,
2166                        quantization: Some(quant_seg),
2167                        shape: vec![1, 2, 160, 160],
2168                        dshape: vec![
2169                            (DimName::Batch, 1),
2170                            (DimName::NumClasses, 2),
2171                            (DimName::Height, 160),
2172                            (DimName::Width, 160),
2173                        ],
2174                    },
2175                )
2176                .with_score_threshold(score_threshold)
2177                .with_iou_threshold(iou_threshold)
2178                .build()
2179                .unwrap();
2180            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2181            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2182
2183            for _ in 0..100 {
2184                decoder
2185                    .decode_quantized(
2186                        &[
2187                            detect0.view().into(),
2188                            detect1.view().into(),
2189                            seg.view().into(),
2190                        ],
2191                        &mut output_boxes,
2192                        &mut output_masks,
2193                    )
2194                    .unwrap();
2195
2196                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2197            }
2198        };
2199
2200        let handles = vec![
2201            std::thread::spawn(yolo_det),
2202            std::thread::spawn(modelpack_det_split),
2203            std::thread::spawn(yolo_det),
2204            std::thread::spawn(modelpack_det_split),
2205            std::thread::spawn(yolo_det),
2206            std::thread::spawn(modelpack_det_split),
2207            std::thread::spawn(yolo_det),
2208            std::thread::spawn(modelpack_det_split),
2209        ];
2210        for handle in handles {
2211            handle.join().unwrap();
2212        }
2213    }
2214
2215    #[test]
2216    fn test_ndarray_to_xyxy_float() {
2217        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2218        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2219        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2220
2221        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2222        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2223        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2224    }
2225
2226    #[test]
2227    fn test_class_aware_nms_float() {
2228        use crate::float::nms_class_aware_float;
2229
2230        // Create two overlapping boxes with different classes
2231        let boxes = vec![
2232            DetectBox {
2233                bbox: BoundingBox {
2234                    xmin: 0.0,
2235                    ymin: 0.0,
2236                    xmax: 0.5,
2237                    ymax: 0.5,
2238                },
2239                score: 0.9,
2240                label: 0, // class 0
2241            },
2242            DetectBox {
2243                bbox: BoundingBox {
2244                    xmin: 0.1,
2245                    ymin: 0.1,
2246                    xmax: 0.6,
2247                    ymax: 0.6,
2248                },
2249                score: 0.8,
2250                label: 1, // class 1 - different class
2251            },
2252        ];
2253
2254        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2255        // threshold 0.3)
2256        let result = nms_class_aware_float(0.3, boxes.clone());
2257        assert_eq!(
2258            result.len(),
2259            2,
2260            "Class-aware NMS should keep both boxes with different classes"
2261        );
2262
2263        // Now test with same class - should suppress one
2264        let same_class_boxes = vec![
2265            DetectBox {
2266                bbox: BoundingBox {
2267                    xmin: 0.0,
2268                    ymin: 0.0,
2269                    xmax: 0.5,
2270                    ymax: 0.5,
2271                },
2272                score: 0.9,
2273                label: 0,
2274            },
2275            DetectBox {
2276                bbox: BoundingBox {
2277                    xmin: 0.1,
2278                    ymin: 0.1,
2279                    xmax: 0.6,
2280                    ymax: 0.6,
2281                },
2282                score: 0.8,
2283                label: 0, // same class
2284            },
2285        ];
2286
2287        let result = nms_class_aware_float(0.3, same_class_boxes);
2288        assert_eq!(
2289            result.len(),
2290            1,
2291            "Class-aware NMS should suppress overlapping box with same class"
2292        );
2293        assert_eq!(result[0].label, 0);
2294        assert!((result[0].score - 0.9).abs() < 1e-6);
2295    }
2296
2297    #[test]
2298    fn test_class_agnostic_vs_aware_nms() {
2299        use crate::float::{nms_class_aware_float, nms_float};
2300
2301        // Two overlapping boxes with different classes
2302        let boxes = vec![
2303            DetectBox {
2304                bbox: BoundingBox {
2305                    xmin: 0.0,
2306                    ymin: 0.0,
2307                    xmax: 0.5,
2308                    ymax: 0.5,
2309                },
2310                score: 0.9,
2311                label: 0,
2312            },
2313            DetectBox {
2314                bbox: BoundingBox {
2315                    xmin: 0.1,
2316                    ymin: 0.1,
2317                    xmax: 0.6,
2318                    ymax: 0.6,
2319                },
2320                score: 0.8,
2321                label: 1,
2322            },
2323        ];
2324
2325        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2326        let agnostic_result = nms_float(0.3, boxes.clone());
2327        assert_eq!(
2328            agnostic_result.len(),
2329            1,
2330            "Class-agnostic NMS should suppress overlapping boxes"
2331        );
2332
2333        // Class-aware should keep both (different classes)
2334        let aware_result = nms_class_aware_float(0.3, boxes);
2335        assert_eq!(
2336            aware_result.len(),
2337            2,
2338            "Class-aware NMS should keep boxes with different classes"
2339        );
2340    }
2341
2342    #[test]
2343    fn test_class_aware_nms_int() {
2344        use crate::byte::nms_class_aware_int;
2345
2346        // Create two overlapping boxes with different classes
2347        let boxes = vec![
2348            DetectBoxQuantized {
2349                bbox: BoundingBox {
2350                    xmin: 0.0,
2351                    ymin: 0.0,
2352                    xmax: 0.5,
2353                    ymax: 0.5,
2354                },
2355                score: 200_u8,
2356                label: 0,
2357            },
2358            DetectBoxQuantized {
2359                bbox: BoundingBox {
2360                    xmin: 0.1,
2361                    ymin: 0.1,
2362                    xmax: 0.6,
2363                    ymax: 0.6,
2364                },
2365                score: 180_u8,
2366                label: 1, // different class
2367            },
2368        ];
2369
2370        // Should keep both (different classes)
2371        let result = nms_class_aware_int(0.5, boxes);
2372        assert_eq!(
2373            result.len(),
2374            2,
2375            "Class-aware NMS (int) should keep boxes with different classes"
2376        );
2377    }
2378
2379    #[test]
2380    fn test_nms_enum_default() {
2381        // Test that Nms enum has the correct default
2382        let default_nms: configs::Nms = Default::default();
2383        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2384    }
2385
2386    #[test]
2387    fn test_decoder_nms_mode() {
2388        // Test that decoder properly stores NMS mode
2389        let decoder = DecoderBuilder::default()
2390            .with_config_yolo_det(
2391                configs::Detection {
2392                    anchors: None,
2393                    decoder: DecoderType::Ultralytics,
2394                    quantization: None,
2395                    shape: vec![1, 84, 8400],
2396                    dshape: Vec::new(),
2397                    normalized: Some(true),
2398                },
2399                None,
2400            )
2401            .with_nms(Some(configs::Nms::ClassAware))
2402            .build()
2403            .unwrap();
2404
2405        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2406    }
2407
2408    #[test]
2409    fn test_decoder_nms_bypass() {
2410        // Test that decoder can be configured with nms=None (bypass)
2411        let decoder = DecoderBuilder::default()
2412            .with_config_yolo_det(
2413                configs::Detection {
2414                    anchors: None,
2415                    decoder: DecoderType::Ultralytics,
2416                    quantization: None,
2417                    shape: vec![1, 84, 8400],
2418                    dshape: Vec::new(),
2419                    normalized: Some(true),
2420                },
2421                None,
2422            )
2423            .with_nms(None)
2424            .build()
2425            .unwrap();
2426
2427        assert_eq!(decoder.nms, None);
2428    }
2429
2430    #[test]
2431    fn test_decoder_normalized_boxes_true() {
2432        // Test that normalized_boxes returns Some(true) when explicitly set
2433        let decoder = DecoderBuilder::default()
2434            .with_config_yolo_det(
2435                configs::Detection {
2436                    anchors: None,
2437                    decoder: DecoderType::Ultralytics,
2438                    quantization: None,
2439                    shape: vec![1, 84, 8400],
2440                    dshape: Vec::new(),
2441                    normalized: Some(true),
2442                },
2443                None,
2444            )
2445            .build()
2446            .unwrap();
2447
2448        assert_eq!(decoder.normalized_boxes(), Some(true));
2449    }
2450
2451    #[test]
2452    fn test_decoder_normalized_boxes_false() {
2453        // Test that normalized_boxes returns Some(false) when config specifies
2454        // unnormalized
2455        let decoder = DecoderBuilder::default()
2456            .with_config_yolo_det(
2457                configs::Detection {
2458                    anchors: None,
2459                    decoder: DecoderType::Ultralytics,
2460                    quantization: None,
2461                    shape: vec![1, 84, 8400],
2462                    dshape: Vec::new(),
2463                    normalized: Some(false),
2464                },
2465                None,
2466            )
2467            .build()
2468            .unwrap();
2469
2470        assert_eq!(decoder.normalized_boxes(), Some(false));
2471    }
2472
2473    #[test]
2474    fn test_decoder_normalized_boxes_unknown() {
2475        // Test that normalized_boxes returns None when not specified in config
2476        let decoder = DecoderBuilder::default()
2477            .with_config_yolo_det(
2478                configs::Detection {
2479                    anchors: None,
2480                    decoder: DecoderType::Ultralytics,
2481                    quantization: None,
2482                    shape: vec![1, 84, 8400],
2483                    dshape: Vec::new(),
2484                    normalized: None,
2485                },
2486                Some(DecoderVersion::Yolo11),
2487            )
2488            .build()
2489            .unwrap();
2490
2491        assert_eq!(decoder.normalized_boxes(), None);
2492    }
2493}