Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod yolo;
77
78mod decoder;
79pub use decoder::*;
80
81pub use configs::{DecoderVersion, Nms};
82pub use error::{DecoderError, DecoderResult};
83
84use crate::{
85    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
86    yolo::yolo_segmentation_to_mask,
87};
88
89/// Trait to convert bounding box formats to XYXY float format
90pub trait BBoxTypeTrait {
91    /// Converts the bbox into XYXY float format.
92    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
93
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
96        input: &[B; 4],
97        quant: Quantization,
98    ) -> [A; 4]
99    where
100        f32: AsPrimitive<A>,
101        i32: AsPrimitive<A>;
102
103    /// Converts the bbox into XYXY float format.
104    ///
105    /// # Examples
106    /// ```rust
107    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
108    /// # use ndarray::array;
109    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
110    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
111    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
112    /// ```
113    #[inline(always)]
114    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
115        input: ArrayView1<B>,
116    ) -> [A; 4] {
117        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
118    }
119
120    #[inline(always)]
121    /// Converts the bbox into XYXY float format.
122    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
123        input: ArrayView1<B>,
124        quant: Quantization,
125    ) -> [A; 4]
126    where
127        f32: AsPrimitive<A>,
128        i32: AsPrimitive<A>,
129    {
130        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
131    }
132}
133
134/// Converts XYXY bounding boxes to XYXY
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct XYXY {}
137
138impl BBoxTypeTrait for XYXY {
139    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
140        input.map(|b| b.as_())
141    }
142
143    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
144        input: &[B; 4],
145        quant: Quantization,
146    ) -> [A; 4]
147    where
148        f32: AsPrimitive<A>,
149        i32: AsPrimitive<A>,
150    {
151        let scale = quant.scale.as_();
152        let zp = quant.zero_point.as_();
153        input.map(|b| (b.as_() - zp) * scale)
154    }
155
156    #[inline(always)]
157    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
158        input: ArrayView1<B>,
159    ) -> [A; 4] {
160        [
161            input[0].as_(),
162            input[1].as_(),
163            input[2].as_(),
164            input[3].as_(),
165        ]
166    }
167}
168
169/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
170/// box
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub struct XYWH {}
173
174impl BBoxTypeTrait for XYWH {
175    #[inline(always)]
176    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
177        let half = A::one() / (A::one() + A::one());
178        [
179            (input[0].as_()) - (input[2].as_() * half),
180            (input[1].as_()) - (input[3].as_() * half),
181            (input[0].as_()) + (input[2].as_() * half),
182            (input[1].as_()) + (input[3].as_() * half),
183        ]
184    }
185
186    #[inline(always)]
187    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
188        input: &[B; 4],
189        quant: Quantization,
190    ) -> [A; 4]
191    where
192        f32: AsPrimitive<A>,
193        i32: AsPrimitive<A>,
194    {
195        let scale = quant.scale.as_();
196        let half_scale = (quant.scale * 0.5).as_();
197        let zp = quant.zero_point.as_();
198        let [x, y, w, h] = [
199            (input[0].as_() - zp) * scale,
200            (input[1].as_() - zp) * scale,
201            (input[2].as_() - zp) * half_scale,
202            (input[3].as_() - zp) * half_scale,
203        ];
204
205        [x - w, y - h, x + w, y + h]
206    }
207
208    #[inline(always)]
209    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
210        input: ArrayView1<B>,
211    ) -> [A; 4] {
212        let half = A::one() / (A::one() + A::one());
213        [
214            (input[0].as_()) - (input[2].as_() * half),
215            (input[1].as_()) - (input[3].as_() * half),
216            (input[0].as_()) + (input[2].as_() * half),
217            (input[1].as_()) + (input[3].as_() * half),
218        ]
219    }
220}
221
222/// Describes the quantization parameters for a tensor
223#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct Quantization {
225    pub scale: f32,
226    pub zero_point: i32,
227}
228
229impl Quantization {
230    /// Creates a new Quantization struct
231    /// # Examples
232    /// ```
233    /// # use edgefirst_decoder::Quantization;
234    /// let quant = Quantization::new(0.1, -128);
235    /// assert_eq!(quant.scale, 0.1);
236    /// assert_eq!(quant.zero_point, -128);
237    /// ```
238    pub fn new(scale: f32, zero_point: i32) -> Self {
239        Self { scale, zero_point }
240    }
241}
242
243impl From<QuantTuple> for Quantization {
244    /// Creates a new Quantization struct from a QuantTuple
245    /// # Examples
246    /// ```
247    /// # use edgefirst_decoder::Quantization;
248    /// # use edgefirst_decoder::configs::QuantTuple;
249    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
250    /// let quant = Quantization::from(quant_tuple);
251    /// assert_eq!(quant.scale, 0.1);
252    /// assert_eq!(quant.zero_point, -128);
253    /// ```
254    fn from(quant_tuple: QuantTuple) -> Quantization {
255        Quantization {
256            scale: quant_tuple.0,
257            zero_point: quant_tuple.1,
258        }
259    }
260}
261
262impl<S, Z> From<(S, Z)> for Quantization
263where
264    S: AsPrimitive<f32>,
265    Z: AsPrimitive<i32>,
266{
267    /// Creates a new Quantization struct from a tuple
268    /// # Examples
269    /// ```
270    /// # use edgefirst_decoder::Quantization;
271    /// let quant = Quantization::from((0.1_f64, -128_i64));
272    /// assert_eq!(quant.scale, 0.1);
273    /// assert_eq!(quant.zero_point, -128);
274    /// ```
275    fn from((scale, zp): (S, Z)) -> Quantization {
276        Self {
277            scale: scale.as_(),
278            zero_point: zp.as_(),
279        }
280    }
281}
282
283impl Default for Quantization {
284    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
285    /// # Examples
286    /// ```rust
287    /// # use edgefirst_decoder::Quantization;
288    /// let quant = Quantization::default();
289    /// assert_eq!(quant.scale, 1.0);
290    /// assert_eq!(quant.zero_point, 0);
291    /// ```
292    fn default() -> Self {
293        Self {
294            scale: 1.0,
295            zero_point: 0,
296        }
297    }
298}
299
300/// A detection box with f32 bbox and score
301#[derive(Debug, Clone, Copy, PartialEq, Default)]
302pub struct DetectBox {
303    pub bbox: BoundingBox,
304    /// model-specific score for this detection, higher implies more confidence
305    pub score: f32,
306    /// label index for this detection
307    pub label: usize,
308}
309
310/// A bounding box with f32 coordinates in XYXY format
311#[derive(Debug, Clone, Copy, PartialEq, Default)]
312pub struct BoundingBox {
313    /// left-most normalized coordinate of the bounding box
314    pub xmin: f32,
315    /// top-most normalized coordinate of the bounding box
316    pub ymin: f32,
317    /// right-most normalized coordinate of the bounding box
318    pub xmax: f32,
319    /// bottom-most normalized coordinate of the bounding box
320    pub ymax: f32,
321}
322
323impl BoundingBox {
324    /// Creates a new BoundingBox from the given coordinates
325    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
326        Self {
327            xmin,
328            ymin,
329            xmax,
330            ymax,
331        }
332    }
333
334    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
335    ///
336    /// ```
337    /// # use edgefirst_decoder::BoundingBox;
338    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
339    /// let canonical_bbox = bbox.to_canonical();
340    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
341    /// ```
342    pub fn to_canonical(&self) -> Self {
343        let xmin = self.xmin.min(self.xmax);
344        let xmax = self.xmin.max(self.xmax);
345        let ymin = self.ymin.min(self.ymax);
346        let ymax = self.ymin.max(self.ymax);
347        BoundingBox {
348            xmin,
349            ymin,
350            xmax,
351            ymax,
352        }
353    }
354}
355
356impl From<BoundingBox> for [f32; 4] {
357    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
358    /// xmax, ymax order
359    /// # Examples
360    /// ```
361    /// # use edgefirst_decoder::BoundingBox;
362    /// let bbox = BoundingBox {
363    ///     xmin: 0.1,
364    ///     ymin: 0.2,
365    ///     xmax: 0.3,
366    ///     ymax: 0.4,
367    /// };
368    /// let arr: [f32; 4] = bbox.into();
369    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
370    /// ```
371    fn from(b: BoundingBox) -> Self {
372        [b.xmin, b.ymin, b.xmax, b.ymax]
373    }
374}
375
376impl From<[f32; 4]> for BoundingBox {
377    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
378    // BoundingBox
379    fn from(arr: [f32; 4]) -> Self {
380        BoundingBox {
381            xmin: arr[0],
382            ymin: arr[1],
383            xmax: arr[2],
384            ymax: arr[3],
385        }
386    }
387}
388
389impl DetectBox {
390    /// Returns true if one detect box is equal to another detect box, within
391    /// the given `eps`
392    ///
393    /// # Examples
394    /// ```
395    /// # use edgefirst_decoder::DetectBox;
396    /// let box1 = DetectBox {
397    ///     bbox: edgefirst_decoder::BoundingBox {
398    ///         xmin: 0.1,
399    ///         ymin: 0.2,
400    ///         xmax: 0.3,
401    ///         ymax: 0.4,
402    ///     },
403    ///     score: 0.5,
404    ///     label: 1,
405    /// };
406    /// let box2 = DetectBox {
407    ///     bbox: edgefirst_decoder::BoundingBox {
408    ///         xmin: 0.101,
409    ///         ymin: 0.199,
410    ///         xmax: 0.301,
411    ///         ymax: 0.399,
412    ///     },
413    ///     score: 0.510,
414    ///     label: 1,
415    /// };
416    /// assert!(box1.equal_within_delta(&box2, 0.011));
417    /// ```
418    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
419        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
420        self.label == rhs.label
421            && eq_delta(self.score, rhs.score)
422            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
423            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
424            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
425            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
426    }
427}
428
429/// A segmentation result with a segmentation mask, and a normalized bounding
430/// box representing the area that the segmentation mask covers
431#[derive(Debug, Clone, PartialEq, Default)]
432pub struct Segmentation {
433    /// left-most normalized coordinate of the segmentation box
434    pub xmin: f32,
435    /// top-most normalized coordinate of the segmentation box
436    pub ymin: f32,
437    /// right-most normalized coordinate of the segmentation box
438    pub xmax: f32,
439    /// bottom-most normalized coordinate of the segmentation box
440    pub ymax: f32,
441    /// 3D segmentation array of shape `(H, W, C)`.
442    ///
443    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
444    /// continuous sigmoid confidence values quantized to u8 (0 = background,
445    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
446    /// 0.5) or use smooth interpolation for anti-aliased edges.
447    ///
448    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
449    /// class scores where the object class is the argmax index.
450    pub segmentation: Array3<u8>,
451}
452
453/// Prototype tensor variants for fused decode+render pipelines.
454///
455/// Carries either raw quantized data (to skip CPU dequantization and let the
456/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
457/// paths).
458#[derive(Debug, Clone)]
459pub enum ProtoTensor {
460    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
461    /// The GPU fragment shader will dequantize per-texel using the scale and
462    /// zero_point.
463    Quantized {
464        protos: Array3<i8>,
465        quantization: Quantization,
466    },
467    /// Dequantized f32 protos (from float models or legacy path).
468    Float(Array3<f32>),
469}
470
471impl ProtoTensor {
472    /// Returns `true` if this is the quantized variant.
473    pub fn is_quantized(&self) -> bool {
474        matches!(self, ProtoTensor::Quantized { .. })
475    }
476
477    /// Returns the spatial dimensions `(height, width, num_protos)`.
478    pub fn dim(&self) -> (usize, usize, usize) {
479        match self {
480            ProtoTensor::Quantized { protos, .. } => protos.dim(),
481            ProtoTensor::Float(arr) => arr.dim(),
482        }
483    }
484
485    /// Returns dequantized f32 protos. For the `Float` variant this is a
486    /// no-copy reference; for `Quantized` it allocates and dequantizes.
487    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
488        match self {
489            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
490            ProtoTensor::Quantized {
491                protos,
492                quantization,
493            } => {
494                let scale = quantization.scale;
495                let zp = quantization.zero_point as f32;
496                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
497            }
498        }
499    }
500}
501
502/// Raw prototype data for fused decode+render pipelines.
503///
504/// Holds post-NMS intermediate state before mask materialization, allowing the
505/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
506/// shader) without materializing intermediate `Array3<u8>` masks.
507#[derive(Debug, Clone)]
508pub struct ProtoData {
509    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
510    pub mask_coefficients: Vec<Vec<f32>>,
511    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
512    pub protos: ProtoTensor,
513}
514
515/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
516///
517///  # Examples
518/// ```
519/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
520/// let quant = Quantization::new(0.1, -128);
521/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
522/// let detect_quant = DetectBoxQuantized {
523///     bbox,
524///     score: 100_i8,
525///     label: 1,
526/// };
527/// let detect = dequant_detect_box(&detect_quant, quant);
528/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
529/// assert_eq!(detect.label, 1);
530/// assert_eq!(detect.bbox, bbox);
531/// ```
532pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
533    detect: &DetectBoxQuantized<SCORE>,
534    quant_scores: Quantization,
535) -> DetectBox {
536    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
537    DetectBox {
538        bbox: detect.bbox,
539        score: quant_scores.scale * detect.score.as_() + scaled_zp,
540        label: detect.label,
541    }
542}
543/// A detection box with a f32 bbox and quantized score
544#[derive(Debug, Clone, Copy, PartialEq)]
545pub struct DetectBoxQuantized<
546    // BOX: Signed + PrimInt + AsPrimitive<f32>,
547    SCORE: PrimInt + AsPrimitive<f32>,
548> {
549    // pub bbox: BoundingBoxQuantized<BOX>,
550    pub bbox: BoundingBox,
551    /// model-specific score for this detection, higher implies more
552    /// confidence.
553    pub score: SCORE,
554    /// label index for this detect
555    pub label: usize,
556}
557
558/// Dequantizes an ndarray from quantized values to f32 values using the given
559/// quantization parameters
560///
561/// # Examples
562/// ```
563/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
564/// let quant = Quantization::new(0.1, -128);
565/// let input: Vec<i8> = vec![0, 127, -128, 64];
566/// let input_array = ndarray::Array1::from(input);
567/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
568/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
569/// ```
570pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
571    input: ArrayView<T, D>,
572    quant: Quantization,
573) -> Array<F, D>
574where
575    i32: num_traits::AsPrimitive<F>,
576    f32: num_traits::AsPrimitive<F>,
577{
578    let zero_point = quant.zero_point.as_();
579    let scale = quant.scale.as_();
580    if zero_point != F::zero() {
581        let scaled_zero = -zero_point * scale;
582        input.mapv(|d| d.as_() * scale + scaled_zero)
583    } else {
584        input.mapv(|d| d.as_() * scale)
585    }
586}
587
588/// Dequantizes a slice from quantized values to float values using the given
589/// quantization parameters
590///
591/// # Examples
592/// ```
593/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
594/// let quant = Quantization::new(0.1, -128);
595/// let input: Vec<i8> = vec![0, 127, -128, 64];
596/// let mut output: Vec<f32> = vec![0.0; input.len()];
597/// dequantize_cpu(&input, quant, &mut output);
598/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
599/// ```
600pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
601    input: &[T],
602    quant: Quantization,
603    output: &mut [F],
604) where
605    f32: num_traits::AsPrimitive<F>,
606    i32: num_traits::AsPrimitive<F>,
607{
608    assert!(input.len() == output.len());
609    let zero_point = quant.zero_point.as_();
610    let scale = quant.scale.as_();
611    if zero_point != F::zero() {
612        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
613        input
614            .iter()
615            .zip(output)
616            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
617    } else {
618        input
619            .iter()
620            .zip(output)
621            .for_each(|(d, deq)| *deq = d.as_() * scale);
622    }
623}
624
625/// Dequantizes a slice from quantized values to float values using the given
626/// quantization parameters, using chunked processing. This is around 5% faster
627/// than `dequantize_cpu` for large slices.
628///
629/// # Examples
630/// ```
631/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
632/// let quant = Quantization::new(0.1, -128);
633/// let input: Vec<i8> = vec![0, 127, -128, 64];
634/// let mut output: Vec<f32> = vec![0.0; input.len()];
635/// dequantize_cpu_chunked(&input, quant, &mut output);
636/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
637/// ```
638pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
639    input: &[T],
640    quant: Quantization,
641    output: &mut [F],
642) where
643    f32: num_traits::AsPrimitive<F>,
644    i32: num_traits::AsPrimitive<F>,
645{
646    assert!(input.len() == output.len());
647    let zero_point = quant.zero_point.as_();
648    let scale = quant.scale.as_();
649
650    let input = input.as_chunks::<4>();
651    let output = output.as_chunks_mut::<4>();
652
653    if zero_point != F::zero() {
654        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
655
656        input
657            .0
658            .iter()
659            .zip(output.0)
660            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
661        input
662            .1
663            .iter()
664            .zip(output.1)
665            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
666    } else {
667        input
668            .0
669            .iter()
670            .zip(output.0)
671            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
672        input
673            .1
674            .iter()
675            .zip(output.1)
676            .for_each(|(d, deq)| *deq = d.as_() * scale);
677    }
678}
679
680/// Converts a segmentation tensor into a 2D mask
681/// If the last dimension of the segmentation tensor is 1, values equal or
682/// above 128 are considered objects. Otherwise the object is the argmax index
683///
684/// # Errors
685///
686/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
687/// invalid shape.
688///
689/// # Examples
690/// ```
691/// # use edgefirst_decoder::segmentation_to_mask;
692/// let segmentation =
693///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
694/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
695/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
696/// ```
697pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
698    if segmentation.shape()[2] == 0 {
699        return Err(DecoderError::InvalidShape(
700            "Segmentation tensor must have non-zero depth".to_string(),
701        ));
702    }
703    if segmentation.shape()[2] == 1 {
704        yolo_segmentation_to_mask(segmentation, 128)
705    } else {
706        Ok(modelpack_segmentation_to_mask(segmentation))
707    }
708}
709
710/// Returns the maximum value and its index from a 1D array
711fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
712    score
713        .iter()
714        .enumerate()
715        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
716            if max > *s {
717                (max, arg_max)
718            } else {
719                (*s, ind)
720            }
721        })
722}
723#[cfg(test)]
724#[cfg_attr(coverage_nightly, coverage(off))]
725mod decoder_tests {
726    #![allow(clippy::excessive_precision)]
727    use crate::{
728        configs::{DecoderType, DimName, Protos},
729        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
730        yolo::{
731            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
732            decode_yolo_segdet_quant,
733        },
734        *,
735    };
736    use ndarray::{array, s, Array4};
737    use ndarray_stats::DeviationExt;
738
739    fn compare_outputs(
740        boxes: (&[DetectBox], &[DetectBox]),
741        masks: (&[Segmentation], &[Segmentation]),
742    ) {
743        let (boxes0, boxes1) = boxes;
744        let (masks0, masks1) = masks;
745
746        assert_eq!(boxes0.len(), boxes1.len());
747        assert_eq!(masks0.len(), masks1.len());
748
749        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
750            assert!(
751                b_i8.equal_within_delta(b_f32, 1e-6),
752                "{b_i8:?} is not equal to {b_f32:?}"
753            );
754        }
755
756        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
757            assert_eq!(
758                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
759                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
760            );
761            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
762            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
763            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
764            let diff = &mask_i8 - &mask_f32;
765            for x in 0..diff.shape()[0] {
766                for y in 0..diff.shape()[1] {
767                    for z in 0..diff.shape()[2] {
768                        let val = diff[[x, y, z]];
769                        assert!(
770                            val.abs() <= 1,
771                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
772                            x,
773                            y,
774                            z,
775                            val
776                        );
777                    }
778                }
779            }
780            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
781            assert!(
782                mean_sq_err < 1e-2,
783                "Mean Square Error between masks was greater than 1%: {:.2}%",
784                mean_sq_err * 100.0
785            );
786        }
787    }
788
789    #[test]
790    fn test_decoder_modelpack() {
791        let score_threshold = 0.45;
792        let iou_threshold = 0.45;
793        let boxes = include_bytes!(concat!(
794            env!("CARGO_MANIFEST_DIR"),
795            "/../../testdata/modelpack_boxes_1935x1x4.bin"
796        ));
797        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
798
799        let scores = include_bytes!(concat!(
800            env!("CARGO_MANIFEST_DIR"),
801            "/../../testdata/modelpack_scores_1935x1.bin"
802        ));
803        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
804
805        let quant_boxes = (0.004656755365431309, 21).into();
806        let quant_scores = (0.0019603664986789227, 0).into();
807
808        let decoder = DecoderBuilder::default()
809            .with_config_modelpack_det(
810                configs::Boxes {
811                    decoder: DecoderType::ModelPack,
812                    quantization: Some(quant_boxes),
813                    shape: vec![1, 1935, 1, 4],
814                    dshape: vec![
815                        (DimName::Batch, 1),
816                        (DimName::NumBoxes, 1935),
817                        (DimName::Padding, 1),
818                        (DimName::BoxCoords, 4),
819                    ],
820                    normalized: Some(true),
821                },
822                configs::Scores {
823                    decoder: DecoderType::ModelPack,
824                    quantization: Some(quant_scores),
825                    shape: vec![1, 1935, 1],
826                    dshape: vec![
827                        (DimName::Batch, 1),
828                        (DimName::NumBoxes, 1935),
829                        (DimName::NumClasses, 1),
830                    ],
831                },
832            )
833            .with_score_threshold(score_threshold)
834            .with_iou_threshold(iou_threshold)
835            .build()
836            .unwrap();
837
838        let quant_boxes = quant_boxes.into();
839        let quant_scores = quant_scores.into();
840
841        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
842        decode_modelpack_det(
843            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
844            (scores.slice(s![0, .., ..]), quant_scores),
845            score_threshold,
846            iou_threshold,
847            &mut output_boxes,
848        );
849        assert!(output_boxes[0].equal_within_delta(
850            &DetectBox {
851                bbox: BoundingBox {
852                    xmin: 0.40513772,
853                    ymin: 0.6379755,
854                    xmax: 0.5122431,
855                    ymax: 0.7730214,
856                },
857                score: 0.4861709,
858                label: 0
859            },
860            1e-6
861        ));
862
863        let mut output_boxes1 = Vec::with_capacity(50);
864        let mut output_masks1 = Vec::with_capacity(50);
865
866        decoder
867            .decode_quantized(
868                &[boxes.view().into(), scores.view().into()],
869                &mut output_boxes1,
870                &mut output_masks1,
871            )
872            .unwrap();
873
874        let mut output_boxes_float = Vec::with_capacity(50);
875        let mut output_masks_float = Vec::with_capacity(50);
876
877        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
878        let scores = dequantize_ndarray(scores.view(), quant_scores);
879
880        decoder
881            .decode_float::<f32>(
882                &[boxes.view().into_dyn(), scores.view().into_dyn()],
883                &mut output_boxes_float,
884                &mut output_masks_float,
885            )
886            .unwrap();
887
888        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
889        compare_outputs(
890            (&output_boxes, &output_boxes_float),
891            (&[], &output_masks_float),
892        );
893    }
894
895    #[test]
896    fn test_decoder_modelpack_split_u8() {
897        let score_threshold = 0.45;
898        let iou_threshold = 0.45;
899        let detect0 = include_bytes!(concat!(
900            env!("CARGO_MANIFEST_DIR"),
901            "/../../testdata/modelpack_split_9x15x18.bin"
902        ));
903        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
904
905        let detect1 = include_bytes!(concat!(
906            env!("CARGO_MANIFEST_DIR"),
907            "/../../testdata/modelpack_split_17x30x18.bin"
908        ));
909        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
910
911        let quant0 = (0.08547406643629074, 174).into();
912        let quant1 = (0.09929127991199493, 183).into();
913        let anchors0 = vec![
914            [0.36666667461395264, 0.31481480598449707],
915            [0.38749998807907104, 0.4740740656852722],
916            [0.5333333611488342, 0.644444465637207],
917        ];
918        let anchors1 = vec![
919            [0.13750000298023224, 0.2074074000120163],
920            [0.2541666626930237, 0.21481481194496155],
921            [0.23125000298023224, 0.35185185074806213],
922        ];
923
924        let detect_config0 = configs::Detection {
925            decoder: DecoderType::ModelPack,
926            shape: vec![1, 9, 15, 18],
927            anchors: Some(anchors0.clone()),
928            quantization: Some(quant0),
929            dshape: vec![
930                (DimName::Batch, 1),
931                (DimName::Height, 9),
932                (DimName::Width, 15),
933                (DimName::NumAnchorsXFeatures, 18),
934            ],
935            normalized: Some(true),
936        };
937
938        let detect_config1 = configs::Detection {
939            decoder: DecoderType::ModelPack,
940            shape: vec![1, 17, 30, 18],
941            anchors: Some(anchors1.clone()),
942            quantization: Some(quant1),
943            dshape: vec![
944                (DimName::Batch, 1),
945                (DimName::Height, 17),
946                (DimName::Width, 30),
947                (DimName::NumAnchorsXFeatures, 18),
948            ],
949            normalized: Some(true),
950        };
951
952        let config0 = (&detect_config0).try_into().unwrap();
953        let config1 = (&detect_config1).try_into().unwrap();
954
955        let decoder = DecoderBuilder::default()
956            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
957            .with_score_threshold(score_threshold)
958            .with_iou_threshold(iou_threshold)
959            .build()
960            .unwrap();
961
962        let quant0 = quant0.into();
963        let quant1 = quant1.into();
964
965        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
966        decode_modelpack_split_quant(
967            &[
968                detect0.slice(s![0, .., .., ..]),
969                detect1.slice(s![0, .., .., ..]),
970            ],
971            &[config0, config1],
972            score_threshold,
973            iou_threshold,
974            &mut output_boxes,
975        );
976        assert!(output_boxes[0].equal_within_delta(
977            &DetectBox {
978                bbox: BoundingBox {
979                    xmin: 0.43171933,
980                    ymin: 0.68243736,
981                    xmax: 0.5626645,
982                    ymax: 0.808863,
983                },
984                score: 0.99240804,
985                label: 0
986            },
987            1e-6
988        ));
989
990        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
991        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
992        decoder
993            .decode_quantized(
994                &[detect0.view().into(), detect1.view().into()],
995                &mut output_boxes1,
996                &mut output_masks1,
997            )
998            .unwrap();
999
1000        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1001        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1002
1003        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1004        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1005        decoder
1006            .decode_float::<f32>(
1007                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1008                &mut output_boxes1_f32,
1009                &mut output_masks1_f32,
1010            )
1011            .unwrap();
1012
1013        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1014        compare_outputs(
1015            (&output_boxes, &output_boxes1_f32),
1016            (&[], &output_masks1_f32),
1017        );
1018    }
1019
1020    #[test]
1021    fn test_decoder_parse_config_modelpack_split_u8() {
1022        let score_threshold = 0.45;
1023        let iou_threshold = 0.45;
1024        let detect0 = include_bytes!(concat!(
1025            env!("CARGO_MANIFEST_DIR"),
1026            "/../../testdata/modelpack_split_9x15x18.bin"
1027        ));
1028        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1029
1030        let detect1 = include_bytes!(concat!(
1031            env!("CARGO_MANIFEST_DIR"),
1032            "/../../testdata/modelpack_split_17x30x18.bin"
1033        ));
1034        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1035
1036        let decoder = DecoderBuilder::default()
1037            .with_config_yaml_str(
1038                include_str!(concat!(
1039                    env!("CARGO_MANIFEST_DIR"),
1040                    "/../../testdata/modelpack_split.yaml"
1041                ))
1042                .to_string(),
1043            )
1044            .with_score_threshold(score_threshold)
1045            .with_iou_threshold(iou_threshold)
1046            .build()
1047            .unwrap();
1048
1049        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1050        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1051        decoder
1052            .decode_quantized(
1053                &[
1054                    ArrayViewDQuantized::from(detect1.view()),
1055                    ArrayViewDQuantized::from(detect0.view()),
1056                ],
1057                &mut output_boxes,
1058                &mut output_masks,
1059            )
1060            .unwrap();
1061        assert!(output_boxes[0].equal_within_delta(
1062            &DetectBox {
1063                bbox: BoundingBox {
1064                    xmin: 0.43171933,
1065                    ymin: 0.68243736,
1066                    xmax: 0.5626645,
1067                    ymax: 0.808863,
1068                },
1069                score: 0.99240804,
1070                label: 0
1071            },
1072            1e-6
1073        ));
1074    }
1075
1076    #[test]
1077    fn test_modelpack_seg() {
1078        let out = include_bytes!(concat!(
1079            env!("CARGO_MANIFEST_DIR"),
1080            "/../../testdata/modelpack_seg_2x160x160.bin"
1081        ));
1082        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1083        let quant = (1.0 / 255.0, 0).into();
1084
1085        let decoder = DecoderBuilder::default()
1086            .with_config_modelpack_seg(configs::Segmentation {
1087                decoder: DecoderType::ModelPack,
1088                quantization: Some(quant),
1089                shape: vec![1, 2, 160, 160],
1090                dshape: vec![
1091                    (DimName::Batch, 1),
1092                    (DimName::NumClasses, 2),
1093                    (DimName::Height, 160),
1094                    (DimName::Width, 160),
1095                ],
1096            })
1097            .build()
1098            .unwrap();
1099        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1100        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1101        decoder
1102            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1103            .unwrap();
1104
1105        let mut mask = out.slice(s![0, .., .., ..]);
1106        mask.swap_axes(0, 1);
1107        mask.swap_axes(1, 2);
1108        let mask = [Segmentation {
1109            xmin: 0.0,
1110            ymin: 0.0,
1111            xmax: 1.0,
1112            ymax: 1.0,
1113            segmentation: mask.into_owned(),
1114        }];
1115        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1116
1117        decoder
1118            .decode_float::<f32>(
1119                &[dequantize_ndarray(out.view(), quant.into())
1120                    .view()
1121                    .into_dyn()],
1122                &mut output_boxes,
1123                &mut output_masks,
1124            )
1125            .unwrap();
1126
1127        // not expected for float decoder to have same values as quantized decoder, as
1128        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1129        // the model output. Thus the float output is the same as the quantized output
1130        // but scaled differently. However, it is expected that the mask after argmax
1131        // will be the same.
1132        compare_outputs((&[], &output_boxes), (&[], &[]));
1133        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1134        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1135
1136        assert_eq!(mask0, mask1);
1137    }
1138    #[test]
1139    fn test_modelpack_seg_quant() {
1140        let out = include_bytes!(concat!(
1141            env!("CARGO_MANIFEST_DIR"),
1142            "/../../testdata/modelpack_seg_2x160x160.bin"
1143        ));
1144        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1145        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1146        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1147        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1148        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1149        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1150
1151        let quant = (1.0 / 255.0, 0).into();
1152
1153        let decoder = DecoderBuilder::default()
1154            .with_config_modelpack_seg(configs::Segmentation {
1155                decoder: DecoderType::ModelPack,
1156                quantization: Some(quant),
1157                shape: vec![1, 2, 160, 160],
1158                dshape: vec![
1159                    (DimName::Batch, 1),
1160                    (DimName::NumClasses, 2),
1161                    (DimName::Height, 160),
1162                    (DimName::Width, 160),
1163                ],
1164            })
1165            .build()
1166            .unwrap();
1167        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1168        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1169        decoder
1170            .decode_quantized(
1171                &[out_u8.view().into()],
1172                &mut output_boxes,
1173                &mut output_masks_u8,
1174            )
1175            .unwrap();
1176
1177        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1178        decoder
1179            .decode_quantized(
1180                &[out_i8.view().into()],
1181                &mut output_boxes,
1182                &mut output_masks_i8,
1183            )
1184            .unwrap();
1185
1186        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1187        decoder
1188            .decode_quantized(
1189                &[out_u16.view().into()],
1190                &mut output_boxes,
1191                &mut output_masks_u16,
1192            )
1193            .unwrap();
1194
1195        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1196        decoder
1197            .decode_quantized(
1198                &[out_i16.view().into()],
1199                &mut output_boxes,
1200                &mut output_masks_i16,
1201            )
1202            .unwrap();
1203
1204        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1205        decoder
1206            .decode_quantized(
1207                &[out_u32.view().into()],
1208                &mut output_boxes,
1209                &mut output_masks_u32,
1210            )
1211            .unwrap();
1212
1213        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1214        decoder
1215            .decode_quantized(
1216                &[out_i32.view().into()],
1217                &mut output_boxes,
1218                &mut output_masks_i32,
1219            )
1220            .unwrap();
1221
1222        compare_outputs((&[], &output_boxes), (&[], &[]));
1223        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1224        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1225        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1226        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1227        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1228        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1229        assert_eq!(mask_u8, mask_i8);
1230        assert_eq!(mask_u8, mask_u16);
1231        assert_eq!(mask_u8, mask_i16);
1232        assert_eq!(mask_u8, mask_u32);
1233        assert_eq!(mask_u8, mask_i32);
1234    }
1235
1236    #[test]
1237    fn test_modelpack_segdet() {
1238        let score_threshold = 0.45;
1239        let iou_threshold = 0.45;
1240
1241        let boxes = include_bytes!(concat!(
1242            env!("CARGO_MANIFEST_DIR"),
1243            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1244        ));
1245        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1246
1247        let scores = include_bytes!(concat!(
1248            env!("CARGO_MANIFEST_DIR"),
1249            "/../../testdata/modelpack_scores_1935x1.bin"
1250        ));
1251        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1252
1253        let seg = include_bytes!(concat!(
1254            env!("CARGO_MANIFEST_DIR"),
1255            "/../../testdata/modelpack_seg_2x160x160.bin"
1256        ));
1257        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1258
1259        let quant_boxes = (0.004656755365431309, 21).into();
1260        let quant_scores = (0.0019603664986789227, 0).into();
1261        let quant_seg = (1.0 / 255.0, 0).into();
1262
1263        let decoder = DecoderBuilder::default()
1264            .with_config_modelpack_segdet(
1265                configs::Boxes {
1266                    decoder: DecoderType::ModelPack,
1267                    quantization: Some(quant_boxes),
1268                    shape: vec![1, 1935, 1, 4],
1269                    dshape: vec![
1270                        (DimName::Batch, 1),
1271                        (DimName::NumBoxes, 1935),
1272                        (DimName::Padding, 1),
1273                        (DimName::BoxCoords, 4),
1274                    ],
1275                    normalized: Some(true),
1276                },
1277                configs::Scores {
1278                    decoder: DecoderType::ModelPack,
1279                    quantization: Some(quant_scores),
1280                    shape: vec![1, 1935, 1],
1281                    dshape: vec![
1282                        (DimName::Batch, 1),
1283                        (DimName::NumBoxes, 1935),
1284                        (DimName::NumClasses, 1),
1285                    ],
1286                },
1287                configs::Segmentation {
1288                    decoder: DecoderType::ModelPack,
1289                    quantization: Some(quant_seg),
1290                    shape: vec![1, 2, 160, 160],
1291                    dshape: vec![
1292                        (DimName::Batch, 1),
1293                        (DimName::NumClasses, 2),
1294                        (DimName::Height, 160),
1295                        (DimName::Width, 160),
1296                    ],
1297                },
1298            )
1299            .with_iou_threshold(iou_threshold)
1300            .with_score_threshold(score_threshold)
1301            .build()
1302            .unwrap();
1303        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1304        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1305        decoder
1306            .decode_quantized(
1307                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1308                &mut output_boxes,
1309                &mut output_masks,
1310            )
1311            .unwrap();
1312
1313        let mut mask = seg.slice(s![0, .., .., ..]);
1314        mask.swap_axes(0, 1);
1315        mask.swap_axes(1, 2);
1316        let mask = [Segmentation {
1317            xmin: 0.0,
1318            ymin: 0.0,
1319            xmax: 1.0,
1320            ymax: 1.0,
1321            segmentation: mask.into_owned(),
1322        }];
1323        let correct_boxes = [DetectBox {
1324            bbox: BoundingBox {
1325                xmin: 0.40513772,
1326                ymin: 0.6379755,
1327                xmax: 0.5122431,
1328                ymax: 0.7730214,
1329            },
1330            score: 0.4861709,
1331            label: 0,
1332        }];
1333        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1334
1335        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1336        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1337        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1338        decoder
1339            .decode_float::<f32>(
1340                &[
1341                    scores.view().into_dyn(),
1342                    boxes.view().into_dyn(),
1343                    seg.view().into_dyn(),
1344                ],
1345                &mut output_boxes,
1346                &mut output_masks,
1347            )
1348            .unwrap();
1349
1350        // not expected for float segmentation decoder to have same values as quantized
1351        // segmentation decoder, as float decoder ensures the data fills 0-255,
1352        // quantized decoder uses whatever the model output. Thus the float
1353        // output is the same as the quantized output but scaled differently.
1354        // However, it is expected that the mask after argmax will be the same.
1355        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1356        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1357        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1358
1359        assert_eq!(mask0, mask1);
1360    }
1361
1362    #[test]
1363    fn test_modelpack_segdet_split() {
1364        let score_threshold = 0.8;
1365        let iou_threshold = 0.5;
1366
1367        let seg = include_bytes!(concat!(
1368            env!("CARGO_MANIFEST_DIR"),
1369            "/../../testdata/modelpack_seg_2x160x160.bin"
1370        ));
1371        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1372
1373        let detect0 = include_bytes!(concat!(
1374            env!("CARGO_MANIFEST_DIR"),
1375            "/../../testdata/modelpack_split_9x15x18.bin"
1376        ));
1377        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1378
1379        let detect1 = include_bytes!(concat!(
1380            env!("CARGO_MANIFEST_DIR"),
1381            "/../../testdata/modelpack_split_17x30x18.bin"
1382        ));
1383        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1384
1385        let quant0 = (0.08547406643629074, 174).into();
1386        let quant1 = (0.09929127991199493, 183).into();
1387        let quant_seg = (1.0 / 255.0, 0).into();
1388
1389        let anchors0 = vec![
1390            [0.36666667461395264, 0.31481480598449707],
1391            [0.38749998807907104, 0.4740740656852722],
1392            [0.5333333611488342, 0.644444465637207],
1393        ];
1394        let anchors1 = vec![
1395            [0.13750000298023224, 0.2074074000120163],
1396            [0.2541666626930237, 0.21481481194496155],
1397            [0.23125000298023224, 0.35185185074806213],
1398        ];
1399
1400        let decoder = DecoderBuilder::default()
1401            .with_config_modelpack_segdet_split(
1402                vec![
1403                    configs::Detection {
1404                        decoder: DecoderType::ModelPack,
1405                        shape: vec![1, 17, 30, 18],
1406                        anchors: Some(anchors1),
1407                        quantization: Some(quant1),
1408                        dshape: vec![
1409                            (DimName::Batch, 1),
1410                            (DimName::Height, 17),
1411                            (DimName::Width, 30),
1412                            (DimName::NumAnchorsXFeatures, 18),
1413                        ],
1414                        normalized: Some(true),
1415                    },
1416                    configs::Detection {
1417                        decoder: DecoderType::ModelPack,
1418                        shape: vec![1, 9, 15, 18],
1419                        anchors: Some(anchors0),
1420                        quantization: Some(quant0),
1421                        dshape: vec![
1422                            (DimName::Batch, 1),
1423                            (DimName::Height, 9),
1424                            (DimName::Width, 15),
1425                            (DimName::NumAnchorsXFeatures, 18),
1426                        ],
1427                        normalized: Some(true),
1428                    },
1429                ],
1430                configs::Segmentation {
1431                    decoder: DecoderType::ModelPack,
1432                    quantization: Some(quant_seg),
1433                    shape: vec![1, 2, 160, 160],
1434                    dshape: vec![
1435                        (DimName::Batch, 1),
1436                        (DimName::NumClasses, 2),
1437                        (DimName::Height, 160),
1438                        (DimName::Width, 160),
1439                    ],
1440                },
1441            )
1442            .with_score_threshold(score_threshold)
1443            .with_iou_threshold(iou_threshold)
1444            .build()
1445            .unwrap();
1446        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1447        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1448        decoder
1449            .decode_quantized(
1450                &[
1451                    detect0.view().into(),
1452                    detect1.view().into(),
1453                    seg.view().into(),
1454                ],
1455                &mut output_boxes,
1456                &mut output_masks,
1457            )
1458            .unwrap();
1459
1460        let mut mask = seg.slice(s![0, .., .., ..]);
1461        mask.swap_axes(0, 1);
1462        mask.swap_axes(1, 2);
1463        let mask = [Segmentation {
1464            xmin: 0.0,
1465            ymin: 0.0,
1466            xmax: 1.0,
1467            ymax: 1.0,
1468            segmentation: mask.into_owned(),
1469        }];
1470        let correct_boxes = [DetectBox {
1471            bbox: BoundingBox {
1472                xmin: 0.43171933,
1473                ymin: 0.68243736,
1474                xmax: 0.5626645,
1475                ymax: 0.808863,
1476            },
1477            score: 0.99240804,
1478            label: 0,
1479        }];
1480        println!("Output Boxes: {:?}", output_boxes);
1481        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1482
1483        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1484        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1485        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1486        decoder
1487            .decode_float::<f32>(
1488                &[
1489                    detect0.view().into_dyn(),
1490                    detect1.view().into_dyn(),
1491                    seg.view().into_dyn(),
1492                ],
1493                &mut output_boxes,
1494                &mut output_masks,
1495            )
1496            .unwrap();
1497
1498        // not expected for float segmentation decoder to have same values as quantized
1499        // segmentation decoder, as float decoder ensures the data fills 0-255,
1500        // quantized decoder uses whatever the model output. Thus the float
1501        // output is the same as the quantized output but scaled differently.
1502        // However, it is expected that the mask after argmax will be the same.
1503        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1504        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1505        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1506
1507        assert_eq!(mask0, mask1);
1508    }
1509
1510    #[test]
1511    fn test_dequant_chunked() {
1512        let out = include_bytes!(concat!(
1513            env!("CARGO_MANIFEST_DIR"),
1514            "/../../testdata/yolov8s_80_classes.bin"
1515        ));
1516        let mut out =
1517            unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1518        out.push(123); // make sure to test non multiple of 16 length
1519
1520        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1521        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1522        let quant = Quantization::new(0.0040811873, -123);
1523        dequantize_cpu(&out, quant, &mut out_dequant);
1524
1525        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1526        assert_eq!(out_dequant, out_dequant_simd);
1527
1528        let quant = Quantization::new(0.0040811873, 0);
1529        dequantize_cpu(&out, quant, &mut out_dequant);
1530
1531        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1532        assert_eq!(out_dequant, out_dequant_simd);
1533    }
1534
1535    #[test]
1536    fn test_dequant_ground_truth() {
1537        // Formula: output = (input - zero_point) * scale
1538        // Verify both dequantize_cpu and dequantize_cpu_chunked against hand-computed values.
1539
1540        // Case 1: scale=0.1, zero_point=-128 (from doc example)
1541        let quant = Quantization::new(0.1, -128);
1542        let input: Vec<i8> = vec![0, 127, -128, 64];
1543        let mut output = vec![0.0f32; 4];
1544        let mut output_chunked = vec![0.0f32; 4];
1545        dequantize_cpu(&input, quant, &mut output);
1546        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1547        // (0 - (-128)) * 0.1 = 12.8
1548        // (127 - (-128)) * 0.1 = 25.5
1549        // (-128 - (-128)) * 0.1 = 0.0
1550        // (64 - (-128)) * 0.1 = 19.2
1551        let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1552        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1553            assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1554        }
1555        for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1556            assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1557        }
1558
1559        // Case 2: scale=1.0, zero_point=0 (identity-like)
1560        let quant = Quantization::new(1.0, 0);
1561        dequantize_cpu(&input, quant, &mut output);
1562        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1563        let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1564        assert_eq!(output, expected);
1565        assert_eq!(output_chunked, expected);
1566
1567        // Case 3: scale=0.5, zero_point=0
1568        let quant = Quantization::new(0.5, 0);
1569        dequantize_cpu(&input, quant, &mut output);
1570        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1571        let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1572        assert_eq!(output, expected);
1573        assert_eq!(output_chunked, expected);
1574
1575        // Case 4: i8 min/max boundaries with typical quantization params
1576        let quant = Quantization::new(0.021287762, 31);
1577        let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1578        let mut output = vec![0.0f32; 6];
1579        let mut output_chunked = vec![0.0f32; 6];
1580        dequantize_cpu(&input, quant, &mut output);
1581        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1582        for i in 0..6 {
1583            let expected = (input[i] as f32 - 31.0) * 0.021287762;
1584            assert!(
1585                (output[i] - expected).abs() < 1e-5,
1586                "cpu[{i}]: {} != {expected}",
1587                output[i]
1588            );
1589            assert!(
1590                (output_chunked[i] - expected).abs() < 1e-5,
1591                "chunked[{i}]: {} != {expected}",
1592                output_chunked[i]
1593            );
1594        }
1595    }
1596
1597    #[test]
1598    fn test_decoder_yolo_det() {
1599        let score_threshold = 0.25;
1600        let iou_threshold = 0.7;
1601        let out = include_bytes!(concat!(
1602            env!("CARGO_MANIFEST_DIR"),
1603            "/../../testdata/yolov8s_80_classes.bin"
1604        ));
1605        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1606        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1607        let quant = (0.0040811873, -123).into();
1608
1609        let decoder = DecoderBuilder::default()
1610            .with_config_yolo_det(
1611                configs::Detection {
1612                    decoder: DecoderType::Ultralytics,
1613                    shape: vec![1, 84, 8400],
1614                    anchors: None,
1615                    quantization: Some(quant),
1616                    dshape: vec![
1617                        (DimName::Batch, 1),
1618                        (DimName::NumFeatures, 84),
1619                        (DimName::NumBoxes, 8400),
1620                    ],
1621                    normalized: Some(true),
1622                },
1623                Some(DecoderVersion::Yolo11),
1624            )
1625            .with_score_threshold(score_threshold)
1626            .with_iou_threshold(iou_threshold)
1627            .build()
1628            .unwrap();
1629
1630        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1631        decode_yolo_det(
1632            (out.slice(s![0, .., ..]), quant.into()),
1633            score_threshold,
1634            iou_threshold,
1635            Some(configs::Nms::ClassAgnostic),
1636            &mut output_boxes,
1637        );
1638        assert!(output_boxes[0].equal_within_delta(
1639            &DetectBox {
1640                bbox: BoundingBox {
1641                    xmin: 0.5285137,
1642                    ymin: 0.05305544,
1643                    xmax: 0.87541467,
1644                    ymax: 0.9998909,
1645                },
1646                score: 0.5591227,
1647                label: 0
1648            },
1649            1e-6
1650        ));
1651
1652        assert!(output_boxes[1].equal_within_delta(
1653            &DetectBox {
1654                bbox: BoundingBox {
1655                    xmin: 0.130598,
1656                    ymin: 0.43260583,
1657                    xmax: 0.35098213,
1658                    ymax: 0.9958097,
1659                },
1660                score: 0.33057618,
1661                label: 75
1662            },
1663            1e-6
1664        ));
1665
1666        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1667        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1668        decoder
1669            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1670            .unwrap();
1671
1672        let out = dequantize_ndarray(out.view(), quant.into());
1673        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1674        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1675        decoder
1676            .decode_float::<f32>(
1677                &[out.view().into_dyn()],
1678                &mut output_boxes_f32,
1679                &mut output_masks_f32,
1680            )
1681            .unwrap();
1682
1683        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1684        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1685    }
1686
1687    #[test]
1688    fn test_decoder_masks() {
1689        let score_threshold = 0.45;
1690        let iou_threshold = 0.45;
1691        let boxes = include_bytes!(concat!(
1692            env!("CARGO_MANIFEST_DIR"),
1693            "/../../testdata/yolov8_boxes_116x8400.bin"
1694        ));
1695        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1696        let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1697        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1698
1699        let protos = include_bytes!(concat!(
1700            env!("CARGO_MANIFEST_DIR"),
1701            "/../../testdata/yolov8_protos_160x160x32.bin"
1702        ));
1703        let protos =
1704            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1705        let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1706        let quant_protos = Quantization::new(0.02491161972284317, -117);
1707        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1708        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1709        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1710        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1711        decode_yolo_segdet_float(
1712            seg.view(),
1713            protos.view(),
1714            score_threshold,
1715            iou_threshold,
1716            Some(configs::Nms::ClassAgnostic),
1717            &mut output_boxes,
1718            &mut output_masks,
1719        )
1720        .unwrap();
1721        assert_eq!(output_boxes.len(), 2);
1722        assert_eq!(output_boxes.len(), output_masks.len());
1723
1724        for (b, m) in output_boxes.iter().zip(&output_masks) {
1725            assert!(b.bbox.xmin >= m.xmin);
1726            assert!(b.bbox.ymin >= m.ymin);
1727            assert!(b.bbox.xmax >= m.xmax);
1728            assert!(b.bbox.ymax >= m.ymax);
1729        }
1730        assert!(output_boxes[0].equal_within_delta(
1731            &DetectBox {
1732                bbox: BoundingBox {
1733                    xmin: 0.08515105,
1734                    ymin: 0.7131401,
1735                    xmax: 0.29802868,
1736                    ymax: 0.8195788,
1737                },
1738                score: 0.91537374,
1739                label: 23
1740            },
1741            1.0 / 160.0, // wider range because mask will expand the box
1742        ));
1743
1744        assert!(output_boxes[1].equal_within_delta(
1745            &DetectBox {
1746                bbox: BoundingBox {
1747                    xmin: 0.59605736,
1748                    ymin: 0.25545314,
1749                    xmax: 0.93666154,
1750                    ymax: 0.72378385,
1751                },
1752                score: 0.91537374,
1753                label: 23
1754            },
1755            1.0 / 160.0, // wider range because mask will expand the box
1756        ));
1757
1758        let full_mask = include_bytes!(concat!(
1759            env!("CARGO_MANIFEST_DIR"),
1760            "/../../testdata/yolov8_mask_results.bin"
1761        ));
1762        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1763
1764        let cropped_mask = full_mask.slice(ndarray::s![
1765            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1766            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1767        ]);
1768
1769        assert_eq!(
1770            cropped_mask,
1771            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1772        );
1773    }
1774
1775    /// Regression test: config-driven path with NCHW protos (no dshape).
1776    /// Simulates YOLOv8-seg ONNX outputs where protos are (1, 32, 160, 160)
1777    /// and the YAML config has no dshape field — the exact scenario from
1778    /// hal_mask_matmul_bug.md.
1779    #[test]
1780    fn test_decoder_masks_nchw_protos() {
1781        let score_threshold = 0.45;
1782        let iou_threshold = 0.45;
1783
1784        // Load test data — boxes as [116, 8400]
1785        let boxes_raw = include_bytes!(concat!(
1786            env!("CARGO_MANIFEST_DIR"),
1787            "/../../testdata/yolov8_boxes_116x8400.bin"
1788        ));
1789        let boxes_raw =
1790            unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1791        let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1792        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1793
1794        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1795        let protos_raw = include_bytes!(concat!(
1796            env!("CARGO_MANIFEST_DIR"),
1797            "/../../testdata/yolov8_protos_160x160x32.bin"
1798        ));
1799        let protos_raw = unsafe {
1800            std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1801        };
1802        let protos_hwc =
1803            ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1804        let quant_protos = Quantization::new(0.02491161972284317, -117);
1805        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1806
1807        // ---- Reference: direct call with HWC protos (known working) ----
1808        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1809        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1810        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1811        decode_yolo_segdet_float(
1812            seg.view(),
1813            protos_f32_hwc.view(),
1814            score_threshold,
1815            iou_threshold,
1816            Some(configs::Nms::ClassAgnostic),
1817            &mut ref_boxes,
1818            &mut ref_masks,
1819        )
1820        .unwrap();
1821        assert_eq!(ref_boxes.len(), 2);
1822
1823        // ---- Config-driven path: NCHW protos, no dshape ----
1824        // Permute protos to NCHW [1, 32, 160, 160] as an ONNX model would output
1825        let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); // [32, 160, 160]
1826        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1827
1828        // Build boxes as [1, 116, 8400] f32
1829        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1830
1831        // Build decoder from config with no dshape on protos
1832        let decoder = DecoderBuilder::default()
1833            .with_config_yolo_segdet(
1834                configs::Detection {
1835                    decoder: configs::DecoderType::Ultralytics,
1836                    quantization: None,
1837                    shape: vec![1, 116, 8400],
1838                    dshape: vec![],
1839                    normalized: Some(true),
1840                    anchors: None,
1841                },
1842                configs::Protos {
1843                    decoder: configs::DecoderType::Ultralytics,
1844                    quantization: None,
1845                    shape: vec![1, 32, 160, 160],
1846                    dshape: vec![], // No dshape — simulates YAML without dshape
1847                },
1848                None, // decoder version
1849            )
1850            .with_score_threshold(score_threshold)
1851            .with_iou_threshold(iou_threshold)
1852            .build()
1853            .unwrap();
1854
1855        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1856        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1857        decoder
1858            .decode_float(
1859                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1860                &mut cfg_boxes,
1861                &mut cfg_masks,
1862            )
1863            .unwrap();
1864
1865        // Must produce the same number of detections
1866        assert_eq!(
1867            cfg_boxes.len(),
1868            ref_boxes.len(),
1869            "config path produced {} boxes, reference produced {}",
1870            cfg_boxes.len(),
1871            ref_boxes.len()
1872        );
1873
1874        // Boxes must match
1875        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1876            assert!(
1877                cb.equal_within_delta(rb, 0.01),
1878                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1879            );
1880        }
1881
1882        // Masks must match pixel-for-pixel
1883        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1884            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1885            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1886            assert_eq!(
1887                cm_arr, rm_arr,
1888                "mask {i} pixel mismatch between config-driven and reference paths"
1889            );
1890        }
1891    }
1892
1893    #[test]
1894    fn test_decoder_masks_i8() {
1895        let score_threshold = 0.45;
1896        let iou_threshold = 0.45;
1897        let boxes = include_bytes!(concat!(
1898            env!("CARGO_MANIFEST_DIR"),
1899            "/../../testdata/yolov8_boxes_116x8400.bin"
1900        ));
1901        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1902        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1903        let quant_boxes = (0.021287761628627777, 31).into();
1904
1905        let protos = include_bytes!(concat!(
1906            env!("CARGO_MANIFEST_DIR"),
1907            "/../../testdata/yolov8_protos_160x160x32.bin"
1908        ));
1909        let protos =
1910            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1911        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1912        let quant_protos = (0.02491161972284317, -117).into();
1913        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1914        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1915
1916        let decoder = DecoderBuilder::default()
1917            .with_config_yolo_segdet(
1918                configs::Detection {
1919                    decoder: configs::DecoderType::Ultralytics,
1920                    quantization: Some(quant_boxes),
1921                    shape: vec![1, 116, 8400],
1922                    anchors: None,
1923                    dshape: vec![
1924                        (DimName::Batch, 1),
1925                        (DimName::NumFeatures, 116),
1926                        (DimName::NumBoxes, 8400),
1927                    ],
1928                    normalized: Some(true),
1929                },
1930                Protos {
1931                    decoder: configs::DecoderType::Ultralytics,
1932                    quantization: Some(quant_protos),
1933                    shape: vec![1, 160, 160, 32],
1934                    dshape: vec![
1935                        (DimName::Batch, 1),
1936                        (DimName::Height, 160),
1937                        (DimName::Width, 160),
1938                        (DimName::NumProtos, 32),
1939                    ],
1940                },
1941                Some(DecoderVersion::Yolo11),
1942            )
1943            .with_score_threshold(score_threshold)
1944            .with_iou_threshold(iou_threshold)
1945            .build()
1946            .unwrap();
1947
1948        let quant_boxes = quant_boxes.into();
1949        let quant_protos = quant_protos.into();
1950
1951        decode_yolo_segdet_quant(
1952            (boxes.slice(s![0, .., ..]), quant_boxes),
1953            (protos.slice(s![0, .., .., ..]), quant_protos),
1954            score_threshold,
1955            iou_threshold,
1956            Some(configs::Nms::ClassAgnostic),
1957            &mut output_boxes,
1958            &mut output_masks,
1959        )
1960        .unwrap();
1961
1962        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1963        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1964
1965        decoder
1966            .decode_quantized(
1967                &[boxes.view().into(), protos.view().into()],
1968                &mut output_boxes1,
1969                &mut output_masks1,
1970            )
1971            .unwrap();
1972
1973        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1974        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1975
1976        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1977        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1978        decode_yolo_segdet_float(
1979            seg.slice(s![0, .., ..]),
1980            protos.slice(s![0, .., .., ..]),
1981            score_threshold,
1982            iou_threshold,
1983            Some(configs::Nms::ClassAgnostic),
1984            &mut output_boxes_f32,
1985            &mut output_masks_f32,
1986        )
1987        .unwrap();
1988
1989        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1990        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1991
1992        decoder
1993            .decode_float(
1994                &[seg.view().into_dyn(), protos.view().into_dyn()],
1995                &mut output_boxes1_f32,
1996                &mut output_masks1_f32,
1997            )
1998            .unwrap();
1999
2000        compare_outputs(
2001            (&output_boxes, &output_boxes1),
2002            (&output_masks, &output_masks1),
2003        );
2004
2005        compare_outputs(
2006            (&output_boxes, &output_boxes_f32),
2007            (&output_masks, &output_masks_f32),
2008        );
2009
2010        compare_outputs(
2011            (&output_boxes_f32, &output_boxes1_f32),
2012            (&output_masks_f32, &output_masks1_f32),
2013        );
2014    }
2015
2016    #[test]
2017    fn test_decoder_yolo_split() {
2018        let score_threshold = 0.45;
2019        let iou_threshold = 0.45;
2020        let boxes = include_bytes!(concat!(
2021            env!("CARGO_MANIFEST_DIR"),
2022            "/../../testdata/yolov8_boxes_116x8400.bin"
2023        ));
2024        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2025        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2026        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2027
2028        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2029
2030        let decoder = DecoderBuilder::default()
2031            .with_config_yolo_split_det(
2032                configs::Boxes {
2033                    decoder: configs::DecoderType::Ultralytics,
2034                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2035                    shape: vec![1, 4, 8400],
2036                    dshape: vec![
2037                        (DimName::Batch, 1),
2038                        (DimName::BoxCoords, 4),
2039                        (DimName::NumBoxes, 8400),
2040                    ],
2041                    normalized: Some(true),
2042                },
2043                configs::Scores {
2044                    decoder: configs::DecoderType::Ultralytics,
2045                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2046                    shape: vec![1, 80, 8400],
2047                    dshape: vec![
2048                        (DimName::Batch, 1),
2049                        (DimName::NumClasses, 80),
2050                        (DimName::NumBoxes, 8400),
2051                    ],
2052                },
2053            )
2054            .with_score_threshold(score_threshold)
2055            .with_iou_threshold(iou_threshold)
2056            .build()
2057            .unwrap();
2058
2059        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2060        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2061
2062        decoder
2063            .decode_quantized(
2064                &[
2065                    boxes.slice(s![.., ..4, ..]).into(),
2066                    boxes.slice(s![.., 4..84, ..]).into(),
2067                ],
2068                &mut output_boxes,
2069                &mut output_masks,
2070            )
2071            .unwrap();
2072
2073        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2074        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2075        decode_yolo_det_float(
2076            seg.slice(s![0, ..84, ..]),
2077            score_threshold,
2078            iou_threshold,
2079            Some(configs::Nms::ClassAgnostic),
2080            &mut output_boxes_f32,
2081        );
2082
2083        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2084        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2085
2086        decoder
2087            .decode_float(
2088                &[
2089                    seg.slice(s![.., ..4, ..]).into_dyn(),
2090                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2091                ],
2092                &mut output_boxes1,
2093                &mut output_masks1,
2094            )
2095            .unwrap();
2096        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2097        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2098    }
2099
2100    #[test]
2101    fn test_decoder_masks_config_mixed() {
2102        let score_threshold = 0.45;
2103        let iou_threshold = 0.45;
2104        let boxes = include_bytes!(concat!(
2105            env!("CARGO_MANIFEST_DIR"),
2106            "/../../testdata/yolov8_boxes_116x8400.bin"
2107        ));
2108        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2109        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2110        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2111
2112        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2113
2114        let protos = include_bytes!(concat!(
2115            env!("CARGO_MANIFEST_DIR"),
2116            "/../../testdata/yolov8_protos_160x160x32.bin"
2117        ));
2118        let protos =
2119            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2120        let protos: Vec<_> = protos.to_vec();
2121        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2122        let quant_protos = Quantization::new(0.02491161972284317, -117);
2123
2124        let decoder = DecoderBuilder::default()
2125            .with_config_yolo_split_segdet(
2126                configs::Boxes {
2127                    decoder: configs::DecoderType::Ultralytics,
2128                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2129                    shape: vec![1, 4, 8400],
2130                    dshape: vec![
2131                        (DimName::Batch, 1),
2132                        (DimName::BoxCoords, 4),
2133                        (DimName::NumBoxes, 8400),
2134                    ],
2135                    normalized: Some(true),
2136                },
2137                configs::Scores {
2138                    decoder: configs::DecoderType::Ultralytics,
2139                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2140                    shape: vec![1, 80, 8400],
2141                    dshape: vec![
2142                        (DimName::Batch, 1),
2143                        (DimName::NumClasses, 80),
2144                        (DimName::NumBoxes, 8400),
2145                    ],
2146                },
2147                configs::MaskCoefficients {
2148                    decoder: configs::DecoderType::Ultralytics,
2149                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2150                    shape: vec![1, 32, 8400],
2151                    dshape: vec![
2152                        (DimName::Batch, 1),
2153                        (DimName::NumProtos, 32),
2154                        (DimName::NumBoxes, 8400),
2155                    ],
2156                },
2157                configs::Protos {
2158                    decoder: configs::DecoderType::Ultralytics,
2159                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2160                    shape: vec![1, 160, 160, 32],
2161                    dshape: vec![
2162                        (DimName::Batch, 1),
2163                        (DimName::Height, 160),
2164                        (DimName::Width, 160),
2165                        (DimName::NumProtos, 32),
2166                    ],
2167                },
2168            )
2169            .with_score_threshold(score_threshold)
2170            .with_iou_threshold(iou_threshold)
2171            .build()
2172            .unwrap();
2173
2174        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2175        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2176
2177        decoder
2178            .decode_quantized(
2179                &[
2180                    boxes.slice(s![.., ..4, ..]).into(),
2181                    boxes.slice(s![.., 4..84, ..]).into(),
2182                    boxes.slice(s![.., 84.., ..]).into(),
2183                    protos.view().into(),
2184                ],
2185                &mut output_boxes,
2186                &mut output_masks,
2187            )
2188            .unwrap();
2189
2190        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2191        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2192        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2193        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2194        decode_yolo_segdet_float(
2195            seg.slice(s![0, .., ..]),
2196            protos.slice(s![0, .., .., ..]),
2197            score_threshold,
2198            iou_threshold,
2199            Some(configs::Nms::ClassAgnostic),
2200            &mut output_boxes_f32,
2201            &mut output_masks_f32,
2202        )
2203        .unwrap();
2204
2205        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2206        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2207
2208        decoder
2209            .decode_float(
2210                &[
2211                    seg.slice(s![.., ..4, ..]).into_dyn(),
2212                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2213                    seg.slice(s![.., 84.., ..]).into_dyn(),
2214                    protos.view().into_dyn(),
2215                ],
2216                &mut output_boxes1,
2217                &mut output_masks1,
2218            )
2219            .unwrap();
2220        compare_outputs(
2221            (&output_boxes, &output_boxes_f32),
2222            (&output_masks, &output_masks_f32),
2223        );
2224        compare_outputs(
2225            (&output_boxes_f32, &output_boxes1),
2226            (&output_masks_f32, &output_masks1),
2227        );
2228    }
2229
2230    #[test]
2231    fn test_decoder_masks_config_i32() {
2232        let score_threshold = 0.45;
2233        let iou_threshold = 0.45;
2234        let boxes = include_bytes!(concat!(
2235            env!("CARGO_MANIFEST_DIR"),
2236            "/../../testdata/yolov8_boxes_116x8400.bin"
2237        ));
2238        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2239        let scale = 1 << 23;
2240        let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2241        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2242
2243        let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2244
2245        let protos = include_bytes!(concat!(
2246            env!("CARGO_MANIFEST_DIR"),
2247            "/../../testdata/yolov8_protos_160x160x32.bin"
2248        ));
2249        let protos =
2250            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2251        let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2252        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2253        let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2254
2255        let decoder = DecoderBuilder::default()
2256            .with_config_yolo_split_segdet(
2257                configs::Boxes {
2258                    decoder: configs::DecoderType::Ultralytics,
2259                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2260                    shape: vec![1, 4, 8400],
2261                    dshape: vec![
2262                        (DimName::Batch, 1),
2263                        (DimName::BoxCoords, 4),
2264                        (DimName::NumBoxes, 8400),
2265                    ],
2266                    normalized: Some(true),
2267                },
2268                configs::Scores {
2269                    decoder: configs::DecoderType::Ultralytics,
2270                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2271                    shape: vec![1, 80, 8400],
2272                    dshape: vec![
2273                        (DimName::Batch, 1),
2274                        (DimName::NumClasses, 80),
2275                        (DimName::NumBoxes, 8400),
2276                    ],
2277                },
2278                configs::MaskCoefficients {
2279                    decoder: configs::DecoderType::Ultralytics,
2280                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2281                    shape: vec![1, 32, 8400],
2282                    dshape: vec![
2283                        (DimName::Batch, 1),
2284                        (DimName::NumProtos, 32),
2285                        (DimName::NumBoxes, 8400),
2286                    ],
2287                },
2288                configs::Protos {
2289                    decoder: configs::DecoderType::Ultralytics,
2290                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2291                    shape: vec![1, 160, 160, 32],
2292                    dshape: vec![
2293                        (DimName::Batch, 1),
2294                        (DimName::Height, 160),
2295                        (DimName::Width, 160),
2296                        (DimName::NumProtos, 32),
2297                    ],
2298                },
2299            )
2300            .with_score_threshold(score_threshold)
2301            .with_iou_threshold(iou_threshold)
2302            .build()
2303            .unwrap();
2304
2305        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2306        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2307
2308        decoder
2309            .decode_quantized(
2310                &[
2311                    boxes.slice(s![.., ..4, ..]).into(),
2312                    boxes.slice(s![.., 4..84, ..]).into(),
2313                    boxes.slice(s![.., 84.., ..]).into(),
2314                    protos.view().into(),
2315                ],
2316                &mut output_boxes,
2317                &mut output_masks,
2318            )
2319            .unwrap();
2320
2321        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2322        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2323        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2324        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2325        decode_yolo_segdet_float(
2326            seg.slice(s![0, .., ..]),
2327            protos.slice(s![0, .., .., ..]),
2328            score_threshold,
2329            iou_threshold,
2330            Some(configs::Nms::ClassAgnostic),
2331            &mut output_boxes_f32,
2332            &mut output_masks_f32,
2333        )
2334        .unwrap();
2335
2336        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2337        assert_eq!(output_masks.len(), output_masks_f32.len());
2338
2339        compare_outputs(
2340            (&output_boxes, &output_boxes_f32),
2341            (&output_masks, &output_masks_f32),
2342        );
2343    }
2344
2345    /// test running multiple decoders concurrently
2346    #[test]
2347    fn test_context_switch() {
2348        let yolo_det = || {
2349            let score_threshold = 0.25;
2350            let iou_threshold = 0.7;
2351            let out = include_bytes!(concat!(
2352                env!("CARGO_MANIFEST_DIR"),
2353                "/../../testdata/yolov8s_80_classes.bin"
2354            ));
2355            let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2356            let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2357            let quant = (0.0040811873, -123).into();
2358
2359            let decoder = DecoderBuilder::default()
2360                .with_config_yolo_det(
2361                    configs::Detection {
2362                        decoder: DecoderType::Ultralytics,
2363                        shape: vec![1, 84, 8400],
2364                        anchors: None,
2365                        quantization: Some(quant),
2366                        dshape: vec![
2367                            (DimName::Batch, 1),
2368                            (DimName::NumFeatures, 84),
2369                            (DimName::NumBoxes, 8400),
2370                        ],
2371                        normalized: None,
2372                    },
2373                    None,
2374                )
2375                .with_score_threshold(score_threshold)
2376                .with_iou_threshold(iou_threshold)
2377                .build()
2378                .unwrap();
2379
2380            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2381            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2382
2383            for _ in 0..100 {
2384                decoder
2385                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2386                    .unwrap();
2387
2388                assert!(output_boxes[0].equal_within_delta(
2389                    &DetectBox {
2390                        bbox: BoundingBox {
2391                            xmin: 0.5285137,
2392                            ymin: 0.05305544,
2393                            xmax: 0.87541467,
2394                            ymax: 0.9998909,
2395                        },
2396                        score: 0.5591227,
2397                        label: 0
2398                    },
2399                    1e-6
2400                ));
2401
2402                assert!(output_boxes[1].equal_within_delta(
2403                    &DetectBox {
2404                        bbox: BoundingBox {
2405                            xmin: 0.130598,
2406                            ymin: 0.43260583,
2407                            xmax: 0.35098213,
2408                            ymax: 0.9958097,
2409                        },
2410                        score: 0.33057618,
2411                        label: 75
2412                    },
2413                    1e-6
2414                ));
2415                assert!(output_masks.is_empty());
2416            }
2417        };
2418
2419        let modelpack_det_split = || {
2420            let score_threshold = 0.8;
2421            let iou_threshold = 0.5;
2422
2423            let seg = include_bytes!(concat!(
2424                env!("CARGO_MANIFEST_DIR"),
2425                "/../../testdata/modelpack_seg_2x160x160.bin"
2426            ));
2427            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2428
2429            let detect0 = include_bytes!(concat!(
2430                env!("CARGO_MANIFEST_DIR"),
2431                "/../../testdata/modelpack_split_9x15x18.bin"
2432            ));
2433            let detect0 =
2434                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2435
2436            let detect1 = include_bytes!(concat!(
2437                env!("CARGO_MANIFEST_DIR"),
2438                "/../../testdata/modelpack_split_17x30x18.bin"
2439            ));
2440            let detect1 =
2441                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2442
2443            let mut mask = seg.slice(s![0, .., .., ..]);
2444            mask.swap_axes(0, 1);
2445            mask.swap_axes(1, 2);
2446            let mask = [Segmentation {
2447                xmin: 0.0,
2448                ymin: 0.0,
2449                xmax: 1.0,
2450                ymax: 1.0,
2451                segmentation: mask.into_owned(),
2452            }];
2453            let correct_boxes = [DetectBox {
2454                bbox: BoundingBox {
2455                    xmin: 0.43171933,
2456                    ymin: 0.68243736,
2457                    xmax: 0.5626645,
2458                    ymax: 0.808863,
2459                },
2460                score: 0.99240804,
2461                label: 0,
2462            }];
2463
2464            let quant0 = (0.08547406643629074, 174).into();
2465            let quant1 = (0.09929127991199493, 183).into();
2466            let quant_seg = (1.0 / 255.0, 0).into();
2467
2468            let anchors0 = vec![
2469                [0.36666667461395264, 0.31481480598449707],
2470                [0.38749998807907104, 0.4740740656852722],
2471                [0.5333333611488342, 0.644444465637207],
2472            ];
2473            let anchors1 = vec![
2474                [0.13750000298023224, 0.2074074000120163],
2475                [0.2541666626930237, 0.21481481194496155],
2476                [0.23125000298023224, 0.35185185074806213],
2477            ];
2478
2479            let decoder = DecoderBuilder::default()
2480                .with_config_modelpack_segdet_split(
2481                    vec![
2482                        configs::Detection {
2483                            decoder: DecoderType::ModelPack,
2484                            shape: vec![1, 17, 30, 18],
2485                            anchors: Some(anchors1),
2486                            quantization: Some(quant1),
2487                            dshape: vec![
2488                                (DimName::Batch, 1),
2489                                (DimName::Height, 17),
2490                                (DimName::Width, 30),
2491                                (DimName::NumAnchorsXFeatures, 18),
2492                            ],
2493                            normalized: None,
2494                        },
2495                        configs::Detection {
2496                            decoder: DecoderType::ModelPack,
2497                            shape: vec![1, 9, 15, 18],
2498                            anchors: Some(anchors0),
2499                            quantization: Some(quant0),
2500                            dshape: vec![
2501                                (DimName::Batch, 1),
2502                                (DimName::Height, 9),
2503                                (DimName::Width, 15),
2504                                (DimName::NumAnchorsXFeatures, 18),
2505                            ],
2506                            normalized: None,
2507                        },
2508                    ],
2509                    configs::Segmentation {
2510                        decoder: DecoderType::ModelPack,
2511                        quantization: Some(quant_seg),
2512                        shape: vec![1, 2, 160, 160],
2513                        dshape: vec![
2514                            (DimName::Batch, 1),
2515                            (DimName::NumClasses, 2),
2516                            (DimName::Height, 160),
2517                            (DimName::Width, 160),
2518                        ],
2519                    },
2520                )
2521                .with_score_threshold(score_threshold)
2522                .with_iou_threshold(iou_threshold)
2523                .build()
2524                .unwrap();
2525            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2526            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2527
2528            for _ in 0..100 {
2529                decoder
2530                    .decode_quantized(
2531                        &[
2532                            detect0.view().into(),
2533                            detect1.view().into(),
2534                            seg.view().into(),
2535                        ],
2536                        &mut output_boxes,
2537                        &mut output_masks,
2538                    )
2539                    .unwrap();
2540
2541                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2542            }
2543        };
2544
2545        let handles = vec![
2546            std::thread::spawn(yolo_det),
2547            std::thread::spawn(modelpack_det_split),
2548            std::thread::spawn(yolo_det),
2549            std::thread::spawn(modelpack_det_split),
2550            std::thread::spawn(yolo_det),
2551            std::thread::spawn(modelpack_det_split),
2552            std::thread::spawn(yolo_det),
2553            std::thread::spawn(modelpack_det_split),
2554        ];
2555        for handle in handles {
2556            handle.join().unwrap();
2557        }
2558    }
2559
2560    #[test]
2561    fn test_ndarray_to_xyxy_float() {
2562        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2563        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2564        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2565
2566        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2567        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2568        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2569    }
2570
2571    #[test]
2572    fn test_class_aware_nms_float() {
2573        use crate::float::nms_class_aware_float;
2574
2575        // Create two overlapping boxes with different classes
2576        let boxes = vec![
2577            DetectBox {
2578                bbox: BoundingBox {
2579                    xmin: 0.0,
2580                    ymin: 0.0,
2581                    xmax: 0.5,
2582                    ymax: 0.5,
2583                },
2584                score: 0.9,
2585                label: 0, // class 0
2586            },
2587            DetectBox {
2588                bbox: BoundingBox {
2589                    xmin: 0.1,
2590                    ymin: 0.1,
2591                    xmax: 0.6,
2592                    ymax: 0.6,
2593                },
2594                score: 0.8,
2595                label: 1, // class 1 - different class
2596            },
2597        ];
2598
2599        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2600        // threshold 0.3)
2601        let result = nms_class_aware_float(0.3, boxes.clone());
2602        assert_eq!(
2603            result.len(),
2604            2,
2605            "Class-aware NMS should keep both boxes with different classes"
2606        );
2607
2608        // Now test with same class - should suppress one
2609        let same_class_boxes = vec![
2610            DetectBox {
2611                bbox: BoundingBox {
2612                    xmin: 0.0,
2613                    ymin: 0.0,
2614                    xmax: 0.5,
2615                    ymax: 0.5,
2616                },
2617                score: 0.9,
2618                label: 0,
2619            },
2620            DetectBox {
2621                bbox: BoundingBox {
2622                    xmin: 0.1,
2623                    ymin: 0.1,
2624                    xmax: 0.6,
2625                    ymax: 0.6,
2626                },
2627                score: 0.8,
2628                label: 0, // same class
2629            },
2630        ];
2631
2632        let result = nms_class_aware_float(0.3, same_class_boxes);
2633        assert_eq!(
2634            result.len(),
2635            1,
2636            "Class-aware NMS should suppress overlapping box with same class"
2637        );
2638        assert_eq!(result[0].label, 0);
2639        assert!((result[0].score - 0.9).abs() < 1e-6);
2640    }
2641
2642    #[test]
2643    fn test_class_agnostic_vs_aware_nms() {
2644        use crate::float::{nms_class_aware_float, nms_float};
2645
2646        // Two overlapping boxes with different classes
2647        let boxes = vec![
2648            DetectBox {
2649                bbox: BoundingBox {
2650                    xmin: 0.0,
2651                    ymin: 0.0,
2652                    xmax: 0.5,
2653                    ymax: 0.5,
2654                },
2655                score: 0.9,
2656                label: 0,
2657            },
2658            DetectBox {
2659                bbox: BoundingBox {
2660                    xmin: 0.1,
2661                    ymin: 0.1,
2662                    xmax: 0.6,
2663                    ymax: 0.6,
2664                },
2665                score: 0.8,
2666                label: 1,
2667            },
2668        ];
2669
2670        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2671        let agnostic_result = nms_float(0.3, boxes.clone());
2672        assert_eq!(
2673            agnostic_result.len(),
2674            1,
2675            "Class-agnostic NMS should suppress overlapping boxes"
2676        );
2677
2678        // Class-aware should keep both (different classes)
2679        let aware_result = nms_class_aware_float(0.3, boxes);
2680        assert_eq!(
2681            aware_result.len(),
2682            2,
2683            "Class-aware NMS should keep boxes with different classes"
2684        );
2685    }
2686
2687    #[test]
2688    fn test_class_aware_nms_int() {
2689        use crate::byte::nms_class_aware_int;
2690
2691        // Create two overlapping boxes with different classes
2692        let boxes = vec![
2693            DetectBoxQuantized {
2694                bbox: BoundingBox {
2695                    xmin: 0.0,
2696                    ymin: 0.0,
2697                    xmax: 0.5,
2698                    ymax: 0.5,
2699                },
2700                score: 200_u8,
2701                label: 0,
2702            },
2703            DetectBoxQuantized {
2704                bbox: BoundingBox {
2705                    xmin: 0.1,
2706                    ymin: 0.1,
2707                    xmax: 0.6,
2708                    ymax: 0.6,
2709                },
2710                score: 180_u8,
2711                label: 1, // different class
2712            },
2713        ];
2714
2715        // Should keep both (different classes)
2716        let result = nms_class_aware_int(0.5, boxes);
2717        assert_eq!(
2718            result.len(),
2719            2,
2720            "Class-aware NMS (int) should keep boxes with different classes"
2721        );
2722    }
2723
2724    #[test]
2725    fn test_nms_enum_default() {
2726        // Test that Nms enum has the correct default
2727        let default_nms: configs::Nms = Default::default();
2728        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2729    }
2730
2731    #[test]
2732    fn test_decoder_nms_mode() {
2733        // Test that decoder properly stores NMS mode
2734        let decoder = DecoderBuilder::default()
2735            .with_config_yolo_det(
2736                configs::Detection {
2737                    anchors: None,
2738                    decoder: DecoderType::Ultralytics,
2739                    quantization: None,
2740                    shape: vec![1, 84, 8400],
2741                    dshape: Vec::new(),
2742                    normalized: Some(true),
2743                },
2744                None,
2745            )
2746            .with_nms(Some(configs::Nms::ClassAware))
2747            .build()
2748            .unwrap();
2749
2750        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2751    }
2752
2753    #[test]
2754    fn test_decoder_nms_bypass() {
2755        // Test that decoder can be configured with nms=None (bypass)
2756        let decoder = DecoderBuilder::default()
2757            .with_config_yolo_det(
2758                configs::Detection {
2759                    anchors: None,
2760                    decoder: DecoderType::Ultralytics,
2761                    quantization: None,
2762                    shape: vec![1, 84, 8400],
2763                    dshape: Vec::new(),
2764                    normalized: Some(true),
2765                },
2766                None,
2767            )
2768            .with_nms(None)
2769            .build()
2770            .unwrap();
2771
2772        assert_eq!(decoder.nms, None);
2773    }
2774
2775    #[test]
2776    fn test_decoder_normalized_boxes_true() {
2777        // Test that normalized_boxes returns Some(true) when explicitly set
2778        let decoder = DecoderBuilder::default()
2779            .with_config_yolo_det(
2780                configs::Detection {
2781                    anchors: None,
2782                    decoder: DecoderType::Ultralytics,
2783                    quantization: None,
2784                    shape: vec![1, 84, 8400],
2785                    dshape: Vec::new(),
2786                    normalized: Some(true),
2787                },
2788                None,
2789            )
2790            .build()
2791            .unwrap();
2792
2793        assert_eq!(decoder.normalized_boxes(), Some(true));
2794    }
2795
2796    #[test]
2797    fn test_decoder_normalized_boxes_false() {
2798        // Test that normalized_boxes returns Some(false) when config specifies
2799        // unnormalized
2800        let decoder = DecoderBuilder::default()
2801            .with_config_yolo_det(
2802                configs::Detection {
2803                    anchors: None,
2804                    decoder: DecoderType::Ultralytics,
2805                    quantization: None,
2806                    shape: vec![1, 84, 8400],
2807                    dshape: Vec::new(),
2808                    normalized: Some(false),
2809                },
2810                None,
2811            )
2812            .build()
2813            .unwrap();
2814
2815        assert_eq!(decoder.normalized_boxes(), Some(false));
2816    }
2817
2818    #[test]
2819    fn test_decoder_normalized_boxes_unknown() {
2820        // Test that normalized_boxes returns None when not specified in config
2821        let decoder = DecoderBuilder::default()
2822            .with_config_yolo_det(
2823                configs::Detection {
2824                    anchors: None,
2825                    decoder: DecoderType::Ultralytics,
2826                    quantization: None,
2827                    shape: vec![1, 84, 8400],
2828                    dshape: Vec::new(),
2829                    normalized: None,
2830                },
2831                Some(DecoderVersion::Yolo11),
2832            )
2833            .build()
2834            .unwrap();
2835
2836        assert_eq!(decoder.normalized_boxes(), None);
2837    }
2838}
2839
2840#[cfg(feature = "tracker")]
2841#[cfg(test)]
2842#[cfg_attr(coverage_nightly, coverage(off))]
2843mod decoder_tracked_tests {
2844
2845    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
2846    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
2847    use num_traits::{AsPrimitive, Float, PrimInt};
2848    use rand::{RngExt, SeedableRng};
2849    use rand_distr::StandardNormal;
2850
2851    use crate::{
2852        configs::{self, DimName},
2853        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
2854    };
2855
2856    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2857        input: ArrayView<F, D>,
2858        quant: Quantization,
2859    ) -> Array<T, D>
2860    where
2861        i32: num_traits::AsPrimitive<F>,
2862        f32: num_traits::AsPrimitive<F>,
2863    {
2864        let zero_point = quant.zero_point.as_();
2865        let div_scale = F::one() / quant.scale.as_();
2866        if zero_point != F::zero() {
2867            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2868        } else {
2869            input.mapv(|d| (d * div_scale).round().as_())
2870        }
2871    }
2872
2873    #[test]
2874    fn test_decoder_tracked_random_jitter() {
2875        use crate::configs::{DecoderType, Nms};
2876        use crate::DecoderBuilder;
2877
2878        let score_threshold = 0.25;
2879        let iou_threshold = 0.1;
2880        let out = include_bytes!(concat!(
2881            env!("CARGO_MANIFEST_DIR"),
2882            "/../../testdata/yolov8s_80_classes.bin"
2883        ));
2884        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2885        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2886        let quant = (0.0040811873, -123).into();
2887
2888        let decoder = DecoderBuilder::default()
2889            .with_config_yolo_det(
2890                crate::configs::Detection {
2891                    decoder: DecoderType::Ultralytics,
2892                    shape: vec![1, 84, 8400],
2893                    anchors: None,
2894                    quantization: Some(quant),
2895                    dshape: vec![
2896                        (crate::configs::DimName::Batch, 1),
2897                        (crate::configs::DimName::NumFeatures, 84),
2898                        (crate::configs::DimName::NumBoxes, 8400),
2899                    ],
2900                    normalized: Some(true),
2901                },
2902                None,
2903            )
2904            .with_score_threshold(score_threshold)
2905            .with_iou_threshold(iou_threshold)
2906            .with_nms(Some(Nms::ClassAgnostic))
2907            .build()
2908            .unwrap();
2909        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
2910
2911        let expected_boxes = [
2912            crate::DetectBox {
2913                bbox: crate::BoundingBox {
2914                    xmin: 0.5285137,
2915                    ymin: 0.05305544,
2916                    xmax: 0.87541467,
2917                    ymax: 0.9998909,
2918                },
2919                score: 0.5591227,
2920                label: 0,
2921            },
2922            crate::DetectBox {
2923                bbox: crate::BoundingBox {
2924                    xmin: 0.130598,
2925                    ymin: 0.43260583,
2926                    xmax: 0.35098213,
2927                    ymax: 0.9958097,
2928                },
2929                score: 0.33057618,
2930                label: 75,
2931            },
2932        ];
2933
2934        let mut tracker = ByteTrackBuilder::new()
2935            .track_update(0.1)
2936            .track_high_conf(0.3)
2937            .build();
2938
2939        let mut output_boxes = Vec::with_capacity(50);
2940        let mut output_masks = Vec::with_capacity(50);
2941        let mut output_tracks = Vec::with_capacity(50);
2942
2943        decoder
2944            .decode_tracked_quantized(
2945                &mut tracker,
2946                0,
2947                &[out.view().into()],
2948                &mut output_boxes,
2949                &mut output_masks,
2950                &mut output_tracks,
2951            )
2952            .unwrap();
2953
2954        assert_eq!(output_boxes.len(), 2);
2955        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
2956        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
2957
2958        let mut last_boxes = output_boxes.clone();
2959
2960        for i in 1..=100 {
2961            let mut out = out.clone();
2962            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
2963            let mut x_values = out.slice_mut(s![0, 0, ..]);
2964            for x in x_values.iter_mut() {
2965                let r: f32 = rng.sample(StandardNormal);
2966                let r = r.clamp(-2.0, 2.0) / 2.0;
2967                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
2968            }
2969
2970            let mut y_values = out.slice_mut(s![0, 1, ..]);
2971            for y in y_values.iter_mut() {
2972                let r: f32 = rng.sample(StandardNormal);
2973                let r = r.clamp(-2.0, 2.0) / 2.0;
2974                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
2975            }
2976
2977            decoder
2978                .decode_tracked_quantized(
2979                    &mut tracker,
2980                    100_000_000 * i / 3, // simulate 33.333ms between frames
2981                    &[out.view().into()],
2982                    &mut output_boxes,
2983                    &mut output_masks,
2984                    &mut output_tracks,
2985                )
2986                .unwrap();
2987
2988            assert_eq!(output_boxes.len(), 2);
2989            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
2990            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
2991
2992            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
2993            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
2994            last_boxes = output_boxes.clone();
2995        }
2996    }
2997
2998    // ─── Shared helpers for tracked decoder tests ────────────────────
2999
3000    fn real_data_expected_boxes() -> [DetectBox; 2] {
3001        [
3002            DetectBox {
3003                bbox: BoundingBox {
3004                    xmin: 0.08515105,
3005                    ymin: 0.7131401,
3006                    xmax: 0.29802868,
3007                    ymax: 0.8195788,
3008                },
3009                score: 0.91537374,
3010                label: 23,
3011            },
3012            DetectBox {
3013                bbox: BoundingBox {
3014                    xmin: 0.59605736,
3015                    ymin: 0.25545314,
3016                    xmax: 0.93666154,
3017                    ymax: 0.72378385,
3018                },
3019                score: 0.91537374,
3020                label: 23,
3021            },
3022        ]
3023    }
3024
3025    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
3026        [DetectBox {
3027            bbox: BoundingBox {
3028                xmin: 0.12549022,
3029                ymin: 0.12549022,
3030                xmax: 0.23529413,
3031                ymax: 0.23529413,
3032            },
3033            score: 0.98823535,
3034            label: 2,
3035        }]
3036    }
3037
3038    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
3039        [DetectBox {
3040            bbox: BoundingBox {
3041                xmin: 0.1234,
3042                ymin: 0.1234,
3043                xmax: 0.2345,
3044                ymax: 0.2345,
3045            },
3046            score: 0.9876,
3047            label: 2,
3048        }]
3049    }
3050
3051    fn build_split_decoder(
3052        score_threshold: f32,
3053        iou_threshold: f32,
3054        quant_boxes: (f32, i32),
3055        quant_protos: (f32, i32),
3056    ) -> crate::Decoder {
3057        DecoderBuilder::default()
3058            .with_config_yolo_split_segdet(
3059                configs::Boxes {
3060                    decoder: configs::DecoderType::Ultralytics,
3061                    quantization: Some(quant_boxes.into()),
3062                    shape: vec![1, 4, 8400],
3063                    dshape: vec![
3064                        (DimName::Batch, 1),
3065                        (DimName::BoxCoords, 4),
3066                        (DimName::NumBoxes, 8400),
3067                    ],
3068                    normalized: Some(true),
3069                },
3070                configs::Scores {
3071                    decoder: configs::DecoderType::Ultralytics,
3072                    quantization: Some(quant_boxes.into()),
3073                    shape: vec![1, 80, 8400],
3074                    dshape: vec![
3075                        (DimName::Batch, 1),
3076                        (DimName::NumClasses, 80),
3077                        (DimName::NumBoxes, 8400),
3078                    ],
3079                },
3080                configs::MaskCoefficients {
3081                    decoder: configs::DecoderType::Ultralytics,
3082                    quantization: Some(quant_boxes.into()),
3083                    shape: vec![1, 32, 8400],
3084                    dshape: vec![
3085                        (DimName::Batch, 1),
3086                        (DimName::NumProtos, 32),
3087                        (DimName::NumBoxes, 8400),
3088                    ],
3089                },
3090                configs::Protos {
3091                    decoder: configs::DecoderType::Ultralytics,
3092                    quantization: Some(quant_protos.into()),
3093                    shape: vec![1, 160, 160, 32],
3094                    dshape: vec![
3095                        (DimName::Batch, 1),
3096                        (DimName::Height, 160),
3097                        (DimName::Width, 160),
3098                        (DimName::NumProtos, 32),
3099                    ],
3100                },
3101            )
3102            .with_score_threshold(score_threshold)
3103            .with_iou_threshold(iou_threshold)
3104            .build()
3105            .unwrap()
3106    }
3107
3108    // ─── Real-data tracked test macro ───────────────────────────────
3109    //
3110    // Generates tests that load i8 binary test data from testdata/ and
3111    // exercise all (quant/float) × (combined/split) × (masks/proto)
3112    // decoder paths.
3113
3114    macro_rules! real_data_tracked_test {
3115        ($name:ident, quantized, $layout:ident, $output:ident) => {
3116            #[test]
3117            fn $name() {
3118                use crate::configs::Nms;
3119                let is_split = matches!(stringify!($layout), "split");
3120                let is_proto = matches!(stringify!($output), "proto");
3121
3122                let score_threshold = 0.45;
3123                let iou_threshold = 0.45;
3124                let quant_boxes = (0.021287762_f32, 31_i32);
3125                let quant_protos = (0.02491162_f32, -117_i32);
3126
3127                let raw_boxes = include_bytes!(concat!(
3128                    env!("CARGO_MANIFEST_DIR"),
3129                    "/../../testdata/yolov8_boxes_116x8400.bin"
3130                ));
3131                let raw_boxes = unsafe {
3132                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
3133                };
3134                let boxes_i8 =
3135                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
3136
3137                let raw_protos = include_bytes!(concat!(
3138                    env!("CARGO_MANIFEST_DIR"),
3139                    "/../../testdata/yolov8_protos_160x160x32.bin"
3140                ));
3141                let raw_protos = unsafe {
3142                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3143                };
3144                let protos_i8 =
3145                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
3146                        .unwrap();
3147
3148                // Pre-split (unused for combined, but harmless)
3149                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
3150                let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
3151                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
3152                let mut boxes_combined = boxes_i8;
3153
3154                let decoder = if is_split {
3155                    build_split_decoder(score_threshold, iou_threshold, quant_boxes, quant_protos)
3156                } else {
3157                    let config_yaml = include_str!(concat!(
3158                        env!("CARGO_MANIFEST_DIR"),
3159                        "/../../testdata/yolov8_seg.yaml"
3160                    ));
3161                    DecoderBuilder::default()
3162                        .with_config_yaml_str(config_yaml.to_string())
3163                        .with_score_threshold(score_threshold)
3164                        .with_iou_threshold(iou_threshold)
3165                        .with_nms(Some(Nms::ClassAgnostic))
3166                        .build()
3167                        .unwrap()
3168                };
3169
3170                let expected = real_data_expected_boxes();
3171                let mut tracker = ByteTrackBuilder::new()
3172                    .track_update(0.1)
3173                    .track_high_conf(0.7)
3174                    .build();
3175                let mut output_boxes = Vec::with_capacity(50);
3176                let mut output_tracks = Vec::with_capacity(50);
3177
3178                // Frame 1: decode
3179                if is_proto {
3180                    {
3181                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3182                            vec![
3183                                boxes_split.view().into(),
3184                                scores_split.view().into(),
3185                                mask_split.view().into(),
3186                                protos_i8.view().into(),
3187                            ]
3188                        } else {
3189                            vec![boxes_combined.view().into(), protos_i8.view().into()]
3190                        };
3191                        decoder
3192                            .decode_tracked_quantized_proto(
3193                                &mut tracker,
3194                                0,
3195                                &inputs,
3196                                &mut output_boxes,
3197                                &mut output_tracks,
3198                            )
3199                            .unwrap();
3200                    }
3201                    assert_eq!(output_boxes.len(), 2);
3202                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3203                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3204
3205                    // Zero scores for frame 2
3206                    if is_split {
3207                        for score in scores_split.iter_mut() {
3208                            *score = i8::MIN;
3209                        }
3210                    } else {
3211                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
3212                            *score = i8::MIN;
3213                        }
3214                    }
3215
3216                    let proto_result = {
3217                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3218                            vec![
3219                                boxes_split.view().into(),
3220                                scores_split.view().into(),
3221                                mask_split.view().into(),
3222                                protos_i8.view().into(),
3223                            ]
3224                        } else {
3225                            vec![boxes_combined.view().into(), protos_i8.view().into()]
3226                        };
3227                        decoder
3228                            .decode_tracked_quantized_proto(
3229                                &mut tracker,
3230                                100_000_000 / 3,
3231                                &inputs,
3232                                &mut output_boxes,
3233                                &mut output_tracks,
3234                            )
3235                            .unwrap()
3236                    };
3237                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3238                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
3239                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3240                } else {
3241                    let mut output_masks = Vec::with_capacity(50);
3242                    {
3243                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3244                            vec![
3245                                boxes_split.view().into(),
3246                                scores_split.view().into(),
3247                                mask_split.view().into(),
3248                                protos_i8.view().into(),
3249                            ]
3250                        } else {
3251                            vec![boxes_combined.view().into(), protos_i8.view().into()]
3252                        };
3253                        decoder
3254                            .decode_tracked_quantized(
3255                                &mut tracker,
3256                                0,
3257                                &inputs,
3258                                &mut output_boxes,
3259                                &mut output_masks,
3260                                &mut output_tracks,
3261                            )
3262                            .unwrap();
3263                    }
3264                    assert_eq!(output_boxes.len(), 2);
3265                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3266                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3267
3268                    if is_split {
3269                        for score in scores_split.iter_mut() {
3270                            *score = i8::MIN;
3271                        }
3272                    } else {
3273                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
3274                            *score = i8::MIN;
3275                        }
3276                    }
3277
3278                    {
3279                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3280                            vec![
3281                                boxes_split.view().into(),
3282                                scores_split.view().into(),
3283                                mask_split.view().into(),
3284                                protos_i8.view().into(),
3285                            ]
3286                        } else {
3287                            vec![boxes_combined.view().into(), protos_i8.view().into()]
3288                        };
3289                        decoder
3290                            .decode_tracked_quantized(
3291                                &mut tracker,
3292                                100_000_000 / 3,
3293                                &inputs,
3294                                &mut output_boxes,
3295                                &mut output_masks,
3296                                &mut output_tracks,
3297                            )
3298                            .unwrap();
3299                    }
3300                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3301                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
3302                    assert!(output_masks.is_empty());
3303                }
3304            }
3305        };
3306        ($name:ident, float, $layout:ident, $output:ident) => {
3307            #[test]
3308            fn $name() {
3309                use crate::configs::Nms;
3310                let is_split = matches!(stringify!($layout), "split");
3311                let is_proto = matches!(stringify!($output), "proto");
3312
3313                let score_threshold = 0.45;
3314                let iou_threshold = 0.45;
3315                let quant_boxes = (0.021287762_f32, 31_i32);
3316                let quant_protos = (0.02491162_f32, -117_i32);
3317
3318                let raw_boxes = include_bytes!(concat!(
3319                    env!("CARGO_MANIFEST_DIR"),
3320                    "/../../testdata/yolov8_boxes_116x8400.bin"
3321                ));
3322                let raw_boxes = unsafe {
3323                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
3324                };
3325                let boxes_i8 =
3326                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
3327                let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
3328
3329                let raw_protos = include_bytes!(concat!(
3330                    env!("CARGO_MANIFEST_DIR"),
3331                    "/../../testdata/yolov8_protos_160x160x32.bin"
3332                ));
3333                let raw_protos = unsafe {
3334                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3335                };
3336                let protos_i8 =
3337                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
3338                        .unwrap();
3339                let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
3340
3341                // Pre-split from dequantized data
3342                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
3343                let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
3344                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
3345                let mut boxes_combined = boxes_f32;
3346
3347                let decoder = if is_split {
3348                    build_split_decoder(score_threshold, iou_threshold, quant_boxes, quant_protos)
3349                } else {
3350                    let config_yaml = include_str!(concat!(
3351                        env!("CARGO_MANIFEST_DIR"),
3352                        "/../../testdata/yolov8_seg.yaml"
3353                    ));
3354                    DecoderBuilder::default()
3355                        .with_config_yaml_str(config_yaml.to_string())
3356                        .with_score_threshold(score_threshold)
3357                        .with_iou_threshold(iou_threshold)
3358                        .with_nms(Some(Nms::ClassAgnostic))
3359                        .build()
3360                        .unwrap()
3361                };
3362
3363                let expected = real_data_expected_boxes();
3364                let mut tracker = ByteTrackBuilder::new()
3365                    .track_update(0.1)
3366                    .track_high_conf(0.7)
3367                    .build();
3368                let mut output_boxes = Vec::with_capacity(50);
3369                let mut output_tracks = Vec::with_capacity(50);
3370
3371                if is_proto {
3372                    {
3373                        let inputs = if is_split {
3374                            vec![
3375                                boxes_split.view().into_dyn(),
3376                                scores_split.view().into_dyn(),
3377                                mask_split.view().into_dyn(),
3378                                protos_f32.view().into_dyn(),
3379                            ]
3380                        } else {
3381                            vec![
3382                                boxes_combined.view().into_dyn(),
3383                                protos_f32.view().into_dyn(),
3384                            ]
3385                        };
3386                        decoder
3387                            .decode_tracked_float_proto(
3388                                &mut tracker,
3389                                0,
3390                                &inputs,
3391                                &mut output_boxes,
3392                                &mut output_tracks,
3393                            )
3394                            .unwrap();
3395                    }
3396                    assert_eq!(output_boxes.len(), 2);
3397                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3398                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3399
3400                    if is_split {
3401                        for score in scores_split.iter_mut() {
3402                            *score = 0.0;
3403                        }
3404                    } else {
3405                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
3406                            *score = 0.0;
3407                        }
3408                    }
3409
3410                    let proto_result = {
3411                        let inputs = if is_split {
3412                            vec![
3413                                boxes_split.view().into_dyn(),
3414                                scores_split.view().into_dyn(),
3415                                mask_split.view().into_dyn(),
3416                                protos_f32.view().into_dyn(),
3417                            ]
3418                        } else {
3419                            vec![
3420                                boxes_combined.view().into_dyn(),
3421                                protos_f32.view().into_dyn(),
3422                            ]
3423                        };
3424                        decoder
3425                            .decode_tracked_float_proto(
3426                                &mut tracker,
3427                                100_000_000 / 3,
3428                                &inputs,
3429                                &mut output_boxes,
3430                                &mut output_tracks,
3431                            )
3432                            .unwrap()
3433                    };
3434                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3435                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
3436                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3437                } else {
3438                    let mut output_masks = Vec::with_capacity(50);
3439                    {
3440                        let inputs = if is_split {
3441                            vec![
3442                                boxes_split.view().into_dyn(),
3443                                scores_split.view().into_dyn(),
3444                                mask_split.view().into_dyn(),
3445                                protos_f32.view().into_dyn(),
3446                            ]
3447                        } else {
3448                            vec![
3449                                boxes_combined.view().into_dyn(),
3450                                protos_f32.view().into_dyn(),
3451                            ]
3452                        };
3453                        decoder
3454                            .decode_tracked_float(
3455                                &mut tracker,
3456                                0,
3457                                &inputs,
3458                                &mut output_boxes,
3459                                &mut output_masks,
3460                                &mut output_tracks,
3461                            )
3462                            .unwrap();
3463                    }
3464                    assert_eq!(output_boxes.len(), 2);
3465                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3466                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3467
3468                    if is_split {
3469                        for score in scores_split.iter_mut() {
3470                            *score = 0.0;
3471                        }
3472                    } else {
3473                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
3474                            *score = 0.0;
3475                        }
3476                    }
3477
3478                    {
3479                        let inputs = if is_split {
3480                            vec![
3481                                boxes_split.view().into_dyn(),
3482                                scores_split.view().into_dyn(),
3483                                mask_split.view().into_dyn(),
3484                                protos_f32.view().into_dyn(),
3485                            ]
3486                        } else {
3487                            vec![
3488                                boxes_combined.view().into_dyn(),
3489                                protos_f32.view().into_dyn(),
3490                            ]
3491                        };
3492                        decoder
3493                            .decode_tracked_float(
3494                                &mut tracker,
3495                                100_000_000 / 3,
3496                                &inputs,
3497                                &mut output_boxes,
3498                                &mut output_masks,
3499                                &mut output_tracks,
3500                            )
3501                            .unwrap();
3502                    }
3503                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3504                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
3505                    assert!(output_masks.is_empty());
3506                }
3507            }
3508        };
3509    }
3510
3511    real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
3512    real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
3513    real_data_tracked_test!(
3514        test_decoder_tracked_segdet_proto,
3515        quantized,
3516        combined,
3517        proto
3518    );
3519    real_data_tracked_test!(
3520        test_decoder_tracked_segdet_proto_float,
3521        float,
3522        combined,
3523        proto
3524    );
3525    real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
3526    real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
3527    real_data_tracked_test!(
3528        test_decoder_tracked_segdet_split_proto,
3529        quantized,
3530        split,
3531        proto
3532    );
3533    real_data_tracked_test!(
3534        test_decoder_tracked_segdet_split_proto_float,
3535        float,
3536        split,
3537        proto
3538    );
3539
3540    // ─── End-to-end tracked test macro ──────────────────────────────
3541    //
3542    // Generates tests with synthetic data to exercise all tracked
3543    // decode paths without needing real model output files.
3544
3545    const E2E_COMBINED_CONFIG: &str = "
3546decoder_version: yolo26
3547outputs:
3548 - type: detection
3549   decoder: ultralytics
3550   quantization: [0.00784313725490196, 0]
3551   shape: [1, 10, 38]
3552   dshape:
3553    - [batch, 1]
3554    - [num_boxes, 10]
3555    - [num_features, 38]
3556   normalized: true
3557 - type: protos
3558   decoder: ultralytics
3559   quantization: [0.0039215686274509803921568627451, 128]
3560   shape: [1, 160, 160, 32]
3561   dshape:
3562    - [batch, 1]
3563    - [height, 160]
3564    - [width, 160]
3565    - [num_protos, 32]
3566";
3567
3568    const E2E_SPLIT_CONFIG: &str = "
3569decoder_version: yolo26
3570outputs:
3571 - type: boxes
3572   decoder: ultralytics
3573   quantization: [0.00784313725490196, 0]
3574   shape: [1, 10, 4]
3575   dshape:
3576    - [batch, 1]
3577    - [num_boxes, 10]
3578    - [box_coords, 4]
3579   normalized: true
3580 - type: scores
3581   decoder: ultralytics
3582   quantization: [0.00784313725490196, 0]
3583   shape: [1, 10, 1]
3584   dshape:
3585    - [batch, 1]
3586    - [num_boxes, 10]
3587    - [num_classes, 1]
3588 - type: classes
3589   decoder: ultralytics
3590   quantization: [0.00784313725490196, 0]
3591   shape: [1, 10, 1]
3592   dshape:
3593    - [batch, 1]
3594    - [num_boxes, 10]
3595    - [num_classes, 1]
3596 - type: mask_coefficients
3597   decoder: ultralytics
3598   quantization: [0.00784313725490196, 0]
3599   shape: [1, 10, 32]
3600   dshape:
3601    - [batch, 1]
3602    - [num_boxes, 10]
3603    - [num_protos, 32]
3604 - type: protos
3605   decoder: ultralytics
3606   quantization: [0.0039215686274509803921568627451, 128]
3607   shape: [1, 160, 160, 32]
3608   dshape:
3609    - [batch, 1]
3610    - [height, 160]
3611    - [width, 160]
3612    - [num_protos, 32]
3613";
3614
3615    macro_rules! e2e_tracked_test {
3616        ($name:ident, quantized, $layout:ident, $output:ident) => {
3617            #[test]
3618            fn $name() {
3619                let is_split = matches!(stringify!($layout), "split");
3620                let is_proto = matches!(stringify!($output), "proto");
3621
3622                let score_threshold = 0.45;
3623                let iou_threshold = 0.45;
3624
3625                let mut boxes = Array2::zeros((10, 4));
3626                let mut scores = Array2::zeros((10, 1));
3627                let mut classes = Array2::zeros((10, 1));
3628                let mask = Array2::zeros((10, 32));
3629                let protos = Array3::<f64>::zeros((160, 160, 32));
3630                let protos = protos.insert_axis(Axis(0));
3631                let protos_quant = (1.0 / 255.0, 0.0);
3632                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3633
3634                boxes
3635                    .slice_mut(s![0, ..])
3636                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3637                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3638                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3639
3640                let detect_quant = (2.0 / 255.0, 0.0);
3641
3642                let decoder = if is_split {
3643                    DecoderBuilder::default()
3644                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
3645                        .with_score_threshold(score_threshold)
3646                        .with_iou_threshold(iou_threshold)
3647                        .build()
3648                        .unwrap()
3649                } else {
3650                    DecoderBuilder::default()
3651                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
3652                        .with_score_threshold(score_threshold)
3653                        .with_iou_threshold(iou_threshold)
3654                        .build()
3655                        .unwrap()
3656                };
3657
3658                let expected = e2e_expected_boxes_quant();
3659                let mut tracker = ByteTrackBuilder::new()
3660                    .track_update(0.1)
3661                    .track_high_conf(0.7)
3662                    .build();
3663                let mut output_boxes = Vec::with_capacity(50);
3664                let mut output_tracks = Vec::with_capacity(50);
3665
3666                if is_split {
3667                    let boxes = boxes.insert_axis(Axis(0));
3668                    let scores = scores.insert_axis(Axis(0));
3669                    let classes = classes.insert_axis(Axis(0));
3670                    let mask = mask.insert_axis(Axis(0));
3671
3672                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3673                    let mut scores: Array3<u8> =
3674                        quantize_ndarray(scores.view(), detect_quant.into());
3675                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3676                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3677
3678                    if is_proto {
3679                        {
3680                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3681                                boxes.view().into(),
3682                                scores.view().into(),
3683                                classes.view().into(),
3684                                mask.view().into(),
3685                                protos.view().into(),
3686                            ];
3687                            decoder
3688                                .decode_tracked_quantized_proto(
3689                                    &mut tracker,
3690                                    0,
3691                                    &inputs,
3692                                    &mut output_boxes,
3693                                    &mut output_tracks,
3694                                )
3695                                .unwrap();
3696                        }
3697                        assert_eq!(output_boxes.len(), 1);
3698                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3699
3700                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
3701                            *score = u8::MIN;
3702                        }
3703                        let proto_result = {
3704                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3705                                boxes.view().into(),
3706                                scores.view().into(),
3707                                classes.view().into(),
3708                                mask.view().into(),
3709                                protos.view().into(),
3710                            ];
3711                            decoder
3712                                .decode_tracked_quantized_proto(
3713                                    &mut tracker,
3714                                    100_000_000 / 3,
3715                                    &inputs,
3716                                    &mut output_boxes,
3717                                    &mut output_tracks,
3718                                )
3719                                .unwrap()
3720                        };
3721                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3722                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3723                    } else {
3724                        let mut output_masks = Vec::with_capacity(50);
3725                        {
3726                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3727                                boxes.view().into(),
3728                                scores.view().into(),
3729                                classes.view().into(),
3730                                mask.view().into(),
3731                                protos.view().into(),
3732                            ];
3733                            decoder
3734                                .decode_tracked_quantized(
3735                                    &mut tracker,
3736                                    0,
3737                                    &inputs,
3738                                    &mut output_boxes,
3739                                    &mut output_masks,
3740                                    &mut output_tracks,
3741                                )
3742                                .unwrap();
3743                        }
3744                        assert_eq!(output_boxes.len(), 1);
3745                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3746
3747                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
3748                            *score = u8::MIN;
3749                        }
3750                        {
3751                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3752                                boxes.view().into(),
3753                                scores.view().into(),
3754                                classes.view().into(),
3755                                mask.view().into(),
3756                                protos.view().into(),
3757                            ];
3758                            decoder
3759                                .decode_tracked_quantized(
3760                                    &mut tracker,
3761                                    100_000_000 / 3,
3762                                    &inputs,
3763                                    &mut output_boxes,
3764                                    &mut output_masks,
3765                                    &mut output_tracks,
3766                                )
3767                                .unwrap();
3768                        }
3769                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3770                        assert!(output_masks.is_empty());
3771                    }
3772                } else {
3773                    // Combined layout
3774                    let detect = ndarray::concatenate![
3775                        Axis(1),
3776                        boxes.view(),
3777                        scores.view(),
3778                        classes.view(),
3779                        mask.view()
3780                    ];
3781                    let detect = detect.insert_axis(Axis(0));
3782                    assert_eq!(detect.shape(), &[1, 10, 38]);
3783                    let mut detect: Array3<u8> =
3784                        quantize_ndarray(detect.view(), detect_quant.into());
3785
3786                    if is_proto {
3787                        {
3788                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3789                                vec![detect.view().into(), protos.view().into()];
3790                            decoder
3791                                .decode_tracked_quantized_proto(
3792                                    &mut tracker,
3793                                    0,
3794                                    &inputs,
3795                                    &mut output_boxes,
3796                                    &mut output_tracks,
3797                                )
3798                                .unwrap();
3799                        }
3800                        assert_eq!(output_boxes.len(), 1);
3801                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3802
3803                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
3804                            *score = u8::MIN;
3805                        }
3806                        let proto_result = {
3807                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3808                                vec![detect.view().into(), protos.view().into()];
3809                            decoder
3810                                .decode_tracked_quantized_proto(
3811                                    &mut tracker,
3812                                    100_000_000 / 3,
3813                                    &inputs,
3814                                    &mut output_boxes,
3815                                    &mut output_tracks,
3816                                )
3817                                .unwrap()
3818                        };
3819                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3820                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3821                    } else {
3822                        let mut output_masks = Vec::with_capacity(50);
3823                        {
3824                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3825                                vec![detect.view().into(), protos.view().into()];
3826                            decoder
3827                                .decode_tracked_quantized(
3828                                    &mut tracker,
3829                                    0,
3830                                    &inputs,
3831                                    &mut output_boxes,
3832                                    &mut output_masks,
3833                                    &mut output_tracks,
3834                                )
3835                                .unwrap();
3836                        }
3837                        assert_eq!(output_boxes.len(), 1);
3838                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3839
3840                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
3841                            *score = u8::MIN;
3842                        }
3843                        {
3844                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3845                                vec![detect.view().into(), protos.view().into()];
3846                            decoder
3847                                .decode_tracked_quantized(
3848                                    &mut tracker,
3849                                    100_000_000 / 3,
3850                                    &inputs,
3851                                    &mut output_boxes,
3852                                    &mut output_masks,
3853                                    &mut output_tracks,
3854                                )
3855                                .unwrap();
3856                        }
3857                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3858                        assert!(output_masks.is_empty());
3859                    }
3860                }
3861            }
3862        };
3863        ($name:ident, float, $layout:ident, $output:ident) => {
3864            #[test]
3865            fn $name() {
3866                let is_split = matches!(stringify!($layout), "split");
3867                let is_proto = matches!(stringify!($output), "proto");
3868
3869                let score_threshold = 0.45;
3870                let iou_threshold = 0.45;
3871
3872                let mut boxes = Array2::zeros((10, 4));
3873                let mut scores = Array2::zeros((10, 1));
3874                let mut classes = Array2::zeros((10, 1));
3875                let mask: Array2<f64> = Array2::zeros((10, 32));
3876                let protos = Array3::<f64>::zeros((160, 160, 32));
3877                let protos = protos.insert_axis(Axis(0));
3878
3879                boxes
3880                    .slice_mut(s![0, ..])
3881                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3882                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3883                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3884
3885                let decoder = if is_split {
3886                    DecoderBuilder::default()
3887                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
3888                        .with_score_threshold(score_threshold)
3889                        .with_iou_threshold(iou_threshold)
3890                        .build()
3891                        .unwrap()
3892                } else {
3893                    DecoderBuilder::default()
3894                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
3895                        .with_score_threshold(score_threshold)
3896                        .with_iou_threshold(iou_threshold)
3897                        .build()
3898                        .unwrap()
3899                };
3900
3901                let expected = e2e_expected_boxes_float();
3902                let mut tracker = ByteTrackBuilder::new()
3903                    .track_update(0.1)
3904                    .track_high_conf(0.7)
3905                    .build();
3906                let mut output_boxes = Vec::with_capacity(50);
3907                let mut output_tracks = Vec::with_capacity(50);
3908
3909                if is_split {
3910                    let boxes = boxes.insert_axis(Axis(0));
3911                    let mut scores = scores.insert_axis(Axis(0));
3912                    let classes = classes.insert_axis(Axis(0));
3913                    let mask = mask.insert_axis(Axis(0));
3914
3915                    if is_proto {
3916                        {
3917                            let inputs = vec![
3918                                boxes.view().into_dyn(),
3919                                scores.view().into_dyn(),
3920                                classes.view().into_dyn(),
3921                                mask.view().into_dyn(),
3922                                protos.view().into_dyn(),
3923                            ];
3924                            decoder
3925                                .decode_tracked_float_proto(
3926                                    &mut tracker,
3927                                    0,
3928                                    &inputs,
3929                                    &mut output_boxes,
3930                                    &mut output_tracks,
3931                                )
3932                                .unwrap();
3933                        }
3934                        assert_eq!(output_boxes.len(), 1);
3935                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3936
3937                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
3938                            *score = 0.0;
3939                        }
3940                        let proto_result = {
3941                            let inputs = vec![
3942                                boxes.view().into_dyn(),
3943                                scores.view().into_dyn(),
3944                                classes.view().into_dyn(),
3945                                mask.view().into_dyn(),
3946                                protos.view().into_dyn(),
3947                            ];
3948                            decoder
3949                                .decode_tracked_float_proto(
3950                                    &mut tracker,
3951                                    100_000_000 / 3,
3952                                    &inputs,
3953                                    &mut output_boxes,
3954                                    &mut output_tracks,
3955                                )
3956                                .unwrap()
3957                        };
3958                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3959                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
3960                    } else {
3961                        let mut output_masks = Vec::with_capacity(50);
3962                        {
3963                            let inputs = vec![
3964                                boxes.view().into_dyn(),
3965                                scores.view().into_dyn(),
3966                                classes.view().into_dyn(),
3967                                mask.view().into_dyn(),
3968                                protos.view().into_dyn(),
3969                            ];
3970                            decoder
3971                                .decode_tracked_float(
3972                                    &mut tracker,
3973                                    0,
3974                                    &inputs,
3975                                    &mut output_boxes,
3976                                    &mut output_masks,
3977                                    &mut output_tracks,
3978                                )
3979                                .unwrap();
3980                        }
3981                        assert_eq!(output_boxes.len(), 1);
3982                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3983
3984                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
3985                            *score = 0.0;
3986                        }
3987                        {
3988                            let inputs = vec![
3989                                boxes.view().into_dyn(),
3990                                scores.view().into_dyn(),
3991                                classes.view().into_dyn(),
3992                                mask.view().into_dyn(),
3993                                protos.view().into_dyn(),
3994                            ];
3995                            decoder
3996                                .decode_tracked_float(
3997                                    &mut tracker,
3998                                    100_000_000 / 3,
3999                                    &inputs,
4000                                    &mut output_boxes,
4001                                    &mut output_masks,
4002                                    &mut output_tracks,
4003                                )
4004                                .unwrap();
4005                        }
4006                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4007                        assert!(output_masks.is_empty());
4008                    }
4009                } else {
4010                    // Combined layout
4011                    let detect = ndarray::concatenate![
4012                        Axis(1),
4013                        boxes.view(),
4014                        scores.view(),
4015                        classes.view(),
4016                        mask.view()
4017                    ];
4018                    let mut detect = detect.insert_axis(Axis(0));
4019                    assert_eq!(detect.shape(), &[1, 10, 38]);
4020
4021                    if is_proto {
4022                        {
4023                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4024                            decoder
4025                                .decode_tracked_float_proto(
4026                                    &mut tracker,
4027                                    0,
4028                                    &inputs,
4029                                    &mut output_boxes,
4030                                    &mut output_tracks,
4031                                )
4032                                .unwrap();
4033                        }
4034                        assert_eq!(output_boxes.len(), 1);
4035                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4036
4037                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4038                            *score = 0.0;
4039                        }
4040                        let proto_result = {
4041                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4042                            decoder
4043                                .decode_tracked_float_proto(
4044                                    &mut tracker,
4045                                    100_000_000 / 3,
4046                                    &inputs,
4047                                    &mut output_boxes,
4048                                    &mut output_tracks,
4049                                )
4050                                .unwrap()
4051                        };
4052                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4053                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4054                    } else {
4055                        let mut output_masks = Vec::with_capacity(50);
4056                        {
4057                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4058                            decoder
4059                                .decode_tracked_float(
4060                                    &mut tracker,
4061                                    0,
4062                                    &inputs,
4063                                    &mut output_boxes,
4064                                    &mut output_masks,
4065                                    &mut output_tracks,
4066                                )
4067                                .unwrap();
4068                        }
4069                        assert_eq!(output_boxes.len(), 1);
4070                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4071
4072                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4073                            *score = 0.0;
4074                        }
4075                        {
4076                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4077                            decoder
4078                                .decode_tracked_float(
4079                                    &mut tracker,
4080                                    100_000_000 / 3,
4081                                    &inputs,
4082                                    &mut output_boxes,
4083                                    &mut output_masks,
4084                                    &mut output_tracks,
4085                                )
4086                                .unwrap();
4087                        }
4088                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4089                        assert!(output_masks.is_empty());
4090                    }
4091                }
4092            }
4093        };
4094    }
4095
4096    e2e_tracked_test!(
4097        test_decoder_tracked_end_to_end_segdet,
4098        quantized,
4099        combined,
4100        masks
4101    );
4102    e2e_tracked_test!(
4103        test_decoder_tracked_end_to_end_segdet_float,
4104        float,
4105        combined,
4106        masks
4107    );
4108    e2e_tracked_test!(
4109        test_decoder_tracked_end_to_end_segdet_proto,
4110        quantized,
4111        combined,
4112        proto
4113    );
4114    e2e_tracked_test!(
4115        test_decoder_tracked_end_to_end_segdet_proto_float,
4116        float,
4117        combined,
4118        proto
4119    );
4120    e2e_tracked_test!(
4121        test_decoder_tracked_end_to_end_segdet_split,
4122        quantized,
4123        split,
4124        masks
4125    );
4126    e2e_tracked_test!(
4127        test_decoder_tracked_end_to_end_segdet_split_float,
4128        float,
4129        split,
4130        masks
4131    );
4132    e2e_tracked_test!(
4133        test_decoder_tracked_end_to_end_segdet_split_proto,
4134        quantized,
4135        split,
4136        proto
4137    );
4138    e2e_tracked_test!(
4139        test_decoder_tracked_end_to_end_segdet_split_proto_float,
4140        float,
4141        split,
4142        proto
4143    );
4144
4145    #[test]
4146    fn test_decoder_tracked_linear_motion() {
4147        use crate::configs::{DecoderType, Nms};
4148        use crate::DecoderBuilder;
4149
4150        let score_threshold = 0.25;
4151        let iou_threshold = 0.1;
4152        let out = include_bytes!(concat!(
4153            env!("CARGO_MANIFEST_DIR"),
4154            "/../../testdata/yolov8s_80_classes.bin"
4155        ));
4156        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4157        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4158        let quant = (0.0040811873, -123).into();
4159
4160        let decoder = DecoderBuilder::default()
4161            .with_config_yolo_det(
4162                crate::configs::Detection {
4163                    decoder: DecoderType::Ultralytics,
4164                    shape: vec![1, 84, 8400],
4165                    anchors: None,
4166                    quantization: Some(quant),
4167                    dshape: vec![
4168                        (crate::configs::DimName::Batch, 1),
4169                        (crate::configs::DimName::NumFeatures, 84),
4170                        (crate::configs::DimName::NumBoxes, 8400),
4171                    ],
4172                    normalized: Some(true),
4173                },
4174                None,
4175            )
4176            .with_score_threshold(score_threshold)
4177            .with_iou_threshold(iou_threshold)
4178            .with_nms(Some(Nms::ClassAgnostic))
4179            .build()
4180            .unwrap();
4181
4182        let mut expected_boxes = [
4183            DetectBox {
4184                bbox: BoundingBox {
4185                    xmin: 0.5285137,
4186                    ymin: 0.05305544,
4187                    xmax: 0.87541467,
4188                    ymax: 0.9998909,
4189                },
4190                score: 0.5591227,
4191                label: 0,
4192            },
4193            DetectBox {
4194                bbox: BoundingBox {
4195                    xmin: 0.130598,
4196                    ymin: 0.43260583,
4197                    xmax: 0.35098213,
4198                    ymax: 0.9958097,
4199                },
4200                score: 0.33057618,
4201                label: 75,
4202            },
4203        ];
4204
4205        let mut tracker = ByteTrackBuilder::new()
4206            .track_update(0.1)
4207            .track_high_conf(0.3)
4208            .build();
4209
4210        let mut output_boxes = Vec::with_capacity(50);
4211        let mut output_masks = Vec::with_capacity(50);
4212        let mut output_tracks = Vec::with_capacity(50);
4213
4214        decoder
4215            .decode_tracked_quantized(
4216                &mut tracker,
4217                0,
4218                &[out.view().into()],
4219                &mut output_boxes,
4220                &mut output_masks,
4221                &mut output_tracks,
4222            )
4223            .unwrap();
4224
4225        assert_eq!(output_boxes.len(), 2);
4226        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4227        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4228
4229        for i in 1..=100 {
4230            let mut out = out.clone();
4231            // introduce linear movement into the XY coordinates
4232            let mut x_values = out.slice_mut(s![0, 0, ..]);
4233            for x in x_values.iter_mut() {
4234                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
4235            }
4236
4237            decoder
4238                .decode_tracked_quantized(
4239                    &mut tracker,
4240                    100_000_000 * i / 3, // simulate 33.333ms between frames
4241                    &[out.view().into()],
4242                    &mut output_boxes,
4243                    &mut output_masks,
4244                    &mut output_tracks,
4245                )
4246                .unwrap();
4247
4248            assert_eq!(output_boxes.len(), 2);
4249        }
4250        let tracks = tracker.get_active_tracks();
4251        let predicted_boxes: Vec<_> = tracks
4252            .iter()
4253            .map(|track| {
4254                let mut l = track.last_box;
4255                l.bbox = track.info.tracked_location.into();
4256                l
4257            })
4258            .collect();
4259        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
4260        expected_boxes[0].bbox.xmax += 0.1;
4261        expected_boxes[1].bbox.xmin += 0.1;
4262        expected_boxes[1].bbox.xmax += 0.1;
4263
4264        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
4265        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
4266
4267        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4268        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
4269        for score in scores_values.iter_mut() {
4270            *score = i8::MIN; // set all scores to minimum to simulate no detections
4271        }
4272        decoder
4273            .decode_tracked_quantized(
4274                &mut tracker,
4275                100_000_000 * 101 / 3,
4276                &[out.view().into()],
4277                &mut output_boxes,
4278                &mut output_masks,
4279                &mut output_tracks,
4280            )
4281            .unwrap();
4282        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
4283        expected_boxes[0].bbox.xmax += 0.001;
4284        expected_boxes[1].bbox.xmin += 0.001;
4285        expected_boxes[1].bbox.xmax += 0.001;
4286
4287        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
4288        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
4289    }
4290
4291    #[test]
4292    fn test_decoder_tracked_end_to_end_float() {
4293        let score_threshold = 0.45;
4294        let iou_threshold = 0.45;
4295
4296        let mut boxes = Array2::zeros((10, 4));
4297        let mut scores = Array2::zeros((10, 1));
4298        let mut classes = Array2::zeros((10, 1));
4299
4300        boxes
4301            .slice_mut(s![0, ..,])
4302            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4303        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4304        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4305
4306        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
4307        let mut detect = detect.insert_axis(Axis(0));
4308        assert_eq!(detect.shape(), &[1, 10, 6]);
4309        let config = "
4310decoder_version: yolo26
4311outputs:
4312 - type: detection
4313   decoder: ultralytics
4314   quantization: [0.00784313725490196, 0]
4315   shape: [1, 10, 6]
4316   dshape:
4317    - [batch, 1]
4318    - [num_boxes, 10]
4319    - [num_features, 6]
4320   normalized: true
4321";
4322
4323        let decoder = DecoderBuilder::default()
4324            .with_config_yaml_str(config.to_string())
4325            .with_score_threshold(score_threshold)
4326            .with_iou_threshold(iou_threshold)
4327            .build()
4328            .unwrap();
4329
4330        let expected_boxes = [DetectBox {
4331            bbox: BoundingBox {
4332                xmin: 0.1234,
4333                ymin: 0.1234,
4334                xmax: 0.2345,
4335                ymax: 0.2345,
4336            },
4337            score: 0.9876,
4338            label: 2,
4339        }];
4340
4341        let mut tracker = ByteTrackBuilder::new()
4342            .track_update(0.1)
4343            .track_high_conf(0.7)
4344            .build();
4345
4346        let mut output_boxes = Vec::with_capacity(50);
4347        let mut output_masks = Vec::with_capacity(50);
4348        let mut output_tracks = Vec::with_capacity(50);
4349
4350        decoder
4351            .decode_tracked_float(
4352                &mut tracker,
4353                0,
4354                &[detect.view().into_dyn()],
4355                &mut output_boxes,
4356                &mut output_masks,
4357                &mut output_tracks,
4358            )
4359            .unwrap();
4360
4361        assert_eq!(output_boxes.len(), 1);
4362        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4363
4364        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4365
4366        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4367            *score = 0.0; // set all scores to minimum to simulate no detections
4368        }
4369
4370        decoder
4371            .decode_tracked_float(
4372                &mut tracker,
4373                100_000_000 / 3,
4374                &[detect.view().into_dyn()],
4375                &mut output_boxes,
4376                &mut output_masks,
4377                &mut output_tracks,
4378            )
4379            .unwrap();
4380        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4381    }
4382}