Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod schema;
77pub mod yolo;
78
79mod decoder;
80pub use decoder::*;
81
82pub use configs::{DecoderVersion, Nms};
83pub use error::{DecoderError, DecoderResult};
84
85use crate::{
86    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
87    yolo::yolo_segmentation_to_mask,
88};
89
90/// Trait to convert bounding box formats to XYXY float format
91pub trait BBoxTypeTrait {
92    /// Converts the bbox into XYXY float format.
93    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
94
95    /// Converts the bbox into XYXY float format.
96    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
97        input: &[B; 4],
98        quant: Quantization,
99    ) -> [A; 4]
100    where
101        f32: AsPrimitive<A>,
102        i32: AsPrimitive<A>;
103
104    /// Converts the bbox into XYXY float format.
105    ///
106    /// # Examples
107    /// ```rust
108    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
109    /// # use ndarray::array;
110    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
111    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
112    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
113    /// ```
114    #[inline(always)]
115    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
116        input: ArrayView1<B>,
117    ) -> [A; 4] {
118        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
119    }
120
121    #[inline(always)]
122    /// Converts the bbox into XYXY float format.
123    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
124        input: ArrayView1<B>,
125        quant: Quantization,
126    ) -> [A; 4]
127    where
128        f32: AsPrimitive<A>,
129        i32: AsPrimitive<A>,
130    {
131        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
132    }
133}
134
135/// Converts XYXY bounding boxes to XYXY
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137pub struct XYXY {}
138
139impl BBoxTypeTrait for XYXY {
140    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
141        input.map(|b| b.as_())
142    }
143
144    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
145        input: &[B; 4],
146        quant: Quantization,
147    ) -> [A; 4]
148    where
149        f32: AsPrimitive<A>,
150        i32: AsPrimitive<A>,
151    {
152        let scale = quant.scale.as_();
153        let zp = quant.zero_point.as_();
154        input.map(|b| (b.as_() - zp) * scale)
155    }
156
157    #[inline(always)]
158    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
159        input: ArrayView1<B>,
160    ) -> [A; 4] {
161        [
162            input[0].as_(),
163            input[1].as_(),
164            input[2].as_(),
165            input[3].as_(),
166        ]
167    }
168}
169
170/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
171/// box
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub struct XYWH {}
174
175impl BBoxTypeTrait for XYWH {
176    #[inline(always)]
177    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
178        let half = A::one() / (A::one() + A::one());
179        [
180            (input[0].as_()) - (input[2].as_() * half),
181            (input[1].as_()) - (input[3].as_() * half),
182            (input[0].as_()) + (input[2].as_() * half),
183            (input[1].as_()) + (input[3].as_() * half),
184        ]
185    }
186
187    #[inline(always)]
188    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
189        input: &[B; 4],
190        quant: Quantization,
191    ) -> [A; 4]
192    where
193        f32: AsPrimitive<A>,
194        i32: AsPrimitive<A>,
195    {
196        let scale = quant.scale.as_();
197        let half_scale = (quant.scale * 0.5).as_();
198        let zp = quant.zero_point.as_();
199        let [x, y, w, h] = [
200            (input[0].as_() - zp) * scale,
201            (input[1].as_() - zp) * scale,
202            (input[2].as_() - zp) * half_scale,
203            (input[3].as_() - zp) * half_scale,
204        ];
205
206        [x - w, y - h, x + w, y + h]
207    }
208
209    #[inline(always)]
210    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
211        input: ArrayView1<B>,
212    ) -> [A; 4] {
213        let half = A::one() / (A::one() + A::one());
214        [
215            (input[0].as_()) - (input[2].as_() * half),
216            (input[1].as_()) - (input[3].as_() * half),
217            (input[0].as_()) + (input[2].as_() * half),
218            (input[1].as_()) + (input[3].as_() * half),
219        ]
220    }
221}
222
223/// Describes the quantization parameters for a tensor
224#[derive(Debug, Clone, Copy, PartialEq)]
225pub struct Quantization {
226    pub scale: f32,
227    pub zero_point: i32,
228}
229
230impl Quantization {
231    /// Creates a new Quantization struct
232    /// # Examples
233    /// ```
234    /// # use edgefirst_decoder::Quantization;
235    /// let quant = Quantization::new(0.1, -128);
236    /// assert_eq!(quant.scale, 0.1);
237    /// assert_eq!(quant.zero_point, -128);
238    /// ```
239    pub fn new(scale: f32, zero_point: i32) -> Self {
240        Self { scale, zero_point }
241    }
242}
243
244impl From<QuantTuple> for Quantization {
245    /// Creates a new Quantization struct from a QuantTuple
246    /// # Examples
247    /// ```
248    /// # use edgefirst_decoder::Quantization;
249    /// # use edgefirst_decoder::configs::QuantTuple;
250    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
251    /// let quant = Quantization::from(quant_tuple);
252    /// assert_eq!(quant.scale, 0.1);
253    /// assert_eq!(quant.zero_point, -128);
254    /// ```
255    fn from(quant_tuple: QuantTuple) -> Quantization {
256        Quantization {
257            scale: quant_tuple.0,
258            zero_point: quant_tuple.1,
259        }
260    }
261}
262
263impl<S, Z> From<(S, Z)> for Quantization
264where
265    S: AsPrimitive<f32>,
266    Z: AsPrimitive<i32>,
267{
268    /// Creates a new Quantization struct from a tuple
269    /// # Examples
270    /// ```
271    /// # use edgefirst_decoder::Quantization;
272    /// let quant = Quantization::from((0.1_f64, -128_i64));
273    /// assert_eq!(quant.scale, 0.1);
274    /// assert_eq!(quant.zero_point, -128);
275    /// ```
276    fn from((scale, zp): (S, Z)) -> Quantization {
277        Self {
278            scale: scale.as_(),
279            zero_point: zp.as_(),
280        }
281    }
282}
283
284impl Default for Quantization {
285    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
286    /// # Examples
287    /// ```rust
288    /// # use edgefirst_decoder::Quantization;
289    /// let quant = Quantization::default();
290    /// assert_eq!(quant.scale, 1.0);
291    /// assert_eq!(quant.zero_point, 0);
292    /// ```
293    fn default() -> Self {
294        Self {
295            scale: 1.0,
296            zero_point: 0,
297        }
298    }
299}
300
301/// A detection box with f32 bbox and score
302#[derive(Debug, Clone, Copy, PartialEq, Default)]
303pub struct DetectBox {
304    pub bbox: BoundingBox,
305    /// model-specific score for this detection, higher implies more confidence
306    pub score: f32,
307    /// label index for this detection
308    pub label: usize,
309}
310
311/// A bounding box with f32 coordinates in XYXY format
312#[derive(Debug, Clone, Copy, PartialEq, Default)]
313pub struct BoundingBox {
314    /// left-most normalized coordinate of the bounding box
315    pub xmin: f32,
316    /// top-most normalized coordinate of the bounding box
317    pub ymin: f32,
318    /// right-most normalized coordinate of the bounding box
319    pub xmax: f32,
320    /// bottom-most normalized coordinate of the bounding box
321    pub ymax: f32,
322}
323
324impl BoundingBox {
325    /// Creates a new BoundingBox from the given coordinates
326    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
327        Self {
328            xmin,
329            ymin,
330            xmax,
331            ymax,
332        }
333    }
334
335    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
336    ///
337    /// ```
338    /// # use edgefirst_decoder::BoundingBox;
339    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
340    /// let canonical_bbox = bbox.to_canonical();
341    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
342    /// ```
343    pub fn to_canonical(&self) -> Self {
344        let xmin = self.xmin.min(self.xmax);
345        let xmax = self.xmin.max(self.xmax);
346        let ymin = self.ymin.min(self.ymax);
347        let ymax = self.ymin.max(self.ymax);
348        BoundingBox {
349            xmin,
350            ymin,
351            xmax,
352            ymax,
353        }
354    }
355}
356
357impl From<BoundingBox> for [f32; 4] {
358    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
359    /// xmax, ymax order
360    /// # Examples
361    /// ```
362    /// # use edgefirst_decoder::BoundingBox;
363    /// let bbox = BoundingBox {
364    ///     xmin: 0.1,
365    ///     ymin: 0.2,
366    ///     xmax: 0.3,
367    ///     ymax: 0.4,
368    /// };
369    /// let arr: [f32; 4] = bbox.into();
370    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
371    /// ```
372    fn from(b: BoundingBox) -> Self {
373        [b.xmin, b.ymin, b.xmax, b.ymax]
374    }
375}
376
377impl From<[f32; 4]> for BoundingBox {
378    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
379    // BoundingBox
380    fn from(arr: [f32; 4]) -> Self {
381        BoundingBox {
382            xmin: arr[0],
383            ymin: arr[1],
384            xmax: arr[2],
385            ymax: arr[3],
386        }
387    }
388}
389
390impl DetectBox {
391    /// Returns true if one detect box is equal to another detect box, within
392    /// the given `eps`
393    ///
394    /// # Examples
395    /// ```
396    /// # use edgefirst_decoder::DetectBox;
397    /// let box1 = DetectBox {
398    ///     bbox: edgefirst_decoder::BoundingBox {
399    ///         xmin: 0.1,
400    ///         ymin: 0.2,
401    ///         xmax: 0.3,
402    ///         ymax: 0.4,
403    ///     },
404    ///     score: 0.5,
405    ///     label: 1,
406    /// };
407    /// let box2 = DetectBox {
408    ///     bbox: edgefirst_decoder::BoundingBox {
409    ///         xmin: 0.101,
410    ///         ymin: 0.199,
411    ///         xmax: 0.301,
412    ///         ymax: 0.399,
413    ///     },
414    ///     score: 0.510,
415    ///     label: 1,
416    /// };
417    /// assert!(box1.equal_within_delta(&box2, 0.011));
418    /// ```
419    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
420        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
421        self.label == rhs.label
422            && eq_delta(self.score, rhs.score)
423            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
424            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
425            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
426            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
427    }
428}
429
430/// A segmentation result with a segmentation mask, and a normalized bounding
431/// box representing the area that the segmentation mask covers
432#[derive(Debug, Clone, PartialEq, Default)]
433pub struct Segmentation {
434    /// left-most normalized coordinate of the segmentation box
435    pub xmin: f32,
436    /// top-most normalized coordinate of the segmentation box
437    pub ymin: f32,
438    /// right-most normalized coordinate of the segmentation box
439    pub xmax: f32,
440    /// bottom-most normalized coordinate of the segmentation box
441    pub ymax: f32,
442    /// 3D segmentation array of shape `(H, W, C)`.
443    ///
444    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
445    /// continuous sigmoid confidence values quantized to u8 (0 = background,
446    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
447    /// 0.5) or use smooth interpolation for anti-aliased edges.
448    ///
449    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
450    /// class scores where the object class is the argmax index.
451    pub segmentation: Array3<u8>,
452}
453
454/// Prototype tensor variants for fused decode+render pipelines.
455///
456/// Carries either raw quantized data (to skip CPU dequantization and let the
457/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
458/// paths).
459#[derive(Debug, Clone)]
460pub enum ProtoTensor {
461    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
462    /// The GPU fragment shader will dequantize per-texel using the scale and
463    /// zero_point.
464    Quantized {
465        protos: Array3<i8>,
466        quantization: Quantization,
467    },
468    /// Dequantized f32 protos (from float models or legacy path).
469    Float(Array3<f32>),
470}
471
472impl ProtoTensor {
473    /// Returns `true` if this is the quantized variant.
474    pub fn is_quantized(&self) -> bool {
475        matches!(self, ProtoTensor::Quantized { .. })
476    }
477
478    /// Returns the spatial dimensions `(height, width, num_protos)`.
479    pub fn dim(&self) -> (usize, usize, usize) {
480        match self {
481            ProtoTensor::Quantized { protos, .. } => protos.dim(),
482            ProtoTensor::Float(arr) => arr.dim(),
483        }
484    }
485
486    /// Returns dequantized f32 protos. For the `Float` variant this is a
487    /// no-copy reference; for `Quantized` it allocates and dequantizes.
488    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
489        match self {
490            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
491            ProtoTensor::Quantized {
492                protos,
493                quantization,
494            } => {
495                let scale = quantization.scale;
496                let zp = quantization.zero_point as f32;
497                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
498            }
499        }
500    }
501}
502
503/// Raw prototype data for fused decode+render pipelines.
504///
505/// Holds post-NMS intermediate state before mask materialization, allowing the
506/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
507/// shader) without materializing intermediate `Array3<u8>` masks.
508#[derive(Debug, Clone)]
509pub struct ProtoData {
510    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
511    pub mask_coefficients: Vec<Vec<f32>>,
512    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
513    pub protos: ProtoTensor,
514}
515
516/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
517///
518///  # Examples
519/// ```
520/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
521/// let quant = Quantization::new(0.1, -128);
522/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
523/// let detect_quant = DetectBoxQuantized {
524///     bbox,
525///     score: 100_i8,
526///     label: 1,
527/// };
528/// let detect = dequant_detect_box(&detect_quant, quant);
529/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
530/// assert_eq!(detect.label, 1);
531/// assert_eq!(detect.bbox, bbox);
532/// ```
533pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
534    detect: &DetectBoxQuantized<SCORE>,
535    quant_scores: Quantization,
536) -> DetectBox {
537    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
538    DetectBox {
539        bbox: detect.bbox,
540        score: quant_scores.scale * detect.score.as_() + scaled_zp,
541        label: detect.label,
542    }
543}
544/// A detection box with a f32 bbox and quantized score
545#[derive(Debug, Clone, Copy, PartialEq)]
546pub struct DetectBoxQuantized<
547    // BOX: Signed + PrimInt + AsPrimitive<f32>,
548    SCORE: PrimInt + AsPrimitive<f32>,
549> {
550    // pub bbox: BoundingBoxQuantized<BOX>,
551    pub bbox: BoundingBox,
552    /// model-specific score for this detection, higher implies more
553    /// confidence.
554    pub score: SCORE,
555    /// label index for this detect
556    pub label: usize,
557}
558
559/// Dequantizes an ndarray from quantized values to f32 values using the given
560/// quantization parameters
561///
562/// # Examples
563/// ```
564/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
565/// let quant = Quantization::new(0.1, -128);
566/// let input: Vec<i8> = vec![0, 127, -128, 64];
567/// let input_array = ndarray::Array1::from(input);
568/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
569/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
570/// ```
571pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
572    input: ArrayView<T, D>,
573    quant: Quantization,
574) -> Array<F, D>
575where
576    i32: num_traits::AsPrimitive<F>,
577    f32: num_traits::AsPrimitive<F>,
578{
579    let zero_point = quant.zero_point.as_();
580    let scale = quant.scale.as_();
581    if zero_point != F::zero() {
582        let scaled_zero = -zero_point * scale;
583        input.mapv(|d| d.as_() * scale + scaled_zero)
584    } else {
585        input.mapv(|d| d.as_() * scale)
586    }
587}
588
589/// Dequantizes a slice from quantized values to float values using the given
590/// quantization parameters
591///
592/// # Examples
593/// ```
594/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
595/// let quant = Quantization::new(0.1, -128);
596/// let input: Vec<i8> = vec![0, 127, -128, 64];
597/// let mut output: Vec<f32> = vec![0.0; input.len()];
598/// dequantize_cpu(&input, quant, &mut output);
599/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
600/// ```
601pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
602    input: &[T],
603    quant: Quantization,
604    output: &mut [F],
605) where
606    f32: num_traits::AsPrimitive<F>,
607    i32: num_traits::AsPrimitive<F>,
608{
609    assert!(input.len() == output.len());
610    let zero_point = quant.zero_point.as_();
611    let scale = quant.scale.as_();
612    if zero_point != F::zero() {
613        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
614        input
615            .iter()
616            .zip(output)
617            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
618    } else {
619        input
620            .iter()
621            .zip(output)
622            .for_each(|(d, deq)| *deq = d.as_() * scale);
623    }
624}
625
626/// Dequantizes a slice from quantized values to float values using the given
627/// quantization parameters, using chunked processing. This is around 5% faster
628/// than `dequantize_cpu` for large slices.
629///
630/// # Examples
631/// ```
632/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
633/// let quant = Quantization::new(0.1, -128);
634/// let input: Vec<i8> = vec![0, 127, -128, 64];
635/// let mut output: Vec<f32> = vec![0.0; input.len()];
636/// dequantize_cpu_chunked(&input, quant, &mut output);
637/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
638/// ```
639pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
640    input: &[T],
641    quant: Quantization,
642    output: &mut [F],
643) where
644    f32: num_traits::AsPrimitive<F>,
645    i32: num_traits::AsPrimitive<F>,
646{
647    assert!(input.len() == output.len());
648    let zero_point = quant.zero_point.as_();
649    let scale = quant.scale.as_();
650
651    let input = input.as_chunks::<4>();
652    let output = output.as_chunks_mut::<4>();
653
654    if zero_point != F::zero() {
655        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
656
657        input
658            .0
659            .iter()
660            .zip(output.0)
661            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
662        input
663            .1
664            .iter()
665            .zip(output.1)
666            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
667    } else {
668        input
669            .0
670            .iter()
671            .zip(output.0)
672            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
673        input
674            .1
675            .iter()
676            .zip(output.1)
677            .for_each(|(d, deq)| *deq = d.as_() * scale);
678    }
679}
680
681/// Converts a segmentation tensor into a 2D mask
682/// If the last dimension of the segmentation tensor is 1, values equal or
683/// above 128 are considered objects. Otherwise the object is the argmax index
684///
685/// # Errors
686///
687/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
688/// invalid shape.
689///
690/// # Examples
691/// ```
692/// # use edgefirst_decoder::segmentation_to_mask;
693/// let segmentation =
694///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
695/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
696/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
697/// ```
698pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
699    if segmentation.shape()[2] == 0 {
700        return Err(DecoderError::InvalidShape(
701            "Segmentation tensor must have non-zero depth".to_string(),
702        ));
703    }
704    if segmentation.shape()[2] == 1 {
705        yolo_segmentation_to_mask(segmentation, 128)
706    } else {
707        Ok(modelpack_segmentation_to_mask(segmentation))
708    }
709}
710
711/// Returns the maximum value and its index from a 1D array
712fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
713    score
714        .iter()
715        .enumerate()
716        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
717            if max > *s {
718                (max, arg_max)
719            } else {
720                (*s, ind)
721            }
722        })
723}
724#[cfg(test)]
725#[cfg_attr(coverage_nightly, coverage(off))]
726mod decoder_tests {
727    #![allow(clippy::excessive_precision)]
728    use crate::{
729        configs::{DecoderType, DimName, Protos},
730        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
731        yolo::{
732            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
733            decode_yolo_segdet_quant,
734        },
735        *,
736    };
737    use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
738    use ndarray::Dimension;
739    use ndarray::{array, s, Array2, Array3, Array4, Axis};
740    use ndarray_stats::DeviationExt;
741    use num_traits::{AsPrimitive, PrimInt};
742
743    fn compare_outputs(
744        boxes: (&[DetectBox], &[DetectBox]),
745        masks: (&[Segmentation], &[Segmentation]),
746    ) {
747        let (boxes0, boxes1) = boxes;
748        let (masks0, masks1) = masks;
749
750        assert_eq!(boxes0.len(), boxes1.len());
751        assert_eq!(masks0.len(), masks1.len());
752
753        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
754            assert!(
755                b_i8.equal_within_delta(b_f32, 1e-6),
756                "{b_i8:?} is not equal to {b_f32:?}"
757            );
758        }
759
760        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
761            assert_eq!(
762                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
763                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
764            );
765            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
766            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
767            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
768            let diff = &mask_i8 - &mask_f32;
769            for x in 0..diff.shape()[0] {
770                for y in 0..diff.shape()[1] {
771                    for z in 0..diff.shape()[2] {
772                        let val = diff[[x, y, z]];
773                        assert!(
774                            val.abs() <= 1,
775                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
776                            x,
777                            y,
778                            z,
779                            val
780                        );
781                    }
782                }
783            }
784            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
785            assert!(
786                mean_sq_err < 1e-2,
787                "Mean Square Error between masks was greater than 1%: {:.2}%",
788                mean_sq_err * 100.0
789            );
790        }
791    }
792
793    // ─── Shared test data loaders ────────────────────────
794
795    fn load_yolov8_boxes() -> Array3<i8> {
796        let raw = include_bytes!(concat!(
797            env!("CARGO_MANIFEST_DIR"),
798            "/../../testdata/yolov8_boxes_116x8400.bin"
799        ));
800        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
801        Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
802    }
803
804    fn load_yolov8_protos() -> Array4<i8> {
805        let raw = include_bytes!(concat!(
806            env!("CARGO_MANIFEST_DIR"),
807            "/../../testdata/yolov8_protos_160x160x32.bin"
808        ));
809        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
810        Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
811    }
812
813    fn load_yolov8s_det() -> Array3<i8> {
814        let raw = include_bytes!(concat!(
815            env!("CARGO_MANIFEST_DIR"),
816            "/../../testdata/yolov8s_80_classes.bin"
817        ));
818        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
819        Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
820    }
821
822    #[test]
823    fn test_decoder_modelpack() {
824        let score_threshold = 0.45;
825        let iou_threshold = 0.45;
826        let boxes = include_bytes!(concat!(
827            env!("CARGO_MANIFEST_DIR"),
828            "/../../testdata/modelpack_boxes_1935x1x4.bin"
829        ));
830        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
831
832        let scores = include_bytes!(concat!(
833            env!("CARGO_MANIFEST_DIR"),
834            "/../../testdata/modelpack_scores_1935x1.bin"
835        ));
836        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
837
838        let quant_boxes = (0.004656755365431309, 21).into();
839        let quant_scores = (0.0019603664986789227, 0).into();
840
841        let decoder = DecoderBuilder::default()
842            .with_config_modelpack_det(
843                configs::Boxes {
844                    decoder: DecoderType::ModelPack,
845                    quantization: Some(quant_boxes),
846                    shape: vec![1, 1935, 1, 4],
847                    dshape: vec![
848                        (DimName::Batch, 1),
849                        (DimName::NumBoxes, 1935),
850                        (DimName::Padding, 1),
851                        (DimName::BoxCoords, 4),
852                    ],
853                    normalized: Some(true),
854                },
855                configs::Scores {
856                    decoder: DecoderType::ModelPack,
857                    quantization: Some(quant_scores),
858                    shape: vec![1, 1935, 1],
859                    dshape: vec![
860                        (DimName::Batch, 1),
861                        (DimName::NumBoxes, 1935),
862                        (DimName::NumClasses, 1),
863                    ],
864                },
865            )
866            .with_score_threshold(score_threshold)
867            .with_iou_threshold(iou_threshold)
868            .build()
869            .unwrap();
870
871        let quant_boxes = quant_boxes.into();
872        let quant_scores = quant_scores.into();
873
874        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
875        decode_modelpack_det(
876            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
877            (scores.slice(s![0, .., ..]), quant_scores),
878            score_threshold,
879            iou_threshold,
880            &mut output_boxes,
881        );
882        assert!(output_boxes[0].equal_within_delta(
883            &DetectBox {
884                bbox: BoundingBox {
885                    xmin: 0.40513772,
886                    ymin: 0.6379755,
887                    xmax: 0.5122431,
888                    ymax: 0.7730214,
889                },
890                score: 0.4861709,
891                label: 0
892            },
893            1e-6
894        ));
895
896        let mut output_boxes1 = Vec::with_capacity(50);
897        let mut output_masks1 = Vec::with_capacity(50);
898
899        decoder
900            .decode_quantized(
901                &[boxes.view().into(), scores.view().into()],
902                &mut output_boxes1,
903                &mut output_masks1,
904            )
905            .unwrap();
906
907        let mut output_boxes_float = Vec::with_capacity(50);
908        let mut output_masks_float = Vec::with_capacity(50);
909
910        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
911        let scores = dequantize_ndarray(scores.view(), quant_scores);
912
913        decoder
914            .decode_float::<f32>(
915                &[boxes.view().into_dyn(), scores.view().into_dyn()],
916                &mut output_boxes_float,
917                &mut output_masks_float,
918            )
919            .unwrap();
920
921        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
922        compare_outputs(
923            (&output_boxes, &output_boxes_float),
924            (&[], &output_masks_float),
925        );
926    }
927
928    #[test]
929    fn test_decoder_modelpack_split_u8() {
930        let score_threshold = 0.45;
931        let iou_threshold = 0.45;
932        let detect0 = include_bytes!(concat!(
933            env!("CARGO_MANIFEST_DIR"),
934            "/../../testdata/modelpack_split_9x15x18.bin"
935        ));
936        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
937
938        let detect1 = include_bytes!(concat!(
939            env!("CARGO_MANIFEST_DIR"),
940            "/../../testdata/modelpack_split_17x30x18.bin"
941        ));
942        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
943
944        let quant0 = (0.08547406643629074, 174).into();
945        let quant1 = (0.09929127991199493, 183).into();
946        let anchors0 = vec![
947            [0.36666667461395264, 0.31481480598449707],
948            [0.38749998807907104, 0.4740740656852722],
949            [0.5333333611488342, 0.644444465637207],
950        ];
951        let anchors1 = vec![
952            [0.13750000298023224, 0.2074074000120163],
953            [0.2541666626930237, 0.21481481194496155],
954            [0.23125000298023224, 0.35185185074806213],
955        ];
956
957        let detect_config0 = configs::Detection {
958            decoder: DecoderType::ModelPack,
959            shape: vec![1, 9, 15, 18],
960            anchors: Some(anchors0.clone()),
961            quantization: Some(quant0),
962            dshape: vec![
963                (DimName::Batch, 1),
964                (DimName::Height, 9),
965                (DimName::Width, 15),
966                (DimName::NumAnchorsXFeatures, 18),
967            ],
968            normalized: Some(true),
969        };
970
971        let detect_config1 = configs::Detection {
972            decoder: DecoderType::ModelPack,
973            shape: vec![1, 17, 30, 18],
974            anchors: Some(anchors1.clone()),
975            quantization: Some(quant1),
976            dshape: vec![
977                (DimName::Batch, 1),
978                (DimName::Height, 17),
979                (DimName::Width, 30),
980                (DimName::NumAnchorsXFeatures, 18),
981            ],
982            normalized: Some(true),
983        };
984
985        let config0 = (&detect_config0).try_into().unwrap();
986        let config1 = (&detect_config1).try_into().unwrap();
987
988        let decoder = DecoderBuilder::default()
989            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
990            .with_score_threshold(score_threshold)
991            .with_iou_threshold(iou_threshold)
992            .build()
993            .unwrap();
994
995        let quant0 = quant0.into();
996        let quant1 = quant1.into();
997
998        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
999        decode_modelpack_split_quant(
1000            &[
1001                detect0.slice(s![0, .., .., ..]),
1002                detect1.slice(s![0, .., .., ..]),
1003            ],
1004            &[config0, config1],
1005            score_threshold,
1006            iou_threshold,
1007            &mut output_boxes,
1008        );
1009        assert!(output_boxes[0].equal_within_delta(
1010            &DetectBox {
1011                bbox: BoundingBox {
1012                    xmin: 0.43171933,
1013                    ymin: 0.68243736,
1014                    xmax: 0.5626645,
1015                    ymax: 0.808863,
1016                },
1017                score: 0.99240804,
1018                label: 0
1019            },
1020            1e-6
1021        ));
1022
1023        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1024        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1025        decoder
1026            .decode_quantized(
1027                &[detect0.view().into(), detect1.view().into()],
1028                &mut output_boxes1,
1029                &mut output_masks1,
1030            )
1031            .unwrap();
1032
1033        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1034        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1035
1036        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1037        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1038        decoder
1039            .decode_float::<f32>(
1040                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1041                &mut output_boxes1_f32,
1042                &mut output_masks1_f32,
1043            )
1044            .unwrap();
1045
1046        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1047        compare_outputs(
1048            (&output_boxes, &output_boxes1_f32),
1049            (&[], &output_masks1_f32),
1050        );
1051    }
1052
1053    #[test]
1054    fn test_decoder_parse_config_modelpack_split_u8() {
1055        let score_threshold = 0.45;
1056        let iou_threshold = 0.45;
1057        let detect0 = include_bytes!(concat!(
1058            env!("CARGO_MANIFEST_DIR"),
1059            "/../../testdata/modelpack_split_9x15x18.bin"
1060        ));
1061        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1062
1063        let detect1 = include_bytes!(concat!(
1064            env!("CARGO_MANIFEST_DIR"),
1065            "/../../testdata/modelpack_split_17x30x18.bin"
1066        ));
1067        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1068
1069        let decoder = DecoderBuilder::default()
1070            .with_config_yaml_str(
1071                include_str!(concat!(
1072                    env!("CARGO_MANIFEST_DIR"),
1073                    "/../../testdata/modelpack_split.yaml"
1074                ))
1075                .to_string(),
1076            )
1077            .with_score_threshold(score_threshold)
1078            .with_iou_threshold(iou_threshold)
1079            .build()
1080            .unwrap();
1081
1082        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1083        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1084        decoder
1085            .decode_quantized(
1086                &[
1087                    ArrayViewDQuantized::from(detect1.view()),
1088                    ArrayViewDQuantized::from(detect0.view()),
1089                ],
1090                &mut output_boxes,
1091                &mut output_masks,
1092            )
1093            .unwrap();
1094        assert!(output_boxes[0].equal_within_delta(
1095            &DetectBox {
1096                bbox: BoundingBox {
1097                    xmin: 0.43171933,
1098                    ymin: 0.68243736,
1099                    xmax: 0.5626645,
1100                    ymax: 0.808863,
1101                },
1102                score: 0.99240804,
1103                label: 0
1104            },
1105            1e-6
1106        ));
1107    }
1108
1109    #[test]
1110    fn test_modelpack_seg() {
1111        let out = include_bytes!(concat!(
1112            env!("CARGO_MANIFEST_DIR"),
1113            "/../../testdata/modelpack_seg_2x160x160.bin"
1114        ));
1115        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1116        let quant = (1.0 / 255.0, 0).into();
1117
1118        let decoder = DecoderBuilder::default()
1119            .with_config_modelpack_seg(configs::Segmentation {
1120                decoder: DecoderType::ModelPack,
1121                quantization: Some(quant),
1122                shape: vec![1, 2, 160, 160],
1123                dshape: vec![
1124                    (DimName::Batch, 1),
1125                    (DimName::NumClasses, 2),
1126                    (DimName::Height, 160),
1127                    (DimName::Width, 160),
1128                ],
1129            })
1130            .build()
1131            .unwrap();
1132        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1133        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1134        decoder
1135            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1136            .unwrap();
1137
1138        let mut mask = out.slice(s![0, .., .., ..]);
1139        mask.swap_axes(0, 1);
1140        mask.swap_axes(1, 2);
1141        let mask = [Segmentation {
1142            xmin: 0.0,
1143            ymin: 0.0,
1144            xmax: 1.0,
1145            ymax: 1.0,
1146            segmentation: mask.into_owned(),
1147        }];
1148        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1149
1150        decoder
1151            .decode_float::<f32>(
1152                &[dequantize_ndarray(out.view(), quant.into())
1153                    .view()
1154                    .into_dyn()],
1155                &mut output_boxes,
1156                &mut output_masks,
1157            )
1158            .unwrap();
1159
1160        // not expected for float decoder to have same values as quantized decoder, as
1161        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1162        // the model output. Thus the float output is the same as the quantized output
1163        // but scaled differently. However, it is expected that the mask after argmax
1164        // will be the same.
1165        compare_outputs((&[], &output_boxes), (&[], &[]));
1166        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1167        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1168
1169        assert_eq!(mask0, mask1);
1170    }
1171    #[test]
1172    fn test_modelpack_seg_quant() {
1173        let out = include_bytes!(concat!(
1174            env!("CARGO_MANIFEST_DIR"),
1175            "/../../testdata/modelpack_seg_2x160x160.bin"
1176        ));
1177        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1178        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1179        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1180        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1181        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1182        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1183
1184        let quant = (1.0 / 255.0, 0).into();
1185
1186        let decoder = DecoderBuilder::default()
1187            .with_config_modelpack_seg(configs::Segmentation {
1188                decoder: DecoderType::ModelPack,
1189                quantization: Some(quant),
1190                shape: vec![1, 2, 160, 160],
1191                dshape: vec![
1192                    (DimName::Batch, 1),
1193                    (DimName::NumClasses, 2),
1194                    (DimName::Height, 160),
1195                    (DimName::Width, 160),
1196                ],
1197            })
1198            .build()
1199            .unwrap();
1200        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1201        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1202        decoder
1203            .decode_quantized(
1204                &[out_u8.view().into()],
1205                &mut output_boxes,
1206                &mut output_masks_u8,
1207            )
1208            .unwrap();
1209
1210        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1211        decoder
1212            .decode_quantized(
1213                &[out_i8.view().into()],
1214                &mut output_boxes,
1215                &mut output_masks_i8,
1216            )
1217            .unwrap();
1218
1219        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1220        decoder
1221            .decode_quantized(
1222                &[out_u16.view().into()],
1223                &mut output_boxes,
1224                &mut output_masks_u16,
1225            )
1226            .unwrap();
1227
1228        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1229        decoder
1230            .decode_quantized(
1231                &[out_i16.view().into()],
1232                &mut output_boxes,
1233                &mut output_masks_i16,
1234            )
1235            .unwrap();
1236
1237        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1238        decoder
1239            .decode_quantized(
1240                &[out_u32.view().into()],
1241                &mut output_boxes,
1242                &mut output_masks_u32,
1243            )
1244            .unwrap();
1245
1246        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1247        decoder
1248            .decode_quantized(
1249                &[out_i32.view().into()],
1250                &mut output_boxes,
1251                &mut output_masks_i32,
1252            )
1253            .unwrap();
1254
1255        compare_outputs((&[], &output_boxes), (&[], &[]));
1256        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1257        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1258        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1259        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1260        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1261        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1262        assert_eq!(mask_u8, mask_i8);
1263        assert_eq!(mask_u8, mask_u16);
1264        assert_eq!(mask_u8, mask_i16);
1265        assert_eq!(mask_u8, mask_u32);
1266        assert_eq!(mask_u8, mask_i32);
1267    }
1268
1269    #[test]
1270    fn test_modelpack_segdet() {
1271        let score_threshold = 0.45;
1272        let iou_threshold = 0.45;
1273
1274        let boxes = include_bytes!(concat!(
1275            env!("CARGO_MANIFEST_DIR"),
1276            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1277        ));
1278        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1279
1280        let scores = include_bytes!(concat!(
1281            env!("CARGO_MANIFEST_DIR"),
1282            "/../../testdata/modelpack_scores_1935x1.bin"
1283        ));
1284        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1285
1286        let seg = include_bytes!(concat!(
1287            env!("CARGO_MANIFEST_DIR"),
1288            "/../../testdata/modelpack_seg_2x160x160.bin"
1289        ));
1290        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1291
1292        let quant_boxes = (0.004656755365431309, 21).into();
1293        let quant_scores = (0.0019603664986789227, 0).into();
1294        let quant_seg = (1.0 / 255.0, 0).into();
1295
1296        let decoder = DecoderBuilder::default()
1297            .with_config_modelpack_segdet(
1298                configs::Boxes {
1299                    decoder: DecoderType::ModelPack,
1300                    quantization: Some(quant_boxes),
1301                    shape: vec![1, 1935, 1, 4],
1302                    dshape: vec![
1303                        (DimName::Batch, 1),
1304                        (DimName::NumBoxes, 1935),
1305                        (DimName::Padding, 1),
1306                        (DimName::BoxCoords, 4),
1307                    ],
1308                    normalized: Some(true),
1309                },
1310                configs::Scores {
1311                    decoder: DecoderType::ModelPack,
1312                    quantization: Some(quant_scores),
1313                    shape: vec![1, 1935, 1],
1314                    dshape: vec![
1315                        (DimName::Batch, 1),
1316                        (DimName::NumBoxes, 1935),
1317                        (DimName::NumClasses, 1),
1318                    ],
1319                },
1320                configs::Segmentation {
1321                    decoder: DecoderType::ModelPack,
1322                    quantization: Some(quant_seg),
1323                    shape: vec![1, 2, 160, 160],
1324                    dshape: vec![
1325                        (DimName::Batch, 1),
1326                        (DimName::NumClasses, 2),
1327                        (DimName::Height, 160),
1328                        (DimName::Width, 160),
1329                    ],
1330                },
1331            )
1332            .with_iou_threshold(iou_threshold)
1333            .with_score_threshold(score_threshold)
1334            .build()
1335            .unwrap();
1336        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1337        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1338        decoder
1339            .decode_quantized(
1340                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1341                &mut output_boxes,
1342                &mut output_masks,
1343            )
1344            .unwrap();
1345
1346        let mut mask = seg.slice(s![0, .., .., ..]);
1347        mask.swap_axes(0, 1);
1348        mask.swap_axes(1, 2);
1349        let mask = [Segmentation {
1350            xmin: 0.0,
1351            ymin: 0.0,
1352            xmax: 1.0,
1353            ymax: 1.0,
1354            segmentation: mask.into_owned(),
1355        }];
1356        let correct_boxes = [DetectBox {
1357            bbox: BoundingBox {
1358                xmin: 0.40513772,
1359                ymin: 0.6379755,
1360                xmax: 0.5122431,
1361                ymax: 0.7730214,
1362            },
1363            score: 0.4861709,
1364            label: 0,
1365        }];
1366        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1367
1368        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1369        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1370        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1371        decoder
1372            .decode_float::<f32>(
1373                &[
1374                    scores.view().into_dyn(),
1375                    boxes.view().into_dyn(),
1376                    seg.view().into_dyn(),
1377                ],
1378                &mut output_boxes,
1379                &mut output_masks,
1380            )
1381            .unwrap();
1382
1383        // not expected for float segmentation decoder to have same values as quantized
1384        // segmentation decoder, as float decoder ensures the data fills 0-255,
1385        // quantized decoder uses whatever the model output. Thus the float
1386        // output is the same as the quantized output but scaled differently.
1387        // However, it is expected that the mask after argmax will be the same.
1388        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1389        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1390        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1391
1392        assert_eq!(mask0, mask1);
1393    }
1394
1395    #[test]
1396    fn test_modelpack_segdet_split() {
1397        let score_threshold = 0.8;
1398        let iou_threshold = 0.5;
1399
1400        let seg = include_bytes!(concat!(
1401            env!("CARGO_MANIFEST_DIR"),
1402            "/../../testdata/modelpack_seg_2x160x160.bin"
1403        ));
1404        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1405
1406        let detect0 = include_bytes!(concat!(
1407            env!("CARGO_MANIFEST_DIR"),
1408            "/../../testdata/modelpack_split_9x15x18.bin"
1409        ));
1410        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1411
1412        let detect1 = include_bytes!(concat!(
1413            env!("CARGO_MANIFEST_DIR"),
1414            "/../../testdata/modelpack_split_17x30x18.bin"
1415        ));
1416        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1417
1418        let quant0 = (0.08547406643629074, 174).into();
1419        let quant1 = (0.09929127991199493, 183).into();
1420        let quant_seg = (1.0 / 255.0, 0).into();
1421
1422        let anchors0 = vec![
1423            [0.36666667461395264, 0.31481480598449707],
1424            [0.38749998807907104, 0.4740740656852722],
1425            [0.5333333611488342, 0.644444465637207],
1426        ];
1427        let anchors1 = vec![
1428            [0.13750000298023224, 0.2074074000120163],
1429            [0.2541666626930237, 0.21481481194496155],
1430            [0.23125000298023224, 0.35185185074806213],
1431        ];
1432
1433        let decoder = DecoderBuilder::default()
1434            .with_config_modelpack_segdet_split(
1435                vec![
1436                    configs::Detection {
1437                        decoder: DecoderType::ModelPack,
1438                        shape: vec![1, 17, 30, 18],
1439                        anchors: Some(anchors1),
1440                        quantization: Some(quant1),
1441                        dshape: vec![
1442                            (DimName::Batch, 1),
1443                            (DimName::Height, 17),
1444                            (DimName::Width, 30),
1445                            (DimName::NumAnchorsXFeatures, 18),
1446                        ],
1447                        normalized: Some(true),
1448                    },
1449                    configs::Detection {
1450                        decoder: DecoderType::ModelPack,
1451                        shape: vec![1, 9, 15, 18],
1452                        anchors: Some(anchors0),
1453                        quantization: Some(quant0),
1454                        dshape: vec![
1455                            (DimName::Batch, 1),
1456                            (DimName::Height, 9),
1457                            (DimName::Width, 15),
1458                            (DimName::NumAnchorsXFeatures, 18),
1459                        ],
1460                        normalized: Some(true),
1461                    },
1462                ],
1463                configs::Segmentation {
1464                    decoder: DecoderType::ModelPack,
1465                    quantization: Some(quant_seg),
1466                    shape: vec![1, 2, 160, 160],
1467                    dshape: vec![
1468                        (DimName::Batch, 1),
1469                        (DimName::NumClasses, 2),
1470                        (DimName::Height, 160),
1471                        (DimName::Width, 160),
1472                    ],
1473                },
1474            )
1475            .with_score_threshold(score_threshold)
1476            .with_iou_threshold(iou_threshold)
1477            .build()
1478            .unwrap();
1479        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1480        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1481        decoder
1482            .decode_quantized(
1483                &[
1484                    detect0.view().into(),
1485                    detect1.view().into(),
1486                    seg.view().into(),
1487                ],
1488                &mut output_boxes,
1489                &mut output_masks,
1490            )
1491            .unwrap();
1492
1493        let mut mask = seg.slice(s![0, .., .., ..]);
1494        mask.swap_axes(0, 1);
1495        mask.swap_axes(1, 2);
1496        let mask = [Segmentation {
1497            xmin: 0.0,
1498            ymin: 0.0,
1499            xmax: 1.0,
1500            ymax: 1.0,
1501            segmentation: mask.into_owned(),
1502        }];
1503        let correct_boxes = [DetectBox {
1504            bbox: BoundingBox {
1505                xmin: 0.43171933,
1506                ymin: 0.68243736,
1507                xmax: 0.5626645,
1508                ymax: 0.808863,
1509            },
1510            score: 0.99240804,
1511            label: 0,
1512        }];
1513        println!("Output Boxes: {:?}", output_boxes);
1514        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1515
1516        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1517        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1518        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1519        decoder
1520            .decode_float::<f32>(
1521                &[
1522                    detect0.view().into_dyn(),
1523                    detect1.view().into_dyn(),
1524                    seg.view().into_dyn(),
1525                ],
1526                &mut output_boxes,
1527                &mut output_masks,
1528            )
1529            .unwrap();
1530
1531        // not expected for float segmentation decoder to have same values as quantized
1532        // segmentation decoder, as float decoder ensures the data fills 0-255,
1533        // quantized decoder uses whatever the model output. Thus the float
1534        // output is the same as the quantized output but scaled differently.
1535        // However, it is expected that the mask after argmax will be the same.
1536        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1537        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1538        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1539
1540        assert_eq!(mask0, mask1);
1541    }
1542
1543    #[test]
1544    fn test_dequant_chunked() {
1545        let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1546        out.push(123); // make sure to test non multiple of 16 length
1547
1548        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1549        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1550        let quant = Quantization::new(0.0040811873, -123);
1551        dequantize_cpu(&out, quant, &mut out_dequant);
1552
1553        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1554        assert_eq!(out_dequant, out_dequant_simd);
1555
1556        let quant = Quantization::new(0.0040811873, 0);
1557        dequantize_cpu(&out, quant, &mut out_dequant);
1558
1559        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1560        assert_eq!(out_dequant, out_dequant_simd);
1561    }
1562
1563    #[test]
1564    fn test_dequant_ground_truth() {
1565        // Formula: output = (input - zero_point) * scale
1566        // Verify both dequantize_cpu and dequantize_cpu_chunked against hand-computed values.
1567
1568        // Case 1: scale=0.1, zero_point=-128 (from doc example)
1569        let quant = Quantization::new(0.1, -128);
1570        let input: Vec<i8> = vec![0, 127, -128, 64];
1571        let mut output = vec![0.0f32; 4];
1572        let mut output_chunked = vec![0.0f32; 4];
1573        dequantize_cpu(&input, quant, &mut output);
1574        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1575        // (0 - (-128)) * 0.1 = 12.8
1576        // (127 - (-128)) * 0.1 = 25.5
1577        // (-128 - (-128)) * 0.1 = 0.0
1578        // (64 - (-128)) * 0.1 = 19.2
1579        let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1580        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1581            assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1582        }
1583        for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1584            assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1585        }
1586
1587        // Case 2: scale=1.0, zero_point=0 (identity-like)
1588        let quant = Quantization::new(1.0, 0);
1589        dequantize_cpu(&input, quant, &mut output);
1590        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1591        let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1592        assert_eq!(output, expected);
1593        assert_eq!(output_chunked, expected);
1594
1595        // Case 3: scale=0.5, zero_point=0
1596        let quant = Quantization::new(0.5, 0);
1597        dequantize_cpu(&input, quant, &mut output);
1598        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1599        let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1600        assert_eq!(output, expected);
1601        assert_eq!(output_chunked, expected);
1602
1603        // Case 4: i8 min/max boundaries with typical quantization params
1604        let quant = Quantization::new(0.021287762, 31);
1605        let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1606        let mut output = vec![0.0f32; 6];
1607        let mut output_chunked = vec![0.0f32; 6];
1608        dequantize_cpu(&input, quant, &mut output);
1609        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1610        for i in 0..6 {
1611            let expected = (input[i] as f32 - 31.0) * 0.021287762;
1612            assert!(
1613                (output[i] - expected).abs() < 1e-5,
1614                "cpu[{i}]: {} != {expected}",
1615                output[i]
1616            );
1617            assert!(
1618                (output_chunked[i] - expected).abs() < 1e-5,
1619                "chunked[{i}]: {} != {expected}",
1620                output_chunked[i]
1621            );
1622        }
1623    }
1624
1625    #[test]
1626    fn test_decoder_yolo_det() {
1627        let score_threshold = 0.25;
1628        let iou_threshold = 0.7;
1629        let out = load_yolov8s_det();
1630        let quant = (0.0040811873, -123).into();
1631
1632        let decoder = DecoderBuilder::default()
1633            .with_config_yolo_det(
1634                configs::Detection {
1635                    decoder: DecoderType::Ultralytics,
1636                    shape: vec![1, 84, 8400],
1637                    anchors: None,
1638                    quantization: Some(quant),
1639                    dshape: vec![
1640                        (DimName::Batch, 1),
1641                        (DimName::NumFeatures, 84),
1642                        (DimName::NumBoxes, 8400),
1643                    ],
1644                    normalized: Some(true),
1645                },
1646                Some(DecoderVersion::Yolo11),
1647            )
1648            .with_score_threshold(score_threshold)
1649            .with_iou_threshold(iou_threshold)
1650            .build()
1651            .unwrap();
1652
1653        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1654        decode_yolo_det(
1655            (out.slice(s![0, .., ..]), quant.into()),
1656            score_threshold,
1657            iou_threshold,
1658            Some(configs::Nms::ClassAgnostic),
1659            &mut output_boxes,
1660        );
1661        assert!(output_boxes[0].equal_within_delta(
1662            &DetectBox {
1663                bbox: BoundingBox {
1664                    xmin: 0.5285137,
1665                    ymin: 0.05305544,
1666                    xmax: 0.87541467,
1667                    ymax: 0.9998909,
1668                },
1669                score: 0.5591227,
1670                label: 0
1671            },
1672            1e-6
1673        ));
1674
1675        assert!(output_boxes[1].equal_within_delta(
1676            &DetectBox {
1677                bbox: BoundingBox {
1678                    xmin: 0.130598,
1679                    ymin: 0.43260583,
1680                    xmax: 0.35098213,
1681                    ymax: 0.9958097,
1682                },
1683                score: 0.33057618,
1684                label: 75
1685            },
1686            1e-6
1687        ));
1688
1689        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1690        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1691        decoder
1692            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1693            .unwrap();
1694
1695        let out = dequantize_ndarray(out.view(), quant.into());
1696        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1697        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1698        decoder
1699            .decode_float::<f32>(
1700                &[out.view().into_dyn()],
1701                &mut output_boxes_f32,
1702                &mut output_masks_f32,
1703            )
1704            .unwrap();
1705
1706        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1707        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1708    }
1709
1710    #[test]
1711    fn test_decoder_masks() {
1712        let score_threshold = 0.45;
1713        let iou_threshold = 0.45;
1714        let boxes = load_yolov8_boxes();
1715        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1716
1717        let protos = load_yolov8_protos();
1718        let quant_protos = Quantization::new(0.02491161972284317, -117);
1719        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1720        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1721        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1722        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1723        decode_yolo_segdet_float(
1724            seg.slice(s![0, .., ..]),
1725            protos.slice(s![0, .., .., ..]),
1726            score_threshold,
1727            iou_threshold,
1728            Some(configs::Nms::ClassAgnostic),
1729            &mut output_boxes,
1730            &mut output_masks,
1731        )
1732        .unwrap();
1733        assert_eq!(output_boxes.len(), 2);
1734        assert_eq!(output_boxes.len(), output_masks.len());
1735
1736        for (b, m) in output_boxes.iter().zip(&output_masks) {
1737            assert!(b.bbox.xmin >= m.xmin);
1738            assert!(b.bbox.ymin >= m.ymin);
1739            assert!(b.bbox.xmax >= m.xmax);
1740            assert!(b.bbox.ymax >= m.ymax);
1741        }
1742        assert!(output_boxes[0].equal_within_delta(
1743            &DetectBox {
1744                bbox: BoundingBox {
1745                    xmin: 0.08515105,
1746                    ymin: 0.7131401,
1747                    xmax: 0.29802868,
1748                    ymax: 0.8195788,
1749                },
1750                score: 0.91537374,
1751                label: 23
1752            },
1753            1.0 / 160.0, // wider range because mask will expand the box
1754        ));
1755
1756        assert!(output_boxes[1].equal_within_delta(
1757            &DetectBox {
1758                bbox: BoundingBox {
1759                    xmin: 0.59605736,
1760                    ymin: 0.25545314,
1761                    xmax: 0.93666154,
1762                    ymax: 0.72378385,
1763                },
1764                score: 0.91537374,
1765                label: 23
1766            },
1767            1.0 / 160.0, // wider range because mask will expand the box
1768        ));
1769
1770        let full_mask = include_bytes!(concat!(
1771            env!("CARGO_MANIFEST_DIR"),
1772            "/../../testdata/yolov8_mask_results.bin"
1773        ));
1774        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1775
1776        let cropped_mask = full_mask.slice(ndarray::s![
1777            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1778            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1779        ]);
1780
1781        assert_eq!(
1782            cropped_mask,
1783            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1784        );
1785    }
1786
1787    /// Regression test: config-driven path with NCHW protos (no dshape).
1788    /// Simulates YOLOv8-seg ONNX outputs where protos are (1, 32, 160, 160)
1789    /// and the YAML config has no dshape field — the exact scenario from
1790    /// hal_mask_matmul_bug.md.
1791    #[test]
1792    fn test_decoder_masks_nchw_protos() {
1793        let score_threshold = 0.45;
1794        let iou_threshold = 0.45;
1795
1796        // Load test data — boxes as [116, 8400]
1797        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1798        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1799
1800        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1801        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1802        let quant_protos = Quantization::new(0.02491161972284317, -117);
1803        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1804
1805        // ---- Reference: direct call with HWC protos (known working) ----
1806        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1807        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1808        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1809        decode_yolo_segdet_float(
1810            seg.view(),
1811            protos_f32_hwc.view(),
1812            score_threshold,
1813            iou_threshold,
1814            Some(configs::Nms::ClassAgnostic),
1815            &mut ref_boxes,
1816            &mut ref_masks,
1817        )
1818        .unwrap();
1819        assert_eq!(ref_boxes.len(), 2);
1820
1821        // ---- Config-driven path: NCHW protos, no dshape ----
1822        // Permute protos to NCHW [1, 32, 160, 160] as an ONNX model would output
1823        let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); // [32, 160, 160]
1824        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1825
1826        // Build boxes as [1, 116, 8400] f32
1827        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1828
1829        // Build decoder from config with no dshape on protos
1830        let decoder = DecoderBuilder::default()
1831            .with_config_yolo_segdet(
1832                configs::Detection {
1833                    decoder: configs::DecoderType::Ultralytics,
1834                    quantization: None,
1835                    shape: vec![1, 116, 8400],
1836                    dshape: vec![],
1837                    normalized: Some(true),
1838                    anchors: None,
1839                },
1840                configs::Protos {
1841                    decoder: configs::DecoderType::Ultralytics,
1842                    quantization: None,
1843                    shape: vec![1, 32, 160, 160],
1844                    dshape: vec![], // No dshape — simulates YAML without dshape
1845                },
1846                None, // decoder version
1847            )
1848            .with_score_threshold(score_threshold)
1849            .with_iou_threshold(iou_threshold)
1850            .build()
1851            .unwrap();
1852
1853        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1854        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1855        decoder
1856            .decode_float(
1857                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1858                &mut cfg_boxes,
1859                &mut cfg_masks,
1860            )
1861            .unwrap();
1862
1863        // Must produce the same number of detections
1864        assert_eq!(
1865            cfg_boxes.len(),
1866            ref_boxes.len(),
1867            "config path produced {} boxes, reference produced {}",
1868            cfg_boxes.len(),
1869            ref_boxes.len()
1870        );
1871
1872        // Boxes must match
1873        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1874            assert!(
1875                cb.equal_within_delta(rb, 0.01),
1876                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1877            );
1878        }
1879
1880        // Masks must match pixel-for-pixel
1881        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1882            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1883            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1884            assert_eq!(
1885                cm_arr, rm_arr,
1886                "mask {i} pixel mismatch between config-driven and reference paths"
1887            );
1888        }
1889    }
1890
1891    #[test]
1892    fn test_decoder_masks_i8() {
1893        let score_threshold = 0.45;
1894        let iou_threshold = 0.45;
1895        let boxes = load_yolov8_boxes();
1896        let quant_boxes = (0.021287761628627777, 31).into();
1897
1898        let protos = load_yolov8_protos();
1899        let quant_protos = (0.02491161972284317, -117).into();
1900        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1901        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1902
1903        let decoder = DecoderBuilder::default()
1904            .with_config_yolo_segdet(
1905                configs::Detection {
1906                    decoder: configs::DecoderType::Ultralytics,
1907                    quantization: Some(quant_boxes),
1908                    shape: vec![1, 116, 8400],
1909                    anchors: None,
1910                    dshape: vec![
1911                        (DimName::Batch, 1),
1912                        (DimName::NumFeatures, 116),
1913                        (DimName::NumBoxes, 8400),
1914                    ],
1915                    normalized: Some(true),
1916                },
1917                Protos {
1918                    decoder: configs::DecoderType::Ultralytics,
1919                    quantization: Some(quant_protos),
1920                    shape: vec![1, 160, 160, 32],
1921                    dshape: vec![
1922                        (DimName::Batch, 1),
1923                        (DimName::Height, 160),
1924                        (DimName::Width, 160),
1925                        (DimName::NumProtos, 32),
1926                    ],
1927                },
1928                Some(DecoderVersion::Yolo11),
1929            )
1930            .with_score_threshold(score_threshold)
1931            .with_iou_threshold(iou_threshold)
1932            .build()
1933            .unwrap();
1934
1935        let quant_boxes = quant_boxes.into();
1936        let quant_protos = quant_protos.into();
1937
1938        decode_yolo_segdet_quant(
1939            (boxes.slice(s![0, .., ..]), quant_boxes),
1940            (protos.slice(s![0, .., .., ..]), quant_protos),
1941            score_threshold,
1942            iou_threshold,
1943            Some(configs::Nms::ClassAgnostic),
1944            &mut output_boxes,
1945            &mut output_masks,
1946        )
1947        .unwrap();
1948
1949        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1950        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1951
1952        decoder
1953            .decode_quantized(
1954                &[boxes.view().into(), protos.view().into()],
1955                &mut output_boxes1,
1956                &mut output_masks1,
1957            )
1958            .unwrap();
1959
1960        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1961        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1962
1963        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1964        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1965        decode_yolo_segdet_float(
1966            seg.slice(s![0, .., ..]),
1967            protos.slice(s![0, .., .., ..]),
1968            score_threshold,
1969            iou_threshold,
1970            Some(configs::Nms::ClassAgnostic),
1971            &mut output_boxes_f32,
1972            &mut output_masks_f32,
1973        )
1974        .unwrap();
1975
1976        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1977        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1978
1979        decoder
1980            .decode_float(
1981                &[seg.view().into_dyn(), protos.view().into_dyn()],
1982                &mut output_boxes1_f32,
1983                &mut output_masks1_f32,
1984            )
1985            .unwrap();
1986
1987        compare_outputs(
1988            (&output_boxes, &output_boxes1),
1989            (&output_masks, &output_masks1),
1990        );
1991
1992        compare_outputs(
1993            (&output_boxes, &output_boxes_f32),
1994            (&output_masks, &output_masks_f32),
1995        );
1996
1997        compare_outputs(
1998            (&output_boxes_f32, &output_boxes1_f32),
1999            (&output_masks_f32, &output_masks1_f32),
2000        );
2001    }
2002
2003    #[test]
2004    fn test_decoder_yolo_split() {
2005        let score_threshold = 0.45;
2006        let iou_threshold = 0.45;
2007        let boxes = load_yolov8_boxes();
2008        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2009        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2010
2011        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2012
2013        let decoder = DecoderBuilder::default()
2014            .with_config_yolo_split_det(
2015                configs::Boxes {
2016                    decoder: configs::DecoderType::Ultralytics,
2017                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2018                    shape: vec![1, 4, 8400],
2019                    dshape: vec![
2020                        (DimName::Batch, 1),
2021                        (DimName::BoxCoords, 4),
2022                        (DimName::NumBoxes, 8400),
2023                    ],
2024                    normalized: Some(true),
2025                },
2026                configs::Scores {
2027                    decoder: configs::DecoderType::Ultralytics,
2028                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2029                    shape: vec![1, 80, 8400],
2030                    dshape: vec![
2031                        (DimName::Batch, 1),
2032                        (DimName::NumClasses, 80),
2033                        (DimName::NumBoxes, 8400),
2034                    ],
2035                },
2036            )
2037            .with_score_threshold(score_threshold)
2038            .with_iou_threshold(iou_threshold)
2039            .build()
2040            .unwrap();
2041
2042        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2043        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2044
2045        decoder
2046            .decode_quantized(
2047                &[
2048                    boxes.slice(s![.., ..4, ..]).into(),
2049                    boxes.slice(s![.., 4..84, ..]).into(),
2050                ],
2051                &mut output_boxes,
2052                &mut output_masks,
2053            )
2054            .unwrap();
2055
2056        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2057        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2058        decode_yolo_det_float(
2059            seg.slice(s![0, ..84, ..]),
2060            score_threshold,
2061            iou_threshold,
2062            Some(configs::Nms::ClassAgnostic),
2063            &mut output_boxes_f32,
2064        );
2065
2066        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2067        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2068
2069        decoder
2070            .decode_float(
2071                &[
2072                    seg.slice(s![.., ..4, ..]).into_dyn(),
2073                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2074                ],
2075                &mut output_boxes1,
2076                &mut output_masks1,
2077            )
2078            .unwrap();
2079        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2080        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2081    }
2082
2083    #[test]
2084    fn test_decoder_masks_config_mixed() {
2085        let score_threshold = 0.45;
2086        let iou_threshold = 0.45;
2087        let boxes_raw = load_yolov8_boxes();
2088        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2089        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2090
2091        let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2092
2093        let protos = load_yolov8_protos();
2094        let quant_protos = (0.02491161972284317, -117);
2095
2096        let decoder = build_yolo_split_segdet_decoder(
2097            score_threshold,
2098            iou_threshold,
2099            quant_boxes,
2100            quant_protos,
2101        );
2102        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2103        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2104
2105        decoder
2106            .decode_quantized(
2107                &[
2108                    boxes.slice(s![.., ..4, ..]).into(),
2109                    boxes.slice(s![.., 4..84, ..]).into(),
2110                    boxes.slice(s![.., 84.., ..]).into(),
2111                    protos.view().into(),
2112                ],
2113                &mut output_boxes,
2114                &mut output_masks,
2115            )
2116            .unwrap();
2117
2118        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2119        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2120        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2121        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2122        decode_yolo_segdet_float(
2123            seg.slice(s![0, .., ..]),
2124            protos.slice(s![0, .., .., ..]),
2125            score_threshold,
2126            iou_threshold,
2127            Some(configs::Nms::ClassAgnostic),
2128            &mut output_boxes_f32,
2129            &mut output_masks_f32,
2130        )
2131        .unwrap();
2132
2133        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2134        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2135
2136        decoder
2137            .decode_float(
2138                &[
2139                    seg.slice(s![.., ..4, ..]).into_dyn(),
2140                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2141                    seg.slice(s![.., 84.., ..]).into_dyn(),
2142                    protos.view().into_dyn(),
2143                ],
2144                &mut output_boxes1,
2145                &mut output_masks1,
2146            )
2147            .unwrap();
2148        compare_outputs(
2149            (&output_boxes, &output_boxes_f32),
2150            (&output_masks, &output_masks_f32),
2151        );
2152        compare_outputs(
2153            (&output_boxes_f32, &output_boxes1),
2154            (&output_masks_f32, &output_masks1),
2155        );
2156    }
2157
2158    fn build_yolo_split_segdet_decoder(
2159        score_threshold: f32,
2160        iou_threshold: f32,
2161        quant_boxes: (f32, i32),
2162        quant_protos: (f32, i32),
2163    ) -> crate::Decoder {
2164        DecoderBuilder::default()
2165            .with_config_yolo_split_segdet(
2166                configs::Boxes {
2167                    decoder: configs::DecoderType::Ultralytics,
2168                    quantization: Some(quant_boxes.into()),
2169                    shape: vec![1, 4, 8400],
2170                    dshape: vec![
2171                        (DimName::Batch, 1),
2172                        (DimName::BoxCoords, 4),
2173                        (DimName::NumBoxes, 8400),
2174                    ],
2175                    normalized: Some(true),
2176                },
2177                configs::Scores {
2178                    decoder: configs::DecoderType::Ultralytics,
2179                    quantization: Some(quant_boxes.into()),
2180                    shape: vec![1, 80, 8400],
2181                    dshape: vec![
2182                        (DimName::Batch, 1),
2183                        (DimName::NumClasses, 80),
2184                        (DimName::NumBoxes, 8400),
2185                    ],
2186                },
2187                configs::MaskCoefficients {
2188                    decoder: configs::DecoderType::Ultralytics,
2189                    quantization: Some(quant_boxes.into()),
2190                    shape: vec![1, 32, 8400],
2191                    dshape: vec![
2192                        (DimName::Batch, 1),
2193                        (DimName::NumProtos, 32),
2194                        (DimName::NumBoxes, 8400),
2195                    ],
2196                },
2197                configs::Protos {
2198                    decoder: configs::DecoderType::Ultralytics,
2199                    quantization: Some(quant_protos.into()),
2200                    shape: vec![1, 160, 160, 32],
2201                    dshape: vec![
2202                        (DimName::Batch, 1),
2203                        (DimName::Height, 160),
2204                        (DimName::Width, 160),
2205                        (DimName::NumProtos, 32),
2206                    ],
2207                },
2208            )
2209            .with_score_threshold(score_threshold)
2210            .with_iou_threshold(iou_threshold)
2211            .build()
2212            .unwrap()
2213    }
2214
2215    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2216        let config_yaml = include_str!(concat!(
2217            env!("CARGO_MANIFEST_DIR"),
2218            "/../../testdata/yolov8_seg.yaml"
2219        ));
2220        DecoderBuilder::default()
2221            .with_config_yaml_str(config_yaml.to_string())
2222            .with_score_threshold(score_threshold)
2223            .with_iou_threshold(iou_threshold)
2224            .build()
2225            .unwrap()
2226    }
2227    #[test]
2228    fn test_decoder_masks_config_i32() {
2229        let score_threshold = 0.45;
2230        let iou_threshold = 0.45;
2231        let boxes_raw = load_yolov8_boxes();
2232        let scale = 1 << 23;
2233        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2234        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2235
2236        let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2237
2238        let protos_raw = load_yolov8_protos();
2239        let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2240        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2241        let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2242
2243        let decoder = build_yolo_split_segdet_decoder(
2244            score_threshold,
2245            iou_threshold,
2246            quant_boxes,
2247            quant_protos,
2248        );
2249
2250        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2251        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2252
2253        decoder
2254            .decode_quantized(
2255                &[
2256                    boxes.slice(s![.., ..4, ..]).into(),
2257                    boxes.slice(s![.., 4..84, ..]).into(),
2258                    boxes.slice(s![.., 84.., ..]).into(),
2259                    protos.view().into(),
2260                ],
2261                &mut output_boxes,
2262                &mut output_masks,
2263            )
2264            .unwrap();
2265
2266        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2267        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2268        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2269        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2270        decode_yolo_segdet_float(
2271            seg.slice(s![0, .., ..]),
2272            protos.slice(s![0, .., .., ..]),
2273            score_threshold,
2274            iou_threshold,
2275            Some(configs::Nms::ClassAgnostic),
2276            &mut output_boxes_f32,
2277            &mut output_masks_f32,
2278        )
2279        .unwrap();
2280
2281        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2282        assert_eq!(output_masks.len(), output_masks_f32.len());
2283
2284        compare_outputs(
2285            (&output_boxes, &output_boxes_f32),
2286            (&output_masks, &output_masks_f32),
2287        );
2288    }
2289
2290    /// test running multiple decoders concurrently
2291    #[test]
2292    fn test_context_switch() {
2293        let yolo_det = || {
2294            let score_threshold = 0.25;
2295            let iou_threshold = 0.7;
2296            let out = load_yolov8s_det();
2297            let quant = (0.0040811873, -123).into();
2298
2299            let decoder = DecoderBuilder::default()
2300                .with_config_yolo_det(
2301                    configs::Detection {
2302                        decoder: DecoderType::Ultralytics,
2303                        shape: vec![1, 84, 8400],
2304                        anchors: None,
2305                        quantization: Some(quant),
2306                        dshape: vec![
2307                            (DimName::Batch, 1),
2308                            (DimName::NumFeatures, 84),
2309                            (DimName::NumBoxes, 8400),
2310                        ],
2311                        normalized: None,
2312                    },
2313                    None,
2314                )
2315                .with_score_threshold(score_threshold)
2316                .with_iou_threshold(iou_threshold)
2317                .build()
2318                .unwrap();
2319
2320            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2321            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2322
2323            for _ in 0..100 {
2324                decoder
2325                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2326                    .unwrap();
2327
2328                assert!(output_boxes[0].equal_within_delta(
2329                    &DetectBox {
2330                        bbox: BoundingBox {
2331                            xmin: 0.5285137,
2332                            ymin: 0.05305544,
2333                            xmax: 0.87541467,
2334                            ymax: 0.9998909,
2335                        },
2336                        score: 0.5591227,
2337                        label: 0
2338                    },
2339                    1e-6
2340                ));
2341
2342                assert!(output_boxes[1].equal_within_delta(
2343                    &DetectBox {
2344                        bbox: BoundingBox {
2345                            xmin: 0.130598,
2346                            ymin: 0.43260583,
2347                            xmax: 0.35098213,
2348                            ymax: 0.9958097,
2349                        },
2350                        score: 0.33057618,
2351                        label: 75
2352                    },
2353                    1e-6
2354                ));
2355                assert!(output_masks.is_empty());
2356            }
2357        };
2358
2359        let modelpack_det_split = || {
2360            let score_threshold = 0.8;
2361            let iou_threshold = 0.5;
2362
2363            let seg = include_bytes!(concat!(
2364                env!("CARGO_MANIFEST_DIR"),
2365                "/../../testdata/modelpack_seg_2x160x160.bin"
2366            ));
2367            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2368
2369            let detect0 = include_bytes!(concat!(
2370                env!("CARGO_MANIFEST_DIR"),
2371                "/../../testdata/modelpack_split_9x15x18.bin"
2372            ));
2373            let detect0 =
2374                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2375
2376            let detect1 = include_bytes!(concat!(
2377                env!("CARGO_MANIFEST_DIR"),
2378                "/../../testdata/modelpack_split_17x30x18.bin"
2379            ));
2380            let detect1 =
2381                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2382
2383            let mut mask = seg.slice(s![0, .., .., ..]);
2384            mask.swap_axes(0, 1);
2385            mask.swap_axes(1, 2);
2386            let mask = [Segmentation {
2387                xmin: 0.0,
2388                ymin: 0.0,
2389                xmax: 1.0,
2390                ymax: 1.0,
2391                segmentation: mask.into_owned(),
2392            }];
2393            let correct_boxes = [DetectBox {
2394                bbox: BoundingBox {
2395                    xmin: 0.43171933,
2396                    ymin: 0.68243736,
2397                    xmax: 0.5626645,
2398                    ymax: 0.808863,
2399                },
2400                score: 0.99240804,
2401                label: 0,
2402            }];
2403
2404            let quant0 = (0.08547406643629074, 174).into();
2405            let quant1 = (0.09929127991199493, 183).into();
2406            let quant_seg = (1.0 / 255.0, 0).into();
2407
2408            let anchors0 = vec![
2409                [0.36666667461395264, 0.31481480598449707],
2410                [0.38749998807907104, 0.4740740656852722],
2411                [0.5333333611488342, 0.644444465637207],
2412            ];
2413            let anchors1 = vec![
2414                [0.13750000298023224, 0.2074074000120163],
2415                [0.2541666626930237, 0.21481481194496155],
2416                [0.23125000298023224, 0.35185185074806213],
2417            ];
2418
2419            let decoder = DecoderBuilder::default()
2420                .with_config_modelpack_segdet_split(
2421                    vec![
2422                        configs::Detection {
2423                            decoder: DecoderType::ModelPack,
2424                            shape: vec![1, 17, 30, 18],
2425                            anchors: Some(anchors1),
2426                            quantization: Some(quant1),
2427                            dshape: vec![
2428                                (DimName::Batch, 1),
2429                                (DimName::Height, 17),
2430                                (DimName::Width, 30),
2431                                (DimName::NumAnchorsXFeatures, 18),
2432                            ],
2433                            normalized: None,
2434                        },
2435                        configs::Detection {
2436                            decoder: DecoderType::ModelPack,
2437                            shape: vec![1, 9, 15, 18],
2438                            anchors: Some(anchors0),
2439                            quantization: Some(quant0),
2440                            dshape: vec![
2441                                (DimName::Batch, 1),
2442                                (DimName::Height, 9),
2443                                (DimName::Width, 15),
2444                                (DimName::NumAnchorsXFeatures, 18),
2445                            ],
2446                            normalized: None,
2447                        },
2448                    ],
2449                    configs::Segmentation {
2450                        decoder: DecoderType::ModelPack,
2451                        quantization: Some(quant_seg),
2452                        shape: vec![1, 2, 160, 160],
2453                        dshape: vec![
2454                            (DimName::Batch, 1),
2455                            (DimName::NumClasses, 2),
2456                            (DimName::Height, 160),
2457                            (DimName::Width, 160),
2458                        ],
2459                    },
2460                )
2461                .with_score_threshold(score_threshold)
2462                .with_iou_threshold(iou_threshold)
2463                .build()
2464                .unwrap();
2465            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2466            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2467
2468            for _ in 0..100 {
2469                decoder
2470                    .decode_quantized(
2471                        &[
2472                            detect0.view().into(),
2473                            detect1.view().into(),
2474                            seg.view().into(),
2475                        ],
2476                        &mut output_boxes,
2477                        &mut output_masks,
2478                    )
2479                    .unwrap();
2480
2481                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2482            }
2483        };
2484
2485        let handles = vec![
2486            std::thread::spawn(yolo_det),
2487            std::thread::spawn(modelpack_det_split),
2488            std::thread::spawn(yolo_det),
2489            std::thread::spawn(modelpack_det_split),
2490            std::thread::spawn(yolo_det),
2491            std::thread::spawn(modelpack_det_split),
2492            std::thread::spawn(yolo_det),
2493            std::thread::spawn(modelpack_det_split),
2494        ];
2495        for handle in handles {
2496            handle.join().unwrap();
2497        }
2498    }
2499
2500    #[test]
2501    fn test_ndarray_to_xyxy_float() {
2502        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2503        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2504        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2505
2506        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2507        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2508        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2509    }
2510
2511    #[test]
2512    fn test_class_aware_nms_float() {
2513        use crate::float::nms_class_aware_float;
2514
2515        // Create two overlapping boxes with different classes
2516        let boxes = vec![
2517            DetectBox {
2518                bbox: BoundingBox {
2519                    xmin: 0.0,
2520                    ymin: 0.0,
2521                    xmax: 0.5,
2522                    ymax: 0.5,
2523                },
2524                score: 0.9,
2525                label: 0, // class 0
2526            },
2527            DetectBox {
2528                bbox: BoundingBox {
2529                    xmin: 0.1,
2530                    ymin: 0.1,
2531                    xmax: 0.6,
2532                    ymax: 0.6,
2533                },
2534                score: 0.8,
2535                label: 1, // class 1 - different class
2536            },
2537        ];
2538
2539        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2540        // threshold 0.3)
2541        let result = nms_class_aware_float(0.3, boxes.clone());
2542        assert_eq!(
2543            result.len(),
2544            2,
2545            "Class-aware NMS should keep both boxes with different classes"
2546        );
2547
2548        // Now test with same class - should suppress one
2549        let same_class_boxes = vec![
2550            DetectBox {
2551                bbox: BoundingBox {
2552                    xmin: 0.0,
2553                    ymin: 0.0,
2554                    xmax: 0.5,
2555                    ymax: 0.5,
2556                },
2557                score: 0.9,
2558                label: 0,
2559            },
2560            DetectBox {
2561                bbox: BoundingBox {
2562                    xmin: 0.1,
2563                    ymin: 0.1,
2564                    xmax: 0.6,
2565                    ymax: 0.6,
2566                },
2567                score: 0.8,
2568                label: 0, // same class
2569            },
2570        ];
2571
2572        let result = nms_class_aware_float(0.3, same_class_boxes);
2573        assert_eq!(
2574            result.len(),
2575            1,
2576            "Class-aware NMS should suppress overlapping box with same class"
2577        );
2578        assert_eq!(result[0].label, 0);
2579        assert!((result[0].score - 0.9).abs() < 1e-6);
2580    }
2581
2582    #[test]
2583    fn test_class_agnostic_vs_aware_nms() {
2584        use crate::float::{nms_class_aware_float, nms_float};
2585
2586        // Two overlapping boxes with different classes
2587        let boxes = vec![
2588            DetectBox {
2589                bbox: BoundingBox {
2590                    xmin: 0.0,
2591                    ymin: 0.0,
2592                    xmax: 0.5,
2593                    ymax: 0.5,
2594                },
2595                score: 0.9,
2596                label: 0,
2597            },
2598            DetectBox {
2599                bbox: BoundingBox {
2600                    xmin: 0.1,
2601                    ymin: 0.1,
2602                    xmax: 0.6,
2603                    ymax: 0.6,
2604                },
2605                score: 0.8,
2606                label: 1,
2607            },
2608        ];
2609
2610        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2611        let agnostic_result = nms_float(0.3, boxes.clone());
2612        assert_eq!(
2613            agnostic_result.len(),
2614            1,
2615            "Class-agnostic NMS should suppress overlapping boxes"
2616        );
2617
2618        // Class-aware should keep both (different classes)
2619        let aware_result = nms_class_aware_float(0.3, boxes);
2620        assert_eq!(
2621            aware_result.len(),
2622            2,
2623            "Class-aware NMS should keep boxes with different classes"
2624        );
2625    }
2626
2627    #[test]
2628    fn test_class_aware_nms_int() {
2629        use crate::byte::nms_class_aware_int;
2630
2631        // Create two overlapping boxes with different classes
2632        let boxes = vec![
2633            DetectBoxQuantized {
2634                bbox: BoundingBox {
2635                    xmin: 0.0,
2636                    ymin: 0.0,
2637                    xmax: 0.5,
2638                    ymax: 0.5,
2639                },
2640                score: 200_u8,
2641                label: 0,
2642            },
2643            DetectBoxQuantized {
2644                bbox: BoundingBox {
2645                    xmin: 0.1,
2646                    ymin: 0.1,
2647                    xmax: 0.6,
2648                    ymax: 0.6,
2649                },
2650                score: 180_u8,
2651                label: 1, // different class
2652            },
2653        ];
2654
2655        // Should keep both (different classes)
2656        let result = nms_class_aware_int(0.5, boxes);
2657        assert_eq!(
2658            result.len(),
2659            2,
2660            "Class-aware NMS (int) should keep boxes with different classes"
2661        );
2662    }
2663
2664    #[test]
2665    fn test_nms_enum_default() {
2666        // Test that Nms enum has the correct default
2667        let default_nms: configs::Nms = Default::default();
2668        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2669    }
2670
2671    #[test]
2672    fn test_decoder_nms_mode() {
2673        // Test that decoder properly stores NMS mode
2674        let decoder = DecoderBuilder::default()
2675            .with_config_yolo_det(
2676                configs::Detection {
2677                    anchors: None,
2678                    decoder: DecoderType::Ultralytics,
2679                    quantization: None,
2680                    shape: vec![1, 84, 8400],
2681                    dshape: Vec::new(),
2682                    normalized: Some(true),
2683                },
2684                None,
2685            )
2686            .with_nms(Some(configs::Nms::ClassAware))
2687            .build()
2688            .unwrap();
2689
2690        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2691    }
2692
2693    #[test]
2694    fn test_decoder_nms_bypass() {
2695        // Test that decoder can be configured with nms=None (bypass)
2696        let decoder = DecoderBuilder::default()
2697            .with_config_yolo_det(
2698                configs::Detection {
2699                    anchors: None,
2700                    decoder: DecoderType::Ultralytics,
2701                    quantization: None,
2702                    shape: vec![1, 84, 8400],
2703                    dshape: Vec::new(),
2704                    normalized: Some(true),
2705                },
2706                None,
2707            )
2708            .with_nms(None)
2709            .build()
2710            .unwrap();
2711
2712        assert_eq!(decoder.nms, None);
2713    }
2714
2715    #[test]
2716    fn test_decoder_normalized_boxes_true() {
2717        // Test that normalized_boxes returns Some(true) when explicitly set
2718        let decoder = DecoderBuilder::default()
2719            .with_config_yolo_det(
2720                configs::Detection {
2721                    anchors: None,
2722                    decoder: DecoderType::Ultralytics,
2723                    quantization: None,
2724                    shape: vec![1, 84, 8400],
2725                    dshape: Vec::new(),
2726                    normalized: Some(true),
2727                },
2728                None,
2729            )
2730            .build()
2731            .unwrap();
2732
2733        assert_eq!(decoder.normalized_boxes(), Some(true));
2734    }
2735
2736    #[test]
2737    fn test_decoder_normalized_boxes_false() {
2738        // Test that normalized_boxes returns Some(false) when config specifies
2739        // unnormalized
2740        let decoder = DecoderBuilder::default()
2741            .with_config_yolo_det(
2742                configs::Detection {
2743                    anchors: None,
2744                    decoder: DecoderType::Ultralytics,
2745                    quantization: None,
2746                    shape: vec![1, 84, 8400],
2747                    dshape: Vec::new(),
2748                    normalized: Some(false),
2749                },
2750                None,
2751            )
2752            .build()
2753            .unwrap();
2754
2755        assert_eq!(decoder.normalized_boxes(), Some(false));
2756    }
2757
2758    #[test]
2759    fn test_decoder_normalized_boxes_unknown() {
2760        // Test that normalized_boxes returns None when not specified in config
2761        let decoder = DecoderBuilder::default()
2762            .with_config_yolo_det(
2763                configs::Detection {
2764                    anchors: None,
2765                    decoder: DecoderType::Ultralytics,
2766                    quantization: None,
2767                    shape: vec![1, 84, 8400],
2768                    dshape: Vec::new(),
2769                    normalized: None,
2770                },
2771                Some(DecoderVersion::Yolo11),
2772            )
2773            .build()
2774            .unwrap();
2775
2776        assert_eq!(decoder.normalized_boxes(), None);
2777    }
2778
2779    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2780        input: ArrayView<F, D>,
2781        quant: Quantization,
2782    ) -> Array<T, D>
2783    where
2784        i32: num_traits::AsPrimitive<F>,
2785        f32: num_traits::AsPrimitive<F>,
2786    {
2787        let zero_point = quant.zero_point.as_();
2788        let div_scale = F::one() / quant.scale.as_();
2789        if zero_point != F::zero() {
2790            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2791        } else {
2792            input.mapv(|d| (d * div_scale).round().as_())
2793        }
2794    }
2795
2796    fn real_data_expected_boxes() -> [DetectBox; 2] {
2797        [
2798            DetectBox {
2799                bbox: BoundingBox {
2800                    xmin: 0.08515105,
2801                    ymin: 0.7131401,
2802                    xmax: 0.29802868,
2803                    ymax: 0.8195788,
2804                },
2805                score: 0.91537374,
2806                label: 23,
2807            },
2808            DetectBox {
2809                bbox: BoundingBox {
2810                    xmin: 0.59605736,
2811                    ymin: 0.25545314,
2812                    xmax: 0.93666154,
2813                    ymax: 0.72378385,
2814                },
2815                score: 0.91537374,
2816                label: 23,
2817            },
2818        ]
2819    }
2820
2821    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2822        [DetectBox {
2823            bbox: BoundingBox {
2824                xmin: 0.12549022,
2825                ymin: 0.12549022,
2826                xmax: 0.23529413,
2827                ymax: 0.23529413,
2828            },
2829            score: 0.98823535,
2830            label: 2,
2831        }]
2832    }
2833
2834    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2835        [DetectBox {
2836            bbox: BoundingBox {
2837                xmin: 0.1234,
2838                ymin: 0.1234,
2839                xmax: 0.2345,
2840                ymax: 0.2345,
2841            },
2842            score: 0.9876,
2843            label: 2,
2844        }]
2845    }
2846
2847    macro_rules! real_data_proto_test {
2848        ($name:ident, quantized, $layout:ident) => {
2849            #[test]
2850            fn $name() {
2851                let is_split = matches!(stringify!($layout), "split");
2852
2853                let score_threshold = 0.45;
2854                let iou_threshold = 0.45;
2855                let quant_boxes = (0.021287762_f32, 31_i32);
2856                let quant_protos = (0.02491162_f32, -117_i32);
2857
2858                let raw_boxes = include_bytes!(concat!(
2859                    env!("CARGO_MANIFEST_DIR"),
2860                    "/../../testdata/yolov8_boxes_116x8400.bin"
2861                ));
2862                let raw_boxes = unsafe {
2863                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2864                };
2865                let boxes_i8 =
2866                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2867
2868                let raw_protos = include_bytes!(concat!(
2869                    env!("CARGO_MANIFEST_DIR"),
2870                    "/../../testdata/yolov8_protos_160x160x32.bin"
2871                ));
2872                let raw_protos = unsafe {
2873                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2874                };
2875                let protos_i8 =
2876                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2877                        .unwrap();
2878
2879                // Pre-split (unused for combined, but harmless)
2880                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2881                let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2882                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2883                let boxes_combined = boxes_i8;
2884
2885                let decoder = if is_split {
2886                    build_yolo_split_segdet_decoder(
2887                        score_threshold,
2888                        iou_threshold,
2889                        quant_boxes,
2890                        quant_protos,
2891                    )
2892                } else {
2893                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2894                };
2895
2896                let expected = real_data_expected_boxes();
2897                let mut output_boxes = Vec::with_capacity(50);
2898
2899                let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
2900                    vec![
2901                        boxes_split.view().into(),
2902                        scores_split.view().into(),
2903                        mask_split.view().into(),
2904                        protos_i8.view().into(),
2905                    ]
2906                } else {
2907                    vec![boxes_combined.view().into(), protos_i8.view().into()]
2908                };
2909                decoder
2910                    .decode_quantized_proto(&inputs, &mut output_boxes)
2911                    .unwrap();
2912
2913                assert_eq!(output_boxes.len(), 2);
2914                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2915                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2916            }
2917        };
2918        ($name:ident, float, $layout:ident) => {
2919            #[test]
2920            fn $name() {
2921                let is_split = matches!(stringify!($layout), "split");
2922
2923                let score_threshold = 0.45;
2924                let iou_threshold = 0.45;
2925                let quant_boxes = (0.021287762_f32, 31_i32);
2926                let quant_protos = (0.02491162_f32, -117_i32);
2927
2928                let raw_boxes = include_bytes!(concat!(
2929                    env!("CARGO_MANIFEST_DIR"),
2930                    "/../../testdata/yolov8_boxes_116x8400.bin"
2931                ));
2932                let raw_boxes = unsafe {
2933                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2934                };
2935                let boxes_i8 =
2936                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2937                let boxes_f32: Array3<f32> =
2938                    dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
2939
2940                let raw_protos = include_bytes!(concat!(
2941                    env!("CARGO_MANIFEST_DIR"),
2942                    "/../../testdata/yolov8_protos_160x160x32.bin"
2943                ));
2944                let raw_protos = unsafe {
2945                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2946                };
2947                let protos_i8 =
2948                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2949                        .unwrap();
2950                let protos_f32: Array4<f32> =
2951                    dequantize_ndarray(protos_i8.view(), quant_protos.into());
2952
2953                // Pre-split from dequantized data
2954                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
2955                let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
2956                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
2957                let boxes_combined = boxes_f32;
2958
2959                let decoder = if is_split {
2960                    build_yolo_split_segdet_decoder(
2961                        score_threshold,
2962                        iou_threshold,
2963                        quant_boxes,
2964                        quant_protos,
2965                    )
2966                } else {
2967                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2968                };
2969
2970                let expected = real_data_expected_boxes();
2971                let mut output_boxes = Vec::with_capacity(50);
2972
2973                let inputs = if is_split {
2974                    vec![
2975                        boxes_split.view().into_dyn(),
2976                        scores_split.view().into_dyn(),
2977                        mask_split.view().into_dyn(),
2978                        protos_f32.view().into_dyn(),
2979                    ]
2980                } else {
2981                    vec![
2982                        boxes_combined.view().into_dyn(),
2983                        protos_f32.view().into_dyn(),
2984                    ]
2985                };
2986                decoder
2987                    .decode_float_proto(&inputs, &mut output_boxes)
2988                    .unwrap();
2989
2990                assert_eq!(output_boxes.len(), 2);
2991                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2992                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2993            }
2994        };
2995    }
2996
2997    real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
2998    real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
2999    real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
3000    real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3001
3002    const E2E_COMBINED_DET_CONFIG: &str = "
3003decoder_version: yolo26
3004outputs:
3005 - type: detection
3006   decoder: ultralytics
3007   quantization: [0.00784313725490196, 0]
3008   shape: [1, 10, 6]
3009   dshape:
3010    - [batch, 1]
3011    - [num_boxes, 10]
3012    - [num_features, 6]
3013   normalized: true
3014";
3015
3016    const E2E_COMBINED_SEGDET_CONFIG: &str = "
3017decoder_version: yolo26
3018outputs:
3019 - type: detection
3020   decoder: ultralytics
3021   quantization: [0.00784313725490196, 0]
3022   shape: [1, 10, 38]
3023   dshape:
3024    - [batch, 1]
3025    - [num_boxes, 10]
3026    - [num_features, 38]
3027   normalized: true
3028 - type: protos
3029   decoder: ultralytics
3030   quantization: [0.0039215686274509803921568627451, 128]
3031   shape: [1, 160, 160, 32]
3032   dshape:
3033    - [batch, 1]
3034    - [height, 160]
3035    - [width, 160]
3036    - [num_protos, 32]
3037";
3038
3039    const E2E_SPLIT_DET_CONFIG: &str = "
3040decoder_version: yolo26
3041outputs:
3042 - type: boxes
3043   decoder: ultralytics
3044   quantization: [0.00784313725490196, 0]
3045   shape: [1, 10, 4]
3046   dshape:
3047    - [batch, 1]
3048    - [num_boxes, 10]
3049    - [box_coords, 4]
3050   normalized: true
3051 - type: scores
3052   decoder: ultralytics
3053   quantization: [0.00784313725490196, 0]
3054   shape: [1, 10, 1]
3055   dshape:
3056    - [batch, 1]
3057    - [num_boxes, 10]
3058    - [num_classes, 1]
3059 - type: classes
3060   decoder: ultralytics
3061   quantization: [0.00784313725490196, 0]
3062   shape: [1, 10, 1]
3063   dshape:
3064    - [batch, 1]
3065    - [num_boxes, 10]
3066    - [num_classes, 1]
3067";
3068
3069    const E2E_SPLIT_SEGDET_CONFIG: &str = "
3070decoder_version: yolo26
3071outputs:
3072 - type: boxes
3073   decoder: ultralytics
3074   quantization: [0.00784313725490196, 0]
3075   shape: [1, 10, 4]
3076   dshape:
3077    - [batch, 1]
3078    - [num_boxes, 10]
3079    - [box_coords, 4]
3080   normalized: true
3081 - type: scores
3082   decoder: ultralytics
3083   quantization: [0.00784313725490196, 0]
3084   shape: [1, 10, 1]
3085   dshape:
3086    - [batch, 1]
3087    - [num_boxes, 10]
3088    - [num_classes, 1]
3089 - type: classes
3090   decoder: ultralytics
3091   quantization: [0.00784313725490196, 0]
3092   shape: [1, 10, 1]
3093   dshape:
3094    - [batch, 1]
3095    - [num_boxes, 10]
3096    - [num_classes, 1]
3097 - type: mask_coefficients
3098   decoder: ultralytics
3099   quantization: [0.00784313725490196, 0]
3100   shape: [1, 10, 32]
3101   dshape:
3102    - [batch, 1]
3103    - [num_boxes, 10]
3104    - [num_protos, 32]
3105 - type: protos
3106   decoder: ultralytics
3107   quantization: [0.0039215686274509803921568627451, 128]
3108   shape: [1, 160, 160, 32]
3109   dshape:
3110    - [batch, 1]
3111    - [height, 160]
3112    - [width, 160]
3113    - [num_protos, 32]
3114";
3115
3116    macro_rules! e2e_segdet_test {
3117        ($name:ident, quantized, $layout:ident, $output:ident) => {
3118            #[test]
3119            fn $name() {
3120                let is_split = matches!(stringify!($layout), "split");
3121                let is_proto = matches!(stringify!($output), "proto");
3122
3123                let score_threshold = 0.45;
3124                let iou_threshold = 0.45;
3125
3126                let mut boxes = Array2::zeros((10, 4));
3127                let mut scores = Array2::zeros((10, 1));
3128                let mut classes = Array2::zeros((10, 1));
3129                let mask = Array2::zeros((10, 32));
3130                let protos = Array3::<f64>::zeros((160, 160, 32));
3131                let protos = protos.insert_axis(Axis(0));
3132                let protos_quant = (1.0 / 255.0, 0.0);
3133                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3134
3135                boxes
3136                    .slice_mut(s![0, ..])
3137                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3138                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3139                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3140
3141                let detect_quant = (2.0 / 255.0, 0.0);
3142
3143                let decoder = if is_split {
3144                    DecoderBuilder::default()
3145                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3146                        .with_score_threshold(score_threshold)
3147                        .with_iou_threshold(iou_threshold)
3148                        .build()
3149                        .unwrap()
3150                } else {
3151                    DecoderBuilder::default()
3152                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3153                        .with_score_threshold(score_threshold)
3154                        .with_iou_threshold(iou_threshold)
3155                        .build()
3156                        .unwrap()
3157                };
3158
3159                let expected = e2e_expected_boxes_quant();
3160                let mut output_boxes = Vec::with_capacity(50);
3161
3162                if is_split {
3163                    let boxes = boxes.insert_axis(Axis(0));
3164                    let scores = scores.insert_axis(Axis(0));
3165                    let classes = classes.insert_axis(Axis(0));
3166                    let mask = mask.insert_axis(Axis(0));
3167
3168                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3169                    let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3170                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3171                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3172
3173                    if is_proto {
3174                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3175                            boxes.view().into(),
3176                            scores.view().into(),
3177                            classes.view().into(),
3178                            mask.view().into(),
3179                            protos.view().into(),
3180                        ];
3181                        decoder
3182                            .decode_quantized_proto(&inputs, &mut output_boxes)
3183                            .unwrap();
3184
3185                        assert_eq!(output_boxes.len(), 1);
3186                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3187                    } else {
3188                        let mut output_masks = Vec::with_capacity(50);
3189                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3190                            boxes.view().into(),
3191                            scores.view().into(),
3192                            classes.view().into(),
3193                            mask.view().into(),
3194                            protos.view().into(),
3195                        ];
3196                        decoder
3197                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3198                            .unwrap();
3199
3200                        assert_eq!(output_boxes.len(), 1);
3201                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3202                    }
3203                } else {
3204                    // Combined layout
3205                    let detect = ndarray::concatenate![
3206                        Axis(1),
3207                        boxes.view(),
3208                        scores.view(),
3209                        classes.view(),
3210                        mask.view()
3211                    ];
3212                    let detect = detect.insert_axis(Axis(0));
3213                    assert_eq!(detect.shape(), &[1, 10, 38]);
3214                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3215
3216                    if is_proto {
3217                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3218                            vec![detect.view().into(), protos.view().into()];
3219                        decoder
3220                            .decode_quantized_proto(&inputs, &mut output_boxes)
3221                            .unwrap();
3222
3223                        assert_eq!(output_boxes.len(), 1);
3224                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3225                    } else {
3226                        let mut output_masks = Vec::with_capacity(50);
3227                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3228                            vec![detect.view().into(), protos.view().into()];
3229                        decoder
3230                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3231                            .unwrap();
3232
3233                        assert_eq!(output_boxes.len(), 1);
3234                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3235                    }
3236                }
3237            }
3238        };
3239        ($name:ident, float, $layout:ident, $output:ident) => {
3240            #[test]
3241            fn $name() {
3242                let is_split = matches!(stringify!($layout), "split");
3243                let is_proto = matches!(stringify!($output), "proto");
3244
3245                let score_threshold = 0.45;
3246                let iou_threshold = 0.45;
3247
3248                let mut boxes = Array2::zeros((10, 4));
3249                let mut scores = Array2::zeros((10, 1));
3250                let mut classes = Array2::zeros((10, 1));
3251                let mask: Array2<f64> = Array2::zeros((10, 32));
3252                let protos = Array3::<f64>::zeros((160, 160, 32));
3253                let protos = protos.insert_axis(Axis(0));
3254
3255                boxes
3256                    .slice_mut(s![0, ..])
3257                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3258                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3259                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3260
3261                let decoder = if is_split {
3262                    DecoderBuilder::default()
3263                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3264                        .with_score_threshold(score_threshold)
3265                        .with_iou_threshold(iou_threshold)
3266                        .build()
3267                        .unwrap()
3268                } else {
3269                    DecoderBuilder::default()
3270                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3271                        .with_score_threshold(score_threshold)
3272                        .with_iou_threshold(iou_threshold)
3273                        .build()
3274                        .unwrap()
3275                };
3276
3277                let expected = e2e_expected_boxes_float();
3278                let mut output_boxes = Vec::with_capacity(50);
3279
3280                if is_split {
3281                    let boxes = boxes.insert_axis(Axis(0));
3282                    let scores = scores.insert_axis(Axis(0));
3283                    let classes = classes.insert_axis(Axis(0));
3284                    let mask = mask.insert_axis(Axis(0));
3285
3286                    if is_proto {
3287                        let inputs = vec![
3288                            boxes.view().into_dyn(),
3289                            scores.view().into_dyn(),
3290                            classes.view().into_dyn(),
3291                            mask.view().into_dyn(),
3292                            protos.view().into_dyn(),
3293                        ];
3294                        decoder
3295                            .decode_float_proto(&inputs, &mut output_boxes)
3296                            .unwrap();
3297
3298                        assert_eq!(output_boxes.len(), 1);
3299                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3300                    } else {
3301                        let mut output_masks = Vec::with_capacity(50);
3302                        let inputs = vec![
3303                            boxes.view().into_dyn(),
3304                            scores.view().into_dyn(),
3305                            classes.view().into_dyn(),
3306                            mask.view().into_dyn(),
3307                            protos.view().into_dyn(),
3308                        ];
3309                        decoder
3310                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3311                            .unwrap();
3312
3313                        assert_eq!(output_boxes.len(), 1);
3314                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3315                    }
3316                } else {
3317                    // Combined layout
3318                    let detect = ndarray::concatenate![
3319                        Axis(1),
3320                        boxes.view(),
3321                        scores.view(),
3322                        classes.view(),
3323                        mask.view()
3324                    ];
3325                    let detect = detect.insert_axis(Axis(0));
3326                    assert_eq!(detect.shape(), &[1, 10, 38]);
3327
3328                    if is_proto {
3329                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3330                        decoder
3331                            .decode_float_proto(&inputs, &mut output_boxes)
3332                            .unwrap();
3333
3334                        assert_eq!(output_boxes.len(), 1);
3335                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3336                    } else {
3337                        let mut output_masks = Vec::with_capacity(50);
3338                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3339                        decoder
3340                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3341                            .unwrap();
3342
3343                        assert_eq!(output_boxes.len(), 1);
3344                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3345                    }
3346                }
3347            }
3348        };
3349    }
3350
3351    e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3352    e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3353    e2e_segdet_test!(
3354        test_decoder_end_to_end_segdet_proto,
3355        quantized,
3356        combined,
3357        proto
3358    );
3359    e2e_segdet_test!(
3360        test_decoder_end_to_end_segdet_proto_float,
3361        float,
3362        combined,
3363        proto
3364    );
3365    e2e_segdet_test!(
3366        test_decoder_end_to_end_segdet_split,
3367        quantized,
3368        split,
3369        masks
3370    );
3371    e2e_segdet_test!(
3372        test_decoder_end_to_end_segdet_split_float,
3373        float,
3374        split,
3375        masks
3376    );
3377    e2e_segdet_test!(
3378        test_decoder_end_to_end_segdet_split_proto,
3379        quantized,
3380        split,
3381        proto
3382    );
3383    e2e_segdet_test!(
3384        test_decoder_end_to_end_segdet_split_proto_float,
3385        float,
3386        split,
3387        proto
3388    );
3389
3390    macro_rules! e2e_det_test {
3391        ($name:ident, quantized, $layout:ident) => {
3392            #[test]
3393            fn $name() {
3394                let is_split = matches!(stringify!($layout), "split");
3395
3396                let score_threshold = 0.45;
3397                let iou_threshold = 0.45;
3398
3399                let mut boxes = Array3::zeros((1, 10, 4));
3400                let mut scores = Array3::zeros((1, 10, 1));
3401                let mut classes = Array3::zeros((1, 10, 1));
3402
3403                boxes
3404                    .slice_mut(s![0, 0, ..])
3405                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3406                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3407                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3408
3409                let detect_quant = (2.0 / 255.0, 0_i32);
3410
3411                let decoder = if is_split {
3412                    DecoderBuilder::default()
3413                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3414                        .with_score_threshold(score_threshold)
3415                        .with_iou_threshold(iou_threshold)
3416                        .build()
3417                        .unwrap()
3418                } else {
3419                    DecoderBuilder::default()
3420                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3421                        .with_score_threshold(score_threshold)
3422                        .with_iou_threshold(iou_threshold)
3423                        .build()
3424                        .unwrap()
3425                };
3426
3427                let expected = e2e_expected_boxes_quant();
3428                let mut output_boxes = Vec::with_capacity(50);
3429
3430                if is_split {
3431                    let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3432                    let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3433                    let classes: Array<u8, _> =
3434                        quantize_ndarray(classes.view(), detect_quant.into());
3435                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3436                        boxes.view().into(),
3437                        scores.view().into(),
3438                        classes.view().into(),
3439                    ];
3440                    decoder
3441                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3442                        .unwrap();
3443                } else {
3444                    let detect =
3445                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3446                    assert_eq!(detect.shape(), &[1, 10, 6]);
3447                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3448                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3449                        vec![detect.view().into()];
3450                    decoder
3451                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3452                        .unwrap();
3453                }
3454
3455                assert_eq!(output_boxes.len(), 1);
3456                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3457            }
3458        };
3459        ($name:ident, float, $layout:ident) => {
3460            #[test]
3461            fn $name() {
3462                let is_split = matches!(stringify!($layout), "split");
3463
3464                let score_threshold = 0.45;
3465                let iou_threshold = 0.45;
3466
3467                let mut boxes = Array3::zeros((1, 10, 4));
3468                let mut scores = Array3::zeros((1, 10, 1));
3469                let mut classes = Array3::zeros((1, 10, 1));
3470
3471                boxes
3472                    .slice_mut(s![0, 0, ..])
3473                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3474                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3475                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3476
3477                let decoder = if is_split {
3478                    DecoderBuilder::default()
3479                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3480                        .with_score_threshold(score_threshold)
3481                        .with_iou_threshold(iou_threshold)
3482                        .build()
3483                        .unwrap()
3484                } else {
3485                    DecoderBuilder::default()
3486                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3487                        .with_score_threshold(score_threshold)
3488                        .with_iou_threshold(iou_threshold)
3489                        .build()
3490                        .unwrap()
3491                };
3492
3493                let expected = e2e_expected_boxes_float();
3494                let mut output_boxes = Vec::with_capacity(50);
3495
3496                if is_split {
3497                    let inputs = vec![
3498                        boxes.view().into_dyn(),
3499                        scores.view().into_dyn(),
3500                        classes.view().into_dyn(),
3501                    ];
3502                    decoder
3503                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3504                        .unwrap();
3505                } else {
3506                    let detect =
3507                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3508                    assert_eq!(detect.shape(), &[1, 10, 6]);
3509                    let inputs = vec![detect.view().into_dyn()];
3510                    decoder
3511                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3512                        .unwrap();
3513                }
3514
3515                assert_eq!(output_boxes.len(), 1);
3516                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3517            }
3518        };
3519    }
3520
3521    e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3522    e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3523    e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3524    e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3525
3526    #[test]
3527    fn test_decode_tensor() {
3528        let score_threshold = 0.45;
3529        let iou_threshold = 0.45;
3530
3531        let raw_boxes = include_bytes!(concat!(
3532            env!("CARGO_MANIFEST_DIR"),
3533            "/../../testdata/yolov8_boxes_116x8400.bin"
3534        ));
3535        let raw_boxes =
3536            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3537        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3538        boxes_i8
3539            .map()
3540            .unwrap()
3541            .as_mut_slice()
3542            .copy_from_slice(raw_boxes);
3543        let boxes_i8 = boxes_i8.into();
3544
3545        let raw_protos = include_bytes!(concat!(
3546            env!("CARGO_MANIFEST_DIR"),
3547            "/../../testdata/yolov8_protos_160x160x32.bin"
3548        ));
3549        let raw_protos = unsafe {
3550            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3551        };
3552        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3553        protos_i8
3554            .map()
3555            .unwrap()
3556            .as_mut_slice()
3557            .copy_from_slice(raw_protos);
3558        let protos_i8 = protos_i8.into();
3559
3560        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3561        let expected = real_data_expected_boxes();
3562        let mut output_boxes = Vec::with_capacity(50);
3563
3564        decoder
3565            .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3566            .unwrap();
3567
3568        assert_eq!(output_boxes.len(), 2);
3569        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3570        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3571    }
3572
3573    #[test]
3574    fn test_decode_tensor_f32() {
3575        let score_threshold = 0.45;
3576        let iou_threshold = 0.45;
3577
3578        let quant_boxes = (0.021287762_f32, 31_i32);
3579        let quant_protos = (0.02491162_f32, -117_i32);
3580        let raw_boxes = include_bytes!(concat!(
3581            env!("CARGO_MANIFEST_DIR"),
3582            "/../../testdata/yolov8_boxes_116x8400.bin"
3583        ));
3584        let raw_boxes =
3585            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3586        let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3587        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3588        let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3589        boxes_f32
3590            .map()
3591            .unwrap()
3592            .as_mut_slice()
3593            .copy_from_slice(&raw_boxes_f32);
3594        let boxes_f32 = boxes_f32.into();
3595
3596        let raw_protos = include_bytes!(concat!(
3597            env!("CARGO_MANIFEST_DIR"),
3598            "/../../testdata/yolov8_protos_160x160x32.bin"
3599        ));
3600        let raw_protos = unsafe {
3601            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3602        };
3603        let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3604        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3605        let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3606        protos_f32
3607            .map()
3608            .unwrap()
3609            .as_mut_slice()
3610            .copy_from_slice(&raw_protos_f32);
3611        let protos_f32 = protos_f32.into();
3612
3613        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3614
3615        let expected = real_data_expected_boxes();
3616        let mut output_boxes = Vec::with_capacity(50);
3617
3618        decoder
3619            .decode(
3620                &[&boxes_f32, &protos_f32],
3621                &mut output_boxes,
3622                &mut Vec::new(),
3623            )
3624            .unwrap();
3625
3626        assert_eq!(output_boxes.len(), 2);
3627        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3628        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3629    }
3630
3631    #[test]
3632    fn test_decode_tensor_f64() {
3633        let score_threshold = 0.45;
3634        let iou_threshold = 0.45;
3635
3636        let quant_boxes = (0.021287762_f32, 31_i32);
3637        let quant_protos = (0.02491162_f32, -117_i32);
3638        let raw_boxes = include_bytes!(concat!(
3639            env!("CARGO_MANIFEST_DIR"),
3640            "/../../testdata/yolov8_boxes_116x8400.bin"
3641        ));
3642        let raw_boxes =
3643            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3644        let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3645        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3646        let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3647        boxes_f64
3648            .map()
3649            .unwrap()
3650            .as_mut_slice()
3651            .copy_from_slice(&raw_boxes_f64);
3652        let boxes_f64 = boxes_f64.into();
3653
3654        let raw_protos = include_bytes!(concat!(
3655            env!("CARGO_MANIFEST_DIR"),
3656            "/../../testdata/yolov8_protos_160x160x32.bin"
3657        ));
3658        let raw_protos = unsafe {
3659            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3660        };
3661        let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3662        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3663        let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3664        protos_f64
3665            .map()
3666            .unwrap()
3667            .as_mut_slice()
3668            .copy_from_slice(&raw_protos_f64);
3669        let protos_f64 = protos_f64.into();
3670
3671        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3672
3673        let expected = real_data_expected_boxes();
3674        let mut output_boxes = Vec::with_capacity(50);
3675
3676        decoder
3677            .decode(
3678                &[&boxes_f64, &protos_f64],
3679                &mut output_boxes,
3680                &mut Vec::new(),
3681            )
3682            .unwrap();
3683
3684        assert_eq!(output_boxes.len(), 2);
3685        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3686        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3687    }
3688
3689    #[test]
3690    fn test_decode_tensor_proto() {
3691        let score_threshold = 0.45;
3692        let iou_threshold = 0.45;
3693
3694        let raw_boxes = include_bytes!(concat!(
3695            env!("CARGO_MANIFEST_DIR"),
3696            "/../../testdata/yolov8_boxes_116x8400.bin"
3697        ));
3698        let raw_boxes =
3699            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3700        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3701        boxes_i8
3702            .map()
3703            .unwrap()
3704            .as_mut_slice()
3705            .copy_from_slice(raw_boxes);
3706        let boxes_i8 = boxes_i8.into();
3707
3708        let raw_protos = include_bytes!(concat!(
3709            env!("CARGO_MANIFEST_DIR"),
3710            "/../../testdata/yolov8_protos_160x160x32.bin"
3711        ));
3712        let raw_protos = unsafe {
3713            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3714        };
3715        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3716        protos_i8
3717            .map()
3718            .unwrap()
3719            .as_mut_slice()
3720            .copy_from_slice(raw_protos);
3721        let protos_i8 = protos_i8.into();
3722
3723        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3724
3725        let expected = real_data_expected_boxes();
3726        let mut output_boxes = Vec::with_capacity(50);
3727
3728        let proto_data = decoder
3729            .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3730            .unwrap();
3731
3732        assert_eq!(output_boxes.len(), 2);
3733        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3734        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3735
3736        let proto_data = proto_data.expect("segmentation model should return ProtoData");
3737        assert_eq!(
3738            proto_data.mask_coefficients.len(),
3739            output_boxes.len(),
3740            "mask_coefficients count must match detection count"
3741        );
3742        for coeff in &proto_data.mask_coefficients {
3743            assert_eq!(
3744                coeff.len(),
3745                32,
3746                "each detection should have 32 mask coefficients"
3747            );
3748        }
3749    }
3750}
3751
3752#[cfg(feature = "tracker")]
3753#[cfg(test)]
3754#[cfg_attr(coverage_nightly, coverage(off))]
3755mod decoder_tracked_tests {
3756
3757    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
3758    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
3759    use num_traits::{AsPrimitive, Float, PrimInt};
3760    use rand::{RngExt, SeedableRng};
3761    use rand_distr::StandardNormal;
3762
3763    use crate::{
3764        configs::{self, DimName},
3765        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
3766    };
3767
3768    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
3769        input: ArrayView<F, D>,
3770        quant: Quantization,
3771    ) -> Array<T, D>
3772    where
3773        i32: num_traits::AsPrimitive<F>,
3774        f32: num_traits::AsPrimitive<F>,
3775    {
3776        let zero_point = quant.zero_point.as_();
3777        let div_scale = F::one() / quant.scale.as_();
3778        if zero_point != F::zero() {
3779            input.mapv(|d| (d * div_scale + zero_point).round().as_())
3780        } else {
3781            input.mapv(|d| (d * div_scale).round().as_())
3782        }
3783    }
3784
3785    #[test]
3786    fn test_decoder_tracked_random_jitter() {
3787        use crate::configs::{DecoderType, Nms};
3788        use crate::DecoderBuilder;
3789
3790        let score_threshold = 0.25;
3791        let iou_threshold = 0.1;
3792        let out = include_bytes!(concat!(
3793            env!("CARGO_MANIFEST_DIR"),
3794            "/../../testdata/yolov8s_80_classes.bin"
3795        ));
3796        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
3797        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
3798        let quant = (0.0040811873, -123).into();
3799
3800        let decoder = DecoderBuilder::default()
3801            .with_config_yolo_det(
3802                crate::configs::Detection {
3803                    decoder: DecoderType::Ultralytics,
3804                    shape: vec![1, 84, 8400],
3805                    anchors: None,
3806                    quantization: Some(quant),
3807                    dshape: vec![
3808                        (crate::configs::DimName::Batch, 1),
3809                        (crate::configs::DimName::NumFeatures, 84),
3810                        (crate::configs::DimName::NumBoxes, 8400),
3811                    ],
3812                    normalized: Some(true),
3813                },
3814                None,
3815            )
3816            .with_score_threshold(score_threshold)
3817            .with_iou_threshold(iou_threshold)
3818            .with_nms(Some(Nms::ClassAgnostic))
3819            .build()
3820            .unwrap();
3821        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
3822
3823        let expected_boxes = [
3824            crate::DetectBox {
3825                bbox: crate::BoundingBox {
3826                    xmin: 0.5285137,
3827                    ymin: 0.05305544,
3828                    xmax: 0.87541467,
3829                    ymax: 0.9998909,
3830                },
3831                score: 0.5591227,
3832                label: 0,
3833            },
3834            crate::DetectBox {
3835                bbox: crate::BoundingBox {
3836                    xmin: 0.130598,
3837                    ymin: 0.43260583,
3838                    xmax: 0.35098213,
3839                    ymax: 0.9958097,
3840                },
3841                score: 0.33057618,
3842                label: 75,
3843            },
3844        ];
3845
3846        let mut tracker = ByteTrackBuilder::new()
3847            .track_update(0.1)
3848            .track_high_conf(0.3)
3849            .build();
3850
3851        let mut output_boxes = Vec::with_capacity(50);
3852        let mut output_masks = Vec::with_capacity(50);
3853        let mut output_tracks = Vec::with_capacity(50);
3854
3855        decoder
3856            .decode_tracked_quantized(
3857                &mut tracker,
3858                0,
3859                &[out.view().into()],
3860                &mut output_boxes,
3861                &mut output_masks,
3862                &mut output_tracks,
3863            )
3864            .unwrap();
3865
3866        assert_eq!(output_boxes.len(), 2);
3867        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3868        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3869
3870        let mut last_boxes = output_boxes.clone();
3871
3872        for i in 1..=100 {
3873            let mut out = out.clone();
3874            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
3875            let mut x_values = out.slice_mut(s![0, 0, ..]);
3876            for x in x_values.iter_mut() {
3877                let r: f32 = rng.sample(StandardNormal);
3878                let r = r.clamp(-2.0, 2.0) / 2.0;
3879                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
3880            }
3881
3882            let mut y_values = out.slice_mut(s![0, 1, ..]);
3883            for y in y_values.iter_mut() {
3884                let r: f32 = rng.sample(StandardNormal);
3885                let r = r.clamp(-2.0, 2.0) / 2.0;
3886                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
3887            }
3888
3889            decoder
3890                .decode_tracked_quantized(
3891                    &mut tracker,
3892                    100_000_000 * i / 3, // simulate 33.333ms between frames
3893                    &[out.view().into()],
3894                    &mut output_boxes,
3895                    &mut output_masks,
3896                    &mut output_tracks,
3897                )
3898                .unwrap();
3899
3900            assert_eq!(output_boxes.len(), 2);
3901            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
3902            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
3903
3904            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
3905            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
3906            last_boxes = output_boxes.clone();
3907        }
3908    }
3909
3910    // ─── Shared helpers for tracked decoder tests ────────────────────
3911
3912    fn real_data_expected_boxes() -> [DetectBox; 2] {
3913        [
3914            DetectBox {
3915                bbox: BoundingBox {
3916                    xmin: 0.08515105,
3917                    ymin: 0.7131401,
3918                    xmax: 0.29802868,
3919                    ymax: 0.8195788,
3920                },
3921                score: 0.91537374,
3922                label: 23,
3923            },
3924            DetectBox {
3925                bbox: BoundingBox {
3926                    xmin: 0.59605736,
3927                    ymin: 0.25545314,
3928                    xmax: 0.93666154,
3929                    ymax: 0.72378385,
3930                },
3931                score: 0.91537374,
3932                label: 23,
3933            },
3934        ]
3935    }
3936
3937    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
3938        [DetectBox {
3939            bbox: BoundingBox {
3940                xmin: 0.12549022,
3941                ymin: 0.12549022,
3942                xmax: 0.23529413,
3943                ymax: 0.23529413,
3944            },
3945            score: 0.98823535,
3946            label: 2,
3947        }]
3948    }
3949
3950    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
3951        [DetectBox {
3952            bbox: BoundingBox {
3953                xmin: 0.1234,
3954                ymin: 0.1234,
3955                xmax: 0.2345,
3956                ymax: 0.2345,
3957            },
3958            score: 0.9876,
3959            label: 2,
3960        }]
3961    }
3962
3963    fn build_yolo_split_segdet_decoder(
3964        score_threshold: f32,
3965        iou_threshold: f32,
3966        quant_boxes: (f32, i32),
3967        quant_protos: (f32, i32),
3968    ) -> crate::Decoder {
3969        DecoderBuilder::default()
3970            .with_config_yolo_split_segdet(
3971                configs::Boxes {
3972                    decoder: configs::DecoderType::Ultralytics,
3973                    quantization: Some(quant_boxes.into()),
3974                    shape: vec![1, 4, 8400],
3975                    dshape: vec![
3976                        (DimName::Batch, 1),
3977                        (DimName::BoxCoords, 4),
3978                        (DimName::NumBoxes, 8400),
3979                    ],
3980                    normalized: Some(true),
3981                },
3982                configs::Scores {
3983                    decoder: configs::DecoderType::Ultralytics,
3984                    quantization: Some(quant_boxes.into()),
3985                    shape: vec![1, 80, 8400],
3986                    dshape: vec![
3987                        (DimName::Batch, 1),
3988                        (DimName::NumClasses, 80),
3989                        (DimName::NumBoxes, 8400),
3990                    ],
3991                },
3992                configs::MaskCoefficients {
3993                    decoder: configs::DecoderType::Ultralytics,
3994                    quantization: Some(quant_boxes.into()),
3995                    shape: vec![1, 32, 8400],
3996                    dshape: vec![
3997                        (DimName::Batch, 1),
3998                        (DimName::NumProtos, 32),
3999                        (DimName::NumBoxes, 8400),
4000                    ],
4001                },
4002                configs::Protos {
4003                    decoder: configs::DecoderType::Ultralytics,
4004                    quantization: Some(quant_protos.into()),
4005                    shape: vec![1, 160, 160, 32],
4006                    dshape: vec![
4007                        (DimName::Batch, 1),
4008                        (DimName::Height, 160),
4009                        (DimName::Width, 160),
4010                        (DimName::NumProtos, 32),
4011                    ],
4012                },
4013            )
4014            .with_score_threshold(score_threshold)
4015            .with_iou_threshold(iou_threshold)
4016            .build()
4017            .unwrap()
4018    }
4019
4020    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4021        let config_yaml = include_str!(concat!(
4022            env!("CARGO_MANIFEST_DIR"),
4023            "/../../testdata/yolov8_seg.yaml"
4024        ));
4025        DecoderBuilder::default()
4026            .with_config_yaml_str(config_yaml.to_string())
4027            .with_score_threshold(score_threshold)
4028            .with_iou_threshold(iou_threshold)
4029            .build()
4030            .unwrap()
4031    }
4032
4033    // ─── Real-data tracked test macro ───────────────────────────────
4034    //
4035    // Generates tests that load i8 binary test data from testdata/ and
4036    // exercise all (quant/float) × (combined/split) × (masks/proto)
4037    // decoder paths.
4038
4039    macro_rules! real_data_tracked_test {
4040        ($name:ident, quantized, $layout:ident, $output:ident) => {
4041            #[test]
4042            fn $name() {
4043                let is_split = matches!(stringify!($layout), "split");
4044                let is_proto = matches!(stringify!($output), "proto");
4045
4046                let score_threshold = 0.45;
4047                let iou_threshold = 0.45;
4048                let quant_boxes = (0.021287762_f32, 31_i32);
4049                let quant_protos = (0.02491162_f32, -117_i32);
4050
4051                let raw_boxes = include_bytes!(concat!(
4052                    env!("CARGO_MANIFEST_DIR"),
4053                    "/../../testdata/yolov8_boxes_116x8400.bin"
4054                ));
4055                let raw_boxes = unsafe {
4056                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4057                };
4058                let boxes_i8 =
4059                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4060
4061                let raw_protos = include_bytes!(concat!(
4062                    env!("CARGO_MANIFEST_DIR"),
4063                    "/../../testdata/yolov8_protos_160x160x32.bin"
4064                ));
4065                let raw_protos = unsafe {
4066                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4067                };
4068                let protos_i8 =
4069                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4070                        .unwrap();
4071
4072                // Pre-split (unused for combined, but harmless)
4073                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4074                let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4075                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4076                let mut boxes_combined = boxes_i8;
4077
4078                let decoder = if is_split {
4079                    build_yolo_split_segdet_decoder(
4080                        score_threshold,
4081                        iou_threshold,
4082                        quant_boxes,
4083                        quant_protos,
4084                    )
4085                } else {
4086                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4087                };
4088
4089                let expected = real_data_expected_boxes();
4090                let mut tracker = ByteTrackBuilder::new()
4091                    .track_update(0.1)
4092                    .track_high_conf(0.7)
4093                    .build();
4094                let mut output_boxes = Vec::with_capacity(50);
4095                let mut output_tracks = Vec::with_capacity(50);
4096
4097                // Frame 1: decode
4098                if is_proto {
4099                    {
4100                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4101                            vec![
4102                                boxes_split.view().into(),
4103                                scores_split.view().into(),
4104                                mask_split.view().into(),
4105                                protos_i8.view().into(),
4106                            ]
4107                        } else {
4108                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4109                        };
4110                        decoder
4111                            .decode_tracked_quantized_proto(
4112                                &mut tracker,
4113                                0,
4114                                &inputs,
4115                                &mut output_boxes,
4116                                &mut output_tracks,
4117                            )
4118                            .unwrap();
4119                    }
4120                    assert_eq!(output_boxes.len(), 2);
4121                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4122                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4123
4124                    // Zero scores for frame 2
4125                    if is_split {
4126                        for score in scores_split.iter_mut() {
4127                            *score = i8::MIN;
4128                        }
4129                    } else {
4130                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4131                            *score = i8::MIN;
4132                        }
4133                    }
4134
4135                    let proto_result = {
4136                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4137                            vec![
4138                                boxes_split.view().into(),
4139                                scores_split.view().into(),
4140                                mask_split.view().into(),
4141                                protos_i8.view().into(),
4142                            ]
4143                        } else {
4144                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4145                        };
4146                        decoder
4147                            .decode_tracked_quantized_proto(
4148                                &mut tracker,
4149                                100_000_000 / 3,
4150                                &inputs,
4151                                &mut output_boxes,
4152                                &mut output_tracks,
4153                            )
4154                            .unwrap()
4155                    };
4156                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4157                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4158                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4159                } else {
4160                    let mut output_masks = Vec::with_capacity(50);
4161                    {
4162                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4163                            vec![
4164                                boxes_split.view().into(),
4165                                scores_split.view().into(),
4166                                mask_split.view().into(),
4167                                protos_i8.view().into(),
4168                            ]
4169                        } else {
4170                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4171                        };
4172                        decoder
4173                            .decode_tracked_quantized(
4174                                &mut tracker,
4175                                0,
4176                                &inputs,
4177                                &mut output_boxes,
4178                                &mut output_masks,
4179                                &mut output_tracks,
4180                            )
4181                            .unwrap();
4182                    }
4183                    assert_eq!(output_boxes.len(), 2);
4184                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4185                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4186
4187                    if is_split {
4188                        for score in scores_split.iter_mut() {
4189                            *score = i8::MIN;
4190                        }
4191                    } else {
4192                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4193                            *score = i8::MIN;
4194                        }
4195                    }
4196
4197                    {
4198                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4199                            vec![
4200                                boxes_split.view().into(),
4201                                scores_split.view().into(),
4202                                mask_split.view().into(),
4203                                protos_i8.view().into(),
4204                            ]
4205                        } else {
4206                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4207                        };
4208                        decoder
4209                            .decode_tracked_quantized(
4210                                &mut tracker,
4211                                100_000_000 / 3,
4212                                &inputs,
4213                                &mut output_boxes,
4214                                &mut output_masks,
4215                                &mut output_tracks,
4216                            )
4217                            .unwrap();
4218                    }
4219                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4220                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4221                    assert!(output_masks.is_empty());
4222                }
4223            }
4224        };
4225        ($name:ident, float, $layout:ident, $output:ident) => {
4226            #[test]
4227            fn $name() {
4228                let is_split = matches!(stringify!($layout), "split");
4229                let is_proto = matches!(stringify!($output), "proto");
4230
4231                let score_threshold = 0.45;
4232                let iou_threshold = 0.45;
4233                let quant_boxes = (0.021287762_f32, 31_i32);
4234                let quant_protos = (0.02491162_f32, -117_i32);
4235
4236                let raw_boxes = include_bytes!(concat!(
4237                    env!("CARGO_MANIFEST_DIR"),
4238                    "/../../testdata/yolov8_boxes_116x8400.bin"
4239                ));
4240                let raw_boxes = unsafe {
4241                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4242                };
4243                let boxes_i8 =
4244                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4245                let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4246
4247                let raw_protos = include_bytes!(concat!(
4248                    env!("CARGO_MANIFEST_DIR"),
4249                    "/../../testdata/yolov8_protos_160x160x32.bin"
4250                ));
4251                let raw_protos = unsafe {
4252                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4253                };
4254                let protos_i8 =
4255                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4256                        .unwrap();
4257                let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4258
4259                // Pre-split from dequantized data
4260                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4261                let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4262                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4263                let mut boxes_combined = boxes_f32;
4264
4265                let decoder = if is_split {
4266                    build_yolo_split_segdet_decoder(
4267                        score_threshold,
4268                        iou_threshold,
4269                        quant_boxes,
4270                        quant_protos,
4271                    )
4272                } else {
4273                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4274                };
4275
4276                let expected = real_data_expected_boxes();
4277                let mut tracker = ByteTrackBuilder::new()
4278                    .track_update(0.1)
4279                    .track_high_conf(0.7)
4280                    .build();
4281                let mut output_boxes = Vec::with_capacity(50);
4282                let mut output_tracks = Vec::with_capacity(50);
4283
4284                if is_proto {
4285                    {
4286                        let inputs = if is_split {
4287                            vec![
4288                                boxes_split.view().into_dyn(),
4289                                scores_split.view().into_dyn(),
4290                                mask_split.view().into_dyn(),
4291                                protos_f32.view().into_dyn(),
4292                            ]
4293                        } else {
4294                            vec![
4295                                boxes_combined.view().into_dyn(),
4296                                protos_f32.view().into_dyn(),
4297                            ]
4298                        };
4299                        decoder
4300                            .decode_tracked_float_proto(
4301                                &mut tracker,
4302                                0,
4303                                &inputs,
4304                                &mut output_boxes,
4305                                &mut output_tracks,
4306                            )
4307                            .unwrap();
4308                    }
4309                    assert_eq!(output_boxes.len(), 2);
4310                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4311                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4312
4313                    if is_split {
4314                        for score in scores_split.iter_mut() {
4315                            *score = 0.0;
4316                        }
4317                    } else {
4318                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4319                            *score = 0.0;
4320                        }
4321                    }
4322
4323                    let proto_result = {
4324                        let inputs = if is_split {
4325                            vec![
4326                                boxes_split.view().into_dyn(),
4327                                scores_split.view().into_dyn(),
4328                                mask_split.view().into_dyn(),
4329                                protos_f32.view().into_dyn(),
4330                            ]
4331                        } else {
4332                            vec![
4333                                boxes_combined.view().into_dyn(),
4334                                protos_f32.view().into_dyn(),
4335                            ]
4336                        };
4337                        decoder
4338                            .decode_tracked_float_proto(
4339                                &mut tracker,
4340                                100_000_000 / 3,
4341                                &inputs,
4342                                &mut output_boxes,
4343                                &mut output_tracks,
4344                            )
4345                            .unwrap()
4346                    };
4347                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4348                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4349                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4350                } else {
4351                    let mut output_masks = Vec::with_capacity(50);
4352                    {
4353                        let inputs = if is_split {
4354                            vec![
4355                                boxes_split.view().into_dyn(),
4356                                scores_split.view().into_dyn(),
4357                                mask_split.view().into_dyn(),
4358                                protos_f32.view().into_dyn(),
4359                            ]
4360                        } else {
4361                            vec![
4362                                boxes_combined.view().into_dyn(),
4363                                protos_f32.view().into_dyn(),
4364                            ]
4365                        };
4366                        decoder
4367                            .decode_tracked_float(
4368                                &mut tracker,
4369                                0,
4370                                &inputs,
4371                                &mut output_boxes,
4372                                &mut output_masks,
4373                                &mut output_tracks,
4374                            )
4375                            .unwrap();
4376                    }
4377                    assert_eq!(output_boxes.len(), 2);
4378                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4379                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4380
4381                    if is_split {
4382                        for score in scores_split.iter_mut() {
4383                            *score = 0.0;
4384                        }
4385                    } else {
4386                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4387                            *score = 0.0;
4388                        }
4389                    }
4390
4391                    {
4392                        let inputs = if is_split {
4393                            vec![
4394                                boxes_split.view().into_dyn(),
4395                                scores_split.view().into_dyn(),
4396                                mask_split.view().into_dyn(),
4397                                protos_f32.view().into_dyn(),
4398                            ]
4399                        } else {
4400                            vec![
4401                                boxes_combined.view().into_dyn(),
4402                                protos_f32.view().into_dyn(),
4403                            ]
4404                        };
4405                        decoder
4406                            .decode_tracked_float(
4407                                &mut tracker,
4408                                100_000_000 / 3,
4409                                &inputs,
4410                                &mut output_boxes,
4411                                &mut output_masks,
4412                                &mut output_tracks,
4413                            )
4414                            .unwrap();
4415                    }
4416                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4417                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4418                    assert!(output_masks.is_empty());
4419                }
4420            }
4421        };
4422    }
4423
4424    real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
4425    real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
4426    real_data_tracked_test!(
4427        test_decoder_tracked_segdet_proto,
4428        quantized,
4429        combined,
4430        proto
4431    );
4432    real_data_tracked_test!(
4433        test_decoder_tracked_segdet_proto_float,
4434        float,
4435        combined,
4436        proto
4437    );
4438    real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
4439    real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
4440    real_data_tracked_test!(
4441        test_decoder_tracked_segdet_split_proto,
4442        quantized,
4443        split,
4444        proto
4445    );
4446    real_data_tracked_test!(
4447        test_decoder_tracked_segdet_split_proto_float,
4448        float,
4449        split,
4450        proto
4451    );
4452
4453    // ─── End-to-end tracked test macro ──────────────────────────────
4454    //
4455    // Generates tests with synthetic data to exercise all tracked
4456    // decode paths without needing real model output files.
4457
4458    const E2E_COMBINED_CONFIG: &str = "
4459decoder_version: yolo26
4460outputs:
4461 - type: detection
4462   decoder: ultralytics
4463   quantization: [0.00784313725490196, 0]
4464   shape: [1, 10, 38]
4465   dshape:
4466    - [batch, 1]
4467    - [num_boxes, 10]
4468    - [num_features, 38]
4469   normalized: true
4470 - type: protos
4471   decoder: ultralytics
4472   quantization: [0.0039215686274509803921568627451, 128]
4473   shape: [1, 160, 160, 32]
4474   dshape:
4475    - [batch, 1]
4476    - [height, 160]
4477    - [width, 160]
4478    - [num_protos, 32]
4479";
4480
4481    const E2E_SPLIT_CONFIG: &str = "
4482decoder_version: yolo26
4483outputs:
4484 - type: boxes
4485   decoder: ultralytics
4486   quantization: [0.00784313725490196, 0]
4487   shape: [1, 10, 4]
4488   dshape:
4489    - [batch, 1]
4490    - [num_boxes, 10]
4491    - [box_coords, 4]
4492   normalized: true
4493 - type: scores
4494   decoder: ultralytics
4495   quantization: [0.00784313725490196, 0]
4496   shape: [1, 10, 1]
4497   dshape:
4498    - [batch, 1]
4499    - [num_boxes, 10]
4500    - [num_classes, 1]
4501 - type: classes
4502   decoder: ultralytics
4503   quantization: [0.00784313725490196, 0]
4504   shape: [1, 10, 1]
4505   dshape:
4506    - [batch, 1]
4507    - [num_boxes, 10]
4508    - [num_classes, 1]
4509 - type: mask_coefficients
4510   decoder: ultralytics
4511   quantization: [0.00784313725490196, 0]
4512   shape: [1, 10, 32]
4513   dshape:
4514    - [batch, 1]
4515    - [num_boxes, 10]
4516    - [num_protos, 32]
4517 - type: protos
4518   decoder: ultralytics
4519   quantization: [0.0039215686274509803921568627451, 128]
4520   shape: [1, 160, 160, 32]
4521   dshape:
4522    - [batch, 1]
4523    - [height, 160]
4524    - [width, 160]
4525    - [num_protos, 32]
4526";
4527
4528    macro_rules! e2e_tracked_test {
4529        ($name:ident, quantized, $layout:ident, $output:ident) => {
4530            #[test]
4531            fn $name() {
4532                let is_split = matches!(stringify!($layout), "split");
4533                let is_proto = matches!(stringify!($output), "proto");
4534
4535                let score_threshold = 0.45;
4536                let iou_threshold = 0.45;
4537
4538                let mut boxes = Array2::zeros((10, 4));
4539                let mut scores = Array2::zeros((10, 1));
4540                let mut classes = Array2::zeros((10, 1));
4541                let mask = Array2::zeros((10, 32));
4542                let protos = Array3::<f64>::zeros((160, 160, 32));
4543                let protos = protos.insert_axis(Axis(0));
4544                let protos_quant = (1.0 / 255.0, 0.0);
4545                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4546
4547                boxes
4548                    .slice_mut(s![0, ..])
4549                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4550                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4551                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4552
4553                let detect_quant = (2.0 / 255.0, 0.0);
4554
4555                let decoder = if is_split {
4556                    DecoderBuilder::default()
4557                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
4558                        .with_score_threshold(score_threshold)
4559                        .with_iou_threshold(iou_threshold)
4560                        .build()
4561                        .unwrap()
4562                } else {
4563                    DecoderBuilder::default()
4564                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
4565                        .with_score_threshold(score_threshold)
4566                        .with_iou_threshold(iou_threshold)
4567                        .build()
4568                        .unwrap()
4569                };
4570
4571                let expected = e2e_expected_boxes_quant();
4572                let mut tracker = ByteTrackBuilder::new()
4573                    .track_update(0.1)
4574                    .track_high_conf(0.7)
4575                    .build();
4576                let mut output_boxes = Vec::with_capacity(50);
4577                let mut output_tracks = Vec::with_capacity(50);
4578
4579                if is_split {
4580                    let boxes = boxes.insert_axis(Axis(0));
4581                    let scores = scores.insert_axis(Axis(0));
4582                    let classes = classes.insert_axis(Axis(0));
4583                    let mask = mask.insert_axis(Axis(0));
4584
4585                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4586                    let mut scores: Array3<u8> =
4587                        quantize_ndarray(scores.view(), detect_quant.into());
4588                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4589                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4590
4591                    if is_proto {
4592                        {
4593                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4594                                boxes.view().into(),
4595                                scores.view().into(),
4596                                classes.view().into(),
4597                                mask.view().into(),
4598                                protos.view().into(),
4599                            ];
4600                            decoder
4601                                .decode_tracked_quantized_proto(
4602                                    &mut tracker,
4603                                    0,
4604                                    &inputs,
4605                                    &mut output_boxes,
4606                                    &mut output_tracks,
4607                                )
4608                                .unwrap();
4609                        }
4610                        assert_eq!(output_boxes.len(), 1);
4611                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4612
4613                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4614                            *score = u8::MIN;
4615                        }
4616                        let proto_result = {
4617                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4618                                boxes.view().into(),
4619                                scores.view().into(),
4620                                classes.view().into(),
4621                                mask.view().into(),
4622                                protos.view().into(),
4623                            ];
4624                            decoder
4625                                .decode_tracked_quantized_proto(
4626                                    &mut tracker,
4627                                    100_000_000 / 3,
4628                                    &inputs,
4629                                    &mut output_boxes,
4630                                    &mut output_tracks,
4631                                )
4632                                .unwrap()
4633                        };
4634                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4635                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4636                    } else {
4637                        let mut output_masks = Vec::with_capacity(50);
4638                        {
4639                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4640                                boxes.view().into(),
4641                                scores.view().into(),
4642                                classes.view().into(),
4643                                mask.view().into(),
4644                                protos.view().into(),
4645                            ];
4646                            decoder
4647                                .decode_tracked_quantized(
4648                                    &mut tracker,
4649                                    0,
4650                                    &inputs,
4651                                    &mut output_boxes,
4652                                    &mut output_masks,
4653                                    &mut output_tracks,
4654                                )
4655                                .unwrap();
4656                        }
4657                        assert_eq!(output_boxes.len(), 1);
4658                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4659
4660                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4661                            *score = u8::MIN;
4662                        }
4663                        {
4664                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4665                                boxes.view().into(),
4666                                scores.view().into(),
4667                                classes.view().into(),
4668                                mask.view().into(),
4669                                protos.view().into(),
4670                            ];
4671                            decoder
4672                                .decode_tracked_quantized(
4673                                    &mut tracker,
4674                                    100_000_000 / 3,
4675                                    &inputs,
4676                                    &mut output_boxes,
4677                                    &mut output_masks,
4678                                    &mut output_tracks,
4679                                )
4680                                .unwrap();
4681                        }
4682                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4683                        assert!(output_masks.is_empty());
4684                    }
4685                } else {
4686                    // Combined layout
4687                    let detect = ndarray::concatenate![
4688                        Axis(1),
4689                        boxes.view(),
4690                        scores.view(),
4691                        classes.view(),
4692                        mask.view()
4693                    ];
4694                    let detect = detect.insert_axis(Axis(0));
4695                    assert_eq!(detect.shape(), &[1, 10, 38]);
4696                    let mut detect: Array3<u8> =
4697                        quantize_ndarray(detect.view(), detect_quant.into());
4698
4699                    if is_proto {
4700                        {
4701                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4702                                vec![detect.view().into(), protos.view().into()];
4703                            decoder
4704                                .decode_tracked_quantized_proto(
4705                                    &mut tracker,
4706                                    0,
4707                                    &inputs,
4708                                    &mut output_boxes,
4709                                    &mut output_tracks,
4710                                )
4711                                .unwrap();
4712                        }
4713                        assert_eq!(output_boxes.len(), 1);
4714                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4715
4716                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4717                            *score = u8::MIN;
4718                        }
4719                        let proto_result = {
4720                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4721                                vec![detect.view().into(), protos.view().into()];
4722                            decoder
4723                                .decode_tracked_quantized_proto(
4724                                    &mut tracker,
4725                                    100_000_000 / 3,
4726                                    &inputs,
4727                                    &mut output_boxes,
4728                                    &mut output_tracks,
4729                                )
4730                                .unwrap()
4731                        };
4732                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4733                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4734                    } else {
4735                        let mut output_masks = Vec::with_capacity(50);
4736                        {
4737                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4738                                vec![detect.view().into(), protos.view().into()];
4739                            decoder
4740                                .decode_tracked_quantized(
4741                                    &mut tracker,
4742                                    0,
4743                                    &inputs,
4744                                    &mut output_boxes,
4745                                    &mut output_masks,
4746                                    &mut output_tracks,
4747                                )
4748                                .unwrap();
4749                        }
4750                        assert_eq!(output_boxes.len(), 1);
4751                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4752
4753                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4754                            *score = u8::MIN;
4755                        }
4756                        {
4757                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4758                                vec![detect.view().into(), protos.view().into()];
4759                            decoder
4760                                .decode_tracked_quantized(
4761                                    &mut tracker,
4762                                    100_000_000 / 3,
4763                                    &inputs,
4764                                    &mut output_boxes,
4765                                    &mut output_masks,
4766                                    &mut output_tracks,
4767                                )
4768                                .unwrap();
4769                        }
4770                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4771                        assert!(output_masks.is_empty());
4772                    }
4773                }
4774            }
4775        };
4776        ($name:ident, float, $layout:ident, $output:ident) => {
4777            #[test]
4778            fn $name() {
4779                let is_split = matches!(stringify!($layout), "split");
4780                let is_proto = matches!(stringify!($output), "proto");
4781
4782                let score_threshold = 0.45;
4783                let iou_threshold = 0.45;
4784
4785                let mut boxes = Array2::zeros((10, 4));
4786                let mut scores = Array2::zeros((10, 1));
4787                let mut classes = Array2::zeros((10, 1));
4788                let mask: Array2<f64> = Array2::zeros((10, 32));
4789                let protos = Array3::<f64>::zeros((160, 160, 32));
4790                let protos = protos.insert_axis(Axis(0));
4791
4792                boxes
4793                    .slice_mut(s![0, ..])
4794                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4795                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4796                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4797
4798                let decoder = if is_split {
4799                    DecoderBuilder::default()
4800                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
4801                        .with_score_threshold(score_threshold)
4802                        .with_iou_threshold(iou_threshold)
4803                        .build()
4804                        .unwrap()
4805                } else {
4806                    DecoderBuilder::default()
4807                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
4808                        .with_score_threshold(score_threshold)
4809                        .with_iou_threshold(iou_threshold)
4810                        .build()
4811                        .unwrap()
4812                };
4813
4814                let expected = e2e_expected_boxes_float();
4815                let mut tracker = ByteTrackBuilder::new()
4816                    .track_update(0.1)
4817                    .track_high_conf(0.7)
4818                    .build();
4819                let mut output_boxes = Vec::with_capacity(50);
4820                let mut output_tracks = Vec::with_capacity(50);
4821
4822                if is_split {
4823                    let boxes = boxes.insert_axis(Axis(0));
4824                    let mut scores = scores.insert_axis(Axis(0));
4825                    let classes = classes.insert_axis(Axis(0));
4826                    let mask = mask.insert_axis(Axis(0));
4827
4828                    if is_proto {
4829                        {
4830                            let inputs = vec![
4831                                boxes.view().into_dyn(),
4832                                scores.view().into_dyn(),
4833                                classes.view().into_dyn(),
4834                                mask.view().into_dyn(),
4835                                protos.view().into_dyn(),
4836                            ];
4837                            decoder
4838                                .decode_tracked_float_proto(
4839                                    &mut tracker,
4840                                    0,
4841                                    &inputs,
4842                                    &mut output_boxes,
4843                                    &mut output_tracks,
4844                                )
4845                                .unwrap();
4846                        }
4847                        assert_eq!(output_boxes.len(), 1);
4848                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4849
4850                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4851                            *score = 0.0;
4852                        }
4853                        let proto_result = {
4854                            let inputs = vec![
4855                                boxes.view().into_dyn(),
4856                                scores.view().into_dyn(),
4857                                classes.view().into_dyn(),
4858                                mask.view().into_dyn(),
4859                                protos.view().into_dyn(),
4860                            ];
4861                            decoder
4862                                .decode_tracked_float_proto(
4863                                    &mut tracker,
4864                                    100_000_000 / 3,
4865                                    &inputs,
4866                                    &mut output_boxes,
4867                                    &mut output_tracks,
4868                                )
4869                                .unwrap()
4870                        };
4871                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4872                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4873                    } else {
4874                        let mut output_masks = Vec::with_capacity(50);
4875                        {
4876                            let inputs = vec![
4877                                boxes.view().into_dyn(),
4878                                scores.view().into_dyn(),
4879                                classes.view().into_dyn(),
4880                                mask.view().into_dyn(),
4881                                protos.view().into_dyn(),
4882                            ];
4883                            decoder
4884                                .decode_tracked_float(
4885                                    &mut tracker,
4886                                    0,
4887                                    &inputs,
4888                                    &mut output_boxes,
4889                                    &mut output_masks,
4890                                    &mut output_tracks,
4891                                )
4892                                .unwrap();
4893                        }
4894                        assert_eq!(output_boxes.len(), 1);
4895                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4896
4897                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4898                            *score = 0.0;
4899                        }
4900                        {
4901                            let inputs = vec![
4902                                boxes.view().into_dyn(),
4903                                scores.view().into_dyn(),
4904                                classes.view().into_dyn(),
4905                                mask.view().into_dyn(),
4906                                protos.view().into_dyn(),
4907                            ];
4908                            decoder
4909                                .decode_tracked_float(
4910                                    &mut tracker,
4911                                    100_000_000 / 3,
4912                                    &inputs,
4913                                    &mut output_boxes,
4914                                    &mut output_masks,
4915                                    &mut output_tracks,
4916                                )
4917                                .unwrap();
4918                        }
4919                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4920                        assert!(output_masks.is_empty());
4921                    }
4922                } else {
4923                    // Combined layout
4924                    let detect = ndarray::concatenate![
4925                        Axis(1),
4926                        boxes.view(),
4927                        scores.view(),
4928                        classes.view(),
4929                        mask.view()
4930                    ];
4931                    let mut detect = detect.insert_axis(Axis(0));
4932                    assert_eq!(detect.shape(), &[1, 10, 38]);
4933
4934                    if is_proto {
4935                        {
4936                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4937                            decoder
4938                                .decode_tracked_float_proto(
4939                                    &mut tracker,
4940                                    0,
4941                                    &inputs,
4942                                    &mut output_boxes,
4943                                    &mut output_tracks,
4944                                )
4945                                .unwrap();
4946                        }
4947                        assert_eq!(output_boxes.len(), 1);
4948                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4949
4950                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4951                            *score = 0.0;
4952                        }
4953                        let proto_result = {
4954                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4955                            decoder
4956                                .decode_tracked_float_proto(
4957                                    &mut tracker,
4958                                    100_000_000 / 3,
4959                                    &inputs,
4960                                    &mut output_boxes,
4961                                    &mut output_tracks,
4962                                )
4963                                .unwrap()
4964                        };
4965                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4966                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4967                    } else {
4968                        let mut output_masks = Vec::with_capacity(50);
4969                        {
4970                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4971                            decoder
4972                                .decode_tracked_float(
4973                                    &mut tracker,
4974                                    0,
4975                                    &inputs,
4976                                    &mut output_boxes,
4977                                    &mut output_masks,
4978                                    &mut output_tracks,
4979                                )
4980                                .unwrap();
4981                        }
4982                        assert_eq!(output_boxes.len(), 1);
4983                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4984
4985                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4986                            *score = 0.0;
4987                        }
4988                        {
4989                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4990                            decoder
4991                                .decode_tracked_float(
4992                                    &mut tracker,
4993                                    100_000_000 / 3,
4994                                    &inputs,
4995                                    &mut output_boxes,
4996                                    &mut output_masks,
4997                                    &mut output_tracks,
4998                                )
4999                                .unwrap();
5000                        }
5001                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5002                        assert!(output_masks.is_empty());
5003                    }
5004                }
5005            }
5006        };
5007    }
5008
5009    e2e_tracked_test!(
5010        test_decoder_tracked_end_to_end_segdet,
5011        quantized,
5012        combined,
5013        masks
5014    );
5015    e2e_tracked_test!(
5016        test_decoder_tracked_end_to_end_segdet_float,
5017        float,
5018        combined,
5019        masks
5020    );
5021    e2e_tracked_test!(
5022        test_decoder_tracked_end_to_end_segdet_proto,
5023        quantized,
5024        combined,
5025        proto
5026    );
5027    e2e_tracked_test!(
5028        test_decoder_tracked_end_to_end_segdet_proto_float,
5029        float,
5030        combined,
5031        proto
5032    );
5033    e2e_tracked_test!(
5034        test_decoder_tracked_end_to_end_segdet_split,
5035        quantized,
5036        split,
5037        masks
5038    );
5039    e2e_tracked_test!(
5040        test_decoder_tracked_end_to_end_segdet_split_float,
5041        float,
5042        split,
5043        masks
5044    );
5045    e2e_tracked_test!(
5046        test_decoder_tracked_end_to_end_segdet_split_proto,
5047        quantized,
5048        split,
5049        proto
5050    );
5051    e2e_tracked_test!(
5052        test_decoder_tracked_end_to_end_segdet_split_proto_float,
5053        float,
5054        split,
5055        proto
5056    );
5057
5058    // ─── End-to-end tracked TensorDyn test macro ────────────────────
5059    //
5060    // Same as e2e_tracked_test but wraps data in TensorDyn and exercises
5061    // the public decode_tracked / decode_proto_tracked API.
5062
5063    macro_rules! e2e_tracked_tensor_test {
5064        ($name:ident, quantized, $layout:ident, $output:ident) => {
5065            #[test]
5066            fn $name() {
5067                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5068
5069                let is_split = matches!(stringify!($layout), "split");
5070                let is_proto = matches!(stringify!($output), "proto");
5071
5072                let score_threshold = 0.45;
5073                let iou_threshold = 0.45;
5074
5075                let mut boxes = Array2::zeros((10, 4));
5076                let mut scores = Array2::zeros((10, 1));
5077                let mut classes = Array2::zeros((10, 1));
5078                let mask = Array2::zeros((10, 32));
5079                let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5080                let protos_f64 = protos_f64.insert_axis(Axis(0));
5081                let protos_quant = (1.0 / 255.0, 0.0);
5082                let protos_u8: Array4<u8> =
5083                    quantize_ndarray(protos_f64.view(), protos_quant.into());
5084
5085                boxes
5086                    .slice_mut(s![0, ..])
5087                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5088                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5089                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5090
5091                let detect_quant = (2.0 / 255.0, 0.0);
5092
5093                let decoder = if is_split {
5094                    DecoderBuilder::default()
5095                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5096                        .with_score_threshold(score_threshold)
5097                        .with_iou_threshold(iou_threshold)
5098                        .build()
5099                        .unwrap()
5100                } else {
5101                    DecoderBuilder::default()
5102                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5103                        .with_score_threshold(score_threshold)
5104                        .with_iou_threshold(iou_threshold)
5105                        .build()
5106                        .unwrap()
5107                };
5108
5109                // Helper to wrap a u8 slice into a TensorDyn
5110                let make_u8_tensor =
5111                    |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5112                        let t = Tensor::<u8>::new(shape, None, None).unwrap();
5113                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5114                        t.into()
5115                    };
5116
5117                let expected = e2e_expected_boxes_quant();
5118                let mut tracker = ByteTrackBuilder::new()
5119                    .track_update(0.1)
5120                    .track_high_conf(0.7)
5121                    .build();
5122                let mut output_boxes = Vec::with_capacity(50);
5123                let mut output_tracks = Vec::with_capacity(50);
5124
5125                let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5126
5127                if is_split {
5128                    let boxes = boxes.insert_axis(Axis(0));
5129                    let scores = scores.insert_axis(Axis(0));
5130                    let classes = classes.insert_axis(Axis(0));
5131                    let mask = mask.insert_axis(Axis(0));
5132
5133                    let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5134                    let mut scores_q: Array3<u8> =
5135                        quantize_ndarray(scores.view(), detect_quant.into());
5136                    let classes_q: Array3<u8> =
5137                        quantize_ndarray(classes.view(), detect_quant.into());
5138                    let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5139
5140                    let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5141                    let classes_td =
5142                        make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5143                    let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5144
5145                    if is_proto {
5146                        let scores_td =
5147                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5148                        decoder
5149                            .decode_proto_tracked(
5150                                &mut tracker,
5151                                0,
5152                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5153                                &mut output_boxes,
5154                                &mut output_tracks,
5155                            )
5156                            .unwrap();
5157
5158                        assert_eq!(output_boxes.len(), 1);
5159                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5160
5161                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5162                            *score = u8::MIN;
5163                        }
5164                        let scores_td =
5165                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5166                        let proto_result = decoder
5167                            .decode_proto_tracked(
5168                                &mut tracker,
5169                                100_000_000 / 3,
5170                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5171                                &mut output_boxes,
5172                                &mut output_tracks,
5173                            )
5174                            .unwrap();
5175                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5176                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5177                    } else {
5178                        let scores_td =
5179                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5180                        let mut output_masks = Vec::with_capacity(50);
5181                        decoder
5182                            .decode_tracked(
5183                                &mut tracker,
5184                                0,
5185                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5186                                &mut output_boxes,
5187                                &mut output_masks,
5188                                &mut output_tracks,
5189                            )
5190                            .unwrap();
5191
5192                        assert_eq!(output_boxes.len(), 1);
5193                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5194
5195                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5196                            *score = u8::MIN;
5197                        }
5198                        let scores_td =
5199                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5200                        decoder
5201                            .decode_tracked(
5202                                &mut tracker,
5203                                100_000_000 / 3,
5204                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5205                                &mut output_boxes,
5206                                &mut output_masks,
5207                                &mut output_tracks,
5208                            )
5209                            .unwrap();
5210                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5211                        assert!(output_masks.is_empty());
5212                    }
5213                } else {
5214                    // Combined layout
5215                    let detect = ndarray::concatenate![
5216                        Axis(1),
5217                        boxes.view(),
5218                        scores.view(),
5219                        classes.view(),
5220                        mask.view()
5221                    ];
5222                    let detect = detect.insert_axis(Axis(0));
5223                    assert_eq!(detect.shape(), &[1, 10, 38]);
5224                    // Ensure contiguous layout after concatenation for as_slice()
5225                    let detect =
5226                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5227                            .unwrap();
5228                    let mut detect_q: Array3<u8> =
5229                        quantize_ndarray(detect.view(), detect_quant.into());
5230
5231                    if is_proto {
5232                        let detect_td =
5233                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5234                        decoder
5235                            .decode_proto_tracked(
5236                                &mut tracker,
5237                                0,
5238                                &[&detect_td, &protos_td],
5239                                &mut output_boxes,
5240                                &mut output_tracks,
5241                            )
5242                            .unwrap();
5243
5244                        assert_eq!(output_boxes.len(), 1);
5245                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5246
5247                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5248                            *score = u8::MIN;
5249                        }
5250                        let detect_td =
5251                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5252                        let proto_result = decoder
5253                            .decode_proto_tracked(
5254                                &mut tracker,
5255                                100_000_000 / 3,
5256                                &[&detect_td, &protos_td],
5257                                &mut output_boxes,
5258                                &mut output_tracks,
5259                            )
5260                            .unwrap();
5261                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5262                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5263                    } else {
5264                        let detect_td =
5265                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5266                        let mut output_masks = Vec::with_capacity(50);
5267                        decoder
5268                            .decode_tracked(
5269                                &mut tracker,
5270                                0,
5271                                &[&detect_td, &protos_td],
5272                                &mut output_boxes,
5273                                &mut output_masks,
5274                                &mut output_tracks,
5275                            )
5276                            .unwrap();
5277
5278                        assert_eq!(output_boxes.len(), 1);
5279                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5280
5281                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5282                            *score = u8::MIN;
5283                        }
5284                        let detect_td =
5285                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5286                        decoder
5287                            .decode_tracked(
5288                                &mut tracker,
5289                                100_000_000 / 3,
5290                                &[&detect_td, &protos_td],
5291                                &mut output_boxes,
5292                                &mut output_masks,
5293                                &mut output_tracks,
5294                            )
5295                            .unwrap();
5296                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5297                        assert!(output_masks.is_empty());
5298                    }
5299                }
5300            }
5301        };
5302        ($name:ident, float, $layout:ident, $output:ident) => {
5303            #[test]
5304            fn $name() {
5305                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5306
5307                let is_split = matches!(stringify!($layout), "split");
5308                let is_proto = matches!(stringify!($output), "proto");
5309
5310                let score_threshold = 0.45;
5311                let iou_threshold = 0.45;
5312
5313                let mut boxes = Array2::zeros((10, 4));
5314                let mut scores = Array2::zeros((10, 1));
5315                let mut classes = Array2::zeros((10, 1));
5316                let mask: Array2<f64> = Array2::zeros((10, 32));
5317                let protos = Array3::<f64>::zeros((160, 160, 32));
5318                let protos = protos.insert_axis(Axis(0));
5319
5320                boxes
5321                    .slice_mut(s![0, ..])
5322                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5323                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5324                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5325
5326                let decoder = if is_split {
5327                    DecoderBuilder::default()
5328                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5329                        .with_score_threshold(score_threshold)
5330                        .with_iou_threshold(iou_threshold)
5331                        .build()
5332                        .unwrap()
5333                } else {
5334                    DecoderBuilder::default()
5335                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5336                        .with_score_threshold(score_threshold)
5337                        .with_iou_threshold(iou_threshold)
5338                        .build()
5339                        .unwrap()
5340                };
5341
5342                // Helper to wrap an f64 slice into a TensorDyn
5343                let make_f64_tensor =
5344                    |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
5345                        let t = Tensor::<f64>::new(shape, None, None).unwrap();
5346                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5347                        t.into()
5348                    };
5349
5350                let expected = e2e_expected_boxes_float();
5351                let mut tracker = ByteTrackBuilder::new()
5352                    .track_update(0.1)
5353                    .track_high_conf(0.7)
5354                    .build();
5355                let mut output_boxes = Vec::with_capacity(50);
5356                let mut output_tracks = Vec::with_capacity(50);
5357
5358                let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
5359
5360                if is_split {
5361                    let boxes = boxes.insert_axis(Axis(0));
5362                    let mut scores = scores.insert_axis(Axis(0));
5363                    let classes = classes.insert_axis(Axis(0));
5364                    let mask = mask.insert_axis(Axis(0));
5365
5366                    let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
5367                    let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
5368                    let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
5369
5370                    if is_proto {
5371                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5372                        decoder
5373                            .decode_proto_tracked(
5374                                &mut tracker,
5375                                0,
5376                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5377                                &mut output_boxes,
5378                                &mut output_tracks,
5379                            )
5380                            .unwrap();
5381
5382                        assert_eq!(output_boxes.len(), 1);
5383                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5384
5385                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5386                            *score = 0.0;
5387                        }
5388                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5389                        let proto_result = decoder
5390                            .decode_proto_tracked(
5391                                &mut tracker,
5392                                100_000_000 / 3,
5393                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5394                                &mut output_boxes,
5395                                &mut output_tracks,
5396                            )
5397                            .unwrap();
5398                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5399                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5400                    } else {
5401                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5402                        let mut output_masks = Vec::with_capacity(50);
5403                        decoder
5404                            .decode_tracked(
5405                                &mut tracker,
5406                                0,
5407                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5408                                &mut output_boxes,
5409                                &mut output_masks,
5410                                &mut output_tracks,
5411                            )
5412                            .unwrap();
5413
5414                        assert_eq!(output_boxes.len(), 1);
5415                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5416
5417                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5418                            *score = 0.0;
5419                        }
5420                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5421                        decoder
5422                            .decode_tracked(
5423                                &mut tracker,
5424                                100_000_000 / 3,
5425                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5426                                &mut output_boxes,
5427                                &mut output_masks,
5428                                &mut output_tracks,
5429                            )
5430                            .unwrap();
5431                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5432                        assert!(output_masks.is_empty());
5433                    }
5434                } else {
5435                    // Combined layout
5436                    let detect = ndarray::concatenate![
5437                        Axis(1),
5438                        boxes.view(),
5439                        scores.view(),
5440                        classes.view(),
5441                        mask.view()
5442                    ];
5443                    let detect = detect.insert_axis(Axis(0));
5444                    assert_eq!(detect.shape(), &[1, 10, 38]);
5445                    // Ensure contiguous layout after concatenation for as_slice()
5446                    let mut detect =
5447                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5448                            .unwrap();
5449
5450                    if is_proto {
5451                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5452                        decoder
5453                            .decode_proto_tracked(
5454                                &mut tracker,
5455                                0,
5456                                &[&detect_td, &protos_td],
5457                                &mut output_boxes,
5458                                &mut output_tracks,
5459                            )
5460                            .unwrap();
5461
5462                        assert_eq!(output_boxes.len(), 1);
5463                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5464
5465                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5466                            *score = 0.0;
5467                        }
5468                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5469                        let proto_result = decoder
5470                            .decode_proto_tracked(
5471                                &mut tracker,
5472                                100_000_000 / 3,
5473                                &[&detect_td, &protos_td],
5474                                &mut output_boxes,
5475                                &mut output_tracks,
5476                            )
5477                            .unwrap();
5478                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5479                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5480                    } else {
5481                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5482                        let mut output_masks = Vec::with_capacity(50);
5483                        decoder
5484                            .decode_tracked(
5485                                &mut tracker,
5486                                0,
5487                                &[&detect_td, &protos_td],
5488                                &mut output_boxes,
5489                                &mut output_masks,
5490                                &mut output_tracks,
5491                            )
5492                            .unwrap();
5493
5494                        assert_eq!(output_boxes.len(), 1);
5495                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5496
5497                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5498                            *score = 0.0;
5499                        }
5500                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5501                        decoder
5502                            .decode_tracked(
5503                                &mut tracker,
5504                                100_000_000 / 3,
5505                                &[&detect_td, &protos_td],
5506                                &mut output_boxes,
5507                                &mut output_masks,
5508                                &mut output_tracks,
5509                            )
5510                            .unwrap();
5511                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5512                        assert!(output_masks.is_empty());
5513                    }
5514                }
5515            }
5516        };
5517    }
5518
5519    e2e_tracked_tensor_test!(
5520        test_decoder_tracked_tensor_end_to_end_segdet,
5521        quantized,
5522        combined,
5523        masks
5524    );
5525    e2e_tracked_tensor_test!(
5526        test_decoder_tracked_tensor_end_to_end_segdet_float,
5527        float,
5528        combined,
5529        masks
5530    );
5531    e2e_tracked_tensor_test!(
5532        test_decoder_tracked_tensor_end_to_end_segdet_proto,
5533        quantized,
5534        combined,
5535        proto
5536    );
5537    e2e_tracked_tensor_test!(
5538        test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
5539        float,
5540        combined,
5541        proto
5542    );
5543    e2e_tracked_tensor_test!(
5544        test_decoder_tracked_tensor_end_to_end_segdet_split,
5545        quantized,
5546        split,
5547        masks
5548    );
5549    e2e_tracked_tensor_test!(
5550        test_decoder_tracked_tensor_end_to_end_segdet_split_float,
5551        float,
5552        split,
5553        masks
5554    );
5555    e2e_tracked_tensor_test!(
5556        test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
5557        quantized,
5558        split,
5559        proto
5560    );
5561    e2e_tracked_tensor_test!(
5562        test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
5563        float,
5564        split,
5565        proto
5566    );
5567
5568    #[test]
5569    fn test_decoder_tracked_linear_motion() {
5570        use crate::configs::{DecoderType, Nms};
5571        use crate::DecoderBuilder;
5572
5573        let score_threshold = 0.25;
5574        let iou_threshold = 0.1;
5575        let out = include_bytes!(concat!(
5576            env!("CARGO_MANIFEST_DIR"),
5577            "/../../testdata/yolov8s_80_classes.bin"
5578        ));
5579        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5580        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5581        let quant = (0.0040811873, -123).into();
5582
5583        let decoder = DecoderBuilder::default()
5584            .with_config_yolo_det(
5585                crate::configs::Detection {
5586                    decoder: DecoderType::Ultralytics,
5587                    shape: vec![1, 84, 8400],
5588                    anchors: None,
5589                    quantization: Some(quant),
5590                    dshape: vec![
5591                        (crate::configs::DimName::Batch, 1),
5592                        (crate::configs::DimName::NumFeatures, 84),
5593                        (crate::configs::DimName::NumBoxes, 8400),
5594                    ],
5595                    normalized: Some(true),
5596                },
5597                None,
5598            )
5599            .with_score_threshold(score_threshold)
5600            .with_iou_threshold(iou_threshold)
5601            .with_nms(Some(Nms::ClassAgnostic))
5602            .build()
5603            .unwrap();
5604
5605        let mut expected_boxes = [
5606            DetectBox {
5607                bbox: BoundingBox {
5608                    xmin: 0.5285137,
5609                    ymin: 0.05305544,
5610                    xmax: 0.87541467,
5611                    ymax: 0.9998909,
5612                },
5613                score: 0.5591227,
5614                label: 0,
5615            },
5616            DetectBox {
5617                bbox: BoundingBox {
5618                    xmin: 0.130598,
5619                    ymin: 0.43260583,
5620                    xmax: 0.35098213,
5621                    ymax: 0.9958097,
5622                },
5623                score: 0.33057618,
5624                label: 75,
5625            },
5626        ];
5627
5628        let mut tracker = ByteTrackBuilder::new()
5629            .track_update(0.1)
5630            .track_high_conf(0.3)
5631            .build();
5632
5633        let mut output_boxes = Vec::with_capacity(50);
5634        let mut output_masks = Vec::with_capacity(50);
5635        let mut output_tracks = Vec::with_capacity(50);
5636
5637        decoder
5638            .decode_tracked_quantized(
5639                &mut tracker,
5640                0,
5641                &[out.view().into()],
5642                &mut output_boxes,
5643                &mut output_masks,
5644                &mut output_tracks,
5645            )
5646            .unwrap();
5647
5648        assert_eq!(output_boxes.len(), 2);
5649        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5650        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5651
5652        for i in 1..=100 {
5653            let mut out = out.clone();
5654            // introduce linear movement into the XY coordinates
5655            let mut x_values = out.slice_mut(s![0, 0, ..]);
5656            for x in x_values.iter_mut() {
5657                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5658            }
5659
5660            decoder
5661                .decode_tracked_quantized(
5662                    &mut tracker,
5663                    100_000_000 * i / 3, // simulate 33.333ms between frames
5664                    &[out.view().into()],
5665                    &mut output_boxes,
5666                    &mut output_masks,
5667                    &mut output_tracks,
5668                )
5669                .unwrap();
5670
5671            assert_eq!(output_boxes.len(), 2);
5672        }
5673        let tracks = tracker.get_active_tracks();
5674        let predicted_boxes: Vec<_> = tracks
5675            .iter()
5676            .map(|track| {
5677                let mut l = track.last_box;
5678                l.bbox = track.info.tracked_location.into();
5679                l
5680            })
5681            .collect();
5682        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
5683        expected_boxes[0].bbox.xmax += 0.1;
5684        expected_boxes[1].bbox.xmin += 0.1;
5685        expected_boxes[1].bbox.xmax += 0.1;
5686
5687        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5688        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5689
5690        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5691        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5692        for score in scores_values.iter_mut() {
5693            *score = i8::MIN; // set all scores to minimum to simulate no detections
5694        }
5695        decoder
5696            .decode_tracked_quantized(
5697                &mut tracker,
5698                100_000_000 * 101 / 3,
5699                &[out.view().into()],
5700                &mut output_boxes,
5701                &mut output_masks,
5702                &mut output_tracks,
5703            )
5704            .unwrap();
5705        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
5706        expected_boxes[0].bbox.xmax += 0.001;
5707        expected_boxes[1].bbox.xmin += 0.001;
5708        expected_boxes[1].bbox.xmax += 0.001;
5709
5710        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5711        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5712    }
5713
5714    #[test]
5715    fn test_decoder_tracked_end_to_end_float() {
5716        let score_threshold = 0.45;
5717        let iou_threshold = 0.45;
5718
5719        let mut boxes = Array2::zeros((10, 4));
5720        let mut scores = Array2::zeros((10, 1));
5721        let mut classes = Array2::zeros((10, 1));
5722
5723        boxes
5724            .slice_mut(s![0, ..,])
5725            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5726        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5727        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5728
5729        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5730        let mut detect = detect.insert_axis(Axis(0));
5731        assert_eq!(detect.shape(), &[1, 10, 6]);
5732        let config = "
5733decoder_version: yolo26
5734outputs:
5735 - type: detection
5736   decoder: ultralytics
5737   quantization: [0.00784313725490196, 0]
5738   shape: [1, 10, 6]
5739   dshape:
5740    - [batch, 1]
5741    - [num_boxes, 10]
5742    - [num_features, 6]
5743   normalized: true
5744";
5745
5746        let decoder = DecoderBuilder::default()
5747            .with_config_yaml_str(config.to_string())
5748            .with_score_threshold(score_threshold)
5749            .with_iou_threshold(iou_threshold)
5750            .build()
5751            .unwrap();
5752
5753        let expected_boxes = [DetectBox {
5754            bbox: BoundingBox {
5755                xmin: 0.1234,
5756                ymin: 0.1234,
5757                xmax: 0.2345,
5758                ymax: 0.2345,
5759            },
5760            score: 0.9876,
5761            label: 2,
5762        }];
5763
5764        let mut tracker = ByteTrackBuilder::new()
5765            .track_update(0.1)
5766            .track_high_conf(0.7)
5767            .build();
5768
5769        let mut output_boxes = Vec::with_capacity(50);
5770        let mut output_masks = Vec::with_capacity(50);
5771        let mut output_tracks = Vec::with_capacity(50);
5772
5773        decoder
5774            .decode_tracked_float(
5775                &mut tracker,
5776                0,
5777                &[detect.view().into_dyn()],
5778                &mut output_boxes,
5779                &mut output_masks,
5780                &mut output_tracks,
5781            )
5782            .unwrap();
5783
5784        assert_eq!(output_boxes.len(), 1);
5785        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5786
5787        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5788
5789        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5790            *score = 0.0; // set all scores to minimum to simulate no detections
5791        }
5792
5793        decoder
5794            .decode_tracked_float(
5795                &mut tracker,
5796                100_000_000 / 3,
5797                &[detect.view().into_dyn()],
5798                &mut output_boxes,
5799                &mut output_masks,
5800                &mut output_tracks,
5801            )
5802            .unwrap();
5803        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5804    }
5805}