Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod schema;
77pub mod yolo;
78
79mod decoder;
80pub use decoder::*;
81
82pub use configs::{DecoderVersion, Nms};
83pub use error::{DecoderError, DecoderResult};
84
85use crate::{
86    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
87    yolo::yolo_segmentation_to_mask,
88};
89
90/// Trait to convert bounding box formats to XYXY float format
91pub trait BBoxTypeTrait {
92    /// Converts the bbox into XYXY float format.
93    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
94
95    /// Converts the bbox into XYXY float format.
96    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
97        input: &[B; 4],
98        quant: Quantization,
99    ) -> [A; 4]
100    where
101        f32: AsPrimitive<A>,
102        i32: AsPrimitive<A>;
103
104    /// Converts the bbox into XYXY float format.
105    ///
106    /// # Examples
107    /// ```rust
108    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
109    /// # use ndarray::array;
110    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
111    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
112    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
113    /// ```
114    #[inline(always)]
115    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
116        input: ArrayView1<B>,
117    ) -> [A; 4] {
118        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
119    }
120
121    #[inline(always)]
122    /// Converts the bbox into XYXY float format.
123    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
124        input: ArrayView1<B>,
125        quant: Quantization,
126    ) -> [A; 4]
127    where
128        f32: AsPrimitive<A>,
129        i32: AsPrimitive<A>,
130    {
131        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
132    }
133}
134
135/// Converts XYXY bounding boxes to XYXY
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137pub struct XYXY {}
138
139impl BBoxTypeTrait for XYXY {
140    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
141        input.map(|b| b.as_())
142    }
143
144    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
145        input: &[B; 4],
146        quant: Quantization,
147    ) -> [A; 4]
148    where
149        f32: AsPrimitive<A>,
150        i32: AsPrimitive<A>,
151    {
152        let scale = quant.scale.as_();
153        let zp = quant.zero_point.as_();
154        input.map(|b| (b.as_() - zp) * scale)
155    }
156
157    #[inline(always)]
158    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
159        input: ArrayView1<B>,
160    ) -> [A; 4] {
161        [
162            input[0].as_(),
163            input[1].as_(),
164            input[2].as_(),
165            input[3].as_(),
166        ]
167    }
168}
169
170/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
171/// box
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub struct XYWH {}
174
175impl BBoxTypeTrait for XYWH {
176    #[inline(always)]
177    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
178        let half = A::one() / (A::one() + A::one());
179        [
180            (input[0].as_()) - (input[2].as_() * half),
181            (input[1].as_()) - (input[3].as_() * half),
182            (input[0].as_()) + (input[2].as_() * half),
183            (input[1].as_()) + (input[3].as_() * half),
184        ]
185    }
186
187    #[inline(always)]
188    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
189        input: &[B; 4],
190        quant: Quantization,
191    ) -> [A; 4]
192    where
193        f32: AsPrimitive<A>,
194        i32: AsPrimitive<A>,
195    {
196        let scale = quant.scale.as_();
197        let half_scale = (quant.scale * 0.5).as_();
198        let zp = quant.zero_point.as_();
199        let [x, y, w, h] = [
200            (input[0].as_() - zp) * scale,
201            (input[1].as_() - zp) * scale,
202            (input[2].as_() - zp) * half_scale,
203            (input[3].as_() - zp) * half_scale,
204        ];
205
206        [x - w, y - h, x + w, y + h]
207    }
208
209    #[inline(always)]
210    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
211        input: ArrayView1<B>,
212    ) -> [A; 4] {
213        let half = A::one() / (A::one() + A::one());
214        [
215            (input[0].as_()) - (input[2].as_() * half),
216            (input[1].as_()) - (input[3].as_() * half),
217            (input[0].as_()) + (input[2].as_() * half),
218            (input[1].as_()) + (input[3].as_() * half),
219        ]
220    }
221}
222
223/// Describes the quantization parameters for a tensor
224#[derive(Debug, Clone, Copy, PartialEq)]
225pub struct Quantization {
226    pub scale: f32,
227    pub zero_point: i32,
228}
229
230impl Quantization {
231    /// Creates a new Quantization struct
232    /// # Examples
233    /// ```
234    /// # use edgefirst_decoder::Quantization;
235    /// let quant = Quantization::new(0.1, -128);
236    /// assert_eq!(quant.scale, 0.1);
237    /// assert_eq!(quant.zero_point, -128);
238    /// ```
239    pub fn new(scale: f32, zero_point: i32) -> Self {
240        Self { scale, zero_point }
241    }
242}
243
244impl From<QuantTuple> for Quantization {
245    /// Creates a new Quantization struct from a QuantTuple
246    /// # Examples
247    /// ```
248    /// # use edgefirst_decoder::Quantization;
249    /// # use edgefirst_decoder::configs::QuantTuple;
250    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
251    /// let quant = Quantization::from(quant_tuple);
252    /// assert_eq!(quant.scale, 0.1);
253    /// assert_eq!(quant.zero_point, -128);
254    /// ```
255    fn from(quant_tuple: QuantTuple) -> Quantization {
256        Quantization {
257            scale: quant_tuple.0,
258            zero_point: quant_tuple.1,
259        }
260    }
261}
262
263impl<S, Z> From<(S, Z)> for Quantization
264where
265    S: AsPrimitive<f32>,
266    Z: AsPrimitive<i32>,
267{
268    /// Creates a new Quantization struct from a tuple
269    /// # Examples
270    /// ```
271    /// # use edgefirst_decoder::Quantization;
272    /// let quant = Quantization::from((0.1_f64, -128_i64));
273    /// assert_eq!(quant.scale, 0.1);
274    /// assert_eq!(quant.zero_point, -128);
275    /// ```
276    fn from((scale, zp): (S, Z)) -> Quantization {
277        Self {
278            scale: scale.as_(),
279            zero_point: zp.as_(),
280        }
281    }
282}
283
284impl Default for Quantization {
285    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
286    /// # Examples
287    /// ```rust
288    /// # use edgefirst_decoder::Quantization;
289    /// let quant = Quantization::default();
290    /// assert_eq!(quant.scale, 1.0);
291    /// assert_eq!(quant.zero_point, 0);
292    /// ```
293    fn default() -> Self {
294        Self {
295            scale: 1.0,
296            zero_point: 0,
297        }
298    }
299}
300
301/// A detection box with f32 bbox and score
302#[derive(Debug, Clone, Copy, PartialEq, Default)]
303pub struct DetectBox {
304    pub bbox: BoundingBox,
305    /// model-specific score for this detection, higher implies more confidence
306    pub score: f32,
307    /// label index for this detection
308    pub label: usize,
309}
310
311/// A bounding box with f32 coordinates in XYXY format
312#[derive(Debug, Clone, Copy, PartialEq, Default)]
313pub struct BoundingBox {
314    /// left-most normalized coordinate of the bounding box
315    pub xmin: f32,
316    /// top-most normalized coordinate of the bounding box
317    pub ymin: f32,
318    /// right-most normalized coordinate of the bounding box
319    pub xmax: f32,
320    /// bottom-most normalized coordinate of the bounding box
321    pub ymax: f32,
322}
323
324impl BoundingBox {
325    /// Creates a new BoundingBox from the given coordinates
326    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
327        Self {
328            xmin,
329            ymin,
330            xmax,
331            ymax,
332        }
333    }
334
335    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
336    ///
337    /// ```
338    /// # use edgefirst_decoder::BoundingBox;
339    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
340    /// let canonical_bbox = bbox.to_canonical();
341    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
342    /// ```
343    pub fn to_canonical(&self) -> Self {
344        let xmin = self.xmin.min(self.xmax);
345        let xmax = self.xmin.max(self.xmax);
346        let ymin = self.ymin.min(self.ymax);
347        let ymax = self.ymin.max(self.ymax);
348        BoundingBox {
349            xmin,
350            ymin,
351            xmax,
352            ymax,
353        }
354    }
355}
356
357impl From<BoundingBox> for [f32; 4] {
358    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
359    /// xmax, ymax order
360    /// # Examples
361    /// ```
362    /// # use edgefirst_decoder::BoundingBox;
363    /// let bbox = BoundingBox {
364    ///     xmin: 0.1,
365    ///     ymin: 0.2,
366    ///     xmax: 0.3,
367    ///     ymax: 0.4,
368    /// };
369    /// let arr: [f32; 4] = bbox.into();
370    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
371    /// ```
372    fn from(b: BoundingBox) -> Self {
373        [b.xmin, b.ymin, b.xmax, b.ymax]
374    }
375}
376
377impl From<[f32; 4]> for BoundingBox {
378    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
379    // BoundingBox
380    fn from(arr: [f32; 4]) -> Self {
381        BoundingBox {
382            xmin: arr[0],
383            ymin: arr[1],
384            xmax: arr[2],
385            ymax: arr[3],
386        }
387    }
388}
389
390impl DetectBox {
391    /// Returns true if one detect box is equal to another detect box, within
392    /// the given `eps`
393    ///
394    /// # Examples
395    /// ```
396    /// # use edgefirst_decoder::DetectBox;
397    /// let box1 = DetectBox {
398    ///     bbox: edgefirst_decoder::BoundingBox {
399    ///         xmin: 0.1,
400    ///         ymin: 0.2,
401    ///         xmax: 0.3,
402    ///         ymax: 0.4,
403    ///     },
404    ///     score: 0.5,
405    ///     label: 1,
406    /// };
407    /// let box2 = DetectBox {
408    ///     bbox: edgefirst_decoder::BoundingBox {
409    ///         xmin: 0.101,
410    ///         ymin: 0.199,
411    ///         xmax: 0.301,
412    ///         ymax: 0.399,
413    ///     },
414    ///     score: 0.510,
415    ///     label: 1,
416    /// };
417    /// assert!(box1.equal_within_delta(&box2, 0.011));
418    /// ```
419    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
420        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
421        self.label == rhs.label
422            && eq_delta(self.score, rhs.score)
423            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
424            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
425            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
426            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
427    }
428}
429
430/// A segmentation result with a segmentation mask, and a normalized bounding
431/// box representing the area that the segmentation mask covers
432#[derive(Debug, Clone, PartialEq, Default)]
433pub struct Segmentation {
434    /// left-most normalized coordinate of the segmentation box
435    pub xmin: f32,
436    /// top-most normalized coordinate of the segmentation box
437    pub ymin: f32,
438    /// right-most normalized coordinate of the segmentation box
439    pub xmax: f32,
440    /// bottom-most normalized coordinate of the segmentation box
441    pub ymax: f32,
442    /// 3D segmentation array of shape `(H, W, C)`.
443    ///
444    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
445    /// continuous sigmoid confidence values quantized to u8 (0 = background,
446    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
447    /// 0.5) or use smooth interpolation for anti-aliased edges.
448    ///
449    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
450    /// class scores where the object class is the argmax index.
451    pub segmentation: Array3<u8>,
452}
453
454/// Raw prototype data for fused decode+render pipelines.
455///
456/// Holds post-NMS intermediate state before mask materialization, allowing the
457/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
458/// shader) without materializing intermediate `Array3<u8>` masks.
459///
460/// Both fields are carried as [`TensorDyn`] so downstream consumers (Rust, C
461/// API, Python) get zero-copy typed access through the HAL's shared tensor
462/// infrastructure. Dtype policy:
463///
464/// | Source model | protos dtype | mask_coefficients dtype | protos.quantization |
465/// |---|---|---|---|
466/// | int8 quantized | [`TensorDyn::I8`] | [`TensorDyn::F32`] (dequantized) | `Some(q)` |
467/// | f32 | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
468/// | f16 (TensorRT fp16) | [`TensorDyn::F16`] | [`TensorDyn::F16`] | `None` |
469/// | f64 (narrowed) | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
470///
471/// Quantization metadata lives on the proto tensor itself via
472/// [`edgefirst_tensor::Tensor::quantization`] — float tensors cannot carry
473/// quantization (compile-time gated on the `IntegerType` sealed trait).
474///
475/// `TensorDyn` is not `Clone`, so neither is `ProtoData`. Consumers that need
476/// to share the proto buffer across threads should use `TensorDyn::clone_fd`
477/// / `dmabuf_clone` to dup the backing fd.
478#[derive(Debug)]
479pub struct ProtoData {
480    /// Per-detection mask coefficients, shape `[num_detections, num_protos]`.
481    pub mask_coefficients: edgefirst_tensor::TensorDyn,
482    /// Prototype tensor, shape `[proto_h, proto_w, num_protos]`.
483    pub protos: edgefirst_tensor::TensorDyn,
484}
485
486/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
487///
488///  # Examples
489/// ```
490/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
491/// let quant = Quantization::new(0.1, -128);
492/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
493/// let detect_quant = DetectBoxQuantized {
494///     bbox,
495///     score: 100_i8,
496///     label: 1,
497/// };
498/// let detect = dequant_detect_box(&detect_quant, quant);
499/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
500/// assert_eq!(detect.label, 1);
501/// assert_eq!(detect.bbox, bbox);
502/// ```
503pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
504    detect: &DetectBoxQuantized<SCORE>,
505    quant_scores: Quantization,
506) -> DetectBox {
507    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
508    DetectBox {
509        bbox: detect.bbox,
510        score: quant_scores.scale * detect.score.as_() + scaled_zp,
511        label: detect.label,
512    }
513}
514/// A detection box with a f32 bbox and quantized score
515#[derive(Debug, Clone, Copy, PartialEq)]
516pub struct DetectBoxQuantized<
517    // BOX: Signed + PrimInt + AsPrimitive<f32>,
518    SCORE: PrimInt + AsPrimitive<f32>,
519> {
520    // pub bbox: BoundingBoxQuantized<BOX>,
521    pub bbox: BoundingBox,
522    /// model-specific score for this detection, higher implies more
523    /// confidence.
524    pub score: SCORE,
525    /// label index for this detect
526    pub label: usize,
527}
528
529/// Dequantizes an ndarray from quantized values to f32 values using the given
530/// quantization parameters
531///
532/// # Examples
533/// ```
534/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
535/// let quant = Quantization::new(0.1, -128);
536/// let input: Vec<i8> = vec![0, 127, -128, 64];
537/// let input_array = ndarray::Array1::from(input);
538/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
539/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
540/// ```
541pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
542    input: ArrayView<T, D>,
543    quant: Quantization,
544) -> Array<F, D>
545where
546    i32: num_traits::AsPrimitive<F>,
547    f32: num_traits::AsPrimitive<F>,
548{
549    let zero_point = quant.zero_point.as_();
550    let scale = quant.scale.as_();
551    if zero_point != F::zero() {
552        let scaled_zero = -zero_point * scale;
553        input.mapv(|d| d.as_() * scale + scaled_zero)
554    } else {
555        input.mapv(|d| d.as_() * scale)
556    }
557}
558
559/// Dequantizes a slice from quantized values to float values using the given
560/// quantization parameters
561///
562/// # Examples
563/// ```
564/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
565/// let quant = Quantization::new(0.1, -128);
566/// let input: Vec<i8> = vec![0, 127, -128, 64];
567/// let mut output: Vec<f32> = vec![0.0; input.len()];
568/// dequantize_cpu(&input, quant, &mut output);
569/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
570/// ```
571pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
572    input: &[T],
573    quant: Quantization,
574    output: &mut [F],
575) where
576    f32: num_traits::AsPrimitive<F>,
577    i32: num_traits::AsPrimitive<F>,
578{
579    assert!(input.len() == output.len());
580    let zero_point = quant.zero_point.as_();
581    let scale = quant.scale.as_();
582    if zero_point != F::zero() {
583        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
584        input
585            .iter()
586            .zip(output)
587            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
588    } else {
589        input
590            .iter()
591            .zip(output)
592            .for_each(|(d, deq)| *deq = d.as_() * scale);
593    }
594}
595
596/// Dequantizes a slice from quantized values to float values using the given
597/// quantization parameters, using chunked processing. This is around 5% faster
598/// than `dequantize_cpu` for large slices.
599///
600/// # Examples
601/// ```
602/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
603/// let quant = Quantization::new(0.1, -128);
604/// let input: Vec<i8> = vec![0, 127, -128, 64];
605/// let mut output: Vec<f32> = vec![0.0; input.len()];
606/// dequantize_cpu_chunked(&input, quant, &mut output);
607/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
608/// ```
609pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
610    input: &[T],
611    quant: Quantization,
612    output: &mut [F],
613) where
614    f32: num_traits::AsPrimitive<F>,
615    i32: num_traits::AsPrimitive<F>,
616{
617    assert!(input.len() == output.len());
618    let zero_point = quant.zero_point.as_();
619    let scale = quant.scale.as_();
620
621    let input = input.as_chunks::<4>();
622    let output = output.as_chunks_mut::<4>();
623
624    if zero_point != F::zero() {
625        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
626
627        input
628            .0
629            .iter()
630            .zip(output.0)
631            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
632        input
633            .1
634            .iter()
635            .zip(output.1)
636            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
637    } else {
638        input
639            .0
640            .iter()
641            .zip(output.0)
642            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
643        input
644            .1
645            .iter()
646            .zip(output.1)
647            .for_each(|(d, deq)| *deq = d.as_() * scale);
648    }
649}
650
651/// Converts a segmentation tensor into a 2D mask
652/// If the last dimension of the segmentation tensor is 1, values equal or
653/// above 128 are considered objects. Otherwise the object is the argmax index
654///
655/// # Errors
656///
657/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
658/// invalid shape.
659///
660/// # Examples
661/// ```
662/// # use edgefirst_decoder::segmentation_to_mask;
663/// let segmentation =
664///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
665/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
666/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
667/// ```
668pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
669    if segmentation.shape()[2] == 0 {
670        return Err(DecoderError::InvalidShape(
671            "Segmentation tensor must have non-zero depth".to_string(),
672        ));
673    }
674    if segmentation.shape()[2] == 1 {
675        yolo_segmentation_to_mask(segmentation, 128)
676    } else {
677        Ok(modelpack_segmentation_to_mask(segmentation))
678    }
679}
680
681/// Returns the maximum value and its index from a 1D array
682fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
683    score
684        .iter()
685        .enumerate()
686        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
687            if max > *s {
688                (max, arg_max)
689            } else {
690                (*s, ind)
691            }
692        })
693}
694#[cfg(test)]
695#[cfg_attr(coverage_nightly, coverage(off))]
696mod decoder_tests {
697    #![allow(clippy::excessive_precision)]
698    use crate::{
699        configs::{DecoderType, DimName, Protos},
700        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
701        yolo::{
702            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
703            decode_yolo_segdet_quant,
704        },
705        *,
706    };
707    use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
708    use ndarray::Dimension;
709    use ndarray::{array, s, Array2, Array3, Array4, Axis};
710    use ndarray_stats::DeviationExt;
711    use num_traits::{AsPrimitive, PrimInt};
712
713    fn compare_outputs(
714        boxes: (&[DetectBox], &[DetectBox]),
715        masks: (&[Segmentation], &[Segmentation]),
716    ) {
717        let (boxes0, boxes1) = boxes;
718        let (masks0, masks1) = masks;
719
720        assert_eq!(boxes0.len(), boxes1.len());
721        assert_eq!(masks0.len(), masks1.len());
722
723        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
724            assert!(
725                b_i8.equal_within_delta(b_f32, 1e-6),
726                "{b_i8:?} is not equal to {b_f32:?}"
727            );
728        }
729
730        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
731            assert_eq!(
732                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
733                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
734            );
735            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
736            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
737            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
738            let diff = &mask_i8 - &mask_f32;
739            for x in 0..diff.shape()[0] {
740                for y in 0..diff.shape()[1] {
741                    for z in 0..diff.shape()[2] {
742                        let val = diff[[x, y, z]];
743                        assert!(
744                            val.abs() <= 1,
745                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
746                            x,
747                            y,
748                            z,
749                            val
750                        );
751                    }
752                }
753            }
754            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
755            assert!(
756                mean_sq_err < 1e-2,
757                "Mean Square Error between masks was greater than 1%: {:.2}%",
758                mean_sq_err * 100.0
759            );
760        }
761    }
762
763    // ─── Shared test data loaders ────────────────────────
764
765    fn load_yolov8_boxes() -> Array3<i8> {
766        let raw = include_bytes!(concat!(
767            env!("CARGO_MANIFEST_DIR"),
768            "/../../testdata/yolov8_boxes_116x8400.bin"
769        ));
770        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
771        Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
772    }
773
774    fn load_yolov8_protos() -> Array4<i8> {
775        let raw = include_bytes!(concat!(
776            env!("CARGO_MANIFEST_DIR"),
777            "/../../testdata/yolov8_protos_160x160x32.bin"
778        ));
779        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
780        Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
781    }
782
783    fn load_yolov8s_det() -> Array3<i8> {
784        let raw = include_bytes!(concat!(
785            env!("CARGO_MANIFEST_DIR"),
786            "/../../testdata/yolov8s_80_classes.bin"
787        ));
788        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
789        Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
790    }
791
792    #[test]
793    fn test_decoder_modelpack() {
794        let score_threshold = 0.45;
795        let iou_threshold = 0.45;
796        let boxes = include_bytes!(concat!(
797            env!("CARGO_MANIFEST_DIR"),
798            "/../../testdata/modelpack_boxes_1935x1x4.bin"
799        ));
800        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
801
802        let scores = include_bytes!(concat!(
803            env!("CARGO_MANIFEST_DIR"),
804            "/../../testdata/modelpack_scores_1935x1.bin"
805        ));
806        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
807
808        let quant_boxes = (0.004656755365431309, 21).into();
809        let quant_scores = (0.0019603664986789227, 0).into();
810
811        let decoder = DecoderBuilder::default()
812            .with_config_modelpack_det(
813                configs::Boxes {
814                    decoder: DecoderType::ModelPack,
815                    quantization: Some(quant_boxes),
816                    shape: vec![1, 1935, 1, 4],
817                    dshape: vec![
818                        (DimName::Batch, 1),
819                        (DimName::NumBoxes, 1935),
820                        (DimName::Padding, 1),
821                        (DimName::BoxCoords, 4),
822                    ],
823                    normalized: Some(true),
824                },
825                configs::Scores {
826                    decoder: DecoderType::ModelPack,
827                    quantization: Some(quant_scores),
828                    shape: vec![1, 1935, 1],
829                    dshape: vec![
830                        (DimName::Batch, 1),
831                        (DimName::NumBoxes, 1935),
832                        (DimName::NumClasses, 1),
833                    ],
834                },
835            )
836            .with_score_threshold(score_threshold)
837            .with_iou_threshold(iou_threshold)
838            .build()
839            .unwrap();
840
841        let quant_boxes = quant_boxes.into();
842        let quant_scores = quant_scores.into();
843
844        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
845        decode_modelpack_det(
846            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
847            (scores.slice(s![0, .., ..]), quant_scores),
848            score_threshold,
849            iou_threshold,
850            &mut output_boxes,
851        );
852        assert!(output_boxes[0].equal_within_delta(
853            &DetectBox {
854                bbox: BoundingBox {
855                    xmin: 0.40513772,
856                    ymin: 0.6379755,
857                    xmax: 0.5122431,
858                    ymax: 0.7730214,
859                },
860                score: 0.4861709,
861                label: 0
862            },
863            1e-6
864        ));
865
866        let mut output_boxes1 = Vec::with_capacity(50);
867        let mut output_masks1 = Vec::with_capacity(50);
868
869        decoder
870            .decode_quantized(
871                &[boxes.view().into(), scores.view().into()],
872                &mut output_boxes1,
873                &mut output_masks1,
874            )
875            .unwrap();
876
877        let mut output_boxes_float = Vec::with_capacity(50);
878        let mut output_masks_float = Vec::with_capacity(50);
879
880        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
881        let scores = dequantize_ndarray(scores.view(), quant_scores);
882
883        decoder
884            .decode_float::<f32>(
885                &[boxes.view().into_dyn(), scores.view().into_dyn()],
886                &mut output_boxes_float,
887                &mut output_masks_float,
888            )
889            .unwrap();
890
891        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
892        compare_outputs(
893            (&output_boxes, &output_boxes_float),
894            (&[], &output_masks_float),
895        );
896    }
897
898    #[test]
899    fn test_decoder_modelpack_split_u8() {
900        let score_threshold = 0.45;
901        let iou_threshold = 0.45;
902        let detect0 = include_bytes!(concat!(
903            env!("CARGO_MANIFEST_DIR"),
904            "/../../testdata/modelpack_split_9x15x18.bin"
905        ));
906        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
907
908        let detect1 = include_bytes!(concat!(
909            env!("CARGO_MANIFEST_DIR"),
910            "/../../testdata/modelpack_split_17x30x18.bin"
911        ));
912        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
913
914        let quant0 = (0.08547406643629074, 174).into();
915        let quant1 = (0.09929127991199493, 183).into();
916        let anchors0 = vec![
917            [0.36666667461395264, 0.31481480598449707],
918            [0.38749998807907104, 0.4740740656852722],
919            [0.5333333611488342, 0.644444465637207],
920        ];
921        let anchors1 = vec![
922            [0.13750000298023224, 0.2074074000120163],
923            [0.2541666626930237, 0.21481481194496155],
924            [0.23125000298023224, 0.35185185074806213],
925        ];
926
927        let detect_config0 = configs::Detection {
928            decoder: DecoderType::ModelPack,
929            shape: vec![1, 9, 15, 18],
930            anchors: Some(anchors0.clone()),
931            quantization: Some(quant0),
932            dshape: vec![
933                (DimName::Batch, 1),
934                (DimName::Height, 9),
935                (DimName::Width, 15),
936                (DimName::NumAnchorsXFeatures, 18),
937            ],
938            normalized: Some(true),
939        };
940
941        let detect_config1 = configs::Detection {
942            decoder: DecoderType::ModelPack,
943            shape: vec![1, 17, 30, 18],
944            anchors: Some(anchors1.clone()),
945            quantization: Some(quant1),
946            dshape: vec![
947                (DimName::Batch, 1),
948                (DimName::Height, 17),
949                (DimName::Width, 30),
950                (DimName::NumAnchorsXFeatures, 18),
951            ],
952            normalized: Some(true),
953        };
954
955        let config0 = (&detect_config0).try_into().unwrap();
956        let config1 = (&detect_config1).try_into().unwrap();
957
958        let decoder = DecoderBuilder::default()
959            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
960            .with_score_threshold(score_threshold)
961            .with_iou_threshold(iou_threshold)
962            .build()
963            .unwrap();
964
965        let quant0 = quant0.into();
966        let quant1 = quant1.into();
967
968        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
969        decode_modelpack_split_quant(
970            &[
971                detect0.slice(s![0, .., .., ..]),
972                detect1.slice(s![0, .., .., ..]),
973            ],
974            &[config0, config1],
975            score_threshold,
976            iou_threshold,
977            &mut output_boxes,
978        );
979        assert!(output_boxes[0].equal_within_delta(
980            &DetectBox {
981                bbox: BoundingBox {
982                    xmin: 0.43171933,
983                    ymin: 0.68243736,
984                    xmax: 0.5626645,
985                    ymax: 0.808863,
986                },
987                score: 0.99240804,
988                label: 0
989            },
990            1e-6
991        ));
992
993        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
994        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
995        decoder
996            .decode_quantized(
997                &[detect0.view().into(), detect1.view().into()],
998                &mut output_boxes1,
999                &mut output_masks1,
1000            )
1001            .unwrap();
1002
1003        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1004        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1005
1006        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1007        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1008        decoder
1009            .decode_float::<f32>(
1010                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1011                &mut output_boxes1_f32,
1012                &mut output_masks1_f32,
1013            )
1014            .unwrap();
1015
1016        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1017        compare_outputs(
1018            (&output_boxes, &output_boxes1_f32),
1019            (&[], &output_masks1_f32),
1020        );
1021    }
1022
1023    #[test]
1024    fn test_decoder_parse_config_modelpack_split_u8() {
1025        let score_threshold = 0.45;
1026        let iou_threshold = 0.45;
1027        let detect0 = include_bytes!(concat!(
1028            env!("CARGO_MANIFEST_DIR"),
1029            "/../../testdata/modelpack_split_9x15x18.bin"
1030        ));
1031        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1032
1033        let detect1 = include_bytes!(concat!(
1034            env!("CARGO_MANIFEST_DIR"),
1035            "/../../testdata/modelpack_split_17x30x18.bin"
1036        ));
1037        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1038
1039        let decoder = DecoderBuilder::default()
1040            .with_config_yaml_str(
1041                include_str!(concat!(
1042                    env!("CARGO_MANIFEST_DIR"),
1043                    "/../../testdata/modelpack_split.yaml"
1044                ))
1045                .to_string(),
1046            )
1047            .with_score_threshold(score_threshold)
1048            .with_iou_threshold(iou_threshold)
1049            .build()
1050            .unwrap();
1051
1052        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1053        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1054        decoder
1055            .decode_quantized(
1056                &[
1057                    ArrayViewDQuantized::from(detect1.view()),
1058                    ArrayViewDQuantized::from(detect0.view()),
1059                ],
1060                &mut output_boxes,
1061                &mut output_masks,
1062            )
1063            .unwrap();
1064        assert!(output_boxes[0].equal_within_delta(
1065            &DetectBox {
1066                bbox: BoundingBox {
1067                    xmin: 0.43171933,
1068                    ymin: 0.68243736,
1069                    xmax: 0.5626645,
1070                    ymax: 0.808863,
1071                },
1072                score: 0.99240804,
1073                label: 0
1074            },
1075            1e-6
1076        ));
1077    }
1078
1079    #[test]
1080    fn test_modelpack_seg() {
1081        let out = include_bytes!(concat!(
1082            env!("CARGO_MANIFEST_DIR"),
1083            "/../../testdata/modelpack_seg_2x160x160.bin"
1084        ));
1085        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1086        let quant = (1.0 / 255.0, 0).into();
1087
1088        let decoder = DecoderBuilder::default()
1089            .with_config_modelpack_seg(configs::Segmentation {
1090                decoder: DecoderType::ModelPack,
1091                quantization: Some(quant),
1092                shape: vec![1, 2, 160, 160],
1093                dshape: vec![
1094                    (DimName::Batch, 1),
1095                    (DimName::NumClasses, 2),
1096                    (DimName::Height, 160),
1097                    (DimName::Width, 160),
1098                ],
1099            })
1100            .build()
1101            .unwrap();
1102        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1103        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1104        decoder
1105            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1106            .unwrap();
1107
1108        let mut mask = out.slice(s![0, .., .., ..]);
1109        mask.swap_axes(0, 1);
1110        mask.swap_axes(1, 2);
1111        let mask = [Segmentation {
1112            xmin: 0.0,
1113            ymin: 0.0,
1114            xmax: 1.0,
1115            ymax: 1.0,
1116            segmentation: mask.into_owned(),
1117        }];
1118        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1119
1120        decoder
1121            .decode_float::<f32>(
1122                &[dequantize_ndarray(out.view(), quant.into())
1123                    .view()
1124                    .into_dyn()],
1125                &mut output_boxes,
1126                &mut output_masks,
1127            )
1128            .unwrap();
1129
1130        // not expected for float decoder to have same values as quantized decoder, as
1131        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1132        // the model output. Thus the float output is the same as the quantized output
1133        // but scaled differently. However, it is expected that the mask after argmax
1134        // will be the same.
1135        compare_outputs((&[], &output_boxes), (&[], &[]));
1136        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1137        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1138
1139        assert_eq!(mask0, mask1);
1140    }
1141    #[test]
1142    fn test_modelpack_seg_quant() {
1143        let out = include_bytes!(concat!(
1144            env!("CARGO_MANIFEST_DIR"),
1145            "/../../testdata/modelpack_seg_2x160x160.bin"
1146        ));
1147        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1148        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1149        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1150        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1151        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1152        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1153
1154        let quant = (1.0 / 255.0, 0).into();
1155
1156        let decoder = DecoderBuilder::default()
1157            .with_config_modelpack_seg(configs::Segmentation {
1158                decoder: DecoderType::ModelPack,
1159                quantization: Some(quant),
1160                shape: vec![1, 2, 160, 160],
1161                dshape: vec![
1162                    (DimName::Batch, 1),
1163                    (DimName::NumClasses, 2),
1164                    (DimName::Height, 160),
1165                    (DimName::Width, 160),
1166                ],
1167            })
1168            .build()
1169            .unwrap();
1170        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1171        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1172        decoder
1173            .decode_quantized(
1174                &[out_u8.view().into()],
1175                &mut output_boxes,
1176                &mut output_masks_u8,
1177            )
1178            .unwrap();
1179
1180        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1181        decoder
1182            .decode_quantized(
1183                &[out_i8.view().into()],
1184                &mut output_boxes,
1185                &mut output_masks_i8,
1186            )
1187            .unwrap();
1188
1189        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1190        decoder
1191            .decode_quantized(
1192                &[out_u16.view().into()],
1193                &mut output_boxes,
1194                &mut output_masks_u16,
1195            )
1196            .unwrap();
1197
1198        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1199        decoder
1200            .decode_quantized(
1201                &[out_i16.view().into()],
1202                &mut output_boxes,
1203                &mut output_masks_i16,
1204            )
1205            .unwrap();
1206
1207        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1208        decoder
1209            .decode_quantized(
1210                &[out_u32.view().into()],
1211                &mut output_boxes,
1212                &mut output_masks_u32,
1213            )
1214            .unwrap();
1215
1216        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1217        decoder
1218            .decode_quantized(
1219                &[out_i32.view().into()],
1220                &mut output_boxes,
1221                &mut output_masks_i32,
1222            )
1223            .unwrap();
1224
1225        compare_outputs((&[], &output_boxes), (&[], &[]));
1226        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1227        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1228        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1229        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1230        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1231        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1232        assert_eq!(mask_u8, mask_i8);
1233        assert_eq!(mask_u8, mask_u16);
1234        assert_eq!(mask_u8, mask_i16);
1235        assert_eq!(mask_u8, mask_u32);
1236        assert_eq!(mask_u8, mask_i32);
1237    }
1238
1239    #[test]
1240    fn test_modelpack_segdet() {
1241        let score_threshold = 0.45;
1242        let iou_threshold = 0.45;
1243
1244        let boxes = include_bytes!(concat!(
1245            env!("CARGO_MANIFEST_DIR"),
1246            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1247        ));
1248        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1249
1250        let scores = include_bytes!(concat!(
1251            env!("CARGO_MANIFEST_DIR"),
1252            "/../../testdata/modelpack_scores_1935x1.bin"
1253        ));
1254        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1255
1256        let seg = include_bytes!(concat!(
1257            env!("CARGO_MANIFEST_DIR"),
1258            "/../../testdata/modelpack_seg_2x160x160.bin"
1259        ));
1260        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1261
1262        let quant_boxes = (0.004656755365431309, 21).into();
1263        let quant_scores = (0.0019603664986789227, 0).into();
1264        let quant_seg = (1.0 / 255.0, 0).into();
1265
1266        let decoder = DecoderBuilder::default()
1267            .with_config_modelpack_segdet(
1268                configs::Boxes {
1269                    decoder: DecoderType::ModelPack,
1270                    quantization: Some(quant_boxes),
1271                    shape: vec![1, 1935, 1, 4],
1272                    dshape: vec![
1273                        (DimName::Batch, 1),
1274                        (DimName::NumBoxes, 1935),
1275                        (DimName::Padding, 1),
1276                        (DimName::BoxCoords, 4),
1277                    ],
1278                    normalized: Some(true),
1279                },
1280                configs::Scores {
1281                    decoder: DecoderType::ModelPack,
1282                    quantization: Some(quant_scores),
1283                    shape: vec![1, 1935, 1],
1284                    dshape: vec![
1285                        (DimName::Batch, 1),
1286                        (DimName::NumBoxes, 1935),
1287                        (DimName::NumClasses, 1),
1288                    ],
1289                },
1290                configs::Segmentation {
1291                    decoder: DecoderType::ModelPack,
1292                    quantization: Some(quant_seg),
1293                    shape: vec![1, 2, 160, 160],
1294                    dshape: vec![
1295                        (DimName::Batch, 1),
1296                        (DimName::NumClasses, 2),
1297                        (DimName::Height, 160),
1298                        (DimName::Width, 160),
1299                    ],
1300                },
1301            )
1302            .with_iou_threshold(iou_threshold)
1303            .with_score_threshold(score_threshold)
1304            .build()
1305            .unwrap();
1306        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1307        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1308        decoder
1309            .decode_quantized(
1310                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1311                &mut output_boxes,
1312                &mut output_masks,
1313            )
1314            .unwrap();
1315
1316        let mut mask = seg.slice(s![0, .., .., ..]);
1317        mask.swap_axes(0, 1);
1318        mask.swap_axes(1, 2);
1319        let mask = [Segmentation {
1320            xmin: 0.0,
1321            ymin: 0.0,
1322            xmax: 1.0,
1323            ymax: 1.0,
1324            segmentation: mask.into_owned(),
1325        }];
1326        let correct_boxes = [DetectBox {
1327            bbox: BoundingBox {
1328                xmin: 0.40513772,
1329                ymin: 0.6379755,
1330                xmax: 0.5122431,
1331                ymax: 0.7730214,
1332            },
1333            score: 0.4861709,
1334            label: 0,
1335        }];
1336        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1337
1338        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1339        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1340        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1341        decoder
1342            .decode_float::<f32>(
1343                &[
1344                    scores.view().into_dyn(),
1345                    boxes.view().into_dyn(),
1346                    seg.view().into_dyn(),
1347                ],
1348                &mut output_boxes,
1349                &mut output_masks,
1350            )
1351            .unwrap();
1352
1353        // not expected for float segmentation decoder to have same values as quantized
1354        // segmentation decoder, as float decoder ensures the data fills 0-255,
1355        // quantized decoder uses whatever the model output. Thus the float
1356        // output is the same as the quantized output but scaled differently.
1357        // However, it is expected that the mask after argmax will be the same.
1358        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1359        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1360        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1361
1362        assert_eq!(mask0, mask1);
1363    }
1364
1365    #[test]
1366    fn test_modelpack_segdet_split() {
1367        let score_threshold = 0.8;
1368        let iou_threshold = 0.5;
1369
1370        let seg = include_bytes!(concat!(
1371            env!("CARGO_MANIFEST_DIR"),
1372            "/../../testdata/modelpack_seg_2x160x160.bin"
1373        ));
1374        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1375
1376        let detect0 = include_bytes!(concat!(
1377            env!("CARGO_MANIFEST_DIR"),
1378            "/../../testdata/modelpack_split_9x15x18.bin"
1379        ));
1380        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1381
1382        let detect1 = include_bytes!(concat!(
1383            env!("CARGO_MANIFEST_DIR"),
1384            "/../../testdata/modelpack_split_17x30x18.bin"
1385        ));
1386        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1387
1388        let quant0 = (0.08547406643629074, 174).into();
1389        let quant1 = (0.09929127991199493, 183).into();
1390        let quant_seg = (1.0 / 255.0, 0).into();
1391
1392        let anchors0 = vec![
1393            [0.36666667461395264, 0.31481480598449707],
1394            [0.38749998807907104, 0.4740740656852722],
1395            [0.5333333611488342, 0.644444465637207],
1396        ];
1397        let anchors1 = vec![
1398            [0.13750000298023224, 0.2074074000120163],
1399            [0.2541666626930237, 0.21481481194496155],
1400            [0.23125000298023224, 0.35185185074806213],
1401        ];
1402
1403        let decoder = DecoderBuilder::default()
1404            .with_config_modelpack_segdet_split(
1405                vec![
1406                    configs::Detection {
1407                        decoder: DecoderType::ModelPack,
1408                        shape: vec![1, 17, 30, 18],
1409                        anchors: Some(anchors1),
1410                        quantization: Some(quant1),
1411                        dshape: vec![
1412                            (DimName::Batch, 1),
1413                            (DimName::Height, 17),
1414                            (DimName::Width, 30),
1415                            (DimName::NumAnchorsXFeatures, 18),
1416                        ],
1417                        normalized: Some(true),
1418                    },
1419                    configs::Detection {
1420                        decoder: DecoderType::ModelPack,
1421                        shape: vec![1, 9, 15, 18],
1422                        anchors: Some(anchors0),
1423                        quantization: Some(quant0),
1424                        dshape: vec![
1425                            (DimName::Batch, 1),
1426                            (DimName::Height, 9),
1427                            (DimName::Width, 15),
1428                            (DimName::NumAnchorsXFeatures, 18),
1429                        ],
1430                        normalized: Some(true),
1431                    },
1432                ],
1433                configs::Segmentation {
1434                    decoder: DecoderType::ModelPack,
1435                    quantization: Some(quant_seg),
1436                    shape: vec![1, 2, 160, 160],
1437                    dshape: vec![
1438                        (DimName::Batch, 1),
1439                        (DimName::NumClasses, 2),
1440                        (DimName::Height, 160),
1441                        (DimName::Width, 160),
1442                    ],
1443                },
1444            )
1445            .with_score_threshold(score_threshold)
1446            .with_iou_threshold(iou_threshold)
1447            .build()
1448            .unwrap();
1449        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1450        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1451        decoder
1452            .decode_quantized(
1453                &[
1454                    detect0.view().into(),
1455                    detect1.view().into(),
1456                    seg.view().into(),
1457                ],
1458                &mut output_boxes,
1459                &mut output_masks,
1460            )
1461            .unwrap();
1462
1463        let mut mask = seg.slice(s![0, .., .., ..]);
1464        mask.swap_axes(0, 1);
1465        mask.swap_axes(1, 2);
1466        let mask = [Segmentation {
1467            xmin: 0.0,
1468            ymin: 0.0,
1469            xmax: 1.0,
1470            ymax: 1.0,
1471            segmentation: mask.into_owned(),
1472        }];
1473        let correct_boxes = [DetectBox {
1474            bbox: BoundingBox {
1475                xmin: 0.43171933,
1476                ymin: 0.68243736,
1477                xmax: 0.5626645,
1478                ymax: 0.808863,
1479            },
1480            score: 0.99240804,
1481            label: 0,
1482        }];
1483        println!("Output Boxes: {:?}", output_boxes);
1484        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1485
1486        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1487        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1488        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1489        decoder
1490            .decode_float::<f32>(
1491                &[
1492                    detect0.view().into_dyn(),
1493                    detect1.view().into_dyn(),
1494                    seg.view().into_dyn(),
1495                ],
1496                &mut output_boxes,
1497                &mut output_masks,
1498            )
1499            .unwrap();
1500
1501        // not expected for float segmentation decoder to have same values as quantized
1502        // segmentation decoder, as float decoder ensures the data fills 0-255,
1503        // quantized decoder uses whatever the model output. Thus the float
1504        // output is the same as the quantized output but scaled differently.
1505        // However, it is expected that the mask after argmax will be the same.
1506        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1507        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1508        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1509
1510        assert_eq!(mask0, mask1);
1511    }
1512
1513    #[test]
1514    fn test_dequant_chunked() {
1515        let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1516        out.push(123); // make sure to test non multiple of 16 length
1517
1518        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1519        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1520        let quant = Quantization::new(0.0040811873, -123);
1521        dequantize_cpu(&out, quant, &mut out_dequant);
1522
1523        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1524        assert_eq!(out_dequant, out_dequant_simd);
1525
1526        let quant = Quantization::new(0.0040811873, 0);
1527        dequantize_cpu(&out, quant, &mut out_dequant);
1528
1529        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1530        assert_eq!(out_dequant, out_dequant_simd);
1531    }
1532
1533    #[test]
1534    fn test_dequant_ground_truth() {
1535        // Formula: output = (input - zero_point) * scale
1536        // Verify both dequantize_cpu and dequantize_cpu_chunked against hand-computed values.
1537
1538        // Case 1: scale=0.1, zero_point=-128 (from doc example)
1539        let quant = Quantization::new(0.1, -128);
1540        let input: Vec<i8> = vec![0, 127, -128, 64];
1541        let mut output = vec![0.0f32; 4];
1542        let mut output_chunked = vec![0.0f32; 4];
1543        dequantize_cpu(&input, quant, &mut output);
1544        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1545        // (0 - (-128)) * 0.1 = 12.8
1546        // (127 - (-128)) * 0.1 = 25.5
1547        // (-128 - (-128)) * 0.1 = 0.0
1548        // (64 - (-128)) * 0.1 = 19.2
1549        let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1550        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1551            assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1552        }
1553        for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1554            assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1555        }
1556
1557        // Case 2: scale=1.0, zero_point=0 (identity-like)
1558        let quant = Quantization::new(1.0, 0);
1559        dequantize_cpu(&input, quant, &mut output);
1560        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1561        let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1562        assert_eq!(output, expected);
1563        assert_eq!(output_chunked, expected);
1564
1565        // Case 3: scale=0.5, zero_point=0
1566        let quant = Quantization::new(0.5, 0);
1567        dequantize_cpu(&input, quant, &mut output);
1568        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1569        let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1570        assert_eq!(output, expected);
1571        assert_eq!(output_chunked, expected);
1572
1573        // Case 4: i8 min/max boundaries with typical quantization params
1574        let quant = Quantization::new(0.021287762, 31);
1575        let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1576        let mut output = vec![0.0f32; 6];
1577        let mut output_chunked = vec![0.0f32; 6];
1578        dequantize_cpu(&input, quant, &mut output);
1579        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1580        for i in 0..6 {
1581            let expected = (input[i] as f32 - 31.0) * 0.021287762;
1582            assert!(
1583                (output[i] - expected).abs() < 1e-5,
1584                "cpu[{i}]: {} != {expected}",
1585                output[i]
1586            );
1587            assert!(
1588                (output_chunked[i] - expected).abs() < 1e-5,
1589                "chunked[{i}]: {} != {expected}",
1590                output_chunked[i]
1591            );
1592        }
1593    }
1594
1595    #[test]
1596    fn test_decoder_yolo_det() {
1597        let score_threshold = 0.25;
1598        let iou_threshold = 0.7;
1599        let out = load_yolov8s_det();
1600        let quant = (0.0040811873, -123).into();
1601
1602        let decoder = DecoderBuilder::default()
1603            .with_config_yolo_det(
1604                configs::Detection {
1605                    decoder: DecoderType::Ultralytics,
1606                    shape: vec![1, 84, 8400],
1607                    anchors: None,
1608                    quantization: Some(quant),
1609                    dshape: vec![
1610                        (DimName::Batch, 1),
1611                        (DimName::NumFeatures, 84),
1612                        (DimName::NumBoxes, 8400),
1613                    ],
1614                    normalized: Some(true),
1615                },
1616                Some(DecoderVersion::Yolo11),
1617            )
1618            .with_score_threshold(score_threshold)
1619            .with_iou_threshold(iou_threshold)
1620            .build()
1621            .unwrap();
1622
1623        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1624        decode_yolo_det(
1625            (out.slice(s![0, .., ..]), quant.into()),
1626            score_threshold,
1627            iou_threshold,
1628            Some(configs::Nms::ClassAgnostic),
1629            &mut output_boxes,
1630        );
1631        assert!(output_boxes[0].equal_within_delta(
1632            &DetectBox {
1633                bbox: BoundingBox {
1634                    xmin: 0.5285137,
1635                    ymin: 0.05305544,
1636                    xmax: 0.87541467,
1637                    ymax: 0.9998909,
1638                },
1639                score: 0.5591227,
1640                label: 0
1641            },
1642            1e-6
1643        ));
1644
1645        assert!(output_boxes[1].equal_within_delta(
1646            &DetectBox {
1647                bbox: BoundingBox {
1648                    xmin: 0.130598,
1649                    ymin: 0.43260583,
1650                    xmax: 0.35098213,
1651                    ymax: 0.9958097,
1652                },
1653                score: 0.33057618,
1654                label: 75
1655            },
1656            1e-6
1657        ));
1658
1659        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1660        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1661        decoder
1662            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1663            .unwrap();
1664
1665        let out = dequantize_ndarray(out.view(), quant.into());
1666        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1667        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1668        decoder
1669            .decode_float::<f32>(
1670                &[out.view().into_dyn()],
1671                &mut output_boxes_f32,
1672                &mut output_masks_f32,
1673            )
1674            .unwrap();
1675
1676        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1677        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1678    }
1679
1680    #[test]
1681    fn test_decoder_masks() {
1682        let score_threshold = 0.45;
1683        let iou_threshold = 0.45;
1684        let boxes = load_yolov8_boxes();
1685        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1686
1687        let protos = load_yolov8_protos();
1688        let quant_protos = Quantization::new(0.02491161972284317, -117);
1689        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1690        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1691        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1692        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1693        decode_yolo_segdet_float(
1694            seg.slice(s![0, .., ..]),
1695            protos.slice(s![0, .., .., ..]),
1696            score_threshold,
1697            iou_threshold,
1698            Some(configs::Nms::ClassAgnostic),
1699            &mut output_boxes,
1700            &mut output_masks,
1701        )
1702        .unwrap();
1703        assert_eq!(output_boxes.len(), 2);
1704        assert_eq!(output_boxes.len(), output_masks.len());
1705
1706        for (b, m) in output_boxes.iter().zip(&output_masks) {
1707            assert!(b.bbox.xmin >= m.xmin);
1708            assert!(b.bbox.ymin >= m.ymin);
1709            assert!(b.bbox.xmax >= m.xmax);
1710            assert!(b.bbox.ymax >= m.ymax);
1711        }
1712        assert!(output_boxes[0].equal_within_delta(
1713            &DetectBox {
1714                bbox: BoundingBox {
1715                    xmin: 0.08515105,
1716                    ymin: 0.7131401,
1717                    xmax: 0.29802868,
1718                    ymax: 0.8195788,
1719                },
1720                score: 0.91537374,
1721                label: 23
1722            },
1723            1.0 / 160.0, // wider range because mask will expand the box
1724        ));
1725
1726        assert!(output_boxes[1].equal_within_delta(
1727            &DetectBox {
1728                bbox: BoundingBox {
1729                    xmin: 0.59605736,
1730                    ymin: 0.25545314,
1731                    xmax: 0.93666154,
1732                    ymax: 0.72378385,
1733                },
1734                score: 0.91537374,
1735                label: 23
1736            },
1737            1.0 / 160.0, // wider range because mask will expand the box
1738        ));
1739
1740        let full_mask = include_bytes!(concat!(
1741            env!("CARGO_MANIFEST_DIR"),
1742            "/../../testdata/yolov8_mask_results.bin"
1743        ));
1744        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1745
1746        let cropped_mask = full_mask.slice(ndarray::s![
1747            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1748            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1749        ]);
1750
1751        assert_eq!(
1752            cropped_mask,
1753            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1754        );
1755    }
1756
1757    /// Regression test: config-driven path with physically-NCHW protos.
1758    /// Simulates YOLOv8-seg ONNX outputs where the producer emits protos
1759    /// as `(1, 32, 160, 160)` in CHW memory order — the caller declares
1760    /// shape and dshape matching that physical order and HAL permutes to
1761    /// canonical HWC via `swap_axes_if_needed`.
1762    ///
1763    /// This is the counterpart to the NHWC-producer case (TFLite /
1764    /// Ara-2) where shape+dshape are `(1, 160, 160, 32)` +
1765    /// `[batch, height, width, num_protos]` and no reordering is needed.
1766    #[test]
1767    fn test_decoder_masks_nchw_protos() {
1768        let score_threshold = 0.45;
1769        let iou_threshold = 0.45;
1770
1771        // Load test data — boxes as [116, 8400]
1772        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1773        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1774
1775        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1776        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1777        let quant_protos = Quantization::new(0.02491161972284317, -117);
1778        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1779
1780        // ---- Reference: direct call with HWC protos (known working) ----
1781        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1782        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1783        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1784        decode_yolo_segdet_float(
1785            seg.view(),
1786            protos_f32_hwc.view(),
1787            score_threshold,
1788            iou_threshold,
1789            Some(configs::Nms::ClassAgnostic),
1790            &mut ref_boxes,
1791            &mut ref_masks,
1792        )
1793        .unwrap();
1794        assert_eq!(ref_boxes.len(), 2);
1795
1796        // ---- Config-driven path: NCHW protos declared in physical order ----
1797        // Permute the HWC test data to CHW memory order — this is what an
1798        // ONNX-style producer would emit into the tensor buffer.
1799        // `to_owned` materialises a C-contiguous Array3<f32> with CHW
1800        // strides, matching a producer that writes channels-outer.
1801        let protos_f32_chw_view = protos_f32_hwc.view().permuted_axes([2, 0, 1]); // [32, 160, 160]
1802        let protos_f32_chw = protos_f32_chw_view.to_owned();
1803        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1804
1805        // Build boxes as [1, 116, 8400] f32
1806        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1807
1808        // Declare shape and dshape in physical memory order (outermost
1809        // first). `swap_axes_if_needed` uses dshape role lookup to
1810        // permute the stride tuple into canonical HWC.
1811        let decoder = DecoderBuilder::default()
1812            .with_config_yolo_segdet(
1813                configs::Detection {
1814                    decoder: configs::DecoderType::Ultralytics,
1815                    quantization: None,
1816                    shape: vec![1, 116, 8400],
1817                    dshape: vec![
1818                        (configs::DimName::Batch, 1),
1819                        (configs::DimName::NumFeatures, 116),
1820                        (configs::DimName::NumBoxes, 8400),
1821                    ],
1822                    normalized: Some(true),
1823                    anchors: None,
1824                },
1825                configs::Protos {
1826                    decoder: configs::DecoderType::Ultralytics,
1827                    quantization: None,
1828                    shape: vec![1, 32, 160, 160],
1829                    dshape: vec![
1830                        (configs::DimName::Batch, 1),
1831                        (configs::DimName::NumProtos, 32),
1832                        (configs::DimName::Height, 160),
1833                        (configs::DimName::Width, 160),
1834                    ],
1835                },
1836                None, // decoder version
1837            )
1838            .with_score_threshold(score_threshold)
1839            .with_iou_threshold(iou_threshold)
1840            .build()
1841            .unwrap();
1842
1843        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1844        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1845        decoder
1846            .decode_float(
1847                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1848                &mut cfg_boxes,
1849                &mut cfg_masks,
1850            )
1851            .unwrap();
1852
1853        // Must produce the same number of detections
1854        assert_eq!(
1855            cfg_boxes.len(),
1856            ref_boxes.len(),
1857            "config path produced {} boxes, reference produced {}",
1858            cfg_boxes.len(),
1859            ref_boxes.len()
1860        );
1861
1862        // Boxes must match
1863        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1864            assert!(
1865                cb.equal_within_delta(rb, 0.01),
1866                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1867            );
1868        }
1869
1870        // Masks must match pixel-for-pixel
1871        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1872            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1873            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1874            assert_eq!(
1875                cm_arr, rm_arr,
1876                "mask {i} pixel mismatch between config-driven and reference paths"
1877            );
1878        }
1879    }
1880
1881    #[test]
1882    fn test_decoder_masks_i8() {
1883        let score_threshold = 0.45;
1884        let iou_threshold = 0.45;
1885        let boxes = load_yolov8_boxes();
1886        let quant_boxes = (0.021287761628627777, 31).into();
1887
1888        let protos = load_yolov8_protos();
1889        let quant_protos = (0.02491161972284317, -117).into();
1890        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1891        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1892
1893        let decoder = DecoderBuilder::default()
1894            .with_config_yolo_segdet(
1895                configs::Detection {
1896                    decoder: configs::DecoderType::Ultralytics,
1897                    quantization: Some(quant_boxes),
1898                    shape: vec![1, 116, 8400],
1899                    anchors: None,
1900                    dshape: vec![
1901                        (DimName::Batch, 1),
1902                        (DimName::NumFeatures, 116),
1903                        (DimName::NumBoxes, 8400),
1904                    ],
1905                    normalized: Some(true),
1906                },
1907                Protos {
1908                    decoder: configs::DecoderType::Ultralytics,
1909                    quantization: Some(quant_protos),
1910                    shape: vec![1, 160, 160, 32],
1911                    dshape: vec![
1912                        (DimName::Batch, 1),
1913                        (DimName::Height, 160),
1914                        (DimName::Width, 160),
1915                        (DimName::NumProtos, 32),
1916                    ],
1917                },
1918                Some(DecoderVersion::Yolo11),
1919            )
1920            .with_score_threshold(score_threshold)
1921            .with_iou_threshold(iou_threshold)
1922            .build()
1923            .unwrap();
1924
1925        let quant_boxes = quant_boxes.into();
1926        let quant_protos = quant_protos.into();
1927
1928        decode_yolo_segdet_quant(
1929            (boxes.slice(s![0, .., ..]), quant_boxes),
1930            (protos.slice(s![0, .., .., ..]), quant_protos),
1931            score_threshold,
1932            iou_threshold,
1933            Some(configs::Nms::ClassAgnostic),
1934            &mut output_boxes,
1935            &mut output_masks,
1936        )
1937        .unwrap();
1938
1939        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1940        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1941
1942        decoder
1943            .decode_quantized(
1944                &[boxes.view().into(), protos.view().into()],
1945                &mut output_boxes1,
1946                &mut output_masks1,
1947            )
1948            .unwrap();
1949
1950        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1951        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1952
1953        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1954        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1955        decode_yolo_segdet_float(
1956            seg.slice(s![0, .., ..]),
1957            protos.slice(s![0, .., .., ..]),
1958            score_threshold,
1959            iou_threshold,
1960            Some(configs::Nms::ClassAgnostic),
1961            &mut output_boxes_f32,
1962            &mut output_masks_f32,
1963        )
1964        .unwrap();
1965
1966        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1967        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1968
1969        decoder
1970            .decode_float(
1971                &[seg.view().into_dyn(), protos.view().into_dyn()],
1972                &mut output_boxes1_f32,
1973                &mut output_masks1_f32,
1974            )
1975            .unwrap();
1976
1977        compare_outputs(
1978            (&output_boxes, &output_boxes1),
1979            (&output_masks, &output_masks1),
1980        );
1981
1982        compare_outputs(
1983            (&output_boxes, &output_boxes_f32),
1984            (&output_masks, &output_masks_f32),
1985        );
1986
1987        compare_outputs(
1988            (&output_boxes_f32, &output_boxes1_f32),
1989            (&output_masks_f32, &output_masks1_f32),
1990        );
1991    }
1992
1993    #[test]
1994    fn test_decoder_yolo_split() {
1995        let score_threshold = 0.45;
1996        let iou_threshold = 0.45;
1997        let boxes = load_yolov8_boxes();
1998        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1999        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2000
2001        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2002
2003        let decoder = DecoderBuilder::default()
2004            .with_config_yolo_split_det(
2005                configs::Boxes {
2006                    decoder: configs::DecoderType::Ultralytics,
2007                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2008                    shape: vec![1, 4, 8400],
2009                    dshape: vec![
2010                        (DimName::Batch, 1),
2011                        (DimName::BoxCoords, 4),
2012                        (DimName::NumBoxes, 8400),
2013                    ],
2014                    normalized: Some(true),
2015                },
2016                configs::Scores {
2017                    decoder: configs::DecoderType::Ultralytics,
2018                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2019                    shape: vec![1, 80, 8400],
2020                    dshape: vec![
2021                        (DimName::Batch, 1),
2022                        (DimName::NumClasses, 80),
2023                        (DimName::NumBoxes, 8400),
2024                    ],
2025                },
2026            )
2027            .with_score_threshold(score_threshold)
2028            .with_iou_threshold(iou_threshold)
2029            .build()
2030            .unwrap();
2031
2032        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2033        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2034
2035        decoder
2036            .decode_quantized(
2037                &[
2038                    boxes.slice(s![.., ..4, ..]).into(),
2039                    boxes.slice(s![.., 4..84, ..]).into(),
2040                ],
2041                &mut output_boxes,
2042                &mut output_masks,
2043            )
2044            .unwrap();
2045
2046        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2047        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2048        decode_yolo_det_float(
2049            seg.slice(s![0, ..84, ..]),
2050            score_threshold,
2051            iou_threshold,
2052            Some(configs::Nms::ClassAgnostic),
2053            &mut output_boxes_f32,
2054        );
2055
2056        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2057        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2058
2059        decoder
2060            .decode_float(
2061                &[
2062                    seg.slice(s![.., ..4, ..]).into_dyn(),
2063                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2064                ],
2065                &mut output_boxes1,
2066                &mut output_masks1,
2067            )
2068            .unwrap();
2069        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2070        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2071    }
2072
2073    #[test]
2074    fn test_decoder_masks_config_mixed() {
2075        let score_threshold = 0.45;
2076        let iou_threshold = 0.45;
2077        let boxes_raw = load_yolov8_boxes();
2078        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2079        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2080
2081        let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2082
2083        let protos = load_yolov8_protos();
2084        let quant_protos = (0.02491161972284317, -117);
2085
2086        let decoder = build_yolo_split_segdet_decoder(
2087            score_threshold,
2088            iou_threshold,
2089            quant_boxes,
2090            quant_protos,
2091        );
2092        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2093        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2094
2095        decoder
2096            .decode_quantized(
2097                &[
2098                    boxes.slice(s![.., ..4, ..]).into(),
2099                    boxes.slice(s![.., 4..84, ..]).into(),
2100                    boxes.slice(s![.., 84.., ..]).into(),
2101                    protos.view().into(),
2102                ],
2103                &mut output_boxes,
2104                &mut output_masks,
2105            )
2106            .unwrap();
2107
2108        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2109        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2110        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2111        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2112        decode_yolo_segdet_float(
2113            seg.slice(s![0, .., ..]),
2114            protos.slice(s![0, .., .., ..]),
2115            score_threshold,
2116            iou_threshold,
2117            Some(configs::Nms::ClassAgnostic),
2118            &mut output_boxes_f32,
2119            &mut output_masks_f32,
2120        )
2121        .unwrap();
2122
2123        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2124        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2125
2126        decoder
2127            .decode_float(
2128                &[
2129                    seg.slice(s![.., ..4, ..]).into_dyn(),
2130                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2131                    seg.slice(s![.., 84.., ..]).into_dyn(),
2132                    protos.view().into_dyn(),
2133                ],
2134                &mut output_boxes1,
2135                &mut output_masks1,
2136            )
2137            .unwrap();
2138        compare_outputs(
2139            (&output_boxes, &output_boxes_f32),
2140            (&output_masks, &output_masks_f32),
2141        );
2142        compare_outputs(
2143            (&output_boxes_f32, &output_boxes1),
2144            (&output_masks_f32, &output_masks1),
2145        );
2146    }
2147
2148    fn build_yolo_split_segdet_decoder(
2149        score_threshold: f32,
2150        iou_threshold: f32,
2151        quant_boxes: (f32, i32),
2152        quant_protos: (f32, i32),
2153    ) -> crate::Decoder {
2154        DecoderBuilder::default()
2155            .with_config_yolo_split_segdet(
2156                configs::Boxes {
2157                    decoder: configs::DecoderType::Ultralytics,
2158                    quantization: Some(quant_boxes.into()),
2159                    shape: vec![1, 4, 8400],
2160                    dshape: vec![
2161                        (DimName::Batch, 1),
2162                        (DimName::BoxCoords, 4),
2163                        (DimName::NumBoxes, 8400),
2164                    ],
2165                    normalized: Some(true),
2166                },
2167                configs::Scores {
2168                    decoder: configs::DecoderType::Ultralytics,
2169                    quantization: Some(quant_boxes.into()),
2170                    shape: vec![1, 80, 8400],
2171                    dshape: vec![
2172                        (DimName::Batch, 1),
2173                        (DimName::NumClasses, 80),
2174                        (DimName::NumBoxes, 8400),
2175                    ],
2176                },
2177                configs::MaskCoefficients {
2178                    decoder: configs::DecoderType::Ultralytics,
2179                    quantization: Some(quant_boxes.into()),
2180                    shape: vec![1, 32, 8400],
2181                    dshape: vec![
2182                        (DimName::Batch, 1),
2183                        (DimName::NumProtos, 32),
2184                        (DimName::NumBoxes, 8400),
2185                    ],
2186                },
2187                configs::Protos {
2188                    decoder: configs::DecoderType::Ultralytics,
2189                    quantization: Some(quant_protos.into()),
2190                    shape: vec![1, 160, 160, 32],
2191                    dshape: vec![
2192                        (DimName::Batch, 1),
2193                        (DimName::Height, 160),
2194                        (DimName::Width, 160),
2195                        (DimName::NumProtos, 32),
2196                    ],
2197                },
2198            )
2199            .with_score_threshold(score_threshold)
2200            .with_iou_threshold(iou_threshold)
2201            .build()
2202            .unwrap()
2203    }
2204
2205    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2206        let config_yaml = include_str!(concat!(
2207            env!("CARGO_MANIFEST_DIR"),
2208            "/../../testdata/yolov8_seg.yaml"
2209        ));
2210        DecoderBuilder::default()
2211            .with_config_yaml_str(config_yaml.to_string())
2212            .with_score_threshold(score_threshold)
2213            .with_iou_threshold(iou_threshold)
2214            .build()
2215            .unwrap()
2216    }
2217    #[test]
2218    fn test_decoder_masks_config_i32() {
2219        let score_threshold = 0.45;
2220        let iou_threshold = 0.45;
2221        let boxes_raw = load_yolov8_boxes();
2222        let scale = 1 << 23;
2223        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2224        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2225
2226        let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2227
2228        let protos_raw = load_yolov8_protos();
2229        let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2230        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2231        let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2232
2233        let decoder = build_yolo_split_segdet_decoder(
2234            score_threshold,
2235            iou_threshold,
2236            quant_boxes,
2237            quant_protos,
2238        );
2239
2240        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2241        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2242
2243        decoder
2244            .decode_quantized(
2245                &[
2246                    boxes.slice(s![.., ..4, ..]).into(),
2247                    boxes.slice(s![.., 4..84, ..]).into(),
2248                    boxes.slice(s![.., 84.., ..]).into(),
2249                    protos.view().into(),
2250                ],
2251                &mut output_boxes,
2252                &mut output_masks,
2253            )
2254            .unwrap();
2255
2256        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2257        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2258        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2259        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2260        decode_yolo_segdet_float(
2261            seg.slice(s![0, .., ..]),
2262            protos.slice(s![0, .., .., ..]),
2263            score_threshold,
2264            iou_threshold,
2265            Some(configs::Nms::ClassAgnostic),
2266            &mut output_boxes_f32,
2267            &mut output_masks_f32,
2268        )
2269        .unwrap();
2270
2271        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2272        assert_eq!(output_masks.len(), output_masks_f32.len());
2273
2274        compare_outputs(
2275            (&output_boxes, &output_boxes_f32),
2276            (&output_masks, &output_masks_f32),
2277        );
2278    }
2279
2280    /// test running multiple decoders concurrently
2281    #[test]
2282    fn test_context_switch() {
2283        let yolo_det = || {
2284            let score_threshold = 0.25;
2285            let iou_threshold = 0.7;
2286            let out = load_yolov8s_det();
2287            let quant = (0.0040811873, -123).into();
2288
2289            let decoder = DecoderBuilder::default()
2290                .with_config_yolo_det(
2291                    configs::Detection {
2292                        decoder: DecoderType::Ultralytics,
2293                        shape: vec![1, 84, 8400],
2294                        anchors: None,
2295                        quantization: Some(quant),
2296                        dshape: vec![
2297                            (DimName::Batch, 1),
2298                            (DimName::NumFeatures, 84),
2299                            (DimName::NumBoxes, 8400),
2300                        ],
2301                        normalized: None,
2302                    },
2303                    None,
2304                )
2305                .with_score_threshold(score_threshold)
2306                .with_iou_threshold(iou_threshold)
2307                .build()
2308                .unwrap();
2309
2310            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2311            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2312
2313            for _ in 0..100 {
2314                decoder
2315                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2316                    .unwrap();
2317
2318                assert!(output_boxes[0].equal_within_delta(
2319                    &DetectBox {
2320                        bbox: BoundingBox {
2321                            xmin: 0.5285137,
2322                            ymin: 0.05305544,
2323                            xmax: 0.87541467,
2324                            ymax: 0.9998909,
2325                        },
2326                        score: 0.5591227,
2327                        label: 0
2328                    },
2329                    1e-6
2330                ));
2331
2332                assert!(output_boxes[1].equal_within_delta(
2333                    &DetectBox {
2334                        bbox: BoundingBox {
2335                            xmin: 0.130598,
2336                            ymin: 0.43260583,
2337                            xmax: 0.35098213,
2338                            ymax: 0.9958097,
2339                        },
2340                        score: 0.33057618,
2341                        label: 75
2342                    },
2343                    1e-6
2344                ));
2345                assert!(output_masks.is_empty());
2346            }
2347        };
2348
2349        let modelpack_det_split = || {
2350            let score_threshold = 0.8;
2351            let iou_threshold = 0.5;
2352
2353            let seg = include_bytes!(concat!(
2354                env!("CARGO_MANIFEST_DIR"),
2355                "/../../testdata/modelpack_seg_2x160x160.bin"
2356            ));
2357            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2358
2359            let detect0 = include_bytes!(concat!(
2360                env!("CARGO_MANIFEST_DIR"),
2361                "/../../testdata/modelpack_split_9x15x18.bin"
2362            ));
2363            let detect0 =
2364                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2365
2366            let detect1 = include_bytes!(concat!(
2367                env!("CARGO_MANIFEST_DIR"),
2368                "/../../testdata/modelpack_split_17x30x18.bin"
2369            ));
2370            let detect1 =
2371                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2372
2373            let mut mask = seg.slice(s![0, .., .., ..]);
2374            mask.swap_axes(0, 1);
2375            mask.swap_axes(1, 2);
2376            let mask = [Segmentation {
2377                xmin: 0.0,
2378                ymin: 0.0,
2379                xmax: 1.0,
2380                ymax: 1.0,
2381                segmentation: mask.into_owned(),
2382            }];
2383            let correct_boxes = [DetectBox {
2384                bbox: BoundingBox {
2385                    xmin: 0.43171933,
2386                    ymin: 0.68243736,
2387                    xmax: 0.5626645,
2388                    ymax: 0.808863,
2389                },
2390                score: 0.99240804,
2391                label: 0,
2392            }];
2393
2394            let quant0 = (0.08547406643629074, 174).into();
2395            let quant1 = (0.09929127991199493, 183).into();
2396            let quant_seg = (1.0 / 255.0, 0).into();
2397
2398            let anchors0 = vec![
2399                [0.36666667461395264, 0.31481480598449707],
2400                [0.38749998807907104, 0.4740740656852722],
2401                [0.5333333611488342, 0.644444465637207],
2402            ];
2403            let anchors1 = vec![
2404                [0.13750000298023224, 0.2074074000120163],
2405                [0.2541666626930237, 0.21481481194496155],
2406                [0.23125000298023224, 0.35185185074806213],
2407            ];
2408
2409            let decoder = DecoderBuilder::default()
2410                .with_config_modelpack_segdet_split(
2411                    vec![
2412                        configs::Detection {
2413                            decoder: DecoderType::ModelPack,
2414                            shape: vec![1, 17, 30, 18],
2415                            anchors: Some(anchors1),
2416                            quantization: Some(quant1),
2417                            dshape: vec![
2418                                (DimName::Batch, 1),
2419                                (DimName::Height, 17),
2420                                (DimName::Width, 30),
2421                                (DimName::NumAnchorsXFeatures, 18),
2422                            ],
2423                            normalized: None,
2424                        },
2425                        configs::Detection {
2426                            decoder: DecoderType::ModelPack,
2427                            shape: vec![1, 9, 15, 18],
2428                            anchors: Some(anchors0),
2429                            quantization: Some(quant0),
2430                            dshape: vec![
2431                                (DimName::Batch, 1),
2432                                (DimName::Height, 9),
2433                                (DimName::Width, 15),
2434                                (DimName::NumAnchorsXFeatures, 18),
2435                            ],
2436                            normalized: None,
2437                        },
2438                    ],
2439                    configs::Segmentation {
2440                        decoder: DecoderType::ModelPack,
2441                        quantization: Some(quant_seg),
2442                        shape: vec![1, 2, 160, 160],
2443                        dshape: vec![
2444                            (DimName::Batch, 1),
2445                            (DimName::NumClasses, 2),
2446                            (DimName::Height, 160),
2447                            (DimName::Width, 160),
2448                        ],
2449                    },
2450                )
2451                .with_score_threshold(score_threshold)
2452                .with_iou_threshold(iou_threshold)
2453                .build()
2454                .unwrap();
2455            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2456            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2457
2458            for _ in 0..100 {
2459                decoder
2460                    .decode_quantized(
2461                        &[
2462                            detect0.view().into(),
2463                            detect1.view().into(),
2464                            seg.view().into(),
2465                        ],
2466                        &mut output_boxes,
2467                        &mut output_masks,
2468                    )
2469                    .unwrap();
2470
2471                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2472            }
2473        };
2474
2475        let handles = vec![
2476            std::thread::spawn(yolo_det),
2477            std::thread::spawn(modelpack_det_split),
2478            std::thread::spawn(yolo_det),
2479            std::thread::spawn(modelpack_det_split),
2480            std::thread::spawn(yolo_det),
2481            std::thread::spawn(modelpack_det_split),
2482            std::thread::spawn(yolo_det),
2483            std::thread::spawn(modelpack_det_split),
2484        ];
2485        for handle in handles {
2486            handle.join().unwrap();
2487        }
2488    }
2489
2490    #[test]
2491    fn test_ndarray_to_xyxy_float() {
2492        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2493        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2494        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2495
2496        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2497        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2498        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2499    }
2500
2501    #[test]
2502    fn test_class_aware_nms_float() {
2503        use crate::float::nms_class_aware_float;
2504
2505        // Create two overlapping boxes with different classes
2506        let boxes = vec![
2507            DetectBox {
2508                bbox: BoundingBox {
2509                    xmin: 0.0,
2510                    ymin: 0.0,
2511                    xmax: 0.5,
2512                    ymax: 0.5,
2513                },
2514                score: 0.9,
2515                label: 0, // class 0
2516            },
2517            DetectBox {
2518                bbox: BoundingBox {
2519                    xmin: 0.1,
2520                    ymin: 0.1,
2521                    xmax: 0.6,
2522                    ymax: 0.6,
2523                },
2524                score: 0.8,
2525                label: 1, // class 1 - different class
2526            },
2527        ];
2528
2529        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2530        // threshold 0.3)
2531        let result = nms_class_aware_float(0.3, boxes.clone());
2532        assert_eq!(
2533            result.len(),
2534            2,
2535            "Class-aware NMS should keep both boxes with different classes"
2536        );
2537
2538        // Now test with same class - should suppress one
2539        let same_class_boxes = vec![
2540            DetectBox {
2541                bbox: BoundingBox {
2542                    xmin: 0.0,
2543                    ymin: 0.0,
2544                    xmax: 0.5,
2545                    ymax: 0.5,
2546                },
2547                score: 0.9,
2548                label: 0,
2549            },
2550            DetectBox {
2551                bbox: BoundingBox {
2552                    xmin: 0.1,
2553                    ymin: 0.1,
2554                    xmax: 0.6,
2555                    ymax: 0.6,
2556                },
2557                score: 0.8,
2558                label: 0, // same class
2559            },
2560        ];
2561
2562        let result = nms_class_aware_float(0.3, same_class_boxes);
2563        assert_eq!(
2564            result.len(),
2565            1,
2566            "Class-aware NMS should suppress overlapping box with same class"
2567        );
2568        assert_eq!(result[0].label, 0);
2569        assert!((result[0].score - 0.9).abs() < 1e-6);
2570    }
2571
2572    #[test]
2573    fn test_class_agnostic_vs_aware_nms() {
2574        use crate::float::{nms_class_aware_float, nms_float};
2575
2576        // Two overlapping boxes with different classes
2577        let boxes = vec![
2578            DetectBox {
2579                bbox: BoundingBox {
2580                    xmin: 0.0,
2581                    ymin: 0.0,
2582                    xmax: 0.5,
2583                    ymax: 0.5,
2584                },
2585                score: 0.9,
2586                label: 0,
2587            },
2588            DetectBox {
2589                bbox: BoundingBox {
2590                    xmin: 0.1,
2591                    ymin: 0.1,
2592                    xmax: 0.6,
2593                    ymax: 0.6,
2594                },
2595                score: 0.8,
2596                label: 1,
2597            },
2598        ];
2599
2600        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2601        let agnostic_result = nms_float(0.3, boxes.clone());
2602        assert_eq!(
2603            agnostic_result.len(),
2604            1,
2605            "Class-agnostic NMS should suppress overlapping boxes"
2606        );
2607
2608        // Class-aware should keep both (different classes)
2609        let aware_result = nms_class_aware_float(0.3, boxes);
2610        assert_eq!(
2611            aware_result.len(),
2612            2,
2613            "Class-aware NMS should keep boxes with different classes"
2614        );
2615    }
2616
2617    #[test]
2618    fn test_class_aware_nms_int() {
2619        use crate::byte::nms_class_aware_int;
2620
2621        // Create two overlapping boxes with different classes
2622        let boxes = vec![
2623            DetectBoxQuantized {
2624                bbox: BoundingBox {
2625                    xmin: 0.0,
2626                    ymin: 0.0,
2627                    xmax: 0.5,
2628                    ymax: 0.5,
2629                },
2630                score: 200_u8,
2631                label: 0,
2632            },
2633            DetectBoxQuantized {
2634                bbox: BoundingBox {
2635                    xmin: 0.1,
2636                    ymin: 0.1,
2637                    xmax: 0.6,
2638                    ymax: 0.6,
2639                },
2640                score: 180_u8,
2641                label: 1, // different class
2642            },
2643        ];
2644
2645        // Should keep both (different classes)
2646        let result = nms_class_aware_int(0.5, boxes);
2647        assert_eq!(
2648            result.len(),
2649            2,
2650            "Class-aware NMS (int) should keep boxes with different classes"
2651        );
2652    }
2653
2654    #[test]
2655    fn test_nms_enum_default() {
2656        // Test that Nms enum has the correct default
2657        let default_nms: configs::Nms = Default::default();
2658        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2659    }
2660
2661    #[test]
2662    fn test_decoder_nms_mode() {
2663        // Test that decoder properly stores NMS mode
2664        let decoder = DecoderBuilder::default()
2665            .with_config_yolo_det(
2666                configs::Detection {
2667                    anchors: None,
2668                    decoder: DecoderType::Ultralytics,
2669                    quantization: None,
2670                    shape: vec![1, 84, 8400],
2671                    dshape: Vec::new(),
2672                    normalized: Some(true),
2673                },
2674                None,
2675            )
2676            .with_nms(Some(configs::Nms::ClassAware))
2677            .build()
2678            .unwrap();
2679
2680        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2681    }
2682
2683    #[test]
2684    fn test_decoder_nms_bypass() {
2685        // Test that decoder can be configured with nms=None (bypass)
2686        let decoder = DecoderBuilder::default()
2687            .with_config_yolo_det(
2688                configs::Detection {
2689                    anchors: None,
2690                    decoder: DecoderType::Ultralytics,
2691                    quantization: None,
2692                    shape: vec![1, 84, 8400],
2693                    dshape: Vec::new(),
2694                    normalized: Some(true),
2695                },
2696                None,
2697            )
2698            .with_nms(None)
2699            .build()
2700            .unwrap();
2701
2702        assert_eq!(decoder.nms, None);
2703    }
2704
2705    #[test]
2706    fn test_decoder_normalized_boxes_true() {
2707        // Test that normalized_boxes returns Some(true) when explicitly set
2708        let decoder = DecoderBuilder::default()
2709            .with_config_yolo_det(
2710                configs::Detection {
2711                    anchors: None,
2712                    decoder: DecoderType::Ultralytics,
2713                    quantization: None,
2714                    shape: vec![1, 84, 8400],
2715                    dshape: Vec::new(),
2716                    normalized: Some(true),
2717                },
2718                None,
2719            )
2720            .build()
2721            .unwrap();
2722
2723        assert_eq!(decoder.normalized_boxes(), Some(true));
2724    }
2725
2726    #[test]
2727    fn test_decoder_normalized_boxes_false() {
2728        // Test that normalized_boxes returns Some(false) when config specifies
2729        // unnormalized
2730        let decoder = DecoderBuilder::default()
2731            .with_config_yolo_det(
2732                configs::Detection {
2733                    anchors: None,
2734                    decoder: DecoderType::Ultralytics,
2735                    quantization: None,
2736                    shape: vec![1, 84, 8400],
2737                    dshape: Vec::new(),
2738                    normalized: Some(false),
2739                },
2740                None,
2741            )
2742            .build()
2743            .unwrap();
2744
2745        assert_eq!(decoder.normalized_boxes(), Some(false));
2746    }
2747
2748    #[test]
2749    fn test_decoder_normalized_boxes_unknown() {
2750        // Test that normalized_boxes returns None when not specified in config
2751        let decoder = DecoderBuilder::default()
2752            .with_config_yolo_det(
2753                configs::Detection {
2754                    anchors: None,
2755                    decoder: DecoderType::Ultralytics,
2756                    quantization: None,
2757                    shape: vec![1, 84, 8400],
2758                    dshape: Vec::new(),
2759                    normalized: None,
2760                },
2761                Some(DecoderVersion::Yolo11),
2762            )
2763            .build()
2764            .unwrap();
2765
2766        assert_eq!(decoder.normalized_boxes(), None);
2767    }
2768
2769    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2770        input: ArrayView<F, D>,
2771        quant: Quantization,
2772    ) -> Array<T, D>
2773    where
2774        i32: num_traits::AsPrimitive<F>,
2775        f32: num_traits::AsPrimitive<F>,
2776    {
2777        let zero_point = quant.zero_point.as_();
2778        let div_scale = F::one() / quant.scale.as_();
2779        if zero_point != F::zero() {
2780            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2781        } else {
2782            input.mapv(|d| (d * div_scale).round().as_())
2783        }
2784    }
2785
2786    fn real_data_expected_boxes() -> [DetectBox; 2] {
2787        [
2788            DetectBox {
2789                bbox: BoundingBox {
2790                    xmin: 0.08515105,
2791                    ymin: 0.7131401,
2792                    xmax: 0.29802868,
2793                    ymax: 0.8195788,
2794                },
2795                score: 0.91537374,
2796                label: 23,
2797            },
2798            DetectBox {
2799                bbox: BoundingBox {
2800                    xmin: 0.59605736,
2801                    ymin: 0.25545314,
2802                    xmax: 0.93666154,
2803                    ymax: 0.72378385,
2804                },
2805                score: 0.91537374,
2806                label: 23,
2807            },
2808        ]
2809    }
2810
2811    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2812        [DetectBox {
2813            bbox: BoundingBox {
2814                xmin: 0.12549022,
2815                ymin: 0.12549022,
2816                xmax: 0.23529413,
2817                ymax: 0.23529413,
2818            },
2819            score: 0.98823535,
2820            label: 2,
2821        }]
2822    }
2823
2824    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2825        [DetectBox {
2826            bbox: BoundingBox {
2827                xmin: 0.1234,
2828                ymin: 0.1234,
2829                xmax: 0.2345,
2830                ymax: 0.2345,
2831            },
2832            score: 0.9876,
2833            label: 2,
2834        }]
2835    }
2836
2837    macro_rules! real_data_proto_test {
2838        ($name:ident, quantized, $layout:ident) => {
2839            #[test]
2840            fn $name() {
2841                let is_split = matches!(stringify!($layout), "split");
2842
2843                let score_threshold = 0.45;
2844                let iou_threshold = 0.45;
2845                let quant_boxes = (0.021287762_f32, 31_i32);
2846                let quant_protos = (0.02491162_f32, -117_i32);
2847
2848                let raw_boxes = include_bytes!(concat!(
2849                    env!("CARGO_MANIFEST_DIR"),
2850                    "/../../testdata/yolov8_boxes_116x8400.bin"
2851                ));
2852                let raw_boxes = unsafe {
2853                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2854                };
2855                let boxes_i8 =
2856                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2857
2858                let raw_protos = include_bytes!(concat!(
2859                    env!("CARGO_MANIFEST_DIR"),
2860                    "/../../testdata/yolov8_protos_160x160x32.bin"
2861                ));
2862                let raw_protos = unsafe {
2863                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2864                };
2865                let protos_i8 =
2866                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2867                        .unwrap();
2868
2869                // Pre-split (unused for combined, but harmless)
2870                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2871                let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2872                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2873                let boxes_combined = boxes_i8;
2874
2875                let decoder = if is_split {
2876                    build_yolo_split_segdet_decoder(
2877                        score_threshold,
2878                        iou_threshold,
2879                        quant_boxes,
2880                        quant_protos,
2881                    )
2882                } else {
2883                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2884                };
2885
2886                let expected = real_data_expected_boxes();
2887                let mut output_boxes = Vec::with_capacity(50);
2888
2889                let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
2890                    vec![
2891                        boxes_split.view().into(),
2892                        scores_split.view().into(),
2893                        mask_split.view().into(),
2894                        protos_i8.view().into(),
2895                    ]
2896                } else {
2897                    vec![boxes_combined.view().into(), protos_i8.view().into()]
2898                };
2899                decoder
2900                    .decode_quantized_proto(&inputs, &mut output_boxes)
2901                    .unwrap();
2902
2903                assert_eq!(output_boxes.len(), 2);
2904                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2905                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2906            }
2907        };
2908        ($name:ident, float, $layout:ident) => {
2909            #[test]
2910            fn $name() {
2911                let is_split = matches!(stringify!($layout), "split");
2912
2913                let score_threshold = 0.45;
2914                let iou_threshold = 0.45;
2915                let quant_boxes = (0.021287762_f32, 31_i32);
2916                let quant_protos = (0.02491162_f32, -117_i32);
2917
2918                let raw_boxes = include_bytes!(concat!(
2919                    env!("CARGO_MANIFEST_DIR"),
2920                    "/../../testdata/yolov8_boxes_116x8400.bin"
2921                ));
2922                let raw_boxes = unsafe {
2923                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2924                };
2925                let boxes_i8 =
2926                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2927                let boxes_f32: Array3<f32> =
2928                    dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
2929
2930                let raw_protos = include_bytes!(concat!(
2931                    env!("CARGO_MANIFEST_DIR"),
2932                    "/../../testdata/yolov8_protos_160x160x32.bin"
2933                ));
2934                let raw_protos = unsafe {
2935                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2936                };
2937                let protos_i8 =
2938                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2939                        .unwrap();
2940                let protos_f32: Array4<f32> =
2941                    dequantize_ndarray(protos_i8.view(), quant_protos.into());
2942
2943                // Pre-split from dequantized data
2944                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
2945                let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
2946                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
2947                let boxes_combined = boxes_f32;
2948
2949                let decoder = if is_split {
2950                    build_yolo_split_segdet_decoder(
2951                        score_threshold,
2952                        iou_threshold,
2953                        quant_boxes,
2954                        quant_protos,
2955                    )
2956                } else {
2957                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2958                };
2959
2960                let expected = real_data_expected_boxes();
2961                let mut output_boxes = Vec::with_capacity(50);
2962
2963                let inputs = if is_split {
2964                    vec![
2965                        boxes_split.view().into_dyn(),
2966                        scores_split.view().into_dyn(),
2967                        mask_split.view().into_dyn(),
2968                        protos_f32.view().into_dyn(),
2969                    ]
2970                } else {
2971                    vec![
2972                        boxes_combined.view().into_dyn(),
2973                        protos_f32.view().into_dyn(),
2974                    ]
2975                };
2976                decoder
2977                    .decode_float_proto(&inputs, &mut output_boxes)
2978                    .unwrap();
2979
2980                assert_eq!(output_boxes.len(), 2);
2981                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2982                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2983            }
2984        };
2985    }
2986
2987    real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
2988    real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
2989    real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
2990    real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
2991
2992    const E2E_COMBINED_DET_CONFIG: &str = "
2993decoder_version: yolo26
2994outputs:
2995 - type: detection
2996   decoder: ultralytics
2997   quantization: [0.00784313725490196, 0]
2998   shape: [1, 10, 6]
2999   dshape:
3000    - [batch, 1]
3001    - [num_boxes, 10]
3002    - [num_features, 6]
3003   normalized: true
3004";
3005
3006    const E2E_COMBINED_SEGDET_CONFIG: &str = "
3007decoder_version: yolo26
3008outputs:
3009 - type: detection
3010   decoder: ultralytics
3011   quantization: [0.00784313725490196, 0]
3012   shape: [1, 10, 38]
3013   dshape:
3014    - [batch, 1]
3015    - [num_boxes, 10]
3016    - [num_features, 38]
3017   normalized: true
3018 - type: protos
3019   decoder: ultralytics
3020   quantization: [0.0039215686274509803921568627451, 128]
3021   shape: [1, 160, 160, 32]
3022   dshape:
3023    - [batch, 1]
3024    - [height, 160]
3025    - [width, 160]
3026    - [num_protos, 32]
3027";
3028
3029    const E2E_SPLIT_DET_CONFIG: &str = "
3030decoder_version: yolo26
3031outputs:
3032 - type: boxes
3033   decoder: ultralytics
3034   quantization: [0.00784313725490196, 0]
3035   shape: [1, 10, 4]
3036   dshape:
3037    - [batch, 1]
3038    - [num_boxes, 10]
3039    - [box_coords, 4]
3040   normalized: true
3041 - type: scores
3042   decoder: ultralytics
3043   quantization: [0.00784313725490196, 0]
3044   shape: [1, 10, 1]
3045   dshape:
3046    - [batch, 1]
3047    - [num_boxes, 10]
3048    - [num_classes, 1]
3049 - type: classes
3050   decoder: ultralytics
3051   quantization: [0.00784313725490196, 0]
3052   shape: [1, 10, 1]
3053   dshape:
3054    - [batch, 1]
3055    - [num_boxes, 10]
3056    - [num_classes, 1]
3057";
3058
3059    const E2E_SPLIT_SEGDET_CONFIG: &str = "
3060decoder_version: yolo26
3061outputs:
3062 - type: boxes
3063   decoder: ultralytics
3064   quantization: [0.00784313725490196, 0]
3065   shape: [1, 10, 4]
3066   dshape:
3067    - [batch, 1]
3068    - [num_boxes, 10]
3069    - [box_coords, 4]
3070   normalized: true
3071 - type: scores
3072   decoder: ultralytics
3073   quantization: [0.00784313725490196, 0]
3074   shape: [1, 10, 1]
3075   dshape:
3076    - [batch, 1]
3077    - [num_boxes, 10]
3078    - [num_classes, 1]
3079 - type: classes
3080   decoder: ultralytics
3081   quantization: [0.00784313725490196, 0]
3082   shape: [1, 10, 1]
3083   dshape:
3084    - [batch, 1]
3085    - [num_boxes, 10]
3086    - [num_classes, 1]
3087 - type: mask_coefficients
3088   decoder: ultralytics
3089   quantization: [0.00784313725490196, 0]
3090   shape: [1, 10, 32]
3091   dshape:
3092    - [batch, 1]
3093    - [num_boxes, 10]
3094    - [num_protos, 32]
3095 - type: protos
3096   decoder: ultralytics
3097   quantization: [0.0039215686274509803921568627451, 128]
3098   shape: [1, 160, 160, 32]
3099   dshape:
3100    - [batch, 1]
3101    - [height, 160]
3102    - [width, 160]
3103    - [num_protos, 32]
3104";
3105
3106    macro_rules! e2e_segdet_test {
3107        ($name:ident, quantized, $layout:ident, $output:ident) => {
3108            #[test]
3109            fn $name() {
3110                let is_split = matches!(stringify!($layout), "split");
3111                let is_proto = matches!(stringify!($output), "proto");
3112
3113                let score_threshold = 0.45;
3114                let iou_threshold = 0.45;
3115
3116                let mut boxes = Array2::zeros((10, 4));
3117                let mut scores = Array2::zeros((10, 1));
3118                let mut classes = Array2::zeros((10, 1));
3119                let mask = Array2::zeros((10, 32));
3120                let protos = Array3::<f64>::zeros((160, 160, 32));
3121                let protos = protos.insert_axis(Axis(0));
3122                let protos_quant = (1.0 / 255.0, 0.0);
3123                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3124
3125                boxes
3126                    .slice_mut(s![0, ..])
3127                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3128                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3129                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3130
3131                let detect_quant = (2.0 / 255.0, 0.0);
3132
3133                let decoder = if is_split {
3134                    DecoderBuilder::default()
3135                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3136                        .with_score_threshold(score_threshold)
3137                        .with_iou_threshold(iou_threshold)
3138                        .build()
3139                        .unwrap()
3140                } else {
3141                    DecoderBuilder::default()
3142                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3143                        .with_score_threshold(score_threshold)
3144                        .with_iou_threshold(iou_threshold)
3145                        .build()
3146                        .unwrap()
3147                };
3148
3149                let expected = e2e_expected_boxes_quant();
3150                let mut output_boxes = Vec::with_capacity(50);
3151
3152                if is_split {
3153                    let boxes = boxes.insert_axis(Axis(0));
3154                    let scores = scores.insert_axis(Axis(0));
3155                    let classes = classes.insert_axis(Axis(0));
3156                    let mask = mask.insert_axis(Axis(0));
3157
3158                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3159                    let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3160                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3161                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3162
3163                    if is_proto {
3164                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3165                            boxes.view().into(),
3166                            scores.view().into(),
3167                            classes.view().into(),
3168                            mask.view().into(),
3169                            protos.view().into(),
3170                        ];
3171                        decoder
3172                            .decode_quantized_proto(&inputs, &mut output_boxes)
3173                            .unwrap();
3174
3175                        assert_eq!(output_boxes.len(), 1);
3176                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3177                    } else {
3178                        let mut output_masks = Vec::with_capacity(50);
3179                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3180                            boxes.view().into(),
3181                            scores.view().into(),
3182                            classes.view().into(),
3183                            mask.view().into(),
3184                            protos.view().into(),
3185                        ];
3186                        decoder
3187                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3188                            .unwrap();
3189
3190                        assert_eq!(output_boxes.len(), 1);
3191                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3192                    }
3193                } else {
3194                    // Combined layout
3195                    let detect = ndarray::concatenate![
3196                        Axis(1),
3197                        boxes.view(),
3198                        scores.view(),
3199                        classes.view(),
3200                        mask.view()
3201                    ];
3202                    let detect = detect.insert_axis(Axis(0));
3203                    assert_eq!(detect.shape(), &[1, 10, 38]);
3204                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3205
3206                    if is_proto {
3207                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3208                            vec![detect.view().into(), protos.view().into()];
3209                        decoder
3210                            .decode_quantized_proto(&inputs, &mut output_boxes)
3211                            .unwrap();
3212
3213                        assert_eq!(output_boxes.len(), 1);
3214                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3215                    } else {
3216                        let mut output_masks = Vec::with_capacity(50);
3217                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3218                            vec![detect.view().into(), protos.view().into()];
3219                        decoder
3220                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3221                            .unwrap();
3222
3223                        assert_eq!(output_boxes.len(), 1);
3224                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3225                    }
3226                }
3227            }
3228        };
3229        ($name:ident, float, $layout:ident, $output:ident) => {
3230            #[test]
3231            fn $name() {
3232                let is_split = matches!(stringify!($layout), "split");
3233                let is_proto = matches!(stringify!($output), "proto");
3234
3235                let score_threshold = 0.45;
3236                let iou_threshold = 0.45;
3237
3238                let mut boxes = Array2::zeros((10, 4));
3239                let mut scores = Array2::zeros((10, 1));
3240                let mut classes = Array2::zeros((10, 1));
3241                let mask: Array2<f64> = Array2::zeros((10, 32));
3242                let protos = Array3::<f64>::zeros((160, 160, 32));
3243                let protos = protos.insert_axis(Axis(0));
3244
3245                boxes
3246                    .slice_mut(s![0, ..])
3247                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3248                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3249                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3250
3251                let decoder = if is_split {
3252                    DecoderBuilder::default()
3253                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3254                        .with_score_threshold(score_threshold)
3255                        .with_iou_threshold(iou_threshold)
3256                        .build()
3257                        .unwrap()
3258                } else {
3259                    DecoderBuilder::default()
3260                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3261                        .with_score_threshold(score_threshold)
3262                        .with_iou_threshold(iou_threshold)
3263                        .build()
3264                        .unwrap()
3265                };
3266
3267                let expected = e2e_expected_boxes_float();
3268                let mut output_boxes = Vec::with_capacity(50);
3269
3270                if is_split {
3271                    let boxes = boxes.insert_axis(Axis(0));
3272                    let scores = scores.insert_axis(Axis(0));
3273                    let classes = classes.insert_axis(Axis(0));
3274                    let mask = mask.insert_axis(Axis(0));
3275
3276                    if is_proto {
3277                        let inputs = vec![
3278                            boxes.view().into_dyn(),
3279                            scores.view().into_dyn(),
3280                            classes.view().into_dyn(),
3281                            mask.view().into_dyn(),
3282                            protos.view().into_dyn(),
3283                        ];
3284                        decoder
3285                            .decode_float_proto(&inputs, &mut output_boxes)
3286                            .unwrap();
3287
3288                        assert_eq!(output_boxes.len(), 1);
3289                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3290                    } else {
3291                        let mut output_masks = Vec::with_capacity(50);
3292                        let inputs = vec![
3293                            boxes.view().into_dyn(),
3294                            scores.view().into_dyn(),
3295                            classes.view().into_dyn(),
3296                            mask.view().into_dyn(),
3297                            protos.view().into_dyn(),
3298                        ];
3299                        decoder
3300                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3301                            .unwrap();
3302
3303                        assert_eq!(output_boxes.len(), 1);
3304                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3305                    }
3306                } else {
3307                    // Combined layout
3308                    let detect = ndarray::concatenate![
3309                        Axis(1),
3310                        boxes.view(),
3311                        scores.view(),
3312                        classes.view(),
3313                        mask.view()
3314                    ];
3315                    let detect = detect.insert_axis(Axis(0));
3316                    assert_eq!(detect.shape(), &[1, 10, 38]);
3317
3318                    if is_proto {
3319                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3320                        decoder
3321                            .decode_float_proto(&inputs, &mut output_boxes)
3322                            .unwrap();
3323
3324                        assert_eq!(output_boxes.len(), 1);
3325                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3326                    } else {
3327                        let mut output_masks = Vec::with_capacity(50);
3328                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3329                        decoder
3330                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3331                            .unwrap();
3332
3333                        assert_eq!(output_boxes.len(), 1);
3334                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3335                    }
3336                }
3337            }
3338        };
3339    }
3340
3341    e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3342    e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3343    e2e_segdet_test!(
3344        test_decoder_end_to_end_segdet_proto,
3345        quantized,
3346        combined,
3347        proto
3348    );
3349    e2e_segdet_test!(
3350        test_decoder_end_to_end_segdet_proto_float,
3351        float,
3352        combined,
3353        proto
3354    );
3355    e2e_segdet_test!(
3356        test_decoder_end_to_end_segdet_split,
3357        quantized,
3358        split,
3359        masks
3360    );
3361    e2e_segdet_test!(
3362        test_decoder_end_to_end_segdet_split_float,
3363        float,
3364        split,
3365        masks
3366    );
3367    e2e_segdet_test!(
3368        test_decoder_end_to_end_segdet_split_proto,
3369        quantized,
3370        split,
3371        proto
3372    );
3373    e2e_segdet_test!(
3374        test_decoder_end_to_end_segdet_split_proto_float,
3375        float,
3376        split,
3377        proto
3378    );
3379
3380    macro_rules! e2e_det_test {
3381        ($name:ident, quantized, $layout:ident) => {
3382            #[test]
3383            fn $name() {
3384                let is_split = matches!(stringify!($layout), "split");
3385
3386                let score_threshold = 0.45;
3387                let iou_threshold = 0.45;
3388
3389                let mut boxes = Array3::zeros((1, 10, 4));
3390                let mut scores = Array3::zeros((1, 10, 1));
3391                let mut classes = Array3::zeros((1, 10, 1));
3392
3393                boxes
3394                    .slice_mut(s![0, 0, ..])
3395                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3396                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3397                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3398
3399                let detect_quant = (2.0 / 255.0, 0_i32);
3400
3401                let decoder = if is_split {
3402                    DecoderBuilder::default()
3403                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3404                        .with_score_threshold(score_threshold)
3405                        .with_iou_threshold(iou_threshold)
3406                        .build()
3407                        .unwrap()
3408                } else {
3409                    DecoderBuilder::default()
3410                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3411                        .with_score_threshold(score_threshold)
3412                        .with_iou_threshold(iou_threshold)
3413                        .build()
3414                        .unwrap()
3415                };
3416
3417                let expected = e2e_expected_boxes_quant();
3418                let mut output_boxes = Vec::with_capacity(50);
3419
3420                if is_split {
3421                    let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3422                    let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3423                    let classes: Array<u8, _> =
3424                        quantize_ndarray(classes.view(), detect_quant.into());
3425                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3426                        boxes.view().into(),
3427                        scores.view().into(),
3428                        classes.view().into(),
3429                    ];
3430                    decoder
3431                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3432                        .unwrap();
3433                } else {
3434                    let detect =
3435                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3436                    assert_eq!(detect.shape(), &[1, 10, 6]);
3437                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3438                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3439                        vec![detect.view().into()];
3440                    decoder
3441                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3442                        .unwrap();
3443                }
3444
3445                assert_eq!(output_boxes.len(), 1);
3446                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3447            }
3448        };
3449        ($name:ident, float, $layout:ident) => {
3450            #[test]
3451            fn $name() {
3452                let is_split = matches!(stringify!($layout), "split");
3453
3454                let score_threshold = 0.45;
3455                let iou_threshold = 0.45;
3456
3457                let mut boxes = Array3::zeros((1, 10, 4));
3458                let mut scores = Array3::zeros((1, 10, 1));
3459                let mut classes = Array3::zeros((1, 10, 1));
3460
3461                boxes
3462                    .slice_mut(s![0, 0, ..])
3463                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3464                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3465                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3466
3467                let decoder = if is_split {
3468                    DecoderBuilder::default()
3469                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3470                        .with_score_threshold(score_threshold)
3471                        .with_iou_threshold(iou_threshold)
3472                        .build()
3473                        .unwrap()
3474                } else {
3475                    DecoderBuilder::default()
3476                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3477                        .with_score_threshold(score_threshold)
3478                        .with_iou_threshold(iou_threshold)
3479                        .build()
3480                        .unwrap()
3481                };
3482
3483                let expected = e2e_expected_boxes_float();
3484                let mut output_boxes = Vec::with_capacity(50);
3485
3486                if is_split {
3487                    let inputs = vec![
3488                        boxes.view().into_dyn(),
3489                        scores.view().into_dyn(),
3490                        classes.view().into_dyn(),
3491                    ];
3492                    decoder
3493                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3494                        .unwrap();
3495                } else {
3496                    let detect =
3497                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3498                    assert_eq!(detect.shape(), &[1, 10, 6]);
3499                    let inputs = vec![detect.view().into_dyn()];
3500                    decoder
3501                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3502                        .unwrap();
3503                }
3504
3505                assert_eq!(output_boxes.len(), 1);
3506                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3507            }
3508        };
3509    }
3510
3511    e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3512    e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3513    e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3514    e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3515
3516    #[test]
3517    fn test_decode_tensor() {
3518        let score_threshold = 0.45;
3519        let iou_threshold = 0.45;
3520
3521        let raw_boxes = include_bytes!(concat!(
3522            env!("CARGO_MANIFEST_DIR"),
3523            "/../../testdata/yolov8_boxes_116x8400.bin"
3524        ));
3525        let raw_boxes =
3526            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3527        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3528        boxes_i8
3529            .map()
3530            .unwrap()
3531            .as_mut_slice()
3532            .copy_from_slice(raw_boxes);
3533        let boxes_i8 = boxes_i8.into();
3534
3535        let raw_protos = include_bytes!(concat!(
3536            env!("CARGO_MANIFEST_DIR"),
3537            "/../../testdata/yolov8_protos_160x160x32.bin"
3538        ));
3539        let raw_protos = unsafe {
3540            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3541        };
3542        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3543        protos_i8
3544            .map()
3545            .unwrap()
3546            .as_mut_slice()
3547            .copy_from_slice(raw_protos);
3548        let protos_i8 = protos_i8.into();
3549
3550        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3551        let expected = real_data_expected_boxes();
3552        let mut output_boxes = Vec::with_capacity(50);
3553
3554        decoder
3555            .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3556            .unwrap();
3557
3558        assert_eq!(output_boxes.len(), 2);
3559        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3560        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3561    }
3562
3563    #[test]
3564    fn test_decode_tensor_f32() {
3565        let score_threshold = 0.45;
3566        let iou_threshold = 0.45;
3567
3568        let quant_boxes = (0.021287762_f32, 31_i32);
3569        let quant_protos = (0.02491162_f32, -117_i32);
3570        let raw_boxes = include_bytes!(concat!(
3571            env!("CARGO_MANIFEST_DIR"),
3572            "/../../testdata/yolov8_boxes_116x8400.bin"
3573        ));
3574        let raw_boxes =
3575            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3576        let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3577        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3578        let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3579        boxes_f32
3580            .map()
3581            .unwrap()
3582            .as_mut_slice()
3583            .copy_from_slice(&raw_boxes_f32);
3584        let boxes_f32 = boxes_f32.into();
3585
3586        let raw_protos = include_bytes!(concat!(
3587            env!("CARGO_MANIFEST_DIR"),
3588            "/../../testdata/yolov8_protos_160x160x32.bin"
3589        ));
3590        let raw_protos = unsafe {
3591            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3592        };
3593        let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3594        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3595        let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3596        protos_f32
3597            .map()
3598            .unwrap()
3599            .as_mut_slice()
3600            .copy_from_slice(&raw_protos_f32);
3601        let protos_f32 = protos_f32.into();
3602
3603        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3604
3605        let expected = real_data_expected_boxes();
3606        let mut output_boxes = Vec::with_capacity(50);
3607
3608        decoder
3609            .decode(
3610                &[&boxes_f32, &protos_f32],
3611                &mut output_boxes,
3612                &mut Vec::new(),
3613            )
3614            .unwrap();
3615
3616        assert_eq!(output_boxes.len(), 2);
3617        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3618        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3619    }
3620
3621    #[test]
3622    fn test_decode_tensor_f64() {
3623        let score_threshold = 0.45;
3624        let iou_threshold = 0.45;
3625
3626        let quant_boxes = (0.021287762_f32, 31_i32);
3627        let quant_protos = (0.02491162_f32, -117_i32);
3628        let raw_boxes = include_bytes!(concat!(
3629            env!("CARGO_MANIFEST_DIR"),
3630            "/../../testdata/yolov8_boxes_116x8400.bin"
3631        ));
3632        let raw_boxes =
3633            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3634        let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3635        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3636        let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3637        boxes_f64
3638            .map()
3639            .unwrap()
3640            .as_mut_slice()
3641            .copy_from_slice(&raw_boxes_f64);
3642        let boxes_f64 = boxes_f64.into();
3643
3644        let raw_protos = include_bytes!(concat!(
3645            env!("CARGO_MANIFEST_DIR"),
3646            "/../../testdata/yolov8_protos_160x160x32.bin"
3647        ));
3648        let raw_protos = unsafe {
3649            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3650        };
3651        let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3652        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3653        let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3654        protos_f64
3655            .map()
3656            .unwrap()
3657            .as_mut_slice()
3658            .copy_from_slice(&raw_protos_f64);
3659        let protos_f64 = protos_f64.into();
3660
3661        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3662
3663        let expected = real_data_expected_boxes();
3664        let mut output_boxes = Vec::with_capacity(50);
3665
3666        decoder
3667            .decode(
3668                &[&boxes_f64, &protos_f64],
3669                &mut output_boxes,
3670                &mut Vec::new(),
3671            )
3672            .unwrap();
3673
3674        assert_eq!(output_boxes.len(), 2);
3675        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3676        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3677    }
3678
3679    #[test]
3680    fn test_decode_tensor_proto() {
3681        let score_threshold = 0.45;
3682        let iou_threshold = 0.45;
3683
3684        let raw_boxes = include_bytes!(concat!(
3685            env!("CARGO_MANIFEST_DIR"),
3686            "/../../testdata/yolov8_boxes_116x8400.bin"
3687        ));
3688        let raw_boxes =
3689            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3690        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3691        boxes_i8
3692            .map()
3693            .unwrap()
3694            .as_mut_slice()
3695            .copy_from_slice(raw_boxes);
3696        let boxes_i8 = boxes_i8.into();
3697
3698        let raw_protos = include_bytes!(concat!(
3699            env!("CARGO_MANIFEST_DIR"),
3700            "/../../testdata/yolov8_protos_160x160x32.bin"
3701        ));
3702        let raw_protos = unsafe {
3703            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3704        };
3705        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3706        protos_i8
3707            .map()
3708            .unwrap()
3709            .as_mut_slice()
3710            .copy_from_slice(raw_protos);
3711        let protos_i8 = protos_i8.into();
3712
3713        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3714
3715        let expected = real_data_expected_boxes();
3716        let mut output_boxes = Vec::with_capacity(50);
3717
3718        let proto_data = decoder
3719            .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3720            .unwrap();
3721
3722        assert_eq!(output_boxes.len(), 2);
3723        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3724        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3725
3726        let proto_data = proto_data.expect("segmentation model should return ProtoData");
3727        let coeffs_shape = proto_data.mask_coefficients.shape();
3728        assert_eq!(
3729            coeffs_shape[0],
3730            output_boxes.len(),
3731            "mask_coefficients count must match detection count"
3732        );
3733        assert_eq!(
3734            coeffs_shape[1], 32,
3735            "each detection should have 32 mask coefficients"
3736        );
3737    }
3738
3739    // =========================================================================
3740    // Physical-order contract regression tests
3741    //
3742    // These cover the TFLite VX-delegate vertical-stripe mask bug and the
3743    // Ara-2-via-overlay anchor-first split-tensor bug. The core claim:
3744    // declaring shape+dshape in physical memory order produces correct
3745    // strides at wrap time, and `swap_axes_if_needed` permutes the stride
3746    // tuple into canonical logical order without touching bytes.
3747    // =========================================================================
3748
3749    /// TFLite YOLOv8-seg protos arrive as physically-NHWC bytes with
3750    /// NNStreamer dim string `"32:160:160:1"` (innermost-first). The
3751    /// overlay's 2026-04-22 workaround declares `[1, 160, 160, 32]` with
3752    /// dshape `[batch, height, width, num_protos]` — shape matches
3753    /// physical order, so no permutation is needed and the view is
3754    /// already in canonical HWC after the batch-dim slice. This test
3755    /// confirms that path matches the reference (direct-HWC) decode.
3756    #[test]
3757    fn test_physical_order_tflite_nhwc_protos() {
3758        let score_threshold = 0.45;
3759        let iou_threshold = 0.45;
3760
3761        // Load HWC protos and dequantize to f32.
3762        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
3763        let quant_protos = Quantization::new(0.02491161972284317, -117);
3764        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
3765
3766        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
3767        let quant_boxes = Quantization::new(0.021287761628627777, 31);
3768        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
3769
3770        // Reference: direct call with canonical HWC protos.
3771        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
3772        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
3773        decode_yolo_segdet_float(
3774            seg.view(),
3775            protos_f32_hwc.view(),
3776            score_threshold,
3777            iou_threshold,
3778            Some(configs::Nms::ClassAgnostic),
3779            &mut ref_boxes,
3780            &mut ref_masks,
3781        )
3782        .unwrap();
3783
3784        // Build the same protos as a 4D NHWC tensor and feed it through
3785        // the config-driven path with dshape in physical order.
3786        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0)); // [1, 160, 160, 32]
3787        let seg_3d = seg.insert_axis(Axis(0)); // [1, 116, 8400]
3788
3789        let decoder = DecoderBuilder::default()
3790            .with_config_yolo_segdet(
3791                configs::Detection {
3792                    decoder: configs::DecoderType::Ultralytics,
3793                    quantization: None,
3794                    shape: vec![1, 116, 8400],
3795                    dshape: vec![
3796                        (DimName::Batch, 1),
3797                        (DimName::NumFeatures, 116),
3798                        (DimName::NumBoxes, 8400),
3799                    ],
3800                    normalized: Some(true),
3801                    anchors: None,
3802                },
3803                configs::Protos {
3804                    decoder: configs::DecoderType::Ultralytics,
3805                    quantization: None,
3806                    shape: vec![1, 160, 160, 32],
3807                    // Physical NHWC — matches the TFLite buffer layout.
3808                    dshape: vec![
3809                        (DimName::Batch, 1),
3810                        (DimName::Height, 160),
3811                        (DimName::Width, 160),
3812                        (DimName::NumProtos, 32),
3813                    ],
3814                },
3815                None,
3816            )
3817            .with_score_threshold(score_threshold)
3818            .with_iou_threshold(iou_threshold)
3819            .build()
3820            .expect("config with NHWC protos dshape must build");
3821
3822        let mut cfg_boxes = Vec::with_capacity(10);
3823        let mut cfg_masks = Vec::with_capacity(10);
3824        decoder
3825            .decode_float(
3826                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
3827                &mut cfg_boxes,
3828                &mut cfg_masks,
3829            )
3830            .unwrap();
3831
3832        assert_eq!(cfg_boxes.len(), ref_boxes.len(), "box count mismatch");
3833        for (c, r) in cfg_boxes.iter().zip(&ref_boxes) {
3834            assert!(
3835                c.equal_within_delta(r, 0.01),
3836                "NHWC-declared box does not match reference: {c:?} vs {r:?}"
3837            );
3838        }
3839        for (cm, rm) in cfg_masks.iter().zip(&ref_masks) {
3840            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
3841            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
3842            assert_eq!(
3843                cm_arr, rm_arr,
3844                "NHWC-declared mask must match reference pixel-for-pixel"
3845            );
3846        }
3847    }
3848
3849    /// Ara-2 split-tensor boxes arrive physically anchor-first: NNStreamer
3850    /// dim string `"4:1:8400:1"` means innermost axis is 4 (coords), next
3851    /// is 1 (padding), next is 8400 (anchors), outermost is 1 (batch).
3852    /// The correct physical-order declaration is `[1, 8400, 1, 4]` with
3853    /// dshape `[batch, num_boxes, padding, box_coords]`. This test
3854    /// confirms that declaration produces the same decoded boxes as the
3855    /// canonical features-first TFLite declaration over the same logical
3856    /// data.
3857    #[test]
3858    fn test_physical_order_ara2_anchor_first_split_boxes() {
3859        use configs::{Boxes, Scores};
3860
3861        // Build synthetic boxes data in features-first canonical layout:
3862        // shape [1, 4, 8400] with one detection at anchor 42.
3863        const N: usize = 8400;
3864        let mut boxes_canonical = Array3::<f32>::zeros((1, 4, N));
3865        let target_anchor = 42usize;
3866        boxes_canonical[[0, 0, target_anchor]] = 0.4; // xc
3867        boxes_canonical[[0, 1, target_anchor]] = 0.5; // yc
3868        boxes_canonical[[0, 2, target_anchor]] = 0.2; // w
3869        boxes_canonical[[0, 3, target_anchor]] = 0.2; // h
3870
3871        // Scores: one class at 0.9 for the same anchor.
3872        let mut scores_canonical = Array3::<f32>::zeros((1, 80, N));
3873        scores_canonical[[0, 0, target_anchor]] = 0.9;
3874
3875        // Reference: canonical (features-first) shape declaration.
3876        let ref_decoder = DecoderBuilder::default()
3877            .with_config_yolo_split_det(
3878                Boxes {
3879                    decoder: configs::DecoderType::Ultralytics,
3880                    quantization: None,
3881                    shape: vec![1, 4, N],
3882                    dshape: vec![
3883                        (DimName::Batch, 1),
3884                        (DimName::BoxCoords, 4),
3885                        (DimName::NumBoxes, N),
3886                    ],
3887                    normalized: Some(true),
3888                },
3889                Scores {
3890                    decoder: configs::DecoderType::Ultralytics,
3891                    quantization: None,
3892                    shape: vec![1, 80, N],
3893                    dshape: vec![
3894                        (DimName::Batch, 1),
3895                        (DimName::NumClasses, 80),
3896                        (DimName::NumBoxes, N),
3897                    ],
3898                },
3899            )
3900            .with_score_threshold(0.5)
3901            .with_iou_threshold(0.5)
3902            .with_nms(Some(configs::Nms::ClassAgnostic))
3903            .build()
3904            .expect("reference canonical split decoder must build");
3905
3906        let mut ref_boxes = Vec::with_capacity(4);
3907        let mut ref_masks = Vec::with_capacity(0);
3908        ref_decoder
3909            .decode_float(
3910                &[
3911                    boxes_canonical.view().into_dyn(),
3912                    scores_canonical.view().into_dyn(),
3913                ],
3914                &mut ref_boxes,
3915                &mut ref_masks,
3916            )
3917            .unwrap();
3918        assert_eq!(ref_boxes.len(), 1, "reference should produce one box");
3919
3920        // Ara-2 physical layout: transpose axes 1↔2 so the innermost
3921        // axis is BoxCoords / NumClasses. Materialise to get a
3922        // C-contiguous Array3 with strides matching the physical order
3923        // the Ara-2 backend would write into the tensor.
3924        let boxes_ara2 = boxes_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 4]
3925        let scores_ara2 = scores_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 80]
3926
3927        let ara2_decoder = DecoderBuilder::default()
3928            .with_config_yolo_split_det(
3929                Boxes {
3930                    decoder: configs::DecoderType::Ultralytics,
3931                    quantization: None,
3932                    shape: vec![1, N, 4],
3933                    dshape: vec![
3934                        (DimName::Batch, 1),
3935                        (DimName::NumBoxes, N),
3936                        (DimName::BoxCoords, 4),
3937                    ],
3938                    normalized: Some(true),
3939                },
3940                Scores {
3941                    decoder: configs::DecoderType::Ultralytics,
3942                    quantization: None,
3943                    shape: vec![1, N, 80],
3944                    dshape: vec![
3945                        (DimName::Batch, 1),
3946                        (DimName::NumBoxes, N),
3947                        (DimName::NumClasses, 80),
3948                    ],
3949                },
3950            )
3951            .with_score_threshold(0.5)
3952            .with_iou_threshold(0.5)
3953            .with_nms(Some(configs::Nms::ClassAgnostic))
3954            .build()
3955            .expect("Ara-2 anchor-first decoder must build");
3956
3957        let mut ara2_boxes = Vec::with_capacity(4);
3958        let mut ara2_masks = Vec::with_capacity(0);
3959        ara2_decoder
3960            .decode_float(
3961                &[boxes_ara2.view().into_dyn(), scores_ara2.view().into_dyn()],
3962                &mut ara2_boxes,
3963                &mut ara2_masks,
3964            )
3965            .unwrap();
3966
3967        assert_eq!(
3968            ara2_boxes.len(),
3969            ref_boxes.len(),
3970            "Ara-2 anchor-first declaration must produce the same number \
3971             of boxes as the canonical features-first reference"
3972        );
3973        for (a, r) in ara2_boxes.iter().zip(&ref_boxes) {
3974            assert!(
3975                a.equal_within_delta(r, 1e-4),
3976                "Ara-2 box differs from reference: {a:?} vs {r:?}"
3977            );
3978        }
3979    }
3980
3981    /// Dshape whose per-axis sizes disagree with shape (the exact
3982    /// failure mode that caused the TFLite stripe bug — shape declared
3983    /// in one order, dshape in another) is rejected with a clear error.
3984    #[test]
3985    fn test_physical_order_rejects_shape_dshape_mismatch() {
3986        let result = DecoderBuilder::default()
3987            .with_config_yolo_segdet(
3988                configs::Detection {
3989                    decoder: configs::DecoderType::Ultralytics,
3990                    quantization: None,
3991                    shape: vec![1, 116, 8400],
3992                    dshape: vec![
3993                        (DimName::Batch, 1),
3994                        (DimName::NumFeatures, 116),
3995                        (DimName::NumBoxes, 8400),
3996                    ],
3997                    normalized: Some(true),
3998                    anchors: None,
3999                },
4000                configs::Protos {
4001                    decoder: configs::DecoderType::Ultralytics,
4002                    quantization: None,
4003                    // Shape in NCHW order...
4004                    shape: vec![1, 32, 160, 160],
4005                    // ...but dshape in NHWC order — the sizes don't line
4006                    // up positionally (dshape[1]=160 vs shape[1]=32).
4007                    dshape: vec![
4008                        (DimName::Batch, 1),
4009                        (DimName::Height, 160),
4010                        (DimName::Width, 160),
4011                        (DimName::NumProtos, 32),
4012                    ],
4013                },
4014                None,
4015            )
4016            .build();
4017
4018        match result {
4019            Err(DecoderError::InvalidConfig(msg)) => {
4020                assert!(
4021                    msg.contains("does not match shape"),
4022                    "expected shape/dshape size mismatch error, got: {msg}"
4023                );
4024            }
4025            other => panic!("expected InvalidConfig, got {other:?}"),
4026        }
4027    }
4028
4029    /// Duplicate dim name in dshape is rejected (two axes mapping to the
4030    /// same canonical slot would break the permutation).
4031    #[test]
4032    fn test_physical_order_rejects_duplicate_dshape_axis() {
4033        let result = DecoderBuilder::default()
4034            .with_config_yolo_split_det(
4035                configs::Boxes {
4036                    decoder: configs::DecoderType::Ultralytics,
4037                    quantization: None,
4038                    shape: vec![1, 4, 8400],
4039                    dshape: vec![
4040                        (DimName::Batch, 1),
4041                        (DimName::BoxCoords, 4),
4042                        (DimName::BoxCoords, 4), // duplicate (size matches shape[2]=8400? no)
4043                    ],
4044                    normalized: Some(true),
4045                },
4046                configs::Scores {
4047                    decoder: configs::DecoderType::Ultralytics,
4048                    quantization: None,
4049                    shape: vec![1, 80, 8400],
4050                    dshape: vec![
4051                        (DimName::Batch, 1),
4052                        (DimName::NumClasses, 80),
4053                        (DimName::NumBoxes, 8400),
4054                    ],
4055                },
4056            )
4057            .build();
4058
4059        // Shape[2] = 8400 ≠ dshape[2] size 4, so the positional-mismatch
4060        // check fires first — which is fine: any mis-shaped dshape
4061        // should be rejected. Check for either error as long as it's a
4062        // clear InvalidConfig.
4063        match result {
4064            Err(DecoderError::InvalidConfig(msg)) => {
4065                assert!(
4066                    msg.contains("appears at both index") || msg.contains("does not match shape"),
4067                    "expected positional or duplicate-axis error, got: {msg}"
4068                );
4069            }
4070            other => panic!("expected InvalidConfig, got {other:?}"),
4071        }
4072
4073        // Separate case: size-consistent duplicate to exercise the
4074        // duplicate-axis code path explicitly. Use `Batch` (no size
4075        // constraint, can repeat with size 1 against a shape that
4076        // carries two singleton dims).
4077        let result = DecoderBuilder::default()
4078            .with_config_yolo_split_det(
4079                configs::Boxes {
4080                    decoder: configs::DecoderType::Ultralytics,
4081                    quantization: None,
4082                    shape: vec![1, 1, 4, 8400],
4083                    dshape: vec![
4084                        (DimName::Batch, 1),
4085                        (DimName::Batch, 1), // duplicate, sizes match
4086                        (DimName::BoxCoords, 4),
4087                        (DimName::NumBoxes, 8400),
4088                    ],
4089                    normalized: Some(true),
4090                },
4091                configs::Scores {
4092                    decoder: configs::DecoderType::Ultralytics,
4093                    quantization: None,
4094                    shape: vec![1, 80, 8400],
4095                    dshape: vec![
4096                        (DimName::Batch, 1),
4097                        (DimName::NumClasses, 80),
4098                        (DimName::NumBoxes, 8400),
4099                    ],
4100                },
4101            )
4102            .build();
4103        match result {
4104            Err(DecoderError::InvalidConfig(msg)) => {
4105                assert!(
4106                    msg.contains("appears at both index"),
4107                    "expected duplicate-axis error, got: {msg}"
4108                );
4109            }
4110            other => panic!("expected InvalidConfig, got {other:?}"),
4111        }
4112    }
4113
4114    /// Canonical (dshape-omitted) high-level decode produces the same
4115    /// numeric result as the dshape-populated form. This closes a
4116    /// coverage gap flagged during review: dshape omission had been
4117    /// exercised only at the builder level, not through the full
4118    /// `decode_float` pipeline.
4119    #[test]
4120    fn test_physical_order_dshape_omitted_decodes_numerically() {
4121        let score_threshold = 0.45;
4122        let iou_threshold = 0.45;
4123
4124        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
4125        let quant_protos = Quantization::new(0.02491161972284317, -117);
4126        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
4127
4128        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
4129        let quant_boxes = Quantization::new(0.021287761628627777, 31);
4130        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
4131
4132        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0));
4133        let seg_3d = seg.insert_axis(Axis(0));
4134
4135        let build_decoder = |det_dshape: Vec<(DimName, usize)>,
4136                             proto_dshape: Vec<(DimName, usize)>| {
4137            DecoderBuilder::default()
4138                .with_config_yolo_segdet(
4139                    configs::Detection {
4140                        decoder: configs::DecoderType::Ultralytics,
4141                        quantization: None,
4142                        shape: vec![1, 116, 8400],
4143                        dshape: det_dshape,
4144                        normalized: Some(true),
4145                        anchors: None,
4146                    },
4147                    configs::Protos {
4148                        decoder: configs::DecoderType::Ultralytics,
4149                        quantization: None,
4150                        shape: vec![1, 160, 160, 32],
4151                        dshape: proto_dshape,
4152                    },
4153                    None,
4154                )
4155                .with_score_threshold(score_threshold)
4156                .with_iou_threshold(iou_threshold)
4157                .build()
4158                .unwrap()
4159        };
4160
4161        // Baseline: dshape populated in physical order.
4162        let dshaped = build_decoder(
4163            vec![
4164                (DimName::Batch, 1),
4165                (DimName::NumFeatures, 116),
4166                (DimName::NumBoxes, 8400),
4167            ],
4168            vec![
4169                (DimName::Batch, 1),
4170                (DimName::Height, 160),
4171                (DimName::Width, 160),
4172                (DimName::NumProtos, 32),
4173            ],
4174        );
4175        let mut dshaped_boxes = Vec::new();
4176        let mut dshaped_masks = Vec::new();
4177        dshaped
4178            .decode_float(
4179                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4180                &mut dshaped_boxes,
4181                &mut dshaped_masks,
4182            )
4183            .unwrap();
4184
4185        // Variant: dshape omitted — caller asserts shape is already in
4186        // the decoder's canonical order.
4187        let bare = build_decoder(vec![], vec![]);
4188        let mut bare_boxes = Vec::new();
4189        let mut bare_masks = Vec::new();
4190        bare.decode_float(
4191            &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4192            &mut bare_boxes,
4193            &mut bare_masks,
4194        )
4195        .unwrap();
4196
4197        assert_eq!(bare_boxes.len(), dshaped_boxes.len());
4198        for (b, d) in bare_boxes.iter().zip(&dshaped_boxes) {
4199            assert!(
4200                b.equal_within_delta(d, 1e-4),
4201                "dshape-omitted box {b:?} differs from dshape-populated {d:?}"
4202            );
4203        }
4204        for (bm, dm) in bare_masks.iter().zip(&dshaped_masks) {
4205            let bm_arr = segmentation_to_mask(bm.segmentation.view()).unwrap();
4206            let dm_arr = segmentation_to_mask(dm.segmentation.view()).unwrap();
4207            assert_eq!(
4208                bm_arr, dm_arr,
4209                "dshape-omitted mask must match dshape-populated pixel-for-pixel"
4210            );
4211        }
4212    }
4213
4214    /// 4D anchor-first boxes schema with a trailing `padding` axis — the
4215    /// exact Ara-2 DVM shape (`"4:1:8400:1"` innermost-first = physical
4216    /// `[1, 8400, 1, 4]`). Exercises the schema path's
4217    /// `squeeze_padding_dims`: the caller declares a 4D physical-order
4218    /// layout including `padding`, HAL squeezes the size-1 axis during
4219    /// `to_legacy_config_outputs`, and the decoder sees the resulting
4220    /// 3D shape. Callers feed HAL a 3D squeezed view of the same bytes.
4221    /// Complements the 3D `test_physical_order_ara2_anchor_first_split_boxes`
4222    /// which exercises the programmatic-builder path.
4223    #[test]
4224    fn test_physical_order_ara2_4d_anchor_first_with_padding() {
4225        // Build synthetic data in anchor-first 3D layout (the shape HAL
4226        // actually operates on after squeezing). The 4D-with-padding
4227        // declaration in the schema is a producer-side convention; the
4228        // caller is expected to present HAL with a squeezed view.
4229        const N: usize = 8400;
4230        let mut boxes = Array3::<f32>::zeros((1, N, 4));
4231        let target = 42usize;
4232        boxes[[0, target, 0]] = 0.4;
4233        boxes[[0, target, 1]] = 0.5;
4234        boxes[[0, target, 2]] = 0.2;
4235        boxes[[0, target, 3]] = 0.2;
4236        let mut scores = Array3::<f32>::zeros((1, N, 80));
4237        scores[[0, target, 0]] = 0.9;
4238
4239        // Schema JSON declares shape+dshape in physical anchor-first
4240        // order including padding — mirrors `ara2_int8_edgefirst.json`'s
4241        // feature-first declaration but with the axes in physical
4242        // anchor-first order. `to_legacy_config_outputs` squeezes
4243        // padding before the decoder's rank-3 verification.
4244        let json = r#"{
4245          "schema_version": 2,
4246          "decoder_version": "yolov8",
4247          "nms": "class_agnostic",
4248          "outputs": [
4249            {"name": "boxes", "type": "boxes",
4250             "shape": [1, 8400, 1, 4],
4251             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"box_coords":4}],
4252             "encoding": "direct",
4253             "decoder": "ultralytics",
4254             "normalized": true},
4255            {"name": "scores", "type": "scores",
4256             "shape": [1, 8400, 1, 80],
4257             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"num_classes":80}],
4258             "decoder": "ultralytics",
4259             "score_format": "per_class"}
4260          ]
4261        }"#;
4262        let decoder = DecoderBuilder::default()
4263            .with_config_json_str(json.to_string())
4264            .with_score_threshold(0.5)
4265            .with_iou_threshold(0.5)
4266            .build()
4267            .expect("4D anchor-first schema should build via squeeze_padding_dims");
4268
4269        let mut out_boxes = Vec::with_capacity(4);
4270        let mut out_masks = Vec::with_capacity(0);
4271        decoder
4272            .decode_float(
4273                &[boxes.view().into_dyn(), scores.view().into_dyn()],
4274                &mut out_boxes,
4275                &mut out_masks,
4276            )
4277            .unwrap();
4278
4279        assert_eq!(
4280            out_boxes.len(),
4281            1,
4282            "4D anchor-first with padding should decode exactly one box from the seeded anchor"
4283        );
4284        let b = &out_boxes[0];
4285        // xywh(0.4, 0.5, 0.2, 0.2) → xyxy(0.3, 0.4, 0.5, 0.6)
4286        assert!((b.bbox.xmin - 0.3).abs() < 1e-3, "xmin wrong: {b:?}");
4287        assert!((b.bbox.ymin - 0.4).abs() < 1e-3, "ymin wrong: {b:?}");
4288        assert!((b.bbox.xmax - 0.5).abs() < 1e-3, "xmax wrong: {b:?}");
4289        assert!((b.bbox.ymax - 0.6).abs() < 1e-3, "ymax wrong: {b:?}");
4290        assert_eq!(b.label, 0);
4291        assert!(b.score > 0.85, "score {}: {b:?}", b.score);
4292    }
4293}
4294
4295#[cfg(feature = "tracker")]
4296#[cfg(test)]
4297#[cfg_attr(coverage_nightly, coverage(off))]
4298mod decoder_tracked_tests {
4299
4300    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
4301    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
4302    use num_traits::{AsPrimitive, Float, PrimInt};
4303    use rand::{RngExt, SeedableRng};
4304    use rand_distr::StandardNormal;
4305
4306    use crate::{
4307        configs::{self, DimName},
4308        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
4309    };
4310
4311    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
4312        input: ArrayView<F, D>,
4313        quant: Quantization,
4314    ) -> Array<T, D>
4315    where
4316        i32: num_traits::AsPrimitive<F>,
4317        f32: num_traits::AsPrimitive<F>,
4318    {
4319        let zero_point = quant.zero_point.as_();
4320        let div_scale = F::one() / quant.scale.as_();
4321        if zero_point != F::zero() {
4322            input.mapv(|d| (d * div_scale + zero_point).round().as_())
4323        } else {
4324            input.mapv(|d| (d * div_scale).round().as_())
4325        }
4326    }
4327
4328    #[test]
4329    fn test_decoder_tracked_random_jitter() {
4330        use crate::configs::{DecoderType, Nms};
4331        use crate::DecoderBuilder;
4332
4333        let score_threshold = 0.25;
4334        let iou_threshold = 0.1;
4335        let out = include_bytes!(concat!(
4336            env!("CARGO_MANIFEST_DIR"),
4337            "/../../testdata/yolov8s_80_classes.bin"
4338        ));
4339        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4340        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4341        let quant = (0.0040811873, -123).into();
4342
4343        let decoder = DecoderBuilder::default()
4344            .with_config_yolo_det(
4345                crate::configs::Detection {
4346                    decoder: DecoderType::Ultralytics,
4347                    shape: vec![1, 84, 8400],
4348                    anchors: None,
4349                    quantization: Some(quant),
4350                    dshape: vec![
4351                        (crate::configs::DimName::Batch, 1),
4352                        (crate::configs::DimName::NumFeatures, 84),
4353                        (crate::configs::DimName::NumBoxes, 8400),
4354                    ],
4355                    normalized: Some(true),
4356                },
4357                None,
4358            )
4359            .with_score_threshold(score_threshold)
4360            .with_iou_threshold(iou_threshold)
4361            .with_nms(Some(Nms::ClassAgnostic))
4362            .build()
4363            .unwrap();
4364        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
4365
4366        let expected_boxes = [
4367            crate::DetectBox {
4368                bbox: crate::BoundingBox {
4369                    xmin: 0.5285137,
4370                    ymin: 0.05305544,
4371                    xmax: 0.87541467,
4372                    ymax: 0.9998909,
4373                },
4374                score: 0.5591227,
4375                label: 0,
4376            },
4377            crate::DetectBox {
4378                bbox: crate::BoundingBox {
4379                    xmin: 0.130598,
4380                    ymin: 0.43260583,
4381                    xmax: 0.35098213,
4382                    ymax: 0.9958097,
4383                },
4384                score: 0.33057618,
4385                label: 75,
4386            },
4387        ];
4388
4389        let mut tracker = ByteTrackBuilder::new()
4390            .track_update(0.1)
4391            .track_high_conf(0.3)
4392            .build();
4393
4394        let mut output_boxes = Vec::with_capacity(50);
4395        let mut output_masks = Vec::with_capacity(50);
4396        let mut output_tracks = Vec::with_capacity(50);
4397
4398        decoder
4399            .decode_tracked_quantized(
4400                &mut tracker,
4401                0,
4402                &[out.view().into()],
4403                &mut output_boxes,
4404                &mut output_masks,
4405                &mut output_tracks,
4406            )
4407            .unwrap();
4408
4409        assert_eq!(output_boxes.len(), 2);
4410        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4411        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4412
4413        let mut last_boxes = output_boxes.clone();
4414
4415        for i in 1..=100 {
4416            let mut out = out.clone();
4417            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
4418            let mut x_values = out.slice_mut(s![0, 0, ..]);
4419            for x in x_values.iter_mut() {
4420                let r: f32 = rng.sample(StandardNormal);
4421                let r = r.clamp(-2.0, 2.0) / 2.0;
4422                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
4423            }
4424
4425            let mut y_values = out.slice_mut(s![0, 1, ..]);
4426            for y in y_values.iter_mut() {
4427                let r: f32 = rng.sample(StandardNormal);
4428                let r = r.clamp(-2.0, 2.0) / 2.0;
4429                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
4430            }
4431
4432            decoder
4433                .decode_tracked_quantized(
4434                    &mut tracker,
4435                    100_000_000 * i / 3, // simulate 33.333ms between frames
4436                    &[out.view().into()],
4437                    &mut output_boxes,
4438                    &mut output_masks,
4439                    &mut output_tracks,
4440                )
4441                .unwrap();
4442
4443            assert_eq!(output_boxes.len(), 2);
4444            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
4445            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
4446
4447            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
4448            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
4449            last_boxes = output_boxes.clone();
4450        }
4451    }
4452
4453    // ─── Shared helpers for tracked decoder tests ────────────────────
4454
4455    fn real_data_expected_boxes() -> [DetectBox; 2] {
4456        [
4457            DetectBox {
4458                bbox: BoundingBox {
4459                    xmin: 0.08515105,
4460                    ymin: 0.7131401,
4461                    xmax: 0.29802868,
4462                    ymax: 0.8195788,
4463                },
4464                score: 0.91537374,
4465                label: 23,
4466            },
4467            DetectBox {
4468                bbox: BoundingBox {
4469                    xmin: 0.59605736,
4470                    ymin: 0.25545314,
4471                    xmax: 0.93666154,
4472                    ymax: 0.72378385,
4473                },
4474                score: 0.91537374,
4475                label: 23,
4476            },
4477        ]
4478    }
4479
4480    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
4481        [DetectBox {
4482            bbox: BoundingBox {
4483                xmin: 0.12549022,
4484                ymin: 0.12549022,
4485                xmax: 0.23529413,
4486                ymax: 0.23529413,
4487            },
4488            score: 0.98823535,
4489            label: 2,
4490        }]
4491    }
4492
4493    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
4494        [DetectBox {
4495            bbox: BoundingBox {
4496                xmin: 0.1234,
4497                ymin: 0.1234,
4498                xmax: 0.2345,
4499                ymax: 0.2345,
4500            },
4501            score: 0.9876,
4502            label: 2,
4503        }]
4504    }
4505
4506    fn build_yolo_split_segdet_decoder(
4507        score_threshold: f32,
4508        iou_threshold: f32,
4509        quant_boxes: (f32, i32),
4510        quant_protos: (f32, i32),
4511    ) -> crate::Decoder {
4512        DecoderBuilder::default()
4513            .with_config_yolo_split_segdet(
4514                configs::Boxes {
4515                    decoder: configs::DecoderType::Ultralytics,
4516                    quantization: Some(quant_boxes.into()),
4517                    shape: vec![1, 4, 8400],
4518                    dshape: vec![
4519                        (DimName::Batch, 1),
4520                        (DimName::BoxCoords, 4),
4521                        (DimName::NumBoxes, 8400),
4522                    ],
4523                    normalized: Some(true),
4524                },
4525                configs::Scores {
4526                    decoder: configs::DecoderType::Ultralytics,
4527                    quantization: Some(quant_boxes.into()),
4528                    shape: vec![1, 80, 8400],
4529                    dshape: vec![
4530                        (DimName::Batch, 1),
4531                        (DimName::NumClasses, 80),
4532                        (DimName::NumBoxes, 8400),
4533                    ],
4534                },
4535                configs::MaskCoefficients {
4536                    decoder: configs::DecoderType::Ultralytics,
4537                    quantization: Some(quant_boxes.into()),
4538                    shape: vec![1, 32, 8400],
4539                    dshape: vec![
4540                        (DimName::Batch, 1),
4541                        (DimName::NumProtos, 32),
4542                        (DimName::NumBoxes, 8400),
4543                    ],
4544                },
4545                configs::Protos {
4546                    decoder: configs::DecoderType::Ultralytics,
4547                    quantization: Some(quant_protos.into()),
4548                    shape: vec![1, 160, 160, 32],
4549                    dshape: vec![
4550                        (DimName::Batch, 1),
4551                        (DimName::Height, 160),
4552                        (DimName::Width, 160),
4553                        (DimName::NumProtos, 32),
4554                    ],
4555                },
4556            )
4557            .with_score_threshold(score_threshold)
4558            .with_iou_threshold(iou_threshold)
4559            .build()
4560            .unwrap()
4561    }
4562
4563    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4564        let config_yaml = include_str!(concat!(
4565            env!("CARGO_MANIFEST_DIR"),
4566            "/../../testdata/yolov8_seg.yaml"
4567        ));
4568        DecoderBuilder::default()
4569            .with_config_yaml_str(config_yaml.to_string())
4570            .with_score_threshold(score_threshold)
4571            .with_iou_threshold(iou_threshold)
4572            .build()
4573            .unwrap()
4574    }
4575
4576    // ─── Real-data tracked test macro ───────────────────────────────
4577    //
4578    // Generates tests that load i8 binary test data from testdata/ and
4579    // exercise all (quant/float) × (combined/split) × (masks/proto)
4580    // decoder paths.
4581
4582    macro_rules! real_data_tracked_test {
4583        ($name:ident, quantized, $layout:ident, $output:ident) => {
4584            #[test]
4585            fn $name() {
4586                let is_split = matches!(stringify!($layout), "split");
4587                let is_proto = matches!(stringify!($output), "proto");
4588
4589                let score_threshold = 0.45;
4590                let iou_threshold = 0.45;
4591                let quant_boxes = (0.021287762_f32, 31_i32);
4592                let quant_protos = (0.02491162_f32, -117_i32);
4593
4594                let raw_boxes = include_bytes!(concat!(
4595                    env!("CARGO_MANIFEST_DIR"),
4596                    "/../../testdata/yolov8_boxes_116x8400.bin"
4597                ));
4598                let raw_boxes = unsafe {
4599                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4600                };
4601                let boxes_i8 =
4602                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4603
4604                let raw_protos = include_bytes!(concat!(
4605                    env!("CARGO_MANIFEST_DIR"),
4606                    "/../../testdata/yolov8_protos_160x160x32.bin"
4607                ));
4608                let raw_protos = unsafe {
4609                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4610                };
4611                let protos_i8 =
4612                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4613                        .unwrap();
4614
4615                // Pre-split (unused for combined, but harmless)
4616                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4617                let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4618                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4619                let mut boxes_combined = boxes_i8;
4620
4621                let decoder = if is_split {
4622                    build_yolo_split_segdet_decoder(
4623                        score_threshold,
4624                        iou_threshold,
4625                        quant_boxes,
4626                        quant_protos,
4627                    )
4628                } else {
4629                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4630                };
4631
4632                let expected = real_data_expected_boxes();
4633                let mut tracker = ByteTrackBuilder::new()
4634                    .track_update(0.1)
4635                    .track_high_conf(0.7)
4636                    .build();
4637                let mut output_boxes = Vec::with_capacity(50);
4638                let mut output_tracks = Vec::with_capacity(50);
4639
4640                // Frame 1: decode
4641                if is_proto {
4642                    {
4643                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4644                            vec![
4645                                boxes_split.view().into(),
4646                                scores_split.view().into(),
4647                                mask_split.view().into(),
4648                                protos_i8.view().into(),
4649                            ]
4650                        } else {
4651                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4652                        };
4653                        decoder
4654                            .decode_tracked_quantized_proto(
4655                                &mut tracker,
4656                                0,
4657                                &inputs,
4658                                &mut output_boxes,
4659                                &mut output_tracks,
4660                            )
4661                            .unwrap();
4662                    }
4663                    assert_eq!(output_boxes.len(), 2);
4664                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4665                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4666
4667                    // Zero scores for frame 2
4668                    if is_split {
4669                        for score in scores_split.iter_mut() {
4670                            *score = i8::MIN;
4671                        }
4672                    } else {
4673                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4674                            *score = i8::MIN;
4675                        }
4676                    }
4677
4678                    let proto_result = {
4679                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4680                            vec![
4681                                boxes_split.view().into(),
4682                                scores_split.view().into(),
4683                                mask_split.view().into(),
4684                                protos_i8.view().into(),
4685                            ]
4686                        } else {
4687                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4688                        };
4689                        decoder
4690                            .decode_tracked_quantized_proto(
4691                                &mut tracker,
4692                                100_000_000 / 3,
4693                                &inputs,
4694                                &mut output_boxes,
4695                                &mut output_tracks,
4696                            )
4697                            .unwrap()
4698                    };
4699                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4700                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4701                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4702                } else {
4703                    let mut output_masks = Vec::with_capacity(50);
4704                    {
4705                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4706                            vec![
4707                                boxes_split.view().into(),
4708                                scores_split.view().into(),
4709                                mask_split.view().into(),
4710                                protos_i8.view().into(),
4711                            ]
4712                        } else {
4713                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4714                        };
4715                        decoder
4716                            .decode_tracked_quantized(
4717                                &mut tracker,
4718                                0,
4719                                &inputs,
4720                                &mut output_boxes,
4721                                &mut output_masks,
4722                                &mut output_tracks,
4723                            )
4724                            .unwrap();
4725                    }
4726                    assert_eq!(output_boxes.len(), 2);
4727                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4728                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4729
4730                    if is_split {
4731                        for score in scores_split.iter_mut() {
4732                            *score = i8::MIN;
4733                        }
4734                    } else {
4735                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4736                            *score = i8::MIN;
4737                        }
4738                    }
4739
4740                    {
4741                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4742                            vec![
4743                                boxes_split.view().into(),
4744                                scores_split.view().into(),
4745                                mask_split.view().into(),
4746                                protos_i8.view().into(),
4747                            ]
4748                        } else {
4749                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4750                        };
4751                        decoder
4752                            .decode_tracked_quantized(
4753                                &mut tracker,
4754                                100_000_000 / 3,
4755                                &inputs,
4756                                &mut output_boxes,
4757                                &mut output_masks,
4758                                &mut output_tracks,
4759                            )
4760                            .unwrap();
4761                    }
4762                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4763                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4764                    assert!(output_masks.is_empty());
4765                }
4766            }
4767        };
4768        ($name:ident, float, $layout:ident, $output:ident) => {
4769            #[test]
4770            fn $name() {
4771                let is_split = matches!(stringify!($layout), "split");
4772                let is_proto = matches!(stringify!($output), "proto");
4773
4774                let score_threshold = 0.45;
4775                let iou_threshold = 0.45;
4776                let quant_boxes = (0.021287762_f32, 31_i32);
4777                let quant_protos = (0.02491162_f32, -117_i32);
4778
4779                let raw_boxes = include_bytes!(concat!(
4780                    env!("CARGO_MANIFEST_DIR"),
4781                    "/../../testdata/yolov8_boxes_116x8400.bin"
4782                ));
4783                let raw_boxes = unsafe {
4784                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4785                };
4786                let boxes_i8 =
4787                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4788                let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4789
4790                let raw_protos = include_bytes!(concat!(
4791                    env!("CARGO_MANIFEST_DIR"),
4792                    "/../../testdata/yolov8_protos_160x160x32.bin"
4793                ));
4794                let raw_protos = unsafe {
4795                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4796                };
4797                let protos_i8 =
4798                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4799                        .unwrap();
4800                let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4801
4802                // Pre-split from dequantized data
4803                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4804                let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4805                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4806                let mut boxes_combined = boxes_f32;
4807
4808                let decoder = if is_split {
4809                    build_yolo_split_segdet_decoder(
4810                        score_threshold,
4811                        iou_threshold,
4812                        quant_boxes,
4813                        quant_protos,
4814                    )
4815                } else {
4816                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4817                };
4818
4819                let expected = real_data_expected_boxes();
4820                let mut tracker = ByteTrackBuilder::new()
4821                    .track_update(0.1)
4822                    .track_high_conf(0.7)
4823                    .build();
4824                let mut output_boxes = Vec::with_capacity(50);
4825                let mut output_tracks = Vec::with_capacity(50);
4826
4827                if is_proto {
4828                    {
4829                        let inputs = if is_split {
4830                            vec![
4831                                boxes_split.view().into_dyn(),
4832                                scores_split.view().into_dyn(),
4833                                mask_split.view().into_dyn(),
4834                                protos_f32.view().into_dyn(),
4835                            ]
4836                        } else {
4837                            vec![
4838                                boxes_combined.view().into_dyn(),
4839                                protos_f32.view().into_dyn(),
4840                            ]
4841                        };
4842                        decoder
4843                            .decode_tracked_float_proto(
4844                                &mut tracker,
4845                                0,
4846                                &inputs,
4847                                &mut output_boxes,
4848                                &mut output_tracks,
4849                            )
4850                            .unwrap();
4851                    }
4852                    assert_eq!(output_boxes.len(), 2);
4853                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4854                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4855
4856                    if is_split {
4857                        for score in scores_split.iter_mut() {
4858                            *score = 0.0;
4859                        }
4860                    } else {
4861                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4862                            *score = 0.0;
4863                        }
4864                    }
4865
4866                    let proto_result = {
4867                        let inputs = if is_split {
4868                            vec![
4869                                boxes_split.view().into_dyn(),
4870                                scores_split.view().into_dyn(),
4871                                mask_split.view().into_dyn(),
4872                                protos_f32.view().into_dyn(),
4873                            ]
4874                        } else {
4875                            vec![
4876                                boxes_combined.view().into_dyn(),
4877                                protos_f32.view().into_dyn(),
4878                            ]
4879                        };
4880                        decoder
4881                            .decode_tracked_float_proto(
4882                                &mut tracker,
4883                                100_000_000 / 3,
4884                                &inputs,
4885                                &mut output_boxes,
4886                                &mut output_tracks,
4887                            )
4888                            .unwrap()
4889                    };
4890                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4891                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4892                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4893                } else {
4894                    let mut output_masks = Vec::with_capacity(50);
4895                    {
4896                        let inputs = if is_split {
4897                            vec![
4898                                boxes_split.view().into_dyn(),
4899                                scores_split.view().into_dyn(),
4900                                mask_split.view().into_dyn(),
4901                                protos_f32.view().into_dyn(),
4902                            ]
4903                        } else {
4904                            vec![
4905                                boxes_combined.view().into_dyn(),
4906                                protos_f32.view().into_dyn(),
4907                            ]
4908                        };
4909                        decoder
4910                            .decode_tracked_float(
4911                                &mut tracker,
4912                                0,
4913                                &inputs,
4914                                &mut output_boxes,
4915                                &mut output_masks,
4916                                &mut output_tracks,
4917                            )
4918                            .unwrap();
4919                    }
4920                    assert_eq!(output_boxes.len(), 2);
4921                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4922                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4923
4924                    if is_split {
4925                        for score in scores_split.iter_mut() {
4926                            *score = 0.0;
4927                        }
4928                    } else {
4929                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4930                            *score = 0.0;
4931                        }
4932                    }
4933
4934                    {
4935                        let inputs = if is_split {
4936                            vec![
4937                                boxes_split.view().into_dyn(),
4938                                scores_split.view().into_dyn(),
4939                                mask_split.view().into_dyn(),
4940                                protos_f32.view().into_dyn(),
4941                            ]
4942                        } else {
4943                            vec![
4944                                boxes_combined.view().into_dyn(),
4945                                protos_f32.view().into_dyn(),
4946                            ]
4947                        };
4948                        decoder
4949                            .decode_tracked_float(
4950                                &mut tracker,
4951                                100_000_000 / 3,
4952                                &inputs,
4953                                &mut output_boxes,
4954                                &mut output_masks,
4955                                &mut output_tracks,
4956                            )
4957                            .unwrap();
4958                    }
4959                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4960                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4961                    assert!(output_masks.is_empty());
4962                }
4963            }
4964        };
4965    }
4966
4967    real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
4968    real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
4969    real_data_tracked_test!(
4970        test_decoder_tracked_segdet_proto,
4971        quantized,
4972        combined,
4973        proto
4974    );
4975    real_data_tracked_test!(
4976        test_decoder_tracked_segdet_proto_float,
4977        float,
4978        combined,
4979        proto
4980    );
4981    real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
4982    real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
4983    real_data_tracked_test!(
4984        test_decoder_tracked_segdet_split_proto,
4985        quantized,
4986        split,
4987        proto
4988    );
4989    real_data_tracked_test!(
4990        test_decoder_tracked_segdet_split_proto_float,
4991        float,
4992        split,
4993        proto
4994    );
4995
4996    // ─── End-to-end tracked test macro ──────────────────────────────
4997    //
4998    // Generates tests with synthetic data to exercise all tracked
4999    // decode paths without needing real model output files.
5000
5001    const E2E_COMBINED_CONFIG: &str = "
5002decoder_version: yolo26
5003outputs:
5004 - type: detection
5005   decoder: ultralytics
5006   quantization: [0.00784313725490196, 0]
5007   shape: [1, 10, 38]
5008   dshape:
5009    - [batch, 1]
5010    - [num_boxes, 10]
5011    - [num_features, 38]
5012   normalized: true
5013 - type: protos
5014   decoder: ultralytics
5015   quantization: [0.0039215686274509803921568627451, 128]
5016   shape: [1, 160, 160, 32]
5017   dshape:
5018    - [batch, 1]
5019    - [height, 160]
5020    - [width, 160]
5021    - [num_protos, 32]
5022";
5023
5024    const E2E_SPLIT_CONFIG: &str = "
5025decoder_version: yolo26
5026outputs:
5027 - type: boxes
5028   decoder: ultralytics
5029   quantization: [0.00784313725490196, 0]
5030   shape: [1, 10, 4]
5031   dshape:
5032    - [batch, 1]
5033    - [num_boxes, 10]
5034    - [box_coords, 4]
5035   normalized: true
5036 - type: scores
5037   decoder: ultralytics
5038   quantization: [0.00784313725490196, 0]
5039   shape: [1, 10, 1]
5040   dshape:
5041    - [batch, 1]
5042    - [num_boxes, 10]
5043    - [num_classes, 1]
5044 - type: classes
5045   decoder: ultralytics
5046   quantization: [0.00784313725490196, 0]
5047   shape: [1, 10, 1]
5048   dshape:
5049    - [batch, 1]
5050    - [num_boxes, 10]
5051    - [num_classes, 1]
5052 - type: mask_coefficients
5053   decoder: ultralytics
5054   quantization: [0.00784313725490196, 0]
5055   shape: [1, 10, 32]
5056   dshape:
5057    - [batch, 1]
5058    - [num_boxes, 10]
5059    - [num_protos, 32]
5060 - type: protos
5061   decoder: ultralytics
5062   quantization: [0.0039215686274509803921568627451, 128]
5063   shape: [1, 160, 160, 32]
5064   dshape:
5065    - [batch, 1]
5066    - [height, 160]
5067    - [width, 160]
5068    - [num_protos, 32]
5069";
5070
5071    macro_rules! e2e_tracked_test {
5072        ($name:ident, quantized, $layout:ident, $output:ident) => {
5073            #[test]
5074            fn $name() {
5075                let is_split = matches!(stringify!($layout), "split");
5076                let is_proto = matches!(stringify!($output), "proto");
5077
5078                let score_threshold = 0.45;
5079                let iou_threshold = 0.45;
5080
5081                let mut boxes = Array2::zeros((10, 4));
5082                let mut scores = Array2::zeros((10, 1));
5083                let mut classes = Array2::zeros((10, 1));
5084                let mask = Array2::zeros((10, 32));
5085                let protos = Array3::<f64>::zeros((160, 160, 32));
5086                let protos = protos.insert_axis(Axis(0));
5087                let protos_quant = (1.0 / 255.0, 0.0);
5088                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
5089
5090                boxes
5091                    .slice_mut(s![0, ..])
5092                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5093                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5094                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5095
5096                let detect_quant = (2.0 / 255.0, 0.0);
5097
5098                let decoder = if is_split {
5099                    DecoderBuilder::default()
5100                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5101                        .with_score_threshold(score_threshold)
5102                        .with_iou_threshold(iou_threshold)
5103                        .build()
5104                        .unwrap()
5105                } else {
5106                    DecoderBuilder::default()
5107                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5108                        .with_score_threshold(score_threshold)
5109                        .with_iou_threshold(iou_threshold)
5110                        .build()
5111                        .unwrap()
5112                };
5113
5114                let expected = e2e_expected_boxes_quant();
5115                let mut tracker = ByteTrackBuilder::new()
5116                    .track_update(0.1)
5117                    .track_high_conf(0.7)
5118                    .build();
5119                let mut output_boxes = Vec::with_capacity(50);
5120                let mut output_tracks = Vec::with_capacity(50);
5121
5122                if is_split {
5123                    let boxes = boxes.insert_axis(Axis(0));
5124                    let scores = scores.insert_axis(Axis(0));
5125                    let classes = classes.insert_axis(Axis(0));
5126                    let mask = mask.insert_axis(Axis(0));
5127
5128                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5129                    let mut scores: Array3<u8> =
5130                        quantize_ndarray(scores.view(), detect_quant.into());
5131                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
5132                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5133
5134                    if is_proto {
5135                        {
5136                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5137                                boxes.view().into(),
5138                                scores.view().into(),
5139                                classes.view().into(),
5140                                mask.view().into(),
5141                                protos.view().into(),
5142                            ];
5143                            decoder
5144                                .decode_tracked_quantized_proto(
5145                                    &mut tracker,
5146                                    0,
5147                                    &inputs,
5148                                    &mut output_boxes,
5149                                    &mut output_tracks,
5150                                )
5151                                .unwrap();
5152                        }
5153                        assert_eq!(output_boxes.len(), 1);
5154                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5155
5156                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5157                            *score = u8::MIN;
5158                        }
5159                        let proto_result = {
5160                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5161                                boxes.view().into(),
5162                                scores.view().into(),
5163                                classes.view().into(),
5164                                mask.view().into(),
5165                                protos.view().into(),
5166                            ];
5167                            decoder
5168                                .decode_tracked_quantized_proto(
5169                                    &mut tracker,
5170                                    100_000_000 / 3,
5171                                    &inputs,
5172                                    &mut output_boxes,
5173                                    &mut output_tracks,
5174                                )
5175                                .unwrap()
5176                        };
5177                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5178                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5179                    } else {
5180                        let mut output_masks = Vec::with_capacity(50);
5181                        {
5182                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5183                                boxes.view().into(),
5184                                scores.view().into(),
5185                                classes.view().into(),
5186                                mask.view().into(),
5187                                protos.view().into(),
5188                            ];
5189                            decoder
5190                                .decode_tracked_quantized(
5191                                    &mut tracker,
5192                                    0,
5193                                    &inputs,
5194                                    &mut output_boxes,
5195                                    &mut output_masks,
5196                                    &mut output_tracks,
5197                                )
5198                                .unwrap();
5199                        }
5200                        assert_eq!(output_boxes.len(), 1);
5201                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5202
5203                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5204                            *score = u8::MIN;
5205                        }
5206                        {
5207                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5208                                boxes.view().into(),
5209                                scores.view().into(),
5210                                classes.view().into(),
5211                                mask.view().into(),
5212                                protos.view().into(),
5213                            ];
5214                            decoder
5215                                .decode_tracked_quantized(
5216                                    &mut tracker,
5217                                    100_000_000 / 3,
5218                                    &inputs,
5219                                    &mut output_boxes,
5220                                    &mut output_masks,
5221                                    &mut output_tracks,
5222                                )
5223                                .unwrap();
5224                        }
5225                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5226                        assert!(output_masks.is_empty());
5227                    }
5228                } else {
5229                    // Combined layout
5230                    let detect = ndarray::concatenate![
5231                        Axis(1),
5232                        boxes.view(),
5233                        scores.view(),
5234                        classes.view(),
5235                        mask.view()
5236                    ];
5237                    let detect = detect.insert_axis(Axis(0));
5238                    assert_eq!(detect.shape(), &[1, 10, 38]);
5239                    let mut detect: Array3<u8> =
5240                        quantize_ndarray(detect.view(), detect_quant.into());
5241
5242                    if is_proto {
5243                        {
5244                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5245                                vec![detect.view().into(), protos.view().into()];
5246                            decoder
5247                                .decode_tracked_quantized_proto(
5248                                    &mut tracker,
5249                                    0,
5250                                    &inputs,
5251                                    &mut output_boxes,
5252                                    &mut output_tracks,
5253                                )
5254                                .unwrap();
5255                        }
5256                        assert_eq!(output_boxes.len(), 1);
5257                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5258
5259                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5260                            *score = u8::MIN;
5261                        }
5262                        let proto_result = {
5263                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5264                                vec![detect.view().into(), protos.view().into()];
5265                            decoder
5266                                .decode_tracked_quantized_proto(
5267                                    &mut tracker,
5268                                    100_000_000 / 3,
5269                                    &inputs,
5270                                    &mut output_boxes,
5271                                    &mut output_tracks,
5272                                )
5273                                .unwrap()
5274                        };
5275                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5276                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5277                    } else {
5278                        let mut output_masks = Vec::with_capacity(50);
5279                        {
5280                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5281                                vec![detect.view().into(), protos.view().into()];
5282                            decoder
5283                                .decode_tracked_quantized(
5284                                    &mut tracker,
5285                                    0,
5286                                    &inputs,
5287                                    &mut output_boxes,
5288                                    &mut output_masks,
5289                                    &mut output_tracks,
5290                                )
5291                                .unwrap();
5292                        }
5293                        assert_eq!(output_boxes.len(), 1);
5294                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5295
5296                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5297                            *score = u8::MIN;
5298                        }
5299                        {
5300                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5301                                vec![detect.view().into(), protos.view().into()];
5302                            decoder
5303                                .decode_tracked_quantized(
5304                                    &mut tracker,
5305                                    100_000_000 / 3,
5306                                    &inputs,
5307                                    &mut output_boxes,
5308                                    &mut output_masks,
5309                                    &mut output_tracks,
5310                                )
5311                                .unwrap();
5312                        }
5313                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5314                        assert!(output_masks.is_empty());
5315                    }
5316                }
5317            }
5318        };
5319        ($name:ident, float, $layout:ident, $output:ident) => {
5320            #[test]
5321            fn $name() {
5322                let is_split = matches!(stringify!($layout), "split");
5323                let is_proto = matches!(stringify!($output), "proto");
5324
5325                let score_threshold = 0.45;
5326                let iou_threshold = 0.45;
5327
5328                let mut boxes = Array2::zeros((10, 4));
5329                let mut scores = Array2::zeros((10, 1));
5330                let mut classes = Array2::zeros((10, 1));
5331                let mask: Array2<f64> = Array2::zeros((10, 32));
5332                let protos = Array3::<f64>::zeros((160, 160, 32));
5333                let protos = protos.insert_axis(Axis(0));
5334
5335                boxes
5336                    .slice_mut(s![0, ..])
5337                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5338                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5339                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5340
5341                let decoder = if is_split {
5342                    DecoderBuilder::default()
5343                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5344                        .with_score_threshold(score_threshold)
5345                        .with_iou_threshold(iou_threshold)
5346                        .build()
5347                        .unwrap()
5348                } else {
5349                    DecoderBuilder::default()
5350                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5351                        .with_score_threshold(score_threshold)
5352                        .with_iou_threshold(iou_threshold)
5353                        .build()
5354                        .unwrap()
5355                };
5356
5357                let expected = e2e_expected_boxes_float();
5358                let mut tracker = ByteTrackBuilder::new()
5359                    .track_update(0.1)
5360                    .track_high_conf(0.7)
5361                    .build();
5362                let mut output_boxes = Vec::with_capacity(50);
5363                let mut output_tracks = Vec::with_capacity(50);
5364
5365                if is_split {
5366                    let boxes = boxes.insert_axis(Axis(0));
5367                    let mut scores = scores.insert_axis(Axis(0));
5368                    let classes = classes.insert_axis(Axis(0));
5369                    let mask = mask.insert_axis(Axis(0));
5370
5371                    if is_proto {
5372                        {
5373                            let inputs = vec![
5374                                boxes.view().into_dyn(),
5375                                scores.view().into_dyn(),
5376                                classes.view().into_dyn(),
5377                                mask.view().into_dyn(),
5378                                protos.view().into_dyn(),
5379                            ];
5380                            decoder
5381                                .decode_tracked_float_proto(
5382                                    &mut tracker,
5383                                    0,
5384                                    &inputs,
5385                                    &mut output_boxes,
5386                                    &mut output_tracks,
5387                                )
5388                                .unwrap();
5389                        }
5390                        assert_eq!(output_boxes.len(), 1);
5391                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5392
5393                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5394                            *score = 0.0;
5395                        }
5396                        let proto_result = {
5397                            let inputs = vec![
5398                                boxes.view().into_dyn(),
5399                                scores.view().into_dyn(),
5400                                classes.view().into_dyn(),
5401                                mask.view().into_dyn(),
5402                                protos.view().into_dyn(),
5403                            ];
5404                            decoder
5405                                .decode_tracked_float_proto(
5406                                    &mut tracker,
5407                                    100_000_000 / 3,
5408                                    &inputs,
5409                                    &mut output_boxes,
5410                                    &mut output_tracks,
5411                                )
5412                                .unwrap()
5413                        };
5414                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5415                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5416                    } else {
5417                        let mut output_masks = Vec::with_capacity(50);
5418                        {
5419                            let inputs = vec![
5420                                boxes.view().into_dyn(),
5421                                scores.view().into_dyn(),
5422                                classes.view().into_dyn(),
5423                                mask.view().into_dyn(),
5424                                protos.view().into_dyn(),
5425                            ];
5426                            decoder
5427                                .decode_tracked_float(
5428                                    &mut tracker,
5429                                    0,
5430                                    &inputs,
5431                                    &mut output_boxes,
5432                                    &mut output_masks,
5433                                    &mut output_tracks,
5434                                )
5435                                .unwrap();
5436                        }
5437                        assert_eq!(output_boxes.len(), 1);
5438                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5439
5440                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5441                            *score = 0.0;
5442                        }
5443                        {
5444                            let inputs = vec![
5445                                boxes.view().into_dyn(),
5446                                scores.view().into_dyn(),
5447                                classes.view().into_dyn(),
5448                                mask.view().into_dyn(),
5449                                protos.view().into_dyn(),
5450                            ];
5451                            decoder
5452                                .decode_tracked_float(
5453                                    &mut tracker,
5454                                    100_000_000 / 3,
5455                                    &inputs,
5456                                    &mut output_boxes,
5457                                    &mut output_masks,
5458                                    &mut output_tracks,
5459                                )
5460                                .unwrap();
5461                        }
5462                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5463                        assert!(output_masks.is_empty());
5464                    }
5465                } else {
5466                    // Combined layout
5467                    let detect = ndarray::concatenate![
5468                        Axis(1),
5469                        boxes.view(),
5470                        scores.view(),
5471                        classes.view(),
5472                        mask.view()
5473                    ];
5474                    let mut detect = detect.insert_axis(Axis(0));
5475                    assert_eq!(detect.shape(), &[1, 10, 38]);
5476
5477                    if is_proto {
5478                        {
5479                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5480                            decoder
5481                                .decode_tracked_float_proto(
5482                                    &mut tracker,
5483                                    0,
5484                                    &inputs,
5485                                    &mut output_boxes,
5486                                    &mut output_tracks,
5487                                )
5488                                .unwrap();
5489                        }
5490                        assert_eq!(output_boxes.len(), 1);
5491                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5492
5493                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5494                            *score = 0.0;
5495                        }
5496                        let proto_result = {
5497                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5498                            decoder
5499                                .decode_tracked_float_proto(
5500                                    &mut tracker,
5501                                    100_000_000 / 3,
5502                                    &inputs,
5503                                    &mut output_boxes,
5504                                    &mut output_tracks,
5505                                )
5506                                .unwrap()
5507                        };
5508                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5509                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5510                    } else {
5511                        let mut output_masks = Vec::with_capacity(50);
5512                        {
5513                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5514                            decoder
5515                                .decode_tracked_float(
5516                                    &mut tracker,
5517                                    0,
5518                                    &inputs,
5519                                    &mut output_boxes,
5520                                    &mut output_masks,
5521                                    &mut output_tracks,
5522                                )
5523                                .unwrap();
5524                        }
5525                        assert_eq!(output_boxes.len(), 1);
5526                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5527
5528                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5529                            *score = 0.0;
5530                        }
5531                        {
5532                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5533                            decoder
5534                                .decode_tracked_float(
5535                                    &mut tracker,
5536                                    100_000_000 / 3,
5537                                    &inputs,
5538                                    &mut output_boxes,
5539                                    &mut output_masks,
5540                                    &mut output_tracks,
5541                                )
5542                                .unwrap();
5543                        }
5544                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5545                        assert!(output_masks.is_empty());
5546                    }
5547                }
5548            }
5549        };
5550    }
5551
5552    e2e_tracked_test!(
5553        test_decoder_tracked_end_to_end_segdet,
5554        quantized,
5555        combined,
5556        masks
5557    );
5558    e2e_tracked_test!(
5559        test_decoder_tracked_end_to_end_segdet_float,
5560        float,
5561        combined,
5562        masks
5563    );
5564    e2e_tracked_test!(
5565        test_decoder_tracked_end_to_end_segdet_proto,
5566        quantized,
5567        combined,
5568        proto
5569    );
5570    e2e_tracked_test!(
5571        test_decoder_tracked_end_to_end_segdet_proto_float,
5572        float,
5573        combined,
5574        proto
5575    );
5576    e2e_tracked_test!(
5577        test_decoder_tracked_end_to_end_segdet_split,
5578        quantized,
5579        split,
5580        masks
5581    );
5582    e2e_tracked_test!(
5583        test_decoder_tracked_end_to_end_segdet_split_float,
5584        float,
5585        split,
5586        masks
5587    );
5588    e2e_tracked_test!(
5589        test_decoder_tracked_end_to_end_segdet_split_proto,
5590        quantized,
5591        split,
5592        proto
5593    );
5594    e2e_tracked_test!(
5595        test_decoder_tracked_end_to_end_segdet_split_proto_float,
5596        float,
5597        split,
5598        proto
5599    );
5600
5601    // ─── End-to-end tracked TensorDyn test macro ────────────────────
5602    //
5603    // Same as e2e_tracked_test but wraps data in TensorDyn and exercises
5604    // the public decode_tracked / decode_proto_tracked API.
5605
5606    macro_rules! e2e_tracked_tensor_test {
5607        ($name:ident, quantized, $layout:ident, $output:ident) => {
5608            #[test]
5609            fn $name() {
5610                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5611
5612                let is_split = matches!(stringify!($layout), "split");
5613                let is_proto = matches!(stringify!($output), "proto");
5614
5615                let score_threshold = 0.45;
5616                let iou_threshold = 0.45;
5617
5618                let mut boxes = Array2::zeros((10, 4));
5619                let mut scores = Array2::zeros((10, 1));
5620                let mut classes = Array2::zeros((10, 1));
5621                let mask = Array2::zeros((10, 32));
5622                let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5623                let protos_f64 = protos_f64.insert_axis(Axis(0));
5624                let protos_quant = (1.0 / 255.0, 0.0);
5625                let protos_u8: Array4<u8> =
5626                    quantize_ndarray(protos_f64.view(), protos_quant.into());
5627
5628                boxes
5629                    .slice_mut(s![0, ..])
5630                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5631                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5632                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5633
5634                let detect_quant = (2.0 / 255.0, 0.0);
5635
5636                let decoder = if is_split {
5637                    DecoderBuilder::default()
5638                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5639                        .with_score_threshold(score_threshold)
5640                        .with_iou_threshold(iou_threshold)
5641                        .build()
5642                        .unwrap()
5643                } else {
5644                    DecoderBuilder::default()
5645                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5646                        .with_score_threshold(score_threshold)
5647                        .with_iou_threshold(iou_threshold)
5648                        .build()
5649                        .unwrap()
5650                };
5651
5652                // Helper to wrap a u8 slice into a TensorDyn
5653                let make_u8_tensor =
5654                    |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5655                        let t = Tensor::<u8>::new(shape, None, None).unwrap();
5656                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5657                        t.into()
5658                    };
5659
5660                let expected = e2e_expected_boxes_quant();
5661                let mut tracker = ByteTrackBuilder::new()
5662                    .track_update(0.1)
5663                    .track_high_conf(0.7)
5664                    .build();
5665                let mut output_boxes = Vec::with_capacity(50);
5666                let mut output_tracks = Vec::with_capacity(50);
5667
5668                let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5669
5670                if is_split {
5671                    let boxes = boxes.insert_axis(Axis(0));
5672                    let scores = scores.insert_axis(Axis(0));
5673                    let classes = classes.insert_axis(Axis(0));
5674                    let mask = mask.insert_axis(Axis(0));
5675
5676                    let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5677                    let mut scores_q: Array3<u8> =
5678                        quantize_ndarray(scores.view(), detect_quant.into());
5679                    let classes_q: Array3<u8> =
5680                        quantize_ndarray(classes.view(), detect_quant.into());
5681                    let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5682
5683                    let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5684                    let classes_td =
5685                        make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5686                    let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5687
5688                    if is_proto {
5689                        let scores_td =
5690                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5691                        decoder
5692                            .decode_proto_tracked(
5693                                &mut tracker,
5694                                0,
5695                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5696                                &mut output_boxes,
5697                                &mut output_tracks,
5698                            )
5699                            .unwrap();
5700
5701                        assert_eq!(output_boxes.len(), 1);
5702                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5703
5704                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5705                            *score = u8::MIN;
5706                        }
5707                        let scores_td =
5708                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5709                        let proto_result = decoder
5710                            .decode_proto_tracked(
5711                                &mut tracker,
5712                                100_000_000 / 3,
5713                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5714                                &mut output_boxes,
5715                                &mut output_tracks,
5716                            )
5717                            .unwrap();
5718                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5719                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5720                    } else {
5721                        let scores_td =
5722                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5723                        let mut output_masks = Vec::with_capacity(50);
5724                        decoder
5725                            .decode_tracked(
5726                                &mut tracker,
5727                                0,
5728                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5729                                &mut output_boxes,
5730                                &mut output_masks,
5731                                &mut output_tracks,
5732                            )
5733                            .unwrap();
5734
5735                        assert_eq!(output_boxes.len(), 1);
5736                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5737
5738                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5739                            *score = u8::MIN;
5740                        }
5741                        let scores_td =
5742                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5743                        decoder
5744                            .decode_tracked(
5745                                &mut tracker,
5746                                100_000_000 / 3,
5747                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5748                                &mut output_boxes,
5749                                &mut output_masks,
5750                                &mut output_tracks,
5751                            )
5752                            .unwrap();
5753                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5754                        assert!(output_masks.is_empty());
5755                    }
5756                } else {
5757                    // Combined layout
5758                    let detect = ndarray::concatenate![
5759                        Axis(1),
5760                        boxes.view(),
5761                        scores.view(),
5762                        classes.view(),
5763                        mask.view()
5764                    ];
5765                    let detect = detect.insert_axis(Axis(0));
5766                    assert_eq!(detect.shape(), &[1, 10, 38]);
5767                    // Ensure contiguous layout after concatenation for as_slice()
5768                    let detect =
5769                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5770                            .unwrap();
5771                    let mut detect_q: Array3<u8> =
5772                        quantize_ndarray(detect.view(), detect_quant.into());
5773
5774                    if is_proto {
5775                        let detect_td =
5776                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5777                        decoder
5778                            .decode_proto_tracked(
5779                                &mut tracker,
5780                                0,
5781                                &[&detect_td, &protos_td],
5782                                &mut output_boxes,
5783                                &mut output_tracks,
5784                            )
5785                            .unwrap();
5786
5787                        assert_eq!(output_boxes.len(), 1);
5788                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5789
5790                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5791                            *score = u8::MIN;
5792                        }
5793                        let detect_td =
5794                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5795                        let proto_result = decoder
5796                            .decode_proto_tracked(
5797                                &mut tracker,
5798                                100_000_000 / 3,
5799                                &[&detect_td, &protos_td],
5800                                &mut output_boxes,
5801                                &mut output_tracks,
5802                            )
5803                            .unwrap();
5804                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5805                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5806                    } else {
5807                        let detect_td =
5808                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5809                        let mut output_masks = Vec::with_capacity(50);
5810                        decoder
5811                            .decode_tracked(
5812                                &mut tracker,
5813                                0,
5814                                &[&detect_td, &protos_td],
5815                                &mut output_boxes,
5816                                &mut output_masks,
5817                                &mut output_tracks,
5818                            )
5819                            .unwrap();
5820
5821                        assert_eq!(output_boxes.len(), 1);
5822                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5823
5824                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5825                            *score = u8::MIN;
5826                        }
5827                        let detect_td =
5828                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5829                        decoder
5830                            .decode_tracked(
5831                                &mut tracker,
5832                                100_000_000 / 3,
5833                                &[&detect_td, &protos_td],
5834                                &mut output_boxes,
5835                                &mut output_masks,
5836                                &mut output_tracks,
5837                            )
5838                            .unwrap();
5839                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5840                        assert!(output_masks.is_empty());
5841                    }
5842                }
5843            }
5844        };
5845        ($name:ident, float, $layout:ident, $output:ident) => {
5846            #[test]
5847            fn $name() {
5848                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5849
5850                let is_split = matches!(stringify!($layout), "split");
5851                let is_proto = matches!(stringify!($output), "proto");
5852
5853                let score_threshold = 0.45;
5854                let iou_threshold = 0.45;
5855
5856                let mut boxes = Array2::zeros((10, 4));
5857                let mut scores = Array2::zeros((10, 1));
5858                let mut classes = Array2::zeros((10, 1));
5859                let mask: Array2<f64> = Array2::zeros((10, 32));
5860                let protos = Array3::<f64>::zeros((160, 160, 32));
5861                let protos = protos.insert_axis(Axis(0));
5862
5863                boxes
5864                    .slice_mut(s![0, ..])
5865                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5866                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5867                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5868
5869                let decoder = if is_split {
5870                    DecoderBuilder::default()
5871                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5872                        .with_score_threshold(score_threshold)
5873                        .with_iou_threshold(iou_threshold)
5874                        .build()
5875                        .unwrap()
5876                } else {
5877                    DecoderBuilder::default()
5878                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5879                        .with_score_threshold(score_threshold)
5880                        .with_iou_threshold(iou_threshold)
5881                        .build()
5882                        .unwrap()
5883                };
5884
5885                // Helper to wrap an f64 slice into a TensorDyn
5886                let make_f64_tensor =
5887                    |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
5888                        let t = Tensor::<f64>::new(shape, None, None).unwrap();
5889                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5890                        t.into()
5891                    };
5892
5893                let expected = e2e_expected_boxes_float();
5894                let mut tracker = ByteTrackBuilder::new()
5895                    .track_update(0.1)
5896                    .track_high_conf(0.7)
5897                    .build();
5898                let mut output_boxes = Vec::with_capacity(50);
5899                let mut output_tracks = Vec::with_capacity(50);
5900
5901                let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
5902
5903                if is_split {
5904                    let boxes = boxes.insert_axis(Axis(0));
5905                    let mut scores = scores.insert_axis(Axis(0));
5906                    let classes = classes.insert_axis(Axis(0));
5907                    let mask = mask.insert_axis(Axis(0));
5908
5909                    let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
5910                    let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
5911                    let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
5912
5913                    if is_proto {
5914                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5915                        decoder
5916                            .decode_proto_tracked(
5917                                &mut tracker,
5918                                0,
5919                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5920                                &mut output_boxes,
5921                                &mut output_tracks,
5922                            )
5923                            .unwrap();
5924
5925                        assert_eq!(output_boxes.len(), 1);
5926                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5927
5928                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5929                            *score = 0.0;
5930                        }
5931                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5932                        let proto_result = decoder
5933                            .decode_proto_tracked(
5934                                &mut tracker,
5935                                100_000_000 / 3,
5936                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5937                                &mut output_boxes,
5938                                &mut output_tracks,
5939                            )
5940                            .unwrap();
5941                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5942                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5943                    } else {
5944                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5945                        let mut output_masks = Vec::with_capacity(50);
5946                        decoder
5947                            .decode_tracked(
5948                                &mut tracker,
5949                                0,
5950                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5951                                &mut output_boxes,
5952                                &mut output_masks,
5953                                &mut output_tracks,
5954                            )
5955                            .unwrap();
5956
5957                        assert_eq!(output_boxes.len(), 1);
5958                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5959
5960                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5961                            *score = 0.0;
5962                        }
5963                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5964                        decoder
5965                            .decode_tracked(
5966                                &mut tracker,
5967                                100_000_000 / 3,
5968                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5969                                &mut output_boxes,
5970                                &mut output_masks,
5971                                &mut output_tracks,
5972                            )
5973                            .unwrap();
5974                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5975                        assert!(output_masks.is_empty());
5976                    }
5977                } else {
5978                    // Combined layout
5979                    let detect = ndarray::concatenate![
5980                        Axis(1),
5981                        boxes.view(),
5982                        scores.view(),
5983                        classes.view(),
5984                        mask.view()
5985                    ];
5986                    let detect = detect.insert_axis(Axis(0));
5987                    assert_eq!(detect.shape(), &[1, 10, 38]);
5988                    // Ensure contiguous layout after concatenation for as_slice()
5989                    let mut detect =
5990                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5991                            .unwrap();
5992
5993                    if is_proto {
5994                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5995                        decoder
5996                            .decode_proto_tracked(
5997                                &mut tracker,
5998                                0,
5999                                &[&detect_td, &protos_td],
6000                                &mut output_boxes,
6001                                &mut output_tracks,
6002                            )
6003                            .unwrap();
6004
6005                        assert_eq!(output_boxes.len(), 1);
6006                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6007
6008                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6009                            *score = 0.0;
6010                        }
6011                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6012                        let proto_result = decoder
6013                            .decode_proto_tracked(
6014                                &mut tracker,
6015                                100_000_000 / 3,
6016                                &[&detect_td, &protos_td],
6017                                &mut output_boxes,
6018                                &mut output_tracks,
6019                            )
6020                            .unwrap();
6021                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6022                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6023                    } else {
6024                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6025                        let mut output_masks = Vec::with_capacity(50);
6026                        decoder
6027                            .decode_tracked(
6028                                &mut tracker,
6029                                0,
6030                                &[&detect_td, &protos_td],
6031                                &mut output_boxes,
6032                                &mut output_masks,
6033                                &mut output_tracks,
6034                            )
6035                            .unwrap();
6036
6037                        assert_eq!(output_boxes.len(), 1);
6038                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6039
6040                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6041                            *score = 0.0;
6042                        }
6043                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6044                        decoder
6045                            .decode_tracked(
6046                                &mut tracker,
6047                                100_000_000 / 3,
6048                                &[&detect_td, &protos_td],
6049                                &mut output_boxes,
6050                                &mut output_masks,
6051                                &mut output_tracks,
6052                            )
6053                            .unwrap();
6054                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6055                        assert!(output_masks.is_empty());
6056                    }
6057                }
6058            }
6059        };
6060    }
6061
6062    e2e_tracked_tensor_test!(
6063        test_decoder_tracked_tensor_end_to_end_segdet,
6064        quantized,
6065        combined,
6066        masks
6067    );
6068    e2e_tracked_tensor_test!(
6069        test_decoder_tracked_tensor_end_to_end_segdet_float,
6070        float,
6071        combined,
6072        masks
6073    );
6074    e2e_tracked_tensor_test!(
6075        test_decoder_tracked_tensor_end_to_end_segdet_proto,
6076        quantized,
6077        combined,
6078        proto
6079    );
6080    e2e_tracked_tensor_test!(
6081        test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
6082        float,
6083        combined,
6084        proto
6085    );
6086    e2e_tracked_tensor_test!(
6087        test_decoder_tracked_tensor_end_to_end_segdet_split,
6088        quantized,
6089        split,
6090        masks
6091    );
6092    e2e_tracked_tensor_test!(
6093        test_decoder_tracked_tensor_end_to_end_segdet_split_float,
6094        float,
6095        split,
6096        masks
6097    );
6098    e2e_tracked_tensor_test!(
6099        test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
6100        quantized,
6101        split,
6102        proto
6103    );
6104    e2e_tracked_tensor_test!(
6105        test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
6106        float,
6107        split,
6108        proto
6109    );
6110
6111    #[test]
6112    fn test_decoder_tracked_linear_motion() {
6113        use crate::configs::{DecoderType, Nms};
6114        use crate::DecoderBuilder;
6115
6116        let score_threshold = 0.25;
6117        let iou_threshold = 0.1;
6118        let out = include_bytes!(concat!(
6119            env!("CARGO_MANIFEST_DIR"),
6120            "/../../testdata/yolov8s_80_classes.bin"
6121        ));
6122        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6123        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6124        let quant = (0.0040811873, -123).into();
6125
6126        let decoder = DecoderBuilder::default()
6127            .with_config_yolo_det(
6128                crate::configs::Detection {
6129                    decoder: DecoderType::Ultralytics,
6130                    shape: vec![1, 84, 8400],
6131                    anchors: None,
6132                    quantization: Some(quant),
6133                    dshape: vec![
6134                        (crate::configs::DimName::Batch, 1),
6135                        (crate::configs::DimName::NumFeatures, 84),
6136                        (crate::configs::DimName::NumBoxes, 8400),
6137                    ],
6138                    normalized: Some(true),
6139                },
6140                None,
6141            )
6142            .with_score_threshold(score_threshold)
6143            .with_iou_threshold(iou_threshold)
6144            .with_nms(Some(Nms::ClassAgnostic))
6145            .build()
6146            .unwrap();
6147
6148        let mut expected_boxes = [
6149            DetectBox {
6150                bbox: BoundingBox {
6151                    xmin: 0.5285137,
6152                    ymin: 0.05305544,
6153                    xmax: 0.87541467,
6154                    ymax: 0.9998909,
6155                },
6156                score: 0.5591227,
6157                label: 0,
6158            },
6159            DetectBox {
6160                bbox: BoundingBox {
6161                    xmin: 0.130598,
6162                    ymin: 0.43260583,
6163                    xmax: 0.35098213,
6164                    ymax: 0.9958097,
6165                },
6166                score: 0.33057618,
6167                label: 75,
6168            },
6169        ];
6170
6171        let mut tracker = ByteTrackBuilder::new()
6172            .track_update(0.1)
6173            .track_high_conf(0.3)
6174            .build();
6175
6176        let mut output_boxes = Vec::with_capacity(50);
6177        let mut output_masks = Vec::with_capacity(50);
6178        let mut output_tracks = Vec::with_capacity(50);
6179
6180        decoder
6181            .decode_tracked_quantized(
6182                &mut tracker,
6183                0,
6184                &[out.view().into()],
6185                &mut output_boxes,
6186                &mut output_masks,
6187                &mut output_tracks,
6188            )
6189            .unwrap();
6190
6191        assert_eq!(output_boxes.len(), 2);
6192        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6193        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
6194
6195        for i in 1..=100 {
6196            let mut out = out.clone();
6197            // introduce linear movement into the XY coordinates
6198            let mut x_values = out.slice_mut(s![0, 0, ..]);
6199            for x in x_values.iter_mut() {
6200                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
6201            }
6202
6203            decoder
6204                .decode_tracked_quantized(
6205                    &mut tracker,
6206                    100_000_000 * i / 3, // simulate 33.333ms between frames
6207                    &[out.view().into()],
6208                    &mut output_boxes,
6209                    &mut output_masks,
6210                    &mut output_tracks,
6211                )
6212                .unwrap();
6213
6214            assert_eq!(output_boxes.len(), 2);
6215        }
6216        let tracks = tracker.get_active_tracks();
6217        let predicted_boxes: Vec<_> = tracks
6218            .iter()
6219            .map(|track| {
6220                let mut l = track.last_box;
6221                l.bbox = track.info.tracked_location.into();
6222                l
6223            })
6224            .collect();
6225        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
6226        expected_boxes[0].bbox.xmax += 0.1;
6227        expected_boxes[1].bbox.xmin += 0.1;
6228        expected_boxes[1].bbox.xmax += 0.1;
6229
6230        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6231        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6232
6233        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6234        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
6235        for score in scores_values.iter_mut() {
6236            *score = i8::MIN; // set all scores to minimum to simulate no detections
6237        }
6238        decoder
6239            .decode_tracked_quantized(
6240                &mut tracker,
6241                100_000_000 * 101 / 3,
6242                &[out.view().into()],
6243                &mut output_boxes,
6244                &mut output_masks,
6245                &mut output_tracks,
6246            )
6247            .unwrap();
6248        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
6249        expected_boxes[0].bbox.xmax += 0.001;
6250        expected_boxes[1].bbox.xmin += 0.001;
6251        expected_boxes[1].bbox.xmax += 0.001;
6252
6253        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6254        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6255    }
6256
6257    #[test]
6258    fn test_decoder_tracked_end_to_end_float() {
6259        let score_threshold = 0.45;
6260        let iou_threshold = 0.45;
6261
6262        let mut boxes = Array2::zeros((10, 4));
6263        let mut scores = Array2::zeros((10, 1));
6264        let mut classes = Array2::zeros((10, 1));
6265
6266        boxes
6267            .slice_mut(s![0, ..,])
6268            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
6269        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
6270        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
6271
6272        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
6273        let mut detect = detect.insert_axis(Axis(0));
6274        assert_eq!(detect.shape(), &[1, 10, 6]);
6275        let config = "
6276decoder_version: yolo26
6277outputs:
6278 - type: detection
6279   decoder: ultralytics
6280   quantization: [0.00784313725490196, 0]
6281   shape: [1, 10, 6]
6282   dshape:
6283    - [batch, 1]
6284    - [num_boxes, 10]
6285    - [num_features, 6]
6286   normalized: true
6287";
6288
6289        let decoder = DecoderBuilder::default()
6290            .with_config_yaml_str(config.to_string())
6291            .with_score_threshold(score_threshold)
6292            .with_iou_threshold(iou_threshold)
6293            .build()
6294            .unwrap();
6295
6296        let expected_boxes = [DetectBox {
6297            bbox: BoundingBox {
6298                xmin: 0.1234,
6299                ymin: 0.1234,
6300                xmax: 0.2345,
6301                ymax: 0.2345,
6302            },
6303            score: 0.9876,
6304            label: 2,
6305        }];
6306
6307        let mut tracker = ByteTrackBuilder::new()
6308            .track_update(0.1)
6309            .track_high_conf(0.7)
6310            .build();
6311
6312        let mut output_boxes = Vec::with_capacity(50);
6313        let mut output_masks = Vec::with_capacity(50);
6314        let mut output_tracks = Vec::with_capacity(50);
6315
6316        decoder
6317            .decode_tracked_float(
6318                &mut tracker,
6319                0,
6320                &[detect.view().into_dyn()],
6321                &mut output_boxes,
6322                &mut output_masks,
6323                &mut output_tracks,
6324            )
6325            .unwrap();
6326
6327        assert_eq!(output_boxes.len(), 1);
6328        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6329
6330        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6331
6332        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6333            *score = 0.0; // set all scores to minimum to simulate no detections
6334        }
6335
6336        decoder
6337            .decode_tracked_float(
6338                &mut tracker,
6339                100_000_000 / 3,
6340                &[detect.view().into_dyn()],
6341                &mut output_boxes,
6342                &mut output_masks,
6343                &mut output_tracks,
6344            )
6345            .unwrap();
6346        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6347    }
6348}