Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust
16# use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs::{self, DecoderVersion} };
17# fn main() -> DecoderResult<()> {
18// Create a decoder for a YOLOv8 model with quantized int8 output with 0.25 score threshold and 0.7 IOU threshold
19let decoder = DecoderBuilder::new()
20    .with_config_yolo_det(configs::Detection {
21        anchors: None,
22        decoder: configs::DecoderType::Ultralytics,
23        quantization: Some(configs::QuantTuple(0.012345, 26)),
24        shape: vec![1, 84, 8400],
25        dshape: Vec::new(),
26        normalized: Some(true),
27    },
28    Some(DecoderVersion::Yolov8))
29    .with_score_threshold(0.25)
30    .with_iou_threshold(0.7)
31    .build()?;
32
33// Get the model output from the model. Here we load it from a test data file for demonstration purposes.
34let model_output: Vec<i8> = include_bytes!("../../../testdata/yolov8s_80_classes.bin")
35    .iter()
36    .map(|b| *b as i8)
37    .collect();
38let model_output_array = ndarray::Array3::from_shape_vec((1, 84, 8400), model_output)?;
39
40// THe capacity is used to determine the maximum number of detections to decode.
41let mut output_boxes: Vec<_> = Vec::with_capacity(10);
42let mut output_masks: Vec<_> = Vec::with_capacity(10);
43
44// Decode the quantized model output into detection boxes and segmentation masks
45// Because this model is a detection-only model, the `output_masks` vector will remain empty.
46decoder.decode_quantized(&[model_output_array.view().into()], &mut output_boxes, &mut output_masks)?;
47# Ok(())
48# }
49```
50
51# Overview
52
53The primary components of this crate are:
54- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
55- `yolo` module: Contains functions specific to decoding YOLO model outputs.
56- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
57
58The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
59It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
60When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
61If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
62the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
63to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
64
65The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
66which can be used if the model type and output formats are known in advance.
67
68
69*/
70#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89    yolo::yolo_segmentation_to_mask,
90};
91
92/// Trait to convert bounding box formats to XYXY float format
93pub trait BBoxTypeTrait {
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97    /// Converts the bbox into XYXY float format.
98    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99        input: &[B; 4],
100        quant: Quantization,
101    ) -> [A; 4]
102    where
103        f32: AsPrimitive<A>,
104        i32: AsPrimitive<A>;
105
106    /// Converts the bbox into XYXY float format.
107    ///
108    /// # Examples
109    /// ```rust
110    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
111    /// # use ndarray::array;
112    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
113    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
114    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
115    /// ```
116    #[inline(always)]
117    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118        input: ArrayView1<B>,
119    ) -> [A; 4] {
120        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121    }
122
123    #[inline(always)]
124    /// Converts the bbox into XYXY float format.
125    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126        input: ArrayView1<B>,
127        quant: Quantization,
128    ) -> [A; 4]
129    where
130        f32: AsPrimitive<A>,
131        i32: AsPrimitive<A>,
132    {
133        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134    }
135}
136
137/// Converts XYXY bounding boxes to XYXY
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143        input.map(|b| b.as_())
144    }
145
146    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147        input: &[B; 4],
148        quant: Quantization,
149    ) -> [A; 4]
150    where
151        f32: AsPrimitive<A>,
152        i32: AsPrimitive<A>,
153    {
154        let scale = quant.scale.as_();
155        let zp = quant.zero_point.as_();
156        input.map(|b| (b.as_() - zp) * scale)
157    }
158
159    #[inline(always)]
160    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161        input: ArrayView1<B>,
162    ) -> [A; 4] {
163        [
164            input[0].as_(),
165            input[1].as_(),
166            input[2].as_(),
167            input[3].as_(),
168        ]
169    }
170}
171
172/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
173/// box
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178    #[inline(always)]
179    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180        let half = A::one() / (A::one() + A::one());
181        [
182            (input[0].as_()) - (input[2].as_() * half),
183            (input[1].as_()) - (input[3].as_() * half),
184            (input[0].as_()) + (input[2].as_() * half),
185            (input[1].as_()) + (input[3].as_() * half),
186        ]
187    }
188
189    #[inline(always)]
190    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191        input: &[B; 4],
192        quant: Quantization,
193    ) -> [A; 4]
194    where
195        f32: AsPrimitive<A>,
196        i32: AsPrimitive<A>,
197    {
198        let scale = quant.scale.as_();
199        let half_scale = (quant.scale * 0.5).as_();
200        let zp = quant.zero_point.as_();
201        let [x, y, w, h] = [
202            (input[0].as_() - zp) * scale,
203            (input[1].as_() - zp) * scale,
204            (input[2].as_() - zp) * half_scale,
205            (input[3].as_() - zp) * half_scale,
206        ];
207
208        [x - w, y - h, x + w, y + h]
209    }
210
211    #[inline(always)]
212    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213        input: ArrayView1<B>,
214    ) -> [A; 4] {
215        let half = A::one() / (A::one() + A::one());
216        [
217            (input[0].as_()) - (input[2].as_() * half),
218            (input[1].as_()) - (input[3].as_() * half),
219            (input[0].as_()) + (input[2].as_() * half),
220            (input[1].as_()) + (input[3].as_() * half),
221        ]
222    }
223}
224
225/// Describes the quantization parameters for a tensor
226#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228    pub scale: f32,
229    pub zero_point: i32,
230}
231
232impl Quantization {
233    /// Creates a new Quantization struct
234    /// # Examples
235    /// ```
236    /// # use edgefirst_decoder::Quantization;
237    /// let quant = Quantization::new(0.1, -128);
238    /// assert_eq!(quant.scale, 0.1);
239    /// assert_eq!(quant.zero_point, -128);
240    /// ```
241    pub fn new(scale: f32, zero_point: i32) -> Self {
242        Self { scale, zero_point }
243    }
244}
245
246impl From<QuantTuple> for Quantization {
247    /// Creates a new Quantization struct from a QuantTuple
248    /// # Examples
249    /// ```
250    /// # use edgefirst_decoder::Quantization;
251    /// # use edgefirst_decoder::configs::QuantTuple;
252    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
253    /// let quant = Quantization::from(quant_tuple);
254    /// assert_eq!(quant.scale, 0.1);
255    /// assert_eq!(quant.zero_point, -128);
256    /// ```
257    fn from(quant_tuple: QuantTuple) -> Quantization {
258        Quantization {
259            scale: quant_tuple.0,
260            zero_point: quant_tuple.1,
261        }
262    }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267    S: AsPrimitive<f32>,
268    Z: AsPrimitive<i32>,
269{
270    /// Creates a new Quantization struct from a tuple
271    /// # Examples
272    /// ```
273    /// # use edgefirst_decoder::Quantization;
274    /// let quant = Quantization::from((0.1_f64, -128_i64));
275    /// assert_eq!(quant.scale, 0.1);
276    /// assert_eq!(quant.zero_point, -128);
277    /// ```
278    fn from((scale, zp): (S, Z)) -> Quantization {
279        Self {
280            scale: scale.as_(),
281            zero_point: zp.as_(),
282        }
283    }
284}
285
286impl Default for Quantization {
287    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
288    /// # Examples
289    /// ```rust
290    /// # use edgefirst_decoder::Quantization;
291    /// let quant = Quantization::default();
292    /// assert_eq!(quant.scale, 1.0);
293    /// assert_eq!(quant.zero_point, 0);
294    /// ```
295    fn default() -> Self {
296        Self {
297            scale: 1.0,
298            zero_point: 0,
299        }
300    }
301}
302
303/// A detection box with f32 bbox and score
304#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306    pub bbox: BoundingBox,
307    /// model-specific score for this detection, higher implies more confidence
308    pub score: f32,
309    /// label index for this detection
310    pub label: usize,
311}
312
313/// A bounding box with f32 coordinates in XYXY format
314#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316    /// left-most normalized coordinate of the bounding box
317    pub xmin: f32,
318    /// top-most normalized coordinate of the bounding box
319    pub ymin: f32,
320    /// right-most normalized coordinate of the bounding box
321    pub xmax: f32,
322    /// bottom-most normalized coordinate of the bounding box
323    pub ymax: f32,
324}
325
326impl BoundingBox {
327    /// Creates a new BoundingBox from the given coordinates
328    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329        Self {
330            xmin,
331            ymin,
332            xmax,
333            ymax,
334        }
335    }
336
337    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
338    ///
339    /// ```
340    /// # use edgefirst_decoder::BoundingBox;
341    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
342    /// let canonical_bbox = bbox.to_canonical();
343    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
344    /// ```
345    pub fn to_canonical(&self) -> Self {
346        let xmin = self.xmin.min(self.xmax);
347        let xmax = self.xmin.max(self.xmax);
348        let ymin = self.ymin.min(self.ymax);
349        let ymax = self.ymin.max(self.ymax);
350        BoundingBox {
351            xmin,
352            ymin,
353            xmax,
354            ymax,
355        }
356    }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
361    /// xmax, ymax order
362    /// # Examples
363    /// ```
364    /// # use edgefirst_decoder::BoundingBox;
365    /// let bbox = BoundingBox {
366    ///     xmin: 0.1,
367    ///     ymin: 0.2,
368    ///     xmax: 0.3,
369    ///     ymax: 0.4,
370    /// };
371    /// let arr: [f32; 4] = bbox.into();
372    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
373    /// ```
374    fn from(b: BoundingBox) -> Self {
375        [b.xmin, b.ymin, b.xmax, b.ymax]
376    }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
381    // BoundingBox
382    fn from(arr: [f32; 4]) -> Self {
383        BoundingBox {
384            xmin: arr[0],
385            ymin: arr[1],
386            xmax: arr[2],
387            ymax: arr[3],
388        }
389    }
390}
391
392impl DetectBox {
393    /// Returns true if one detect box is equal to another detect box, within
394    /// the given `eps`
395    ///
396    /// # Examples
397    /// ```
398    /// # use edgefirst_decoder::DetectBox;
399    /// let box1 = DetectBox {
400    ///     bbox: edgefirst_decoder::BoundingBox {
401    ///         xmin: 0.1,
402    ///         ymin: 0.2,
403    ///         xmax: 0.3,
404    ///         ymax: 0.4,
405    ///     },
406    ///     score: 0.5,
407    ///     label: 1,
408    /// };
409    /// let box2 = DetectBox {
410    ///     bbox: edgefirst_decoder::BoundingBox {
411    ///         xmin: 0.101,
412    ///         ymin: 0.199,
413    ///         xmax: 0.301,
414    ///         ymax: 0.399,
415    ///     },
416    ///     score: 0.510,
417    ///     label: 1,
418    /// };
419    /// assert!(box1.equal_within_delta(&box2, 0.011));
420    /// ```
421    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423        self.label == rhs.label
424            && eq_delta(self.score, rhs.score)
425            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429    }
430}
431
432/// A segmentation result with a segmentation mask, and a normalized bounding
433/// box representing the area that the segmentation mask covers
434#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436    /// left-most normalized coordinate of the segmentation box
437    pub xmin: f32,
438    /// top-most normalized coordinate of the segmentation box
439    pub ymin: f32,
440    /// right-most normalized coordinate of the segmentation box
441    pub xmax: f32,
442    /// bottom-most normalized coordinate of the segmentation box
443    pub ymax: f32,
444    /// 3D segmentation array of shape `(H, W, C)`.
445    ///
446    /// For instance segmentation (e.g. YOLO): `C=1` — binary per-instance
447    /// mask where values >= 128 indicate object presence.
448    ///
449    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
450    /// class scores where the object class is the argmax index.
451    pub segmentation: Array3<u8>,
452}
453
454/// Prototype tensor variants for fused decode+render pipelines.
455///
456/// Carries either raw quantized data (to skip CPU dequantization and let the
457/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
458/// paths).
459#[derive(Debug, Clone)]
460pub enum ProtoTensor {
461    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
462    /// The GPU fragment shader will dequantize per-texel using the scale and
463    /// zero_point.
464    Quantized {
465        protos: Array3<i8>,
466        quantization: Quantization,
467    },
468    /// Dequantized f32 protos (from float models or legacy path).
469    Float(Array3<f32>),
470}
471
472impl ProtoTensor {
473    /// Returns `true` if this is the quantized variant.
474    pub fn is_quantized(&self) -> bool {
475        matches!(self, ProtoTensor::Quantized { .. })
476    }
477
478    /// Returns the spatial dimensions `(height, width, num_protos)`.
479    pub fn dim(&self) -> (usize, usize, usize) {
480        match self {
481            ProtoTensor::Quantized { protos, .. } => protos.dim(),
482            ProtoTensor::Float(arr) => arr.dim(),
483        }
484    }
485
486    /// Returns dequantized f32 protos. For the `Float` variant this is a
487    /// no-copy reference; for `Quantized` it allocates and dequantizes.
488    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
489        match self {
490            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
491            ProtoTensor::Quantized {
492                protos,
493                quantization,
494            } => {
495                let scale = quantization.scale;
496                let zp = quantization.zero_point as f32;
497                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
498            }
499        }
500    }
501}
502
503/// Raw prototype data for fused decode+render pipelines.
504///
505/// Holds post-NMS intermediate state before mask materialization, allowing the
506/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
507/// shader) without materializing intermediate `Array3<u8>` masks.
508#[derive(Debug, Clone)]
509pub struct ProtoData {
510    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
511    pub mask_coefficients: Vec<Vec<f32>>,
512    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
513    pub protos: ProtoTensor,
514}
515
516/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
517///
518///  # Examples
519/// ```
520/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
521/// let quant = Quantization::new(0.1, -128);
522/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
523/// let detect_quant = DetectBoxQuantized {
524///     bbox,
525///     score: 100_i8,
526///     label: 1,
527/// };
528/// let detect = dequant_detect_box(&detect_quant, quant);
529/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
530/// assert_eq!(detect.label, 1);
531/// assert_eq!(detect.bbox, bbox);
532/// ```
533pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
534    detect: &DetectBoxQuantized<SCORE>,
535    quant_scores: Quantization,
536) -> DetectBox {
537    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
538    DetectBox {
539        bbox: detect.bbox,
540        score: quant_scores.scale * detect.score.as_() + scaled_zp,
541        label: detect.label,
542    }
543}
544/// A detection box with a f32 bbox and quantized score
545#[derive(Debug, Clone, Copy, PartialEq)]
546pub struct DetectBoxQuantized<
547    // BOX: Signed + PrimInt + AsPrimitive<f32>,
548    SCORE: PrimInt + AsPrimitive<f32>,
549> {
550    // pub bbox: BoundingBoxQuantized<BOX>,
551    pub bbox: BoundingBox,
552    /// model-specific score for this detection, higher implies more
553    /// confidence.
554    pub score: SCORE,
555    /// label index for this detect
556    pub label: usize,
557}
558
559/// Dequantizes an ndarray from quantized values to f32 values using the given
560/// quantization parameters
561///
562/// # Examples
563/// ```
564/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
565/// let quant = Quantization::new(0.1, -128);
566/// let input: Vec<i8> = vec![0, 127, -128, 64];
567/// let input_array = ndarray::Array1::from(input);
568/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
569/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
570/// ```
571pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
572    input: ArrayView<T, D>,
573    quant: Quantization,
574) -> Array<F, D>
575where
576    i32: num_traits::AsPrimitive<F>,
577    f32: num_traits::AsPrimitive<F>,
578{
579    let zero_point = quant.zero_point.as_();
580    let scale = quant.scale.as_();
581    if zero_point != F::zero() {
582        let scaled_zero = -zero_point * scale;
583        input.mapv(|d| d.as_() * scale + scaled_zero)
584    } else {
585        input.mapv(|d| d.as_() * scale)
586    }
587}
588
589/// Dequantizes a slice from quantized values to float values using the given
590/// quantization parameters
591///
592/// # Examples
593/// ```
594/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
595/// let quant = Quantization::new(0.1, -128);
596/// let input: Vec<i8> = vec![0, 127, -128, 64];
597/// let mut output: Vec<f32> = vec![0.0; input.len()];
598/// dequantize_cpu(&input, quant, &mut output);
599/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
600/// ```
601pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
602    input: &[T],
603    quant: Quantization,
604    output: &mut [F],
605) where
606    f32: num_traits::AsPrimitive<F>,
607    i32: num_traits::AsPrimitive<F>,
608{
609    assert!(input.len() == output.len());
610    let zero_point = quant.zero_point.as_();
611    let scale = quant.scale.as_();
612    if zero_point != F::zero() {
613        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
614        input
615            .iter()
616            .zip(output)
617            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
618    } else {
619        input
620            .iter()
621            .zip(output)
622            .for_each(|(d, deq)| *deq = d.as_() * scale);
623    }
624}
625
626/// Dequantizes a slice from quantized values to float values using the given
627/// quantization parameters, using chunked processing. This is around 5% faster
628/// than `dequantize_cpu` for large slices.
629///
630/// # Examples
631/// ```
632/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
633/// let quant = Quantization::new(0.1, -128);
634/// let input: Vec<i8> = vec![0, 127, -128, 64];
635/// let mut output: Vec<f32> = vec![0.0; input.len()];
636/// dequantize_cpu_chunked(&input, quant, &mut output);
637/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
638/// ```
639pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
640    input: &[T],
641    quant: Quantization,
642    output: &mut [F],
643) where
644    f32: num_traits::AsPrimitive<F>,
645    i32: num_traits::AsPrimitive<F>,
646{
647    assert!(input.len() == output.len());
648    let zero_point = quant.zero_point.as_();
649    let scale = quant.scale.as_();
650
651    let input = input.as_chunks::<4>();
652    let output = output.as_chunks_mut::<4>();
653
654    if zero_point != F::zero() {
655        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
656
657        input
658            .0
659            .iter()
660            .zip(output.0)
661            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
662        input
663            .1
664            .iter()
665            .zip(output.1)
666            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
667    } else {
668        input
669            .0
670            .iter()
671            .zip(output.0)
672            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
673        input
674            .1
675            .iter()
676            .zip(output.1)
677            .for_each(|(d, deq)| *deq = d.as_() * scale);
678    }
679}
680
681/// Converts a segmentation tensor into a 2D mask
682/// If the last dimension of the segmentation tensor is 1, values equal or
683/// above 128 are considered objects. Otherwise the object is the argmax index
684///
685/// # Errors
686///
687/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
688/// invalid shape.
689///
690/// # Examples
691/// ```
692/// # use edgefirst_decoder::segmentation_to_mask;
693/// let segmentation =
694///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
695/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
696/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
697/// ```
698pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
699    if segmentation.shape()[2] == 0 {
700        return Err(DecoderError::InvalidShape(
701            "Segmentation tensor must have non-zero depth".to_string(),
702        ));
703    }
704    if segmentation.shape()[2] == 1 {
705        yolo_segmentation_to_mask(segmentation, 128)
706    } else {
707        Ok(modelpack_segmentation_to_mask(segmentation))
708    }
709}
710
711/// Returns the maximum value and its index from a 1D array
712fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
713    score
714        .iter()
715        .enumerate()
716        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
717            if max > *s {
718                (max, arg_max)
719            } else {
720                (*s, ind)
721            }
722        })
723}
724#[cfg(test)]
725#[cfg_attr(coverage_nightly, coverage(off))]
726mod decoder_tests {
727    #![allow(clippy::excessive_precision)]
728    use crate::{
729        configs::{DecoderType, DimName, Protos},
730        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
731        yolo::{
732            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
733            decode_yolo_segdet_quant,
734        },
735        *,
736    };
737    use ndarray::{array, s, Array4};
738    use ndarray_stats::DeviationExt;
739
740    fn compare_outputs(
741        boxes: (&[DetectBox], &[DetectBox]),
742        masks: (&[Segmentation], &[Segmentation]),
743    ) {
744        let (boxes0, boxes1) = boxes;
745        let (masks0, masks1) = masks;
746
747        assert_eq!(boxes0.len(), boxes1.len());
748        assert_eq!(masks0.len(), masks1.len());
749
750        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
751            assert!(
752                b_i8.equal_within_delta(b_f32, 1e-6),
753                "{b_i8:?} is not equal to {b_f32:?}"
754            );
755        }
756
757        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
758            assert_eq!(
759                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
760                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
761            );
762            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
763            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
764            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
765            let diff = &mask_i8 - &mask_f32;
766            for x in 0..diff.shape()[0] {
767                for y in 0..diff.shape()[1] {
768                    for z in 0..diff.shape()[2] {
769                        let val = diff[[x, y, z]];
770                        assert!(
771                            val.abs() <= 1,
772                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
773                            x,
774                            y,
775                            z,
776                            val
777                        );
778                    }
779                }
780            }
781            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
782            assert!(
783                mean_sq_err < 1e-2,
784                "Mean Square Error between masks was greater than 1%: {:.2}%",
785                mean_sq_err * 100.0
786            );
787        }
788    }
789
790    #[test]
791    fn test_decoder_modelpack() {
792        let score_threshold = 0.45;
793        let iou_threshold = 0.45;
794        let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
795        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
796
797        let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
798        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
799
800        let quant_boxes = (0.004656755365431309, 21).into();
801        let quant_scores = (0.0019603664986789227, 0).into();
802
803        let decoder = DecoderBuilder::default()
804            .with_config_modelpack_det(
805                configs::Boxes {
806                    decoder: DecoderType::ModelPack,
807                    quantization: Some(quant_boxes),
808                    shape: vec![1, 1935, 1, 4],
809                    dshape: vec![
810                        (DimName::Batch, 1),
811                        (DimName::NumBoxes, 1935),
812                        (DimName::Padding, 1),
813                        (DimName::BoxCoords, 4),
814                    ],
815                    normalized: Some(true),
816                },
817                configs::Scores {
818                    decoder: DecoderType::ModelPack,
819                    quantization: Some(quant_scores),
820                    shape: vec![1, 1935, 1],
821                    dshape: vec![
822                        (DimName::Batch, 1),
823                        (DimName::NumBoxes, 1935),
824                        (DimName::NumClasses, 1),
825                    ],
826                },
827            )
828            .with_score_threshold(score_threshold)
829            .with_iou_threshold(iou_threshold)
830            .build()
831            .unwrap();
832
833        let quant_boxes = quant_boxes.into();
834        let quant_scores = quant_scores.into();
835
836        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
837        decode_modelpack_det(
838            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
839            (scores.slice(s![0, .., ..]), quant_scores),
840            score_threshold,
841            iou_threshold,
842            &mut output_boxes,
843        );
844        assert!(output_boxes[0].equal_within_delta(
845            &DetectBox {
846                bbox: BoundingBox {
847                    xmin: 0.40513772,
848                    ymin: 0.6379755,
849                    xmax: 0.5122431,
850                    ymax: 0.7730214,
851                },
852                score: 0.4861709,
853                label: 0
854            },
855            1e-6
856        ));
857
858        let mut output_boxes1 = Vec::with_capacity(50);
859        let mut output_masks1 = Vec::with_capacity(50);
860
861        decoder
862            .decode_quantized(
863                &[boxes.view().into(), scores.view().into()],
864                &mut output_boxes1,
865                &mut output_masks1,
866            )
867            .unwrap();
868
869        let mut output_boxes_float = Vec::with_capacity(50);
870        let mut output_masks_float = Vec::with_capacity(50);
871
872        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
873        let scores = dequantize_ndarray(scores.view(), quant_scores);
874
875        decoder
876            .decode_float::<f32>(
877                &[boxes.view().into_dyn(), scores.view().into_dyn()],
878                &mut output_boxes_float,
879                &mut output_masks_float,
880            )
881            .unwrap();
882
883        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
884        compare_outputs(
885            (&output_boxes, &output_boxes_float),
886            (&[], &output_masks_float),
887        );
888    }
889
890    #[test]
891    fn test_decoder_modelpack_split_u8() {
892        let score_threshold = 0.45;
893        let iou_threshold = 0.45;
894        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
895        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
896
897        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
898        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
899
900        let quant0 = (0.08547406643629074, 174).into();
901        let quant1 = (0.09929127991199493, 183).into();
902        let anchors0 = vec![
903            [0.36666667461395264, 0.31481480598449707],
904            [0.38749998807907104, 0.4740740656852722],
905            [0.5333333611488342, 0.644444465637207],
906        ];
907        let anchors1 = vec![
908            [0.13750000298023224, 0.2074074000120163],
909            [0.2541666626930237, 0.21481481194496155],
910            [0.23125000298023224, 0.35185185074806213],
911        ];
912
913        let detect_config0 = configs::Detection {
914            decoder: DecoderType::ModelPack,
915            shape: vec![1, 9, 15, 18],
916            anchors: Some(anchors0.clone()),
917            quantization: Some(quant0),
918            dshape: vec![
919                (DimName::Batch, 1),
920                (DimName::Height, 9),
921                (DimName::Width, 15),
922                (DimName::NumAnchorsXFeatures, 18),
923            ],
924            normalized: Some(true),
925        };
926
927        let detect_config1 = configs::Detection {
928            decoder: DecoderType::ModelPack,
929            shape: vec![1, 17, 30, 18],
930            anchors: Some(anchors1.clone()),
931            quantization: Some(quant1),
932            dshape: vec![
933                (DimName::Batch, 1),
934                (DimName::Height, 17),
935                (DimName::Width, 30),
936                (DimName::NumAnchorsXFeatures, 18),
937            ],
938            normalized: Some(true),
939        };
940
941        let config0 = (&detect_config0).try_into().unwrap();
942        let config1 = (&detect_config1).try_into().unwrap();
943
944        let decoder = DecoderBuilder::default()
945            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
946            .with_score_threshold(score_threshold)
947            .with_iou_threshold(iou_threshold)
948            .build()
949            .unwrap();
950
951        let quant0 = quant0.into();
952        let quant1 = quant1.into();
953
954        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
955        decode_modelpack_split_quant(
956            &[
957                detect0.slice(s![0, .., .., ..]),
958                detect1.slice(s![0, .., .., ..]),
959            ],
960            &[config0, config1],
961            score_threshold,
962            iou_threshold,
963            &mut output_boxes,
964        );
965        assert!(output_boxes[0].equal_within_delta(
966            &DetectBox {
967                bbox: BoundingBox {
968                    xmin: 0.43171933,
969                    ymin: 0.68243736,
970                    xmax: 0.5626645,
971                    ymax: 0.808863,
972                },
973                score: 0.99240804,
974                label: 0
975            },
976            1e-6
977        ));
978
979        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
980        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
981        decoder
982            .decode_quantized(
983                &[detect0.view().into(), detect1.view().into()],
984                &mut output_boxes1,
985                &mut output_masks1,
986            )
987            .unwrap();
988
989        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
990        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
991
992        let detect0 = dequantize_ndarray(detect0.view(), quant0);
993        let detect1 = dequantize_ndarray(detect1.view(), quant1);
994        decoder
995            .decode_float::<f32>(
996                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
997                &mut output_boxes1_f32,
998                &mut output_masks1_f32,
999            )
1000            .unwrap();
1001
1002        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1003        compare_outputs(
1004            (&output_boxes, &output_boxes1_f32),
1005            (&[], &output_masks1_f32),
1006        );
1007    }
1008
1009    #[test]
1010    fn test_decoder_parse_config_modelpack_split_u8() {
1011        let score_threshold = 0.45;
1012        let iou_threshold = 0.45;
1013        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1014        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1015
1016        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1017        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1018
1019        let decoder = DecoderBuilder::default()
1020            .with_config_yaml_str(
1021                include_str!("../../../testdata/modelpack_split.yaml").to_string(),
1022            )
1023            .with_score_threshold(score_threshold)
1024            .with_iou_threshold(iou_threshold)
1025            .build()
1026            .unwrap();
1027
1028        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1029        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1030        decoder
1031            .decode_quantized(
1032                &[
1033                    ArrayViewDQuantized::from(detect1.view()),
1034                    ArrayViewDQuantized::from(detect0.view()),
1035                ],
1036                &mut output_boxes,
1037                &mut output_masks,
1038            )
1039            .unwrap();
1040        assert!(output_boxes[0].equal_within_delta(
1041            &DetectBox {
1042                bbox: BoundingBox {
1043                    xmin: 0.43171933,
1044                    ymin: 0.68243736,
1045                    xmax: 0.5626645,
1046                    ymax: 0.808863,
1047                },
1048                score: 0.99240804,
1049                label: 0
1050            },
1051            1e-6
1052        ));
1053    }
1054
1055    #[test]
1056    fn test_modelpack_seg() {
1057        let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1058        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1059        let quant = (1.0 / 255.0, 0).into();
1060
1061        let decoder = DecoderBuilder::default()
1062            .with_config_modelpack_seg(configs::Segmentation {
1063                decoder: DecoderType::ModelPack,
1064                quantization: Some(quant),
1065                shape: vec![1, 2, 160, 160],
1066                dshape: vec![
1067                    (DimName::Batch, 1),
1068                    (DimName::NumClasses, 2),
1069                    (DimName::Height, 160),
1070                    (DimName::Width, 160),
1071                ],
1072            })
1073            .build()
1074            .unwrap();
1075        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1076        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1077        decoder
1078            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1079            .unwrap();
1080
1081        let mut mask = out.slice(s![0, .., .., ..]);
1082        mask.swap_axes(0, 1);
1083        mask.swap_axes(1, 2);
1084        let mask = [Segmentation {
1085            xmin: 0.0,
1086            ymin: 0.0,
1087            xmax: 1.0,
1088            ymax: 1.0,
1089            segmentation: mask.into_owned(),
1090        }];
1091        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1092
1093        decoder
1094            .decode_float::<f32>(
1095                &[dequantize_ndarray(out.view(), quant.into())
1096                    .view()
1097                    .into_dyn()],
1098                &mut output_boxes,
1099                &mut output_masks,
1100            )
1101            .unwrap();
1102
1103        // not expected for float decoder to have same values as quantized decoder, as
1104        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1105        // the model output. Thus the float output is the same as the quantized output
1106        // but scaled differently. However, it is expected that the mask after argmax
1107        // will be the same.
1108        compare_outputs((&[], &output_boxes), (&[], &[]));
1109        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1110        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1111
1112        assert_eq!(mask0, mask1);
1113    }
1114    #[test]
1115    fn test_modelpack_seg_quant() {
1116        let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1117        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1118        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1119        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1120        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1121        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1122        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1123
1124        let quant = (1.0 / 255.0, 0).into();
1125
1126        let decoder = DecoderBuilder::default()
1127            .with_config_modelpack_seg(configs::Segmentation {
1128                decoder: DecoderType::ModelPack,
1129                quantization: Some(quant),
1130                shape: vec![1, 2, 160, 160],
1131                dshape: vec![
1132                    (DimName::Batch, 1),
1133                    (DimName::NumClasses, 2),
1134                    (DimName::Height, 160),
1135                    (DimName::Width, 160),
1136                ],
1137            })
1138            .build()
1139            .unwrap();
1140        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1141        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1142        decoder
1143            .decode_quantized(
1144                &[out_u8.view().into()],
1145                &mut output_boxes,
1146                &mut output_masks_u8,
1147            )
1148            .unwrap();
1149
1150        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1151        decoder
1152            .decode_quantized(
1153                &[out_i8.view().into()],
1154                &mut output_boxes,
1155                &mut output_masks_i8,
1156            )
1157            .unwrap();
1158
1159        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1160        decoder
1161            .decode_quantized(
1162                &[out_u16.view().into()],
1163                &mut output_boxes,
1164                &mut output_masks_u16,
1165            )
1166            .unwrap();
1167
1168        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1169        decoder
1170            .decode_quantized(
1171                &[out_i16.view().into()],
1172                &mut output_boxes,
1173                &mut output_masks_i16,
1174            )
1175            .unwrap();
1176
1177        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1178        decoder
1179            .decode_quantized(
1180                &[out_u32.view().into()],
1181                &mut output_boxes,
1182                &mut output_masks_u32,
1183            )
1184            .unwrap();
1185
1186        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1187        decoder
1188            .decode_quantized(
1189                &[out_i32.view().into()],
1190                &mut output_boxes,
1191                &mut output_masks_i32,
1192            )
1193            .unwrap();
1194
1195        compare_outputs((&[], &output_boxes), (&[], &[]));
1196        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1197        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1198        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1199        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1200        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1201        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1202        assert_eq!(mask_u8, mask_i8);
1203        assert_eq!(mask_u8, mask_u16);
1204        assert_eq!(mask_u8, mask_i16);
1205        assert_eq!(mask_u8, mask_u32);
1206        assert_eq!(mask_u8, mask_i32);
1207    }
1208
1209    #[test]
1210    fn test_modelpack_segdet() {
1211        let score_threshold = 0.45;
1212        let iou_threshold = 0.45;
1213
1214        let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
1215        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1216
1217        let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
1218        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1219
1220        let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1221        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1222
1223        let quant_boxes = (0.004656755365431309, 21).into();
1224        let quant_scores = (0.0019603664986789227, 0).into();
1225        let quant_seg = (1.0 / 255.0, 0).into();
1226
1227        let decoder = DecoderBuilder::default()
1228            .with_config_modelpack_segdet(
1229                configs::Boxes {
1230                    decoder: DecoderType::ModelPack,
1231                    quantization: Some(quant_boxes),
1232                    shape: vec![1, 1935, 1, 4],
1233                    dshape: vec![
1234                        (DimName::Batch, 1),
1235                        (DimName::NumBoxes, 1935),
1236                        (DimName::Padding, 1),
1237                        (DimName::BoxCoords, 4),
1238                    ],
1239                    normalized: Some(true),
1240                },
1241                configs::Scores {
1242                    decoder: DecoderType::ModelPack,
1243                    quantization: Some(quant_scores),
1244                    shape: vec![1, 1935, 1],
1245                    dshape: vec![
1246                        (DimName::Batch, 1),
1247                        (DimName::NumBoxes, 1935),
1248                        (DimName::NumClasses, 1),
1249                    ],
1250                },
1251                configs::Segmentation {
1252                    decoder: DecoderType::ModelPack,
1253                    quantization: Some(quant_seg),
1254                    shape: vec![1, 2, 160, 160],
1255                    dshape: vec![
1256                        (DimName::Batch, 1),
1257                        (DimName::NumClasses, 2),
1258                        (DimName::Height, 160),
1259                        (DimName::Width, 160),
1260                    ],
1261                },
1262            )
1263            .with_iou_threshold(iou_threshold)
1264            .with_score_threshold(score_threshold)
1265            .build()
1266            .unwrap();
1267        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1268        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1269        decoder
1270            .decode_quantized(
1271                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1272                &mut output_boxes,
1273                &mut output_masks,
1274            )
1275            .unwrap();
1276
1277        let mut mask = seg.slice(s![0, .., .., ..]);
1278        mask.swap_axes(0, 1);
1279        mask.swap_axes(1, 2);
1280        let mask = [Segmentation {
1281            xmin: 0.0,
1282            ymin: 0.0,
1283            xmax: 1.0,
1284            ymax: 1.0,
1285            segmentation: mask.into_owned(),
1286        }];
1287        let correct_boxes = [DetectBox {
1288            bbox: BoundingBox {
1289                xmin: 0.40513772,
1290                ymin: 0.6379755,
1291                xmax: 0.5122431,
1292                ymax: 0.7730214,
1293            },
1294            score: 0.4861709,
1295            label: 0,
1296        }];
1297        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1298
1299        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1300        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1301        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1302        decoder
1303            .decode_float::<f32>(
1304                &[
1305                    scores.view().into_dyn(),
1306                    boxes.view().into_dyn(),
1307                    seg.view().into_dyn(),
1308                ],
1309                &mut output_boxes,
1310                &mut output_masks,
1311            )
1312            .unwrap();
1313
1314        // not expected for float segmentation decoder to have same values as quantized
1315        // segmentation decoder, as float decoder ensures the data fills 0-255,
1316        // quantized decoder uses whatever the model output. Thus the float
1317        // output is the same as the quantized output but scaled differently.
1318        // However, it is expected that the mask after argmax will be the same.
1319        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1320        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1321        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1322
1323        assert_eq!(mask0, mask1);
1324    }
1325
1326    #[test]
1327    fn test_modelpack_segdet_split() {
1328        let score_threshold = 0.8;
1329        let iou_threshold = 0.5;
1330
1331        let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1332        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1333
1334        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1335        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1336
1337        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1338        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1339
1340        let quant0 = (0.08547406643629074, 174).into();
1341        let quant1 = (0.09929127991199493, 183).into();
1342        let quant_seg = (1.0 / 255.0, 0).into();
1343
1344        let anchors0 = vec![
1345            [0.36666667461395264, 0.31481480598449707],
1346            [0.38749998807907104, 0.4740740656852722],
1347            [0.5333333611488342, 0.644444465637207],
1348        ];
1349        let anchors1 = vec![
1350            [0.13750000298023224, 0.2074074000120163],
1351            [0.2541666626930237, 0.21481481194496155],
1352            [0.23125000298023224, 0.35185185074806213],
1353        ];
1354
1355        let decoder = DecoderBuilder::default()
1356            .with_config_modelpack_segdet_split(
1357                vec![
1358                    configs::Detection {
1359                        decoder: DecoderType::ModelPack,
1360                        shape: vec![1, 17, 30, 18],
1361                        anchors: Some(anchors1),
1362                        quantization: Some(quant1),
1363                        dshape: vec![
1364                            (DimName::Batch, 1),
1365                            (DimName::Height, 17),
1366                            (DimName::Width, 30),
1367                            (DimName::NumAnchorsXFeatures, 18),
1368                        ],
1369                        normalized: Some(true),
1370                    },
1371                    configs::Detection {
1372                        decoder: DecoderType::ModelPack,
1373                        shape: vec![1, 9, 15, 18],
1374                        anchors: Some(anchors0),
1375                        quantization: Some(quant0),
1376                        dshape: vec![
1377                            (DimName::Batch, 1),
1378                            (DimName::Height, 9),
1379                            (DimName::Width, 15),
1380                            (DimName::NumAnchorsXFeatures, 18),
1381                        ],
1382                        normalized: Some(true),
1383                    },
1384                ],
1385                configs::Segmentation {
1386                    decoder: DecoderType::ModelPack,
1387                    quantization: Some(quant_seg),
1388                    shape: vec![1, 2, 160, 160],
1389                    dshape: vec![
1390                        (DimName::Batch, 1),
1391                        (DimName::NumClasses, 2),
1392                        (DimName::Height, 160),
1393                        (DimName::Width, 160),
1394                    ],
1395                },
1396            )
1397            .with_score_threshold(score_threshold)
1398            .with_iou_threshold(iou_threshold)
1399            .build()
1400            .unwrap();
1401        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1402        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1403        decoder
1404            .decode_quantized(
1405                &[
1406                    detect0.view().into(),
1407                    detect1.view().into(),
1408                    seg.view().into(),
1409                ],
1410                &mut output_boxes,
1411                &mut output_masks,
1412            )
1413            .unwrap();
1414
1415        let mut mask = seg.slice(s![0, .., .., ..]);
1416        mask.swap_axes(0, 1);
1417        mask.swap_axes(1, 2);
1418        let mask = [Segmentation {
1419            xmin: 0.0,
1420            ymin: 0.0,
1421            xmax: 1.0,
1422            ymax: 1.0,
1423            segmentation: mask.into_owned(),
1424        }];
1425        let correct_boxes = [DetectBox {
1426            bbox: BoundingBox {
1427                xmin: 0.43171933,
1428                ymin: 0.68243736,
1429                xmax: 0.5626645,
1430                ymax: 0.808863,
1431            },
1432            score: 0.99240804,
1433            label: 0,
1434        }];
1435        println!("Output Boxes: {:?}", output_boxes);
1436        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1437
1438        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1439        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1440        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1441        decoder
1442            .decode_float::<f32>(
1443                &[
1444                    detect0.view().into_dyn(),
1445                    detect1.view().into_dyn(),
1446                    seg.view().into_dyn(),
1447                ],
1448                &mut output_boxes,
1449                &mut output_masks,
1450            )
1451            .unwrap();
1452
1453        // not expected for float segmentation decoder to have same values as quantized
1454        // segmentation decoder, as float decoder ensures the data fills 0-255,
1455        // quantized decoder uses whatever the model output. Thus the float
1456        // output is the same as the quantized output but scaled differently.
1457        // However, it is expected that the mask after argmax will be the same.
1458        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1459        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1460        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1461
1462        assert_eq!(mask0, mask1);
1463    }
1464
1465    #[test]
1466    fn test_dequant_chunked() {
1467        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1468        let mut out =
1469            unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1470        out.push(123); // make sure to test non multiple of 16 length
1471
1472        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1473        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1474        let quant = Quantization::new(0.0040811873, -123);
1475        dequantize_cpu(&out, quant, &mut out_dequant);
1476
1477        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1478        assert_eq!(out_dequant, out_dequant_simd);
1479
1480        let quant = Quantization::new(0.0040811873, 0);
1481        dequantize_cpu(&out, quant, &mut out_dequant);
1482
1483        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1484        assert_eq!(out_dequant, out_dequant_simd);
1485    }
1486
1487    #[test]
1488    fn test_decoder_yolo_det() {
1489        let score_threshold = 0.25;
1490        let iou_threshold = 0.7;
1491        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1492        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1493        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1494        let quant = (0.0040811873, -123).into();
1495
1496        let decoder = DecoderBuilder::default()
1497            .with_config_yolo_det(
1498                configs::Detection {
1499                    decoder: DecoderType::Ultralytics,
1500                    shape: vec![1, 84, 8400],
1501                    anchors: None,
1502                    quantization: Some(quant),
1503                    dshape: vec![
1504                        (DimName::Batch, 1),
1505                        (DimName::NumFeatures, 84),
1506                        (DimName::NumBoxes, 8400),
1507                    ],
1508                    normalized: Some(true),
1509                },
1510                Some(DecoderVersion::Yolo11),
1511            )
1512            .with_score_threshold(score_threshold)
1513            .with_iou_threshold(iou_threshold)
1514            .build()
1515            .unwrap();
1516
1517        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1518        decode_yolo_det(
1519            (out.slice(s![0, .., ..]), quant.into()),
1520            score_threshold,
1521            iou_threshold,
1522            Some(configs::Nms::ClassAgnostic),
1523            &mut output_boxes,
1524        );
1525        assert!(output_boxes[0].equal_within_delta(
1526            &DetectBox {
1527                bbox: BoundingBox {
1528                    xmin: 0.5285137,
1529                    ymin: 0.05305544,
1530                    xmax: 0.87541467,
1531                    ymax: 0.9998909,
1532                },
1533                score: 0.5591227,
1534                label: 0
1535            },
1536            1e-6
1537        ));
1538
1539        assert!(output_boxes[1].equal_within_delta(
1540            &DetectBox {
1541                bbox: BoundingBox {
1542                    xmin: 0.130598,
1543                    ymin: 0.43260583,
1544                    xmax: 0.35098213,
1545                    ymax: 0.9958097,
1546                },
1547                score: 0.33057618,
1548                label: 75
1549            },
1550            1e-6
1551        ));
1552
1553        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1554        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1555        decoder
1556            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1557            .unwrap();
1558
1559        let out = dequantize_ndarray(out.view(), quant.into());
1560        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1561        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1562        decoder
1563            .decode_float::<f32>(
1564                &[out.view().into_dyn()],
1565                &mut output_boxes_f32,
1566                &mut output_masks_f32,
1567            )
1568            .unwrap();
1569
1570        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1571        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1572    }
1573
1574    #[test]
1575    fn test_decoder_masks() {
1576        let score_threshold = 0.45;
1577        let iou_threshold = 0.45;
1578        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1579        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1580        let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1581        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1582
1583        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1584        let protos =
1585            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1586        let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1587        let quant_protos = Quantization::new(0.02491161972284317, -117);
1588        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1589        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1590        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1591        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1592        decode_yolo_segdet_float(
1593            seg.view(),
1594            protos.view(),
1595            score_threshold,
1596            iou_threshold,
1597            Some(configs::Nms::ClassAgnostic),
1598            &mut output_boxes,
1599            &mut output_masks,
1600        )
1601        .unwrap();
1602        assert_eq!(output_boxes.len(), 2);
1603        assert_eq!(output_boxes.len(), output_masks.len());
1604
1605        for (b, m) in output_boxes.iter().zip(&output_masks) {
1606            assert!(b.bbox.xmin >= m.xmin);
1607            assert!(b.bbox.ymin >= m.ymin);
1608            assert!(b.bbox.xmax >= m.xmax);
1609            assert!(b.bbox.ymax >= m.ymax);
1610        }
1611        assert!(output_boxes[0].equal_within_delta(
1612            &DetectBox {
1613                bbox: BoundingBox {
1614                    xmin: 0.08515105,
1615                    ymin: 0.7131401,
1616                    xmax: 0.29802868,
1617                    ymax: 0.8195788,
1618                },
1619                score: 0.91537374,
1620                label: 23
1621            },
1622            1.0 / 160.0, // wider range because mask will expand the box
1623        ));
1624
1625        assert!(output_boxes[1].equal_within_delta(
1626            &DetectBox {
1627                bbox: BoundingBox {
1628                    xmin: 0.59605736,
1629                    ymin: 0.25545314,
1630                    xmax: 0.93666154,
1631                    ymax: 0.72378385,
1632                },
1633                score: 0.91537374,
1634                label: 23
1635            },
1636            1.0 / 160.0, // wider range because mask will expand the box
1637        ));
1638
1639        let full_mask = include_bytes!("../../../testdata/yolov8_mask_results.bin");
1640        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1641
1642        let cropped_mask = full_mask.slice(ndarray::s![
1643            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1644            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1645        ]);
1646
1647        assert_eq!(
1648            cropped_mask,
1649            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1650        );
1651    }
1652
1653    /// Regression test: config-driven path with NCHW protos (no dshape).
1654    /// Simulates YOLOv8-seg ONNX outputs where protos are (1, 32, 160, 160)
1655    /// and the YAML config has no dshape field — the exact scenario from
1656    /// hal_mask_matmul_bug.md.
1657    #[test]
1658    fn test_decoder_masks_nchw_protos() {
1659        let score_threshold = 0.45;
1660        let iou_threshold = 0.45;
1661
1662        // Load test data — boxes as [116, 8400]
1663        let boxes_raw = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1664        let boxes_raw =
1665            unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1666        let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1667        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1668
1669        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1670        let protos_raw = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1671        let protos_raw = unsafe {
1672            std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1673        };
1674        let protos_hwc =
1675            ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1676        let quant_protos = Quantization::new(0.02491161972284317, -117);
1677        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1678
1679        // ---- Reference: direct call with HWC protos (known working) ----
1680        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1681        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1682        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1683        decode_yolo_segdet_float(
1684            seg.view(),
1685            protos_f32_hwc.view(),
1686            score_threshold,
1687            iou_threshold,
1688            Some(configs::Nms::ClassAgnostic),
1689            &mut ref_boxes,
1690            &mut ref_masks,
1691        )
1692        .unwrap();
1693        assert_eq!(ref_boxes.len(), 2);
1694
1695        // ---- Config-driven path: NCHW protos, no dshape ----
1696        // Permute protos to NCHW [1, 32, 160, 160] as an ONNX model would output
1697        let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); // [32, 160, 160]
1698        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1699
1700        // Build boxes as [1, 116, 8400] f32
1701        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1702
1703        // Build decoder from config with no dshape on protos
1704        let decoder = DecoderBuilder::default()
1705            .with_config_yolo_segdet(
1706                configs::Detection {
1707                    decoder: configs::DecoderType::Ultralytics,
1708                    quantization: None,
1709                    shape: vec![1, 116, 8400],
1710                    dshape: vec![],
1711                    normalized: Some(true),
1712                    anchors: None,
1713                },
1714                configs::Protos {
1715                    decoder: configs::DecoderType::Ultralytics,
1716                    quantization: None,
1717                    shape: vec![1, 32, 160, 160],
1718                    dshape: vec![], // No dshape — simulates YAML without dshape
1719                },
1720                None, // decoder version
1721            )
1722            .with_score_threshold(score_threshold)
1723            .with_iou_threshold(iou_threshold)
1724            .build()
1725            .unwrap();
1726
1727        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1728        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1729        decoder
1730            .decode_float(
1731                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1732                &mut cfg_boxes,
1733                &mut cfg_masks,
1734            )
1735            .unwrap();
1736
1737        // Must produce the same number of detections
1738        assert_eq!(
1739            cfg_boxes.len(),
1740            ref_boxes.len(),
1741            "config path produced {} boxes, reference produced {}",
1742            cfg_boxes.len(),
1743            ref_boxes.len()
1744        );
1745
1746        // Boxes must match
1747        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1748            assert!(
1749                cb.equal_within_delta(rb, 0.01),
1750                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1751            );
1752        }
1753
1754        // Masks must match pixel-for-pixel
1755        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1756            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1757            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1758            assert_eq!(
1759                cm_arr, rm_arr,
1760                "mask {i} pixel mismatch between config-driven and reference paths"
1761            );
1762        }
1763    }
1764
1765    #[test]
1766    fn test_decoder_masks_i8() {
1767        let score_threshold = 0.45;
1768        let iou_threshold = 0.45;
1769        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1770        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1771        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1772        let quant_boxes = (0.021287761628627777, 31).into();
1773
1774        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1775        let protos =
1776            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1777        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1778        let quant_protos = (0.02491161972284317, -117).into();
1779        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1780        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1781
1782        let decoder = DecoderBuilder::default()
1783            .with_config_yolo_segdet(
1784                configs::Detection {
1785                    decoder: configs::DecoderType::Ultralytics,
1786                    quantization: Some(quant_boxes),
1787                    shape: vec![1, 116, 8400],
1788                    anchors: None,
1789                    dshape: vec![
1790                        (DimName::Batch, 1),
1791                        (DimName::NumFeatures, 116),
1792                        (DimName::NumBoxes, 8400),
1793                    ],
1794                    normalized: Some(true),
1795                },
1796                Protos {
1797                    decoder: configs::DecoderType::Ultralytics,
1798                    quantization: Some(quant_protos),
1799                    shape: vec![1, 160, 160, 32],
1800                    dshape: vec![
1801                        (DimName::Batch, 1),
1802                        (DimName::Height, 160),
1803                        (DimName::Width, 160),
1804                        (DimName::NumProtos, 32),
1805                    ],
1806                },
1807                Some(DecoderVersion::Yolo11),
1808            )
1809            .with_score_threshold(score_threshold)
1810            .with_iou_threshold(iou_threshold)
1811            .build()
1812            .unwrap();
1813
1814        let quant_boxes = quant_boxes.into();
1815        let quant_protos = quant_protos.into();
1816
1817        decode_yolo_segdet_quant(
1818            (boxes.slice(s![0, .., ..]), quant_boxes),
1819            (protos.slice(s![0, .., .., ..]), quant_protos),
1820            score_threshold,
1821            iou_threshold,
1822            Some(configs::Nms::ClassAgnostic),
1823            &mut output_boxes,
1824            &mut output_masks,
1825        )
1826        .unwrap();
1827
1828        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1829        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1830
1831        decoder
1832            .decode_quantized(
1833                &[boxes.view().into(), protos.view().into()],
1834                &mut output_boxes1,
1835                &mut output_masks1,
1836            )
1837            .unwrap();
1838
1839        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1840        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1841
1842        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1843        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1844        decode_yolo_segdet_float(
1845            seg.slice(s![0, .., ..]),
1846            protos.slice(s![0, .., .., ..]),
1847            score_threshold,
1848            iou_threshold,
1849            Some(configs::Nms::ClassAgnostic),
1850            &mut output_boxes_f32,
1851            &mut output_masks_f32,
1852        )
1853        .unwrap();
1854
1855        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1856        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1857
1858        decoder
1859            .decode_float(
1860                &[seg.view().into_dyn(), protos.view().into_dyn()],
1861                &mut output_boxes1_f32,
1862                &mut output_masks1_f32,
1863            )
1864            .unwrap();
1865
1866        compare_outputs(
1867            (&output_boxes, &output_boxes1),
1868            (&output_masks, &output_masks1),
1869        );
1870
1871        compare_outputs(
1872            (&output_boxes, &output_boxes_f32),
1873            (&output_masks, &output_masks_f32),
1874        );
1875
1876        compare_outputs(
1877            (&output_boxes_f32, &output_boxes1_f32),
1878            (&output_masks_f32, &output_masks1_f32),
1879        );
1880    }
1881
1882    #[test]
1883    fn test_decoder_yolo_split() {
1884        let score_threshold = 0.45;
1885        let iou_threshold = 0.45;
1886        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1887        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1888        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1889        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1890
1891        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1892
1893        let decoder = DecoderBuilder::default()
1894            .with_config_yolo_split_det(
1895                configs::Boxes {
1896                    decoder: configs::DecoderType::Ultralytics,
1897                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1898                    shape: vec![1, 4, 8400],
1899                    dshape: vec![
1900                        (DimName::Batch, 1),
1901                        (DimName::BoxCoords, 4),
1902                        (DimName::NumBoxes, 8400),
1903                    ],
1904                    normalized: Some(true),
1905                },
1906                configs::Scores {
1907                    decoder: configs::DecoderType::Ultralytics,
1908                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1909                    shape: vec![1, 80, 8400],
1910                    dshape: vec![
1911                        (DimName::Batch, 1),
1912                        (DimName::NumClasses, 80),
1913                        (DimName::NumBoxes, 8400),
1914                    ],
1915                },
1916            )
1917            .with_score_threshold(score_threshold)
1918            .with_iou_threshold(iou_threshold)
1919            .build()
1920            .unwrap();
1921
1922        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1923        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1924
1925        decoder
1926            .decode_quantized(
1927                &[
1928                    boxes.slice(s![.., ..4, ..]).into(),
1929                    boxes.slice(s![.., 4..84, ..]).into(),
1930                ],
1931                &mut output_boxes,
1932                &mut output_masks,
1933            )
1934            .unwrap();
1935
1936        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1937        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1938        decode_yolo_det_float(
1939            seg.slice(s![0, ..84, ..]),
1940            score_threshold,
1941            iou_threshold,
1942            Some(configs::Nms::ClassAgnostic),
1943            &mut output_boxes_f32,
1944        );
1945
1946        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1947        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1948
1949        decoder
1950            .decode_float(
1951                &[
1952                    seg.slice(s![.., ..4, ..]).into_dyn(),
1953                    seg.slice(s![.., 4..84, ..]).into_dyn(),
1954                ],
1955                &mut output_boxes1,
1956                &mut output_masks1,
1957            )
1958            .unwrap();
1959        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
1960        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
1961    }
1962
1963    #[test]
1964    fn test_decoder_masks_config_mixed() {
1965        let score_threshold = 0.45;
1966        let iou_threshold = 0.45;
1967        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1968        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1969        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1970        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1971
1972        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1973
1974        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1975        let protos =
1976            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1977        let protos: Vec<_> = protos.to_vec();
1978        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1979        let quant_protos = Quantization::new(0.02491161972284317, -117);
1980
1981        let decoder = DecoderBuilder::default()
1982            .with_config_yolo_split_segdet(
1983                configs::Boxes {
1984                    decoder: configs::DecoderType::Ultralytics,
1985                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1986                    shape: vec![1, 4, 8400],
1987                    dshape: vec![
1988                        (DimName::Batch, 1),
1989                        (DimName::BoxCoords, 4),
1990                        (DimName::NumBoxes, 8400),
1991                    ],
1992                    normalized: Some(true),
1993                },
1994                configs::Scores {
1995                    decoder: configs::DecoderType::Ultralytics,
1996                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1997                    shape: vec![1, 80, 8400],
1998                    dshape: vec![
1999                        (DimName::Batch, 1),
2000                        (DimName::NumClasses, 80),
2001                        (DimName::NumBoxes, 8400),
2002                    ],
2003                },
2004                configs::MaskCoefficients {
2005                    decoder: configs::DecoderType::Ultralytics,
2006                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2007                    shape: vec![1, 32, 8400],
2008                    dshape: vec![
2009                        (DimName::Batch, 1),
2010                        (DimName::NumProtos, 32),
2011                        (DimName::NumBoxes, 8400),
2012                    ],
2013                },
2014                configs::Protos {
2015                    decoder: configs::DecoderType::Ultralytics,
2016                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2017                    shape: vec![1, 160, 160, 32],
2018                    dshape: vec![
2019                        (DimName::Batch, 1),
2020                        (DimName::Height, 160),
2021                        (DimName::Width, 160),
2022                        (DimName::NumProtos, 32),
2023                    ],
2024                },
2025            )
2026            .with_score_threshold(score_threshold)
2027            .with_iou_threshold(iou_threshold)
2028            .build()
2029            .unwrap();
2030
2031        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2032        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2033
2034        decoder
2035            .decode_quantized(
2036                &[
2037                    boxes.slice(s![.., ..4, ..]).into(),
2038                    boxes.slice(s![.., 4..84, ..]).into(),
2039                    boxes.slice(s![.., 84.., ..]).into(),
2040                    protos.view().into(),
2041                ],
2042                &mut output_boxes,
2043                &mut output_masks,
2044            )
2045            .unwrap();
2046
2047        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2048        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2049        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2050        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2051        decode_yolo_segdet_float(
2052            seg.slice(s![0, .., ..]),
2053            protos.slice(s![0, .., .., ..]),
2054            score_threshold,
2055            iou_threshold,
2056            Some(configs::Nms::ClassAgnostic),
2057            &mut output_boxes_f32,
2058            &mut output_masks_f32,
2059        )
2060        .unwrap();
2061
2062        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2063        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2064
2065        decoder
2066            .decode_float(
2067                &[
2068                    seg.slice(s![.., ..4, ..]).into_dyn(),
2069                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2070                    seg.slice(s![.., 84.., ..]).into_dyn(),
2071                    protos.view().into_dyn(),
2072                ],
2073                &mut output_boxes1,
2074                &mut output_masks1,
2075            )
2076            .unwrap();
2077        compare_outputs(
2078            (&output_boxes, &output_boxes_f32),
2079            (&output_masks, &output_masks_f32),
2080        );
2081        compare_outputs(
2082            (&output_boxes_f32, &output_boxes1),
2083            (&output_masks_f32, &output_masks1),
2084        );
2085    }
2086
2087    #[test]
2088    fn test_decoder_masks_config_i32() {
2089        let score_threshold = 0.45;
2090        let iou_threshold = 0.45;
2091        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
2092        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2093        let scale = 1 << 23;
2094        let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2095        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2096
2097        let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2098
2099        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
2100        let protos =
2101            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2102        let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2103        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2104        let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2105
2106        let decoder = DecoderBuilder::default()
2107            .with_config_yolo_split_segdet(
2108                configs::Boxes {
2109                    decoder: configs::DecoderType::Ultralytics,
2110                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2111                    shape: vec![1, 4, 8400],
2112                    dshape: vec![
2113                        (DimName::Batch, 1),
2114                        (DimName::BoxCoords, 4),
2115                        (DimName::NumBoxes, 8400),
2116                    ],
2117                    normalized: Some(true),
2118                },
2119                configs::Scores {
2120                    decoder: configs::DecoderType::Ultralytics,
2121                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2122                    shape: vec![1, 80, 8400],
2123                    dshape: vec![
2124                        (DimName::Batch, 1),
2125                        (DimName::NumClasses, 80),
2126                        (DimName::NumBoxes, 8400),
2127                    ],
2128                },
2129                configs::MaskCoefficients {
2130                    decoder: configs::DecoderType::Ultralytics,
2131                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2132                    shape: vec![1, 32, 8400],
2133                    dshape: vec![
2134                        (DimName::Batch, 1),
2135                        (DimName::NumProtos, 32),
2136                        (DimName::NumBoxes, 8400),
2137                    ],
2138                },
2139                configs::Protos {
2140                    decoder: configs::DecoderType::Ultralytics,
2141                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2142                    shape: vec![1, 160, 160, 32],
2143                    dshape: vec![
2144                        (DimName::Batch, 1),
2145                        (DimName::Height, 160),
2146                        (DimName::Width, 160),
2147                        (DimName::NumProtos, 32),
2148                    ],
2149                },
2150            )
2151            .with_score_threshold(score_threshold)
2152            .with_iou_threshold(iou_threshold)
2153            .build()
2154            .unwrap();
2155
2156        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2157        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2158
2159        decoder
2160            .decode_quantized(
2161                &[
2162                    boxes.slice(s![.., ..4, ..]).into(),
2163                    boxes.slice(s![.., 4..84, ..]).into(),
2164                    boxes.slice(s![.., 84.., ..]).into(),
2165                    protos.view().into(),
2166                ],
2167                &mut output_boxes,
2168                &mut output_masks,
2169            )
2170            .unwrap();
2171
2172        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2173        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2174        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2175        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2176        decode_yolo_segdet_float(
2177            seg.slice(s![0, .., ..]),
2178            protos.slice(s![0, .., .., ..]),
2179            score_threshold,
2180            iou_threshold,
2181            Some(configs::Nms::ClassAgnostic),
2182            &mut output_boxes_f32,
2183            &mut output_masks_f32,
2184        )
2185        .unwrap();
2186
2187        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2188        assert_eq!(output_masks.len(), output_masks_f32.len());
2189
2190        compare_outputs(
2191            (&output_boxes, &output_boxes_f32),
2192            (&output_masks, &output_masks_f32),
2193        );
2194    }
2195
2196    /// test running multiple decoders concurrently
2197    #[test]
2198    fn test_context_switch() {
2199        let yolo_det = || {
2200            let score_threshold = 0.25;
2201            let iou_threshold = 0.7;
2202            let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2203            let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2204            let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2205            let quant = (0.0040811873, -123).into();
2206
2207            let decoder = DecoderBuilder::default()
2208                .with_config_yolo_det(
2209                    configs::Detection {
2210                        decoder: DecoderType::Ultralytics,
2211                        shape: vec![1, 84, 8400],
2212                        anchors: None,
2213                        quantization: Some(quant),
2214                        dshape: vec![
2215                            (DimName::Batch, 1),
2216                            (DimName::NumFeatures, 84),
2217                            (DimName::NumBoxes, 8400),
2218                        ],
2219                        normalized: None,
2220                    },
2221                    None,
2222                )
2223                .with_score_threshold(score_threshold)
2224                .with_iou_threshold(iou_threshold)
2225                .build()
2226                .unwrap();
2227
2228            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2229            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2230
2231            for _ in 0..100 {
2232                decoder
2233                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2234                    .unwrap();
2235
2236                assert!(output_boxes[0].equal_within_delta(
2237                    &DetectBox {
2238                        bbox: BoundingBox {
2239                            xmin: 0.5285137,
2240                            ymin: 0.05305544,
2241                            xmax: 0.87541467,
2242                            ymax: 0.9998909,
2243                        },
2244                        score: 0.5591227,
2245                        label: 0
2246                    },
2247                    1e-6
2248                ));
2249
2250                assert!(output_boxes[1].equal_within_delta(
2251                    &DetectBox {
2252                        bbox: BoundingBox {
2253                            xmin: 0.130598,
2254                            ymin: 0.43260583,
2255                            xmax: 0.35098213,
2256                            ymax: 0.9958097,
2257                        },
2258                        score: 0.33057618,
2259                        label: 75
2260                    },
2261                    1e-6
2262                ));
2263                assert!(output_masks.is_empty());
2264            }
2265        };
2266
2267        let modelpack_det_split = || {
2268            let score_threshold = 0.8;
2269            let iou_threshold = 0.5;
2270
2271            let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
2272            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2273
2274            let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2275            let detect0 =
2276                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2277
2278            let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2279            let detect1 =
2280                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2281
2282            let mut mask = seg.slice(s![0, .., .., ..]);
2283            mask.swap_axes(0, 1);
2284            mask.swap_axes(1, 2);
2285            let mask = [Segmentation {
2286                xmin: 0.0,
2287                ymin: 0.0,
2288                xmax: 1.0,
2289                ymax: 1.0,
2290                segmentation: mask.into_owned(),
2291            }];
2292            let correct_boxes = [DetectBox {
2293                bbox: BoundingBox {
2294                    xmin: 0.43171933,
2295                    ymin: 0.68243736,
2296                    xmax: 0.5626645,
2297                    ymax: 0.808863,
2298                },
2299                score: 0.99240804,
2300                label: 0,
2301            }];
2302
2303            let quant0 = (0.08547406643629074, 174).into();
2304            let quant1 = (0.09929127991199493, 183).into();
2305            let quant_seg = (1.0 / 255.0, 0).into();
2306
2307            let anchors0 = vec![
2308                [0.36666667461395264, 0.31481480598449707],
2309                [0.38749998807907104, 0.4740740656852722],
2310                [0.5333333611488342, 0.644444465637207],
2311            ];
2312            let anchors1 = vec![
2313                [0.13750000298023224, 0.2074074000120163],
2314                [0.2541666626930237, 0.21481481194496155],
2315                [0.23125000298023224, 0.35185185074806213],
2316            ];
2317
2318            let decoder = DecoderBuilder::default()
2319                .with_config_modelpack_segdet_split(
2320                    vec![
2321                        configs::Detection {
2322                            decoder: DecoderType::ModelPack,
2323                            shape: vec![1, 17, 30, 18],
2324                            anchors: Some(anchors1),
2325                            quantization: Some(quant1),
2326                            dshape: vec![
2327                                (DimName::Batch, 1),
2328                                (DimName::Height, 17),
2329                                (DimName::Width, 30),
2330                                (DimName::NumAnchorsXFeatures, 18),
2331                            ],
2332                            normalized: None,
2333                        },
2334                        configs::Detection {
2335                            decoder: DecoderType::ModelPack,
2336                            shape: vec![1, 9, 15, 18],
2337                            anchors: Some(anchors0),
2338                            quantization: Some(quant0),
2339                            dshape: vec![
2340                                (DimName::Batch, 1),
2341                                (DimName::Height, 9),
2342                                (DimName::Width, 15),
2343                                (DimName::NumAnchorsXFeatures, 18),
2344                            ],
2345                            normalized: None,
2346                        },
2347                    ],
2348                    configs::Segmentation {
2349                        decoder: DecoderType::ModelPack,
2350                        quantization: Some(quant_seg),
2351                        shape: vec![1, 2, 160, 160],
2352                        dshape: vec![
2353                            (DimName::Batch, 1),
2354                            (DimName::NumClasses, 2),
2355                            (DimName::Height, 160),
2356                            (DimName::Width, 160),
2357                        ],
2358                    },
2359                )
2360                .with_score_threshold(score_threshold)
2361                .with_iou_threshold(iou_threshold)
2362                .build()
2363                .unwrap();
2364            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2365            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2366
2367            for _ in 0..100 {
2368                decoder
2369                    .decode_quantized(
2370                        &[
2371                            detect0.view().into(),
2372                            detect1.view().into(),
2373                            seg.view().into(),
2374                        ],
2375                        &mut output_boxes,
2376                        &mut output_masks,
2377                    )
2378                    .unwrap();
2379
2380                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2381            }
2382        };
2383
2384        let handles = vec![
2385            std::thread::spawn(yolo_det),
2386            std::thread::spawn(modelpack_det_split),
2387            std::thread::spawn(yolo_det),
2388            std::thread::spawn(modelpack_det_split),
2389            std::thread::spawn(yolo_det),
2390            std::thread::spawn(modelpack_det_split),
2391            std::thread::spawn(yolo_det),
2392            std::thread::spawn(modelpack_det_split),
2393        ];
2394        for handle in handles {
2395            handle.join().unwrap();
2396        }
2397    }
2398
2399    #[test]
2400    fn test_ndarray_to_xyxy_float() {
2401        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2402        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2403        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2404
2405        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2406        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2407        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2408    }
2409
2410    #[test]
2411    fn test_class_aware_nms_float() {
2412        use crate::float::nms_class_aware_float;
2413
2414        // Create two overlapping boxes with different classes
2415        let boxes = vec![
2416            DetectBox {
2417                bbox: BoundingBox {
2418                    xmin: 0.0,
2419                    ymin: 0.0,
2420                    xmax: 0.5,
2421                    ymax: 0.5,
2422                },
2423                score: 0.9,
2424                label: 0, // class 0
2425            },
2426            DetectBox {
2427                bbox: BoundingBox {
2428                    xmin: 0.1,
2429                    ymin: 0.1,
2430                    xmax: 0.6,
2431                    ymax: 0.6,
2432                },
2433                score: 0.8,
2434                label: 1, // class 1 - different class
2435            },
2436        ];
2437
2438        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2439        // threshold 0.3)
2440        let result = nms_class_aware_float(0.3, boxes.clone());
2441        assert_eq!(
2442            result.len(),
2443            2,
2444            "Class-aware NMS should keep both boxes with different classes"
2445        );
2446
2447        // Now test with same class - should suppress one
2448        let same_class_boxes = vec![
2449            DetectBox {
2450                bbox: BoundingBox {
2451                    xmin: 0.0,
2452                    ymin: 0.0,
2453                    xmax: 0.5,
2454                    ymax: 0.5,
2455                },
2456                score: 0.9,
2457                label: 0,
2458            },
2459            DetectBox {
2460                bbox: BoundingBox {
2461                    xmin: 0.1,
2462                    ymin: 0.1,
2463                    xmax: 0.6,
2464                    ymax: 0.6,
2465                },
2466                score: 0.8,
2467                label: 0, // same class
2468            },
2469        ];
2470
2471        let result = nms_class_aware_float(0.3, same_class_boxes);
2472        assert_eq!(
2473            result.len(),
2474            1,
2475            "Class-aware NMS should suppress overlapping box with same class"
2476        );
2477        assert_eq!(result[0].label, 0);
2478        assert!((result[0].score - 0.9).abs() < 1e-6);
2479    }
2480
2481    #[test]
2482    fn test_class_agnostic_vs_aware_nms() {
2483        use crate::float::{nms_class_aware_float, nms_float};
2484
2485        // Two overlapping boxes with different classes
2486        let boxes = vec![
2487            DetectBox {
2488                bbox: BoundingBox {
2489                    xmin: 0.0,
2490                    ymin: 0.0,
2491                    xmax: 0.5,
2492                    ymax: 0.5,
2493                },
2494                score: 0.9,
2495                label: 0,
2496            },
2497            DetectBox {
2498                bbox: BoundingBox {
2499                    xmin: 0.1,
2500                    ymin: 0.1,
2501                    xmax: 0.6,
2502                    ymax: 0.6,
2503                },
2504                score: 0.8,
2505                label: 1,
2506            },
2507        ];
2508
2509        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2510        let agnostic_result = nms_float(0.3, boxes.clone());
2511        assert_eq!(
2512            agnostic_result.len(),
2513            1,
2514            "Class-agnostic NMS should suppress overlapping boxes"
2515        );
2516
2517        // Class-aware should keep both (different classes)
2518        let aware_result = nms_class_aware_float(0.3, boxes);
2519        assert_eq!(
2520            aware_result.len(),
2521            2,
2522            "Class-aware NMS should keep boxes with different classes"
2523        );
2524    }
2525
2526    #[test]
2527    fn test_class_aware_nms_int() {
2528        use crate::byte::nms_class_aware_int;
2529
2530        // Create two overlapping boxes with different classes
2531        let boxes = vec![
2532            DetectBoxQuantized {
2533                bbox: BoundingBox {
2534                    xmin: 0.0,
2535                    ymin: 0.0,
2536                    xmax: 0.5,
2537                    ymax: 0.5,
2538                },
2539                score: 200_u8,
2540                label: 0,
2541            },
2542            DetectBoxQuantized {
2543                bbox: BoundingBox {
2544                    xmin: 0.1,
2545                    ymin: 0.1,
2546                    xmax: 0.6,
2547                    ymax: 0.6,
2548                },
2549                score: 180_u8,
2550                label: 1, // different class
2551            },
2552        ];
2553
2554        // Should keep both (different classes)
2555        let result = nms_class_aware_int(0.5, boxes);
2556        assert_eq!(
2557            result.len(),
2558            2,
2559            "Class-aware NMS (int) should keep boxes with different classes"
2560        );
2561    }
2562
2563    #[test]
2564    fn test_nms_enum_default() {
2565        // Test that Nms enum has the correct default
2566        let default_nms: configs::Nms = Default::default();
2567        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2568    }
2569
2570    #[test]
2571    fn test_decoder_nms_mode() {
2572        // Test that decoder properly stores NMS mode
2573        let decoder = DecoderBuilder::default()
2574            .with_config_yolo_det(
2575                configs::Detection {
2576                    anchors: None,
2577                    decoder: DecoderType::Ultralytics,
2578                    quantization: None,
2579                    shape: vec![1, 84, 8400],
2580                    dshape: Vec::new(),
2581                    normalized: Some(true),
2582                },
2583                None,
2584            )
2585            .with_nms(Some(configs::Nms::ClassAware))
2586            .build()
2587            .unwrap();
2588
2589        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2590    }
2591
2592    #[test]
2593    fn test_decoder_nms_bypass() {
2594        // Test that decoder can be configured with nms=None (bypass)
2595        let decoder = DecoderBuilder::default()
2596            .with_config_yolo_det(
2597                configs::Detection {
2598                    anchors: None,
2599                    decoder: DecoderType::Ultralytics,
2600                    quantization: None,
2601                    shape: vec![1, 84, 8400],
2602                    dshape: Vec::new(),
2603                    normalized: Some(true),
2604                },
2605                None,
2606            )
2607            .with_nms(None)
2608            .build()
2609            .unwrap();
2610
2611        assert_eq!(decoder.nms, None);
2612    }
2613
2614    #[test]
2615    fn test_decoder_normalized_boxes_true() {
2616        // Test that normalized_boxes returns Some(true) when explicitly set
2617        let decoder = DecoderBuilder::default()
2618            .with_config_yolo_det(
2619                configs::Detection {
2620                    anchors: None,
2621                    decoder: DecoderType::Ultralytics,
2622                    quantization: None,
2623                    shape: vec![1, 84, 8400],
2624                    dshape: Vec::new(),
2625                    normalized: Some(true),
2626                },
2627                None,
2628            )
2629            .build()
2630            .unwrap();
2631
2632        assert_eq!(decoder.normalized_boxes(), Some(true));
2633    }
2634
2635    #[test]
2636    fn test_decoder_normalized_boxes_false() {
2637        // Test that normalized_boxes returns Some(false) when config specifies
2638        // unnormalized
2639        let decoder = DecoderBuilder::default()
2640            .with_config_yolo_det(
2641                configs::Detection {
2642                    anchors: None,
2643                    decoder: DecoderType::Ultralytics,
2644                    quantization: None,
2645                    shape: vec![1, 84, 8400],
2646                    dshape: Vec::new(),
2647                    normalized: Some(false),
2648                },
2649                None,
2650            )
2651            .build()
2652            .unwrap();
2653
2654        assert_eq!(decoder.normalized_boxes(), Some(false));
2655    }
2656
2657    #[test]
2658    fn test_decoder_normalized_boxes_unknown() {
2659        // Test that normalized_boxes returns None when not specified in config
2660        let decoder = DecoderBuilder::default()
2661            .with_config_yolo_det(
2662                configs::Detection {
2663                    anchors: None,
2664                    decoder: DecoderType::Ultralytics,
2665                    quantization: None,
2666                    shape: vec![1, 84, 8400],
2667                    dshape: Vec::new(),
2668                    normalized: None,
2669                },
2670                Some(DecoderVersion::Yolo11),
2671            )
2672            .build()
2673            .unwrap();
2674
2675        assert_eq!(decoder.normalized_boxes(), None);
2676    }
2677}