Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod schema;
77pub mod yolo;
78
79mod decoder;
80pub use decoder::*;
81
82pub use configs::{DecoderVersion, Nms};
83pub use error::{DecoderError, DecoderResult};
84
85use crate::{
86    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
87    yolo::yolo_segmentation_to_mask,
88};
89
90/// Trait to convert bounding box formats to XYXY float format
91pub trait BBoxTypeTrait {
92    /// Converts the bbox into XYXY float format.
93    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
94
95    /// Converts the bbox into XYXY float format.
96    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
97        input: &[B; 4],
98        quant: Quantization,
99    ) -> [A; 4]
100    where
101        f32: AsPrimitive<A>,
102        i32: AsPrimitive<A>;
103
104    /// Converts the bbox into XYXY float format.
105    ///
106    /// # Examples
107    /// ```rust
108    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
109    /// # use ndarray::array;
110    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
111    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
112    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
113    /// ```
114    #[inline(always)]
115    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
116        input: ArrayView1<B>,
117    ) -> [A; 4] {
118        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
119    }
120
121    #[inline(always)]
122    /// Converts the bbox into XYXY float format.
123    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
124        input: ArrayView1<B>,
125        quant: Quantization,
126    ) -> [A; 4]
127    where
128        f32: AsPrimitive<A>,
129        i32: AsPrimitive<A>,
130    {
131        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
132    }
133}
134
135/// Converts XYXY bounding boxes to XYXY
136#[derive(Debug, Clone, Copy, PartialEq, Eq)]
137pub struct XYXY {}
138
139impl BBoxTypeTrait for XYXY {
140    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
141        input.map(|b| b.as_())
142    }
143
144    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
145        input: &[B; 4],
146        quant: Quantization,
147    ) -> [A; 4]
148    where
149        f32: AsPrimitive<A>,
150        i32: AsPrimitive<A>,
151    {
152        let scale = quant.scale.as_();
153        let zp = quant.zero_point.as_();
154        input.map(|b| (b.as_() - zp) * scale)
155    }
156
157    #[inline(always)]
158    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
159        input: ArrayView1<B>,
160    ) -> [A; 4] {
161        [
162            input[0].as_(),
163            input[1].as_(),
164            input[2].as_(),
165            input[3].as_(),
166        ]
167    }
168}
169
170/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
171/// box
172#[derive(Debug, Clone, Copy, PartialEq, Eq)]
173pub struct XYWH {}
174
175impl BBoxTypeTrait for XYWH {
176    #[inline(always)]
177    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
178        let half = A::one() / (A::one() + A::one());
179        [
180            (input[0].as_()) - (input[2].as_() * half),
181            (input[1].as_()) - (input[3].as_() * half),
182            (input[0].as_()) + (input[2].as_() * half),
183            (input[1].as_()) + (input[3].as_() * half),
184        ]
185    }
186
187    #[inline(always)]
188    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
189        input: &[B; 4],
190        quant: Quantization,
191    ) -> [A; 4]
192    where
193        f32: AsPrimitive<A>,
194        i32: AsPrimitive<A>,
195    {
196        let scale = quant.scale.as_();
197        let half_scale = (quant.scale * 0.5).as_();
198        let zp = quant.zero_point.as_();
199        let [x, y, w, h] = [
200            (input[0].as_() - zp) * scale,
201            (input[1].as_() - zp) * scale,
202            (input[2].as_() - zp) * half_scale,
203            (input[3].as_() - zp) * half_scale,
204        ];
205
206        [x - w, y - h, x + w, y + h]
207    }
208
209    #[inline(always)]
210    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
211        input: ArrayView1<B>,
212    ) -> [A; 4] {
213        let half = A::one() / (A::one() + A::one());
214        [
215            (input[0].as_()) - (input[2].as_() * half),
216            (input[1].as_()) - (input[3].as_() * half),
217            (input[0].as_()) + (input[2].as_() * half),
218            (input[1].as_()) + (input[3].as_() * half),
219        ]
220    }
221}
222
223/// Describes the quantization parameters for a tensor
224#[derive(Debug, Clone, Copy, PartialEq)]
225pub struct Quantization {
226    pub scale: f32,
227    pub zero_point: i32,
228}
229
230impl Quantization {
231    /// Creates a new Quantization struct
232    /// # Examples
233    /// ```
234    /// # use edgefirst_decoder::Quantization;
235    /// let quant = Quantization::new(0.1, -128);
236    /// assert_eq!(quant.scale, 0.1);
237    /// assert_eq!(quant.zero_point, -128);
238    /// ```
239    pub fn new(scale: f32, zero_point: i32) -> Self {
240        Self { scale, zero_point }
241    }
242}
243
244impl From<QuantTuple> for Quantization {
245    /// Creates a new Quantization struct from a QuantTuple
246    /// # Examples
247    /// ```
248    /// # use edgefirst_decoder::Quantization;
249    /// # use edgefirst_decoder::configs::QuantTuple;
250    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
251    /// let quant = Quantization::from(quant_tuple);
252    /// assert_eq!(quant.scale, 0.1);
253    /// assert_eq!(quant.zero_point, -128);
254    /// ```
255    fn from(quant_tuple: QuantTuple) -> Quantization {
256        Quantization {
257            scale: quant_tuple.0,
258            zero_point: quant_tuple.1,
259        }
260    }
261}
262
263impl<S, Z> From<(S, Z)> for Quantization
264where
265    S: AsPrimitive<f32>,
266    Z: AsPrimitive<i32>,
267{
268    /// Creates a new Quantization struct from a tuple
269    /// # Examples
270    /// ```
271    /// # use edgefirst_decoder::Quantization;
272    /// let quant = Quantization::from((0.1_f64, -128_i64));
273    /// assert_eq!(quant.scale, 0.1);
274    /// assert_eq!(quant.zero_point, -128);
275    /// ```
276    fn from((scale, zp): (S, Z)) -> Quantization {
277        Self {
278            scale: scale.as_(),
279            zero_point: zp.as_(),
280        }
281    }
282}
283
284impl Default for Quantization {
285    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
286    /// # Examples
287    /// ```rust
288    /// # use edgefirst_decoder::Quantization;
289    /// let quant = Quantization::default();
290    /// assert_eq!(quant.scale, 1.0);
291    /// assert_eq!(quant.zero_point, 0);
292    /// ```
293    fn default() -> Self {
294        Self {
295            scale: 1.0,
296            zero_point: 0,
297        }
298    }
299}
300
301/// A detection box with f32 bbox and score
302#[derive(Debug, Clone, Copy, PartialEq, Default)]
303pub struct DetectBox {
304    pub bbox: BoundingBox,
305    /// model-specific score for this detection, higher implies more confidence
306    pub score: f32,
307    /// label index for this detection
308    pub label: usize,
309}
310
311/// A bounding box with f32 coordinates in XYXY format
312#[derive(Debug, Clone, Copy, PartialEq, Default)]
313pub struct BoundingBox {
314    /// left-most normalized coordinate of the bounding box
315    pub xmin: f32,
316    /// top-most normalized coordinate of the bounding box
317    pub ymin: f32,
318    /// right-most normalized coordinate of the bounding box
319    pub xmax: f32,
320    /// bottom-most normalized coordinate of the bounding box
321    pub ymax: f32,
322}
323
324impl BoundingBox {
325    /// Creates a new BoundingBox from the given coordinates
326    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
327        Self {
328            xmin,
329            ymin,
330            xmax,
331            ymax,
332        }
333    }
334
335    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
336    ///
337    /// ```
338    /// # use edgefirst_decoder::BoundingBox;
339    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
340    /// let canonical_bbox = bbox.to_canonical();
341    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
342    /// ```
343    pub fn to_canonical(&self) -> Self {
344        let xmin = self.xmin.min(self.xmax);
345        let xmax = self.xmin.max(self.xmax);
346        let ymin = self.ymin.min(self.ymax);
347        let ymax = self.ymin.max(self.ymax);
348        BoundingBox {
349            xmin,
350            ymin,
351            xmax,
352            ymax,
353        }
354    }
355}
356
357impl From<BoundingBox> for [f32; 4] {
358    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
359    /// xmax, ymax order
360    /// # Examples
361    /// ```
362    /// # use edgefirst_decoder::BoundingBox;
363    /// let bbox = BoundingBox {
364    ///     xmin: 0.1,
365    ///     ymin: 0.2,
366    ///     xmax: 0.3,
367    ///     ymax: 0.4,
368    /// };
369    /// let arr: [f32; 4] = bbox.into();
370    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
371    /// ```
372    fn from(b: BoundingBox) -> Self {
373        [b.xmin, b.ymin, b.xmax, b.ymax]
374    }
375}
376
377impl From<[f32; 4]> for BoundingBox {
378    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
379    // BoundingBox
380    fn from(arr: [f32; 4]) -> Self {
381        BoundingBox {
382            xmin: arr[0],
383            ymin: arr[1],
384            xmax: arr[2],
385            ymax: arr[3],
386        }
387    }
388}
389
390impl DetectBox {
391    /// Returns true if one detect box is equal to another detect box, within
392    /// the given `eps`
393    ///
394    /// # Examples
395    /// ```
396    /// # use edgefirst_decoder::DetectBox;
397    /// let box1 = DetectBox {
398    ///     bbox: edgefirst_decoder::BoundingBox {
399    ///         xmin: 0.1,
400    ///         ymin: 0.2,
401    ///         xmax: 0.3,
402    ///         ymax: 0.4,
403    ///     },
404    ///     score: 0.5,
405    ///     label: 1,
406    /// };
407    /// let box2 = DetectBox {
408    ///     bbox: edgefirst_decoder::BoundingBox {
409    ///         xmin: 0.101,
410    ///         ymin: 0.199,
411    ///         xmax: 0.301,
412    ///         ymax: 0.399,
413    ///     },
414    ///     score: 0.510,
415    ///     label: 1,
416    /// };
417    /// assert!(box1.equal_within_delta(&box2, 0.011));
418    /// ```
419    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
420        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
421        self.label == rhs.label
422            && eq_delta(self.score, rhs.score)
423            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
424            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
425            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
426            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
427    }
428}
429
430/// A segmentation result with a segmentation mask, and a normalized bounding
431/// box representing the area that the segmentation mask covers
432#[derive(Debug, Clone, PartialEq, Default)]
433pub struct Segmentation {
434    /// left-most normalized coordinate of the segmentation box
435    pub xmin: f32,
436    /// top-most normalized coordinate of the segmentation box
437    pub ymin: f32,
438    /// right-most normalized coordinate of the segmentation box
439    pub xmax: f32,
440    /// bottom-most normalized coordinate of the segmentation box
441    pub ymax: f32,
442    /// 3D segmentation array of shape `(H, W, C)`.
443    ///
444    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
445    /// continuous sigmoid confidence values quantized to u8 (0 = background,
446    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
447    /// 0.5) or use smooth interpolation for anti-aliased edges.
448    ///
449    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
450    /// class scores where the object class is the argmax index.
451    pub segmentation: Array3<u8>,
452}
453
454/// Memory layout of the prototype tensor within [`ProtoData`].
455///
456/// Models may output protos in either channel-last (NHWC) or channel-first
457/// (NCHW) layout. The mask materialisation kernels dispatch on this field to
458/// avoid a costly per-frame transpose.
459#[derive(Debug, Clone, Copy, PartialEq, Eq)]
460pub enum ProtoLayout {
461    /// Channel-last: tensor shape is `[H, W, K]`, contiguous along K.
462    /// This is the traditional layout produced by the `extract_proto_data`
463    /// path after transposing from NCHW model outputs.
464    Nhwc,
465    /// Channel-first: tensor shape is `[K, H, W]`, each channel plane is
466    /// contiguous. Skipping the NCHW→NHWC transpose saves ~3 ms per frame
467    /// on Cortex-A53/A55 targets (819 KB for 32×160×160 protos).
468    Nchw,
469}
470
471/// Raw prototype data for fused decode+render pipelines.
472///
473/// Holds post-NMS intermediate state before mask materialization, allowing the
474/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
475/// shader) without materializing intermediate `Array3<u8>` masks.
476///
477/// Both fields are carried as [`TensorDyn`] so downstream consumers (Rust, C
478/// API, Python) get zero-copy typed access through the HAL's shared tensor
479/// infrastructure. Dtype policy:
480///
481/// | Source model | protos dtype | mask_coefficients dtype | protos.quantization |
482/// |---|---|---|---|
483/// | int8 quantized | [`TensorDyn::I8`] | [`TensorDyn::I8`] (raw + quantization) | `Some(q)` |
484/// | f32 | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
485/// | f16 (TensorRT fp16) | [`TensorDyn::F16`] | [`TensorDyn::F16`] | `None` |
486/// | f64 (narrowed) | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
487///
488/// Quantization metadata lives on the proto tensor itself via
489/// [`edgefirst_tensor::Tensor::quantization`] — float tensors cannot carry
490/// quantization (compile-time gated on the `IntegerType` sealed trait).
491///
492/// `TensorDyn` is not `Clone`, so neither is `ProtoData`. Consumers that need
493/// to share the proto buffer across threads should use `TensorDyn::clone_fd`
494/// / `dmabuf_clone` to dup the backing fd.
495#[derive(Debug)]
496pub struct ProtoData {
497    /// Per-detection mask coefficients, shape `[num_detections, num_protos]`.
498    pub mask_coefficients: edgefirst_tensor::TensorDyn,
499    /// Prototype tensor.
500    ///
501    /// - When `layout == ProtoLayout::Nhwc`: shape is `[proto_h, proto_w, num_protos]`.
502    /// - When `layout == ProtoLayout::Nchw`: shape is `[num_protos, proto_h, proto_w]`.
503    pub protos: edgefirst_tensor::TensorDyn,
504    /// Physical memory layout of the `protos` tensor.
505    pub layout: ProtoLayout,
506}
507
508/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
509///
510///  # Examples
511/// ```
512/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
513/// let quant = Quantization::new(0.1, -128);
514/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
515/// let detect_quant = DetectBoxQuantized {
516///     bbox,
517///     score: 100_i8,
518///     label: 1,
519/// };
520/// let detect = dequant_detect_box(&detect_quant, quant);
521/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
522/// assert_eq!(detect.label, 1);
523/// assert_eq!(detect.bbox, bbox);
524/// ```
525pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
526    detect: &DetectBoxQuantized<SCORE>,
527    quant_scores: Quantization,
528) -> DetectBox {
529    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
530    DetectBox {
531        bbox: detect.bbox,
532        score: quant_scores.scale * detect.score.as_() + scaled_zp,
533        label: detect.label,
534    }
535}
536/// A detection box with a f32 bbox and quantized score
537#[derive(Debug, Clone, Copy, PartialEq)]
538pub struct DetectBoxQuantized<
539    // BOX: Signed + PrimInt + AsPrimitive<f32>,
540    SCORE: PrimInt + AsPrimitive<f32>,
541> {
542    // pub bbox: BoundingBoxQuantized<BOX>,
543    pub bbox: BoundingBox,
544    /// model-specific score for this detection, higher implies more
545    /// confidence.
546    pub score: SCORE,
547    /// label index for this detect
548    pub label: usize,
549}
550
551/// Dequantizes an ndarray from quantized values to f32 values using the given
552/// quantization parameters
553///
554/// # Examples
555/// ```
556/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
557/// let quant = Quantization::new(0.1, -128);
558/// let input: Vec<i8> = vec![0, 127, -128, 64];
559/// let input_array = ndarray::Array1::from(input);
560/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
561/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
562/// ```
563pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
564    input: ArrayView<T, D>,
565    quant: Quantization,
566) -> Array<F, D>
567where
568    i32: num_traits::AsPrimitive<F>,
569    f32: num_traits::AsPrimitive<F>,
570{
571    let zero_point = quant.zero_point.as_();
572    let scale = quant.scale.as_();
573    if zero_point != F::zero() {
574        let scaled_zero = -zero_point * scale;
575        input.mapv(|d| d.as_() * scale + scaled_zero)
576    } else {
577        input.mapv(|d| d.as_() * scale)
578    }
579}
580
581/// Dequantizes a slice from quantized values to float values using the given
582/// quantization parameters
583///
584/// # Examples
585/// ```
586/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
587/// let quant = Quantization::new(0.1, -128);
588/// let input: Vec<i8> = vec![0, 127, -128, 64];
589/// let mut output: Vec<f32> = vec![0.0; input.len()];
590/// dequantize_cpu(&input, quant, &mut output);
591/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
592/// ```
593pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
594    input: &[T],
595    quant: Quantization,
596    output: &mut [F],
597) where
598    f32: num_traits::AsPrimitive<F>,
599    i32: num_traits::AsPrimitive<F>,
600{
601    assert!(input.len() == output.len());
602    let zero_point = quant.zero_point.as_();
603    let scale = quant.scale.as_();
604    if zero_point != F::zero() {
605        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
606        input
607            .iter()
608            .zip(output)
609            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
610    } else {
611        input
612            .iter()
613            .zip(output)
614            .for_each(|(d, deq)| *deq = d.as_() * scale);
615    }
616}
617
618/// Dequantizes a slice from quantized values to float values using the given
619/// quantization parameters, using chunked processing. This is around 5% faster
620/// than `dequantize_cpu` for large slices.
621///
622/// # Examples
623/// ```
624/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
625/// let quant = Quantization::new(0.1, -128);
626/// let input: Vec<i8> = vec![0, 127, -128, 64];
627/// let mut output: Vec<f32> = vec![0.0; input.len()];
628/// dequantize_cpu_chunked(&input, quant, &mut output);
629/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
630/// ```
631pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
632    input: &[T],
633    quant: Quantization,
634    output: &mut [F],
635) where
636    f32: num_traits::AsPrimitive<F>,
637    i32: num_traits::AsPrimitive<F>,
638{
639    assert!(input.len() == output.len());
640    let zero_point = quant.zero_point.as_();
641    let scale = quant.scale.as_();
642
643    let input = input.as_chunks::<4>();
644    let output = output.as_chunks_mut::<4>();
645
646    if zero_point != F::zero() {
647        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
648
649        input
650            .0
651            .iter()
652            .zip(output.0)
653            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
654        input
655            .1
656            .iter()
657            .zip(output.1)
658            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
659    } else {
660        input
661            .0
662            .iter()
663            .zip(output.0)
664            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
665        input
666            .1
667            .iter()
668            .zip(output.1)
669            .for_each(|(d, deq)| *deq = d.as_() * scale);
670    }
671}
672
673/// Converts a segmentation tensor into a 2D mask
674/// If the last dimension of the segmentation tensor is 1, values equal or
675/// above 128 are considered objects. Otherwise the object is the argmax index
676///
677/// # Errors
678///
679/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
680/// invalid shape.
681///
682/// # Examples
683/// ```
684/// # use edgefirst_decoder::segmentation_to_mask;
685/// let segmentation =
686///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
687/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
688/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
689/// ```
690pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
691    if segmentation.shape()[2] == 0 {
692        return Err(DecoderError::InvalidShape(
693            "Segmentation tensor must have non-zero depth".to_string(),
694        ));
695    }
696    if segmentation.shape()[2] == 1 {
697        yolo_segmentation_to_mask(segmentation, 128)
698    } else {
699        Ok(modelpack_segmentation_to_mask(segmentation))
700    }
701}
702
703/// Returns the maximum value and its index from a 1D array
704fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
705    score
706        .iter()
707        .enumerate()
708        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
709            if max > *s {
710                (max, arg_max)
711            } else {
712                (*s, ind)
713            }
714        })
715}
716
717/// NEON-accelerated argmax for i8 slices on aarch64.
718///
719/// Finds the maximum value and its index (last index wins on ties, matching
720/// the semantics of [`arg_max`]).  Falls back to the scalar implementation
721/// on non-aarch64 targets or when the slice is too short to benefit from NEON.
722#[cfg(target_arch = "aarch64")]
723pub(crate) fn arg_max_i8(scores: &[i8]) -> (i8, usize) {
724    use std::arch::aarch64::*;
725
726    let n = scores.len();
727    if n < 16 {
728        // Scalar fallback for very short slices.
729        let mut max = scores[0];
730        let mut idx = 0;
731        for (i, &s) in scores.iter().enumerate().skip(1) {
732            if s >= max {
733                max = s;
734                idx = i;
735            }
736        }
737        return (max, idx);
738    }
739
740    unsafe {
741        // Phase 1: Find the global max value using NEON horizontal max.
742        let chunks = n / 16;
743        let mut vmax = vld1q_s8(scores.as_ptr());
744        for i in 1..chunks {
745            let v = vld1q_s8(scores.as_ptr().add(i * 16));
746            vmax = vmaxq_s8(vmax, v);
747        }
748        let global_max = vmaxvq_s8(vmax);
749
750        // Handle remainder with scalar
751        let remainder_start = chunks * 16;
752        let mut final_max = global_max;
753        for &s in &scores[remainder_start..] {
754            if s > final_max {
755                final_max = s;
756            }
757        }
758
759        // Phase 2: Find the LAST index of `final_max` (preserves tie semantics).
760        // Scan backwards for the last occurrence.
761        let mut idx = 0;
762        for i in (0..n).rev() {
763            if scores[i] == final_max {
764                idx = i;
765                break;
766            }
767        }
768        (final_max, idx)
769    }
770}
771#[cfg(test)]
772#[cfg_attr(coverage_nightly, coverage(off))]
773mod decoder_tests {
774    #![allow(clippy::excessive_precision)]
775    use crate::{
776        configs::{DecoderType, DimName, Protos},
777        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
778        yolo::{
779            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
780            decode_yolo_segdet_quant,
781        },
782        *,
783    };
784    use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
785    use ndarray::Dimension;
786    use ndarray::{array, s, Array2, Array3, Array4, Axis};
787    use ndarray_stats::DeviationExt;
788    use num_traits::{AsPrimitive, PrimInt};
789
790    fn compare_outputs(
791        boxes: (&[DetectBox], &[DetectBox]),
792        masks: (&[Segmentation], &[Segmentation]),
793    ) {
794        let (boxes0, boxes1) = boxes;
795        let (masks0, masks1) = masks;
796
797        assert_eq!(boxes0.len(), boxes1.len());
798        assert_eq!(masks0.len(), masks1.len());
799
800        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
801            assert!(
802                b_i8.equal_within_delta(b_f32, 1e-6),
803                "{b_i8:?} is not equal to {b_f32:?}"
804            );
805        }
806
807        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
808            assert_eq!(
809                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
810                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
811            );
812            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
813            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
814            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
815            let diff = &mask_i8 - &mask_f32;
816            for x in 0..diff.shape()[0] {
817                for y in 0..diff.shape()[1] {
818                    for z in 0..diff.shape()[2] {
819                        let val = diff[[x, y, z]];
820                        assert!(
821                            val.abs() <= 1,
822                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
823                            x,
824                            y,
825                            z,
826                            val
827                        );
828                    }
829                }
830            }
831            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
832            assert!(
833                mean_sq_err < 1e-2,
834                "Mean Square Error between masks was greater than 1%: {:.2}%",
835                mean_sq_err * 100.0
836            );
837        }
838    }
839
840    // ─── Shared test data loaders ────────────────────────
841
842    fn load_yolov8_boxes() -> Array3<i8> {
843        let raw = include_bytes!(concat!(
844            env!("CARGO_MANIFEST_DIR"),
845            "/../../testdata/yolov8_boxes_116x8400.bin"
846        ));
847        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
848        Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
849    }
850
851    fn load_yolov8_protos() -> Array4<i8> {
852        let raw = include_bytes!(concat!(
853            env!("CARGO_MANIFEST_DIR"),
854            "/../../testdata/yolov8_protos_160x160x32.bin"
855        ));
856        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
857        Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
858    }
859
860    fn load_yolov8s_det() -> Array3<i8> {
861        let raw = include_bytes!(concat!(
862            env!("CARGO_MANIFEST_DIR"),
863            "/../../testdata/yolov8s_80_classes.bin"
864        ));
865        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
866        Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
867    }
868
869    #[test]
870    fn test_decoder_modelpack() {
871        let score_threshold = 0.45;
872        let iou_threshold = 0.45;
873        let boxes = include_bytes!(concat!(
874            env!("CARGO_MANIFEST_DIR"),
875            "/../../testdata/modelpack_boxes_1935x1x4.bin"
876        ));
877        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
878
879        let scores = include_bytes!(concat!(
880            env!("CARGO_MANIFEST_DIR"),
881            "/../../testdata/modelpack_scores_1935x1.bin"
882        ));
883        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
884
885        let quant_boxes = (0.004656755365431309, 21).into();
886        let quant_scores = (0.0019603664986789227, 0).into();
887
888        let decoder = DecoderBuilder::default()
889            .with_config_modelpack_det(
890                configs::Boxes {
891                    decoder: DecoderType::ModelPack,
892                    quantization: Some(quant_boxes),
893                    shape: vec![1, 1935, 1, 4],
894                    dshape: vec![
895                        (DimName::Batch, 1),
896                        (DimName::NumBoxes, 1935),
897                        (DimName::Padding, 1),
898                        (DimName::BoxCoords, 4),
899                    ],
900                    normalized: Some(true),
901                },
902                configs::Scores {
903                    decoder: DecoderType::ModelPack,
904                    quantization: Some(quant_scores),
905                    shape: vec![1, 1935, 1],
906                    dshape: vec![
907                        (DimName::Batch, 1),
908                        (DimName::NumBoxes, 1935),
909                        (DimName::NumClasses, 1),
910                    ],
911                },
912            )
913            .with_score_threshold(score_threshold)
914            .with_iou_threshold(iou_threshold)
915            .build()
916            .unwrap();
917
918        let quant_boxes = quant_boxes.into();
919        let quant_scores = quant_scores.into();
920
921        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
922        decode_modelpack_det(
923            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
924            (scores.slice(s![0, .., ..]), quant_scores),
925            score_threshold,
926            iou_threshold,
927            &mut output_boxes,
928        );
929        assert!(output_boxes[0].equal_within_delta(
930            &DetectBox {
931                bbox: BoundingBox {
932                    xmin: 0.40513772,
933                    ymin: 0.6379755,
934                    xmax: 0.5122431,
935                    ymax: 0.7730214,
936                },
937                score: 0.4861709,
938                label: 0
939            },
940            1e-6
941        ));
942
943        let mut output_boxes1 = Vec::with_capacity(50);
944        let mut output_masks1 = Vec::with_capacity(50);
945
946        decoder
947            .decode_quantized(
948                &[boxes.view().into(), scores.view().into()],
949                &mut output_boxes1,
950                &mut output_masks1,
951            )
952            .unwrap();
953
954        let mut output_boxes_float = Vec::with_capacity(50);
955        let mut output_masks_float = Vec::with_capacity(50);
956
957        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
958        let scores = dequantize_ndarray(scores.view(), quant_scores);
959
960        decoder
961            .decode_float::<f32>(
962                &[boxes.view().into_dyn(), scores.view().into_dyn()],
963                &mut output_boxes_float,
964                &mut output_masks_float,
965            )
966            .unwrap();
967
968        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
969        compare_outputs(
970            (&output_boxes, &output_boxes_float),
971            (&[], &output_masks_float),
972        );
973    }
974
975    #[test]
976    fn test_decoder_modelpack_split_u8() {
977        let score_threshold = 0.45;
978        let iou_threshold = 0.45;
979        let detect0 = include_bytes!(concat!(
980            env!("CARGO_MANIFEST_DIR"),
981            "/../../testdata/modelpack_split_9x15x18.bin"
982        ));
983        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
984
985        let detect1 = include_bytes!(concat!(
986            env!("CARGO_MANIFEST_DIR"),
987            "/../../testdata/modelpack_split_17x30x18.bin"
988        ));
989        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
990
991        let quant0 = (0.08547406643629074, 174).into();
992        let quant1 = (0.09929127991199493, 183).into();
993        let anchors0 = vec![
994            [0.36666667461395264, 0.31481480598449707],
995            [0.38749998807907104, 0.4740740656852722],
996            [0.5333333611488342, 0.644444465637207],
997        ];
998        let anchors1 = vec![
999            [0.13750000298023224, 0.2074074000120163],
1000            [0.2541666626930237, 0.21481481194496155],
1001            [0.23125000298023224, 0.35185185074806213],
1002        ];
1003
1004        let detect_config0 = configs::Detection {
1005            decoder: DecoderType::ModelPack,
1006            shape: vec![1, 9, 15, 18],
1007            anchors: Some(anchors0.clone()),
1008            quantization: Some(quant0),
1009            dshape: vec![
1010                (DimName::Batch, 1),
1011                (DimName::Height, 9),
1012                (DimName::Width, 15),
1013                (DimName::NumAnchorsXFeatures, 18),
1014            ],
1015            normalized: Some(true),
1016        };
1017
1018        let detect_config1 = configs::Detection {
1019            decoder: DecoderType::ModelPack,
1020            shape: vec![1, 17, 30, 18],
1021            anchors: Some(anchors1.clone()),
1022            quantization: Some(quant1),
1023            dshape: vec![
1024                (DimName::Batch, 1),
1025                (DimName::Height, 17),
1026                (DimName::Width, 30),
1027                (DimName::NumAnchorsXFeatures, 18),
1028            ],
1029            normalized: Some(true),
1030        };
1031
1032        let config0 = (&detect_config0).try_into().unwrap();
1033        let config1 = (&detect_config1).try_into().unwrap();
1034
1035        let decoder = DecoderBuilder::default()
1036            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
1037            .with_score_threshold(score_threshold)
1038            .with_iou_threshold(iou_threshold)
1039            .build()
1040            .unwrap();
1041
1042        let quant0 = quant0.into();
1043        let quant1 = quant1.into();
1044
1045        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1046        decode_modelpack_split_quant(
1047            &[
1048                detect0.slice(s![0, .., .., ..]),
1049                detect1.slice(s![0, .., .., ..]),
1050            ],
1051            &[config0, config1],
1052            score_threshold,
1053            iou_threshold,
1054            &mut output_boxes,
1055        );
1056        assert!(output_boxes[0].equal_within_delta(
1057            &DetectBox {
1058                bbox: BoundingBox {
1059                    xmin: 0.43171933,
1060                    ymin: 0.68243736,
1061                    xmax: 0.5626645,
1062                    ymax: 0.808863,
1063                },
1064                score: 0.99240804,
1065                label: 0
1066            },
1067            1e-6
1068        ));
1069
1070        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1071        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1072        decoder
1073            .decode_quantized(
1074                &[detect0.view().into(), detect1.view().into()],
1075                &mut output_boxes1,
1076                &mut output_masks1,
1077            )
1078            .unwrap();
1079
1080        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1081        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1082
1083        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1084        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1085        decoder
1086            .decode_float::<f32>(
1087                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1088                &mut output_boxes1_f32,
1089                &mut output_masks1_f32,
1090            )
1091            .unwrap();
1092
1093        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1094        compare_outputs(
1095            (&output_boxes, &output_boxes1_f32),
1096            (&[], &output_masks1_f32),
1097        );
1098    }
1099
1100    #[test]
1101    fn test_decoder_parse_config_modelpack_split_u8() {
1102        let score_threshold = 0.45;
1103        let iou_threshold = 0.45;
1104        let detect0 = include_bytes!(concat!(
1105            env!("CARGO_MANIFEST_DIR"),
1106            "/../../testdata/modelpack_split_9x15x18.bin"
1107        ));
1108        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1109
1110        let detect1 = include_bytes!(concat!(
1111            env!("CARGO_MANIFEST_DIR"),
1112            "/../../testdata/modelpack_split_17x30x18.bin"
1113        ));
1114        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1115
1116        let decoder = DecoderBuilder::default()
1117            .with_config_yaml_str(
1118                include_str!(concat!(
1119                    env!("CARGO_MANIFEST_DIR"),
1120                    "/../../testdata/modelpack_split.yaml"
1121                ))
1122                .to_string(),
1123            )
1124            .with_score_threshold(score_threshold)
1125            .with_iou_threshold(iou_threshold)
1126            .build()
1127            .unwrap();
1128
1129        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1130        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1131        decoder
1132            .decode_quantized(
1133                &[
1134                    ArrayViewDQuantized::from(detect1.view()),
1135                    ArrayViewDQuantized::from(detect0.view()),
1136                ],
1137                &mut output_boxes,
1138                &mut output_masks,
1139            )
1140            .unwrap();
1141        assert!(output_boxes[0].equal_within_delta(
1142            &DetectBox {
1143                bbox: BoundingBox {
1144                    xmin: 0.43171933,
1145                    ymin: 0.68243736,
1146                    xmax: 0.5626645,
1147                    ymax: 0.808863,
1148                },
1149                score: 0.99240804,
1150                label: 0
1151            },
1152            1e-6
1153        ));
1154    }
1155
1156    #[test]
1157    fn test_modelpack_seg() {
1158        let out = include_bytes!(concat!(
1159            env!("CARGO_MANIFEST_DIR"),
1160            "/../../testdata/modelpack_seg_2x160x160.bin"
1161        ));
1162        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1163        let quant = (1.0 / 255.0, 0).into();
1164
1165        let decoder = DecoderBuilder::default()
1166            .with_config_modelpack_seg(configs::Segmentation {
1167                decoder: DecoderType::ModelPack,
1168                quantization: Some(quant),
1169                shape: vec![1, 2, 160, 160],
1170                dshape: vec![
1171                    (DimName::Batch, 1),
1172                    (DimName::NumClasses, 2),
1173                    (DimName::Height, 160),
1174                    (DimName::Width, 160),
1175                ],
1176            })
1177            .build()
1178            .unwrap();
1179        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1180        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1181        decoder
1182            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1183            .unwrap();
1184
1185        let mut mask = out.slice(s![0, .., .., ..]);
1186        mask.swap_axes(0, 1);
1187        mask.swap_axes(1, 2);
1188        let mask = [Segmentation {
1189            xmin: 0.0,
1190            ymin: 0.0,
1191            xmax: 1.0,
1192            ymax: 1.0,
1193            segmentation: mask.into_owned(),
1194        }];
1195        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1196
1197        decoder
1198            .decode_float::<f32>(
1199                &[dequantize_ndarray(out.view(), quant.into())
1200                    .view()
1201                    .into_dyn()],
1202                &mut output_boxes,
1203                &mut output_masks,
1204            )
1205            .unwrap();
1206
1207        // not expected for float decoder to have same values as quantized decoder, as
1208        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1209        // the model output. Thus the float output is the same as the quantized output
1210        // but scaled differently. However, it is expected that the mask after argmax
1211        // will be the same.
1212        compare_outputs((&[], &output_boxes), (&[], &[]));
1213        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1214        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1215
1216        assert_eq!(mask0, mask1);
1217    }
1218    #[test]
1219    fn test_modelpack_seg_quant() {
1220        let out = include_bytes!(concat!(
1221            env!("CARGO_MANIFEST_DIR"),
1222            "/../../testdata/modelpack_seg_2x160x160.bin"
1223        ));
1224        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1225        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1226        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1227        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1228        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1229        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1230
1231        let quant = (1.0 / 255.0, 0).into();
1232
1233        let decoder = DecoderBuilder::default()
1234            .with_config_modelpack_seg(configs::Segmentation {
1235                decoder: DecoderType::ModelPack,
1236                quantization: Some(quant),
1237                shape: vec![1, 2, 160, 160],
1238                dshape: vec![
1239                    (DimName::Batch, 1),
1240                    (DimName::NumClasses, 2),
1241                    (DimName::Height, 160),
1242                    (DimName::Width, 160),
1243                ],
1244            })
1245            .build()
1246            .unwrap();
1247        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1248        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1249        decoder
1250            .decode_quantized(
1251                &[out_u8.view().into()],
1252                &mut output_boxes,
1253                &mut output_masks_u8,
1254            )
1255            .unwrap();
1256
1257        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1258        decoder
1259            .decode_quantized(
1260                &[out_i8.view().into()],
1261                &mut output_boxes,
1262                &mut output_masks_i8,
1263            )
1264            .unwrap();
1265
1266        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1267        decoder
1268            .decode_quantized(
1269                &[out_u16.view().into()],
1270                &mut output_boxes,
1271                &mut output_masks_u16,
1272            )
1273            .unwrap();
1274
1275        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1276        decoder
1277            .decode_quantized(
1278                &[out_i16.view().into()],
1279                &mut output_boxes,
1280                &mut output_masks_i16,
1281            )
1282            .unwrap();
1283
1284        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1285        decoder
1286            .decode_quantized(
1287                &[out_u32.view().into()],
1288                &mut output_boxes,
1289                &mut output_masks_u32,
1290            )
1291            .unwrap();
1292
1293        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1294        decoder
1295            .decode_quantized(
1296                &[out_i32.view().into()],
1297                &mut output_boxes,
1298                &mut output_masks_i32,
1299            )
1300            .unwrap();
1301
1302        compare_outputs((&[], &output_boxes), (&[], &[]));
1303        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1304        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1305        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1306        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1307        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1308        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1309        assert_eq!(mask_u8, mask_i8);
1310        assert_eq!(mask_u8, mask_u16);
1311        assert_eq!(mask_u8, mask_i16);
1312        assert_eq!(mask_u8, mask_u32);
1313        assert_eq!(mask_u8, mask_i32);
1314    }
1315
1316    #[test]
1317    fn test_modelpack_segdet() {
1318        let score_threshold = 0.45;
1319        let iou_threshold = 0.45;
1320
1321        let boxes = include_bytes!(concat!(
1322            env!("CARGO_MANIFEST_DIR"),
1323            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1324        ));
1325        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1326
1327        let scores = include_bytes!(concat!(
1328            env!("CARGO_MANIFEST_DIR"),
1329            "/../../testdata/modelpack_scores_1935x1.bin"
1330        ));
1331        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1332
1333        let seg = include_bytes!(concat!(
1334            env!("CARGO_MANIFEST_DIR"),
1335            "/../../testdata/modelpack_seg_2x160x160.bin"
1336        ));
1337        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1338
1339        let quant_boxes = (0.004656755365431309, 21).into();
1340        let quant_scores = (0.0019603664986789227, 0).into();
1341        let quant_seg = (1.0 / 255.0, 0).into();
1342
1343        let decoder = DecoderBuilder::default()
1344            .with_config_modelpack_segdet(
1345                configs::Boxes {
1346                    decoder: DecoderType::ModelPack,
1347                    quantization: Some(quant_boxes),
1348                    shape: vec![1, 1935, 1, 4],
1349                    dshape: vec![
1350                        (DimName::Batch, 1),
1351                        (DimName::NumBoxes, 1935),
1352                        (DimName::Padding, 1),
1353                        (DimName::BoxCoords, 4),
1354                    ],
1355                    normalized: Some(true),
1356                },
1357                configs::Scores {
1358                    decoder: DecoderType::ModelPack,
1359                    quantization: Some(quant_scores),
1360                    shape: vec![1, 1935, 1],
1361                    dshape: vec![
1362                        (DimName::Batch, 1),
1363                        (DimName::NumBoxes, 1935),
1364                        (DimName::NumClasses, 1),
1365                    ],
1366                },
1367                configs::Segmentation {
1368                    decoder: DecoderType::ModelPack,
1369                    quantization: Some(quant_seg),
1370                    shape: vec![1, 2, 160, 160],
1371                    dshape: vec![
1372                        (DimName::Batch, 1),
1373                        (DimName::NumClasses, 2),
1374                        (DimName::Height, 160),
1375                        (DimName::Width, 160),
1376                    ],
1377                },
1378            )
1379            .with_iou_threshold(iou_threshold)
1380            .with_score_threshold(score_threshold)
1381            .build()
1382            .unwrap();
1383        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1384        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1385        decoder
1386            .decode_quantized(
1387                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1388                &mut output_boxes,
1389                &mut output_masks,
1390            )
1391            .unwrap();
1392
1393        let mut mask = seg.slice(s![0, .., .., ..]);
1394        mask.swap_axes(0, 1);
1395        mask.swap_axes(1, 2);
1396        let mask = [Segmentation {
1397            xmin: 0.0,
1398            ymin: 0.0,
1399            xmax: 1.0,
1400            ymax: 1.0,
1401            segmentation: mask.into_owned(),
1402        }];
1403        let correct_boxes = [DetectBox {
1404            bbox: BoundingBox {
1405                xmin: 0.40513772,
1406                ymin: 0.6379755,
1407                xmax: 0.5122431,
1408                ymax: 0.7730214,
1409            },
1410            score: 0.4861709,
1411            label: 0,
1412        }];
1413        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1414
1415        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1416        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1417        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1418        decoder
1419            .decode_float::<f32>(
1420                &[
1421                    scores.view().into_dyn(),
1422                    boxes.view().into_dyn(),
1423                    seg.view().into_dyn(),
1424                ],
1425                &mut output_boxes,
1426                &mut output_masks,
1427            )
1428            .unwrap();
1429
1430        // not expected for float segmentation decoder to have same values as quantized
1431        // segmentation decoder, as float decoder ensures the data fills 0-255,
1432        // quantized decoder uses whatever the model output. Thus the float
1433        // output is the same as the quantized output but scaled differently.
1434        // However, it is expected that the mask after argmax will be the same.
1435        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1436        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1437        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1438
1439        assert_eq!(mask0, mask1);
1440    }
1441
1442    #[test]
1443    fn test_modelpack_segdet_split() {
1444        let score_threshold = 0.8;
1445        let iou_threshold = 0.5;
1446
1447        let seg = include_bytes!(concat!(
1448            env!("CARGO_MANIFEST_DIR"),
1449            "/../../testdata/modelpack_seg_2x160x160.bin"
1450        ));
1451        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1452
1453        let detect0 = include_bytes!(concat!(
1454            env!("CARGO_MANIFEST_DIR"),
1455            "/../../testdata/modelpack_split_9x15x18.bin"
1456        ));
1457        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1458
1459        let detect1 = include_bytes!(concat!(
1460            env!("CARGO_MANIFEST_DIR"),
1461            "/../../testdata/modelpack_split_17x30x18.bin"
1462        ));
1463        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1464
1465        let quant0 = (0.08547406643629074, 174).into();
1466        let quant1 = (0.09929127991199493, 183).into();
1467        let quant_seg = (1.0 / 255.0, 0).into();
1468
1469        let anchors0 = vec![
1470            [0.36666667461395264, 0.31481480598449707],
1471            [0.38749998807907104, 0.4740740656852722],
1472            [0.5333333611488342, 0.644444465637207],
1473        ];
1474        let anchors1 = vec![
1475            [0.13750000298023224, 0.2074074000120163],
1476            [0.2541666626930237, 0.21481481194496155],
1477            [0.23125000298023224, 0.35185185074806213],
1478        ];
1479
1480        let decoder = DecoderBuilder::default()
1481            .with_config_modelpack_segdet_split(
1482                vec![
1483                    configs::Detection {
1484                        decoder: DecoderType::ModelPack,
1485                        shape: vec![1, 17, 30, 18],
1486                        anchors: Some(anchors1),
1487                        quantization: Some(quant1),
1488                        dshape: vec![
1489                            (DimName::Batch, 1),
1490                            (DimName::Height, 17),
1491                            (DimName::Width, 30),
1492                            (DimName::NumAnchorsXFeatures, 18),
1493                        ],
1494                        normalized: Some(true),
1495                    },
1496                    configs::Detection {
1497                        decoder: DecoderType::ModelPack,
1498                        shape: vec![1, 9, 15, 18],
1499                        anchors: Some(anchors0),
1500                        quantization: Some(quant0),
1501                        dshape: vec![
1502                            (DimName::Batch, 1),
1503                            (DimName::Height, 9),
1504                            (DimName::Width, 15),
1505                            (DimName::NumAnchorsXFeatures, 18),
1506                        ],
1507                        normalized: Some(true),
1508                    },
1509                ],
1510                configs::Segmentation {
1511                    decoder: DecoderType::ModelPack,
1512                    quantization: Some(quant_seg),
1513                    shape: vec![1, 2, 160, 160],
1514                    dshape: vec![
1515                        (DimName::Batch, 1),
1516                        (DimName::NumClasses, 2),
1517                        (DimName::Height, 160),
1518                        (DimName::Width, 160),
1519                    ],
1520                },
1521            )
1522            .with_score_threshold(score_threshold)
1523            .with_iou_threshold(iou_threshold)
1524            .build()
1525            .unwrap();
1526        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1527        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1528        decoder
1529            .decode_quantized(
1530                &[
1531                    detect0.view().into(),
1532                    detect1.view().into(),
1533                    seg.view().into(),
1534                ],
1535                &mut output_boxes,
1536                &mut output_masks,
1537            )
1538            .unwrap();
1539
1540        let mut mask = seg.slice(s![0, .., .., ..]);
1541        mask.swap_axes(0, 1);
1542        mask.swap_axes(1, 2);
1543        let mask = [Segmentation {
1544            xmin: 0.0,
1545            ymin: 0.0,
1546            xmax: 1.0,
1547            ymax: 1.0,
1548            segmentation: mask.into_owned(),
1549        }];
1550        let correct_boxes = [DetectBox {
1551            bbox: BoundingBox {
1552                xmin: 0.43171933,
1553                ymin: 0.68243736,
1554                xmax: 0.5626645,
1555                ymax: 0.808863,
1556            },
1557            score: 0.99240804,
1558            label: 0,
1559        }];
1560        println!("Output Boxes: {:?}", output_boxes);
1561        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1562
1563        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1564        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1565        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1566        decoder
1567            .decode_float::<f32>(
1568                &[
1569                    detect0.view().into_dyn(),
1570                    detect1.view().into_dyn(),
1571                    seg.view().into_dyn(),
1572                ],
1573                &mut output_boxes,
1574                &mut output_masks,
1575            )
1576            .unwrap();
1577
1578        // not expected for float segmentation decoder to have same values as quantized
1579        // segmentation decoder, as float decoder ensures the data fills 0-255,
1580        // quantized decoder uses whatever the model output. Thus the float
1581        // output is the same as the quantized output but scaled differently.
1582        // However, it is expected that the mask after argmax will be the same.
1583        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1584        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1585        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1586
1587        assert_eq!(mask0, mask1);
1588    }
1589
1590    #[test]
1591    fn test_dequant_chunked() {
1592        let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1593        out.push(123); // make sure to test non multiple of 16 length
1594
1595        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1596        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1597        let quant = Quantization::new(0.0040811873, -123);
1598        dequantize_cpu(&out, quant, &mut out_dequant);
1599
1600        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1601        assert_eq!(out_dequant, out_dequant_simd);
1602
1603        let quant = Quantization::new(0.0040811873, 0);
1604        dequantize_cpu(&out, quant, &mut out_dequant);
1605
1606        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1607        assert_eq!(out_dequant, out_dequant_simd);
1608    }
1609
1610    #[test]
1611    fn test_dequant_ground_truth() {
1612        // Formula: output = (input - zero_point) * scale
1613        // Verify both dequantize_cpu and dequantize_cpu_chunked against hand-computed values.
1614
1615        // Case 1: scale=0.1, zero_point=-128 (from doc example)
1616        let quant = Quantization::new(0.1, -128);
1617        let input: Vec<i8> = vec![0, 127, -128, 64];
1618        let mut output = vec![0.0f32; 4];
1619        let mut output_chunked = vec![0.0f32; 4];
1620        dequantize_cpu(&input, quant, &mut output);
1621        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1622        // (0 - (-128)) * 0.1 = 12.8
1623        // (127 - (-128)) * 0.1 = 25.5
1624        // (-128 - (-128)) * 0.1 = 0.0
1625        // (64 - (-128)) * 0.1 = 19.2
1626        let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1627        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1628            assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1629        }
1630        for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1631            assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1632        }
1633
1634        // Case 2: scale=1.0, zero_point=0 (identity-like)
1635        let quant = Quantization::new(1.0, 0);
1636        dequantize_cpu(&input, quant, &mut output);
1637        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1638        let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1639        assert_eq!(output, expected);
1640        assert_eq!(output_chunked, expected);
1641
1642        // Case 3: scale=0.5, zero_point=0
1643        let quant = Quantization::new(0.5, 0);
1644        dequantize_cpu(&input, quant, &mut output);
1645        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1646        let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1647        assert_eq!(output, expected);
1648        assert_eq!(output_chunked, expected);
1649
1650        // Case 4: i8 min/max boundaries with typical quantization params
1651        let quant = Quantization::new(0.021287762, 31);
1652        let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1653        let mut output = vec![0.0f32; 6];
1654        let mut output_chunked = vec![0.0f32; 6];
1655        dequantize_cpu(&input, quant, &mut output);
1656        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1657        for i in 0..6 {
1658            let expected = (input[i] as f32 - 31.0) * 0.021287762;
1659            assert!(
1660                (output[i] - expected).abs() < 1e-5,
1661                "cpu[{i}]: {} != {expected}",
1662                output[i]
1663            );
1664            assert!(
1665                (output_chunked[i] - expected).abs() < 1e-5,
1666                "chunked[{i}]: {} != {expected}",
1667                output_chunked[i]
1668            );
1669        }
1670    }
1671
1672    #[test]
1673    fn test_decoder_yolo_det() {
1674        let score_threshold = 0.25;
1675        let iou_threshold = 0.7;
1676        let out = load_yolov8s_det();
1677        let quant = (0.0040811873, -123).into();
1678
1679        let decoder = DecoderBuilder::default()
1680            .with_config_yolo_det(
1681                configs::Detection {
1682                    decoder: DecoderType::Ultralytics,
1683                    shape: vec![1, 84, 8400],
1684                    anchors: None,
1685                    quantization: Some(quant),
1686                    dshape: vec![
1687                        (DimName::Batch, 1),
1688                        (DimName::NumFeatures, 84),
1689                        (DimName::NumBoxes, 8400),
1690                    ],
1691                    normalized: Some(true),
1692                },
1693                Some(DecoderVersion::Yolo11),
1694            )
1695            .with_score_threshold(score_threshold)
1696            .with_iou_threshold(iou_threshold)
1697            .build()
1698            .unwrap();
1699
1700        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1701        decode_yolo_det(
1702            (out.slice(s![0, .., ..]), quant.into()),
1703            score_threshold,
1704            iou_threshold,
1705            Some(configs::Nms::ClassAgnostic),
1706            &mut output_boxes,
1707        );
1708        assert!(output_boxes[0].equal_within_delta(
1709            &DetectBox {
1710                bbox: BoundingBox {
1711                    xmin: 0.5285137,
1712                    ymin: 0.05305544,
1713                    xmax: 0.87541467,
1714                    ymax: 0.9998909,
1715                },
1716                score: 0.5591227,
1717                label: 0
1718            },
1719            1e-6
1720        ));
1721
1722        assert!(output_boxes[1].equal_within_delta(
1723            &DetectBox {
1724                bbox: BoundingBox {
1725                    xmin: 0.130598,
1726                    ymin: 0.43260583,
1727                    xmax: 0.35098213,
1728                    ymax: 0.9958097,
1729                },
1730                score: 0.33057618,
1731                label: 75
1732            },
1733            1e-6
1734        ));
1735
1736        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1737        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1738        decoder
1739            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1740            .unwrap();
1741
1742        let out = dequantize_ndarray(out.view(), quant.into());
1743        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1744        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1745        decoder
1746            .decode_float::<f32>(
1747                &[out.view().into_dyn()],
1748                &mut output_boxes_f32,
1749                &mut output_masks_f32,
1750            )
1751            .unwrap();
1752
1753        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1754        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1755    }
1756
1757    #[test]
1758    fn test_decoder_masks() {
1759        let score_threshold = 0.45;
1760        let iou_threshold = 0.45;
1761        let boxes = load_yolov8_boxes();
1762        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1763
1764        let protos = load_yolov8_protos();
1765        let quant_protos = Quantization::new(0.02491161972284317, -117);
1766        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1767        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1768        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1769        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1770        decode_yolo_segdet_float(
1771            seg.slice(s![0, .., ..]),
1772            protos.slice(s![0, .., .., ..]),
1773            score_threshold,
1774            iou_threshold,
1775            Some(configs::Nms::ClassAgnostic),
1776            &mut output_boxes,
1777            &mut output_masks,
1778        )
1779        .unwrap();
1780        assert_eq!(output_boxes.len(), 2);
1781        assert_eq!(output_boxes.len(), output_masks.len());
1782
1783        for (b, m) in output_boxes.iter().zip(&output_masks) {
1784            assert!(b.bbox.xmin >= m.xmin);
1785            assert!(b.bbox.ymin >= m.ymin);
1786            assert!(b.bbox.xmax >= m.xmax);
1787            assert!(b.bbox.ymax >= m.ymax);
1788        }
1789        assert!(output_boxes[0].equal_within_delta(
1790            &DetectBox {
1791                bbox: BoundingBox {
1792                    xmin: 0.08515105,
1793                    ymin: 0.7131401,
1794                    xmax: 0.29802868,
1795                    ymax: 0.8195788,
1796                },
1797                score: 0.91537374,
1798                label: 23
1799            },
1800            1.0 / 160.0, // wider range because mask will expand the box
1801        ));
1802
1803        assert!(output_boxes[1].equal_within_delta(
1804            &DetectBox {
1805                bbox: BoundingBox {
1806                    xmin: 0.59605736,
1807                    ymin: 0.25545314,
1808                    xmax: 0.93666154,
1809                    ymax: 0.72378385,
1810                },
1811                score: 0.91537374,
1812                label: 23
1813            },
1814            1.0 / 160.0, // wider range because mask will expand the box
1815        ));
1816
1817        let full_mask = include_bytes!(concat!(
1818            env!("CARGO_MANIFEST_DIR"),
1819            "/../../testdata/yolov8_mask_results.bin"
1820        ));
1821        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1822
1823        let cropped_mask = full_mask.slice(ndarray::s![
1824            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1825            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1826        ]);
1827
1828        assert_eq!(
1829            cropped_mask,
1830            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1831        );
1832    }
1833
1834    /// Regression test: config-driven path with physically-NCHW protos.
1835    /// Simulates YOLOv8-seg ONNX outputs where the producer emits protos
1836    /// as `(1, 32, 160, 160)` in CHW memory order — the caller declares
1837    /// shape and dshape matching that physical order and HAL permutes to
1838    /// canonical HWC via `swap_axes_if_needed`.
1839    ///
1840    /// This is the counterpart to the NHWC-producer case (TFLite /
1841    /// Ara-2) where shape+dshape are `(1, 160, 160, 32)` +
1842    /// `[batch, height, width, num_protos]` and no reordering is needed.
1843    #[test]
1844    fn test_decoder_masks_nchw_protos() {
1845        let score_threshold = 0.45;
1846        let iou_threshold = 0.45;
1847
1848        // Load test data — boxes as [116, 8400]
1849        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1850        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1851
1852        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1853        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1854        let quant_protos = Quantization::new(0.02491161972284317, -117);
1855        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1856
1857        // ---- Reference: direct call with HWC protos (known working) ----
1858        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1859        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1860        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1861        decode_yolo_segdet_float(
1862            seg.view(),
1863            protos_f32_hwc.view(),
1864            score_threshold,
1865            iou_threshold,
1866            Some(configs::Nms::ClassAgnostic),
1867            &mut ref_boxes,
1868            &mut ref_masks,
1869        )
1870        .unwrap();
1871        assert_eq!(ref_boxes.len(), 2);
1872
1873        // ---- Config-driven path: NCHW protos declared in physical order ----
1874        // Permute the HWC test data to CHW memory order — this is what an
1875        // ONNX-style producer would emit into the tensor buffer.
1876        // `to_owned` materialises a C-contiguous Array3<f32> with CHW
1877        // strides, matching a producer that writes channels-outer.
1878        let protos_f32_chw_view = protos_f32_hwc.view().permuted_axes([2, 0, 1]); // [32, 160, 160]
1879        let protos_f32_chw = protos_f32_chw_view.to_owned();
1880        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1881
1882        // Build boxes as [1, 116, 8400] f32
1883        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1884
1885        // Declare shape and dshape in physical memory order (outermost
1886        // first). `swap_axes_if_needed` uses dshape role lookup to
1887        // permute the stride tuple into canonical HWC.
1888        let decoder = DecoderBuilder::default()
1889            .with_config_yolo_segdet(
1890                configs::Detection {
1891                    decoder: configs::DecoderType::Ultralytics,
1892                    quantization: None,
1893                    shape: vec![1, 116, 8400],
1894                    dshape: vec![
1895                        (configs::DimName::Batch, 1),
1896                        (configs::DimName::NumFeatures, 116),
1897                        (configs::DimName::NumBoxes, 8400),
1898                    ],
1899                    normalized: Some(true),
1900                    anchors: None,
1901                },
1902                configs::Protos {
1903                    decoder: configs::DecoderType::Ultralytics,
1904                    quantization: None,
1905                    shape: vec![1, 32, 160, 160],
1906                    dshape: vec![
1907                        (configs::DimName::Batch, 1),
1908                        (configs::DimName::NumProtos, 32),
1909                        (configs::DimName::Height, 160),
1910                        (configs::DimName::Width, 160),
1911                    ],
1912                },
1913                None, // decoder version
1914            )
1915            .with_score_threshold(score_threshold)
1916            .with_iou_threshold(iou_threshold)
1917            .build()
1918            .unwrap();
1919
1920        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1921        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1922        decoder
1923            .decode_float(
1924                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1925                &mut cfg_boxes,
1926                &mut cfg_masks,
1927            )
1928            .unwrap();
1929
1930        // Must produce the same number of detections
1931        assert_eq!(
1932            cfg_boxes.len(),
1933            ref_boxes.len(),
1934            "config path produced {} boxes, reference produced {}",
1935            cfg_boxes.len(),
1936            ref_boxes.len()
1937        );
1938
1939        // Boxes must match
1940        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1941            assert!(
1942                cb.equal_within_delta(rb, 0.01),
1943                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1944            );
1945        }
1946
1947        // Masks must match pixel-for-pixel
1948        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1949            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1950            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1951            assert_eq!(
1952                cm_arr, rm_arr,
1953                "mask {i} pixel mismatch between config-driven and reference paths"
1954            );
1955        }
1956    }
1957
1958    #[test]
1959    fn test_decoder_masks_i8() {
1960        let score_threshold = 0.45;
1961        let iou_threshold = 0.45;
1962        let boxes = load_yolov8_boxes();
1963        let quant_boxes = (0.021287761628627777, 31).into();
1964
1965        let protos = load_yolov8_protos();
1966        let quant_protos = (0.02491161972284317, -117).into();
1967        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1968        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1969
1970        let decoder = DecoderBuilder::default()
1971            .with_config_yolo_segdet(
1972                configs::Detection {
1973                    decoder: configs::DecoderType::Ultralytics,
1974                    quantization: Some(quant_boxes),
1975                    shape: vec![1, 116, 8400],
1976                    anchors: None,
1977                    dshape: vec![
1978                        (DimName::Batch, 1),
1979                        (DimName::NumFeatures, 116),
1980                        (DimName::NumBoxes, 8400),
1981                    ],
1982                    normalized: Some(true),
1983                },
1984                Protos {
1985                    decoder: configs::DecoderType::Ultralytics,
1986                    quantization: Some(quant_protos),
1987                    shape: vec![1, 160, 160, 32],
1988                    dshape: vec![
1989                        (DimName::Batch, 1),
1990                        (DimName::Height, 160),
1991                        (DimName::Width, 160),
1992                        (DimName::NumProtos, 32),
1993                    ],
1994                },
1995                Some(DecoderVersion::Yolo11),
1996            )
1997            .with_score_threshold(score_threshold)
1998            .with_iou_threshold(iou_threshold)
1999            .build()
2000            .unwrap();
2001
2002        let quant_boxes = quant_boxes.into();
2003        let quant_protos = quant_protos.into();
2004
2005        decode_yolo_segdet_quant(
2006            (boxes.slice(s![0, .., ..]), quant_boxes),
2007            (protos.slice(s![0, .., .., ..]), quant_protos),
2008            score_threshold,
2009            iou_threshold,
2010            Some(configs::Nms::ClassAgnostic),
2011            &mut output_boxes,
2012            &mut output_masks,
2013        )
2014        .unwrap();
2015
2016        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2017        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2018
2019        decoder
2020            .decode_quantized(
2021                &[boxes.view().into(), protos.view().into()],
2022                &mut output_boxes1,
2023                &mut output_masks1,
2024            )
2025            .unwrap();
2026
2027        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2028        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2029
2030        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2031        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2032        decode_yolo_segdet_float(
2033            seg.slice(s![0, .., ..]),
2034            protos.slice(s![0, .., .., ..]),
2035            score_threshold,
2036            iou_threshold,
2037            Some(configs::Nms::ClassAgnostic),
2038            &mut output_boxes_f32,
2039            &mut output_masks_f32,
2040        )
2041        .unwrap();
2042
2043        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
2044        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
2045
2046        decoder
2047            .decode_float(
2048                &[seg.view().into_dyn(), protos.view().into_dyn()],
2049                &mut output_boxes1_f32,
2050                &mut output_masks1_f32,
2051            )
2052            .unwrap();
2053
2054        compare_outputs(
2055            (&output_boxes, &output_boxes1),
2056            (&output_masks, &output_masks1),
2057        );
2058
2059        compare_outputs(
2060            (&output_boxes, &output_boxes_f32),
2061            (&output_masks, &output_masks_f32),
2062        );
2063
2064        compare_outputs(
2065            (&output_boxes_f32, &output_boxes1_f32),
2066            (&output_masks_f32, &output_masks1_f32),
2067        );
2068    }
2069
2070    #[test]
2071    fn test_decoder_yolo_split() {
2072        let score_threshold = 0.45;
2073        let iou_threshold = 0.45;
2074        let boxes = load_yolov8_boxes();
2075        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2076        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2077
2078        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2079
2080        let decoder = DecoderBuilder::default()
2081            .with_config_yolo_split_det(
2082                configs::Boxes {
2083                    decoder: configs::DecoderType::Ultralytics,
2084                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2085                    shape: vec![1, 4, 8400],
2086                    dshape: vec![
2087                        (DimName::Batch, 1),
2088                        (DimName::BoxCoords, 4),
2089                        (DimName::NumBoxes, 8400),
2090                    ],
2091                    normalized: Some(true),
2092                },
2093                configs::Scores {
2094                    decoder: configs::DecoderType::Ultralytics,
2095                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2096                    shape: vec![1, 80, 8400],
2097                    dshape: vec![
2098                        (DimName::Batch, 1),
2099                        (DimName::NumClasses, 80),
2100                        (DimName::NumBoxes, 8400),
2101                    ],
2102                },
2103            )
2104            .with_score_threshold(score_threshold)
2105            .with_iou_threshold(iou_threshold)
2106            .build()
2107            .unwrap();
2108
2109        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2110        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2111
2112        decoder
2113            .decode_quantized(
2114                &[
2115                    boxes.slice(s![.., ..4, ..]).into(),
2116                    boxes.slice(s![.., 4..84, ..]).into(),
2117                ],
2118                &mut output_boxes,
2119                &mut output_masks,
2120            )
2121            .unwrap();
2122
2123        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2124        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2125        decode_yolo_det_float(
2126            seg.slice(s![0, ..84, ..]),
2127            score_threshold,
2128            iou_threshold,
2129            Some(configs::Nms::ClassAgnostic),
2130            &mut output_boxes_f32,
2131        );
2132
2133        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2134        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2135
2136        decoder
2137            .decode_float(
2138                &[
2139                    seg.slice(s![.., ..4, ..]).into_dyn(),
2140                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2141                ],
2142                &mut output_boxes1,
2143                &mut output_masks1,
2144            )
2145            .unwrap();
2146        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2147        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2148    }
2149
2150    #[test]
2151    fn test_decoder_masks_config_mixed() {
2152        let score_threshold = 0.45;
2153        let iou_threshold = 0.45;
2154        let boxes_raw = load_yolov8_boxes();
2155        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2156        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2157
2158        let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2159
2160        let protos = load_yolov8_protos();
2161        let quant_protos = (0.02491161972284317, -117);
2162
2163        let decoder = build_yolo_split_segdet_decoder(
2164            score_threshold,
2165            iou_threshold,
2166            quant_boxes,
2167            quant_protos,
2168        );
2169        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2170        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2171
2172        decoder
2173            .decode_quantized(
2174                &[
2175                    boxes.slice(s![.., ..4, ..]).into(),
2176                    boxes.slice(s![.., 4..84, ..]).into(),
2177                    boxes.slice(s![.., 84.., ..]).into(),
2178                    protos.view().into(),
2179                ],
2180                &mut output_boxes,
2181                &mut output_masks,
2182            )
2183            .unwrap();
2184
2185        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2186        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2187        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2188        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2189        decode_yolo_segdet_float(
2190            seg.slice(s![0, .., ..]),
2191            protos.slice(s![0, .., .., ..]),
2192            score_threshold,
2193            iou_threshold,
2194            Some(configs::Nms::ClassAgnostic),
2195            &mut output_boxes_f32,
2196            &mut output_masks_f32,
2197        )
2198        .unwrap();
2199
2200        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2201        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2202
2203        decoder
2204            .decode_float(
2205                &[
2206                    seg.slice(s![.., ..4, ..]).into_dyn(),
2207                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2208                    seg.slice(s![.., 84.., ..]).into_dyn(),
2209                    protos.view().into_dyn(),
2210                ],
2211                &mut output_boxes1,
2212                &mut output_masks1,
2213            )
2214            .unwrap();
2215        compare_outputs(
2216            (&output_boxes, &output_boxes_f32),
2217            (&output_masks, &output_masks_f32),
2218        );
2219        compare_outputs(
2220            (&output_boxes_f32, &output_boxes1),
2221            (&output_masks_f32, &output_masks1),
2222        );
2223    }
2224
2225    fn build_yolo_split_segdet_decoder(
2226        score_threshold: f32,
2227        iou_threshold: f32,
2228        quant_boxes: (f32, i32),
2229        quant_protos: (f32, i32),
2230    ) -> crate::Decoder {
2231        DecoderBuilder::default()
2232            .with_config_yolo_split_segdet(
2233                configs::Boxes {
2234                    decoder: configs::DecoderType::Ultralytics,
2235                    quantization: Some(quant_boxes.into()),
2236                    shape: vec![1, 4, 8400],
2237                    dshape: vec![
2238                        (DimName::Batch, 1),
2239                        (DimName::BoxCoords, 4),
2240                        (DimName::NumBoxes, 8400),
2241                    ],
2242                    normalized: Some(true),
2243                },
2244                configs::Scores {
2245                    decoder: configs::DecoderType::Ultralytics,
2246                    quantization: Some(quant_boxes.into()),
2247                    shape: vec![1, 80, 8400],
2248                    dshape: vec![
2249                        (DimName::Batch, 1),
2250                        (DimName::NumClasses, 80),
2251                        (DimName::NumBoxes, 8400),
2252                    ],
2253                },
2254                configs::MaskCoefficients {
2255                    decoder: configs::DecoderType::Ultralytics,
2256                    quantization: Some(quant_boxes.into()),
2257                    shape: vec![1, 32, 8400],
2258                    dshape: vec![
2259                        (DimName::Batch, 1),
2260                        (DimName::NumProtos, 32),
2261                        (DimName::NumBoxes, 8400),
2262                    ],
2263                },
2264                configs::Protos {
2265                    decoder: configs::DecoderType::Ultralytics,
2266                    quantization: Some(quant_protos.into()),
2267                    shape: vec![1, 160, 160, 32],
2268                    dshape: vec![
2269                        (DimName::Batch, 1),
2270                        (DimName::Height, 160),
2271                        (DimName::Width, 160),
2272                        (DimName::NumProtos, 32),
2273                    ],
2274                },
2275            )
2276            .with_score_threshold(score_threshold)
2277            .with_iou_threshold(iou_threshold)
2278            .build()
2279            .unwrap()
2280    }
2281
2282    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2283        let config_yaml = include_str!(concat!(
2284            env!("CARGO_MANIFEST_DIR"),
2285            "/../../testdata/yolov8_seg.yaml"
2286        ));
2287        DecoderBuilder::default()
2288            .with_config_yaml_str(config_yaml.to_string())
2289            .with_score_threshold(score_threshold)
2290            .with_iou_threshold(iou_threshold)
2291            .build()
2292            .unwrap()
2293    }
2294    #[test]
2295    fn test_decoder_masks_config_i32() {
2296        let score_threshold = 0.45;
2297        let iou_threshold = 0.45;
2298        let boxes_raw = load_yolov8_boxes();
2299        let scale = 1 << 23;
2300        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2301        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2302
2303        let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2304
2305        let protos_raw = load_yolov8_protos();
2306        let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2307        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2308        let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2309
2310        let decoder = build_yolo_split_segdet_decoder(
2311            score_threshold,
2312            iou_threshold,
2313            quant_boxes,
2314            quant_protos,
2315        );
2316
2317        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2318        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2319
2320        decoder
2321            .decode_quantized(
2322                &[
2323                    boxes.slice(s![.., ..4, ..]).into(),
2324                    boxes.slice(s![.., 4..84, ..]).into(),
2325                    boxes.slice(s![.., 84.., ..]).into(),
2326                    protos.view().into(),
2327                ],
2328                &mut output_boxes,
2329                &mut output_masks,
2330            )
2331            .unwrap();
2332
2333        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2334        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2335        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2336        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2337        decode_yolo_segdet_float(
2338            seg.slice(s![0, .., ..]),
2339            protos.slice(s![0, .., .., ..]),
2340            score_threshold,
2341            iou_threshold,
2342            Some(configs::Nms::ClassAgnostic),
2343            &mut output_boxes_f32,
2344            &mut output_masks_f32,
2345        )
2346        .unwrap();
2347
2348        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2349        assert_eq!(output_masks.len(), output_masks_f32.len());
2350
2351        compare_outputs(
2352            (&output_boxes, &output_boxes_f32),
2353            (&output_masks, &output_masks_f32),
2354        );
2355    }
2356
2357    /// test running multiple decoders concurrently
2358    #[test]
2359    fn test_context_switch() {
2360        let yolo_det = || {
2361            let score_threshold = 0.25;
2362            let iou_threshold = 0.7;
2363            let out = load_yolov8s_det();
2364            let quant = (0.0040811873, -123).into();
2365
2366            let decoder = DecoderBuilder::default()
2367                .with_config_yolo_det(
2368                    configs::Detection {
2369                        decoder: DecoderType::Ultralytics,
2370                        shape: vec![1, 84, 8400],
2371                        anchors: None,
2372                        quantization: Some(quant),
2373                        dshape: vec![
2374                            (DimName::Batch, 1),
2375                            (DimName::NumFeatures, 84),
2376                            (DimName::NumBoxes, 8400),
2377                        ],
2378                        normalized: None,
2379                    },
2380                    None,
2381                )
2382                .with_score_threshold(score_threshold)
2383                .with_iou_threshold(iou_threshold)
2384                .build()
2385                .unwrap();
2386
2387            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2388            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2389
2390            for _ in 0..100 {
2391                decoder
2392                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2393                    .unwrap();
2394
2395                assert!(output_boxes[0].equal_within_delta(
2396                    &DetectBox {
2397                        bbox: BoundingBox {
2398                            xmin: 0.5285137,
2399                            ymin: 0.05305544,
2400                            xmax: 0.87541467,
2401                            ymax: 0.9998909,
2402                        },
2403                        score: 0.5591227,
2404                        label: 0
2405                    },
2406                    1e-6
2407                ));
2408
2409                assert!(output_boxes[1].equal_within_delta(
2410                    &DetectBox {
2411                        bbox: BoundingBox {
2412                            xmin: 0.130598,
2413                            ymin: 0.43260583,
2414                            xmax: 0.35098213,
2415                            ymax: 0.9958097,
2416                        },
2417                        score: 0.33057618,
2418                        label: 75
2419                    },
2420                    1e-6
2421                ));
2422                assert!(output_masks.is_empty());
2423            }
2424        };
2425
2426        let modelpack_det_split = || {
2427            let score_threshold = 0.8;
2428            let iou_threshold = 0.5;
2429
2430            let seg = include_bytes!(concat!(
2431                env!("CARGO_MANIFEST_DIR"),
2432                "/../../testdata/modelpack_seg_2x160x160.bin"
2433            ));
2434            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2435
2436            let detect0 = include_bytes!(concat!(
2437                env!("CARGO_MANIFEST_DIR"),
2438                "/../../testdata/modelpack_split_9x15x18.bin"
2439            ));
2440            let detect0 =
2441                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2442
2443            let detect1 = include_bytes!(concat!(
2444                env!("CARGO_MANIFEST_DIR"),
2445                "/../../testdata/modelpack_split_17x30x18.bin"
2446            ));
2447            let detect1 =
2448                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2449
2450            let mut mask = seg.slice(s![0, .., .., ..]);
2451            mask.swap_axes(0, 1);
2452            mask.swap_axes(1, 2);
2453            let mask = [Segmentation {
2454                xmin: 0.0,
2455                ymin: 0.0,
2456                xmax: 1.0,
2457                ymax: 1.0,
2458                segmentation: mask.into_owned(),
2459            }];
2460            let correct_boxes = [DetectBox {
2461                bbox: BoundingBox {
2462                    xmin: 0.43171933,
2463                    ymin: 0.68243736,
2464                    xmax: 0.5626645,
2465                    ymax: 0.808863,
2466                },
2467                score: 0.99240804,
2468                label: 0,
2469            }];
2470
2471            let quant0 = (0.08547406643629074, 174).into();
2472            let quant1 = (0.09929127991199493, 183).into();
2473            let quant_seg = (1.0 / 255.0, 0).into();
2474
2475            let anchors0 = vec![
2476                [0.36666667461395264, 0.31481480598449707],
2477                [0.38749998807907104, 0.4740740656852722],
2478                [0.5333333611488342, 0.644444465637207],
2479            ];
2480            let anchors1 = vec![
2481                [0.13750000298023224, 0.2074074000120163],
2482                [0.2541666626930237, 0.21481481194496155],
2483                [0.23125000298023224, 0.35185185074806213],
2484            ];
2485
2486            let decoder = DecoderBuilder::default()
2487                .with_config_modelpack_segdet_split(
2488                    vec![
2489                        configs::Detection {
2490                            decoder: DecoderType::ModelPack,
2491                            shape: vec![1, 17, 30, 18],
2492                            anchors: Some(anchors1),
2493                            quantization: Some(quant1),
2494                            dshape: vec![
2495                                (DimName::Batch, 1),
2496                                (DimName::Height, 17),
2497                                (DimName::Width, 30),
2498                                (DimName::NumAnchorsXFeatures, 18),
2499                            ],
2500                            normalized: None,
2501                        },
2502                        configs::Detection {
2503                            decoder: DecoderType::ModelPack,
2504                            shape: vec![1, 9, 15, 18],
2505                            anchors: Some(anchors0),
2506                            quantization: Some(quant0),
2507                            dshape: vec![
2508                                (DimName::Batch, 1),
2509                                (DimName::Height, 9),
2510                                (DimName::Width, 15),
2511                                (DimName::NumAnchorsXFeatures, 18),
2512                            ],
2513                            normalized: None,
2514                        },
2515                    ],
2516                    configs::Segmentation {
2517                        decoder: DecoderType::ModelPack,
2518                        quantization: Some(quant_seg),
2519                        shape: vec![1, 2, 160, 160],
2520                        dshape: vec![
2521                            (DimName::Batch, 1),
2522                            (DimName::NumClasses, 2),
2523                            (DimName::Height, 160),
2524                            (DimName::Width, 160),
2525                        ],
2526                    },
2527                )
2528                .with_score_threshold(score_threshold)
2529                .with_iou_threshold(iou_threshold)
2530                .build()
2531                .unwrap();
2532            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2533            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2534
2535            for _ in 0..100 {
2536                decoder
2537                    .decode_quantized(
2538                        &[
2539                            detect0.view().into(),
2540                            detect1.view().into(),
2541                            seg.view().into(),
2542                        ],
2543                        &mut output_boxes,
2544                        &mut output_masks,
2545                    )
2546                    .unwrap();
2547
2548                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2549            }
2550        };
2551
2552        let handles = vec![
2553            std::thread::spawn(yolo_det),
2554            std::thread::spawn(modelpack_det_split),
2555            std::thread::spawn(yolo_det),
2556            std::thread::spawn(modelpack_det_split),
2557            std::thread::spawn(yolo_det),
2558            std::thread::spawn(modelpack_det_split),
2559            std::thread::spawn(yolo_det),
2560            std::thread::spawn(modelpack_det_split),
2561        ];
2562        for handle in handles {
2563            handle.join().unwrap();
2564        }
2565    }
2566
2567    #[test]
2568    fn test_ndarray_to_xyxy_float() {
2569        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2570        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2571        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2572
2573        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2574        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2575        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2576    }
2577
2578    #[test]
2579    fn test_class_aware_nms_float() {
2580        use crate::float::nms_class_aware_float;
2581
2582        // Create two overlapping boxes with different classes
2583        let boxes = vec![
2584            DetectBox {
2585                bbox: BoundingBox {
2586                    xmin: 0.0,
2587                    ymin: 0.0,
2588                    xmax: 0.5,
2589                    ymax: 0.5,
2590                },
2591                score: 0.9,
2592                label: 0, // class 0
2593            },
2594            DetectBox {
2595                bbox: BoundingBox {
2596                    xmin: 0.1,
2597                    ymin: 0.1,
2598                    xmax: 0.6,
2599                    ymax: 0.6,
2600                },
2601                score: 0.8,
2602                label: 1, // class 1 - different class
2603            },
2604        ];
2605
2606        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2607        // threshold 0.3)
2608        let result = nms_class_aware_float(0.3, boxes.clone());
2609        assert_eq!(
2610            result.len(),
2611            2,
2612            "Class-aware NMS should keep both boxes with different classes"
2613        );
2614
2615        // Now test with same class - should suppress one
2616        let same_class_boxes = vec![
2617            DetectBox {
2618                bbox: BoundingBox {
2619                    xmin: 0.0,
2620                    ymin: 0.0,
2621                    xmax: 0.5,
2622                    ymax: 0.5,
2623                },
2624                score: 0.9,
2625                label: 0,
2626            },
2627            DetectBox {
2628                bbox: BoundingBox {
2629                    xmin: 0.1,
2630                    ymin: 0.1,
2631                    xmax: 0.6,
2632                    ymax: 0.6,
2633                },
2634                score: 0.8,
2635                label: 0, // same class
2636            },
2637        ];
2638
2639        let result = nms_class_aware_float(0.3, same_class_boxes);
2640        assert_eq!(
2641            result.len(),
2642            1,
2643            "Class-aware NMS should suppress overlapping box with same class"
2644        );
2645        assert_eq!(result[0].label, 0);
2646        assert!((result[0].score - 0.9).abs() < 1e-6);
2647    }
2648
2649    #[test]
2650    fn test_class_agnostic_vs_aware_nms() {
2651        use crate::float::{nms_class_aware_float, nms_float};
2652
2653        // Two overlapping boxes with different classes
2654        let boxes = vec![
2655            DetectBox {
2656                bbox: BoundingBox {
2657                    xmin: 0.0,
2658                    ymin: 0.0,
2659                    xmax: 0.5,
2660                    ymax: 0.5,
2661                },
2662                score: 0.9,
2663                label: 0,
2664            },
2665            DetectBox {
2666                bbox: BoundingBox {
2667                    xmin: 0.1,
2668                    ymin: 0.1,
2669                    xmax: 0.6,
2670                    ymax: 0.6,
2671                },
2672                score: 0.8,
2673                label: 1,
2674            },
2675        ];
2676
2677        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2678        let agnostic_result = nms_float(0.3, boxes.clone());
2679        assert_eq!(
2680            agnostic_result.len(),
2681            1,
2682            "Class-agnostic NMS should suppress overlapping boxes"
2683        );
2684
2685        // Class-aware should keep both (different classes)
2686        let aware_result = nms_class_aware_float(0.3, boxes);
2687        assert_eq!(
2688            aware_result.len(),
2689            2,
2690            "Class-aware NMS should keep boxes with different classes"
2691        );
2692    }
2693
2694    #[test]
2695    fn test_class_aware_nms_int() {
2696        use crate::byte::nms_class_aware_int;
2697
2698        // Create two overlapping boxes with different classes
2699        let boxes = vec![
2700            DetectBoxQuantized {
2701                bbox: BoundingBox {
2702                    xmin: 0.0,
2703                    ymin: 0.0,
2704                    xmax: 0.5,
2705                    ymax: 0.5,
2706                },
2707                score: 200_u8,
2708                label: 0,
2709            },
2710            DetectBoxQuantized {
2711                bbox: BoundingBox {
2712                    xmin: 0.1,
2713                    ymin: 0.1,
2714                    xmax: 0.6,
2715                    ymax: 0.6,
2716                },
2717                score: 180_u8,
2718                label: 1, // different class
2719            },
2720        ];
2721
2722        // Should keep both (different classes)
2723        let result = nms_class_aware_int(0.5, boxes);
2724        assert_eq!(
2725            result.len(),
2726            2,
2727            "Class-aware NMS (int) should keep boxes with different classes"
2728        );
2729    }
2730
2731    #[test]
2732    fn test_nms_enum_default() {
2733        // Test that Nms enum has the correct default
2734        let default_nms: configs::Nms = Default::default();
2735        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2736    }
2737
2738    #[test]
2739    fn test_decoder_nms_mode() {
2740        // Test that decoder properly stores NMS mode
2741        let decoder = DecoderBuilder::default()
2742            .with_config_yolo_det(
2743                configs::Detection {
2744                    anchors: None,
2745                    decoder: DecoderType::Ultralytics,
2746                    quantization: None,
2747                    shape: vec![1, 84, 8400],
2748                    dshape: Vec::new(),
2749                    normalized: Some(true),
2750                },
2751                None,
2752            )
2753            .with_nms(Some(configs::Nms::ClassAware))
2754            .build()
2755            .unwrap();
2756
2757        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2758    }
2759
2760    #[test]
2761    fn test_decoder_nms_bypass() {
2762        // Test that decoder can be configured with nms=None (bypass)
2763        let decoder = DecoderBuilder::default()
2764            .with_config_yolo_det(
2765                configs::Detection {
2766                    anchors: None,
2767                    decoder: DecoderType::Ultralytics,
2768                    quantization: None,
2769                    shape: vec![1, 84, 8400],
2770                    dshape: Vec::new(),
2771                    normalized: Some(true),
2772                },
2773                None,
2774            )
2775            .with_nms(None)
2776            .build()
2777            .unwrap();
2778
2779        assert_eq!(decoder.nms, None);
2780    }
2781
2782    #[test]
2783    fn test_decoder_normalized_boxes_true() {
2784        // Test that normalized_boxes returns Some(true) when explicitly set
2785        let decoder = DecoderBuilder::default()
2786            .with_config_yolo_det(
2787                configs::Detection {
2788                    anchors: None,
2789                    decoder: DecoderType::Ultralytics,
2790                    quantization: None,
2791                    shape: vec![1, 84, 8400],
2792                    dshape: Vec::new(),
2793                    normalized: Some(true),
2794                },
2795                None,
2796            )
2797            .build()
2798            .unwrap();
2799
2800        assert_eq!(decoder.normalized_boxes(), Some(true));
2801    }
2802
2803    #[test]
2804    fn test_decoder_normalized_boxes_false() {
2805        // Test that normalized_boxes returns Some(false) when config specifies
2806        // unnormalized
2807        let decoder = DecoderBuilder::default()
2808            .with_config_yolo_det(
2809                configs::Detection {
2810                    anchors: None,
2811                    decoder: DecoderType::Ultralytics,
2812                    quantization: None,
2813                    shape: vec![1, 84, 8400],
2814                    dshape: Vec::new(),
2815                    normalized: Some(false),
2816                },
2817                None,
2818            )
2819            .build()
2820            .unwrap();
2821
2822        assert_eq!(decoder.normalized_boxes(), Some(false));
2823    }
2824
2825    #[test]
2826    fn test_decoder_normalized_boxes_unknown() {
2827        // Test that normalized_boxes returns None when not specified in config
2828        let decoder = DecoderBuilder::default()
2829            .with_config_yolo_det(
2830                configs::Detection {
2831                    anchors: None,
2832                    decoder: DecoderType::Ultralytics,
2833                    quantization: None,
2834                    shape: vec![1, 84, 8400],
2835                    dshape: Vec::new(),
2836                    normalized: None,
2837                },
2838                Some(DecoderVersion::Yolo11),
2839            )
2840            .build()
2841            .unwrap();
2842
2843        assert_eq!(decoder.normalized_boxes(), None);
2844    }
2845
2846    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2847        input: ArrayView<F, D>,
2848        quant: Quantization,
2849    ) -> Array<T, D>
2850    where
2851        i32: num_traits::AsPrimitive<F>,
2852        f32: num_traits::AsPrimitive<F>,
2853    {
2854        let zero_point = quant.zero_point.as_();
2855        let div_scale = F::one() / quant.scale.as_();
2856        if zero_point != F::zero() {
2857            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2858        } else {
2859            input.mapv(|d| (d * div_scale).round().as_())
2860        }
2861    }
2862
2863    fn real_data_expected_boxes() -> [DetectBox; 2] {
2864        [
2865            DetectBox {
2866                bbox: BoundingBox {
2867                    xmin: 0.08515105,
2868                    ymin: 0.7131401,
2869                    xmax: 0.29802868,
2870                    ymax: 0.8195788,
2871                },
2872                score: 0.91537374,
2873                label: 23,
2874            },
2875            DetectBox {
2876                bbox: BoundingBox {
2877                    xmin: 0.59605736,
2878                    ymin: 0.25545314,
2879                    xmax: 0.93666154,
2880                    ymax: 0.72378385,
2881                },
2882                score: 0.91537374,
2883                label: 23,
2884            },
2885        ]
2886    }
2887
2888    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2889        [DetectBox {
2890            bbox: BoundingBox {
2891                xmin: 0.12549022,
2892                ymin: 0.12549022,
2893                xmax: 0.23529413,
2894                ymax: 0.23529413,
2895            },
2896            score: 0.98823535,
2897            label: 2,
2898        }]
2899    }
2900
2901    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2902        [DetectBox {
2903            bbox: BoundingBox {
2904                xmin: 0.1234,
2905                ymin: 0.1234,
2906                xmax: 0.2345,
2907                ymax: 0.2345,
2908            },
2909            score: 0.9876,
2910            label: 2,
2911        }]
2912    }
2913
2914    macro_rules! real_data_proto_test {
2915        ($name:ident, quantized, $layout:ident) => {
2916            #[test]
2917            fn $name() {
2918                let is_split = matches!(stringify!($layout), "split");
2919
2920                let score_threshold = 0.45;
2921                let iou_threshold = 0.45;
2922                let quant_boxes = (0.021287762_f32, 31_i32);
2923                let quant_protos = (0.02491162_f32, -117_i32);
2924
2925                let raw_boxes = include_bytes!(concat!(
2926                    env!("CARGO_MANIFEST_DIR"),
2927                    "/../../testdata/yolov8_boxes_116x8400.bin"
2928                ));
2929                let raw_boxes = unsafe {
2930                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2931                };
2932                let boxes_i8 =
2933                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2934
2935                let raw_protos = include_bytes!(concat!(
2936                    env!("CARGO_MANIFEST_DIR"),
2937                    "/../../testdata/yolov8_protos_160x160x32.bin"
2938                ));
2939                let raw_protos = unsafe {
2940                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2941                };
2942                let protos_i8 =
2943                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2944                        .unwrap();
2945
2946                // Pre-split (unused for combined, but harmless)
2947                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2948                let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2949                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2950                let boxes_combined = boxes_i8;
2951
2952                let decoder = if is_split {
2953                    build_yolo_split_segdet_decoder(
2954                        score_threshold,
2955                        iou_threshold,
2956                        quant_boxes,
2957                        quant_protos,
2958                    )
2959                } else {
2960                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2961                };
2962
2963                let expected = real_data_expected_boxes();
2964                let mut output_boxes = Vec::with_capacity(50);
2965
2966                let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
2967                    vec![
2968                        boxes_split.view().into(),
2969                        scores_split.view().into(),
2970                        mask_split.view().into(),
2971                        protos_i8.view().into(),
2972                    ]
2973                } else {
2974                    vec![boxes_combined.view().into(), protos_i8.view().into()]
2975                };
2976                decoder
2977                    .decode_quantized_proto(&inputs, &mut output_boxes)
2978                    .unwrap();
2979
2980                assert_eq!(output_boxes.len(), 2);
2981                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2982                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2983            }
2984        };
2985        ($name:ident, float, $layout:ident) => {
2986            #[test]
2987            fn $name() {
2988                let is_split = matches!(stringify!($layout), "split");
2989
2990                let score_threshold = 0.45;
2991                let iou_threshold = 0.45;
2992                let quant_boxes = (0.021287762_f32, 31_i32);
2993                let quant_protos = (0.02491162_f32, -117_i32);
2994
2995                let raw_boxes = include_bytes!(concat!(
2996                    env!("CARGO_MANIFEST_DIR"),
2997                    "/../../testdata/yolov8_boxes_116x8400.bin"
2998                ));
2999                let raw_boxes = unsafe {
3000                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
3001                };
3002                let boxes_i8 =
3003                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
3004                let boxes_f32: Array3<f32> =
3005                    dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
3006
3007                let raw_protos = include_bytes!(concat!(
3008                    env!("CARGO_MANIFEST_DIR"),
3009                    "/../../testdata/yolov8_protos_160x160x32.bin"
3010                ));
3011                let raw_protos = unsafe {
3012                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3013                };
3014                let protos_i8 =
3015                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
3016                        .unwrap();
3017                let protos_f32: Array4<f32> =
3018                    dequantize_ndarray(protos_i8.view(), quant_protos.into());
3019
3020                // Pre-split from dequantized data
3021                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
3022                let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
3023                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
3024                let boxes_combined = boxes_f32;
3025
3026                let decoder = if is_split {
3027                    build_yolo_split_segdet_decoder(
3028                        score_threshold,
3029                        iou_threshold,
3030                        quant_boxes,
3031                        quant_protos,
3032                    )
3033                } else {
3034                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
3035                };
3036
3037                let expected = real_data_expected_boxes();
3038                let mut output_boxes = Vec::with_capacity(50);
3039
3040                let inputs = if is_split {
3041                    vec![
3042                        boxes_split.view().into_dyn(),
3043                        scores_split.view().into_dyn(),
3044                        mask_split.view().into_dyn(),
3045                        protos_f32.view().into_dyn(),
3046                    ]
3047                } else {
3048                    vec![
3049                        boxes_combined.view().into_dyn(),
3050                        protos_f32.view().into_dyn(),
3051                    ]
3052                };
3053                decoder
3054                    .decode_float_proto(&inputs, &mut output_boxes)
3055                    .unwrap();
3056
3057                assert_eq!(output_boxes.len(), 2);
3058                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3059                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3060            }
3061        };
3062    }
3063
3064    real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
3065    real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
3066    real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
3067    real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3068
3069    const E2E_COMBINED_DET_CONFIG: &str = "
3070decoder_version: yolo26
3071outputs:
3072 - type: detection
3073   decoder: ultralytics
3074   quantization: [0.00784313725490196, 0]
3075   shape: [1, 10, 6]
3076   dshape:
3077    - [batch, 1]
3078    - [num_boxes, 10]
3079    - [num_features, 6]
3080   normalized: true
3081";
3082
3083    const E2E_COMBINED_SEGDET_CONFIG: &str = "
3084decoder_version: yolo26
3085outputs:
3086 - type: detection
3087   decoder: ultralytics
3088   quantization: [0.00784313725490196, 0]
3089   shape: [1, 10, 38]
3090   dshape:
3091    - [batch, 1]
3092    - [num_boxes, 10]
3093    - [num_features, 38]
3094   normalized: true
3095 - type: protos
3096   decoder: ultralytics
3097   quantization: [0.0039215686274509803921568627451, 128]
3098   shape: [1, 160, 160, 32]
3099   dshape:
3100    - [batch, 1]
3101    - [height, 160]
3102    - [width, 160]
3103    - [num_protos, 32]
3104";
3105
3106    const E2E_SPLIT_DET_CONFIG: &str = "
3107decoder_version: yolo26
3108outputs:
3109 - type: boxes
3110   decoder: ultralytics
3111   quantization: [0.00784313725490196, 0]
3112   shape: [1, 10, 4]
3113   dshape:
3114    - [batch, 1]
3115    - [num_boxes, 10]
3116    - [box_coords, 4]
3117   normalized: true
3118 - type: scores
3119   decoder: ultralytics
3120   quantization: [0.00784313725490196, 0]
3121   shape: [1, 10, 1]
3122   dshape:
3123    - [batch, 1]
3124    - [num_boxes, 10]
3125    - [num_classes, 1]
3126 - type: classes
3127   decoder: ultralytics
3128   quantization: [0.00784313725490196, 0]
3129   shape: [1, 10, 1]
3130   dshape:
3131    - [batch, 1]
3132    - [num_boxes, 10]
3133    - [num_classes, 1]
3134";
3135
3136    const E2E_SPLIT_SEGDET_CONFIG: &str = "
3137decoder_version: yolo26
3138outputs:
3139 - type: boxes
3140   decoder: ultralytics
3141   quantization: [0.00784313725490196, 0]
3142   shape: [1, 10, 4]
3143   dshape:
3144    - [batch, 1]
3145    - [num_boxes, 10]
3146    - [box_coords, 4]
3147   normalized: true
3148 - type: scores
3149   decoder: ultralytics
3150   quantization: [0.00784313725490196, 0]
3151   shape: [1, 10, 1]
3152   dshape:
3153    - [batch, 1]
3154    - [num_boxes, 10]
3155    - [num_classes, 1]
3156 - type: classes
3157   decoder: ultralytics
3158   quantization: [0.00784313725490196, 0]
3159   shape: [1, 10, 1]
3160   dshape:
3161    - [batch, 1]
3162    - [num_boxes, 10]
3163    - [num_classes, 1]
3164 - type: mask_coefficients
3165   decoder: ultralytics
3166   quantization: [0.00784313725490196, 0]
3167   shape: [1, 10, 32]
3168   dshape:
3169    - [batch, 1]
3170    - [num_boxes, 10]
3171    - [num_protos, 32]
3172 - type: protos
3173   decoder: ultralytics
3174   quantization: [0.0039215686274509803921568627451, 128]
3175   shape: [1, 160, 160, 32]
3176   dshape:
3177    - [batch, 1]
3178    - [height, 160]
3179    - [width, 160]
3180    - [num_protos, 32]
3181";
3182
3183    macro_rules! e2e_segdet_test {
3184        ($name:ident, quantized, $layout:ident, $output:ident) => {
3185            #[test]
3186            fn $name() {
3187                let is_split = matches!(stringify!($layout), "split");
3188                let is_proto = matches!(stringify!($output), "proto");
3189
3190                let score_threshold = 0.45;
3191                let iou_threshold = 0.45;
3192
3193                let mut boxes = Array2::zeros((10, 4));
3194                let mut scores = Array2::zeros((10, 1));
3195                let mut classes = Array2::zeros((10, 1));
3196                let mask = Array2::zeros((10, 32));
3197                let protos = Array3::<f64>::zeros((160, 160, 32));
3198                let protos = protos.insert_axis(Axis(0));
3199                let protos_quant = (1.0 / 255.0, 0.0);
3200                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3201
3202                boxes
3203                    .slice_mut(s![0, ..])
3204                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3205                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3206                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3207
3208                let detect_quant = (2.0 / 255.0, 0.0);
3209
3210                let decoder = if is_split {
3211                    DecoderBuilder::default()
3212                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3213                        .with_score_threshold(score_threshold)
3214                        .with_iou_threshold(iou_threshold)
3215                        .build()
3216                        .unwrap()
3217                } else {
3218                    DecoderBuilder::default()
3219                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3220                        .with_score_threshold(score_threshold)
3221                        .with_iou_threshold(iou_threshold)
3222                        .build()
3223                        .unwrap()
3224                };
3225
3226                let expected = e2e_expected_boxes_quant();
3227                let mut output_boxes = Vec::with_capacity(50);
3228
3229                if is_split {
3230                    let boxes = boxes.insert_axis(Axis(0));
3231                    let scores = scores.insert_axis(Axis(0));
3232                    let classes = classes.insert_axis(Axis(0));
3233                    let mask = mask.insert_axis(Axis(0));
3234
3235                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3236                    let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3237                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3238                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3239
3240                    if is_proto {
3241                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3242                            boxes.view().into(),
3243                            scores.view().into(),
3244                            classes.view().into(),
3245                            mask.view().into(),
3246                            protos.view().into(),
3247                        ];
3248                        decoder
3249                            .decode_quantized_proto(&inputs, &mut output_boxes)
3250                            .unwrap();
3251
3252                        assert_eq!(output_boxes.len(), 1);
3253                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3254                    } else {
3255                        let mut output_masks = Vec::with_capacity(50);
3256                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3257                            boxes.view().into(),
3258                            scores.view().into(),
3259                            classes.view().into(),
3260                            mask.view().into(),
3261                            protos.view().into(),
3262                        ];
3263                        decoder
3264                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3265                            .unwrap();
3266
3267                        assert_eq!(output_boxes.len(), 1);
3268                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3269                    }
3270                } else {
3271                    // Combined layout
3272                    let detect = ndarray::concatenate![
3273                        Axis(1),
3274                        boxes.view(),
3275                        scores.view(),
3276                        classes.view(),
3277                        mask.view()
3278                    ];
3279                    let detect = detect.insert_axis(Axis(0));
3280                    assert_eq!(detect.shape(), &[1, 10, 38]);
3281                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3282
3283                    if is_proto {
3284                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3285                            vec![detect.view().into(), protos.view().into()];
3286                        decoder
3287                            .decode_quantized_proto(&inputs, &mut output_boxes)
3288                            .unwrap();
3289
3290                        assert_eq!(output_boxes.len(), 1);
3291                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3292                    } else {
3293                        let mut output_masks = Vec::with_capacity(50);
3294                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3295                            vec![detect.view().into(), protos.view().into()];
3296                        decoder
3297                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3298                            .unwrap();
3299
3300                        assert_eq!(output_boxes.len(), 1);
3301                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3302                    }
3303                }
3304            }
3305        };
3306        ($name:ident, float, $layout:ident, $output:ident) => {
3307            #[test]
3308            fn $name() {
3309                let is_split = matches!(stringify!($layout), "split");
3310                let is_proto = matches!(stringify!($output), "proto");
3311
3312                let score_threshold = 0.45;
3313                let iou_threshold = 0.45;
3314
3315                let mut boxes = Array2::zeros((10, 4));
3316                let mut scores = Array2::zeros((10, 1));
3317                let mut classes = Array2::zeros((10, 1));
3318                let mask: Array2<f64> = Array2::zeros((10, 32));
3319                let protos = Array3::<f64>::zeros((160, 160, 32));
3320                let protos = protos.insert_axis(Axis(0));
3321
3322                boxes
3323                    .slice_mut(s![0, ..])
3324                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3325                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3326                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3327
3328                let decoder = if is_split {
3329                    DecoderBuilder::default()
3330                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3331                        .with_score_threshold(score_threshold)
3332                        .with_iou_threshold(iou_threshold)
3333                        .build()
3334                        .unwrap()
3335                } else {
3336                    DecoderBuilder::default()
3337                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3338                        .with_score_threshold(score_threshold)
3339                        .with_iou_threshold(iou_threshold)
3340                        .build()
3341                        .unwrap()
3342                };
3343
3344                let expected = e2e_expected_boxes_float();
3345                let mut output_boxes = Vec::with_capacity(50);
3346
3347                if is_split {
3348                    let boxes = boxes.insert_axis(Axis(0));
3349                    let scores = scores.insert_axis(Axis(0));
3350                    let classes = classes.insert_axis(Axis(0));
3351                    let mask = mask.insert_axis(Axis(0));
3352
3353                    if is_proto {
3354                        let inputs = vec![
3355                            boxes.view().into_dyn(),
3356                            scores.view().into_dyn(),
3357                            classes.view().into_dyn(),
3358                            mask.view().into_dyn(),
3359                            protos.view().into_dyn(),
3360                        ];
3361                        decoder
3362                            .decode_float_proto(&inputs, &mut output_boxes)
3363                            .unwrap();
3364
3365                        assert_eq!(output_boxes.len(), 1);
3366                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3367                    } else {
3368                        let mut output_masks = Vec::with_capacity(50);
3369                        let inputs = vec![
3370                            boxes.view().into_dyn(),
3371                            scores.view().into_dyn(),
3372                            classes.view().into_dyn(),
3373                            mask.view().into_dyn(),
3374                            protos.view().into_dyn(),
3375                        ];
3376                        decoder
3377                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3378                            .unwrap();
3379
3380                        assert_eq!(output_boxes.len(), 1);
3381                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3382                    }
3383                } else {
3384                    // Combined layout
3385                    let detect = ndarray::concatenate![
3386                        Axis(1),
3387                        boxes.view(),
3388                        scores.view(),
3389                        classes.view(),
3390                        mask.view()
3391                    ];
3392                    let detect = detect.insert_axis(Axis(0));
3393                    assert_eq!(detect.shape(), &[1, 10, 38]);
3394
3395                    if is_proto {
3396                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3397                        decoder
3398                            .decode_float_proto(&inputs, &mut output_boxes)
3399                            .unwrap();
3400
3401                        assert_eq!(output_boxes.len(), 1);
3402                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3403                    } else {
3404                        let mut output_masks = Vec::with_capacity(50);
3405                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3406                        decoder
3407                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3408                            .unwrap();
3409
3410                        assert_eq!(output_boxes.len(), 1);
3411                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3412                    }
3413                }
3414            }
3415        };
3416    }
3417
3418    e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3419    e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3420    e2e_segdet_test!(
3421        test_decoder_end_to_end_segdet_proto,
3422        quantized,
3423        combined,
3424        proto
3425    );
3426    e2e_segdet_test!(
3427        test_decoder_end_to_end_segdet_proto_float,
3428        float,
3429        combined,
3430        proto
3431    );
3432    e2e_segdet_test!(
3433        test_decoder_end_to_end_segdet_split,
3434        quantized,
3435        split,
3436        masks
3437    );
3438    e2e_segdet_test!(
3439        test_decoder_end_to_end_segdet_split_float,
3440        float,
3441        split,
3442        masks
3443    );
3444    e2e_segdet_test!(
3445        test_decoder_end_to_end_segdet_split_proto,
3446        quantized,
3447        split,
3448        proto
3449    );
3450    e2e_segdet_test!(
3451        test_decoder_end_to_end_segdet_split_proto_float,
3452        float,
3453        split,
3454        proto
3455    );
3456
3457    macro_rules! e2e_det_test {
3458        ($name:ident, quantized, $layout:ident) => {
3459            #[test]
3460            fn $name() {
3461                let is_split = matches!(stringify!($layout), "split");
3462
3463                let score_threshold = 0.45;
3464                let iou_threshold = 0.45;
3465
3466                let mut boxes = Array3::zeros((1, 10, 4));
3467                let mut scores = Array3::zeros((1, 10, 1));
3468                let mut classes = Array3::zeros((1, 10, 1));
3469
3470                boxes
3471                    .slice_mut(s![0, 0, ..])
3472                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3473                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3474                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3475
3476                let detect_quant = (2.0 / 255.0, 0_i32);
3477
3478                let decoder = if is_split {
3479                    DecoderBuilder::default()
3480                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3481                        .with_score_threshold(score_threshold)
3482                        .with_iou_threshold(iou_threshold)
3483                        .build()
3484                        .unwrap()
3485                } else {
3486                    DecoderBuilder::default()
3487                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3488                        .with_score_threshold(score_threshold)
3489                        .with_iou_threshold(iou_threshold)
3490                        .build()
3491                        .unwrap()
3492                };
3493
3494                let expected = e2e_expected_boxes_quant();
3495                let mut output_boxes = Vec::with_capacity(50);
3496
3497                if is_split {
3498                    let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3499                    let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3500                    let classes: Array<u8, _> =
3501                        quantize_ndarray(classes.view(), detect_quant.into());
3502                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3503                        boxes.view().into(),
3504                        scores.view().into(),
3505                        classes.view().into(),
3506                    ];
3507                    decoder
3508                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3509                        .unwrap();
3510                } else {
3511                    let detect =
3512                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3513                    assert_eq!(detect.shape(), &[1, 10, 6]);
3514                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3515                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3516                        vec![detect.view().into()];
3517                    decoder
3518                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3519                        .unwrap();
3520                }
3521
3522                assert_eq!(output_boxes.len(), 1);
3523                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3524            }
3525        };
3526        ($name:ident, float, $layout:ident) => {
3527            #[test]
3528            fn $name() {
3529                let is_split = matches!(stringify!($layout), "split");
3530
3531                let score_threshold = 0.45;
3532                let iou_threshold = 0.45;
3533
3534                let mut boxes = Array3::zeros((1, 10, 4));
3535                let mut scores = Array3::zeros((1, 10, 1));
3536                let mut classes = Array3::zeros((1, 10, 1));
3537
3538                boxes
3539                    .slice_mut(s![0, 0, ..])
3540                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3541                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3542                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3543
3544                let decoder = if is_split {
3545                    DecoderBuilder::default()
3546                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3547                        .with_score_threshold(score_threshold)
3548                        .with_iou_threshold(iou_threshold)
3549                        .build()
3550                        .unwrap()
3551                } else {
3552                    DecoderBuilder::default()
3553                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3554                        .with_score_threshold(score_threshold)
3555                        .with_iou_threshold(iou_threshold)
3556                        .build()
3557                        .unwrap()
3558                };
3559
3560                let expected = e2e_expected_boxes_float();
3561                let mut output_boxes = Vec::with_capacity(50);
3562
3563                if is_split {
3564                    let inputs = vec![
3565                        boxes.view().into_dyn(),
3566                        scores.view().into_dyn(),
3567                        classes.view().into_dyn(),
3568                    ];
3569                    decoder
3570                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3571                        .unwrap();
3572                } else {
3573                    let detect =
3574                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3575                    assert_eq!(detect.shape(), &[1, 10, 6]);
3576                    let inputs = vec![detect.view().into_dyn()];
3577                    decoder
3578                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3579                        .unwrap();
3580                }
3581
3582                assert_eq!(output_boxes.len(), 1);
3583                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3584            }
3585        };
3586    }
3587
3588    e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3589    e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3590    e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3591    e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3592
3593    #[test]
3594    fn test_decode_tensor() {
3595        let score_threshold = 0.45;
3596        let iou_threshold = 0.45;
3597
3598        let raw_boxes = include_bytes!(concat!(
3599            env!("CARGO_MANIFEST_DIR"),
3600            "/../../testdata/yolov8_boxes_116x8400.bin"
3601        ));
3602        let raw_boxes =
3603            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3604        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3605        boxes_i8
3606            .map()
3607            .unwrap()
3608            .as_mut_slice()
3609            .copy_from_slice(raw_boxes);
3610        let boxes_i8 = boxes_i8.into();
3611
3612        let raw_protos = include_bytes!(concat!(
3613            env!("CARGO_MANIFEST_DIR"),
3614            "/../../testdata/yolov8_protos_160x160x32.bin"
3615        ));
3616        let raw_protos = unsafe {
3617            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3618        };
3619        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3620        protos_i8
3621            .map()
3622            .unwrap()
3623            .as_mut_slice()
3624            .copy_from_slice(raw_protos);
3625        let protos_i8 = protos_i8.into();
3626
3627        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3628        let expected = real_data_expected_boxes();
3629        let mut output_boxes = Vec::with_capacity(50);
3630
3631        decoder
3632            .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3633            .unwrap();
3634
3635        assert_eq!(output_boxes.len(), 2);
3636        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3637        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3638    }
3639
3640    #[test]
3641    fn test_decode_tensor_f32() {
3642        let score_threshold = 0.45;
3643        let iou_threshold = 0.45;
3644
3645        let quant_boxes = (0.021287762_f32, 31_i32);
3646        let quant_protos = (0.02491162_f32, -117_i32);
3647        let raw_boxes = include_bytes!(concat!(
3648            env!("CARGO_MANIFEST_DIR"),
3649            "/../../testdata/yolov8_boxes_116x8400.bin"
3650        ));
3651        let raw_boxes =
3652            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3653        let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3654        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3655        let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3656        boxes_f32
3657            .map()
3658            .unwrap()
3659            .as_mut_slice()
3660            .copy_from_slice(&raw_boxes_f32);
3661        let boxes_f32 = boxes_f32.into();
3662
3663        let raw_protos = include_bytes!(concat!(
3664            env!("CARGO_MANIFEST_DIR"),
3665            "/../../testdata/yolov8_protos_160x160x32.bin"
3666        ));
3667        let raw_protos = unsafe {
3668            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3669        };
3670        let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3671        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3672        let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3673        protos_f32
3674            .map()
3675            .unwrap()
3676            .as_mut_slice()
3677            .copy_from_slice(&raw_protos_f32);
3678        let protos_f32 = protos_f32.into();
3679
3680        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3681
3682        let expected = real_data_expected_boxes();
3683        let mut output_boxes = Vec::with_capacity(50);
3684
3685        decoder
3686            .decode(
3687                &[&boxes_f32, &protos_f32],
3688                &mut output_boxes,
3689                &mut Vec::new(),
3690            )
3691            .unwrap();
3692
3693        assert_eq!(output_boxes.len(), 2);
3694        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3695        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3696    }
3697
3698    #[test]
3699    fn test_decode_tensor_f64() {
3700        let score_threshold = 0.45;
3701        let iou_threshold = 0.45;
3702
3703        let quant_boxes = (0.021287762_f32, 31_i32);
3704        let quant_protos = (0.02491162_f32, -117_i32);
3705        let raw_boxes = include_bytes!(concat!(
3706            env!("CARGO_MANIFEST_DIR"),
3707            "/../../testdata/yolov8_boxes_116x8400.bin"
3708        ));
3709        let raw_boxes =
3710            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3711        let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3712        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3713        let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3714        boxes_f64
3715            .map()
3716            .unwrap()
3717            .as_mut_slice()
3718            .copy_from_slice(&raw_boxes_f64);
3719        let boxes_f64 = boxes_f64.into();
3720
3721        let raw_protos = include_bytes!(concat!(
3722            env!("CARGO_MANIFEST_DIR"),
3723            "/../../testdata/yolov8_protos_160x160x32.bin"
3724        ));
3725        let raw_protos = unsafe {
3726            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3727        };
3728        let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3729        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3730        let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3731        protos_f64
3732            .map()
3733            .unwrap()
3734            .as_mut_slice()
3735            .copy_from_slice(&raw_protos_f64);
3736        let protos_f64 = protos_f64.into();
3737
3738        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3739
3740        let expected = real_data_expected_boxes();
3741        let mut output_boxes = Vec::with_capacity(50);
3742
3743        decoder
3744            .decode(
3745                &[&boxes_f64, &protos_f64],
3746                &mut output_boxes,
3747                &mut Vec::new(),
3748            )
3749            .unwrap();
3750
3751        assert_eq!(output_boxes.len(), 2);
3752        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3753        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3754    }
3755
3756    #[test]
3757    fn test_decode_tensor_proto() {
3758        let score_threshold = 0.45;
3759        let iou_threshold = 0.45;
3760
3761        let raw_boxes = include_bytes!(concat!(
3762            env!("CARGO_MANIFEST_DIR"),
3763            "/../../testdata/yolov8_boxes_116x8400.bin"
3764        ));
3765        let raw_boxes =
3766            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3767        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3768        boxes_i8
3769            .map()
3770            .unwrap()
3771            .as_mut_slice()
3772            .copy_from_slice(raw_boxes);
3773        let boxes_i8 = boxes_i8.into();
3774
3775        let raw_protos = include_bytes!(concat!(
3776            env!("CARGO_MANIFEST_DIR"),
3777            "/../../testdata/yolov8_protos_160x160x32.bin"
3778        ));
3779        let raw_protos = unsafe {
3780            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3781        };
3782        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3783        protos_i8
3784            .map()
3785            .unwrap()
3786            .as_mut_slice()
3787            .copy_from_slice(raw_protos);
3788        let protos_i8 = protos_i8.into();
3789
3790        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3791
3792        let expected = real_data_expected_boxes();
3793        let mut output_boxes = Vec::with_capacity(50);
3794
3795        let proto_data = decoder
3796            .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3797            .unwrap();
3798
3799        assert_eq!(output_boxes.len(), 2);
3800        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3801        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3802
3803        let proto_data = proto_data.expect("segmentation model should return ProtoData");
3804        let coeffs_shape = proto_data.mask_coefficients.shape();
3805        assert_eq!(
3806            coeffs_shape[0],
3807            output_boxes.len(),
3808            "mask_coefficients count must match detection count"
3809        );
3810        assert_eq!(
3811            coeffs_shape[1], 32,
3812            "each detection should have 32 mask coefficients"
3813        );
3814    }
3815
3816    // =========================================================================
3817    // Physical-order contract regression tests
3818    //
3819    // These cover the TFLite VX-delegate vertical-stripe mask bug and the
3820    // Ara-2-via-overlay anchor-first split-tensor bug. The core claim:
3821    // declaring shape+dshape in physical memory order produces correct
3822    // strides at wrap time, and `swap_axes_if_needed` permutes the stride
3823    // tuple into canonical logical order without touching bytes.
3824    // =========================================================================
3825
3826    /// TFLite YOLOv8-seg protos arrive as physically-NHWC bytes with
3827    /// NNStreamer dim string `"32:160:160:1"` (innermost-first). The
3828    /// overlay's 2026-04-22 workaround declares `[1, 160, 160, 32]` with
3829    /// dshape `[batch, height, width, num_protos]` — shape matches
3830    /// physical order, so no permutation is needed and the view is
3831    /// already in canonical HWC after the batch-dim slice. This test
3832    /// confirms that path matches the reference (direct-HWC) decode.
3833    #[test]
3834    fn test_physical_order_tflite_nhwc_protos() {
3835        let score_threshold = 0.45;
3836        let iou_threshold = 0.45;
3837
3838        // Load HWC protos and dequantize to f32.
3839        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
3840        let quant_protos = Quantization::new(0.02491161972284317, -117);
3841        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
3842
3843        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
3844        let quant_boxes = Quantization::new(0.021287761628627777, 31);
3845        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
3846
3847        // Reference: direct call with canonical HWC protos.
3848        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
3849        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
3850        decode_yolo_segdet_float(
3851            seg.view(),
3852            protos_f32_hwc.view(),
3853            score_threshold,
3854            iou_threshold,
3855            Some(configs::Nms::ClassAgnostic),
3856            &mut ref_boxes,
3857            &mut ref_masks,
3858        )
3859        .unwrap();
3860
3861        // Build the same protos as a 4D NHWC tensor and feed it through
3862        // the config-driven path with dshape in physical order.
3863        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0)); // [1, 160, 160, 32]
3864        let seg_3d = seg.insert_axis(Axis(0)); // [1, 116, 8400]
3865
3866        let decoder = DecoderBuilder::default()
3867            .with_config_yolo_segdet(
3868                configs::Detection {
3869                    decoder: configs::DecoderType::Ultralytics,
3870                    quantization: None,
3871                    shape: vec![1, 116, 8400],
3872                    dshape: vec![
3873                        (DimName::Batch, 1),
3874                        (DimName::NumFeatures, 116),
3875                        (DimName::NumBoxes, 8400),
3876                    ],
3877                    normalized: Some(true),
3878                    anchors: None,
3879                },
3880                configs::Protos {
3881                    decoder: configs::DecoderType::Ultralytics,
3882                    quantization: None,
3883                    shape: vec![1, 160, 160, 32],
3884                    // Physical NHWC — matches the TFLite buffer layout.
3885                    dshape: vec![
3886                        (DimName::Batch, 1),
3887                        (DimName::Height, 160),
3888                        (DimName::Width, 160),
3889                        (DimName::NumProtos, 32),
3890                    ],
3891                },
3892                None,
3893            )
3894            .with_score_threshold(score_threshold)
3895            .with_iou_threshold(iou_threshold)
3896            .build()
3897            .expect("config with NHWC protos dshape must build");
3898
3899        let mut cfg_boxes = Vec::with_capacity(10);
3900        let mut cfg_masks = Vec::with_capacity(10);
3901        decoder
3902            .decode_float(
3903                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
3904                &mut cfg_boxes,
3905                &mut cfg_masks,
3906            )
3907            .unwrap();
3908
3909        assert_eq!(cfg_boxes.len(), ref_boxes.len(), "box count mismatch");
3910        for (c, r) in cfg_boxes.iter().zip(&ref_boxes) {
3911            assert!(
3912                c.equal_within_delta(r, 0.01),
3913                "NHWC-declared box does not match reference: {c:?} vs {r:?}"
3914            );
3915        }
3916        for (cm, rm) in cfg_masks.iter().zip(&ref_masks) {
3917            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
3918            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
3919            assert_eq!(
3920                cm_arr, rm_arr,
3921                "NHWC-declared mask must match reference pixel-for-pixel"
3922            );
3923        }
3924    }
3925
3926    /// Ara-2 split-tensor boxes arrive physically anchor-first: NNStreamer
3927    /// dim string `"4:1:8400:1"` means innermost axis is 4 (coords), next
3928    /// is 1 (padding), next is 8400 (anchors), outermost is 1 (batch).
3929    /// The correct physical-order declaration is `[1, 8400, 1, 4]` with
3930    /// dshape `[batch, num_boxes, padding, box_coords]`. This test
3931    /// confirms that declaration produces the same decoded boxes as the
3932    /// canonical features-first TFLite declaration over the same logical
3933    /// data.
3934    #[test]
3935    fn test_physical_order_ara2_anchor_first_split_boxes() {
3936        use configs::{Boxes, Scores};
3937
3938        // Build synthetic boxes data in features-first canonical layout:
3939        // shape [1, 4, 8400] with one detection at anchor 42.
3940        const N: usize = 8400;
3941        let mut boxes_canonical = Array3::<f32>::zeros((1, 4, N));
3942        let target_anchor = 42usize;
3943        boxes_canonical[[0, 0, target_anchor]] = 0.4; // xc
3944        boxes_canonical[[0, 1, target_anchor]] = 0.5; // yc
3945        boxes_canonical[[0, 2, target_anchor]] = 0.2; // w
3946        boxes_canonical[[0, 3, target_anchor]] = 0.2; // h
3947
3948        // Scores: one class at 0.9 for the same anchor.
3949        let mut scores_canonical = Array3::<f32>::zeros((1, 80, N));
3950        scores_canonical[[0, 0, target_anchor]] = 0.9;
3951
3952        // Reference: canonical (features-first) shape declaration.
3953        let ref_decoder = DecoderBuilder::default()
3954            .with_config_yolo_split_det(
3955                Boxes {
3956                    decoder: configs::DecoderType::Ultralytics,
3957                    quantization: None,
3958                    shape: vec![1, 4, N],
3959                    dshape: vec![
3960                        (DimName::Batch, 1),
3961                        (DimName::BoxCoords, 4),
3962                        (DimName::NumBoxes, N),
3963                    ],
3964                    normalized: Some(true),
3965                },
3966                Scores {
3967                    decoder: configs::DecoderType::Ultralytics,
3968                    quantization: None,
3969                    shape: vec![1, 80, N],
3970                    dshape: vec![
3971                        (DimName::Batch, 1),
3972                        (DimName::NumClasses, 80),
3973                        (DimName::NumBoxes, N),
3974                    ],
3975                },
3976            )
3977            .with_score_threshold(0.5)
3978            .with_iou_threshold(0.5)
3979            .with_nms(Some(configs::Nms::ClassAgnostic))
3980            .build()
3981            .expect("reference canonical split decoder must build");
3982
3983        let mut ref_boxes = Vec::with_capacity(4);
3984        let mut ref_masks = Vec::with_capacity(0);
3985        ref_decoder
3986            .decode_float(
3987                &[
3988                    boxes_canonical.view().into_dyn(),
3989                    scores_canonical.view().into_dyn(),
3990                ],
3991                &mut ref_boxes,
3992                &mut ref_masks,
3993            )
3994            .unwrap();
3995        assert_eq!(ref_boxes.len(), 1, "reference should produce one box");
3996
3997        // Ara-2 physical layout: transpose axes 1↔2 so the innermost
3998        // axis is BoxCoords / NumClasses. Materialise to get a
3999        // C-contiguous Array3 with strides matching the physical order
4000        // the Ara-2 backend would write into the tensor.
4001        let boxes_ara2 = boxes_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 4]
4002        let scores_ara2 = scores_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 80]
4003
4004        let ara2_decoder = DecoderBuilder::default()
4005            .with_config_yolo_split_det(
4006                Boxes {
4007                    decoder: configs::DecoderType::Ultralytics,
4008                    quantization: None,
4009                    shape: vec![1, N, 4],
4010                    dshape: vec![
4011                        (DimName::Batch, 1),
4012                        (DimName::NumBoxes, N),
4013                        (DimName::BoxCoords, 4),
4014                    ],
4015                    normalized: Some(true),
4016                },
4017                Scores {
4018                    decoder: configs::DecoderType::Ultralytics,
4019                    quantization: None,
4020                    shape: vec![1, N, 80],
4021                    dshape: vec![
4022                        (DimName::Batch, 1),
4023                        (DimName::NumBoxes, N),
4024                        (DimName::NumClasses, 80),
4025                    ],
4026                },
4027            )
4028            .with_score_threshold(0.5)
4029            .with_iou_threshold(0.5)
4030            .with_nms(Some(configs::Nms::ClassAgnostic))
4031            .build()
4032            .expect("Ara-2 anchor-first decoder must build");
4033
4034        let mut ara2_boxes = Vec::with_capacity(4);
4035        let mut ara2_masks = Vec::with_capacity(0);
4036        ara2_decoder
4037            .decode_float(
4038                &[boxes_ara2.view().into_dyn(), scores_ara2.view().into_dyn()],
4039                &mut ara2_boxes,
4040                &mut ara2_masks,
4041            )
4042            .unwrap();
4043
4044        assert_eq!(
4045            ara2_boxes.len(),
4046            ref_boxes.len(),
4047            "Ara-2 anchor-first declaration must produce the same number \
4048             of boxes as the canonical features-first reference"
4049        );
4050        for (a, r) in ara2_boxes.iter().zip(&ref_boxes) {
4051            assert!(
4052                a.equal_within_delta(r, 1e-4),
4053                "Ara-2 box differs from reference: {a:?} vs {r:?}"
4054            );
4055        }
4056    }
4057
4058    /// Dshape whose per-axis sizes disagree with shape (the exact
4059    /// failure mode that caused the TFLite stripe bug — shape declared
4060    /// in one order, dshape in another) is rejected with a clear error.
4061    #[test]
4062    fn test_physical_order_rejects_shape_dshape_mismatch() {
4063        let result = DecoderBuilder::default()
4064            .with_config_yolo_segdet(
4065                configs::Detection {
4066                    decoder: configs::DecoderType::Ultralytics,
4067                    quantization: None,
4068                    shape: vec![1, 116, 8400],
4069                    dshape: vec![
4070                        (DimName::Batch, 1),
4071                        (DimName::NumFeatures, 116),
4072                        (DimName::NumBoxes, 8400),
4073                    ],
4074                    normalized: Some(true),
4075                    anchors: None,
4076                },
4077                configs::Protos {
4078                    decoder: configs::DecoderType::Ultralytics,
4079                    quantization: None,
4080                    // Shape in NCHW order...
4081                    shape: vec![1, 32, 160, 160],
4082                    // ...but dshape in NHWC order — the sizes don't line
4083                    // up positionally (dshape[1]=160 vs shape[1]=32).
4084                    dshape: vec![
4085                        (DimName::Batch, 1),
4086                        (DimName::Height, 160),
4087                        (DimName::Width, 160),
4088                        (DimName::NumProtos, 32),
4089                    ],
4090                },
4091                None,
4092            )
4093            .build();
4094
4095        match result {
4096            Err(DecoderError::InvalidConfig(msg)) => {
4097                assert!(
4098                    msg.contains("does not match shape"),
4099                    "expected shape/dshape size mismatch error, got: {msg}"
4100                );
4101            }
4102            other => panic!("expected InvalidConfig, got {other:?}"),
4103        }
4104    }
4105
4106    /// Duplicate dim name in dshape is rejected (two axes mapping to the
4107    /// same canonical slot would break the permutation).
4108    #[test]
4109    fn test_physical_order_rejects_duplicate_dshape_axis() {
4110        let result = DecoderBuilder::default()
4111            .with_config_yolo_split_det(
4112                configs::Boxes {
4113                    decoder: configs::DecoderType::Ultralytics,
4114                    quantization: None,
4115                    shape: vec![1, 4, 8400],
4116                    dshape: vec![
4117                        (DimName::Batch, 1),
4118                        (DimName::BoxCoords, 4),
4119                        (DimName::BoxCoords, 4), // duplicate (size matches shape[2]=8400? no)
4120                    ],
4121                    normalized: Some(true),
4122                },
4123                configs::Scores {
4124                    decoder: configs::DecoderType::Ultralytics,
4125                    quantization: None,
4126                    shape: vec![1, 80, 8400],
4127                    dshape: vec![
4128                        (DimName::Batch, 1),
4129                        (DimName::NumClasses, 80),
4130                        (DimName::NumBoxes, 8400),
4131                    ],
4132                },
4133            )
4134            .build();
4135
4136        // Shape[2] = 8400 ≠ dshape[2] size 4, so the positional-mismatch
4137        // check fires first — which is fine: any mis-shaped dshape
4138        // should be rejected. Check for either error as long as it's a
4139        // clear InvalidConfig.
4140        match result {
4141            Err(DecoderError::InvalidConfig(msg)) => {
4142                assert!(
4143                    msg.contains("appears at both index") || msg.contains("does not match shape"),
4144                    "expected positional or duplicate-axis error, got: {msg}"
4145                );
4146            }
4147            other => panic!("expected InvalidConfig, got {other:?}"),
4148        }
4149
4150        // Separate case: size-consistent duplicate to exercise the
4151        // duplicate-axis code path explicitly. Use `Batch` (no size
4152        // constraint, can repeat with size 1 against a shape that
4153        // carries two singleton dims).
4154        let result = DecoderBuilder::default()
4155            .with_config_yolo_split_det(
4156                configs::Boxes {
4157                    decoder: configs::DecoderType::Ultralytics,
4158                    quantization: None,
4159                    shape: vec![1, 1, 4, 8400],
4160                    dshape: vec![
4161                        (DimName::Batch, 1),
4162                        (DimName::Batch, 1), // duplicate, sizes match
4163                        (DimName::BoxCoords, 4),
4164                        (DimName::NumBoxes, 8400),
4165                    ],
4166                    normalized: Some(true),
4167                },
4168                configs::Scores {
4169                    decoder: configs::DecoderType::Ultralytics,
4170                    quantization: None,
4171                    shape: vec![1, 80, 8400],
4172                    dshape: vec![
4173                        (DimName::Batch, 1),
4174                        (DimName::NumClasses, 80),
4175                        (DimName::NumBoxes, 8400),
4176                    ],
4177                },
4178            )
4179            .build();
4180        match result {
4181            Err(DecoderError::InvalidConfig(msg)) => {
4182                assert!(
4183                    msg.contains("appears at both index"),
4184                    "expected duplicate-axis error, got: {msg}"
4185                );
4186            }
4187            other => panic!("expected InvalidConfig, got {other:?}"),
4188        }
4189    }
4190
4191    /// Canonical (dshape-omitted) high-level decode produces the same
4192    /// numeric result as the dshape-populated form. This closes a
4193    /// coverage gap flagged during review: dshape omission had been
4194    /// exercised only at the builder level, not through the full
4195    /// `decode_float` pipeline.
4196    #[test]
4197    fn test_physical_order_dshape_omitted_decodes_numerically() {
4198        let score_threshold = 0.45;
4199        let iou_threshold = 0.45;
4200
4201        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
4202        let quant_protos = Quantization::new(0.02491161972284317, -117);
4203        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
4204
4205        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
4206        let quant_boxes = Quantization::new(0.021287761628627777, 31);
4207        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
4208
4209        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0));
4210        let seg_3d = seg.insert_axis(Axis(0));
4211
4212        let build_decoder = |det_dshape: Vec<(DimName, usize)>,
4213                             proto_dshape: Vec<(DimName, usize)>| {
4214            DecoderBuilder::default()
4215                .with_config_yolo_segdet(
4216                    configs::Detection {
4217                        decoder: configs::DecoderType::Ultralytics,
4218                        quantization: None,
4219                        shape: vec![1, 116, 8400],
4220                        dshape: det_dshape,
4221                        normalized: Some(true),
4222                        anchors: None,
4223                    },
4224                    configs::Protos {
4225                        decoder: configs::DecoderType::Ultralytics,
4226                        quantization: None,
4227                        shape: vec![1, 160, 160, 32],
4228                        dshape: proto_dshape,
4229                    },
4230                    None,
4231                )
4232                .with_score_threshold(score_threshold)
4233                .with_iou_threshold(iou_threshold)
4234                .build()
4235                .unwrap()
4236        };
4237
4238        // Baseline: dshape populated in physical order.
4239        let dshaped = build_decoder(
4240            vec![
4241                (DimName::Batch, 1),
4242                (DimName::NumFeatures, 116),
4243                (DimName::NumBoxes, 8400),
4244            ],
4245            vec![
4246                (DimName::Batch, 1),
4247                (DimName::Height, 160),
4248                (DimName::Width, 160),
4249                (DimName::NumProtos, 32),
4250            ],
4251        );
4252        let mut dshaped_boxes = Vec::new();
4253        let mut dshaped_masks = Vec::new();
4254        dshaped
4255            .decode_float(
4256                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4257                &mut dshaped_boxes,
4258                &mut dshaped_masks,
4259            )
4260            .unwrap();
4261
4262        // Variant: dshape omitted — caller asserts shape is already in
4263        // the decoder's canonical order.
4264        let bare = build_decoder(vec![], vec![]);
4265        let mut bare_boxes = Vec::new();
4266        let mut bare_masks = Vec::new();
4267        bare.decode_float(
4268            &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4269            &mut bare_boxes,
4270            &mut bare_masks,
4271        )
4272        .unwrap();
4273
4274        assert_eq!(bare_boxes.len(), dshaped_boxes.len());
4275        for (b, d) in bare_boxes.iter().zip(&dshaped_boxes) {
4276            assert!(
4277                b.equal_within_delta(d, 1e-4),
4278                "dshape-omitted box {b:?} differs from dshape-populated {d:?}"
4279            );
4280        }
4281        for (bm, dm) in bare_masks.iter().zip(&dshaped_masks) {
4282            let bm_arr = segmentation_to_mask(bm.segmentation.view()).unwrap();
4283            let dm_arr = segmentation_to_mask(dm.segmentation.view()).unwrap();
4284            assert_eq!(
4285                bm_arr, dm_arr,
4286                "dshape-omitted mask must match dshape-populated pixel-for-pixel"
4287            );
4288        }
4289    }
4290
4291    /// 4D anchor-first boxes schema with a trailing `padding` axis — the
4292    /// exact Ara-2 DVM shape (`"4:1:8400:1"` innermost-first = physical
4293    /// `[1, 8400, 1, 4]`). Exercises the schema path's
4294    /// `squeeze_padding_dims`: the caller declares a 4D physical-order
4295    /// layout including `padding`, HAL squeezes the size-1 axis during
4296    /// `to_legacy_config_outputs`, and the decoder sees the resulting
4297    /// 3D shape. Callers feed HAL a 3D squeezed view of the same bytes.
4298    /// Complements the 3D `test_physical_order_ara2_anchor_first_split_boxes`
4299    /// which exercises the programmatic-builder path.
4300    #[test]
4301    fn test_physical_order_ara2_4d_anchor_first_with_padding() {
4302        // Build synthetic data in anchor-first 3D layout (the shape HAL
4303        // actually operates on after squeezing). The 4D-with-padding
4304        // declaration in the schema is a producer-side convention; the
4305        // caller is expected to present HAL with a squeezed view.
4306        const N: usize = 8400;
4307        let mut boxes = Array3::<f32>::zeros((1, N, 4));
4308        let target = 42usize;
4309        boxes[[0, target, 0]] = 0.4;
4310        boxes[[0, target, 1]] = 0.5;
4311        boxes[[0, target, 2]] = 0.2;
4312        boxes[[0, target, 3]] = 0.2;
4313        let mut scores = Array3::<f32>::zeros((1, N, 80));
4314        scores[[0, target, 0]] = 0.9;
4315
4316        // Schema JSON declares shape+dshape in physical anchor-first
4317        // order including padding — mirrors `ara2_int8_edgefirst.json`'s
4318        // feature-first declaration but with the axes in physical
4319        // anchor-first order. `to_legacy_config_outputs` squeezes
4320        // padding before the decoder's rank-3 verification.
4321        let json = r#"{
4322          "schema_version": 2,
4323          "decoder_version": "yolov8",
4324          "nms": "class_agnostic",
4325          "outputs": [
4326            {"name": "boxes", "type": "boxes",
4327             "shape": [1, 8400, 1, 4],
4328             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"box_coords":4}],
4329             "encoding": "direct",
4330             "decoder": "ultralytics",
4331             "normalized": true},
4332            {"name": "scores", "type": "scores",
4333             "shape": [1, 8400, 1, 80],
4334             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"num_classes":80}],
4335             "decoder": "ultralytics",
4336             "score_format": "per_class"}
4337          ]
4338        }"#;
4339        let decoder = DecoderBuilder::default()
4340            .with_config_json_str(json.to_string())
4341            .with_score_threshold(0.5)
4342            .with_iou_threshold(0.5)
4343            .build()
4344            .expect("4D anchor-first schema should build via squeeze_padding_dims");
4345
4346        let mut out_boxes = Vec::with_capacity(4);
4347        let mut out_masks = Vec::with_capacity(0);
4348        decoder
4349            .decode_float(
4350                &[boxes.view().into_dyn(), scores.view().into_dyn()],
4351                &mut out_boxes,
4352                &mut out_masks,
4353            )
4354            .unwrap();
4355
4356        assert_eq!(
4357            out_boxes.len(),
4358            1,
4359            "4D anchor-first with padding should decode exactly one box from the seeded anchor"
4360        );
4361        let b = &out_boxes[0];
4362        // xywh(0.4, 0.5, 0.2, 0.2) → xyxy(0.3, 0.4, 0.5, 0.6)
4363        assert!((b.bbox.xmin - 0.3).abs() < 1e-3, "xmin wrong: {b:?}");
4364        assert!((b.bbox.ymin - 0.4).abs() < 1e-3, "ymin wrong: {b:?}");
4365        assert!((b.bbox.xmax - 0.5).abs() < 1e-3, "xmax wrong: {b:?}");
4366        assert!((b.bbox.ymax - 0.6).abs() < 1e-3, "ymax wrong: {b:?}");
4367        assert_eq!(b.label, 0);
4368        assert!(b.score > 0.85, "score {}: {b:?}", b.score);
4369    }
4370}
4371
4372#[cfg(feature = "tracker")]
4373#[cfg(test)]
4374#[cfg_attr(coverage_nightly, coverage(off))]
4375mod decoder_tracked_tests {
4376
4377    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
4378    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
4379    use num_traits::{AsPrimitive, Float, PrimInt};
4380    use rand::{RngExt, SeedableRng};
4381    use rand_distr::StandardNormal;
4382
4383    use crate::{
4384        configs::{self, DimName},
4385        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
4386    };
4387
4388    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
4389        input: ArrayView<F, D>,
4390        quant: Quantization,
4391    ) -> Array<T, D>
4392    where
4393        i32: num_traits::AsPrimitive<F>,
4394        f32: num_traits::AsPrimitive<F>,
4395    {
4396        let zero_point = quant.zero_point.as_();
4397        let div_scale = F::one() / quant.scale.as_();
4398        if zero_point != F::zero() {
4399            input.mapv(|d| (d * div_scale + zero_point).round().as_())
4400        } else {
4401            input.mapv(|d| (d * div_scale).round().as_())
4402        }
4403    }
4404
4405    #[test]
4406    fn test_decoder_tracked_random_jitter() {
4407        use crate::configs::{DecoderType, Nms};
4408        use crate::DecoderBuilder;
4409
4410        let score_threshold = 0.25;
4411        let iou_threshold = 0.1;
4412        let out = include_bytes!(concat!(
4413            env!("CARGO_MANIFEST_DIR"),
4414            "/../../testdata/yolov8s_80_classes.bin"
4415        ));
4416        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4417        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4418        let quant = (0.0040811873, -123).into();
4419
4420        let decoder = DecoderBuilder::default()
4421            .with_config_yolo_det(
4422                crate::configs::Detection {
4423                    decoder: DecoderType::Ultralytics,
4424                    shape: vec![1, 84, 8400],
4425                    anchors: None,
4426                    quantization: Some(quant),
4427                    dshape: vec![
4428                        (crate::configs::DimName::Batch, 1),
4429                        (crate::configs::DimName::NumFeatures, 84),
4430                        (crate::configs::DimName::NumBoxes, 8400),
4431                    ],
4432                    normalized: Some(true),
4433                },
4434                None,
4435            )
4436            .with_score_threshold(score_threshold)
4437            .with_iou_threshold(iou_threshold)
4438            .with_nms(Some(Nms::ClassAgnostic))
4439            .build()
4440            .unwrap();
4441        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
4442
4443        let expected_boxes = [
4444            crate::DetectBox {
4445                bbox: crate::BoundingBox {
4446                    xmin: 0.5285137,
4447                    ymin: 0.05305544,
4448                    xmax: 0.87541467,
4449                    ymax: 0.9998909,
4450                },
4451                score: 0.5591227,
4452                label: 0,
4453            },
4454            crate::DetectBox {
4455                bbox: crate::BoundingBox {
4456                    xmin: 0.130598,
4457                    ymin: 0.43260583,
4458                    xmax: 0.35098213,
4459                    ymax: 0.9958097,
4460                },
4461                score: 0.33057618,
4462                label: 75,
4463            },
4464        ];
4465
4466        let mut tracker = ByteTrackBuilder::new()
4467            .track_update(0.1)
4468            .track_high_conf(0.3)
4469            .build();
4470
4471        let mut output_boxes = Vec::with_capacity(50);
4472        let mut output_masks = Vec::with_capacity(50);
4473        let mut output_tracks = Vec::with_capacity(50);
4474
4475        decoder
4476            .decode_tracked_quantized(
4477                &mut tracker,
4478                0,
4479                &[out.view().into()],
4480                &mut output_boxes,
4481                &mut output_masks,
4482                &mut output_tracks,
4483            )
4484            .unwrap();
4485
4486        assert_eq!(output_boxes.len(), 2);
4487        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4488        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4489
4490        let mut last_boxes = output_boxes.clone();
4491
4492        for i in 1..=100 {
4493            let mut out = out.clone();
4494            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
4495            let mut x_values = out.slice_mut(s![0, 0, ..]);
4496            for x in x_values.iter_mut() {
4497                let r: f32 = rng.sample(StandardNormal);
4498                let r = r.clamp(-2.0, 2.0) / 2.0;
4499                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
4500            }
4501
4502            let mut y_values = out.slice_mut(s![0, 1, ..]);
4503            for y in y_values.iter_mut() {
4504                let r: f32 = rng.sample(StandardNormal);
4505                let r = r.clamp(-2.0, 2.0) / 2.0;
4506                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
4507            }
4508
4509            decoder
4510                .decode_tracked_quantized(
4511                    &mut tracker,
4512                    100_000_000 * i / 3, // simulate 33.333ms between frames
4513                    &[out.view().into()],
4514                    &mut output_boxes,
4515                    &mut output_masks,
4516                    &mut output_tracks,
4517                )
4518                .unwrap();
4519
4520            assert_eq!(output_boxes.len(), 2);
4521            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
4522            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
4523
4524            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
4525            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
4526            last_boxes = output_boxes.clone();
4527        }
4528    }
4529
4530    // ─── Shared helpers for tracked decoder tests ────────────────────
4531
4532    fn real_data_expected_boxes() -> [DetectBox; 2] {
4533        [
4534            DetectBox {
4535                bbox: BoundingBox {
4536                    xmin: 0.08515105,
4537                    ymin: 0.7131401,
4538                    xmax: 0.29802868,
4539                    ymax: 0.8195788,
4540                },
4541                score: 0.91537374,
4542                label: 23,
4543            },
4544            DetectBox {
4545                bbox: BoundingBox {
4546                    xmin: 0.59605736,
4547                    ymin: 0.25545314,
4548                    xmax: 0.93666154,
4549                    ymax: 0.72378385,
4550                },
4551                score: 0.91537374,
4552                label: 23,
4553            },
4554        ]
4555    }
4556
4557    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
4558        [DetectBox {
4559            bbox: BoundingBox {
4560                xmin: 0.12549022,
4561                ymin: 0.12549022,
4562                xmax: 0.23529413,
4563                ymax: 0.23529413,
4564            },
4565            score: 0.98823535,
4566            label: 2,
4567        }]
4568    }
4569
4570    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
4571        [DetectBox {
4572            bbox: BoundingBox {
4573                xmin: 0.1234,
4574                ymin: 0.1234,
4575                xmax: 0.2345,
4576                ymax: 0.2345,
4577            },
4578            score: 0.9876,
4579            label: 2,
4580        }]
4581    }
4582
4583    fn build_yolo_split_segdet_decoder(
4584        score_threshold: f32,
4585        iou_threshold: f32,
4586        quant_boxes: (f32, i32),
4587        quant_protos: (f32, i32),
4588    ) -> crate::Decoder {
4589        DecoderBuilder::default()
4590            .with_config_yolo_split_segdet(
4591                configs::Boxes {
4592                    decoder: configs::DecoderType::Ultralytics,
4593                    quantization: Some(quant_boxes.into()),
4594                    shape: vec![1, 4, 8400],
4595                    dshape: vec![
4596                        (DimName::Batch, 1),
4597                        (DimName::BoxCoords, 4),
4598                        (DimName::NumBoxes, 8400),
4599                    ],
4600                    normalized: Some(true),
4601                },
4602                configs::Scores {
4603                    decoder: configs::DecoderType::Ultralytics,
4604                    quantization: Some(quant_boxes.into()),
4605                    shape: vec![1, 80, 8400],
4606                    dshape: vec![
4607                        (DimName::Batch, 1),
4608                        (DimName::NumClasses, 80),
4609                        (DimName::NumBoxes, 8400),
4610                    ],
4611                },
4612                configs::MaskCoefficients {
4613                    decoder: configs::DecoderType::Ultralytics,
4614                    quantization: Some(quant_boxes.into()),
4615                    shape: vec![1, 32, 8400],
4616                    dshape: vec![
4617                        (DimName::Batch, 1),
4618                        (DimName::NumProtos, 32),
4619                        (DimName::NumBoxes, 8400),
4620                    ],
4621                },
4622                configs::Protos {
4623                    decoder: configs::DecoderType::Ultralytics,
4624                    quantization: Some(quant_protos.into()),
4625                    shape: vec![1, 160, 160, 32],
4626                    dshape: vec![
4627                        (DimName::Batch, 1),
4628                        (DimName::Height, 160),
4629                        (DimName::Width, 160),
4630                        (DimName::NumProtos, 32),
4631                    ],
4632                },
4633            )
4634            .with_score_threshold(score_threshold)
4635            .with_iou_threshold(iou_threshold)
4636            .build()
4637            .unwrap()
4638    }
4639
4640    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4641        let config_yaml = include_str!(concat!(
4642            env!("CARGO_MANIFEST_DIR"),
4643            "/../../testdata/yolov8_seg.yaml"
4644        ));
4645        DecoderBuilder::default()
4646            .with_config_yaml_str(config_yaml.to_string())
4647            .with_score_threshold(score_threshold)
4648            .with_iou_threshold(iou_threshold)
4649            .build()
4650            .unwrap()
4651    }
4652
4653    // ─── Real-data tracked test macro ───────────────────────────────
4654    //
4655    // Generates tests that load i8 binary test data from testdata/ and
4656    // exercise all (quant/float) × (combined/split) × (masks/proto)
4657    // decoder paths.
4658
4659    macro_rules! real_data_tracked_test {
4660        ($name:ident, quantized, $layout:ident, $output:ident) => {
4661            #[test]
4662            fn $name() {
4663                let is_split = matches!(stringify!($layout), "split");
4664                let is_proto = matches!(stringify!($output), "proto");
4665
4666                let score_threshold = 0.45;
4667                let iou_threshold = 0.45;
4668                let quant_boxes = (0.021287762_f32, 31_i32);
4669                let quant_protos = (0.02491162_f32, -117_i32);
4670
4671                let raw_boxes = include_bytes!(concat!(
4672                    env!("CARGO_MANIFEST_DIR"),
4673                    "/../../testdata/yolov8_boxes_116x8400.bin"
4674                ));
4675                let raw_boxes = unsafe {
4676                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4677                };
4678                let boxes_i8 =
4679                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4680
4681                let raw_protos = include_bytes!(concat!(
4682                    env!("CARGO_MANIFEST_DIR"),
4683                    "/../../testdata/yolov8_protos_160x160x32.bin"
4684                ));
4685                let raw_protos = unsafe {
4686                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4687                };
4688                let protos_i8 =
4689                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4690                        .unwrap();
4691
4692                // Pre-split (unused for combined, but harmless)
4693                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4694                let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4695                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4696                let mut boxes_combined = boxes_i8;
4697
4698                let decoder = if is_split {
4699                    build_yolo_split_segdet_decoder(
4700                        score_threshold,
4701                        iou_threshold,
4702                        quant_boxes,
4703                        quant_protos,
4704                    )
4705                } else {
4706                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4707                };
4708
4709                let expected = real_data_expected_boxes();
4710                let mut tracker = ByteTrackBuilder::new()
4711                    .track_update(0.1)
4712                    .track_high_conf(0.7)
4713                    .build();
4714                let mut output_boxes = Vec::with_capacity(50);
4715                let mut output_tracks = Vec::with_capacity(50);
4716
4717                // Frame 1: decode
4718                if is_proto {
4719                    {
4720                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4721                            vec![
4722                                boxes_split.view().into(),
4723                                scores_split.view().into(),
4724                                mask_split.view().into(),
4725                                protos_i8.view().into(),
4726                            ]
4727                        } else {
4728                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4729                        };
4730                        decoder
4731                            .decode_tracked_quantized_proto(
4732                                &mut tracker,
4733                                0,
4734                                &inputs,
4735                                &mut output_boxes,
4736                                &mut output_tracks,
4737                            )
4738                            .unwrap();
4739                    }
4740                    assert_eq!(output_boxes.len(), 2);
4741                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4742                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4743
4744                    // Zero scores for frame 2
4745                    if is_split {
4746                        for score in scores_split.iter_mut() {
4747                            *score = i8::MIN;
4748                        }
4749                    } else {
4750                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4751                            *score = i8::MIN;
4752                        }
4753                    }
4754
4755                    let proto_result = {
4756                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4757                            vec![
4758                                boxes_split.view().into(),
4759                                scores_split.view().into(),
4760                                mask_split.view().into(),
4761                                protos_i8.view().into(),
4762                            ]
4763                        } else {
4764                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4765                        };
4766                        decoder
4767                            .decode_tracked_quantized_proto(
4768                                &mut tracker,
4769                                100_000_000 / 3,
4770                                &inputs,
4771                                &mut output_boxes,
4772                                &mut output_tracks,
4773                            )
4774                            .unwrap()
4775                    };
4776                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4777                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4778                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4779                } else {
4780                    let mut output_masks = Vec::with_capacity(50);
4781                    {
4782                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4783                            vec![
4784                                boxes_split.view().into(),
4785                                scores_split.view().into(),
4786                                mask_split.view().into(),
4787                                protos_i8.view().into(),
4788                            ]
4789                        } else {
4790                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4791                        };
4792                        decoder
4793                            .decode_tracked_quantized(
4794                                &mut tracker,
4795                                0,
4796                                &inputs,
4797                                &mut output_boxes,
4798                                &mut output_masks,
4799                                &mut output_tracks,
4800                            )
4801                            .unwrap();
4802                    }
4803                    assert_eq!(output_boxes.len(), 2);
4804                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4805                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4806
4807                    if is_split {
4808                        for score in scores_split.iter_mut() {
4809                            *score = i8::MIN;
4810                        }
4811                    } else {
4812                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4813                            *score = i8::MIN;
4814                        }
4815                    }
4816
4817                    {
4818                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4819                            vec![
4820                                boxes_split.view().into(),
4821                                scores_split.view().into(),
4822                                mask_split.view().into(),
4823                                protos_i8.view().into(),
4824                            ]
4825                        } else {
4826                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4827                        };
4828                        decoder
4829                            .decode_tracked_quantized(
4830                                &mut tracker,
4831                                100_000_000 / 3,
4832                                &inputs,
4833                                &mut output_boxes,
4834                                &mut output_masks,
4835                                &mut output_tracks,
4836                            )
4837                            .unwrap();
4838                    }
4839                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4840                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4841                    assert!(output_masks.is_empty());
4842                }
4843            }
4844        };
4845        ($name:ident, float, $layout:ident, $output:ident) => {
4846            #[test]
4847            fn $name() {
4848                let is_split = matches!(stringify!($layout), "split");
4849                let is_proto = matches!(stringify!($output), "proto");
4850
4851                let score_threshold = 0.45;
4852                let iou_threshold = 0.45;
4853                let quant_boxes = (0.021287762_f32, 31_i32);
4854                let quant_protos = (0.02491162_f32, -117_i32);
4855
4856                let raw_boxes = include_bytes!(concat!(
4857                    env!("CARGO_MANIFEST_DIR"),
4858                    "/../../testdata/yolov8_boxes_116x8400.bin"
4859                ));
4860                let raw_boxes = unsafe {
4861                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4862                };
4863                let boxes_i8 =
4864                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4865                let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4866
4867                let raw_protos = include_bytes!(concat!(
4868                    env!("CARGO_MANIFEST_DIR"),
4869                    "/../../testdata/yolov8_protos_160x160x32.bin"
4870                ));
4871                let raw_protos = unsafe {
4872                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4873                };
4874                let protos_i8 =
4875                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4876                        .unwrap();
4877                let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4878
4879                // Pre-split from dequantized data
4880                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4881                let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4882                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4883                let mut boxes_combined = boxes_f32;
4884
4885                let decoder = if is_split {
4886                    build_yolo_split_segdet_decoder(
4887                        score_threshold,
4888                        iou_threshold,
4889                        quant_boxes,
4890                        quant_protos,
4891                    )
4892                } else {
4893                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4894                };
4895
4896                let expected = real_data_expected_boxes();
4897                let mut tracker = ByteTrackBuilder::new()
4898                    .track_update(0.1)
4899                    .track_high_conf(0.7)
4900                    .build();
4901                let mut output_boxes = Vec::with_capacity(50);
4902                let mut output_tracks = Vec::with_capacity(50);
4903
4904                if is_proto {
4905                    {
4906                        let inputs = if is_split {
4907                            vec![
4908                                boxes_split.view().into_dyn(),
4909                                scores_split.view().into_dyn(),
4910                                mask_split.view().into_dyn(),
4911                                protos_f32.view().into_dyn(),
4912                            ]
4913                        } else {
4914                            vec![
4915                                boxes_combined.view().into_dyn(),
4916                                protos_f32.view().into_dyn(),
4917                            ]
4918                        };
4919                        decoder
4920                            .decode_tracked_float_proto(
4921                                &mut tracker,
4922                                0,
4923                                &inputs,
4924                                &mut output_boxes,
4925                                &mut output_tracks,
4926                            )
4927                            .unwrap();
4928                    }
4929                    assert_eq!(output_boxes.len(), 2);
4930                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4931                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4932
4933                    if is_split {
4934                        for score in scores_split.iter_mut() {
4935                            *score = 0.0;
4936                        }
4937                    } else {
4938                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4939                            *score = 0.0;
4940                        }
4941                    }
4942
4943                    let proto_result = {
4944                        let inputs = if is_split {
4945                            vec![
4946                                boxes_split.view().into_dyn(),
4947                                scores_split.view().into_dyn(),
4948                                mask_split.view().into_dyn(),
4949                                protos_f32.view().into_dyn(),
4950                            ]
4951                        } else {
4952                            vec![
4953                                boxes_combined.view().into_dyn(),
4954                                protos_f32.view().into_dyn(),
4955                            ]
4956                        };
4957                        decoder
4958                            .decode_tracked_float_proto(
4959                                &mut tracker,
4960                                100_000_000 / 3,
4961                                &inputs,
4962                                &mut output_boxes,
4963                                &mut output_tracks,
4964                            )
4965                            .unwrap()
4966                    };
4967                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4968                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4969                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4970                } else {
4971                    let mut output_masks = Vec::with_capacity(50);
4972                    {
4973                        let inputs = if is_split {
4974                            vec![
4975                                boxes_split.view().into_dyn(),
4976                                scores_split.view().into_dyn(),
4977                                mask_split.view().into_dyn(),
4978                                protos_f32.view().into_dyn(),
4979                            ]
4980                        } else {
4981                            vec![
4982                                boxes_combined.view().into_dyn(),
4983                                protos_f32.view().into_dyn(),
4984                            ]
4985                        };
4986                        decoder
4987                            .decode_tracked_float(
4988                                &mut tracker,
4989                                0,
4990                                &inputs,
4991                                &mut output_boxes,
4992                                &mut output_masks,
4993                                &mut output_tracks,
4994                            )
4995                            .unwrap();
4996                    }
4997                    assert_eq!(output_boxes.len(), 2);
4998                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4999                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
5000
5001                    if is_split {
5002                        for score in scores_split.iter_mut() {
5003                            *score = 0.0;
5004                        }
5005                    } else {
5006                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
5007                            *score = 0.0;
5008                        }
5009                    }
5010
5011                    {
5012                        let inputs = if is_split {
5013                            vec![
5014                                boxes_split.view().into_dyn(),
5015                                scores_split.view().into_dyn(),
5016                                mask_split.view().into_dyn(),
5017                                protos_f32.view().into_dyn(),
5018                            ]
5019                        } else {
5020                            vec![
5021                                boxes_combined.view().into_dyn(),
5022                                protos_f32.view().into_dyn(),
5023                            ]
5024                        };
5025                        decoder
5026                            .decode_tracked_float(
5027                                &mut tracker,
5028                                100_000_000 / 3,
5029                                &inputs,
5030                                &mut output_boxes,
5031                                &mut output_masks,
5032                                &mut output_tracks,
5033                            )
5034                            .unwrap();
5035                    }
5036                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5037                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
5038                    assert!(output_masks.is_empty());
5039                }
5040            }
5041        };
5042    }
5043
5044    real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
5045    real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
5046    real_data_tracked_test!(
5047        test_decoder_tracked_segdet_proto,
5048        quantized,
5049        combined,
5050        proto
5051    );
5052    real_data_tracked_test!(
5053        test_decoder_tracked_segdet_proto_float,
5054        float,
5055        combined,
5056        proto
5057    );
5058    real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
5059    real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
5060    real_data_tracked_test!(
5061        test_decoder_tracked_segdet_split_proto,
5062        quantized,
5063        split,
5064        proto
5065    );
5066    real_data_tracked_test!(
5067        test_decoder_tracked_segdet_split_proto_float,
5068        float,
5069        split,
5070        proto
5071    );
5072
5073    // ─── End-to-end tracked test macro ──────────────────────────────
5074    //
5075    // Generates tests with synthetic data to exercise all tracked
5076    // decode paths without needing real model output files.
5077
5078    const E2E_COMBINED_CONFIG: &str = "
5079decoder_version: yolo26
5080outputs:
5081 - type: detection
5082   decoder: ultralytics
5083   quantization: [0.00784313725490196, 0]
5084   shape: [1, 10, 38]
5085   dshape:
5086    - [batch, 1]
5087    - [num_boxes, 10]
5088    - [num_features, 38]
5089   normalized: true
5090 - type: protos
5091   decoder: ultralytics
5092   quantization: [0.0039215686274509803921568627451, 128]
5093   shape: [1, 160, 160, 32]
5094   dshape:
5095    - [batch, 1]
5096    - [height, 160]
5097    - [width, 160]
5098    - [num_protos, 32]
5099";
5100
5101    const E2E_SPLIT_CONFIG: &str = "
5102decoder_version: yolo26
5103outputs:
5104 - type: boxes
5105   decoder: ultralytics
5106   quantization: [0.00784313725490196, 0]
5107   shape: [1, 10, 4]
5108   dshape:
5109    - [batch, 1]
5110    - [num_boxes, 10]
5111    - [box_coords, 4]
5112   normalized: true
5113 - type: scores
5114   decoder: ultralytics
5115   quantization: [0.00784313725490196, 0]
5116   shape: [1, 10, 1]
5117   dshape:
5118    - [batch, 1]
5119    - [num_boxes, 10]
5120    - [num_classes, 1]
5121 - type: classes
5122   decoder: ultralytics
5123   quantization: [0.00784313725490196, 0]
5124   shape: [1, 10, 1]
5125   dshape:
5126    - [batch, 1]
5127    - [num_boxes, 10]
5128    - [num_classes, 1]
5129 - type: mask_coefficients
5130   decoder: ultralytics
5131   quantization: [0.00784313725490196, 0]
5132   shape: [1, 10, 32]
5133   dshape:
5134    - [batch, 1]
5135    - [num_boxes, 10]
5136    - [num_protos, 32]
5137 - type: protos
5138   decoder: ultralytics
5139   quantization: [0.0039215686274509803921568627451, 128]
5140   shape: [1, 160, 160, 32]
5141   dshape:
5142    - [batch, 1]
5143    - [height, 160]
5144    - [width, 160]
5145    - [num_protos, 32]
5146";
5147
5148    macro_rules! e2e_tracked_test {
5149        ($name:ident, quantized, $layout:ident, $output:ident) => {
5150            #[test]
5151            fn $name() {
5152                let is_split = matches!(stringify!($layout), "split");
5153                let is_proto = matches!(stringify!($output), "proto");
5154
5155                let score_threshold = 0.45;
5156                let iou_threshold = 0.45;
5157
5158                let mut boxes = Array2::zeros((10, 4));
5159                let mut scores = Array2::zeros((10, 1));
5160                let mut classes = Array2::zeros((10, 1));
5161                let mask = Array2::zeros((10, 32));
5162                let protos = Array3::<f64>::zeros((160, 160, 32));
5163                let protos = protos.insert_axis(Axis(0));
5164                let protos_quant = (1.0 / 255.0, 0.0);
5165                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
5166
5167                boxes
5168                    .slice_mut(s![0, ..])
5169                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5170                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5171                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5172
5173                let detect_quant = (2.0 / 255.0, 0.0);
5174
5175                let decoder = if is_split {
5176                    DecoderBuilder::default()
5177                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5178                        .with_score_threshold(score_threshold)
5179                        .with_iou_threshold(iou_threshold)
5180                        .build()
5181                        .unwrap()
5182                } else {
5183                    DecoderBuilder::default()
5184                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5185                        .with_score_threshold(score_threshold)
5186                        .with_iou_threshold(iou_threshold)
5187                        .build()
5188                        .unwrap()
5189                };
5190
5191                let expected = e2e_expected_boxes_quant();
5192                let mut tracker = ByteTrackBuilder::new()
5193                    .track_update(0.1)
5194                    .track_high_conf(0.7)
5195                    .build();
5196                let mut output_boxes = Vec::with_capacity(50);
5197                let mut output_tracks = Vec::with_capacity(50);
5198
5199                if is_split {
5200                    let boxes = boxes.insert_axis(Axis(0));
5201                    let scores = scores.insert_axis(Axis(0));
5202                    let classes = classes.insert_axis(Axis(0));
5203                    let mask = mask.insert_axis(Axis(0));
5204
5205                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5206                    let mut scores: Array3<u8> =
5207                        quantize_ndarray(scores.view(), detect_quant.into());
5208                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
5209                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5210
5211                    if is_proto {
5212                        {
5213                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5214                                boxes.view().into(),
5215                                scores.view().into(),
5216                                classes.view().into(),
5217                                mask.view().into(),
5218                                protos.view().into(),
5219                            ];
5220                            decoder
5221                                .decode_tracked_quantized_proto(
5222                                    &mut tracker,
5223                                    0,
5224                                    &inputs,
5225                                    &mut output_boxes,
5226                                    &mut output_tracks,
5227                                )
5228                                .unwrap();
5229                        }
5230                        assert_eq!(output_boxes.len(), 1);
5231                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5232
5233                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5234                            *score = u8::MIN;
5235                        }
5236                        let proto_result = {
5237                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5238                                boxes.view().into(),
5239                                scores.view().into(),
5240                                classes.view().into(),
5241                                mask.view().into(),
5242                                protos.view().into(),
5243                            ];
5244                            decoder
5245                                .decode_tracked_quantized_proto(
5246                                    &mut tracker,
5247                                    100_000_000 / 3,
5248                                    &inputs,
5249                                    &mut output_boxes,
5250                                    &mut output_tracks,
5251                                )
5252                                .unwrap()
5253                        };
5254                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5255                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5256                    } else {
5257                        let mut output_masks = Vec::with_capacity(50);
5258                        {
5259                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5260                                boxes.view().into(),
5261                                scores.view().into(),
5262                                classes.view().into(),
5263                                mask.view().into(),
5264                                protos.view().into(),
5265                            ];
5266                            decoder
5267                                .decode_tracked_quantized(
5268                                    &mut tracker,
5269                                    0,
5270                                    &inputs,
5271                                    &mut output_boxes,
5272                                    &mut output_masks,
5273                                    &mut output_tracks,
5274                                )
5275                                .unwrap();
5276                        }
5277                        assert_eq!(output_boxes.len(), 1);
5278                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5279
5280                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5281                            *score = u8::MIN;
5282                        }
5283                        {
5284                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5285                                boxes.view().into(),
5286                                scores.view().into(),
5287                                classes.view().into(),
5288                                mask.view().into(),
5289                                protos.view().into(),
5290                            ];
5291                            decoder
5292                                .decode_tracked_quantized(
5293                                    &mut tracker,
5294                                    100_000_000 / 3,
5295                                    &inputs,
5296                                    &mut output_boxes,
5297                                    &mut output_masks,
5298                                    &mut output_tracks,
5299                                )
5300                                .unwrap();
5301                        }
5302                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5303                        assert!(output_masks.is_empty());
5304                    }
5305                } else {
5306                    // Combined layout
5307                    let detect = ndarray::concatenate![
5308                        Axis(1),
5309                        boxes.view(),
5310                        scores.view(),
5311                        classes.view(),
5312                        mask.view()
5313                    ];
5314                    let detect = detect.insert_axis(Axis(0));
5315                    assert_eq!(detect.shape(), &[1, 10, 38]);
5316                    let mut detect: Array3<u8> =
5317                        quantize_ndarray(detect.view(), detect_quant.into());
5318
5319                    if is_proto {
5320                        {
5321                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5322                                vec![detect.view().into(), protos.view().into()];
5323                            decoder
5324                                .decode_tracked_quantized_proto(
5325                                    &mut tracker,
5326                                    0,
5327                                    &inputs,
5328                                    &mut output_boxes,
5329                                    &mut output_tracks,
5330                                )
5331                                .unwrap();
5332                        }
5333                        assert_eq!(output_boxes.len(), 1);
5334                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5335
5336                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5337                            *score = u8::MIN;
5338                        }
5339                        let proto_result = {
5340                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5341                                vec![detect.view().into(), protos.view().into()];
5342                            decoder
5343                                .decode_tracked_quantized_proto(
5344                                    &mut tracker,
5345                                    100_000_000 / 3,
5346                                    &inputs,
5347                                    &mut output_boxes,
5348                                    &mut output_tracks,
5349                                )
5350                                .unwrap()
5351                        };
5352                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5353                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5354                    } else {
5355                        let mut output_masks = Vec::with_capacity(50);
5356                        {
5357                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5358                                vec![detect.view().into(), protos.view().into()];
5359                            decoder
5360                                .decode_tracked_quantized(
5361                                    &mut tracker,
5362                                    0,
5363                                    &inputs,
5364                                    &mut output_boxes,
5365                                    &mut output_masks,
5366                                    &mut output_tracks,
5367                                )
5368                                .unwrap();
5369                        }
5370                        assert_eq!(output_boxes.len(), 1);
5371                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5372
5373                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5374                            *score = u8::MIN;
5375                        }
5376                        {
5377                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5378                                vec![detect.view().into(), protos.view().into()];
5379                            decoder
5380                                .decode_tracked_quantized(
5381                                    &mut tracker,
5382                                    100_000_000 / 3,
5383                                    &inputs,
5384                                    &mut output_boxes,
5385                                    &mut output_masks,
5386                                    &mut output_tracks,
5387                                )
5388                                .unwrap();
5389                        }
5390                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5391                        assert!(output_masks.is_empty());
5392                    }
5393                }
5394            }
5395        };
5396        ($name:ident, float, $layout:ident, $output:ident) => {
5397            #[test]
5398            fn $name() {
5399                let is_split = matches!(stringify!($layout), "split");
5400                let is_proto = matches!(stringify!($output), "proto");
5401
5402                let score_threshold = 0.45;
5403                let iou_threshold = 0.45;
5404
5405                let mut boxes = Array2::zeros((10, 4));
5406                let mut scores = Array2::zeros((10, 1));
5407                let mut classes = Array2::zeros((10, 1));
5408                let mask: Array2<f64> = Array2::zeros((10, 32));
5409                let protos = Array3::<f64>::zeros((160, 160, 32));
5410                let protos = protos.insert_axis(Axis(0));
5411
5412                boxes
5413                    .slice_mut(s![0, ..])
5414                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5415                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5416                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5417
5418                let decoder = if is_split {
5419                    DecoderBuilder::default()
5420                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5421                        .with_score_threshold(score_threshold)
5422                        .with_iou_threshold(iou_threshold)
5423                        .build()
5424                        .unwrap()
5425                } else {
5426                    DecoderBuilder::default()
5427                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5428                        .with_score_threshold(score_threshold)
5429                        .with_iou_threshold(iou_threshold)
5430                        .build()
5431                        .unwrap()
5432                };
5433
5434                let expected = e2e_expected_boxes_float();
5435                let mut tracker = ByteTrackBuilder::new()
5436                    .track_update(0.1)
5437                    .track_high_conf(0.7)
5438                    .build();
5439                let mut output_boxes = Vec::with_capacity(50);
5440                let mut output_tracks = Vec::with_capacity(50);
5441
5442                if is_split {
5443                    let boxes = boxes.insert_axis(Axis(0));
5444                    let mut scores = scores.insert_axis(Axis(0));
5445                    let classes = classes.insert_axis(Axis(0));
5446                    let mask = mask.insert_axis(Axis(0));
5447
5448                    if is_proto {
5449                        {
5450                            let inputs = vec![
5451                                boxes.view().into_dyn(),
5452                                scores.view().into_dyn(),
5453                                classes.view().into_dyn(),
5454                                mask.view().into_dyn(),
5455                                protos.view().into_dyn(),
5456                            ];
5457                            decoder
5458                                .decode_tracked_float_proto(
5459                                    &mut tracker,
5460                                    0,
5461                                    &inputs,
5462                                    &mut output_boxes,
5463                                    &mut output_tracks,
5464                                )
5465                                .unwrap();
5466                        }
5467                        assert_eq!(output_boxes.len(), 1);
5468                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5469
5470                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5471                            *score = 0.0;
5472                        }
5473                        let proto_result = {
5474                            let inputs = vec![
5475                                boxes.view().into_dyn(),
5476                                scores.view().into_dyn(),
5477                                classes.view().into_dyn(),
5478                                mask.view().into_dyn(),
5479                                protos.view().into_dyn(),
5480                            ];
5481                            decoder
5482                                .decode_tracked_float_proto(
5483                                    &mut tracker,
5484                                    100_000_000 / 3,
5485                                    &inputs,
5486                                    &mut output_boxes,
5487                                    &mut output_tracks,
5488                                )
5489                                .unwrap()
5490                        };
5491                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5492                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5493                    } else {
5494                        let mut output_masks = Vec::with_capacity(50);
5495                        {
5496                            let inputs = vec![
5497                                boxes.view().into_dyn(),
5498                                scores.view().into_dyn(),
5499                                classes.view().into_dyn(),
5500                                mask.view().into_dyn(),
5501                                protos.view().into_dyn(),
5502                            ];
5503                            decoder
5504                                .decode_tracked_float(
5505                                    &mut tracker,
5506                                    0,
5507                                    &inputs,
5508                                    &mut output_boxes,
5509                                    &mut output_masks,
5510                                    &mut output_tracks,
5511                                )
5512                                .unwrap();
5513                        }
5514                        assert_eq!(output_boxes.len(), 1);
5515                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5516
5517                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5518                            *score = 0.0;
5519                        }
5520                        {
5521                            let inputs = vec![
5522                                boxes.view().into_dyn(),
5523                                scores.view().into_dyn(),
5524                                classes.view().into_dyn(),
5525                                mask.view().into_dyn(),
5526                                protos.view().into_dyn(),
5527                            ];
5528                            decoder
5529                                .decode_tracked_float(
5530                                    &mut tracker,
5531                                    100_000_000 / 3,
5532                                    &inputs,
5533                                    &mut output_boxes,
5534                                    &mut output_masks,
5535                                    &mut output_tracks,
5536                                )
5537                                .unwrap();
5538                        }
5539                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5540                        assert!(output_masks.is_empty());
5541                    }
5542                } else {
5543                    // Combined layout
5544                    let detect = ndarray::concatenate![
5545                        Axis(1),
5546                        boxes.view(),
5547                        scores.view(),
5548                        classes.view(),
5549                        mask.view()
5550                    ];
5551                    let mut detect = detect.insert_axis(Axis(0));
5552                    assert_eq!(detect.shape(), &[1, 10, 38]);
5553
5554                    if is_proto {
5555                        {
5556                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5557                            decoder
5558                                .decode_tracked_float_proto(
5559                                    &mut tracker,
5560                                    0,
5561                                    &inputs,
5562                                    &mut output_boxes,
5563                                    &mut output_tracks,
5564                                )
5565                                .unwrap();
5566                        }
5567                        assert_eq!(output_boxes.len(), 1);
5568                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5569
5570                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5571                            *score = 0.0;
5572                        }
5573                        let proto_result = {
5574                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5575                            decoder
5576                                .decode_tracked_float_proto(
5577                                    &mut tracker,
5578                                    100_000_000 / 3,
5579                                    &inputs,
5580                                    &mut output_boxes,
5581                                    &mut output_tracks,
5582                                )
5583                                .unwrap()
5584                        };
5585                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5586                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5587                    } else {
5588                        let mut output_masks = Vec::with_capacity(50);
5589                        {
5590                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5591                            decoder
5592                                .decode_tracked_float(
5593                                    &mut tracker,
5594                                    0,
5595                                    &inputs,
5596                                    &mut output_boxes,
5597                                    &mut output_masks,
5598                                    &mut output_tracks,
5599                                )
5600                                .unwrap();
5601                        }
5602                        assert_eq!(output_boxes.len(), 1);
5603                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5604
5605                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5606                            *score = 0.0;
5607                        }
5608                        {
5609                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5610                            decoder
5611                                .decode_tracked_float(
5612                                    &mut tracker,
5613                                    100_000_000 / 3,
5614                                    &inputs,
5615                                    &mut output_boxes,
5616                                    &mut output_masks,
5617                                    &mut output_tracks,
5618                                )
5619                                .unwrap();
5620                        }
5621                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5622                        assert!(output_masks.is_empty());
5623                    }
5624                }
5625            }
5626        };
5627    }
5628
5629    e2e_tracked_test!(
5630        test_decoder_tracked_end_to_end_segdet,
5631        quantized,
5632        combined,
5633        masks
5634    );
5635    e2e_tracked_test!(
5636        test_decoder_tracked_end_to_end_segdet_float,
5637        float,
5638        combined,
5639        masks
5640    );
5641    e2e_tracked_test!(
5642        test_decoder_tracked_end_to_end_segdet_proto,
5643        quantized,
5644        combined,
5645        proto
5646    );
5647    e2e_tracked_test!(
5648        test_decoder_tracked_end_to_end_segdet_proto_float,
5649        float,
5650        combined,
5651        proto
5652    );
5653    e2e_tracked_test!(
5654        test_decoder_tracked_end_to_end_segdet_split,
5655        quantized,
5656        split,
5657        masks
5658    );
5659    e2e_tracked_test!(
5660        test_decoder_tracked_end_to_end_segdet_split_float,
5661        float,
5662        split,
5663        masks
5664    );
5665    e2e_tracked_test!(
5666        test_decoder_tracked_end_to_end_segdet_split_proto,
5667        quantized,
5668        split,
5669        proto
5670    );
5671    e2e_tracked_test!(
5672        test_decoder_tracked_end_to_end_segdet_split_proto_float,
5673        float,
5674        split,
5675        proto
5676    );
5677
5678    // ─── End-to-end tracked TensorDyn test macro ────────────────────
5679    //
5680    // Same as e2e_tracked_test but wraps data in TensorDyn and exercises
5681    // the public decode_tracked / decode_proto_tracked API.
5682
5683    macro_rules! e2e_tracked_tensor_test {
5684        ($name:ident, quantized, $layout:ident, $output:ident) => {
5685            #[test]
5686            fn $name() {
5687                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5688
5689                let is_split = matches!(stringify!($layout), "split");
5690                let is_proto = matches!(stringify!($output), "proto");
5691
5692                let score_threshold = 0.45;
5693                let iou_threshold = 0.45;
5694
5695                let mut boxes = Array2::zeros((10, 4));
5696                let mut scores = Array2::zeros((10, 1));
5697                let mut classes = Array2::zeros((10, 1));
5698                let mask = Array2::zeros((10, 32));
5699                let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5700                let protos_f64 = protos_f64.insert_axis(Axis(0));
5701                let protos_quant = (1.0 / 255.0, 0.0);
5702                let protos_u8: Array4<u8> =
5703                    quantize_ndarray(protos_f64.view(), protos_quant.into());
5704
5705                boxes
5706                    .slice_mut(s![0, ..])
5707                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5708                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5709                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5710
5711                let detect_quant = (2.0 / 255.0, 0.0);
5712
5713                let decoder = if is_split {
5714                    DecoderBuilder::default()
5715                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5716                        .with_score_threshold(score_threshold)
5717                        .with_iou_threshold(iou_threshold)
5718                        .build()
5719                        .unwrap()
5720                } else {
5721                    DecoderBuilder::default()
5722                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5723                        .with_score_threshold(score_threshold)
5724                        .with_iou_threshold(iou_threshold)
5725                        .build()
5726                        .unwrap()
5727                };
5728
5729                // Helper to wrap a u8 slice into a TensorDyn
5730                let make_u8_tensor =
5731                    |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5732                        let t = Tensor::<u8>::new(shape, None, None).unwrap();
5733                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5734                        t.into()
5735                    };
5736
5737                let expected = e2e_expected_boxes_quant();
5738                let mut tracker = ByteTrackBuilder::new()
5739                    .track_update(0.1)
5740                    .track_high_conf(0.7)
5741                    .build();
5742                let mut output_boxes = Vec::with_capacity(50);
5743                let mut output_tracks = Vec::with_capacity(50);
5744
5745                let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5746
5747                if is_split {
5748                    let boxes = boxes.insert_axis(Axis(0));
5749                    let scores = scores.insert_axis(Axis(0));
5750                    let classes = classes.insert_axis(Axis(0));
5751                    let mask = mask.insert_axis(Axis(0));
5752
5753                    let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5754                    let mut scores_q: Array3<u8> =
5755                        quantize_ndarray(scores.view(), detect_quant.into());
5756                    let classes_q: Array3<u8> =
5757                        quantize_ndarray(classes.view(), detect_quant.into());
5758                    let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5759
5760                    let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5761                    let classes_td =
5762                        make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5763                    let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5764
5765                    if is_proto {
5766                        let scores_td =
5767                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5768                        decoder
5769                            .decode_proto_tracked(
5770                                &mut tracker,
5771                                0,
5772                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5773                                &mut output_boxes,
5774                                &mut output_tracks,
5775                            )
5776                            .unwrap();
5777
5778                        assert_eq!(output_boxes.len(), 1);
5779                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5780
5781                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5782                            *score = u8::MIN;
5783                        }
5784                        let scores_td =
5785                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5786                        let proto_result = decoder
5787                            .decode_proto_tracked(
5788                                &mut tracker,
5789                                100_000_000 / 3,
5790                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5791                                &mut output_boxes,
5792                                &mut output_tracks,
5793                            )
5794                            .unwrap();
5795                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5796                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5797                    } else {
5798                        let scores_td =
5799                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5800                        let mut output_masks = Vec::with_capacity(50);
5801                        decoder
5802                            .decode_tracked(
5803                                &mut tracker,
5804                                0,
5805                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5806                                &mut output_boxes,
5807                                &mut output_masks,
5808                                &mut output_tracks,
5809                            )
5810                            .unwrap();
5811
5812                        assert_eq!(output_boxes.len(), 1);
5813                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5814
5815                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5816                            *score = u8::MIN;
5817                        }
5818                        let scores_td =
5819                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5820                        decoder
5821                            .decode_tracked(
5822                                &mut tracker,
5823                                100_000_000 / 3,
5824                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5825                                &mut output_boxes,
5826                                &mut output_masks,
5827                                &mut output_tracks,
5828                            )
5829                            .unwrap();
5830                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5831                        assert!(output_masks.is_empty());
5832                    }
5833                } else {
5834                    // Combined layout
5835                    let detect = ndarray::concatenate![
5836                        Axis(1),
5837                        boxes.view(),
5838                        scores.view(),
5839                        classes.view(),
5840                        mask.view()
5841                    ];
5842                    let detect = detect.insert_axis(Axis(0));
5843                    assert_eq!(detect.shape(), &[1, 10, 38]);
5844                    // Ensure contiguous layout after concatenation for as_slice()
5845                    let detect =
5846                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5847                            .unwrap();
5848                    let mut detect_q: Array3<u8> =
5849                        quantize_ndarray(detect.view(), detect_quant.into());
5850
5851                    if is_proto {
5852                        let detect_td =
5853                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5854                        decoder
5855                            .decode_proto_tracked(
5856                                &mut tracker,
5857                                0,
5858                                &[&detect_td, &protos_td],
5859                                &mut output_boxes,
5860                                &mut output_tracks,
5861                            )
5862                            .unwrap();
5863
5864                        assert_eq!(output_boxes.len(), 1);
5865                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5866
5867                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5868                            *score = u8::MIN;
5869                        }
5870                        let detect_td =
5871                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5872                        let proto_result = decoder
5873                            .decode_proto_tracked(
5874                                &mut tracker,
5875                                100_000_000 / 3,
5876                                &[&detect_td, &protos_td],
5877                                &mut output_boxes,
5878                                &mut output_tracks,
5879                            )
5880                            .unwrap();
5881                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5882                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5883                    } else {
5884                        let detect_td =
5885                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5886                        let mut output_masks = Vec::with_capacity(50);
5887                        decoder
5888                            .decode_tracked(
5889                                &mut tracker,
5890                                0,
5891                                &[&detect_td, &protos_td],
5892                                &mut output_boxes,
5893                                &mut output_masks,
5894                                &mut output_tracks,
5895                            )
5896                            .unwrap();
5897
5898                        assert_eq!(output_boxes.len(), 1);
5899                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5900
5901                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5902                            *score = u8::MIN;
5903                        }
5904                        let detect_td =
5905                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5906                        decoder
5907                            .decode_tracked(
5908                                &mut tracker,
5909                                100_000_000 / 3,
5910                                &[&detect_td, &protos_td],
5911                                &mut output_boxes,
5912                                &mut output_masks,
5913                                &mut output_tracks,
5914                            )
5915                            .unwrap();
5916                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5917                        assert!(output_masks.is_empty());
5918                    }
5919                }
5920            }
5921        };
5922        ($name:ident, float, $layout:ident, $output:ident) => {
5923            #[test]
5924            fn $name() {
5925                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5926
5927                let is_split = matches!(stringify!($layout), "split");
5928                let is_proto = matches!(stringify!($output), "proto");
5929
5930                let score_threshold = 0.45;
5931                let iou_threshold = 0.45;
5932
5933                let mut boxes = Array2::zeros((10, 4));
5934                let mut scores = Array2::zeros((10, 1));
5935                let mut classes = Array2::zeros((10, 1));
5936                let mask: Array2<f64> = Array2::zeros((10, 32));
5937                let protos = Array3::<f64>::zeros((160, 160, 32));
5938                let protos = protos.insert_axis(Axis(0));
5939
5940                boxes
5941                    .slice_mut(s![0, ..])
5942                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5943                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5944                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5945
5946                let decoder = if is_split {
5947                    DecoderBuilder::default()
5948                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5949                        .with_score_threshold(score_threshold)
5950                        .with_iou_threshold(iou_threshold)
5951                        .build()
5952                        .unwrap()
5953                } else {
5954                    DecoderBuilder::default()
5955                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5956                        .with_score_threshold(score_threshold)
5957                        .with_iou_threshold(iou_threshold)
5958                        .build()
5959                        .unwrap()
5960                };
5961
5962                // Helper to wrap an f64 slice into a TensorDyn
5963                let make_f64_tensor =
5964                    |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
5965                        let t = Tensor::<f64>::new(shape, None, None).unwrap();
5966                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5967                        t.into()
5968                    };
5969
5970                let expected = e2e_expected_boxes_float();
5971                let mut tracker = ByteTrackBuilder::new()
5972                    .track_update(0.1)
5973                    .track_high_conf(0.7)
5974                    .build();
5975                let mut output_boxes = Vec::with_capacity(50);
5976                let mut output_tracks = Vec::with_capacity(50);
5977
5978                let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
5979
5980                if is_split {
5981                    let boxes = boxes.insert_axis(Axis(0));
5982                    let mut scores = scores.insert_axis(Axis(0));
5983                    let classes = classes.insert_axis(Axis(0));
5984                    let mask = mask.insert_axis(Axis(0));
5985
5986                    let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
5987                    let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
5988                    let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
5989
5990                    if is_proto {
5991                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5992                        decoder
5993                            .decode_proto_tracked(
5994                                &mut tracker,
5995                                0,
5996                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5997                                &mut output_boxes,
5998                                &mut output_tracks,
5999                            )
6000                            .unwrap();
6001
6002                        assert_eq!(output_boxes.len(), 1);
6003                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6004
6005                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
6006                            *score = 0.0;
6007                        }
6008                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6009                        let proto_result = decoder
6010                            .decode_proto_tracked(
6011                                &mut tracker,
6012                                100_000_000 / 3,
6013                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6014                                &mut output_boxes,
6015                                &mut output_tracks,
6016                            )
6017                            .unwrap();
6018                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6019                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6020                    } else {
6021                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6022                        let mut output_masks = Vec::with_capacity(50);
6023                        decoder
6024                            .decode_tracked(
6025                                &mut tracker,
6026                                0,
6027                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6028                                &mut output_boxes,
6029                                &mut output_masks,
6030                                &mut output_tracks,
6031                            )
6032                            .unwrap();
6033
6034                        assert_eq!(output_boxes.len(), 1);
6035                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6036
6037                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
6038                            *score = 0.0;
6039                        }
6040                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6041                        decoder
6042                            .decode_tracked(
6043                                &mut tracker,
6044                                100_000_000 / 3,
6045                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6046                                &mut output_boxes,
6047                                &mut output_masks,
6048                                &mut output_tracks,
6049                            )
6050                            .unwrap();
6051                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6052                        assert!(output_masks.is_empty());
6053                    }
6054                } else {
6055                    // Combined layout
6056                    let detect = ndarray::concatenate![
6057                        Axis(1),
6058                        boxes.view(),
6059                        scores.view(),
6060                        classes.view(),
6061                        mask.view()
6062                    ];
6063                    let detect = detect.insert_axis(Axis(0));
6064                    assert_eq!(detect.shape(), &[1, 10, 38]);
6065                    // Ensure contiguous layout after concatenation for as_slice()
6066                    let mut detect =
6067                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
6068                            .unwrap();
6069
6070                    if is_proto {
6071                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6072                        decoder
6073                            .decode_proto_tracked(
6074                                &mut tracker,
6075                                0,
6076                                &[&detect_td, &protos_td],
6077                                &mut output_boxes,
6078                                &mut output_tracks,
6079                            )
6080                            .unwrap();
6081
6082                        assert_eq!(output_boxes.len(), 1);
6083                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6084
6085                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6086                            *score = 0.0;
6087                        }
6088                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6089                        let proto_result = decoder
6090                            .decode_proto_tracked(
6091                                &mut tracker,
6092                                100_000_000 / 3,
6093                                &[&detect_td, &protos_td],
6094                                &mut output_boxes,
6095                                &mut output_tracks,
6096                            )
6097                            .unwrap();
6098                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6099                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6100                    } else {
6101                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6102                        let mut output_masks = Vec::with_capacity(50);
6103                        decoder
6104                            .decode_tracked(
6105                                &mut tracker,
6106                                0,
6107                                &[&detect_td, &protos_td],
6108                                &mut output_boxes,
6109                                &mut output_masks,
6110                                &mut output_tracks,
6111                            )
6112                            .unwrap();
6113
6114                        assert_eq!(output_boxes.len(), 1);
6115                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6116
6117                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6118                            *score = 0.0;
6119                        }
6120                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6121                        decoder
6122                            .decode_tracked(
6123                                &mut tracker,
6124                                100_000_000 / 3,
6125                                &[&detect_td, &protos_td],
6126                                &mut output_boxes,
6127                                &mut output_masks,
6128                                &mut output_tracks,
6129                            )
6130                            .unwrap();
6131                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6132                        assert!(output_masks.is_empty());
6133                    }
6134                }
6135            }
6136        };
6137    }
6138
6139    e2e_tracked_tensor_test!(
6140        test_decoder_tracked_tensor_end_to_end_segdet,
6141        quantized,
6142        combined,
6143        masks
6144    );
6145    e2e_tracked_tensor_test!(
6146        test_decoder_tracked_tensor_end_to_end_segdet_float,
6147        float,
6148        combined,
6149        masks
6150    );
6151    e2e_tracked_tensor_test!(
6152        test_decoder_tracked_tensor_end_to_end_segdet_proto,
6153        quantized,
6154        combined,
6155        proto
6156    );
6157    e2e_tracked_tensor_test!(
6158        test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
6159        float,
6160        combined,
6161        proto
6162    );
6163    e2e_tracked_tensor_test!(
6164        test_decoder_tracked_tensor_end_to_end_segdet_split,
6165        quantized,
6166        split,
6167        masks
6168    );
6169    e2e_tracked_tensor_test!(
6170        test_decoder_tracked_tensor_end_to_end_segdet_split_float,
6171        float,
6172        split,
6173        masks
6174    );
6175    e2e_tracked_tensor_test!(
6176        test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
6177        quantized,
6178        split,
6179        proto
6180    );
6181    e2e_tracked_tensor_test!(
6182        test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
6183        float,
6184        split,
6185        proto
6186    );
6187
6188    #[test]
6189    fn test_decoder_tracked_linear_motion() {
6190        use crate::configs::{DecoderType, Nms};
6191        use crate::DecoderBuilder;
6192
6193        let score_threshold = 0.25;
6194        let iou_threshold = 0.1;
6195        let out = include_bytes!(concat!(
6196            env!("CARGO_MANIFEST_DIR"),
6197            "/../../testdata/yolov8s_80_classes.bin"
6198        ));
6199        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6200        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6201        let quant = (0.0040811873, -123).into();
6202
6203        let decoder = DecoderBuilder::default()
6204            .with_config_yolo_det(
6205                crate::configs::Detection {
6206                    decoder: DecoderType::Ultralytics,
6207                    shape: vec![1, 84, 8400],
6208                    anchors: None,
6209                    quantization: Some(quant),
6210                    dshape: vec![
6211                        (crate::configs::DimName::Batch, 1),
6212                        (crate::configs::DimName::NumFeatures, 84),
6213                        (crate::configs::DimName::NumBoxes, 8400),
6214                    ],
6215                    normalized: Some(true),
6216                },
6217                None,
6218            )
6219            .with_score_threshold(score_threshold)
6220            .with_iou_threshold(iou_threshold)
6221            .with_nms(Some(Nms::ClassAgnostic))
6222            .build()
6223            .unwrap();
6224
6225        let mut expected_boxes = [
6226            DetectBox {
6227                bbox: BoundingBox {
6228                    xmin: 0.5285137,
6229                    ymin: 0.05305544,
6230                    xmax: 0.87541467,
6231                    ymax: 0.9998909,
6232                },
6233                score: 0.5591227,
6234                label: 0,
6235            },
6236            DetectBox {
6237                bbox: BoundingBox {
6238                    xmin: 0.130598,
6239                    ymin: 0.43260583,
6240                    xmax: 0.35098213,
6241                    ymax: 0.9958097,
6242                },
6243                score: 0.33057618,
6244                label: 75,
6245            },
6246        ];
6247
6248        let mut tracker = ByteTrackBuilder::new()
6249            .track_update(0.1)
6250            .track_high_conf(0.3)
6251            .build();
6252
6253        let mut output_boxes = Vec::with_capacity(50);
6254        let mut output_masks = Vec::with_capacity(50);
6255        let mut output_tracks = Vec::with_capacity(50);
6256
6257        decoder
6258            .decode_tracked_quantized(
6259                &mut tracker,
6260                0,
6261                &[out.view().into()],
6262                &mut output_boxes,
6263                &mut output_masks,
6264                &mut output_tracks,
6265            )
6266            .unwrap();
6267
6268        assert_eq!(output_boxes.len(), 2);
6269        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6270        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
6271
6272        for i in 1..=100 {
6273            let mut out = out.clone();
6274            // introduce linear movement into the XY coordinates
6275            let mut x_values = out.slice_mut(s![0, 0, ..]);
6276            for x in x_values.iter_mut() {
6277                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
6278            }
6279
6280            decoder
6281                .decode_tracked_quantized(
6282                    &mut tracker,
6283                    100_000_000 * i / 3, // simulate 33.333ms between frames
6284                    &[out.view().into()],
6285                    &mut output_boxes,
6286                    &mut output_masks,
6287                    &mut output_tracks,
6288                )
6289                .unwrap();
6290
6291            assert_eq!(output_boxes.len(), 2);
6292        }
6293        let tracks = tracker.get_active_tracks();
6294        let predicted_boxes: Vec<_> = tracks
6295            .iter()
6296            .map(|track| {
6297                let mut l = track.last_box;
6298                l.bbox = track.info.tracked_location.into();
6299                l
6300            })
6301            .collect();
6302        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
6303        expected_boxes[0].bbox.xmax += 0.1;
6304        expected_boxes[1].bbox.xmin += 0.1;
6305        expected_boxes[1].bbox.xmax += 0.1;
6306
6307        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6308        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6309
6310        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6311        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
6312        for score in scores_values.iter_mut() {
6313            *score = i8::MIN; // set all scores to minimum to simulate no detections
6314        }
6315        decoder
6316            .decode_tracked_quantized(
6317                &mut tracker,
6318                100_000_000 * 101 / 3,
6319                &[out.view().into()],
6320                &mut output_boxes,
6321                &mut output_masks,
6322                &mut output_tracks,
6323            )
6324            .unwrap();
6325        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
6326        expected_boxes[0].bbox.xmax += 0.001;
6327        expected_boxes[1].bbox.xmin += 0.001;
6328        expected_boxes[1].bbox.xmax += 0.001;
6329
6330        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6331        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6332    }
6333
6334    #[test]
6335    fn test_decoder_tracked_end_to_end_float() {
6336        let score_threshold = 0.45;
6337        let iou_threshold = 0.45;
6338
6339        let mut boxes = Array2::zeros((10, 4));
6340        let mut scores = Array2::zeros((10, 1));
6341        let mut classes = Array2::zeros((10, 1));
6342
6343        boxes
6344            .slice_mut(s![0, ..,])
6345            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
6346        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
6347        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
6348
6349        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
6350        let mut detect = detect.insert_axis(Axis(0));
6351        assert_eq!(detect.shape(), &[1, 10, 6]);
6352        let config = "
6353decoder_version: yolo26
6354outputs:
6355 - type: detection
6356   decoder: ultralytics
6357   quantization: [0.00784313725490196, 0]
6358   shape: [1, 10, 6]
6359   dshape:
6360    - [batch, 1]
6361    - [num_boxes, 10]
6362    - [num_features, 6]
6363   normalized: true
6364";
6365
6366        let decoder = DecoderBuilder::default()
6367            .with_config_yaml_str(config.to_string())
6368            .with_score_threshold(score_threshold)
6369            .with_iou_threshold(iou_threshold)
6370            .build()
6371            .unwrap();
6372
6373        let expected_boxes = [DetectBox {
6374            bbox: BoundingBox {
6375                xmin: 0.1234,
6376                ymin: 0.1234,
6377                xmax: 0.2345,
6378                ymax: 0.2345,
6379            },
6380            score: 0.9876,
6381            label: 2,
6382        }];
6383
6384        let mut tracker = ByteTrackBuilder::new()
6385            .track_update(0.1)
6386            .track_high_conf(0.7)
6387            .build();
6388
6389        let mut output_boxes = Vec::with_capacity(50);
6390        let mut output_masks = Vec::with_capacity(50);
6391        let mut output_tracks = Vec::with_capacity(50);
6392
6393        decoder
6394            .decode_tracked_float(
6395                &mut tracker,
6396                0,
6397                &[detect.view().into_dyn()],
6398                &mut output_boxes,
6399                &mut output_masks,
6400                &mut output_tracks,
6401            )
6402            .unwrap();
6403
6404        assert_eq!(output_boxes.len(), 1);
6405        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6406
6407        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6408
6409        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6410            *score = 0.0; // set all scores to minimum to simulate no detections
6411        }
6412
6413        decoder
6414            .decode_tracked_float(
6415                &mut tracker,
6416                100_000_000 / 3,
6417                &[detect.view().into_dyn()],
6418                &mut output_boxes,
6419                &mut output_masks,
6420                &mut output_tracks,
6421            )
6422            .unwrap();
6423        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6424    }
6425}