Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust
16# use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs::{self, DecoderVersion} };
17# fn main() -> DecoderResult<()> {
18// Create a decoder for a YOLOv8 model with quantized int8 output with 0.25 score threshold and 0.7 IOU threshold
19let decoder = DecoderBuilder::new()
20    .with_config_yolo_det(configs::Detection {
21        anchors: None,
22        decoder: configs::DecoderType::Ultralytics,
23        quantization: Some(configs::QuantTuple(0.012345, 26)),
24        shape: vec![1, 84, 8400],
25        dshape: Vec::new(),
26        normalized: Some(true),
27    },
28    Some(DecoderVersion::Yolov8))
29    .with_score_threshold(0.25)
30    .with_iou_threshold(0.7)
31    .build()?;
32
33// Get the model output from the model. Here we load it from a test data file for demonstration purposes.
34let model_output: Vec<i8> = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/yolov8s_80_classes.bin"))
35    .iter()
36    .map(|b| *b as i8)
37    .collect();
38let model_output_array = ndarray::Array3::from_shape_vec((1, 84, 8400), model_output)?;
39
40// THe capacity is used to determine the maximum number of detections to decode.
41let mut output_boxes: Vec<_> = Vec::with_capacity(10);
42let mut output_masks: Vec<_> = Vec::with_capacity(10);
43
44// Decode the quantized model output into detection boxes and segmentation masks
45// Because this model is a detection-only model, the `output_masks` vector will remain empty.
46decoder.decode_quantized(&[model_output_array.view().into()], &mut output_boxes, &mut output_masks)?;
47# Ok(())
48# }
49```
50
51# Overview
52
53The primary components of this crate are:
54- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
55- `yolo` module: Contains functions specific to decoding YOLO model outputs.
56- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
57
58The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
59It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
60When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
61If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
62the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
63to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
64
65The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
66which can be used if the model type and output formats are known in advance.
67
68
69*/
70#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89    yolo::yolo_segmentation_to_mask,
90};
91
92/// Trait to convert bounding box formats to XYXY float format
93pub trait BBoxTypeTrait {
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97    /// Converts the bbox into XYXY float format.
98    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99        input: &[B; 4],
100        quant: Quantization,
101    ) -> [A; 4]
102    where
103        f32: AsPrimitive<A>,
104        i32: AsPrimitive<A>;
105
106    /// Converts the bbox into XYXY float format.
107    ///
108    /// # Examples
109    /// ```rust
110    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
111    /// # use ndarray::array;
112    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
113    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
114    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
115    /// ```
116    #[inline(always)]
117    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118        input: ArrayView1<B>,
119    ) -> [A; 4] {
120        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121    }
122
123    #[inline(always)]
124    /// Converts the bbox into XYXY float format.
125    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126        input: ArrayView1<B>,
127        quant: Quantization,
128    ) -> [A; 4]
129    where
130        f32: AsPrimitive<A>,
131        i32: AsPrimitive<A>,
132    {
133        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134    }
135}
136
137/// Converts XYXY bounding boxes to XYXY
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143        input.map(|b| b.as_())
144    }
145
146    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147        input: &[B; 4],
148        quant: Quantization,
149    ) -> [A; 4]
150    where
151        f32: AsPrimitive<A>,
152        i32: AsPrimitive<A>,
153    {
154        let scale = quant.scale.as_();
155        let zp = quant.zero_point.as_();
156        input.map(|b| (b.as_() - zp) * scale)
157    }
158
159    #[inline(always)]
160    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161        input: ArrayView1<B>,
162    ) -> [A; 4] {
163        [
164            input[0].as_(),
165            input[1].as_(),
166            input[2].as_(),
167            input[3].as_(),
168        ]
169    }
170}
171
172/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
173/// box
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178    #[inline(always)]
179    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180        let half = A::one() / (A::one() + A::one());
181        [
182            (input[0].as_()) - (input[2].as_() * half),
183            (input[1].as_()) - (input[3].as_() * half),
184            (input[0].as_()) + (input[2].as_() * half),
185            (input[1].as_()) + (input[3].as_() * half),
186        ]
187    }
188
189    #[inline(always)]
190    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191        input: &[B; 4],
192        quant: Quantization,
193    ) -> [A; 4]
194    where
195        f32: AsPrimitive<A>,
196        i32: AsPrimitive<A>,
197    {
198        let scale = quant.scale.as_();
199        let half_scale = (quant.scale * 0.5).as_();
200        let zp = quant.zero_point.as_();
201        let [x, y, w, h] = [
202            (input[0].as_() - zp) * scale,
203            (input[1].as_() - zp) * scale,
204            (input[2].as_() - zp) * half_scale,
205            (input[3].as_() - zp) * half_scale,
206        ];
207
208        [x - w, y - h, x + w, y + h]
209    }
210
211    #[inline(always)]
212    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213        input: ArrayView1<B>,
214    ) -> [A; 4] {
215        let half = A::one() / (A::one() + A::one());
216        [
217            (input[0].as_()) - (input[2].as_() * half),
218            (input[1].as_()) - (input[3].as_() * half),
219            (input[0].as_()) + (input[2].as_() * half),
220            (input[1].as_()) + (input[3].as_() * half),
221        ]
222    }
223}
224
225/// Describes the quantization parameters for a tensor
226#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228    pub scale: f32,
229    pub zero_point: i32,
230}
231
232impl Quantization {
233    /// Creates a new Quantization struct
234    /// # Examples
235    /// ```
236    /// # use edgefirst_decoder::Quantization;
237    /// let quant = Quantization::new(0.1, -128);
238    /// assert_eq!(quant.scale, 0.1);
239    /// assert_eq!(quant.zero_point, -128);
240    /// ```
241    pub fn new(scale: f32, zero_point: i32) -> Self {
242        Self { scale, zero_point }
243    }
244}
245
246impl From<QuantTuple> for Quantization {
247    /// Creates a new Quantization struct from a QuantTuple
248    /// # Examples
249    /// ```
250    /// # use edgefirst_decoder::Quantization;
251    /// # use edgefirst_decoder::configs::QuantTuple;
252    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
253    /// let quant = Quantization::from(quant_tuple);
254    /// assert_eq!(quant.scale, 0.1);
255    /// assert_eq!(quant.zero_point, -128);
256    /// ```
257    fn from(quant_tuple: QuantTuple) -> Quantization {
258        Quantization {
259            scale: quant_tuple.0,
260            zero_point: quant_tuple.1,
261        }
262    }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267    S: AsPrimitive<f32>,
268    Z: AsPrimitive<i32>,
269{
270    /// Creates a new Quantization struct from a tuple
271    /// # Examples
272    /// ```
273    /// # use edgefirst_decoder::Quantization;
274    /// let quant = Quantization::from((0.1_f64, -128_i64));
275    /// assert_eq!(quant.scale, 0.1);
276    /// assert_eq!(quant.zero_point, -128);
277    /// ```
278    fn from((scale, zp): (S, Z)) -> Quantization {
279        Self {
280            scale: scale.as_(),
281            zero_point: zp.as_(),
282        }
283    }
284}
285
286impl Default for Quantization {
287    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
288    /// # Examples
289    /// ```rust
290    /// # use edgefirst_decoder::Quantization;
291    /// let quant = Quantization::default();
292    /// assert_eq!(quant.scale, 1.0);
293    /// assert_eq!(quant.zero_point, 0);
294    /// ```
295    fn default() -> Self {
296        Self {
297            scale: 1.0,
298            zero_point: 0,
299        }
300    }
301}
302
303/// A detection box with f32 bbox and score
304#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306    pub bbox: BoundingBox,
307    /// model-specific score for this detection, higher implies more confidence
308    pub score: f32,
309    /// label index for this detection
310    pub label: usize,
311}
312
313/// A bounding box with f32 coordinates in XYXY format
314#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316    /// left-most normalized coordinate of the bounding box
317    pub xmin: f32,
318    /// top-most normalized coordinate of the bounding box
319    pub ymin: f32,
320    /// right-most normalized coordinate of the bounding box
321    pub xmax: f32,
322    /// bottom-most normalized coordinate of the bounding box
323    pub ymax: f32,
324}
325
326impl BoundingBox {
327    /// Creates a new BoundingBox from the given coordinates
328    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329        Self {
330            xmin,
331            ymin,
332            xmax,
333            ymax,
334        }
335    }
336
337    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
338    ///
339    /// ```
340    /// # use edgefirst_decoder::BoundingBox;
341    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
342    /// let canonical_bbox = bbox.to_canonical();
343    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
344    /// ```
345    pub fn to_canonical(&self) -> Self {
346        let xmin = self.xmin.min(self.xmax);
347        let xmax = self.xmin.max(self.xmax);
348        let ymin = self.ymin.min(self.ymax);
349        let ymax = self.ymin.max(self.ymax);
350        BoundingBox {
351            xmin,
352            ymin,
353            xmax,
354            ymax,
355        }
356    }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
361    /// xmax, ymax order
362    /// # Examples
363    /// ```
364    /// # use edgefirst_decoder::BoundingBox;
365    /// let bbox = BoundingBox {
366    ///     xmin: 0.1,
367    ///     ymin: 0.2,
368    ///     xmax: 0.3,
369    ///     ymax: 0.4,
370    /// };
371    /// let arr: [f32; 4] = bbox.into();
372    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
373    /// ```
374    fn from(b: BoundingBox) -> Self {
375        [b.xmin, b.ymin, b.xmax, b.ymax]
376    }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
381    // BoundingBox
382    fn from(arr: [f32; 4]) -> Self {
383        BoundingBox {
384            xmin: arr[0],
385            ymin: arr[1],
386            xmax: arr[2],
387            ymax: arr[3],
388        }
389    }
390}
391
392impl DetectBox {
393    /// Returns true if one detect box is equal to another detect box, within
394    /// the given `eps`
395    ///
396    /// # Examples
397    /// ```
398    /// # use edgefirst_decoder::DetectBox;
399    /// let box1 = DetectBox {
400    ///     bbox: edgefirst_decoder::BoundingBox {
401    ///         xmin: 0.1,
402    ///         ymin: 0.2,
403    ///         xmax: 0.3,
404    ///         ymax: 0.4,
405    ///     },
406    ///     score: 0.5,
407    ///     label: 1,
408    /// };
409    /// let box2 = DetectBox {
410    ///     bbox: edgefirst_decoder::BoundingBox {
411    ///         xmin: 0.101,
412    ///         ymin: 0.199,
413    ///         xmax: 0.301,
414    ///         ymax: 0.399,
415    ///     },
416    ///     score: 0.510,
417    ///     label: 1,
418    /// };
419    /// assert!(box1.equal_within_delta(&box2, 0.011));
420    /// ```
421    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423        self.label == rhs.label
424            && eq_delta(self.score, rhs.score)
425            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429    }
430}
431
432/// A segmentation result with a segmentation mask, and a normalized bounding
433/// box representing the area that the segmentation mask covers
434#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436    /// left-most normalized coordinate of the segmentation box
437    pub xmin: f32,
438    /// top-most normalized coordinate of the segmentation box
439    pub ymin: f32,
440    /// right-most normalized coordinate of the segmentation box
441    pub xmax: f32,
442    /// bottom-most normalized coordinate of the segmentation box
443    pub ymax: f32,
444    /// 3D segmentation array of shape `(H, W, C)`.
445    ///
446    /// For instance segmentation (e.g. YOLO): `C=1` — binary per-instance
447    /// mask where values >= 128 indicate object presence.
448    ///
449    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
450    /// class scores where the object class is the argmax index.
451    pub segmentation: Array3<u8>,
452}
453
454/// Prototype tensor variants for fused decode+render pipelines.
455///
456/// Carries either raw quantized data (to skip CPU dequantization and let the
457/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
458/// paths).
459#[derive(Debug, Clone)]
460pub enum ProtoTensor {
461    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
462    /// The GPU fragment shader will dequantize per-texel using the scale and
463    /// zero_point.
464    Quantized {
465        protos: Array3<i8>,
466        quantization: Quantization,
467    },
468    /// Dequantized f32 protos (from float models or legacy path).
469    Float(Array3<f32>),
470}
471
472impl ProtoTensor {
473    /// Returns `true` if this is the quantized variant.
474    pub fn is_quantized(&self) -> bool {
475        matches!(self, ProtoTensor::Quantized { .. })
476    }
477
478    /// Returns the spatial dimensions `(height, width, num_protos)`.
479    pub fn dim(&self) -> (usize, usize, usize) {
480        match self {
481            ProtoTensor::Quantized { protos, .. } => protos.dim(),
482            ProtoTensor::Float(arr) => arr.dim(),
483        }
484    }
485
486    /// Returns dequantized f32 protos. For the `Float` variant this is a
487    /// no-copy reference; for `Quantized` it allocates and dequantizes.
488    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
489        match self {
490            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
491            ProtoTensor::Quantized {
492                protos,
493                quantization,
494            } => {
495                let scale = quantization.scale;
496                let zp = quantization.zero_point as f32;
497                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
498            }
499        }
500    }
501}
502
503/// Raw prototype data for fused decode+render pipelines.
504///
505/// Holds post-NMS intermediate state before mask materialization, allowing the
506/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
507/// shader) without materializing intermediate `Array3<u8>` masks.
508#[derive(Debug, Clone)]
509pub struct ProtoData {
510    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
511    pub mask_coefficients: Vec<Vec<f32>>,
512    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
513    pub protos: ProtoTensor,
514}
515
516/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
517///
518///  # Examples
519/// ```
520/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
521/// let quant = Quantization::new(0.1, -128);
522/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
523/// let detect_quant = DetectBoxQuantized {
524///     bbox,
525///     score: 100_i8,
526///     label: 1,
527/// };
528/// let detect = dequant_detect_box(&detect_quant, quant);
529/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
530/// assert_eq!(detect.label, 1);
531/// assert_eq!(detect.bbox, bbox);
532/// ```
533pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
534    detect: &DetectBoxQuantized<SCORE>,
535    quant_scores: Quantization,
536) -> DetectBox {
537    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
538    DetectBox {
539        bbox: detect.bbox,
540        score: quant_scores.scale * detect.score.as_() + scaled_zp,
541        label: detect.label,
542    }
543}
544/// A detection box with a f32 bbox and quantized score
545#[derive(Debug, Clone, Copy, PartialEq)]
546pub struct DetectBoxQuantized<
547    // BOX: Signed + PrimInt + AsPrimitive<f32>,
548    SCORE: PrimInt + AsPrimitive<f32>,
549> {
550    // pub bbox: BoundingBoxQuantized<BOX>,
551    pub bbox: BoundingBox,
552    /// model-specific score for this detection, higher implies more
553    /// confidence.
554    pub score: SCORE,
555    /// label index for this detect
556    pub label: usize,
557}
558
559/// Dequantizes an ndarray from quantized values to f32 values using the given
560/// quantization parameters
561///
562/// # Examples
563/// ```
564/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
565/// let quant = Quantization::new(0.1, -128);
566/// let input: Vec<i8> = vec![0, 127, -128, 64];
567/// let input_array = ndarray::Array1::from(input);
568/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
569/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
570/// ```
571pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
572    input: ArrayView<T, D>,
573    quant: Quantization,
574) -> Array<F, D>
575where
576    i32: num_traits::AsPrimitive<F>,
577    f32: num_traits::AsPrimitive<F>,
578{
579    let zero_point = quant.zero_point.as_();
580    let scale = quant.scale.as_();
581    if zero_point != F::zero() {
582        let scaled_zero = -zero_point * scale;
583        input.mapv(|d| d.as_() * scale + scaled_zero)
584    } else {
585        input.mapv(|d| d.as_() * scale)
586    }
587}
588
589/// Dequantizes a slice from quantized values to float values using the given
590/// quantization parameters
591///
592/// # Examples
593/// ```
594/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
595/// let quant = Quantization::new(0.1, -128);
596/// let input: Vec<i8> = vec![0, 127, -128, 64];
597/// let mut output: Vec<f32> = vec![0.0; input.len()];
598/// dequantize_cpu(&input, quant, &mut output);
599/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
600/// ```
601pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
602    input: &[T],
603    quant: Quantization,
604    output: &mut [F],
605) where
606    f32: num_traits::AsPrimitive<F>,
607    i32: num_traits::AsPrimitive<F>,
608{
609    assert!(input.len() == output.len());
610    let zero_point = quant.zero_point.as_();
611    let scale = quant.scale.as_();
612    if zero_point != F::zero() {
613        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
614        input
615            .iter()
616            .zip(output)
617            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
618    } else {
619        input
620            .iter()
621            .zip(output)
622            .for_each(|(d, deq)| *deq = d.as_() * scale);
623    }
624}
625
626/// Dequantizes a slice from quantized values to float values using the given
627/// quantization parameters, using chunked processing. This is around 5% faster
628/// than `dequantize_cpu` for large slices.
629///
630/// # Examples
631/// ```
632/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
633/// let quant = Quantization::new(0.1, -128);
634/// let input: Vec<i8> = vec![0, 127, -128, 64];
635/// let mut output: Vec<f32> = vec![0.0; input.len()];
636/// dequantize_cpu_chunked(&input, quant, &mut output);
637/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
638/// ```
639pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
640    input: &[T],
641    quant: Quantization,
642    output: &mut [F],
643) where
644    f32: num_traits::AsPrimitive<F>,
645    i32: num_traits::AsPrimitive<F>,
646{
647    assert!(input.len() == output.len());
648    let zero_point = quant.zero_point.as_();
649    let scale = quant.scale.as_();
650
651    let input = input.as_chunks::<4>();
652    let output = output.as_chunks_mut::<4>();
653
654    if zero_point != F::zero() {
655        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
656
657        input
658            .0
659            .iter()
660            .zip(output.0)
661            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
662        input
663            .1
664            .iter()
665            .zip(output.1)
666            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
667    } else {
668        input
669            .0
670            .iter()
671            .zip(output.0)
672            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
673        input
674            .1
675            .iter()
676            .zip(output.1)
677            .for_each(|(d, deq)| *deq = d.as_() * scale);
678    }
679}
680
681/// Converts a segmentation tensor into a 2D mask
682/// If the last dimension of the segmentation tensor is 1, values equal or
683/// above 128 are considered objects. Otherwise the object is the argmax index
684///
685/// # Errors
686///
687/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
688/// invalid shape.
689///
690/// # Examples
691/// ```
692/// # use edgefirst_decoder::segmentation_to_mask;
693/// let segmentation =
694///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
695/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
696/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
697/// ```
698pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
699    if segmentation.shape()[2] == 0 {
700        return Err(DecoderError::InvalidShape(
701            "Segmentation tensor must have non-zero depth".to_string(),
702        ));
703    }
704    if segmentation.shape()[2] == 1 {
705        yolo_segmentation_to_mask(segmentation, 128)
706    } else {
707        Ok(modelpack_segmentation_to_mask(segmentation))
708    }
709}
710
711/// Returns the maximum value and its index from a 1D array
712fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
713    score
714        .iter()
715        .enumerate()
716        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
717            if max > *s {
718                (max, arg_max)
719            } else {
720                (*s, ind)
721            }
722        })
723}
724#[cfg(test)]
725#[cfg_attr(coverage_nightly, coverage(off))]
726mod decoder_tests {
727    #![allow(clippy::excessive_precision)]
728    use crate::{
729        configs::{DecoderType, DimName, Protos},
730        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
731        yolo::{
732            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
733            decode_yolo_segdet_quant,
734        },
735        *,
736    };
737    use ndarray::{array, s, Array4};
738    use ndarray_stats::DeviationExt;
739
740    fn compare_outputs(
741        boxes: (&[DetectBox], &[DetectBox]),
742        masks: (&[Segmentation], &[Segmentation]),
743    ) {
744        let (boxes0, boxes1) = boxes;
745        let (masks0, masks1) = masks;
746
747        assert_eq!(boxes0.len(), boxes1.len());
748        assert_eq!(masks0.len(), masks1.len());
749
750        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
751            assert!(
752                b_i8.equal_within_delta(b_f32, 1e-6),
753                "{b_i8:?} is not equal to {b_f32:?}"
754            );
755        }
756
757        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
758            assert_eq!(
759                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
760                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
761            );
762            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
763            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
764            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
765            let diff = &mask_i8 - &mask_f32;
766            for x in 0..diff.shape()[0] {
767                for y in 0..diff.shape()[1] {
768                    for z in 0..diff.shape()[2] {
769                        let val = diff[[x, y, z]];
770                        assert!(
771                            val.abs() <= 1,
772                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
773                            x,
774                            y,
775                            z,
776                            val
777                        );
778                    }
779                }
780            }
781            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
782            assert!(
783                mean_sq_err < 1e-2,
784                "Mean Square Error between masks was greater than 1%: {:.2}%",
785                mean_sq_err * 100.0
786            );
787        }
788    }
789
790    #[test]
791    fn test_decoder_modelpack() {
792        let score_threshold = 0.45;
793        let iou_threshold = 0.45;
794        let boxes = include_bytes!(concat!(
795            env!("CARGO_MANIFEST_DIR"),
796            "/../../testdata/modelpack_boxes_1935x1x4.bin"
797        ));
798        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
799
800        let scores = include_bytes!(concat!(
801            env!("CARGO_MANIFEST_DIR"),
802            "/../../testdata/modelpack_scores_1935x1.bin"
803        ));
804        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
805
806        let quant_boxes = (0.004656755365431309, 21).into();
807        let quant_scores = (0.0019603664986789227, 0).into();
808
809        let decoder = DecoderBuilder::default()
810            .with_config_modelpack_det(
811                configs::Boxes {
812                    decoder: DecoderType::ModelPack,
813                    quantization: Some(quant_boxes),
814                    shape: vec![1, 1935, 1, 4],
815                    dshape: vec![
816                        (DimName::Batch, 1),
817                        (DimName::NumBoxes, 1935),
818                        (DimName::Padding, 1),
819                        (DimName::BoxCoords, 4),
820                    ],
821                    normalized: Some(true),
822                },
823                configs::Scores {
824                    decoder: DecoderType::ModelPack,
825                    quantization: Some(quant_scores),
826                    shape: vec![1, 1935, 1],
827                    dshape: vec![
828                        (DimName::Batch, 1),
829                        (DimName::NumBoxes, 1935),
830                        (DimName::NumClasses, 1),
831                    ],
832                },
833            )
834            .with_score_threshold(score_threshold)
835            .with_iou_threshold(iou_threshold)
836            .build()
837            .unwrap();
838
839        let quant_boxes = quant_boxes.into();
840        let quant_scores = quant_scores.into();
841
842        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
843        decode_modelpack_det(
844            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
845            (scores.slice(s![0, .., ..]), quant_scores),
846            score_threshold,
847            iou_threshold,
848            &mut output_boxes,
849        );
850        assert!(output_boxes[0].equal_within_delta(
851            &DetectBox {
852                bbox: BoundingBox {
853                    xmin: 0.40513772,
854                    ymin: 0.6379755,
855                    xmax: 0.5122431,
856                    ymax: 0.7730214,
857                },
858                score: 0.4861709,
859                label: 0
860            },
861            1e-6
862        ));
863
864        let mut output_boxes1 = Vec::with_capacity(50);
865        let mut output_masks1 = Vec::with_capacity(50);
866
867        decoder
868            .decode_quantized(
869                &[boxes.view().into(), scores.view().into()],
870                &mut output_boxes1,
871                &mut output_masks1,
872            )
873            .unwrap();
874
875        let mut output_boxes_float = Vec::with_capacity(50);
876        let mut output_masks_float = Vec::with_capacity(50);
877
878        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
879        let scores = dequantize_ndarray(scores.view(), quant_scores);
880
881        decoder
882            .decode_float::<f32>(
883                &[boxes.view().into_dyn(), scores.view().into_dyn()],
884                &mut output_boxes_float,
885                &mut output_masks_float,
886            )
887            .unwrap();
888
889        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
890        compare_outputs(
891            (&output_boxes, &output_boxes_float),
892            (&[], &output_masks_float),
893        );
894    }
895
896    #[test]
897    fn test_decoder_modelpack_split_u8() {
898        let score_threshold = 0.45;
899        let iou_threshold = 0.45;
900        let detect0 = include_bytes!(concat!(
901            env!("CARGO_MANIFEST_DIR"),
902            "/../../testdata/modelpack_split_9x15x18.bin"
903        ));
904        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
905
906        let detect1 = include_bytes!(concat!(
907            env!("CARGO_MANIFEST_DIR"),
908            "/../../testdata/modelpack_split_17x30x18.bin"
909        ));
910        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
911
912        let quant0 = (0.08547406643629074, 174).into();
913        let quant1 = (0.09929127991199493, 183).into();
914        let anchors0 = vec![
915            [0.36666667461395264, 0.31481480598449707],
916            [0.38749998807907104, 0.4740740656852722],
917            [0.5333333611488342, 0.644444465637207],
918        ];
919        let anchors1 = vec![
920            [0.13750000298023224, 0.2074074000120163],
921            [0.2541666626930237, 0.21481481194496155],
922            [0.23125000298023224, 0.35185185074806213],
923        ];
924
925        let detect_config0 = configs::Detection {
926            decoder: DecoderType::ModelPack,
927            shape: vec![1, 9, 15, 18],
928            anchors: Some(anchors0.clone()),
929            quantization: Some(quant0),
930            dshape: vec![
931                (DimName::Batch, 1),
932                (DimName::Height, 9),
933                (DimName::Width, 15),
934                (DimName::NumAnchorsXFeatures, 18),
935            ],
936            normalized: Some(true),
937        };
938
939        let detect_config1 = configs::Detection {
940            decoder: DecoderType::ModelPack,
941            shape: vec![1, 17, 30, 18],
942            anchors: Some(anchors1.clone()),
943            quantization: Some(quant1),
944            dshape: vec![
945                (DimName::Batch, 1),
946                (DimName::Height, 17),
947                (DimName::Width, 30),
948                (DimName::NumAnchorsXFeatures, 18),
949            ],
950            normalized: Some(true),
951        };
952
953        let config0 = (&detect_config0).try_into().unwrap();
954        let config1 = (&detect_config1).try_into().unwrap();
955
956        let decoder = DecoderBuilder::default()
957            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
958            .with_score_threshold(score_threshold)
959            .with_iou_threshold(iou_threshold)
960            .build()
961            .unwrap();
962
963        let quant0 = quant0.into();
964        let quant1 = quant1.into();
965
966        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
967        decode_modelpack_split_quant(
968            &[
969                detect0.slice(s![0, .., .., ..]),
970                detect1.slice(s![0, .., .., ..]),
971            ],
972            &[config0, config1],
973            score_threshold,
974            iou_threshold,
975            &mut output_boxes,
976        );
977        assert!(output_boxes[0].equal_within_delta(
978            &DetectBox {
979                bbox: BoundingBox {
980                    xmin: 0.43171933,
981                    ymin: 0.68243736,
982                    xmax: 0.5626645,
983                    ymax: 0.808863,
984                },
985                score: 0.99240804,
986                label: 0
987            },
988            1e-6
989        ));
990
991        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
992        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
993        decoder
994            .decode_quantized(
995                &[detect0.view().into(), detect1.view().into()],
996                &mut output_boxes1,
997                &mut output_masks1,
998            )
999            .unwrap();
1000
1001        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1002        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1003
1004        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1005        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1006        decoder
1007            .decode_float::<f32>(
1008                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1009                &mut output_boxes1_f32,
1010                &mut output_masks1_f32,
1011            )
1012            .unwrap();
1013
1014        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1015        compare_outputs(
1016            (&output_boxes, &output_boxes1_f32),
1017            (&[], &output_masks1_f32),
1018        );
1019    }
1020
1021    #[test]
1022    fn test_decoder_parse_config_modelpack_split_u8() {
1023        let score_threshold = 0.45;
1024        let iou_threshold = 0.45;
1025        let detect0 = include_bytes!(concat!(
1026            env!("CARGO_MANIFEST_DIR"),
1027            "/../../testdata/modelpack_split_9x15x18.bin"
1028        ));
1029        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1030
1031        let detect1 = include_bytes!(concat!(
1032            env!("CARGO_MANIFEST_DIR"),
1033            "/../../testdata/modelpack_split_17x30x18.bin"
1034        ));
1035        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1036
1037        let decoder = DecoderBuilder::default()
1038            .with_config_yaml_str(
1039                include_str!(concat!(
1040                    env!("CARGO_MANIFEST_DIR"),
1041                    "/../../testdata/modelpack_split.yaml"
1042                ))
1043                .to_string(),
1044            )
1045            .with_score_threshold(score_threshold)
1046            .with_iou_threshold(iou_threshold)
1047            .build()
1048            .unwrap();
1049
1050        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1051        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1052        decoder
1053            .decode_quantized(
1054                &[
1055                    ArrayViewDQuantized::from(detect1.view()),
1056                    ArrayViewDQuantized::from(detect0.view()),
1057                ],
1058                &mut output_boxes,
1059                &mut output_masks,
1060            )
1061            .unwrap();
1062        assert!(output_boxes[0].equal_within_delta(
1063            &DetectBox {
1064                bbox: BoundingBox {
1065                    xmin: 0.43171933,
1066                    ymin: 0.68243736,
1067                    xmax: 0.5626645,
1068                    ymax: 0.808863,
1069                },
1070                score: 0.99240804,
1071                label: 0
1072            },
1073            1e-6
1074        ));
1075    }
1076
1077    #[test]
1078    fn test_modelpack_seg() {
1079        let out = include_bytes!(concat!(
1080            env!("CARGO_MANIFEST_DIR"),
1081            "/../../testdata/modelpack_seg_2x160x160.bin"
1082        ));
1083        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1084        let quant = (1.0 / 255.0, 0).into();
1085
1086        let decoder = DecoderBuilder::default()
1087            .with_config_modelpack_seg(configs::Segmentation {
1088                decoder: DecoderType::ModelPack,
1089                quantization: Some(quant),
1090                shape: vec![1, 2, 160, 160],
1091                dshape: vec![
1092                    (DimName::Batch, 1),
1093                    (DimName::NumClasses, 2),
1094                    (DimName::Height, 160),
1095                    (DimName::Width, 160),
1096                ],
1097            })
1098            .build()
1099            .unwrap();
1100        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1101        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1102        decoder
1103            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1104            .unwrap();
1105
1106        let mut mask = out.slice(s![0, .., .., ..]);
1107        mask.swap_axes(0, 1);
1108        mask.swap_axes(1, 2);
1109        let mask = [Segmentation {
1110            xmin: 0.0,
1111            ymin: 0.0,
1112            xmax: 1.0,
1113            ymax: 1.0,
1114            segmentation: mask.into_owned(),
1115        }];
1116        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1117
1118        decoder
1119            .decode_float::<f32>(
1120                &[dequantize_ndarray(out.view(), quant.into())
1121                    .view()
1122                    .into_dyn()],
1123                &mut output_boxes,
1124                &mut output_masks,
1125            )
1126            .unwrap();
1127
1128        // not expected for float decoder to have same values as quantized decoder, as
1129        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1130        // the model output. Thus the float output is the same as the quantized output
1131        // but scaled differently. However, it is expected that the mask after argmax
1132        // will be the same.
1133        compare_outputs((&[], &output_boxes), (&[], &[]));
1134        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1135        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1136
1137        assert_eq!(mask0, mask1);
1138    }
1139    #[test]
1140    fn test_modelpack_seg_quant() {
1141        let out = include_bytes!(concat!(
1142            env!("CARGO_MANIFEST_DIR"),
1143            "/../../testdata/modelpack_seg_2x160x160.bin"
1144        ));
1145        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1146        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1147        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1148        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1149        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1150        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1151
1152        let quant = (1.0 / 255.0, 0).into();
1153
1154        let decoder = DecoderBuilder::default()
1155            .with_config_modelpack_seg(configs::Segmentation {
1156                decoder: DecoderType::ModelPack,
1157                quantization: Some(quant),
1158                shape: vec![1, 2, 160, 160],
1159                dshape: vec![
1160                    (DimName::Batch, 1),
1161                    (DimName::NumClasses, 2),
1162                    (DimName::Height, 160),
1163                    (DimName::Width, 160),
1164                ],
1165            })
1166            .build()
1167            .unwrap();
1168        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1169        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1170        decoder
1171            .decode_quantized(
1172                &[out_u8.view().into()],
1173                &mut output_boxes,
1174                &mut output_masks_u8,
1175            )
1176            .unwrap();
1177
1178        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1179        decoder
1180            .decode_quantized(
1181                &[out_i8.view().into()],
1182                &mut output_boxes,
1183                &mut output_masks_i8,
1184            )
1185            .unwrap();
1186
1187        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1188        decoder
1189            .decode_quantized(
1190                &[out_u16.view().into()],
1191                &mut output_boxes,
1192                &mut output_masks_u16,
1193            )
1194            .unwrap();
1195
1196        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1197        decoder
1198            .decode_quantized(
1199                &[out_i16.view().into()],
1200                &mut output_boxes,
1201                &mut output_masks_i16,
1202            )
1203            .unwrap();
1204
1205        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1206        decoder
1207            .decode_quantized(
1208                &[out_u32.view().into()],
1209                &mut output_boxes,
1210                &mut output_masks_u32,
1211            )
1212            .unwrap();
1213
1214        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1215        decoder
1216            .decode_quantized(
1217                &[out_i32.view().into()],
1218                &mut output_boxes,
1219                &mut output_masks_i32,
1220            )
1221            .unwrap();
1222
1223        compare_outputs((&[], &output_boxes), (&[], &[]));
1224        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1225        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1226        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1227        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1228        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1229        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1230        assert_eq!(mask_u8, mask_i8);
1231        assert_eq!(mask_u8, mask_u16);
1232        assert_eq!(mask_u8, mask_i16);
1233        assert_eq!(mask_u8, mask_u32);
1234        assert_eq!(mask_u8, mask_i32);
1235    }
1236
1237    #[test]
1238    fn test_modelpack_segdet() {
1239        let score_threshold = 0.45;
1240        let iou_threshold = 0.45;
1241
1242        let boxes = include_bytes!(concat!(
1243            env!("CARGO_MANIFEST_DIR"),
1244            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1245        ));
1246        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1247
1248        let scores = include_bytes!(concat!(
1249            env!("CARGO_MANIFEST_DIR"),
1250            "/../../testdata/modelpack_scores_1935x1.bin"
1251        ));
1252        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1253
1254        let seg = include_bytes!(concat!(
1255            env!("CARGO_MANIFEST_DIR"),
1256            "/../../testdata/modelpack_seg_2x160x160.bin"
1257        ));
1258        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1259
1260        let quant_boxes = (0.004656755365431309, 21).into();
1261        let quant_scores = (0.0019603664986789227, 0).into();
1262        let quant_seg = (1.0 / 255.0, 0).into();
1263
1264        let decoder = DecoderBuilder::default()
1265            .with_config_modelpack_segdet(
1266                configs::Boxes {
1267                    decoder: DecoderType::ModelPack,
1268                    quantization: Some(quant_boxes),
1269                    shape: vec![1, 1935, 1, 4],
1270                    dshape: vec![
1271                        (DimName::Batch, 1),
1272                        (DimName::NumBoxes, 1935),
1273                        (DimName::Padding, 1),
1274                        (DimName::BoxCoords, 4),
1275                    ],
1276                    normalized: Some(true),
1277                },
1278                configs::Scores {
1279                    decoder: DecoderType::ModelPack,
1280                    quantization: Some(quant_scores),
1281                    shape: vec![1, 1935, 1],
1282                    dshape: vec![
1283                        (DimName::Batch, 1),
1284                        (DimName::NumBoxes, 1935),
1285                        (DimName::NumClasses, 1),
1286                    ],
1287                },
1288                configs::Segmentation {
1289                    decoder: DecoderType::ModelPack,
1290                    quantization: Some(quant_seg),
1291                    shape: vec![1, 2, 160, 160],
1292                    dshape: vec![
1293                        (DimName::Batch, 1),
1294                        (DimName::NumClasses, 2),
1295                        (DimName::Height, 160),
1296                        (DimName::Width, 160),
1297                    ],
1298                },
1299            )
1300            .with_iou_threshold(iou_threshold)
1301            .with_score_threshold(score_threshold)
1302            .build()
1303            .unwrap();
1304        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1305        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1306        decoder
1307            .decode_quantized(
1308                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1309                &mut output_boxes,
1310                &mut output_masks,
1311            )
1312            .unwrap();
1313
1314        let mut mask = seg.slice(s![0, .., .., ..]);
1315        mask.swap_axes(0, 1);
1316        mask.swap_axes(1, 2);
1317        let mask = [Segmentation {
1318            xmin: 0.0,
1319            ymin: 0.0,
1320            xmax: 1.0,
1321            ymax: 1.0,
1322            segmentation: mask.into_owned(),
1323        }];
1324        let correct_boxes = [DetectBox {
1325            bbox: BoundingBox {
1326                xmin: 0.40513772,
1327                ymin: 0.6379755,
1328                xmax: 0.5122431,
1329                ymax: 0.7730214,
1330            },
1331            score: 0.4861709,
1332            label: 0,
1333        }];
1334        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1335
1336        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1337        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1338        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1339        decoder
1340            .decode_float::<f32>(
1341                &[
1342                    scores.view().into_dyn(),
1343                    boxes.view().into_dyn(),
1344                    seg.view().into_dyn(),
1345                ],
1346                &mut output_boxes,
1347                &mut output_masks,
1348            )
1349            .unwrap();
1350
1351        // not expected for float segmentation decoder to have same values as quantized
1352        // segmentation decoder, as float decoder ensures the data fills 0-255,
1353        // quantized decoder uses whatever the model output. Thus the float
1354        // output is the same as the quantized output but scaled differently.
1355        // However, it is expected that the mask after argmax will be the same.
1356        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1357        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1358        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1359
1360        assert_eq!(mask0, mask1);
1361    }
1362
1363    #[test]
1364    fn test_modelpack_segdet_split() {
1365        let score_threshold = 0.8;
1366        let iou_threshold = 0.5;
1367
1368        let seg = include_bytes!(concat!(
1369            env!("CARGO_MANIFEST_DIR"),
1370            "/../../testdata/modelpack_seg_2x160x160.bin"
1371        ));
1372        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1373
1374        let detect0 = include_bytes!(concat!(
1375            env!("CARGO_MANIFEST_DIR"),
1376            "/../../testdata/modelpack_split_9x15x18.bin"
1377        ));
1378        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1379
1380        let detect1 = include_bytes!(concat!(
1381            env!("CARGO_MANIFEST_DIR"),
1382            "/../../testdata/modelpack_split_17x30x18.bin"
1383        ));
1384        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1385
1386        let quant0 = (0.08547406643629074, 174).into();
1387        let quant1 = (0.09929127991199493, 183).into();
1388        let quant_seg = (1.0 / 255.0, 0).into();
1389
1390        let anchors0 = vec![
1391            [0.36666667461395264, 0.31481480598449707],
1392            [0.38749998807907104, 0.4740740656852722],
1393            [0.5333333611488342, 0.644444465637207],
1394        ];
1395        let anchors1 = vec![
1396            [0.13750000298023224, 0.2074074000120163],
1397            [0.2541666626930237, 0.21481481194496155],
1398            [0.23125000298023224, 0.35185185074806213],
1399        ];
1400
1401        let decoder = DecoderBuilder::default()
1402            .with_config_modelpack_segdet_split(
1403                vec![
1404                    configs::Detection {
1405                        decoder: DecoderType::ModelPack,
1406                        shape: vec![1, 17, 30, 18],
1407                        anchors: Some(anchors1),
1408                        quantization: Some(quant1),
1409                        dshape: vec![
1410                            (DimName::Batch, 1),
1411                            (DimName::Height, 17),
1412                            (DimName::Width, 30),
1413                            (DimName::NumAnchorsXFeatures, 18),
1414                        ],
1415                        normalized: Some(true),
1416                    },
1417                    configs::Detection {
1418                        decoder: DecoderType::ModelPack,
1419                        shape: vec![1, 9, 15, 18],
1420                        anchors: Some(anchors0),
1421                        quantization: Some(quant0),
1422                        dshape: vec![
1423                            (DimName::Batch, 1),
1424                            (DimName::Height, 9),
1425                            (DimName::Width, 15),
1426                            (DimName::NumAnchorsXFeatures, 18),
1427                        ],
1428                        normalized: Some(true),
1429                    },
1430                ],
1431                configs::Segmentation {
1432                    decoder: DecoderType::ModelPack,
1433                    quantization: Some(quant_seg),
1434                    shape: vec![1, 2, 160, 160],
1435                    dshape: vec![
1436                        (DimName::Batch, 1),
1437                        (DimName::NumClasses, 2),
1438                        (DimName::Height, 160),
1439                        (DimName::Width, 160),
1440                    ],
1441                },
1442            )
1443            .with_score_threshold(score_threshold)
1444            .with_iou_threshold(iou_threshold)
1445            .build()
1446            .unwrap();
1447        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1448        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1449        decoder
1450            .decode_quantized(
1451                &[
1452                    detect0.view().into(),
1453                    detect1.view().into(),
1454                    seg.view().into(),
1455                ],
1456                &mut output_boxes,
1457                &mut output_masks,
1458            )
1459            .unwrap();
1460
1461        let mut mask = seg.slice(s![0, .., .., ..]);
1462        mask.swap_axes(0, 1);
1463        mask.swap_axes(1, 2);
1464        let mask = [Segmentation {
1465            xmin: 0.0,
1466            ymin: 0.0,
1467            xmax: 1.0,
1468            ymax: 1.0,
1469            segmentation: mask.into_owned(),
1470        }];
1471        let correct_boxes = [DetectBox {
1472            bbox: BoundingBox {
1473                xmin: 0.43171933,
1474                ymin: 0.68243736,
1475                xmax: 0.5626645,
1476                ymax: 0.808863,
1477            },
1478            score: 0.99240804,
1479            label: 0,
1480        }];
1481        println!("Output Boxes: {:?}", output_boxes);
1482        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1483
1484        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1485        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1486        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1487        decoder
1488            .decode_float::<f32>(
1489                &[
1490                    detect0.view().into_dyn(),
1491                    detect1.view().into_dyn(),
1492                    seg.view().into_dyn(),
1493                ],
1494                &mut output_boxes,
1495                &mut output_masks,
1496            )
1497            .unwrap();
1498
1499        // not expected for float segmentation decoder to have same values as quantized
1500        // segmentation decoder, as float decoder ensures the data fills 0-255,
1501        // quantized decoder uses whatever the model output. Thus the float
1502        // output is the same as the quantized output but scaled differently.
1503        // However, it is expected that the mask after argmax will be the same.
1504        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1505        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1506        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1507
1508        assert_eq!(mask0, mask1);
1509    }
1510
1511    #[test]
1512    fn test_dequant_chunked() {
1513        let out = include_bytes!(concat!(
1514            env!("CARGO_MANIFEST_DIR"),
1515            "/../../testdata/yolov8s_80_classes.bin"
1516        ));
1517        let mut out =
1518            unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1519        out.push(123); // make sure to test non multiple of 16 length
1520
1521        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1522        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1523        let quant = Quantization::new(0.0040811873, -123);
1524        dequantize_cpu(&out, quant, &mut out_dequant);
1525
1526        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1527        assert_eq!(out_dequant, out_dequant_simd);
1528
1529        let quant = Quantization::new(0.0040811873, 0);
1530        dequantize_cpu(&out, quant, &mut out_dequant);
1531
1532        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1533        assert_eq!(out_dequant, out_dequant_simd);
1534    }
1535
1536    #[test]
1537    fn test_decoder_yolo_det() {
1538        let score_threshold = 0.25;
1539        let iou_threshold = 0.7;
1540        let out = include_bytes!(concat!(
1541            env!("CARGO_MANIFEST_DIR"),
1542            "/../../testdata/yolov8s_80_classes.bin"
1543        ));
1544        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1545        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1546        let quant = (0.0040811873, -123).into();
1547
1548        let decoder = DecoderBuilder::default()
1549            .with_config_yolo_det(
1550                configs::Detection {
1551                    decoder: DecoderType::Ultralytics,
1552                    shape: vec![1, 84, 8400],
1553                    anchors: None,
1554                    quantization: Some(quant),
1555                    dshape: vec![
1556                        (DimName::Batch, 1),
1557                        (DimName::NumFeatures, 84),
1558                        (DimName::NumBoxes, 8400),
1559                    ],
1560                    normalized: Some(true),
1561                },
1562                Some(DecoderVersion::Yolo11),
1563            )
1564            .with_score_threshold(score_threshold)
1565            .with_iou_threshold(iou_threshold)
1566            .build()
1567            .unwrap();
1568
1569        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1570        decode_yolo_det(
1571            (out.slice(s![0, .., ..]), quant.into()),
1572            score_threshold,
1573            iou_threshold,
1574            Some(configs::Nms::ClassAgnostic),
1575            &mut output_boxes,
1576        );
1577        assert!(output_boxes[0].equal_within_delta(
1578            &DetectBox {
1579                bbox: BoundingBox {
1580                    xmin: 0.5285137,
1581                    ymin: 0.05305544,
1582                    xmax: 0.87541467,
1583                    ymax: 0.9998909,
1584                },
1585                score: 0.5591227,
1586                label: 0
1587            },
1588            1e-6
1589        ));
1590
1591        assert!(output_boxes[1].equal_within_delta(
1592            &DetectBox {
1593                bbox: BoundingBox {
1594                    xmin: 0.130598,
1595                    ymin: 0.43260583,
1596                    xmax: 0.35098213,
1597                    ymax: 0.9958097,
1598                },
1599                score: 0.33057618,
1600                label: 75
1601            },
1602            1e-6
1603        ));
1604
1605        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1606        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1607        decoder
1608            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1609            .unwrap();
1610
1611        let out = dequantize_ndarray(out.view(), quant.into());
1612        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1613        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1614        decoder
1615            .decode_float::<f32>(
1616                &[out.view().into_dyn()],
1617                &mut output_boxes_f32,
1618                &mut output_masks_f32,
1619            )
1620            .unwrap();
1621
1622        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1623        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1624    }
1625
1626    #[test]
1627    fn test_decoder_masks() {
1628        let score_threshold = 0.45;
1629        let iou_threshold = 0.45;
1630        let boxes = include_bytes!(concat!(
1631            env!("CARGO_MANIFEST_DIR"),
1632            "/../../testdata/yolov8_boxes_116x8400.bin"
1633        ));
1634        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1635        let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1636        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1637
1638        let protos = include_bytes!(concat!(
1639            env!("CARGO_MANIFEST_DIR"),
1640            "/../../testdata/yolov8_protos_160x160x32.bin"
1641        ));
1642        let protos =
1643            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1644        let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1645        let quant_protos = Quantization::new(0.02491161972284317, -117);
1646        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1647        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1648        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1649        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1650        decode_yolo_segdet_float(
1651            seg.view(),
1652            protos.view(),
1653            score_threshold,
1654            iou_threshold,
1655            Some(configs::Nms::ClassAgnostic),
1656            &mut output_boxes,
1657            &mut output_masks,
1658        )
1659        .unwrap();
1660        assert_eq!(output_boxes.len(), 2);
1661        assert_eq!(output_boxes.len(), output_masks.len());
1662
1663        for (b, m) in output_boxes.iter().zip(&output_masks) {
1664            assert!(b.bbox.xmin >= m.xmin);
1665            assert!(b.bbox.ymin >= m.ymin);
1666            assert!(b.bbox.xmax >= m.xmax);
1667            assert!(b.bbox.ymax >= m.ymax);
1668        }
1669        assert!(output_boxes[0].equal_within_delta(
1670            &DetectBox {
1671                bbox: BoundingBox {
1672                    xmin: 0.08515105,
1673                    ymin: 0.7131401,
1674                    xmax: 0.29802868,
1675                    ymax: 0.8195788,
1676                },
1677                score: 0.91537374,
1678                label: 23
1679            },
1680            1.0 / 160.0, // wider range because mask will expand the box
1681        ));
1682
1683        assert!(output_boxes[1].equal_within_delta(
1684            &DetectBox {
1685                bbox: BoundingBox {
1686                    xmin: 0.59605736,
1687                    ymin: 0.25545314,
1688                    xmax: 0.93666154,
1689                    ymax: 0.72378385,
1690                },
1691                score: 0.91537374,
1692                label: 23
1693            },
1694            1.0 / 160.0, // wider range because mask will expand the box
1695        ));
1696
1697        let full_mask = include_bytes!(concat!(
1698            env!("CARGO_MANIFEST_DIR"),
1699            "/../../testdata/yolov8_mask_results.bin"
1700        ));
1701        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1702
1703        let cropped_mask = full_mask.slice(ndarray::s![
1704            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1705            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1706        ]);
1707
1708        assert_eq!(
1709            cropped_mask,
1710            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1711        );
1712    }
1713
1714    /// Regression test: config-driven path with NCHW protos (no dshape).
1715    /// Simulates YOLOv8-seg ONNX outputs where protos are (1, 32, 160, 160)
1716    /// and the YAML config has no dshape field — the exact scenario from
1717    /// hal_mask_matmul_bug.md.
1718    #[test]
1719    fn test_decoder_masks_nchw_protos() {
1720        let score_threshold = 0.45;
1721        let iou_threshold = 0.45;
1722
1723        // Load test data — boxes as [116, 8400]
1724        let boxes_raw = include_bytes!(concat!(
1725            env!("CARGO_MANIFEST_DIR"),
1726            "/../../testdata/yolov8_boxes_116x8400.bin"
1727        ));
1728        let boxes_raw =
1729            unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1730        let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1731        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1732
1733        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1734        let protos_raw = include_bytes!(concat!(
1735            env!("CARGO_MANIFEST_DIR"),
1736            "/../../testdata/yolov8_protos_160x160x32.bin"
1737        ));
1738        let protos_raw = unsafe {
1739            std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1740        };
1741        let protos_hwc =
1742            ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1743        let quant_protos = Quantization::new(0.02491161972284317, -117);
1744        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1745
1746        // ---- Reference: direct call with HWC protos (known working) ----
1747        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1748        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1749        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1750        decode_yolo_segdet_float(
1751            seg.view(),
1752            protos_f32_hwc.view(),
1753            score_threshold,
1754            iou_threshold,
1755            Some(configs::Nms::ClassAgnostic),
1756            &mut ref_boxes,
1757            &mut ref_masks,
1758        )
1759        .unwrap();
1760        assert_eq!(ref_boxes.len(), 2);
1761
1762        // ---- Config-driven path: NCHW protos, no dshape ----
1763        // Permute protos to NCHW [1, 32, 160, 160] as an ONNX model would output
1764        let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); // [32, 160, 160]
1765        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1766
1767        // Build boxes as [1, 116, 8400] f32
1768        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1769
1770        // Build decoder from config with no dshape on protos
1771        let decoder = DecoderBuilder::default()
1772            .with_config_yolo_segdet(
1773                configs::Detection {
1774                    decoder: configs::DecoderType::Ultralytics,
1775                    quantization: None,
1776                    shape: vec![1, 116, 8400],
1777                    dshape: vec![],
1778                    normalized: Some(true),
1779                    anchors: None,
1780                },
1781                configs::Protos {
1782                    decoder: configs::DecoderType::Ultralytics,
1783                    quantization: None,
1784                    shape: vec![1, 32, 160, 160],
1785                    dshape: vec![], // No dshape — simulates YAML without dshape
1786                },
1787                None, // decoder version
1788            )
1789            .with_score_threshold(score_threshold)
1790            .with_iou_threshold(iou_threshold)
1791            .build()
1792            .unwrap();
1793
1794        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1795        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1796        decoder
1797            .decode_float(
1798                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1799                &mut cfg_boxes,
1800                &mut cfg_masks,
1801            )
1802            .unwrap();
1803
1804        // Must produce the same number of detections
1805        assert_eq!(
1806            cfg_boxes.len(),
1807            ref_boxes.len(),
1808            "config path produced {} boxes, reference produced {}",
1809            cfg_boxes.len(),
1810            ref_boxes.len()
1811        );
1812
1813        // Boxes must match
1814        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1815            assert!(
1816                cb.equal_within_delta(rb, 0.01),
1817                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1818            );
1819        }
1820
1821        // Masks must match pixel-for-pixel
1822        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1823            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1824            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1825            assert_eq!(
1826                cm_arr, rm_arr,
1827                "mask {i} pixel mismatch between config-driven and reference paths"
1828            );
1829        }
1830    }
1831
1832    #[test]
1833    fn test_decoder_masks_i8() {
1834        let score_threshold = 0.45;
1835        let iou_threshold = 0.45;
1836        let boxes = include_bytes!(concat!(
1837            env!("CARGO_MANIFEST_DIR"),
1838            "/../../testdata/yolov8_boxes_116x8400.bin"
1839        ));
1840        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1841        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1842        let quant_boxes = (0.021287761628627777, 31).into();
1843
1844        let protos = include_bytes!(concat!(
1845            env!("CARGO_MANIFEST_DIR"),
1846            "/../../testdata/yolov8_protos_160x160x32.bin"
1847        ));
1848        let protos =
1849            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1850        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1851        let quant_protos = (0.02491161972284317, -117).into();
1852        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1853        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1854
1855        let decoder = DecoderBuilder::default()
1856            .with_config_yolo_segdet(
1857                configs::Detection {
1858                    decoder: configs::DecoderType::Ultralytics,
1859                    quantization: Some(quant_boxes),
1860                    shape: vec![1, 116, 8400],
1861                    anchors: None,
1862                    dshape: vec![
1863                        (DimName::Batch, 1),
1864                        (DimName::NumFeatures, 116),
1865                        (DimName::NumBoxes, 8400),
1866                    ],
1867                    normalized: Some(true),
1868                },
1869                Protos {
1870                    decoder: configs::DecoderType::Ultralytics,
1871                    quantization: Some(quant_protos),
1872                    shape: vec![1, 160, 160, 32],
1873                    dshape: vec![
1874                        (DimName::Batch, 1),
1875                        (DimName::Height, 160),
1876                        (DimName::Width, 160),
1877                        (DimName::NumProtos, 32),
1878                    ],
1879                },
1880                Some(DecoderVersion::Yolo11),
1881            )
1882            .with_score_threshold(score_threshold)
1883            .with_iou_threshold(iou_threshold)
1884            .build()
1885            .unwrap();
1886
1887        let quant_boxes = quant_boxes.into();
1888        let quant_protos = quant_protos.into();
1889
1890        decode_yolo_segdet_quant(
1891            (boxes.slice(s![0, .., ..]), quant_boxes),
1892            (protos.slice(s![0, .., .., ..]), quant_protos),
1893            score_threshold,
1894            iou_threshold,
1895            Some(configs::Nms::ClassAgnostic),
1896            &mut output_boxes,
1897            &mut output_masks,
1898        )
1899        .unwrap();
1900
1901        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1902        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1903
1904        decoder
1905            .decode_quantized(
1906                &[boxes.view().into(), protos.view().into()],
1907                &mut output_boxes1,
1908                &mut output_masks1,
1909            )
1910            .unwrap();
1911
1912        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1913        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1914
1915        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1916        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1917        decode_yolo_segdet_float(
1918            seg.slice(s![0, .., ..]),
1919            protos.slice(s![0, .., .., ..]),
1920            score_threshold,
1921            iou_threshold,
1922            Some(configs::Nms::ClassAgnostic),
1923            &mut output_boxes_f32,
1924            &mut output_masks_f32,
1925        )
1926        .unwrap();
1927
1928        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1929        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1930
1931        decoder
1932            .decode_float(
1933                &[seg.view().into_dyn(), protos.view().into_dyn()],
1934                &mut output_boxes1_f32,
1935                &mut output_masks1_f32,
1936            )
1937            .unwrap();
1938
1939        compare_outputs(
1940            (&output_boxes, &output_boxes1),
1941            (&output_masks, &output_masks1),
1942        );
1943
1944        compare_outputs(
1945            (&output_boxes, &output_boxes_f32),
1946            (&output_masks, &output_masks_f32),
1947        );
1948
1949        compare_outputs(
1950            (&output_boxes_f32, &output_boxes1_f32),
1951            (&output_masks_f32, &output_masks1_f32),
1952        );
1953    }
1954
1955    #[test]
1956    fn test_decoder_yolo_split() {
1957        let score_threshold = 0.45;
1958        let iou_threshold = 0.45;
1959        let boxes = include_bytes!(concat!(
1960            env!("CARGO_MANIFEST_DIR"),
1961            "/../../testdata/yolov8_boxes_116x8400.bin"
1962        ));
1963        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1964        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1965        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1966
1967        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1968
1969        let decoder = DecoderBuilder::default()
1970            .with_config_yolo_split_det(
1971                configs::Boxes {
1972                    decoder: configs::DecoderType::Ultralytics,
1973                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1974                    shape: vec![1, 4, 8400],
1975                    dshape: vec![
1976                        (DimName::Batch, 1),
1977                        (DimName::BoxCoords, 4),
1978                        (DimName::NumBoxes, 8400),
1979                    ],
1980                    normalized: Some(true),
1981                },
1982                configs::Scores {
1983                    decoder: configs::DecoderType::Ultralytics,
1984                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1985                    shape: vec![1, 80, 8400],
1986                    dshape: vec![
1987                        (DimName::Batch, 1),
1988                        (DimName::NumClasses, 80),
1989                        (DimName::NumBoxes, 8400),
1990                    ],
1991                },
1992            )
1993            .with_score_threshold(score_threshold)
1994            .with_iou_threshold(iou_threshold)
1995            .build()
1996            .unwrap();
1997
1998        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1999        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2000
2001        decoder
2002            .decode_quantized(
2003                &[
2004                    boxes.slice(s![.., ..4, ..]).into(),
2005                    boxes.slice(s![.., 4..84, ..]).into(),
2006                ],
2007                &mut output_boxes,
2008                &mut output_masks,
2009            )
2010            .unwrap();
2011
2012        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2013        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2014        decode_yolo_det_float(
2015            seg.slice(s![0, ..84, ..]),
2016            score_threshold,
2017            iou_threshold,
2018            Some(configs::Nms::ClassAgnostic),
2019            &mut output_boxes_f32,
2020        );
2021
2022        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2023        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2024
2025        decoder
2026            .decode_float(
2027                &[
2028                    seg.slice(s![.., ..4, ..]).into_dyn(),
2029                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2030                ],
2031                &mut output_boxes1,
2032                &mut output_masks1,
2033            )
2034            .unwrap();
2035        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2036        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2037    }
2038
2039    #[test]
2040    fn test_decoder_masks_config_mixed() {
2041        let score_threshold = 0.45;
2042        let iou_threshold = 0.45;
2043        let boxes = include_bytes!(concat!(
2044            env!("CARGO_MANIFEST_DIR"),
2045            "/../../testdata/yolov8_boxes_116x8400.bin"
2046        ));
2047        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2048        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2049        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2050
2051        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2052
2053        let protos = include_bytes!(concat!(
2054            env!("CARGO_MANIFEST_DIR"),
2055            "/../../testdata/yolov8_protos_160x160x32.bin"
2056        ));
2057        let protos =
2058            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2059        let protos: Vec<_> = protos.to_vec();
2060        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2061        let quant_protos = Quantization::new(0.02491161972284317, -117);
2062
2063        let decoder = DecoderBuilder::default()
2064            .with_config_yolo_split_segdet(
2065                configs::Boxes {
2066                    decoder: configs::DecoderType::Ultralytics,
2067                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2068                    shape: vec![1, 4, 8400],
2069                    dshape: vec![
2070                        (DimName::Batch, 1),
2071                        (DimName::BoxCoords, 4),
2072                        (DimName::NumBoxes, 8400),
2073                    ],
2074                    normalized: Some(true),
2075                },
2076                configs::Scores {
2077                    decoder: configs::DecoderType::Ultralytics,
2078                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2079                    shape: vec![1, 80, 8400],
2080                    dshape: vec![
2081                        (DimName::Batch, 1),
2082                        (DimName::NumClasses, 80),
2083                        (DimName::NumBoxes, 8400),
2084                    ],
2085                },
2086                configs::MaskCoefficients {
2087                    decoder: configs::DecoderType::Ultralytics,
2088                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2089                    shape: vec![1, 32, 8400],
2090                    dshape: vec![
2091                        (DimName::Batch, 1),
2092                        (DimName::NumProtos, 32),
2093                        (DimName::NumBoxes, 8400),
2094                    ],
2095                },
2096                configs::Protos {
2097                    decoder: configs::DecoderType::Ultralytics,
2098                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2099                    shape: vec![1, 160, 160, 32],
2100                    dshape: vec![
2101                        (DimName::Batch, 1),
2102                        (DimName::Height, 160),
2103                        (DimName::Width, 160),
2104                        (DimName::NumProtos, 32),
2105                    ],
2106                },
2107            )
2108            .with_score_threshold(score_threshold)
2109            .with_iou_threshold(iou_threshold)
2110            .build()
2111            .unwrap();
2112
2113        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2114        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2115
2116        decoder
2117            .decode_quantized(
2118                &[
2119                    boxes.slice(s![.., ..4, ..]).into(),
2120                    boxes.slice(s![.., 4..84, ..]).into(),
2121                    boxes.slice(s![.., 84.., ..]).into(),
2122                    protos.view().into(),
2123                ],
2124                &mut output_boxes,
2125                &mut output_masks,
2126            )
2127            .unwrap();
2128
2129        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2130        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2131        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2132        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2133        decode_yolo_segdet_float(
2134            seg.slice(s![0, .., ..]),
2135            protos.slice(s![0, .., .., ..]),
2136            score_threshold,
2137            iou_threshold,
2138            Some(configs::Nms::ClassAgnostic),
2139            &mut output_boxes_f32,
2140            &mut output_masks_f32,
2141        )
2142        .unwrap();
2143
2144        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2145        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2146
2147        decoder
2148            .decode_float(
2149                &[
2150                    seg.slice(s![.., ..4, ..]).into_dyn(),
2151                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2152                    seg.slice(s![.., 84.., ..]).into_dyn(),
2153                    protos.view().into_dyn(),
2154                ],
2155                &mut output_boxes1,
2156                &mut output_masks1,
2157            )
2158            .unwrap();
2159        compare_outputs(
2160            (&output_boxes, &output_boxes_f32),
2161            (&output_masks, &output_masks_f32),
2162        );
2163        compare_outputs(
2164            (&output_boxes_f32, &output_boxes1),
2165            (&output_masks_f32, &output_masks1),
2166        );
2167    }
2168
2169    #[test]
2170    fn test_decoder_masks_config_i32() {
2171        let score_threshold = 0.45;
2172        let iou_threshold = 0.45;
2173        let boxes = include_bytes!(concat!(
2174            env!("CARGO_MANIFEST_DIR"),
2175            "/../../testdata/yolov8_boxes_116x8400.bin"
2176        ));
2177        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2178        let scale = 1 << 23;
2179        let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2180        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2181
2182        let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2183
2184        let protos = include_bytes!(concat!(
2185            env!("CARGO_MANIFEST_DIR"),
2186            "/../../testdata/yolov8_protos_160x160x32.bin"
2187        ));
2188        let protos =
2189            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2190        let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2191        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2192        let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2193
2194        let decoder = DecoderBuilder::default()
2195            .with_config_yolo_split_segdet(
2196                configs::Boxes {
2197                    decoder: configs::DecoderType::Ultralytics,
2198                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2199                    shape: vec![1, 4, 8400],
2200                    dshape: vec![
2201                        (DimName::Batch, 1),
2202                        (DimName::BoxCoords, 4),
2203                        (DimName::NumBoxes, 8400),
2204                    ],
2205                    normalized: Some(true),
2206                },
2207                configs::Scores {
2208                    decoder: configs::DecoderType::Ultralytics,
2209                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2210                    shape: vec![1, 80, 8400],
2211                    dshape: vec![
2212                        (DimName::Batch, 1),
2213                        (DimName::NumClasses, 80),
2214                        (DimName::NumBoxes, 8400),
2215                    ],
2216                },
2217                configs::MaskCoefficients {
2218                    decoder: configs::DecoderType::Ultralytics,
2219                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2220                    shape: vec![1, 32, 8400],
2221                    dshape: vec![
2222                        (DimName::Batch, 1),
2223                        (DimName::NumProtos, 32),
2224                        (DimName::NumBoxes, 8400),
2225                    ],
2226                },
2227                configs::Protos {
2228                    decoder: configs::DecoderType::Ultralytics,
2229                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2230                    shape: vec![1, 160, 160, 32],
2231                    dshape: vec![
2232                        (DimName::Batch, 1),
2233                        (DimName::Height, 160),
2234                        (DimName::Width, 160),
2235                        (DimName::NumProtos, 32),
2236                    ],
2237                },
2238            )
2239            .with_score_threshold(score_threshold)
2240            .with_iou_threshold(iou_threshold)
2241            .build()
2242            .unwrap();
2243
2244        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2245        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2246
2247        decoder
2248            .decode_quantized(
2249                &[
2250                    boxes.slice(s![.., ..4, ..]).into(),
2251                    boxes.slice(s![.., 4..84, ..]).into(),
2252                    boxes.slice(s![.., 84.., ..]).into(),
2253                    protos.view().into(),
2254                ],
2255                &mut output_boxes,
2256                &mut output_masks,
2257            )
2258            .unwrap();
2259
2260        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2261        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2262        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2263        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2264        decode_yolo_segdet_float(
2265            seg.slice(s![0, .., ..]),
2266            protos.slice(s![0, .., .., ..]),
2267            score_threshold,
2268            iou_threshold,
2269            Some(configs::Nms::ClassAgnostic),
2270            &mut output_boxes_f32,
2271            &mut output_masks_f32,
2272        )
2273        .unwrap();
2274
2275        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2276        assert_eq!(output_masks.len(), output_masks_f32.len());
2277
2278        compare_outputs(
2279            (&output_boxes, &output_boxes_f32),
2280            (&output_masks, &output_masks_f32),
2281        );
2282    }
2283
2284    /// test running multiple decoders concurrently
2285    #[test]
2286    fn test_context_switch() {
2287        let yolo_det = || {
2288            let score_threshold = 0.25;
2289            let iou_threshold = 0.7;
2290            let out = include_bytes!(concat!(
2291                env!("CARGO_MANIFEST_DIR"),
2292                "/../../testdata/yolov8s_80_classes.bin"
2293            ));
2294            let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2295            let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2296            let quant = (0.0040811873, -123).into();
2297
2298            let decoder = DecoderBuilder::default()
2299                .with_config_yolo_det(
2300                    configs::Detection {
2301                        decoder: DecoderType::Ultralytics,
2302                        shape: vec![1, 84, 8400],
2303                        anchors: None,
2304                        quantization: Some(quant),
2305                        dshape: vec![
2306                            (DimName::Batch, 1),
2307                            (DimName::NumFeatures, 84),
2308                            (DimName::NumBoxes, 8400),
2309                        ],
2310                        normalized: None,
2311                    },
2312                    None,
2313                )
2314                .with_score_threshold(score_threshold)
2315                .with_iou_threshold(iou_threshold)
2316                .build()
2317                .unwrap();
2318
2319            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2320            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2321
2322            for _ in 0..100 {
2323                decoder
2324                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2325                    .unwrap();
2326
2327                assert!(output_boxes[0].equal_within_delta(
2328                    &DetectBox {
2329                        bbox: BoundingBox {
2330                            xmin: 0.5285137,
2331                            ymin: 0.05305544,
2332                            xmax: 0.87541467,
2333                            ymax: 0.9998909,
2334                        },
2335                        score: 0.5591227,
2336                        label: 0
2337                    },
2338                    1e-6
2339                ));
2340
2341                assert!(output_boxes[1].equal_within_delta(
2342                    &DetectBox {
2343                        bbox: BoundingBox {
2344                            xmin: 0.130598,
2345                            ymin: 0.43260583,
2346                            xmax: 0.35098213,
2347                            ymax: 0.9958097,
2348                        },
2349                        score: 0.33057618,
2350                        label: 75
2351                    },
2352                    1e-6
2353                ));
2354                assert!(output_masks.is_empty());
2355            }
2356        };
2357
2358        let modelpack_det_split = || {
2359            let score_threshold = 0.8;
2360            let iou_threshold = 0.5;
2361
2362            let seg = include_bytes!(concat!(
2363                env!("CARGO_MANIFEST_DIR"),
2364                "/../../testdata/modelpack_seg_2x160x160.bin"
2365            ));
2366            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2367
2368            let detect0 = include_bytes!(concat!(
2369                env!("CARGO_MANIFEST_DIR"),
2370                "/../../testdata/modelpack_split_9x15x18.bin"
2371            ));
2372            let detect0 =
2373                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2374
2375            let detect1 = include_bytes!(concat!(
2376                env!("CARGO_MANIFEST_DIR"),
2377                "/../../testdata/modelpack_split_17x30x18.bin"
2378            ));
2379            let detect1 =
2380                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2381
2382            let mut mask = seg.slice(s![0, .., .., ..]);
2383            mask.swap_axes(0, 1);
2384            mask.swap_axes(1, 2);
2385            let mask = [Segmentation {
2386                xmin: 0.0,
2387                ymin: 0.0,
2388                xmax: 1.0,
2389                ymax: 1.0,
2390                segmentation: mask.into_owned(),
2391            }];
2392            let correct_boxes = [DetectBox {
2393                bbox: BoundingBox {
2394                    xmin: 0.43171933,
2395                    ymin: 0.68243736,
2396                    xmax: 0.5626645,
2397                    ymax: 0.808863,
2398                },
2399                score: 0.99240804,
2400                label: 0,
2401            }];
2402
2403            let quant0 = (0.08547406643629074, 174).into();
2404            let quant1 = (0.09929127991199493, 183).into();
2405            let quant_seg = (1.0 / 255.0, 0).into();
2406
2407            let anchors0 = vec![
2408                [0.36666667461395264, 0.31481480598449707],
2409                [0.38749998807907104, 0.4740740656852722],
2410                [0.5333333611488342, 0.644444465637207],
2411            ];
2412            let anchors1 = vec![
2413                [0.13750000298023224, 0.2074074000120163],
2414                [0.2541666626930237, 0.21481481194496155],
2415                [0.23125000298023224, 0.35185185074806213],
2416            ];
2417
2418            let decoder = DecoderBuilder::default()
2419                .with_config_modelpack_segdet_split(
2420                    vec![
2421                        configs::Detection {
2422                            decoder: DecoderType::ModelPack,
2423                            shape: vec![1, 17, 30, 18],
2424                            anchors: Some(anchors1),
2425                            quantization: Some(quant1),
2426                            dshape: vec![
2427                                (DimName::Batch, 1),
2428                                (DimName::Height, 17),
2429                                (DimName::Width, 30),
2430                                (DimName::NumAnchorsXFeatures, 18),
2431                            ],
2432                            normalized: None,
2433                        },
2434                        configs::Detection {
2435                            decoder: DecoderType::ModelPack,
2436                            shape: vec![1, 9, 15, 18],
2437                            anchors: Some(anchors0),
2438                            quantization: Some(quant0),
2439                            dshape: vec![
2440                                (DimName::Batch, 1),
2441                                (DimName::Height, 9),
2442                                (DimName::Width, 15),
2443                                (DimName::NumAnchorsXFeatures, 18),
2444                            ],
2445                            normalized: None,
2446                        },
2447                    ],
2448                    configs::Segmentation {
2449                        decoder: DecoderType::ModelPack,
2450                        quantization: Some(quant_seg),
2451                        shape: vec![1, 2, 160, 160],
2452                        dshape: vec![
2453                            (DimName::Batch, 1),
2454                            (DimName::NumClasses, 2),
2455                            (DimName::Height, 160),
2456                            (DimName::Width, 160),
2457                        ],
2458                    },
2459                )
2460                .with_score_threshold(score_threshold)
2461                .with_iou_threshold(iou_threshold)
2462                .build()
2463                .unwrap();
2464            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2465            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2466
2467            for _ in 0..100 {
2468                decoder
2469                    .decode_quantized(
2470                        &[
2471                            detect0.view().into(),
2472                            detect1.view().into(),
2473                            seg.view().into(),
2474                        ],
2475                        &mut output_boxes,
2476                        &mut output_masks,
2477                    )
2478                    .unwrap();
2479
2480                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2481            }
2482        };
2483
2484        let handles = vec![
2485            std::thread::spawn(yolo_det),
2486            std::thread::spawn(modelpack_det_split),
2487            std::thread::spawn(yolo_det),
2488            std::thread::spawn(modelpack_det_split),
2489            std::thread::spawn(yolo_det),
2490            std::thread::spawn(modelpack_det_split),
2491            std::thread::spawn(yolo_det),
2492            std::thread::spawn(modelpack_det_split),
2493        ];
2494        for handle in handles {
2495            handle.join().unwrap();
2496        }
2497    }
2498
2499    #[test]
2500    fn test_ndarray_to_xyxy_float() {
2501        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2502        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2503        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2504
2505        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2506        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2507        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2508    }
2509
2510    #[test]
2511    fn test_class_aware_nms_float() {
2512        use crate::float::nms_class_aware_float;
2513
2514        // Create two overlapping boxes with different classes
2515        let boxes = vec![
2516            DetectBox {
2517                bbox: BoundingBox {
2518                    xmin: 0.0,
2519                    ymin: 0.0,
2520                    xmax: 0.5,
2521                    ymax: 0.5,
2522                },
2523                score: 0.9,
2524                label: 0, // class 0
2525            },
2526            DetectBox {
2527                bbox: BoundingBox {
2528                    xmin: 0.1,
2529                    ymin: 0.1,
2530                    xmax: 0.6,
2531                    ymax: 0.6,
2532                },
2533                score: 0.8,
2534                label: 1, // class 1 - different class
2535            },
2536        ];
2537
2538        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2539        // threshold 0.3)
2540        let result = nms_class_aware_float(0.3, boxes.clone());
2541        assert_eq!(
2542            result.len(),
2543            2,
2544            "Class-aware NMS should keep both boxes with different classes"
2545        );
2546
2547        // Now test with same class - should suppress one
2548        let same_class_boxes = vec![
2549            DetectBox {
2550                bbox: BoundingBox {
2551                    xmin: 0.0,
2552                    ymin: 0.0,
2553                    xmax: 0.5,
2554                    ymax: 0.5,
2555                },
2556                score: 0.9,
2557                label: 0,
2558            },
2559            DetectBox {
2560                bbox: BoundingBox {
2561                    xmin: 0.1,
2562                    ymin: 0.1,
2563                    xmax: 0.6,
2564                    ymax: 0.6,
2565                },
2566                score: 0.8,
2567                label: 0, // same class
2568            },
2569        ];
2570
2571        let result = nms_class_aware_float(0.3, same_class_boxes);
2572        assert_eq!(
2573            result.len(),
2574            1,
2575            "Class-aware NMS should suppress overlapping box with same class"
2576        );
2577        assert_eq!(result[0].label, 0);
2578        assert!((result[0].score - 0.9).abs() < 1e-6);
2579    }
2580
2581    #[test]
2582    fn test_class_agnostic_vs_aware_nms() {
2583        use crate::float::{nms_class_aware_float, nms_float};
2584
2585        // Two overlapping boxes with different classes
2586        let boxes = vec![
2587            DetectBox {
2588                bbox: BoundingBox {
2589                    xmin: 0.0,
2590                    ymin: 0.0,
2591                    xmax: 0.5,
2592                    ymax: 0.5,
2593                },
2594                score: 0.9,
2595                label: 0,
2596            },
2597            DetectBox {
2598                bbox: BoundingBox {
2599                    xmin: 0.1,
2600                    ymin: 0.1,
2601                    xmax: 0.6,
2602                    ymax: 0.6,
2603                },
2604                score: 0.8,
2605                label: 1,
2606            },
2607        ];
2608
2609        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2610        let agnostic_result = nms_float(0.3, boxes.clone());
2611        assert_eq!(
2612            agnostic_result.len(),
2613            1,
2614            "Class-agnostic NMS should suppress overlapping boxes"
2615        );
2616
2617        // Class-aware should keep both (different classes)
2618        let aware_result = nms_class_aware_float(0.3, boxes);
2619        assert_eq!(
2620            aware_result.len(),
2621            2,
2622            "Class-aware NMS should keep boxes with different classes"
2623        );
2624    }
2625
2626    #[test]
2627    fn test_class_aware_nms_int() {
2628        use crate::byte::nms_class_aware_int;
2629
2630        // Create two overlapping boxes with different classes
2631        let boxes = vec![
2632            DetectBoxQuantized {
2633                bbox: BoundingBox {
2634                    xmin: 0.0,
2635                    ymin: 0.0,
2636                    xmax: 0.5,
2637                    ymax: 0.5,
2638                },
2639                score: 200_u8,
2640                label: 0,
2641            },
2642            DetectBoxQuantized {
2643                bbox: BoundingBox {
2644                    xmin: 0.1,
2645                    ymin: 0.1,
2646                    xmax: 0.6,
2647                    ymax: 0.6,
2648                },
2649                score: 180_u8,
2650                label: 1, // different class
2651            },
2652        ];
2653
2654        // Should keep both (different classes)
2655        let result = nms_class_aware_int(0.5, boxes);
2656        assert_eq!(
2657            result.len(),
2658            2,
2659            "Class-aware NMS (int) should keep boxes with different classes"
2660        );
2661    }
2662
2663    #[test]
2664    fn test_nms_enum_default() {
2665        // Test that Nms enum has the correct default
2666        let default_nms: configs::Nms = Default::default();
2667        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2668    }
2669
2670    #[test]
2671    fn test_decoder_nms_mode() {
2672        // Test that decoder properly stores NMS mode
2673        let decoder = DecoderBuilder::default()
2674            .with_config_yolo_det(
2675                configs::Detection {
2676                    anchors: None,
2677                    decoder: DecoderType::Ultralytics,
2678                    quantization: None,
2679                    shape: vec![1, 84, 8400],
2680                    dshape: Vec::new(),
2681                    normalized: Some(true),
2682                },
2683                None,
2684            )
2685            .with_nms(Some(configs::Nms::ClassAware))
2686            .build()
2687            .unwrap();
2688
2689        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2690    }
2691
2692    #[test]
2693    fn test_decoder_nms_bypass() {
2694        // Test that decoder can be configured with nms=None (bypass)
2695        let decoder = DecoderBuilder::default()
2696            .with_config_yolo_det(
2697                configs::Detection {
2698                    anchors: None,
2699                    decoder: DecoderType::Ultralytics,
2700                    quantization: None,
2701                    shape: vec![1, 84, 8400],
2702                    dshape: Vec::new(),
2703                    normalized: Some(true),
2704                },
2705                None,
2706            )
2707            .with_nms(None)
2708            .build()
2709            .unwrap();
2710
2711        assert_eq!(decoder.nms, None);
2712    }
2713
2714    #[test]
2715    fn test_decoder_normalized_boxes_true() {
2716        // Test that normalized_boxes returns Some(true) when explicitly set
2717        let decoder = DecoderBuilder::default()
2718            .with_config_yolo_det(
2719                configs::Detection {
2720                    anchors: None,
2721                    decoder: DecoderType::Ultralytics,
2722                    quantization: None,
2723                    shape: vec![1, 84, 8400],
2724                    dshape: Vec::new(),
2725                    normalized: Some(true),
2726                },
2727                None,
2728            )
2729            .build()
2730            .unwrap();
2731
2732        assert_eq!(decoder.normalized_boxes(), Some(true));
2733    }
2734
2735    #[test]
2736    fn test_decoder_normalized_boxes_false() {
2737        // Test that normalized_boxes returns Some(false) when config specifies
2738        // unnormalized
2739        let decoder = DecoderBuilder::default()
2740            .with_config_yolo_det(
2741                configs::Detection {
2742                    anchors: None,
2743                    decoder: DecoderType::Ultralytics,
2744                    quantization: None,
2745                    shape: vec![1, 84, 8400],
2746                    dshape: Vec::new(),
2747                    normalized: Some(false),
2748                },
2749                None,
2750            )
2751            .build()
2752            .unwrap();
2753
2754        assert_eq!(decoder.normalized_boxes(), Some(false));
2755    }
2756
2757    #[test]
2758    fn test_decoder_normalized_boxes_unknown() {
2759        // Test that normalized_boxes returns None when not specified in config
2760        let decoder = DecoderBuilder::default()
2761            .with_config_yolo_det(
2762                configs::Detection {
2763                    anchors: None,
2764                    decoder: DecoderType::Ultralytics,
2765                    quantization: None,
2766                    shape: vec![1, 84, 8400],
2767                    dshape: Vec::new(),
2768                    normalized: None,
2769                },
2770                Some(DecoderVersion::Yolo11),
2771            )
2772            .build()
2773            .unwrap();
2774
2775        assert_eq!(decoder.normalized_boxes(), None);
2776    }
2777}