Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod yolo;
77
78mod decoder;
79pub use decoder::*;
80
81pub use configs::{DecoderVersion, Nms};
82pub use error::{DecoderError, DecoderResult};
83
84use crate::{
85    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
86    yolo::yolo_segmentation_to_mask,
87};
88
89/// Trait to convert bounding box formats to XYXY float format
90pub trait BBoxTypeTrait {
91    /// Converts the bbox into XYXY float format.
92    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
93
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
96        input: &[B; 4],
97        quant: Quantization,
98    ) -> [A; 4]
99    where
100        f32: AsPrimitive<A>,
101        i32: AsPrimitive<A>;
102
103    /// Converts the bbox into XYXY float format.
104    ///
105    /// # Examples
106    /// ```rust
107    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
108    /// # use ndarray::array;
109    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
110    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
111    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
112    /// ```
113    #[inline(always)]
114    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
115        input: ArrayView1<B>,
116    ) -> [A; 4] {
117        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
118    }
119
120    #[inline(always)]
121    /// Converts the bbox into XYXY float format.
122    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
123        input: ArrayView1<B>,
124        quant: Quantization,
125    ) -> [A; 4]
126    where
127        f32: AsPrimitive<A>,
128        i32: AsPrimitive<A>,
129    {
130        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
131    }
132}
133
134/// Converts XYXY bounding boxes to XYXY
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct XYXY {}
137
138impl BBoxTypeTrait for XYXY {
139    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
140        input.map(|b| b.as_())
141    }
142
143    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
144        input: &[B; 4],
145        quant: Quantization,
146    ) -> [A; 4]
147    where
148        f32: AsPrimitive<A>,
149        i32: AsPrimitive<A>,
150    {
151        let scale = quant.scale.as_();
152        let zp = quant.zero_point.as_();
153        input.map(|b| (b.as_() - zp) * scale)
154    }
155
156    #[inline(always)]
157    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
158        input: ArrayView1<B>,
159    ) -> [A; 4] {
160        [
161            input[0].as_(),
162            input[1].as_(),
163            input[2].as_(),
164            input[3].as_(),
165        ]
166    }
167}
168
169/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
170/// box
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub struct XYWH {}
173
174impl BBoxTypeTrait for XYWH {
175    #[inline(always)]
176    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
177        let half = A::one() / (A::one() + A::one());
178        [
179            (input[0].as_()) - (input[2].as_() * half),
180            (input[1].as_()) - (input[3].as_() * half),
181            (input[0].as_()) + (input[2].as_() * half),
182            (input[1].as_()) + (input[3].as_() * half),
183        ]
184    }
185
186    #[inline(always)]
187    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
188        input: &[B; 4],
189        quant: Quantization,
190    ) -> [A; 4]
191    where
192        f32: AsPrimitive<A>,
193        i32: AsPrimitive<A>,
194    {
195        let scale = quant.scale.as_();
196        let half_scale = (quant.scale * 0.5).as_();
197        let zp = quant.zero_point.as_();
198        let [x, y, w, h] = [
199            (input[0].as_() - zp) * scale,
200            (input[1].as_() - zp) * scale,
201            (input[2].as_() - zp) * half_scale,
202            (input[3].as_() - zp) * half_scale,
203        ];
204
205        [x - w, y - h, x + w, y + h]
206    }
207
208    #[inline(always)]
209    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
210        input: ArrayView1<B>,
211    ) -> [A; 4] {
212        let half = A::one() / (A::one() + A::one());
213        [
214            (input[0].as_()) - (input[2].as_() * half),
215            (input[1].as_()) - (input[3].as_() * half),
216            (input[0].as_()) + (input[2].as_() * half),
217            (input[1].as_()) + (input[3].as_() * half),
218        ]
219    }
220}
221
222/// Describes the quantization parameters for a tensor
223#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct Quantization {
225    pub scale: f32,
226    pub zero_point: i32,
227}
228
229impl Quantization {
230    /// Creates a new Quantization struct
231    /// # Examples
232    /// ```
233    /// # use edgefirst_decoder::Quantization;
234    /// let quant = Quantization::new(0.1, -128);
235    /// assert_eq!(quant.scale, 0.1);
236    /// assert_eq!(quant.zero_point, -128);
237    /// ```
238    pub fn new(scale: f32, zero_point: i32) -> Self {
239        Self { scale, zero_point }
240    }
241}
242
243impl From<QuantTuple> for Quantization {
244    /// Creates a new Quantization struct from a QuantTuple
245    /// # Examples
246    /// ```
247    /// # use edgefirst_decoder::Quantization;
248    /// # use edgefirst_decoder::configs::QuantTuple;
249    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
250    /// let quant = Quantization::from(quant_tuple);
251    /// assert_eq!(quant.scale, 0.1);
252    /// assert_eq!(quant.zero_point, -128);
253    /// ```
254    fn from(quant_tuple: QuantTuple) -> Quantization {
255        Quantization {
256            scale: quant_tuple.0,
257            zero_point: quant_tuple.1,
258        }
259    }
260}
261
262impl<S, Z> From<(S, Z)> for Quantization
263where
264    S: AsPrimitive<f32>,
265    Z: AsPrimitive<i32>,
266{
267    /// Creates a new Quantization struct from a tuple
268    /// # Examples
269    /// ```
270    /// # use edgefirst_decoder::Quantization;
271    /// let quant = Quantization::from((0.1_f64, -128_i64));
272    /// assert_eq!(quant.scale, 0.1);
273    /// assert_eq!(quant.zero_point, -128);
274    /// ```
275    fn from((scale, zp): (S, Z)) -> Quantization {
276        Self {
277            scale: scale.as_(),
278            zero_point: zp.as_(),
279        }
280    }
281}
282
283impl Default for Quantization {
284    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
285    /// # Examples
286    /// ```rust
287    /// # use edgefirst_decoder::Quantization;
288    /// let quant = Quantization::default();
289    /// assert_eq!(quant.scale, 1.0);
290    /// assert_eq!(quant.zero_point, 0);
291    /// ```
292    fn default() -> Self {
293        Self {
294            scale: 1.0,
295            zero_point: 0,
296        }
297    }
298}
299
300/// A detection box with f32 bbox and score
301#[derive(Debug, Clone, Copy, PartialEq, Default)]
302pub struct DetectBox {
303    pub bbox: BoundingBox,
304    /// model-specific score for this detection, higher implies more confidence
305    pub score: f32,
306    /// label index for this detection
307    pub label: usize,
308}
309
310/// A bounding box with f32 coordinates in XYXY format
311#[derive(Debug, Clone, Copy, PartialEq, Default)]
312pub struct BoundingBox {
313    /// left-most normalized coordinate of the bounding box
314    pub xmin: f32,
315    /// top-most normalized coordinate of the bounding box
316    pub ymin: f32,
317    /// right-most normalized coordinate of the bounding box
318    pub xmax: f32,
319    /// bottom-most normalized coordinate of the bounding box
320    pub ymax: f32,
321}
322
323impl BoundingBox {
324    /// Creates a new BoundingBox from the given coordinates
325    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
326        Self {
327            xmin,
328            ymin,
329            xmax,
330            ymax,
331        }
332    }
333
334    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
335    ///
336    /// ```
337    /// # use edgefirst_decoder::BoundingBox;
338    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
339    /// let canonical_bbox = bbox.to_canonical();
340    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
341    /// ```
342    pub fn to_canonical(&self) -> Self {
343        let xmin = self.xmin.min(self.xmax);
344        let xmax = self.xmin.max(self.xmax);
345        let ymin = self.ymin.min(self.ymax);
346        let ymax = self.ymin.max(self.ymax);
347        BoundingBox {
348            xmin,
349            ymin,
350            xmax,
351            ymax,
352        }
353    }
354}
355
356impl From<BoundingBox> for [f32; 4] {
357    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
358    /// xmax, ymax order
359    /// # Examples
360    /// ```
361    /// # use edgefirst_decoder::BoundingBox;
362    /// let bbox = BoundingBox {
363    ///     xmin: 0.1,
364    ///     ymin: 0.2,
365    ///     xmax: 0.3,
366    ///     ymax: 0.4,
367    /// };
368    /// let arr: [f32; 4] = bbox.into();
369    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
370    /// ```
371    fn from(b: BoundingBox) -> Self {
372        [b.xmin, b.ymin, b.xmax, b.ymax]
373    }
374}
375
376impl From<[f32; 4]> for BoundingBox {
377    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
378    // BoundingBox
379    fn from(arr: [f32; 4]) -> Self {
380        BoundingBox {
381            xmin: arr[0],
382            ymin: arr[1],
383            xmax: arr[2],
384            ymax: arr[3],
385        }
386    }
387}
388
389impl DetectBox {
390    /// Returns true if one detect box is equal to another detect box, within
391    /// the given `eps`
392    ///
393    /// # Examples
394    /// ```
395    /// # use edgefirst_decoder::DetectBox;
396    /// let box1 = DetectBox {
397    ///     bbox: edgefirst_decoder::BoundingBox {
398    ///         xmin: 0.1,
399    ///         ymin: 0.2,
400    ///         xmax: 0.3,
401    ///         ymax: 0.4,
402    ///     },
403    ///     score: 0.5,
404    ///     label: 1,
405    /// };
406    /// let box2 = DetectBox {
407    ///     bbox: edgefirst_decoder::BoundingBox {
408    ///         xmin: 0.101,
409    ///         ymin: 0.199,
410    ///         xmax: 0.301,
411    ///         ymax: 0.399,
412    ///     },
413    ///     score: 0.510,
414    ///     label: 1,
415    /// };
416    /// assert!(box1.equal_within_delta(&box2, 0.011));
417    /// ```
418    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
419        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
420        self.label == rhs.label
421            && eq_delta(self.score, rhs.score)
422            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
423            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
424            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
425            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
426    }
427}
428
429/// A segmentation result with a segmentation mask, and a normalized bounding
430/// box representing the area that the segmentation mask covers
431#[derive(Debug, Clone, PartialEq, Default)]
432pub struct Segmentation {
433    /// left-most normalized coordinate of the segmentation box
434    pub xmin: f32,
435    /// top-most normalized coordinate of the segmentation box
436    pub ymin: f32,
437    /// right-most normalized coordinate of the segmentation box
438    pub xmax: f32,
439    /// bottom-most normalized coordinate of the segmentation box
440    pub ymax: f32,
441    /// 3D segmentation array of shape `(H, W, C)`.
442    ///
443    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
444    /// continuous sigmoid confidence values quantized to u8 (0 = background,
445    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
446    /// 0.5) or use smooth interpolation for anti-aliased edges.
447    ///
448    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
449    /// class scores where the object class is the argmax index.
450    pub segmentation: Array3<u8>,
451}
452
453/// Prototype tensor variants for fused decode+render pipelines.
454///
455/// Carries either raw quantized data (to skip CPU dequantization and let the
456/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
457/// paths).
458#[derive(Debug, Clone)]
459pub enum ProtoTensor {
460    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
461    /// The GPU fragment shader will dequantize per-texel using the scale and
462    /// zero_point.
463    Quantized {
464        protos: Array3<i8>,
465        quantization: Quantization,
466    },
467    /// Dequantized f32 protos (from float models or legacy path).
468    Float(Array3<f32>),
469}
470
471impl ProtoTensor {
472    /// Returns `true` if this is the quantized variant.
473    pub fn is_quantized(&self) -> bool {
474        matches!(self, ProtoTensor::Quantized { .. })
475    }
476
477    /// Returns the spatial dimensions `(height, width, num_protos)`.
478    pub fn dim(&self) -> (usize, usize, usize) {
479        match self {
480            ProtoTensor::Quantized { protos, .. } => protos.dim(),
481            ProtoTensor::Float(arr) => arr.dim(),
482        }
483    }
484
485    /// Returns dequantized f32 protos. For the `Float` variant this is a
486    /// no-copy reference; for `Quantized` it allocates and dequantizes.
487    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
488        match self {
489            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
490            ProtoTensor::Quantized {
491                protos,
492                quantization,
493            } => {
494                let scale = quantization.scale;
495                let zp = quantization.zero_point as f32;
496                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
497            }
498        }
499    }
500}
501
502/// Raw prototype data for fused decode+render pipelines.
503///
504/// Holds post-NMS intermediate state before mask materialization, allowing the
505/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
506/// shader) without materializing intermediate `Array3<u8>` masks.
507#[derive(Debug, Clone)]
508pub struct ProtoData {
509    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
510    pub mask_coefficients: Vec<Vec<f32>>,
511    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
512    pub protos: ProtoTensor,
513}
514
515/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
516///
517///  # Examples
518/// ```
519/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
520/// let quant = Quantization::new(0.1, -128);
521/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
522/// let detect_quant = DetectBoxQuantized {
523///     bbox,
524///     score: 100_i8,
525///     label: 1,
526/// };
527/// let detect = dequant_detect_box(&detect_quant, quant);
528/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
529/// assert_eq!(detect.label, 1);
530/// assert_eq!(detect.bbox, bbox);
531/// ```
532pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
533    detect: &DetectBoxQuantized<SCORE>,
534    quant_scores: Quantization,
535) -> DetectBox {
536    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
537    DetectBox {
538        bbox: detect.bbox,
539        score: quant_scores.scale * detect.score.as_() + scaled_zp,
540        label: detect.label,
541    }
542}
543/// A detection box with a f32 bbox and quantized score
544#[derive(Debug, Clone, Copy, PartialEq)]
545pub struct DetectBoxQuantized<
546    // BOX: Signed + PrimInt + AsPrimitive<f32>,
547    SCORE: PrimInt + AsPrimitive<f32>,
548> {
549    // pub bbox: BoundingBoxQuantized<BOX>,
550    pub bbox: BoundingBox,
551    /// model-specific score for this detection, higher implies more
552    /// confidence.
553    pub score: SCORE,
554    /// label index for this detect
555    pub label: usize,
556}
557
558/// Dequantizes an ndarray from quantized values to f32 values using the given
559/// quantization parameters
560///
561/// # Examples
562/// ```
563/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
564/// let quant = Quantization::new(0.1, -128);
565/// let input: Vec<i8> = vec![0, 127, -128, 64];
566/// let input_array = ndarray::Array1::from(input);
567/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
568/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
569/// ```
570pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
571    input: ArrayView<T, D>,
572    quant: Quantization,
573) -> Array<F, D>
574where
575    i32: num_traits::AsPrimitive<F>,
576    f32: num_traits::AsPrimitive<F>,
577{
578    let zero_point = quant.zero_point.as_();
579    let scale = quant.scale.as_();
580    if zero_point != F::zero() {
581        let scaled_zero = -zero_point * scale;
582        input.mapv(|d| d.as_() * scale + scaled_zero)
583    } else {
584        input.mapv(|d| d.as_() * scale)
585    }
586}
587
588/// Dequantizes a slice from quantized values to float values using the given
589/// quantization parameters
590///
591/// # Examples
592/// ```
593/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
594/// let quant = Quantization::new(0.1, -128);
595/// let input: Vec<i8> = vec![0, 127, -128, 64];
596/// let mut output: Vec<f32> = vec![0.0; input.len()];
597/// dequantize_cpu(&input, quant, &mut output);
598/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
599/// ```
600pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
601    input: &[T],
602    quant: Quantization,
603    output: &mut [F],
604) where
605    f32: num_traits::AsPrimitive<F>,
606    i32: num_traits::AsPrimitive<F>,
607{
608    assert!(input.len() == output.len());
609    let zero_point = quant.zero_point.as_();
610    let scale = quant.scale.as_();
611    if zero_point != F::zero() {
612        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
613        input
614            .iter()
615            .zip(output)
616            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
617    } else {
618        input
619            .iter()
620            .zip(output)
621            .for_each(|(d, deq)| *deq = d.as_() * scale);
622    }
623}
624
625/// Dequantizes a slice from quantized values to float values using the given
626/// quantization parameters, using chunked processing. This is around 5% faster
627/// than `dequantize_cpu` for large slices.
628///
629/// # Examples
630/// ```
631/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
632/// let quant = Quantization::new(0.1, -128);
633/// let input: Vec<i8> = vec![0, 127, -128, 64];
634/// let mut output: Vec<f32> = vec![0.0; input.len()];
635/// dequantize_cpu_chunked(&input, quant, &mut output);
636/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
637/// ```
638pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
639    input: &[T],
640    quant: Quantization,
641    output: &mut [F],
642) where
643    f32: num_traits::AsPrimitive<F>,
644    i32: num_traits::AsPrimitive<F>,
645{
646    assert!(input.len() == output.len());
647    let zero_point = quant.zero_point.as_();
648    let scale = quant.scale.as_();
649
650    let input = input.as_chunks::<4>();
651    let output = output.as_chunks_mut::<4>();
652
653    if zero_point != F::zero() {
654        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
655
656        input
657            .0
658            .iter()
659            .zip(output.0)
660            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
661        input
662            .1
663            .iter()
664            .zip(output.1)
665            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
666    } else {
667        input
668            .0
669            .iter()
670            .zip(output.0)
671            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
672        input
673            .1
674            .iter()
675            .zip(output.1)
676            .for_each(|(d, deq)| *deq = d.as_() * scale);
677    }
678}
679
680/// Converts a segmentation tensor into a 2D mask
681/// If the last dimension of the segmentation tensor is 1, values equal or
682/// above 128 are considered objects. Otherwise the object is the argmax index
683///
684/// # Errors
685///
686/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
687/// invalid shape.
688///
689/// # Examples
690/// ```
691/// # use edgefirst_decoder::segmentation_to_mask;
692/// let segmentation =
693///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
694/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
695/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
696/// ```
697pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
698    if segmentation.shape()[2] == 0 {
699        return Err(DecoderError::InvalidShape(
700            "Segmentation tensor must have non-zero depth".to_string(),
701        ));
702    }
703    if segmentation.shape()[2] == 1 {
704        yolo_segmentation_to_mask(segmentation, 128)
705    } else {
706        Ok(modelpack_segmentation_to_mask(segmentation))
707    }
708}
709
710/// Returns the maximum value and its index from a 1D array
711fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
712    score
713        .iter()
714        .enumerate()
715        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
716            if max > *s {
717                (max, arg_max)
718            } else {
719                (*s, ind)
720            }
721        })
722}
723#[cfg(test)]
724#[cfg_attr(coverage_nightly, coverage(off))]
725mod decoder_tests {
726    #![allow(clippy::excessive_precision)]
727    use crate::{
728        configs::{DecoderType, DimName, Protos},
729        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
730        yolo::{
731            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
732            decode_yolo_segdet_quant,
733        },
734        *,
735    };
736    use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
737    use ndarray::Dimension;
738    use ndarray::{array, s, Array2, Array3, Array4, Axis};
739    use ndarray_stats::DeviationExt;
740    use num_traits::{AsPrimitive, PrimInt};
741
742    fn compare_outputs(
743        boxes: (&[DetectBox], &[DetectBox]),
744        masks: (&[Segmentation], &[Segmentation]),
745    ) {
746        let (boxes0, boxes1) = boxes;
747        let (masks0, masks1) = masks;
748
749        assert_eq!(boxes0.len(), boxes1.len());
750        assert_eq!(masks0.len(), masks1.len());
751
752        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
753            assert!(
754                b_i8.equal_within_delta(b_f32, 1e-6),
755                "{b_i8:?} is not equal to {b_f32:?}"
756            );
757        }
758
759        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
760            assert_eq!(
761                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
762                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
763            );
764            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
765            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
766            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
767            let diff = &mask_i8 - &mask_f32;
768            for x in 0..diff.shape()[0] {
769                for y in 0..diff.shape()[1] {
770                    for z in 0..diff.shape()[2] {
771                        let val = diff[[x, y, z]];
772                        assert!(
773                            val.abs() <= 1,
774                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
775                            x,
776                            y,
777                            z,
778                            val
779                        );
780                    }
781                }
782            }
783            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
784            assert!(
785                mean_sq_err < 1e-2,
786                "Mean Square Error between masks was greater than 1%: {:.2}%",
787                mean_sq_err * 100.0
788            );
789        }
790    }
791
792    // ─── Shared test data loaders ────────────────────────
793
794    fn load_yolov8_boxes() -> Array3<i8> {
795        let raw = include_bytes!(concat!(
796            env!("CARGO_MANIFEST_DIR"),
797            "/../../testdata/yolov8_boxes_116x8400.bin"
798        ));
799        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
800        Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
801    }
802
803    fn load_yolov8_protos() -> Array4<i8> {
804        let raw = include_bytes!(concat!(
805            env!("CARGO_MANIFEST_DIR"),
806            "/../../testdata/yolov8_protos_160x160x32.bin"
807        ));
808        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
809        Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
810    }
811
812    fn load_yolov8s_det() -> Array3<i8> {
813        let raw = include_bytes!(concat!(
814            env!("CARGO_MANIFEST_DIR"),
815            "/../../testdata/yolov8s_80_classes.bin"
816        ));
817        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
818        Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
819    }
820
821    #[test]
822    fn test_decoder_modelpack() {
823        let score_threshold = 0.45;
824        let iou_threshold = 0.45;
825        let boxes = include_bytes!(concat!(
826            env!("CARGO_MANIFEST_DIR"),
827            "/../../testdata/modelpack_boxes_1935x1x4.bin"
828        ));
829        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
830
831        let scores = include_bytes!(concat!(
832            env!("CARGO_MANIFEST_DIR"),
833            "/../../testdata/modelpack_scores_1935x1.bin"
834        ));
835        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
836
837        let quant_boxes = (0.004656755365431309, 21).into();
838        let quant_scores = (0.0019603664986789227, 0).into();
839
840        let decoder = DecoderBuilder::default()
841            .with_config_modelpack_det(
842                configs::Boxes {
843                    decoder: DecoderType::ModelPack,
844                    quantization: Some(quant_boxes),
845                    shape: vec![1, 1935, 1, 4],
846                    dshape: vec![
847                        (DimName::Batch, 1),
848                        (DimName::NumBoxes, 1935),
849                        (DimName::Padding, 1),
850                        (DimName::BoxCoords, 4),
851                    ],
852                    normalized: Some(true),
853                },
854                configs::Scores {
855                    decoder: DecoderType::ModelPack,
856                    quantization: Some(quant_scores),
857                    shape: vec![1, 1935, 1],
858                    dshape: vec![
859                        (DimName::Batch, 1),
860                        (DimName::NumBoxes, 1935),
861                        (DimName::NumClasses, 1),
862                    ],
863                },
864            )
865            .with_score_threshold(score_threshold)
866            .with_iou_threshold(iou_threshold)
867            .build()
868            .unwrap();
869
870        let quant_boxes = quant_boxes.into();
871        let quant_scores = quant_scores.into();
872
873        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
874        decode_modelpack_det(
875            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
876            (scores.slice(s![0, .., ..]), quant_scores),
877            score_threshold,
878            iou_threshold,
879            &mut output_boxes,
880        );
881        assert!(output_boxes[0].equal_within_delta(
882            &DetectBox {
883                bbox: BoundingBox {
884                    xmin: 0.40513772,
885                    ymin: 0.6379755,
886                    xmax: 0.5122431,
887                    ymax: 0.7730214,
888                },
889                score: 0.4861709,
890                label: 0
891            },
892            1e-6
893        ));
894
895        let mut output_boxes1 = Vec::with_capacity(50);
896        let mut output_masks1 = Vec::with_capacity(50);
897
898        decoder
899            .decode_quantized(
900                &[boxes.view().into(), scores.view().into()],
901                &mut output_boxes1,
902                &mut output_masks1,
903            )
904            .unwrap();
905
906        let mut output_boxes_float = Vec::with_capacity(50);
907        let mut output_masks_float = Vec::with_capacity(50);
908
909        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
910        let scores = dequantize_ndarray(scores.view(), quant_scores);
911
912        decoder
913            .decode_float::<f32>(
914                &[boxes.view().into_dyn(), scores.view().into_dyn()],
915                &mut output_boxes_float,
916                &mut output_masks_float,
917            )
918            .unwrap();
919
920        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
921        compare_outputs(
922            (&output_boxes, &output_boxes_float),
923            (&[], &output_masks_float),
924        );
925    }
926
927    #[test]
928    fn test_decoder_modelpack_split_u8() {
929        let score_threshold = 0.45;
930        let iou_threshold = 0.45;
931        let detect0 = include_bytes!(concat!(
932            env!("CARGO_MANIFEST_DIR"),
933            "/../../testdata/modelpack_split_9x15x18.bin"
934        ));
935        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
936
937        let detect1 = include_bytes!(concat!(
938            env!("CARGO_MANIFEST_DIR"),
939            "/../../testdata/modelpack_split_17x30x18.bin"
940        ));
941        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
942
943        let quant0 = (0.08547406643629074, 174).into();
944        let quant1 = (0.09929127991199493, 183).into();
945        let anchors0 = vec![
946            [0.36666667461395264, 0.31481480598449707],
947            [0.38749998807907104, 0.4740740656852722],
948            [0.5333333611488342, 0.644444465637207],
949        ];
950        let anchors1 = vec![
951            [0.13750000298023224, 0.2074074000120163],
952            [0.2541666626930237, 0.21481481194496155],
953            [0.23125000298023224, 0.35185185074806213],
954        ];
955
956        let detect_config0 = configs::Detection {
957            decoder: DecoderType::ModelPack,
958            shape: vec![1, 9, 15, 18],
959            anchors: Some(anchors0.clone()),
960            quantization: Some(quant0),
961            dshape: vec![
962                (DimName::Batch, 1),
963                (DimName::Height, 9),
964                (DimName::Width, 15),
965                (DimName::NumAnchorsXFeatures, 18),
966            ],
967            normalized: Some(true),
968        };
969
970        let detect_config1 = configs::Detection {
971            decoder: DecoderType::ModelPack,
972            shape: vec![1, 17, 30, 18],
973            anchors: Some(anchors1.clone()),
974            quantization: Some(quant1),
975            dshape: vec![
976                (DimName::Batch, 1),
977                (DimName::Height, 17),
978                (DimName::Width, 30),
979                (DimName::NumAnchorsXFeatures, 18),
980            ],
981            normalized: Some(true),
982        };
983
984        let config0 = (&detect_config0).try_into().unwrap();
985        let config1 = (&detect_config1).try_into().unwrap();
986
987        let decoder = DecoderBuilder::default()
988            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
989            .with_score_threshold(score_threshold)
990            .with_iou_threshold(iou_threshold)
991            .build()
992            .unwrap();
993
994        let quant0 = quant0.into();
995        let quant1 = quant1.into();
996
997        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
998        decode_modelpack_split_quant(
999            &[
1000                detect0.slice(s![0, .., .., ..]),
1001                detect1.slice(s![0, .., .., ..]),
1002            ],
1003            &[config0, config1],
1004            score_threshold,
1005            iou_threshold,
1006            &mut output_boxes,
1007        );
1008        assert!(output_boxes[0].equal_within_delta(
1009            &DetectBox {
1010                bbox: BoundingBox {
1011                    xmin: 0.43171933,
1012                    ymin: 0.68243736,
1013                    xmax: 0.5626645,
1014                    ymax: 0.808863,
1015                },
1016                score: 0.99240804,
1017                label: 0
1018            },
1019            1e-6
1020        ));
1021
1022        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1023        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1024        decoder
1025            .decode_quantized(
1026                &[detect0.view().into(), detect1.view().into()],
1027                &mut output_boxes1,
1028                &mut output_masks1,
1029            )
1030            .unwrap();
1031
1032        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1033        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1034
1035        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1036        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1037        decoder
1038            .decode_float::<f32>(
1039                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1040                &mut output_boxes1_f32,
1041                &mut output_masks1_f32,
1042            )
1043            .unwrap();
1044
1045        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1046        compare_outputs(
1047            (&output_boxes, &output_boxes1_f32),
1048            (&[], &output_masks1_f32),
1049        );
1050    }
1051
1052    #[test]
1053    fn test_decoder_parse_config_modelpack_split_u8() {
1054        let score_threshold = 0.45;
1055        let iou_threshold = 0.45;
1056        let detect0 = include_bytes!(concat!(
1057            env!("CARGO_MANIFEST_DIR"),
1058            "/../../testdata/modelpack_split_9x15x18.bin"
1059        ));
1060        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1061
1062        let detect1 = include_bytes!(concat!(
1063            env!("CARGO_MANIFEST_DIR"),
1064            "/../../testdata/modelpack_split_17x30x18.bin"
1065        ));
1066        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1067
1068        let decoder = DecoderBuilder::default()
1069            .with_config_yaml_str(
1070                include_str!(concat!(
1071                    env!("CARGO_MANIFEST_DIR"),
1072                    "/../../testdata/modelpack_split.yaml"
1073                ))
1074                .to_string(),
1075            )
1076            .with_score_threshold(score_threshold)
1077            .with_iou_threshold(iou_threshold)
1078            .build()
1079            .unwrap();
1080
1081        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1082        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1083        decoder
1084            .decode_quantized(
1085                &[
1086                    ArrayViewDQuantized::from(detect1.view()),
1087                    ArrayViewDQuantized::from(detect0.view()),
1088                ],
1089                &mut output_boxes,
1090                &mut output_masks,
1091            )
1092            .unwrap();
1093        assert!(output_boxes[0].equal_within_delta(
1094            &DetectBox {
1095                bbox: BoundingBox {
1096                    xmin: 0.43171933,
1097                    ymin: 0.68243736,
1098                    xmax: 0.5626645,
1099                    ymax: 0.808863,
1100                },
1101                score: 0.99240804,
1102                label: 0
1103            },
1104            1e-6
1105        ));
1106    }
1107
1108    #[test]
1109    fn test_modelpack_seg() {
1110        let out = include_bytes!(concat!(
1111            env!("CARGO_MANIFEST_DIR"),
1112            "/../../testdata/modelpack_seg_2x160x160.bin"
1113        ));
1114        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1115        let quant = (1.0 / 255.0, 0).into();
1116
1117        let decoder = DecoderBuilder::default()
1118            .with_config_modelpack_seg(configs::Segmentation {
1119                decoder: DecoderType::ModelPack,
1120                quantization: Some(quant),
1121                shape: vec![1, 2, 160, 160],
1122                dshape: vec![
1123                    (DimName::Batch, 1),
1124                    (DimName::NumClasses, 2),
1125                    (DimName::Height, 160),
1126                    (DimName::Width, 160),
1127                ],
1128            })
1129            .build()
1130            .unwrap();
1131        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1132        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1133        decoder
1134            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1135            .unwrap();
1136
1137        let mut mask = out.slice(s![0, .., .., ..]);
1138        mask.swap_axes(0, 1);
1139        mask.swap_axes(1, 2);
1140        let mask = [Segmentation {
1141            xmin: 0.0,
1142            ymin: 0.0,
1143            xmax: 1.0,
1144            ymax: 1.0,
1145            segmentation: mask.into_owned(),
1146        }];
1147        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1148
1149        decoder
1150            .decode_float::<f32>(
1151                &[dequantize_ndarray(out.view(), quant.into())
1152                    .view()
1153                    .into_dyn()],
1154                &mut output_boxes,
1155                &mut output_masks,
1156            )
1157            .unwrap();
1158
1159        // not expected for float decoder to have same values as quantized decoder, as
1160        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1161        // the model output. Thus the float output is the same as the quantized output
1162        // but scaled differently. However, it is expected that the mask after argmax
1163        // will be the same.
1164        compare_outputs((&[], &output_boxes), (&[], &[]));
1165        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1166        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1167
1168        assert_eq!(mask0, mask1);
1169    }
1170    #[test]
1171    fn test_modelpack_seg_quant() {
1172        let out = include_bytes!(concat!(
1173            env!("CARGO_MANIFEST_DIR"),
1174            "/../../testdata/modelpack_seg_2x160x160.bin"
1175        ));
1176        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1177        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1178        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1179        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1180        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1181        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1182
1183        let quant = (1.0 / 255.0, 0).into();
1184
1185        let decoder = DecoderBuilder::default()
1186            .with_config_modelpack_seg(configs::Segmentation {
1187                decoder: DecoderType::ModelPack,
1188                quantization: Some(quant),
1189                shape: vec![1, 2, 160, 160],
1190                dshape: vec![
1191                    (DimName::Batch, 1),
1192                    (DimName::NumClasses, 2),
1193                    (DimName::Height, 160),
1194                    (DimName::Width, 160),
1195                ],
1196            })
1197            .build()
1198            .unwrap();
1199        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1200        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1201        decoder
1202            .decode_quantized(
1203                &[out_u8.view().into()],
1204                &mut output_boxes,
1205                &mut output_masks_u8,
1206            )
1207            .unwrap();
1208
1209        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1210        decoder
1211            .decode_quantized(
1212                &[out_i8.view().into()],
1213                &mut output_boxes,
1214                &mut output_masks_i8,
1215            )
1216            .unwrap();
1217
1218        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1219        decoder
1220            .decode_quantized(
1221                &[out_u16.view().into()],
1222                &mut output_boxes,
1223                &mut output_masks_u16,
1224            )
1225            .unwrap();
1226
1227        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1228        decoder
1229            .decode_quantized(
1230                &[out_i16.view().into()],
1231                &mut output_boxes,
1232                &mut output_masks_i16,
1233            )
1234            .unwrap();
1235
1236        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1237        decoder
1238            .decode_quantized(
1239                &[out_u32.view().into()],
1240                &mut output_boxes,
1241                &mut output_masks_u32,
1242            )
1243            .unwrap();
1244
1245        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1246        decoder
1247            .decode_quantized(
1248                &[out_i32.view().into()],
1249                &mut output_boxes,
1250                &mut output_masks_i32,
1251            )
1252            .unwrap();
1253
1254        compare_outputs((&[], &output_boxes), (&[], &[]));
1255        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1256        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1257        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1258        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1259        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1260        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1261        assert_eq!(mask_u8, mask_i8);
1262        assert_eq!(mask_u8, mask_u16);
1263        assert_eq!(mask_u8, mask_i16);
1264        assert_eq!(mask_u8, mask_u32);
1265        assert_eq!(mask_u8, mask_i32);
1266    }
1267
1268    #[test]
1269    fn test_modelpack_segdet() {
1270        let score_threshold = 0.45;
1271        let iou_threshold = 0.45;
1272
1273        let boxes = include_bytes!(concat!(
1274            env!("CARGO_MANIFEST_DIR"),
1275            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1276        ));
1277        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1278
1279        let scores = include_bytes!(concat!(
1280            env!("CARGO_MANIFEST_DIR"),
1281            "/../../testdata/modelpack_scores_1935x1.bin"
1282        ));
1283        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1284
1285        let seg = include_bytes!(concat!(
1286            env!("CARGO_MANIFEST_DIR"),
1287            "/../../testdata/modelpack_seg_2x160x160.bin"
1288        ));
1289        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1290
1291        let quant_boxes = (0.004656755365431309, 21).into();
1292        let quant_scores = (0.0019603664986789227, 0).into();
1293        let quant_seg = (1.0 / 255.0, 0).into();
1294
1295        let decoder = DecoderBuilder::default()
1296            .with_config_modelpack_segdet(
1297                configs::Boxes {
1298                    decoder: DecoderType::ModelPack,
1299                    quantization: Some(quant_boxes),
1300                    shape: vec![1, 1935, 1, 4],
1301                    dshape: vec![
1302                        (DimName::Batch, 1),
1303                        (DimName::NumBoxes, 1935),
1304                        (DimName::Padding, 1),
1305                        (DimName::BoxCoords, 4),
1306                    ],
1307                    normalized: Some(true),
1308                },
1309                configs::Scores {
1310                    decoder: DecoderType::ModelPack,
1311                    quantization: Some(quant_scores),
1312                    shape: vec![1, 1935, 1],
1313                    dshape: vec![
1314                        (DimName::Batch, 1),
1315                        (DimName::NumBoxes, 1935),
1316                        (DimName::NumClasses, 1),
1317                    ],
1318                },
1319                configs::Segmentation {
1320                    decoder: DecoderType::ModelPack,
1321                    quantization: Some(quant_seg),
1322                    shape: vec![1, 2, 160, 160],
1323                    dshape: vec![
1324                        (DimName::Batch, 1),
1325                        (DimName::NumClasses, 2),
1326                        (DimName::Height, 160),
1327                        (DimName::Width, 160),
1328                    ],
1329                },
1330            )
1331            .with_iou_threshold(iou_threshold)
1332            .with_score_threshold(score_threshold)
1333            .build()
1334            .unwrap();
1335        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1336        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1337        decoder
1338            .decode_quantized(
1339                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1340                &mut output_boxes,
1341                &mut output_masks,
1342            )
1343            .unwrap();
1344
1345        let mut mask = seg.slice(s![0, .., .., ..]);
1346        mask.swap_axes(0, 1);
1347        mask.swap_axes(1, 2);
1348        let mask = [Segmentation {
1349            xmin: 0.0,
1350            ymin: 0.0,
1351            xmax: 1.0,
1352            ymax: 1.0,
1353            segmentation: mask.into_owned(),
1354        }];
1355        let correct_boxes = [DetectBox {
1356            bbox: BoundingBox {
1357                xmin: 0.40513772,
1358                ymin: 0.6379755,
1359                xmax: 0.5122431,
1360                ymax: 0.7730214,
1361            },
1362            score: 0.4861709,
1363            label: 0,
1364        }];
1365        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1366
1367        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1368        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1369        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1370        decoder
1371            .decode_float::<f32>(
1372                &[
1373                    scores.view().into_dyn(),
1374                    boxes.view().into_dyn(),
1375                    seg.view().into_dyn(),
1376                ],
1377                &mut output_boxes,
1378                &mut output_masks,
1379            )
1380            .unwrap();
1381
1382        // not expected for float segmentation decoder to have same values as quantized
1383        // segmentation decoder, as float decoder ensures the data fills 0-255,
1384        // quantized decoder uses whatever the model output. Thus the float
1385        // output is the same as the quantized output but scaled differently.
1386        // However, it is expected that the mask after argmax will be the same.
1387        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1388        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1389        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1390
1391        assert_eq!(mask0, mask1);
1392    }
1393
1394    #[test]
1395    fn test_modelpack_segdet_split() {
1396        let score_threshold = 0.8;
1397        let iou_threshold = 0.5;
1398
1399        let seg = include_bytes!(concat!(
1400            env!("CARGO_MANIFEST_DIR"),
1401            "/../../testdata/modelpack_seg_2x160x160.bin"
1402        ));
1403        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1404
1405        let detect0 = include_bytes!(concat!(
1406            env!("CARGO_MANIFEST_DIR"),
1407            "/../../testdata/modelpack_split_9x15x18.bin"
1408        ));
1409        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1410
1411        let detect1 = include_bytes!(concat!(
1412            env!("CARGO_MANIFEST_DIR"),
1413            "/../../testdata/modelpack_split_17x30x18.bin"
1414        ));
1415        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1416
1417        let quant0 = (0.08547406643629074, 174).into();
1418        let quant1 = (0.09929127991199493, 183).into();
1419        let quant_seg = (1.0 / 255.0, 0).into();
1420
1421        let anchors0 = vec![
1422            [0.36666667461395264, 0.31481480598449707],
1423            [0.38749998807907104, 0.4740740656852722],
1424            [0.5333333611488342, 0.644444465637207],
1425        ];
1426        let anchors1 = vec![
1427            [0.13750000298023224, 0.2074074000120163],
1428            [0.2541666626930237, 0.21481481194496155],
1429            [0.23125000298023224, 0.35185185074806213],
1430        ];
1431
1432        let decoder = DecoderBuilder::default()
1433            .with_config_modelpack_segdet_split(
1434                vec![
1435                    configs::Detection {
1436                        decoder: DecoderType::ModelPack,
1437                        shape: vec![1, 17, 30, 18],
1438                        anchors: Some(anchors1),
1439                        quantization: Some(quant1),
1440                        dshape: vec![
1441                            (DimName::Batch, 1),
1442                            (DimName::Height, 17),
1443                            (DimName::Width, 30),
1444                            (DimName::NumAnchorsXFeatures, 18),
1445                        ],
1446                        normalized: Some(true),
1447                    },
1448                    configs::Detection {
1449                        decoder: DecoderType::ModelPack,
1450                        shape: vec![1, 9, 15, 18],
1451                        anchors: Some(anchors0),
1452                        quantization: Some(quant0),
1453                        dshape: vec![
1454                            (DimName::Batch, 1),
1455                            (DimName::Height, 9),
1456                            (DimName::Width, 15),
1457                            (DimName::NumAnchorsXFeatures, 18),
1458                        ],
1459                        normalized: Some(true),
1460                    },
1461                ],
1462                configs::Segmentation {
1463                    decoder: DecoderType::ModelPack,
1464                    quantization: Some(quant_seg),
1465                    shape: vec![1, 2, 160, 160],
1466                    dshape: vec![
1467                        (DimName::Batch, 1),
1468                        (DimName::NumClasses, 2),
1469                        (DimName::Height, 160),
1470                        (DimName::Width, 160),
1471                    ],
1472                },
1473            )
1474            .with_score_threshold(score_threshold)
1475            .with_iou_threshold(iou_threshold)
1476            .build()
1477            .unwrap();
1478        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1479        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1480        decoder
1481            .decode_quantized(
1482                &[
1483                    detect0.view().into(),
1484                    detect1.view().into(),
1485                    seg.view().into(),
1486                ],
1487                &mut output_boxes,
1488                &mut output_masks,
1489            )
1490            .unwrap();
1491
1492        let mut mask = seg.slice(s![0, .., .., ..]);
1493        mask.swap_axes(0, 1);
1494        mask.swap_axes(1, 2);
1495        let mask = [Segmentation {
1496            xmin: 0.0,
1497            ymin: 0.0,
1498            xmax: 1.0,
1499            ymax: 1.0,
1500            segmentation: mask.into_owned(),
1501        }];
1502        let correct_boxes = [DetectBox {
1503            bbox: BoundingBox {
1504                xmin: 0.43171933,
1505                ymin: 0.68243736,
1506                xmax: 0.5626645,
1507                ymax: 0.808863,
1508            },
1509            score: 0.99240804,
1510            label: 0,
1511        }];
1512        println!("Output Boxes: {:?}", output_boxes);
1513        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1514
1515        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1516        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1517        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1518        decoder
1519            .decode_float::<f32>(
1520                &[
1521                    detect0.view().into_dyn(),
1522                    detect1.view().into_dyn(),
1523                    seg.view().into_dyn(),
1524                ],
1525                &mut output_boxes,
1526                &mut output_masks,
1527            )
1528            .unwrap();
1529
1530        // not expected for float segmentation decoder to have same values as quantized
1531        // segmentation decoder, as float decoder ensures the data fills 0-255,
1532        // quantized decoder uses whatever the model output. Thus the float
1533        // output is the same as the quantized output but scaled differently.
1534        // However, it is expected that the mask after argmax will be the same.
1535        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1536        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1537        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1538
1539        assert_eq!(mask0, mask1);
1540    }
1541
1542    #[test]
1543    fn test_dequant_chunked() {
1544        let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1545        out.push(123); // make sure to test non multiple of 16 length
1546
1547        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1548        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1549        let quant = Quantization::new(0.0040811873, -123);
1550        dequantize_cpu(&out, quant, &mut out_dequant);
1551
1552        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1553        assert_eq!(out_dequant, out_dequant_simd);
1554
1555        let quant = Quantization::new(0.0040811873, 0);
1556        dequantize_cpu(&out, quant, &mut out_dequant);
1557
1558        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1559        assert_eq!(out_dequant, out_dequant_simd);
1560    }
1561
1562    #[test]
1563    fn test_dequant_ground_truth() {
1564        // Formula: output = (input - zero_point) * scale
1565        // Verify both dequantize_cpu and dequantize_cpu_chunked against hand-computed values.
1566
1567        // Case 1: scale=0.1, zero_point=-128 (from doc example)
1568        let quant = Quantization::new(0.1, -128);
1569        let input: Vec<i8> = vec![0, 127, -128, 64];
1570        let mut output = vec![0.0f32; 4];
1571        let mut output_chunked = vec![0.0f32; 4];
1572        dequantize_cpu(&input, quant, &mut output);
1573        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1574        // (0 - (-128)) * 0.1 = 12.8
1575        // (127 - (-128)) * 0.1 = 25.5
1576        // (-128 - (-128)) * 0.1 = 0.0
1577        // (64 - (-128)) * 0.1 = 19.2
1578        let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1579        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1580            assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1581        }
1582        for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1583            assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1584        }
1585
1586        // Case 2: scale=1.0, zero_point=0 (identity-like)
1587        let quant = Quantization::new(1.0, 0);
1588        dequantize_cpu(&input, quant, &mut output);
1589        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1590        let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1591        assert_eq!(output, expected);
1592        assert_eq!(output_chunked, expected);
1593
1594        // Case 3: scale=0.5, zero_point=0
1595        let quant = Quantization::new(0.5, 0);
1596        dequantize_cpu(&input, quant, &mut output);
1597        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1598        let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1599        assert_eq!(output, expected);
1600        assert_eq!(output_chunked, expected);
1601
1602        // Case 4: i8 min/max boundaries with typical quantization params
1603        let quant = Quantization::new(0.021287762, 31);
1604        let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1605        let mut output = vec![0.0f32; 6];
1606        let mut output_chunked = vec![0.0f32; 6];
1607        dequantize_cpu(&input, quant, &mut output);
1608        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1609        for i in 0..6 {
1610            let expected = (input[i] as f32 - 31.0) * 0.021287762;
1611            assert!(
1612                (output[i] - expected).abs() < 1e-5,
1613                "cpu[{i}]: {} != {expected}",
1614                output[i]
1615            );
1616            assert!(
1617                (output_chunked[i] - expected).abs() < 1e-5,
1618                "chunked[{i}]: {} != {expected}",
1619                output_chunked[i]
1620            );
1621        }
1622    }
1623
1624    #[test]
1625    fn test_decoder_yolo_det() {
1626        let score_threshold = 0.25;
1627        let iou_threshold = 0.7;
1628        let out = load_yolov8s_det();
1629        let quant = (0.0040811873, -123).into();
1630
1631        let decoder = DecoderBuilder::default()
1632            .with_config_yolo_det(
1633                configs::Detection {
1634                    decoder: DecoderType::Ultralytics,
1635                    shape: vec![1, 84, 8400],
1636                    anchors: None,
1637                    quantization: Some(quant),
1638                    dshape: vec![
1639                        (DimName::Batch, 1),
1640                        (DimName::NumFeatures, 84),
1641                        (DimName::NumBoxes, 8400),
1642                    ],
1643                    normalized: Some(true),
1644                },
1645                Some(DecoderVersion::Yolo11),
1646            )
1647            .with_score_threshold(score_threshold)
1648            .with_iou_threshold(iou_threshold)
1649            .build()
1650            .unwrap();
1651
1652        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1653        decode_yolo_det(
1654            (out.slice(s![0, .., ..]), quant.into()),
1655            score_threshold,
1656            iou_threshold,
1657            Some(configs::Nms::ClassAgnostic),
1658            &mut output_boxes,
1659        );
1660        assert!(output_boxes[0].equal_within_delta(
1661            &DetectBox {
1662                bbox: BoundingBox {
1663                    xmin: 0.5285137,
1664                    ymin: 0.05305544,
1665                    xmax: 0.87541467,
1666                    ymax: 0.9998909,
1667                },
1668                score: 0.5591227,
1669                label: 0
1670            },
1671            1e-6
1672        ));
1673
1674        assert!(output_boxes[1].equal_within_delta(
1675            &DetectBox {
1676                bbox: BoundingBox {
1677                    xmin: 0.130598,
1678                    ymin: 0.43260583,
1679                    xmax: 0.35098213,
1680                    ymax: 0.9958097,
1681                },
1682                score: 0.33057618,
1683                label: 75
1684            },
1685            1e-6
1686        ));
1687
1688        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1689        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1690        decoder
1691            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1692            .unwrap();
1693
1694        let out = dequantize_ndarray(out.view(), quant.into());
1695        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1696        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1697        decoder
1698            .decode_float::<f32>(
1699                &[out.view().into_dyn()],
1700                &mut output_boxes_f32,
1701                &mut output_masks_f32,
1702            )
1703            .unwrap();
1704
1705        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1706        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1707    }
1708
1709    #[test]
1710    fn test_decoder_masks() {
1711        let score_threshold = 0.45;
1712        let iou_threshold = 0.45;
1713        let boxes = load_yolov8_boxes();
1714        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1715
1716        let protos = load_yolov8_protos();
1717        let quant_protos = Quantization::new(0.02491161972284317, -117);
1718        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1719        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1720        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1721        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1722        decode_yolo_segdet_float(
1723            seg.slice(s![0, .., ..]),
1724            protos.slice(s![0, .., .., ..]),
1725            score_threshold,
1726            iou_threshold,
1727            Some(configs::Nms::ClassAgnostic),
1728            &mut output_boxes,
1729            &mut output_masks,
1730        )
1731        .unwrap();
1732        assert_eq!(output_boxes.len(), 2);
1733        assert_eq!(output_boxes.len(), output_masks.len());
1734
1735        for (b, m) in output_boxes.iter().zip(&output_masks) {
1736            assert!(b.bbox.xmin >= m.xmin);
1737            assert!(b.bbox.ymin >= m.ymin);
1738            assert!(b.bbox.xmax >= m.xmax);
1739            assert!(b.bbox.ymax >= m.ymax);
1740        }
1741        assert!(output_boxes[0].equal_within_delta(
1742            &DetectBox {
1743                bbox: BoundingBox {
1744                    xmin: 0.08515105,
1745                    ymin: 0.7131401,
1746                    xmax: 0.29802868,
1747                    ymax: 0.8195788,
1748                },
1749                score: 0.91537374,
1750                label: 23
1751            },
1752            1.0 / 160.0, // wider range because mask will expand the box
1753        ));
1754
1755        assert!(output_boxes[1].equal_within_delta(
1756            &DetectBox {
1757                bbox: BoundingBox {
1758                    xmin: 0.59605736,
1759                    ymin: 0.25545314,
1760                    xmax: 0.93666154,
1761                    ymax: 0.72378385,
1762                },
1763                score: 0.91537374,
1764                label: 23
1765            },
1766            1.0 / 160.0, // wider range because mask will expand the box
1767        ));
1768
1769        let full_mask = include_bytes!(concat!(
1770            env!("CARGO_MANIFEST_DIR"),
1771            "/../../testdata/yolov8_mask_results.bin"
1772        ));
1773        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1774
1775        let cropped_mask = full_mask.slice(ndarray::s![
1776            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1777            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1778        ]);
1779
1780        assert_eq!(
1781            cropped_mask,
1782            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1783        );
1784    }
1785
1786    /// Regression test: config-driven path with NCHW protos (no dshape).
1787    /// Simulates YOLOv8-seg ONNX outputs where protos are (1, 32, 160, 160)
1788    /// and the YAML config has no dshape field — the exact scenario from
1789    /// hal_mask_matmul_bug.md.
1790    #[test]
1791    fn test_decoder_masks_nchw_protos() {
1792        let score_threshold = 0.45;
1793        let iou_threshold = 0.45;
1794
1795        // Load test data — boxes as [116, 8400]
1796        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1797        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1798
1799        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1800        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1801        let quant_protos = Quantization::new(0.02491161972284317, -117);
1802        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1803
1804        // ---- Reference: direct call with HWC protos (known working) ----
1805        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1806        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1807        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1808        decode_yolo_segdet_float(
1809            seg.view(),
1810            protos_f32_hwc.view(),
1811            score_threshold,
1812            iou_threshold,
1813            Some(configs::Nms::ClassAgnostic),
1814            &mut ref_boxes,
1815            &mut ref_masks,
1816        )
1817        .unwrap();
1818        assert_eq!(ref_boxes.len(), 2);
1819
1820        // ---- Config-driven path: NCHW protos, no dshape ----
1821        // Permute protos to NCHW [1, 32, 160, 160] as an ONNX model would output
1822        let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); // [32, 160, 160]
1823        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1824
1825        // Build boxes as [1, 116, 8400] f32
1826        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1827
1828        // Build decoder from config with no dshape on protos
1829        let decoder = DecoderBuilder::default()
1830            .with_config_yolo_segdet(
1831                configs::Detection {
1832                    decoder: configs::DecoderType::Ultralytics,
1833                    quantization: None,
1834                    shape: vec![1, 116, 8400],
1835                    dshape: vec![],
1836                    normalized: Some(true),
1837                    anchors: None,
1838                },
1839                configs::Protos {
1840                    decoder: configs::DecoderType::Ultralytics,
1841                    quantization: None,
1842                    shape: vec![1, 32, 160, 160],
1843                    dshape: vec![], // No dshape — simulates YAML without dshape
1844                },
1845                None, // decoder version
1846            )
1847            .with_score_threshold(score_threshold)
1848            .with_iou_threshold(iou_threshold)
1849            .build()
1850            .unwrap();
1851
1852        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1853        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1854        decoder
1855            .decode_float(
1856                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1857                &mut cfg_boxes,
1858                &mut cfg_masks,
1859            )
1860            .unwrap();
1861
1862        // Must produce the same number of detections
1863        assert_eq!(
1864            cfg_boxes.len(),
1865            ref_boxes.len(),
1866            "config path produced {} boxes, reference produced {}",
1867            cfg_boxes.len(),
1868            ref_boxes.len()
1869        );
1870
1871        // Boxes must match
1872        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1873            assert!(
1874                cb.equal_within_delta(rb, 0.01),
1875                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1876            );
1877        }
1878
1879        // Masks must match pixel-for-pixel
1880        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1881            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1882            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1883            assert_eq!(
1884                cm_arr, rm_arr,
1885                "mask {i} pixel mismatch between config-driven and reference paths"
1886            );
1887        }
1888    }
1889
1890    #[test]
1891    fn test_decoder_masks_i8() {
1892        let score_threshold = 0.45;
1893        let iou_threshold = 0.45;
1894        let boxes = load_yolov8_boxes();
1895        let quant_boxes = (0.021287761628627777, 31).into();
1896
1897        let protos = load_yolov8_protos();
1898        let quant_protos = (0.02491161972284317, -117).into();
1899        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1900        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1901
1902        let decoder = DecoderBuilder::default()
1903            .with_config_yolo_segdet(
1904                configs::Detection {
1905                    decoder: configs::DecoderType::Ultralytics,
1906                    quantization: Some(quant_boxes),
1907                    shape: vec![1, 116, 8400],
1908                    anchors: None,
1909                    dshape: vec![
1910                        (DimName::Batch, 1),
1911                        (DimName::NumFeatures, 116),
1912                        (DimName::NumBoxes, 8400),
1913                    ],
1914                    normalized: Some(true),
1915                },
1916                Protos {
1917                    decoder: configs::DecoderType::Ultralytics,
1918                    quantization: Some(quant_protos),
1919                    shape: vec![1, 160, 160, 32],
1920                    dshape: vec![
1921                        (DimName::Batch, 1),
1922                        (DimName::Height, 160),
1923                        (DimName::Width, 160),
1924                        (DimName::NumProtos, 32),
1925                    ],
1926                },
1927                Some(DecoderVersion::Yolo11),
1928            )
1929            .with_score_threshold(score_threshold)
1930            .with_iou_threshold(iou_threshold)
1931            .build()
1932            .unwrap();
1933
1934        let quant_boxes = quant_boxes.into();
1935        let quant_protos = quant_protos.into();
1936
1937        decode_yolo_segdet_quant(
1938            (boxes.slice(s![0, .., ..]), quant_boxes),
1939            (protos.slice(s![0, .., .., ..]), quant_protos),
1940            score_threshold,
1941            iou_threshold,
1942            Some(configs::Nms::ClassAgnostic),
1943            &mut output_boxes,
1944            &mut output_masks,
1945        )
1946        .unwrap();
1947
1948        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1949        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1950
1951        decoder
1952            .decode_quantized(
1953                &[boxes.view().into(), protos.view().into()],
1954                &mut output_boxes1,
1955                &mut output_masks1,
1956            )
1957            .unwrap();
1958
1959        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1960        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1961
1962        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1963        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1964        decode_yolo_segdet_float(
1965            seg.slice(s![0, .., ..]),
1966            protos.slice(s![0, .., .., ..]),
1967            score_threshold,
1968            iou_threshold,
1969            Some(configs::Nms::ClassAgnostic),
1970            &mut output_boxes_f32,
1971            &mut output_masks_f32,
1972        )
1973        .unwrap();
1974
1975        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1976        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1977
1978        decoder
1979            .decode_float(
1980                &[seg.view().into_dyn(), protos.view().into_dyn()],
1981                &mut output_boxes1_f32,
1982                &mut output_masks1_f32,
1983            )
1984            .unwrap();
1985
1986        compare_outputs(
1987            (&output_boxes, &output_boxes1),
1988            (&output_masks, &output_masks1),
1989        );
1990
1991        compare_outputs(
1992            (&output_boxes, &output_boxes_f32),
1993            (&output_masks, &output_masks_f32),
1994        );
1995
1996        compare_outputs(
1997            (&output_boxes_f32, &output_boxes1_f32),
1998            (&output_masks_f32, &output_masks1_f32),
1999        );
2000    }
2001
2002    #[test]
2003    fn test_decoder_yolo_split() {
2004        let score_threshold = 0.45;
2005        let iou_threshold = 0.45;
2006        let boxes = load_yolov8_boxes();
2007        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2008        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2009
2010        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2011
2012        let decoder = DecoderBuilder::default()
2013            .with_config_yolo_split_det(
2014                configs::Boxes {
2015                    decoder: configs::DecoderType::Ultralytics,
2016                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2017                    shape: vec![1, 4, 8400],
2018                    dshape: vec![
2019                        (DimName::Batch, 1),
2020                        (DimName::BoxCoords, 4),
2021                        (DimName::NumBoxes, 8400),
2022                    ],
2023                    normalized: Some(true),
2024                },
2025                configs::Scores {
2026                    decoder: configs::DecoderType::Ultralytics,
2027                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2028                    shape: vec![1, 80, 8400],
2029                    dshape: vec![
2030                        (DimName::Batch, 1),
2031                        (DimName::NumClasses, 80),
2032                        (DimName::NumBoxes, 8400),
2033                    ],
2034                },
2035            )
2036            .with_score_threshold(score_threshold)
2037            .with_iou_threshold(iou_threshold)
2038            .build()
2039            .unwrap();
2040
2041        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2042        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2043
2044        decoder
2045            .decode_quantized(
2046                &[
2047                    boxes.slice(s![.., ..4, ..]).into(),
2048                    boxes.slice(s![.., 4..84, ..]).into(),
2049                ],
2050                &mut output_boxes,
2051                &mut output_masks,
2052            )
2053            .unwrap();
2054
2055        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2056        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2057        decode_yolo_det_float(
2058            seg.slice(s![0, ..84, ..]),
2059            score_threshold,
2060            iou_threshold,
2061            Some(configs::Nms::ClassAgnostic),
2062            &mut output_boxes_f32,
2063        );
2064
2065        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2066        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2067
2068        decoder
2069            .decode_float(
2070                &[
2071                    seg.slice(s![.., ..4, ..]).into_dyn(),
2072                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2073                ],
2074                &mut output_boxes1,
2075                &mut output_masks1,
2076            )
2077            .unwrap();
2078        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2079        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2080    }
2081
2082    #[test]
2083    fn test_decoder_masks_config_mixed() {
2084        let score_threshold = 0.45;
2085        let iou_threshold = 0.45;
2086        let boxes_raw = load_yolov8_boxes();
2087        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2088        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2089
2090        let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2091
2092        let protos = load_yolov8_protos();
2093        let quant_protos = (0.02491161972284317, -117);
2094
2095        let decoder = build_yolo_split_segdet_decoder(
2096            score_threshold,
2097            iou_threshold,
2098            quant_boxes,
2099            quant_protos,
2100        );
2101        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2102        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2103
2104        decoder
2105            .decode_quantized(
2106                &[
2107                    boxes.slice(s![.., ..4, ..]).into(),
2108                    boxes.slice(s![.., 4..84, ..]).into(),
2109                    boxes.slice(s![.., 84.., ..]).into(),
2110                    protos.view().into(),
2111                ],
2112                &mut output_boxes,
2113                &mut output_masks,
2114            )
2115            .unwrap();
2116
2117        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2118        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2119        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2120        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2121        decode_yolo_segdet_float(
2122            seg.slice(s![0, .., ..]),
2123            protos.slice(s![0, .., .., ..]),
2124            score_threshold,
2125            iou_threshold,
2126            Some(configs::Nms::ClassAgnostic),
2127            &mut output_boxes_f32,
2128            &mut output_masks_f32,
2129        )
2130        .unwrap();
2131
2132        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2133        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2134
2135        decoder
2136            .decode_float(
2137                &[
2138                    seg.slice(s![.., ..4, ..]).into_dyn(),
2139                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2140                    seg.slice(s![.., 84.., ..]).into_dyn(),
2141                    protos.view().into_dyn(),
2142                ],
2143                &mut output_boxes1,
2144                &mut output_masks1,
2145            )
2146            .unwrap();
2147        compare_outputs(
2148            (&output_boxes, &output_boxes_f32),
2149            (&output_masks, &output_masks_f32),
2150        );
2151        compare_outputs(
2152            (&output_boxes_f32, &output_boxes1),
2153            (&output_masks_f32, &output_masks1),
2154        );
2155    }
2156
2157    fn build_yolo_split_segdet_decoder(
2158        score_threshold: f32,
2159        iou_threshold: f32,
2160        quant_boxes: (f32, i32),
2161        quant_protos: (f32, i32),
2162    ) -> crate::Decoder {
2163        DecoderBuilder::default()
2164            .with_config_yolo_split_segdet(
2165                configs::Boxes {
2166                    decoder: configs::DecoderType::Ultralytics,
2167                    quantization: Some(quant_boxes.into()),
2168                    shape: vec![1, 4, 8400],
2169                    dshape: vec![
2170                        (DimName::Batch, 1),
2171                        (DimName::BoxCoords, 4),
2172                        (DimName::NumBoxes, 8400),
2173                    ],
2174                    normalized: Some(true),
2175                },
2176                configs::Scores {
2177                    decoder: configs::DecoderType::Ultralytics,
2178                    quantization: Some(quant_boxes.into()),
2179                    shape: vec![1, 80, 8400],
2180                    dshape: vec![
2181                        (DimName::Batch, 1),
2182                        (DimName::NumClasses, 80),
2183                        (DimName::NumBoxes, 8400),
2184                    ],
2185                },
2186                configs::MaskCoefficients {
2187                    decoder: configs::DecoderType::Ultralytics,
2188                    quantization: Some(quant_boxes.into()),
2189                    shape: vec![1, 32, 8400],
2190                    dshape: vec![
2191                        (DimName::Batch, 1),
2192                        (DimName::NumProtos, 32),
2193                        (DimName::NumBoxes, 8400),
2194                    ],
2195                },
2196                configs::Protos {
2197                    decoder: configs::DecoderType::Ultralytics,
2198                    quantization: Some(quant_protos.into()),
2199                    shape: vec![1, 160, 160, 32],
2200                    dshape: vec![
2201                        (DimName::Batch, 1),
2202                        (DimName::Height, 160),
2203                        (DimName::Width, 160),
2204                        (DimName::NumProtos, 32),
2205                    ],
2206                },
2207            )
2208            .with_score_threshold(score_threshold)
2209            .with_iou_threshold(iou_threshold)
2210            .build()
2211            .unwrap()
2212    }
2213
2214    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2215        let config_yaml = include_str!(concat!(
2216            env!("CARGO_MANIFEST_DIR"),
2217            "/../../testdata/yolov8_seg.yaml"
2218        ));
2219        DecoderBuilder::default()
2220            .with_config_yaml_str(config_yaml.to_string())
2221            .with_score_threshold(score_threshold)
2222            .with_iou_threshold(iou_threshold)
2223            .build()
2224            .unwrap()
2225    }
2226    #[test]
2227    fn test_decoder_masks_config_i32() {
2228        let score_threshold = 0.45;
2229        let iou_threshold = 0.45;
2230        let boxes_raw = load_yolov8_boxes();
2231        let scale = 1 << 23;
2232        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2233        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2234
2235        let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2236
2237        let protos_raw = load_yolov8_protos();
2238        let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2239        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2240        let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2241
2242        let decoder = build_yolo_split_segdet_decoder(
2243            score_threshold,
2244            iou_threshold,
2245            quant_boxes,
2246            quant_protos,
2247        );
2248
2249        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2250        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2251
2252        decoder
2253            .decode_quantized(
2254                &[
2255                    boxes.slice(s![.., ..4, ..]).into(),
2256                    boxes.slice(s![.., 4..84, ..]).into(),
2257                    boxes.slice(s![.., 84.., ..]).into(),
2258                    protos.view().into(),
2259                ],
2260                &mut output_boxes,
2261                &mut output_masks,
2262            )
2263            .unwrap();
2264
2265        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2266        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2267        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2268        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2269        decode_yolo_segdet_float(
2270            seg.slice(s![0, .., ..]),
2271            protos.slice(s![0, .., .., ..]),
2272            score_threshold,
2273            iou_threshold,
2274            Some(configs::Nms::ClassAgnostic),
2275            &mut output_boxes_f32,
2276            &mut output_masks_f32,
2277        )
2278        .unwrap();
2279
2280        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2281        assert_eq!(output_masks.len(), output_masks_f32.len());
2282
2283        compare_outputs(
2284            (&output_boxes, &output_boxes_f32),
2285            (&output_masks, &output_masks_f32),
2286        );
2287    }
2288
2289    /// test running multiple decoders concurrently
2290    #[test]
2291    fn test_context_switch() {
2292        let yolo_det = || {
2293            let score_threshold = 0.25;
2294            let iou_threshold = 0.7;
2295            let out = load_yolov8s_det();
2296            let quant = (0.0040811873, -123).into();
2297
2298            let decoder = DecoderBuilder::default()
2299                .with_config_yolo_det(
2300                    configs::Detection {
2301                        decoder: DecoderType::Ultralytics,
2302                        shape: vec![1, 84, 8400],
2303                        anchors: None,
2304                        quantization: Some(quant),
2305                        dshape: vec![
2306                            (DimName::Batch, 1),
2307                            (DimName::NumFeatures, 84),
2308                            (DimName::NumBoxes, 8400),
2309                        ],
2310                        normalized: None,
2311                    },
2312                    None,
2313                )
2314                .with_score_threshold(score_threshold)
2315                .with_iou_threshold(iou_threshold)
2316                .build()
2317                .unwrap();
2318
2319            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2320            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2321
2322            for _ in 0..100 {
2323                decoder
2324                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2325                    .unwrap();
2326
2327                assert!(output_boxes[0].equal_within_delta(
2328                    &DetectBox {
2329                        bbox: BoundingBox {
2330                            xmin: 0.5285137,
2331                            ymin: 0.05305544,
2332                            xmax: 0.87541467,
2333                            ymax: 0.9998909,
2334                        },
2335                        score: 0.5591227,
2336                        label: 0
2337                    },
2338                    1e-6
2339                ));
2340
2341                assert!(output_boxes[1].equal_within_delta(
2342                    &DetectBox {
2343                        bbox: BoundingBox {
2344                            xmin: 0.130598,
2345                            ymin: 0.43260583,
2346                            xmax: 0.35098213,
2347                            ymax: 0.9958097,
2348                        },
2349                        score: 0.33057618,
2350                        label: 75
2351                    },
2352                    1e-6
2353                ));
2354                assert!(output_masks.is_empty());
2355            }
2356        };
2357
2358        let modelpack_det_split = || {
2359            let score_threshold = 0.8;
2360            let iou_threshold = 0.5;
2361
2362            let seg = include_bytes!(concat!(
2363                env!("CARGO_MANIFEST_DIR"),
2364                "/../../testdata/modelpack_seg_2x160x160.bin"
2365            ));
2366            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2367
2368            let detect0 = include_bytes!(concat!(
2369                env!("CARGO_MANIFEST_DIR"),
2370                "/../../testdata/modelpack_split_9x15x18.bin"
2371            ));
2372            let detect0 =
2373                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2374
2375            let detect1 = include_bytes!(concat!(
2376                env!("CARGO_MANIFEST_DIR"),
2377                "/../../testdata/modelpack_split_17x30x18.bin"
2378            ));
2379            let detect1 =
2380                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2381
2382            let mut mask = seg.slice(s![0, .., .., ..]);
2383            mask.swap_axes(0, 1);
2384            mask.swap_axes(1, 2);
2385            let mask = [Segmentation {
2386                xmin: 0.0,
2387                ymin: 0.0,
2388                xmax: 1.0,
2389                ymax: 1.0,
2390                segmentation: mask.into_owned(),
2391            }];
2392            let correct_boxes = [DetectBox {
2393                bbox: BoundingBox {
2394                    xmin: 0.43171933,
2395                    ymin: 0.68243736,
2396                    xmax: 0.5626645,
2397                    ymax: 0.808863,
2398                },
2399                score: 0.99240804,
2400                label: 0,
2401            }];
2402
2403            let quant0 = (0.08547406643629074, 174).into();
2404            let quant1 = (0.09929127991199493, 183).into();
2405            let quant_seg = (1.0 / 255.0, 0).into();
2406
2407            let anchors0 = vec![
2408                [0.36666667461395264, 0.31481480598449707],
2409                [0.38749998807907104, 0.4740740656852722],
2410                [0.5333333611488342, 0.644444465637207],
2411            ];
2412            let anchors1 = vec![
2413                [0.13750000298023224, 0.2074074000120163],
2414                [0.2541666626930237, 0.21481481194496155],
2415                [0.23125000298023224, 0.35185185074806213],
2416            ];
2417
2418            let decoder = DecoderBuilder::default()
2419                .with_config_modelpack_segdet_split(
2420                    vec![
2421                        configs::Detection {
2422                            decoder: DecoderType::ModelPack,
2423                            shape: vec![1, 17, 30, 18],
2424                            anchors: Some(anchors1),
2425                            quantization: Some(quant1),
2426                            dshape: vec![
2427                                (DimName::Batch, 1),
2428                                (DimName::Height, 17),
2429                                (DimName::Width, 30),
2430                                (DimName::NumAnchorsXFeatures, 18),
2431                            ],
2432                            normalized: None,
2433                        },
2434                        configs::Detection {
2435                            decoder: DecoderType::ModelPack,
2436                            shape: vec![1, 9, 15, 18],
2437                            anchors: Some(anchors0),
2438                            quantization: Some(quant0),
2439                            dshape: vec![
2440                                (DimName::Batch, 1),
2441                                (DimName::Height, 9),
2442                                (DimName::Width, 15),
2443                                (DimName::NumAnchorsXFeatures, 18),
2444                            ],
2445                            normalized: None,
2446                        },
2447                    ],
2448                    configs::Segmentation {
2449                        decoder: DecoderType::ModelPack,
2450                        quantization: Some(quant_seg),
2451                        shape: vec![1, 2, 160, 160],
2452                        dshape: vec![
2453                            (DimName::Batch, 1),
2454                            (DimName::NumClasses, 2),
2455                            (DimName::Height, 160),
2456                            (DimName::Width, 160),
2457                        ],
2458                    },
2459                )
2460                .with_score_threshold(score_threshold)
2461                .with_iou_threshold(iou_threshold)
2462                .build()
2463                .unwrap();
2464            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2465            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2466
2467            for _ in 0..100 {
2468                decoder
2469                    .decode_quantized(
2470                        &[
2471                            detect0.view().into(),
2472                            detect1.view().into(),
2473                            seg.view().into(),
2474                        ],
2475                        &mut output_boxes,
2476                        &mut output_masks,
2477                    )
2478                    .unwrap();
2479
2480                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2481            }
2482        };
2483
2484        let handles = vec![
2485            std::thread::spawn(yolo_det),
2486            std::thread::spawn(modelpack_det_split),
2487            std::thread::spawn(yolo_det),
2488            std::thread::spawn(modelpack_det_split),
2489            std::thread::spawn(yolo_det),
2490            std::thread::spawn(modelpack_det_split),
2491            std::thread::spawn(yolo_det),
2492            std::thread::spawn(modelpack_det_split),
2493        ];
2494        for handle in handles {
2495            handle.join().unwrap();
2496        }
2497    }
2498
2499    #[test]
2500    fn test_ndarray_to_xyxy_float() {
2501        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2502        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2503        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2504
2505        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2506        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2507        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2508    }
2509
2510    #[test]
2511    fn test_class_aware_nms_float() {
2512        use crate::float::nms_class_aware_float;
2513
2514        // Create two overlapping boxes with different classes
2515        let boxes = vec![
2516            DetectBox {
2517                bbox: BoundingBox {
2518                    xmin: 0.0,
2519                    ymin: 0.0,
2520                    xmax: 0.5,
2521                    ymax: 0.5,
2522                },
2523                score: 0.9,
2524                label: 0, // class 0
2525            },
2526            DetectBox {
2527                bbox: BoundingBox {
2528                    xmin: 0.1,
2529                    ymin: 0.1,
2530                    xmax: 0.6,
2531                    ymax: 0.6,
2532                },
2533                score: 0.8,
2534                label: 1, // class 1 - different class
2535            },
2536        ];
2537
2538        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2539        // threshold 0.3)
2540        let result = nms_class_aware_float(0.3, boxes.clone());
2541        assert_eq!(
2542            result.len(),
2543            2,
2544            "Class-aware NMS should keep both boxes with different classes"
2545        );
2546
2547        // Now test with same class - should suppress one
2548        let same_class_boxes = vec![
2549            DetectBox {
2550                bbox: BoundingBox {
2551                    xmin: 0.0,
2552                    ymin: 0.0,
2553                    xmax: 0.5,
2554                    ymax: 0.5,
2555                },
2556                score: 0.9,
2557                label: 0,
2558            },
2559            DetectBox {
2560                bbox: BoundingBox {
2561                    xmin: 0.1,
2562                    ymin: 0.1,
2563                    xmax: 0.6,
2564                    ymax: 0.6,
2565                },
2566                score: 0.8,
2567                label: 0, // same class
2568            },
2569        ];
2570
2571        let result = nms_class_aware_float(0.3, same_class_boxes);
2572        assert_eq!(
2573            result.len(),
2574            1,
2575            "Class-aware NMS should suppress overlapping box with same class"
2576        );
2577        assert_eq!(result[0].label, 0);
2578        assert!((result[0].score - 0.9).abs() < 1e-6);
2579    }
2580
2581    #[test]
2582    fn test_class_agnostic_vs_aware_nms() {
2583        use crate::float::{nms_class_aware_float, nms_float};
2584
2585        // Two overlapping boxes with different classes
2586        let boxes = vec![
2587            DetectBox {
2588                bbox: BoundingBox {
2589                    xmin: 0.0,
2590                    ymin: 0.0,
2591                    xmax: 0.5,
2592                    ymax: 0.5,
2593                },
2594                score: 0.9,
2595                label: 0,
2596            },
2597            DetectBox {
2598                bbox: BoundingBox {
2599                    xmin: 0.1,
2600                    ymin: 0.1,
2601                    xmax: 0.6,
2602                    ymax: 0.6,
2603                },
2604                score: 0.8,
2605                label: 1,
2606            },
2607        ];
2608
2609        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2610        let agnostic_result = nms_float(0.3, boxes.clone());
2611        assert_eq!(
2612            agnostic_result.len(),
2613            1,
2614            "Class-agnostic NMS should suppress overlapping boxes"
2615        );
2616
2617        // Class-aware should keep both (different classes)
2618        let aware_result = nms_class_aware_float(0.3, boxes);
2619        assert_eq!(
2620            aware_result.len(),
2621            2,
2622            "Class-aware NMS should keep boxes with different classes"
2623        );
2624    }
2625
2626    #[test]
2627    fn test_class_aware_nms_int() {
2628        use crate::byte::nms_class_aware_int;
2629
2630        // Create two overlapping boxes with different classes
2631        let boxes = vec![
2632            DetectBoxQuantized {
2633                bbox: BoundingBox {
2634                    xmin: 0.0,
2635                    ymin: 0.0,
2636                    xmax: 0.5,
2637                    ymax: 0.5,
2638                },
2639                score: 200_u8,
2640                label: 0,
2641            },
2642            DetectBoxQuantized {
2643                bbox: BoundingBox {
2644                    xmin: 0.1,
2645                    ymin: 0.1,
2646                    xmax: 0.6,
2647                    ymax: 0.6,
2648                },
2649                score: 180_u8,
2650                label: 1, // different class
2651            },
2652        ];
2653
2654        // Should keep both (different classes)
2655        let result = nms_class_aware_int(0.5, boxes);
2656        assert_eq!(
2657            result.len(),
2658            2,
2659            "Class-aware NMS (int) should keep boxes with different classes"
2660        );
2661    }
2662
2663    #[test]
2664    fn test_nms_enum_default() {
2665        // Test that Nms enum has the correct default
2666        let default_nms: configs::Nms = Default::default();
2667        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2668    }
2669
2670    #[test]
2671    fn test_decoder_nms_mode() {
2672        // Test that decoder properly stores NMS mode
2673        let decoder = DecoderBuilder::default()
2674            .with_config_yolo_det(
2675                configs::Detection {
2676                    anchors: None,
2677                    decoder: DecoderType::Ultralytics,
2678                    quantization: None,
2679                    shape: vec![1, 84, 8400],
2680                    dshape: Vec::new(),
2681                    normalized: Some(true),
2682                },
2683                None,
2684            )
2685            .with_nms(Some(configs::Nms::ClassAware))
2686            .build()
2687            .unwrap();
2688
2689        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2690    }
2691
2692    #[test]
2693    fn test_decoder_nms_bypass() {
2694        // Test that decoder can be configured with nms=None (bypass)
2695        let decoder = DecoderBuilder::default()
2696            .with_config_yolo_det(
2697                configs::Detection {
2698                    anchors: None,
2699                    decoder: DecoderType::Ultralytics,
2700                    quantization: None,
2701                    shape: vec![1, 84, 8400],
2702                    dshape: Vec::new(),
2703                    normalized: Some(true),
2704                },
2705                None,
2706            )
2707            .with_nms(None)
2708            .build()
2709            .unwrap();
2710
2711        assert_eq!(decoder.nms, None);
2712    }
2713
2714    #[test]
2715    fn test_decoder_normalized_boxes_true() {
2716        // Test that normalized_boxes returns Some(true) when explicitly set
2717        let decoder = DecoderBuilder::default()
2718            .with_config_yolo_det(
2719                configs::Detection {
2720                    anchors: None,
2721                    decoder: DecoderType::Ultralytics,
2722                    quantization: None,
2723                    shape: vec![1, 84, 8400],
2724                    dshape: Vec::new(),
2725                    normalized: Some(true),
2726                },
2727                None,
2728            )
2729            .build()
2730            .unwrap();
2731
2732        assert_eq!(decoder.normalized_boxes(), Some(true));
2733    }
2734
2735    #[test]
2736    fn test_decoder_normalized_boxes_false() {
2737        // Test that normalized_boxes returns Some(false) when config specifies
2738        // unnormalized
2739        let decoder = DecoderBuilder::default()
2740            .with_config_yolo_det(
2741                configs::Detection {
2742                    anchors: None,
2743                    decoder: DecoderType::Ultralytics,
2744                    quantization: None,
2745                    shape: vec![1, 84, 8400],
2746                    dshape: Vec::new(),
2747                    normalized: Some(false),
2748                },
2749                None,
2750            )
2751            .build()
2752            .unwrap();
2753
2754        assert_eq!(decoder.normalized_boxes(), Some(false));
2755    }
2756
2757    #[test]
2758    fn test_decoder_normalized_boxes_unknown() {
2759        // Test that normalized_boxes returns None when not specified in config
2760        let decoder = DecoderBuilder::default()
2761            .with_config_yolo_det(
2762                configs::Detection {
2763                    anchors: None,
2764                    decoder: DecoderType::Ultralytics,
2765                    quantization: None,
2766                    shape: vec![1, 84, 8400],
2767                    dshape: Vec::new(),
2768                    normalized: None,
2769                },
2770                Some(DecoderVersion::Yolo11),
2771            )
2772            .build()
2773            .unwrap();
2774
2775        assert_eq!(decoder.normalized_boxes(), None);
2776    }
2777
2778    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2779        input: ArrayView<F, D>,
2780        quant: Quantization,
2781    ) -> Array<T, D>
2782    where
2783        i32: num_traits::AsPrimitive<F>,
2784        f32: num_traits::AsPrimitive<F>,
2785    {
2786        let zero_point = quant.zero_point.as_();
2787        let div_scale = F::one() / quant.scale.as_();
2788        if zero_point != F::zero() {
2789            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2790        } else {
2791            input.mapv(|d| (d * div_scale).round().as_())
2792        }
2793    }
2794
2795    fn real_data_expected_boxes() -> [DetectBox; 2] {
2796        [
2797            DetectBox {
2798                bbox: BoundingBox {
2799                    xmin: 0.08515105,
2800                    ymin: 0.7131401,
2801                    xmax: 0.29802868,
2802                    ymax: 0.8195788,
2803                },
2804                score: 0.91537374,
2805                label: 23,
2806            },
2807            DetectBox {
2808                bbox: BoundingBox {
2809                    xmin: 0.59605736,
2810                    ymin: 0.25545314,
2811                    xmax: 0.93666154,
2812                    ymax: 0.72378385,
2813                },
2814                score: 0.91537374,
2815                label: 23,
2816            },
2817        ]
2818    }
2819
2820    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2821        [DetectBox {
2822            bbox: BoundingBox {
2823                xmin: 0.12549022,
2824                ymin: 0.12549022,
2825                xmax: 0.23529413,
2826                ymax: 0.23529413,
2827            },
2828            score: 0.98823535,
2829            label: 2,
2830        }]
2831    }
2832
2833    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2834        [DetectBox {
2835            bbox: BoundingBox {
2836                xmin: 0.1234,
2837                ymin: 0.1234,
2838                xmax: 0.2345,
2839                ymax: 0.2345,
2840            },
2841            score: 0.9876,
2842            label: 2,
2843        }]
2844    }
2845
2846    macro_rules! real_data_proto_test {
2847        ($name:ident, quantized, $layout:ident) => {
2848            #[test]
2849            fn $name() {
2850                let is_split = matches!(stringify!($layout), "split");
2851
2852                let score_threshold = 0.45;
2853                let iou_threshold = 0.45;
2854                let quant_boxes = (0.021287762_f32, 31_i32);
2855                let quant_protos = (0.02491162_f32, -117_i32);
2856
2857                let raw_boxes = include_bytes!(concat!(
2858                    env!("CARGO_MANIFEST_DIR"),
2859                    "/../../testdata/yolov8_boxes_116x8400.bin"
2860                ));
2861                let raw_boxes = unsafe {
2862                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2863                };
2864                let boxes_i8 =
2865                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2866
2867                let raw_protos = include_bytes!(concat!(
2868                    env!("CARGO_MANIFEST_DIR"),
2869                    "/../../testdata/yolov8_protos_160x160x32.bin"
2870                ));
2871                let raw_protos = unsafe {
2872                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2873                };
2874                let protos_i8 =
2875                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2876                        .unwrap();
2877
2878                // Pre-split (unused for combined, but harmless)
2879                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2880                let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2881                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2882                let boxes_combined = boxes_i8;
2883
2884                let decoder = if is_split {
2885                    build_yolo_split_segdet_decoder(
2886                        score_threshold,
2887                        iou_threshold,
2888                        quant_boxes,
2889                        quant_protos,
2890                    )
2891                } else {
2892                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2893                };
2894
2895                let expected = real_data_expected_boxes();
2896                let mut output_boxes = Vec::with_capacity(50);
2897
2898                let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
2899                    vec![
2900                        boxes_split.view().into(),
2901                        scores_split.view().into(),
2902                        mask_split.view().into(),
2903                        protos_i8.view().into(),
2904                    ]
2905                } else {
2906                    vec![boxes_combined.view().into(), protos_i8.view().into()]
2907                };
2908                decoder
2909                    .decode_quantized_proto(&inputs, &mut output_boxes)
2910                    .unwrap();
2911
2912                assert_eq!(output_boxes.len(), 2);
2913                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2914                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2915            }
2916        };
2917        ($name:ident, float, $layout:ident) => {
2918            #[test]
2919            fn $name() {
2920                let is_split = matches!(stringify!($layout), "split");
2921
2922                let score_threshold = 0.45;
2923                let iou_threshold = 0.45;
2924                let quant_boxes = (0.021287762_f32, 31_i32);
2925                let quant_protos = (0.02491162_f32, -117_i32);
2926
2927                let raw_boxes = include_bytes!(concat!(
2928                    env!("CARGO_MANIFEST_DIR"),
2929                    "/../../testdata/yolov8_boxes_116x8400.bin"
2930                ));
2931                let raw_boxes = unsafe {
2932                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2933                };
2934                let boxes_i8 =
2935                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2936                let boxes_f32: Array3<f32> =
2937                    dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
2938
2939                let raw_protos = include_bytes!(concat!(
2940                    env!("CARGO_MANIFEST_DIR"),
2941                    "/../../testdata/yolov8_protos_160x160x32.bin"
2942                ));
2943                let raw_protos = unsafe {
2944                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2945                };
2946                let protos_i8 =
2947                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2948                        .unwrap();
2949                let protos_f32: Array4<f32> =
2950                    dequantize_ndarray(protos_i8.view(), quant_protos.into());
2951
2952                // Pre-split from dequantized data
2953                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
2954                let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
2955                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
2956                let boxes_combined = boxes_f32;
2957
2958                let decoder = if is_split {
2959                    build_yolo_split_segdet_decoder(
2960                        score_threshold,
2961                        iou_threshold,
2962                        quant_boxes,
2963                        quant_protos,
2964                    )
2965                } else {
2966                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2967                };
2968
2969                let expected = real_data_expected_boxes();
2970                let mut output_boxes = Vec::with_capacity(50);
2971
2972                let inputs = if is_split {
2973                    vec![
2974                        boxes_split.view().into_dyn(),
2975                        scores_split.view().into_dyn(),
2976                        mask_split.view().into_dyn(),
2977                        protos_f32.view().into_dyn(),
2978                    ]
2979                } else {
2980                    vec![
2981                        boxes_combined.view().into_dyn(),
2982                        protos_f32.view().into_dyn(),
2983                    ]
2984                };
2985                decoder
2986                    .decode_float_proto(&inputs, &mut output_boxes)
2987                    .unwrap();
2988
2989                assert_eq!(output_boxes.len(), 2);
2990                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2991                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2992            }
2993        };
2994    }
2995
2996    real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
2997    real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
2998    real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
2999    real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3000
3001    const E2E_COMBINED_DET_CONFIG: &str = "
3002decoder_version: yolo26
3003outputs:
3004 - type: detection
3005   decoder: ultralytics
3006   quantization: [0.00784313725490196, 0]
3007   shape: [1, 10, 6]
3008   dshape:
3009    - [batch, 1]
3010    - [num_boxes, 10]
3011    - [num_features, 6]
3012   normalized: true
3013";
3014
3015    const E2E_COMBINED_SEGDET_CONFIG: &str = "
3016decoder_version: yolo26
3017outputs:
3018 - type: detection
3019   decoder: ultralytics
3020   quantization: [0.00784313725490196, 0]
3021   shape: [1, 10, 38]
3022   dshape:
3023    - [batch, 1]
3024    - [num_boxes, 10]
3025    - [num_features, 38]
3026   normalized: true
3027 - type: protos
3028   decoder: ultralytics
3029   quantization: [0.0039215686274509803921568627451, 128]
3030   shape: [1, 160, 160, 32]
3031   dshape:
3032    - [batch, 1]
3033    - [height, 160]
3034    - [width, 160]
3035    - [num_protos, 32]
3036";
3037
3038    const E2E_SPLIT_DET_CONFIG: &str = "
3039decoder_version: yolo26
3040outputs:
3041 - type: boxes
3042   decoder: ultralytics
3043   quantization: [0.00784313725490196, 0]
3044   shape: [1, 10, 4]
3045   dshape:
3046    - [batch, 1]
3047    - [num_boxes, 10]
3048    - [box_coords, 4]
3049   normalized: true
3050 - type: scores
3051   decoder: ultralytics
3052   quantization: [0.00784313725490196, 0]
3053   shape: [1, 10, 1]
3054   dshape:
3055    - [batch, 1]
3056    - [num_boxes, 10]
3057    - [num_classes, 1]
3058 - type: classes
3059   decoder: ultralytics
3060   quantization: [0.00784313725490196, 0]
3061   shape: [1, 10, 1]
3062   dshape:
3063    - [batch, 1]
3064    - [num_boxes, 10]
3065    - [num_classes, 1]
3066";
3067
3068    const E2E_SPLIT_SEGDET_CONFIG: &str = "
3069decoder_version: yolo26
3070outputs:
3071 - type: boxes
3072   decoder: ultralytics
3073   quantization: [0.00784313725490196, 0]
3074   shape: [1, 10, 4]
3075   dshape:
3076    - [batch, 1]
3077    - [num_boxes, 10]
3078    - [box_coords, 4]
3079   normalized: true
3080 - type: scores
3081   decoder: ultralytics
3082   quantization: [0.00784313725490196, 0]
3083   shape: [1, 10, 1]
3084   dshape:
3085    - [batch, 1]
3086    - [num_boxes, 10]
3087    - [num_classes, 1]
3088 - type: classes
3089   decoder: ultralytics
3090   quantization: [0.00784313725490196, 0]
3091   shape: [1, 10, 1]
3092   dshape:
3093    - [batch, 1]
3094    - [num_boxes, 10]
3095    - [num_classes, 1]
3096 - type: mask_coefficients
3097   decoder: ultralytics
3098   quantization: [0.00784313725490196, 0]
3099   shape: [1, 10, 32]
3100   dshape:
3101    - [batch, 1]
3102    - [num_boxes, 10]
3103    - [num_protos, 32]
3104 - type: protos
3105   decoder: ultralytics
3106   quantization: [0.0039215686274509803921568627451, 128]
3107   shape: [1, 160, 160, 32]
3108   dshape:
3109    - [batch, 1]
3110    - [height, 160]
3111    - [width, 160]
3112    - [num_protos, 32]
3113";
3114
3115    macro_rules! e2e_segdet_test {
3116        ($name:ident, quantized, $layout:ident, $output:ident) => {
3117            #[test]
3118            fn $name() {
3119                let is_split = matches!(stringify!($layout), "split");
3120                let is_proto = matches!(stringify!($output), "proto");
3121
3122                let score_threshold = 0.45;
3123                let iou_threshold = 0.45;
3124
3125                let mut boxes = Array2::zeros((10, 4));
3126                let mut scores = Array2::zeros((10, 1));
3127                let mut classes = Array2::zeros((10, 1));
3128                let mask = Array2::zeros((10, 32));
3129                let protos = Array3::<f64>::zeros((160, 160, 32));
3130                let protos = protos.insert_axis(Axis(0));
3131                let protos_quant = (1.0 / 255.0, 0.0);
3132                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3133
3134                boxes
3135                    .slice_mut(s![0, ..])
3136                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3137                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3138                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3139
3140                let detect_quant = (2.0 / 255.0, 0.0);
3141
3142                let decoder = if is_split {
3143                    DecoderBuilder::default()
3144                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3145                        .with_score_threshold(score_threshold)
3146                        .with_iou_threshold(iou_threshold)
3147                        .build()
3148                        .unwrap()
3149                } else {
3150                    DecoderBuilder::default()
3151                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3152                        .with_score_threshold(score_threshold)
3153                        .with_iou_threshold(iou_threshold)
3154                        .build()
3155                        .unwrap()
3156                };
3157
3158                let expected = e2e_expected_boxes_quant();
3159                let mut output_boxes = Vec::with_capacity(50);
3160
3161                if is_split {
3162                    let boxes = boxes.insert_axis(Axis(0));
3163                    let scores = scores.insert_axis(Axis(0));
3164                    let classes = classes.insert_axis(Axis(0));
3165                    let mask = mask.insert_axis(Axis(0));
3166
3167                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3168                    let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3169                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3170                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3171
3172                    if is_proto {
3173                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3174                            boxes.view().into(),
3175                            scores.view().into(),
3176                            classes.view().into(),
3177                            mask.view().into(),
3178                            protos.view().into(),
3179                        ];
3180                        decoder
3181                            .decode_quantized_proto(&inputs, &mut output_boxes)
3182                            .unwrap();
3183
3184                        assert_eq!(output_boxes.len(), 1);
3185                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3186                    } else {
3187                        let mut output_masks = Vec::with_capacity(50);
3188                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3189                            boxes.view().into(),
3190                            scores.view().into(),
3191                            classes.view().into(),
3192                            mask.view().into(),
3193                            protos.view().into(),
3194                        ];
3195                        decoder
3196                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3197                            .unwrap();
3198
3199                        assert_eq!(output_boxes.len(), 1);
3200                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3201                    }
3202                } else {
3203                    // Combined layout
3204                    let detect = ndarray::concatenate![
3205                        Axis(1),
3206                        boxes.view(),
3207                        scores.view(),
3208                        classes.view(),
3209                        mask.view()
3210                    ];
3211                    let detect = detect.insert_axis(Axis(0));
3212                    assert_eq!(detect.shape(), &[1, 10, 38]);
3213                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3214
3215                    if is_proto {
3216                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3217                            vec![detect.view().into(), protos.view().into()];
3218                        decoder
3219                            .decode_quantized_proto(&inputs, &mut output_boxes)
3220                            .unwrap();
3221
3222                        assert_eq!(output_boxes.len(), 1);
3223                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3224                    } else {
3225                        let mut output_masks = Vec::with_capacity(50);
3226                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3227                            vec![detect.view().into(), protos.view().into()];
3228                        decoder
3229                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3230                            .unwrap();
3231
3232                        assert_eq!(output_boxes.len(), 1);
3233                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3234                    }
3235                }
3236            }
3237        };
3238        ($name:ident, float, $layout:ident, $output:ident) => {
3239            #[test]
3240            fn $name() {
3241                let is_split = matches!(stringify!($layout), "split");
3242                let is_proto = matches!(stringify!($output), "proto");
3243
3244                let score_threshold = 0.45;
3245                let iou_threshold = 0.45;
3246
3247                let mut boxes = Array2::zeros((10, 4));
3248                let mut scores = Array2::zeros((10, 1));
3249                let mut classes = Array2::zeros((10, 1));
3250                let mask: Array2<f64> = Array2::zeros((10, 32));
3251                let protos = Array3::<f64>::zeros((160, 160, 32));
3252                let protos = protos.insert_axis(Axis(0));
3253
3254                boxes
3255                    .slice_mut(s![0, ..])
3256                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3257                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3258                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3259
3260                let decoder = if is_split {
3261                    DecoderBuilder::default()
3262                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3263                        .with_score_threshold(score_threshold)
3264                        .with_iou_threshold(iou_threshold)
3265                        .build()
3266                        .unwrap()
3267                } else {
3268                    DecoderBuilder::default()
3269                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3270                        .with_score_threshold(score_threshold)
3271                        .with_iou_threshold(iou_threshold)
3272                        .build()
3273                        .unwrap()
3274                };
3275
3276                let expected = e2e_expected_boxes_float();
3277                let mut output_boxes = Vec::with_capacity(50);
3278
3279                if is_split {
3280                    let boxes = boxes.insert_axis(Axis(0));
3281                    let scores = scores.insert_axis(Axis(0));
3282                    let classes = classes.insert_axis(Axis(0));
3283                    let mask = mask.insert_axis(Axis(0));
3284
3285                    if is_proto {
3286                        let inputs = vec![
3287                            boxes.view().into_dyn(),
3288                            scores.view().into_dyn(),
3289                            classes.view().into_dyn(),
3290                            mask.view().into_dyn(),
3291                            protos.view().into_dyn(),
3292                        ];
3293                        decoder
3294                            .decode_float_proto(&inputs, &mut output_boxes)
3295                            .unwrap();
3296
3297                        assert_eq!(output_boxes.len(), 1);
3298                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3299                    } else {
3300                        let mut output_masks = Vec::with_capacity(50);
3301                        let inputs = vec![
3302                            boxes.view().into_dyn(),
3303                            scores.view().into_dyn(),
3304                            classes.view().into_dyn(),
3305                            mask.view().into_dyn(),
3306                            protos.view().into_dyn(),
3307                        ];
3308                        decoder
3309                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3310                            .unwrap();
3311
3312                        assert_eq!(output_boxes.len(), 1);
3313                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3314                    }
3315                } else {
3316                    // Combined layout
3317                    let detect = ndarray::concatenate![
3318                        Axis(1),
3319                        boxes.view(),
3320                        scores.view(),
3321                        classes.view(),
3322                        mask.view()
3323                    ];
3324                    let detect = detect.insert_axis(Axis(0));
3325                    assert_eq!(detect.shape(), &[1, 10, 38]);
3326
3327                    if is_proto {
3328                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3329                        decoder
3330                            .decode_float_proto(&inputs, &mut output_boxes)
3331                            .unwrap();
3332
3333                        assert_eq!(output_boxes.len(), 1);
3334                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3335                    } else {
3336                        let mut output_masks = Vec::with_capacity(50);
3337                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3338                        decoder
3339                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3340                            .unwrap();
3341
3342                        assert_eq!(output_boxes.len(), 1);
3343                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3344                    }
3345                }
3346            }
3347        };
3348    }
3349
3350    e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3351    e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3352    e2e_segdet_test!(
3353        test_decoder_end_to_end_segdet_proto,
3354        quantized,
3355        combined,
3356        proto
3357    );
3358    e2e_segdet_test!(
3359        test_decoder_end_to_end_segdet_proto_float,
3360        float,
3361        combined,
3362        proto
3363    );
3364    e2e_segdet_test!(
3365        test_decoder_end_to_end_segdet_split,
3366        quantized,
3367        split,
3368        masks
3369    );
3370    e2e_segdet_test!(
3371        test_decoder_end_to_end_segdet_split_float,
3372        float,
3373        split,
3374        masks
3375    );
3376    e2e_segdet_test!(
3377        test_decoder_end_to_end_segdet_split_proto,
3378        quantized,
3379        split,
3380        proto
3381    );
3382    e2e_segdet_test!(
3383        test_decoder_end_to_end_segdet_split_proto_float,
3384        float,
3385        split,
3386        proto
3387    );
3388
3389    macro_rules! e2e_det_test {
3390        ($name:ident, quantized, $layout:ident) => {
3391            #[test]
3392            fn $name() {
3393                let is_split = matches!(stringify!($layout), "split");
3394
3395                let score_threshold = 0.45;
3396                let iou_threshold = 0.45;
3397
3398                let mut boxes = Array3::zeros((1, 10, 4));
3399                let mut scores = Array3::zeros((1, 10, 1));
3400                let mut classes = Array3::zeros((1, 10, 1));
3401
3402                boxes
3403                    .slice_mut(s![0, 0, ..])
3404                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3405                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3406                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3407
3408                let detect_quant = (2.0 / 255.0, 0_i32);
3409
3410                let decoder = if is_split {
3411                    DecoderBuilder::default()
3412                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3413                        .with_score_threshold(score_threshold)
3414                        .with_iou_threshold(iou_threshold)
3415                        .build()
3416                        .unwrap()
3417                } else {
3418                    DecoderBuilder::default()
3419                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3420                        .with_score_threshold(score_threshold)
3421                        .with_iou_threshold(iou_threshold)
3422                        .build()
3423                        .unwrap()
3424                };
3425
3426                let expected = e2e_expected_boxes_quant();
3427                let mut output_boxes = Vec::with_capacity(50);
3428
3429                if is_split {
3430                    let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3431                    let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3432                    let classes: Array<u8, _> =
3433                        quantize_ndarray(classes.view(), detect_quant.into());
3434                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3435                        boxes.view().into(),
3436                        scores.view().into(),
3437                        classes.view().into(),
3438                    ];
3439                    decoder
3440                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3441                        .unwrap();
3442                } else {
3443                    let detect =
3444                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3445                    assert_eq!(detect.shape(), &[1, 10, 6]);
3446                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3447                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3448                        vec![detect.view().into()];
3449                    decoder
3450                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3451                        .unwrap();
3452                }
3453
3454                assert_eq!(output_boxes.len(), 1);
3455                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3456            }
3457        };
3458        ($name:ident, float, $layout:ident) => {
3459            #[test]
3460            fn $name() {
3461                let is_split = matches!(stringify!($layout), "split");
3462
3463                let score_threshold = 0.45;
3464                let iou_threshold = 0.45;
3465
3466                let mut boxes = Array3::zeros((1, 10, 4));
3467                let mut scores = Array3::zeros((1, 10, 1));
3468                let mut classes = Array3::zeros((1, 10, 1));
3469
3470                boxes
3471                    .slice_mut(s![0, 0, ..])
3472                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3473                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3474                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3475
3476                let decoder = if is_split {
3477                    DecoderBuilder::default()
3478                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3479                        .with_score_threshold(score_threshold)
3480                        .with_iou_threshold(iou_threshold)
3481                        .build()
3482                        .unwrap()
3483                } else {
3484                    DecoderBuilder::default()
3485                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3486                        .with_score_threshold(score_threshold)
3487                        .with_iou_threshold(iou_threshold)
3488                        .build()
3489                        .unwrap()
3490                };
3491
3492                let expected = e2e_expected_boxes_float();
3493                let mut output_boxes = Vec::with_capacity(50);
3494
3495                if is_split {
3496                    let inputs = vec![
3497                        boxes.view().into_dyn(),
3498                        scores.view().into_dyn(),
3499                        classes.view().into_dyn(),
3500                    ];
3501                    decoder
3502                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3503                        .unwrap();
3504                } else {
3505                    let detect =
3506                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3507                    assert_eq!(detect.shape(), &[1, 10, 6]);
3508                    let inputs = vec![detect.view().into_dyn()];
3509                    decoder
3510                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3511                        .unwrap();
3512                }
3513
3514                assert_eq!(output_boxes.len(), 1);
3515                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3516            }
3517        };
3518    }
3519
3520    e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3521    e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3522    e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3523    e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3524
3525    #[test]
3526    fn test_decode_tensor() {
3527        let score_threshold = 0.45;
3528        let iou_threshold = 0.45;
3529
3530        let raw_boxes = include_bytes!(concat!(
3531            env!("CARGO_MANIFEST_DIR"),
3532            "/../../testdata/yolov8_boxes_116x8400.bin"
3533        ));
3534        let raw_boxes =
3535            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3536        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3537        boxes_i8
3538            .map()
3539            .unwrap()
3540            .as_mut_slice()
3541            .copy_from_slice(raw_boxes);
3542        let boxes_i8 = boxes_i8.into();
3543
3544        let raw_protos = include_bytes!(concat!(
3545            env!("CARGO_MANIFEST_DIR"),
3546            "/../../testdata/yolov8_protos_160x160x32.bin"
3547        ));
3548        let raw_protos = unsafe {
3549            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3550        };
3551        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3552        protos_i8
3553            .map()
3554            .unwrap()
3555            .as_mut_slice()
3556            .copy_from_slice(raw_protos);
3557        let protos_i8 = protos_i8.into();
3558
3559        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3560        let expected = real_data_expected_boxes();
3561        let mut output_boxes = Vec::with_capacity(50);
3562
3563        decoder
3564            .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3565            .unwrap();
3566
3567        assert_eq!(output_boxes.len(), 2);
3568        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3569        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3570    }
3571
3572    #[test]
3573    fn test_decode_tensor_f32() {
3574        let score_threshold = 0.45;
3575        let iou_threshold = 0.45;
3576
3577        let quant_boxes = (0.021287762_f32, 31_i32);
3578        let quant_protos = (0.02491162_f32, -117_i32);
3579        let raw_boxes = include_bytes!(concat!(
3580            env!("CARGO_MANIFEST_DIR"),
3581            "/../../testdata/yolov8_boxes_116x8400.bin"
3582        ));
3583        let raw_boxes =
3584            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3585        let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3586        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3587        let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3588        boxes_f32
3589            .map()
3590            .unwrap()
3591            .as_mut_slice()
3592            .copy_from_slice(&raw_boxes_f32);
3593        let boxes_f32 = boxes_f32.into();
3594
3595        let raw_protos = include_bytes!(concat!(
3596            env!("CARGO_MANIFEST_DIR"),
3597            "/../../testdata/yolov8_protos_160x160x32.bin"
3598        ));
3599        let raw_protos = unsafe {
3600            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3601        };
3602        let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3603        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3604        let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3605        protos_f32
3606            .map()
3607            .unwrap()
3608            .as_mut_slice()
3609            .copy_from_slice(&raw_protos_f32);
3610        let protos_f32 = protos_f32.into();
3611
3612        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3613
3614        let expected = real_data_expected_boxes();
3615        let mut output_boxes = Vec::with_capacity(50);
3616
3617        decoder
3618            .decode(
3619                &[&boxes_f32, &protos_f32],
3620                &mut output_boxes,
3621                &mut Vec::new(),
3622            )
3623            .unwrap();
3624
3625        assert_eq!(output_boxes.len(), 2);
3626        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3627        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3628    }
3629
3630    #[test]
3631    fn test_decode_tensor_f64() {
3632        let score_threshold = 0.45;
3633        let iou_threshold = 0.45;
3634
3635        let quant_boxes = (0.021287762_f32, 31_i32);
3636        let quant_protos = (0.02491162_f32, -117_i32);
3637        let raw_boxes = include_bytes!(concat!(
3638            env!("CARGO_MANIFEST_DIR"),
3639            "/../../testdata/yolov8_boxes_116x8400.bin"
3640        ));
3641        let raw_boxes =
3642            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3643        let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3644        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3645        let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3646        boxes_f64
3647            .map()
3648            .unwrap()
3649            .as_mut_slice()
3650            .copy_from_slice(&raw_boxes_f64);
3651        let boxes_f64 = boxes_f64.into();
3652
3653        let raw_protos = include_bytes!(concat!(
3654            env!("CARGO_MANIFEST_DIR"),
3655            "/../../testdata/yolov8_protos_160x160x32.bin"
3656        ));
3657        let raw_protos = unsafe {
3658            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3659        };
3660        let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3661        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3662        let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3663        protos_f64
3664            .map()
3665            .unwrap()
3666            .as_mut_slice()
3667            .copy_from_slice(&raw_protos_f64);
3668        let protos_f64 = protos_f64.into();
3669
3670        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3671
3672        let expected = real_data_expected_boxes();
3673        let mut output_boxes = Vec::with_capacity(50);
3674
3675        decoder
3676            .decode(
3677                &[&boxes_f64, &protos_f64],
3678                &mut output_boxes,
3679                &mut Vec::new(),
3680            )
3681            .unwrap();
3682
3683        assert_eq!(output_boxes.len(), 2);
3684        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3685        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3686    }
3687
3688    #[test]
3689    fn test_decode_tensor_proto() {
3690        let score_threshold = 0.45;
3691        let iou_threshold = 0.45;
3692
3693        let raw_boxes = include_bytes!(concat!(
3694            env!("CARGO_MANIFEST_DIR"),
3695            "/../../testdata/yolov8_boxes_116x8400.bin"
3696        ));
3697        let raw_boxes =
3698            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3699        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3700        boxes_i8
3701            .map()
3702            .unwrap()
3703            .as_mut_slice()
3704            .copy_from_slice(raw_boxes);
3705        let boxes_i8 = boxes_i8.into();
3706
3707        let raw_protos = include_bytes!(concat!(
3708            env!("CARGO_MANIFEST_DIR"),
3709            "/../../testdata/yolov8_protos_160x160x32.bin"
3710        ));
3711        let raw_protos = unsafe {
3712            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3713        };
3714        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3715        protos_i8
3716            .map()
3717            .unwrap()
3718            .as_mut_slice()
3719            .copy_from_slice(raw_protos);
3720        let protos_i8 = protos_i8.into();
3721
3722        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3723
3724        let expected = real_data_expected_boxes();
3725        let mut output_boxes = Vec::with_capacity(50);
3726
3727        let proto_data = decoder
3728            .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3729            .unwrap();
3730
3731        assert_eq!(output_boxes.len(), 2);
3732        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3733        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3734
3735        let proto_data = proto_data.expect("segmentation model should return ProtoData");
3736        assert_eq!(
3737            proto_data.mask_coefficients.len(),
3738            output_boxes.len(),
3739            "mask_coefficients count must match detection count"
3740        );
3741        for coeff in &proto_data.mask_coefficients {
3742            assert_eq!(
3743                coeff.len(),
3744                32,
3745                "each detection should have 32 mask coefficients"
3746            );
3747        }
3748    }
3749}
3750
3751#[cfg(feature = "tracker")]
3752#[cfg(test)]
3753#[cfg_attr(coverage_nightly, coverage(off))]
3754mod decoder_tracked_tests {
3755
3756    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
3757    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
3758    use num_traits::{AsPrimitive, Float, PrimInt};
3759    use rand::{RngExt, SeedableRng};
3760    use rand_distr::StandardNormal;
3761
3762    use crate::{
3763        configs::{self, DimName},
3764        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
3765    };
3766
3767    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
3768        input: ArrayView<F, D>,
3769        quant: Quantization,
3770    ) -> Array<T, D>
3771    where
3772        i32: num_traits::AsPrimitive<F>,
3773        f32: num_traits::AsPrimitive<F>,
3774    {
3775        let zero_point = quant.zero_point.as_();
3776        let div_scale = F::one() / quant.scale.as_();
3777        if zero_point != F::zero() {
3778            input.mapv(|d| (d * div_scale + zero_point).round().as_())
3779        } else {
3780            input.mapv(|d| (d * div_scale).round().as_())
3781        }
3782    }
3783
3784    #[test]
3785    fn test_decoder_tracked_random_jitter() {
3786        use crate::configs::{DecoderType, Nms};
3787        use crate::DecoderBuilder;
3788
3789        let score_threshold = 0.25;
3790        let iou_threshold = 0.1;
3791        let out = include_bytes!(concat!(
3792            env!("CARGO_MANIFEST_DIR"),
3793            "/../../testdata/yolov8s_80_classes.bin"
3794        ));
3795        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
3796        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
3797        let quant = (0.0040811873, -123).into();
3798
3799        let decoder = DecoderBuilder::default()
3800            .with_config_yolo_det(
3801                crate::configs::Detection {
3802                    decoder: DecoderType::Ultralytics,
3803                    shape: vec![1, 84, 8400],
3804                    anchors: None,
3805                    quantization: Some(quant),
3806                    dshape: vec![
3807                        (crate::configs::DimName::Batch, 1),
3808                        (crate::configs::DimName::NumFeatures, 84),
3809                        (crate::configs::DimName::NumBoxes, 8400),
3810                    ],
3811                    normalized: Some(true),
3812                },
3813                None,
3814            )
3815            .with_score_threshold(score_threshold)
3816            .with_iou_threshold(iou_threshold)
3817            .with_nms(Some(Nms::ClassAgnostic))
3818            .build()
3819            .unwrap();
3820        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
3821
3822        let expected_boxes = [
3823            crate::DetectBox {
3824                bbox: crate::BoundingBox {
3825                    xmin: 0.5285137,
3826                    ymin: 0.05305544,
3827                    xmax: 0.87541467,
3828                    ymax: 0.9998909,
3829                },
3830                score: 0.5591227,
3831                label: 0,
3832            },
3833            crate::DetectBox {
3834                bbox: crate::BoundingBox {
3835                    xmin: 0.130598,
3836                    ymin: 0.43260583,
3837                    xmax: 0.35098213,
3838                    ymax: 0.9958097,
3839                },
3840                score: 0.33057618,
3841                label: 75,
3842            },
3843        ];
3844
3845        let mut tracker = ByteTrackBuilder::new()
3846            .track_update(0.1)
3847            .track_high_conf(0.3)
3848            .build();
3849
3850        let mut output_boxes = Vec::with_capacity(50);
3851        let mut output_masks = Vec::with_capacity(50);
3852        let mut output_tracks = Vec::with_capacity(50);
3853
3854        decoder
3855            .decode_tracked_quantized(
3856                &mut tracker,
3857                0,
3858                &[out.view().into()],
3859                &mut output_boxes,
3860                &mut output_masks,
3861                &mut output_tracks,
3862            )
3863            .unwrap();
3864
3865        assert_eq!(output_boxes.len(), 2);
3866        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3867        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3868
3869        let mut last_boxes = output_boxes.clone();
3870
3871        for i in 1..=100 {
3872            let mut out = out.clone();
3873            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
3874            let mut x_values = out.slice_mut(s![0, 0, ..]);
3875            for x in x_values.iter_mut() {
3876                let r: f32 = rng.sample(StandardNormal);
3877                let r = r.clamp(-2.0, 2.0) / 2.0;
3878                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
3879            }
3880
3881            let mut y_values = out.slice_mut(s![0, 1, ..]);
3882            for y in y_values.iter_mut() {
3883                let r: f32 = rng.sample(StandardNormal);
3884                let r = r.clamp(-2.0, 2.0) / 2.0;
3885                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
3886            }
3887
3888            decoder
3889                .decode_tracked_quantized(
3890                    &mut tracker,
3891                    100_000_000 * i / 3, // simulate 33.333ms between frames
3892                    &[out.view().into()],
3893                    &mut output_boxes,
3894                    &mut output_masks,
3895                    &mut output_tracks,
3896                )
3897                .unwrap();
3898
3899            assert_eq!(output_boxes.len(), 2);
3900            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
3901            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
3902
3903            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
3904            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
3905            last_boxes = output_boxes.clone();
3906        }
3907    }
3908
3909    // ─── Shared helpers for tracked decoder tests ────────────────────
3910
3911    fn real_data_expected_boxes() -> [DetectBox; 2] {
3912        [
3913            DetectBox {
3914                bbox: BoundingBox {
3915                    xmin: 0.08515105,
3916                    ymin: 0.7131401,
3917                    xmax: 0.29802868,
3918                    ymax: 0.8195788,
3919                },
3920                score: 0.91537374,
3921                label: 23,
3922            },
3923            DetectBox {
3924                bbox: BoundingBox {
3925                    xmin: 0.59605736,
3926                    ymin: 0.25545314,
3927                    xmax: 0.93666154,
3928                    ymax: 0.72378385,
3929                },
3930                score: 0.91537374,
3931                label: 23,
3932            },
3933        ]
3934    }
3935
3936    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
3937        [DetectBox {
3938            bbox: BoundingBox {
3939                xmin: 0.12549022,
3940                ymin: 0.12549022,
3941                xmax: 0.23529413,
3942                ymax: 0.23529413,
3943            },
3944            score: 0.98823535,
3945            label: 2,
3946        }]
3947    }
3948
3949    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
3950        [DetectBox {
3951            bbox: BoundingBox {
3952                xmin: 0.1234,
3953                ymin: 0.1234,
3954                xmax: 0.2345,
3955                ymax: 0.2345,
3956            },
3957            score: 0.9876,
3958            label: 2,
3959        }]
3960    }
3961
3962    fn build_yolo_split_segdet_decoder(
3963        score_threshold: f32,
3964        iou_threshold: f32,
3965        quant_boxes: (f32, i32),
3966        quant_protos: (f32, i32),
3967    ) -> crate::Decoder {
3968        DecoderBuilder::default()
3969            .with_config_yolo_split_segdet(
3970                configs::Boxes {
3971                    decoder: configs::DecoderType::Ultralytics,
3972                    quantization: Some(quant_boxes.into()),
3973                    shape: vec![1, 4, 8400],
3974                    dshape: vec![
3975                        (DimName::Batch, 1),
3976                        (DimName::BoxCoords, 4),
3977                        (DimName::NumBoxes, 8400),
3978                    ],
3979                    normalized: Some(true),
3980                },
3981                configs::Scores {
3982                    decoder: configs::DecoderType::Ultralytics,
3983                    quantization: Some(quant_boxes.into()),
3984                    shape: vec![1, 80, 8400],
3985                    dshape: vec![
3986                        (DimName::Batch, 1),
3987                        (DimName::NumClasses, 80),
3988                        (DimName::NumBoxes, 8400),
3989                    ],
3990                },
3991                configs::MaskCoefficients {
3992                    decoder: configs::DecoderType::Ultralytics,
3993                    quantization: Some(quant_boxes.into()),
3994                    shape: vec![1, 32, 8400],
3995                    dshape: vec![
3996                        (DimName::Batch, 1),
3997                        (DimName::NumProtos, 32),
3998                        (DimName::NumBoxes, 8400),
3999                    ],
4000                },
4001                configs::Protos {
4002                    decoder: configs::DecoderType::Ultralytics,
4003                    quantization: Some(quant_protos.into()),
4004                    shape: vec![1, 160, 160, 32],
4005                    dshape: vec![
4006                        (DimName::Batch, 1),
4007                        (DimName::Height, 160),
4008                        (DimName::Width, 160),
4009                        (DimName::NumProtos, 32),
4010                    ],
4011                },
4012            )
4013            .with_score_threshold(score_threshold)
4014            .with_iou_threshold(iou_threshold)
4015            .build()
4016            .unwrap()
4017    }
4018
4019    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4020        let config_yaml = include_str!(concat!(
4021            env!("CARGO_MANIFEST_DIR"),
4022            "/../../testdata/yolov8_seg.yaml"
4023        ));
4024        DecoderBuilder::default()
4025            .with_config_yaml_str(config_yaml.to_string())
4026            .with_score_threshold(score_threshold)
4027            .with_iou_threshold(iou_threshold)
4028            .build()
4029            .unwrap()
4030    }
4031
4032    // ─── Real-data tracked test macro ───────────────────────────────
4033    //
4034    // Generates tests that load i8 binary test data from testdata/ and
4035    // exercise all (quant/float) × (combined/split) × (masks/proto)
4036    // decoder paths.
4037
4038    macro_rules! real_data_tracked_test {
4039        ($name:ident, quantized, $layout:ident, $output:ident) => {
4040            #[test]
4041            fn $name() {
4042                let is_split = matches!(stringify!($layout), "split");
4043                let is_proto = matches!(stringify!($output), "proto");
4044
4045                let score_threshold = 0.45;
4046                let iou_threshold = 0.45;
4047                let quant_boxes = (0.021287762_f32, 31_i32);
4048                let quant_protos = (0.02491162_f32, -117_i32);
4049
4050                let raw_boxes = include_bytes!(concat!(
4051                    env!("CARGO_MANIFEST_DIR"),
4052                    "/../../testdata/yolov8_boxes_116x8400.bin"
4053                ));
4054                let raw_boxes = unsafe {
4055                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4056                };
4057                let boxes_i8 =
4058                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4059
4060                let raw_protos = include_bytes!(concat!(
4061                    env!("CARGO_MANIFEST_DIR"),
4062                    "/../../testdata/yolov8_protos_160x160x32.bin"
4063                ));
4064                let raw_protos = unsafe {
4065                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4066                };
4067                let protos_i8 =
4068                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4069                        .unwrap();
4070
4071                // Pre-split (unused for combined, but harmless)
4072                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4073                let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4074                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4075                let mut boxes_combined = boxes_i8;
4076
4077                let decoder = if is_split {
4078                    build_yolo_split_segdet_decoder(
4079                        score_threshold,
4080                        iou_threshold,
4081                        quant_boxes,
4082                        quant_protos,
4083                    )
4084                } else {
4085                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4086                };
4087
4088                let expected = real_data_expected_boxes();
4089                let mut tracker = ByteTrackBuilder::new()
4090                    .track_update(0.1)
4091                    .track_high_conf(0.7)
4092                    .build();
4093                let mut output_boxes = Vec::with_capacity(50);
4094                let mut output_tracks = Vec::with_capacity(50);
4095
4096                // Frame 1: decode
4097                if is_proto {
4098                    {
4099                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4100                            vec![
4101                                boxes_split.view().into(),
4102                                scores_split.view().into(),
4103                                mask_split.view().into(),
4104                                protos_i8.view().into(),
4105                            ]
4106                        } else {
4107                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4108                        };
4109                        decoder
4110                            .decode_tracked_quantized_proto(
4111                                &mut tracker,
4112                                0,
4113                                &inputs,
4114                                &mut output_boxes,
4115                                &mut output_tracks,
4116                            )
4117                            .unwrap();
4118                    }
4119                    assert_eq!(output_boxes.len(), 2);
4120                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4121                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4122
4123                    // Zero scores for frame 2
4124                    if is_split {
4125                        for score in scores_split.iter_mut() {
4126                            *score = i8::MIN;
4127                        }
4128                    } else {
4129                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4130                            *score = i8::MIN;
4131                        }
4132                    }
4133
4134                    let proto_result = {
4135                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4136                            vec![
4137                                boxes_split.view().into(),
4138                                scores_split.view().into(),
4139                                mask_split.view().into(),
4140                                protos_i8.view().into(),
4141                            ]
4142                        } else {
4143                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4144                        };
4145                        decoder
4146                            .decode_tracked_quantized_proto(
4147                                &mut tracker,
4148                                100_000_000 / 3,
4149                                &inputs,
4150                                &mut output_boxes,
4151                                &mut output_tracks,
4152                            )
4153                            .unwrap()
4154                    };
4155                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4156                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4157                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4158                } else {
4159                    let mut output_masks = Vec::with_capacity(50);
4160                    {
4161                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4162                            vec![
4163                                boxes_split.view().into(),
4164                                scores_split.view().into(),
4165                                mask_split.view().into(),
4166                                protos_i8.view().into(),
4167                            ]
4168                        } else {
4169                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4170                        };
4171                        decoder
4172                            .decode_tracked_quantized(
4173                                &mut tracker,
4174                                0,
4175                                &inputs,
4176                                &mut output_boxes,
4177                                &mut output_masks,
4178                                &mut output_tracks,
4179                            )
4180                            .unwrap();
4181                    }
4182                    assert_eq!(output_boxes.len(), 2);
4183                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4184                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4185
4186                    if is_split {
4187                        for score in scores_split.iter_mut() {
4188                            *score = i8::MIN;
4189                        }
4190                    } else {
4191                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4192                            *score = i8::MIN;
4193                        }
4194                    }
4195
4196                    {
4197                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4198                            vec![
4199                                boxes_split.view().into(),
4200                                scores_split.view().into(),
4201                                mask_split.view().into(),
4202                                protos_i8.view().into(),
4203                            ]
4204                        } else {
4205                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4206                        };
4207                        decoder
4208                            .decode_tracked_quantized(
4209                                &mut tracker,
4210                                100_000_000 / 3,
4211                                &inputs,
4212                                &mut output_boxes,
4213                                &mut output_masks,
4214                                &mut output_tracks,
4215                            )
4216                            .unwrap();
4217                    }
4218                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4219                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4220                    assert!(output_masks.is_empty());
4221                }
4222            }
4223        };
4224        ($name:ident, float, $layout:ident, $output:ident) => {
4225            #[test]
4226            fn $name() {
4227                let is_split = matches!(stringify!($layout), "split");
4228                let is_proto = matches!(stringify!($output), "proto");
4229
4230                let score_threshold = 0.45;
4231                let iou_threshold = 0.45;
4232                let quant_boxes = (0.021287762_f32, 31_i32);
4233                let quant_protos = (0.02491162_f32, -117_i32);
4234
4235                let raw_boxes = include_bytes!(concat!(
4236                    env!("CARGO_MANIFEST_DIR"),
4237                    "/../../testdata/yolov8_boxes_116x8400.bin"
4238                ));
4239                let raw_boxes = unsafe {
4240                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4241                };
4242                let boxes_i8 =
4243                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4244                let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4245
4246                let raw_protos = include_bytes!(concat!(
4247                    env!("CARGO_MANIFEST_DIR"),
4248                    "/../../testdata/yolov8_protos_160x160x32.bin"
4249                ));
4250                let raw_protos = unsafe {
4251                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4252                };
4253                let protos_i8 =
4254                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4255                        .unwrap();
4256                let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4257
4258                // Pre-split from dequantized data
4259                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4260                let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4261                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4262                let mut boxes_combined = boxes_f32;
4263
4264                let decoder = if is_split {
4265                    build_yolo_split_segdet_decoder(
4266                        score_threshold,
4267                        iou_threshold,
4268                        quant_boxes,
4269                        quant_protos,
4270                    )
4271                } else {
4272                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4273                };
4274
4275                let expected = real_data_expected_boxes();
4276                let mut tracker = ByteTrackBuilder::new()
4277                    .track_update(0.1)
4278                    .track_high_conf(0.7)
4279                    .build();
4280                let mut output_boxes = Vec::with_capacity(50);
4281                let mut output_tracks = Vec::with_capacity(50);
4282
4283                if is_proto {
4284                    {
4285                        let inputs = if is_split {
4286                            vec![
4287                                boxes_split.view().into_dyn(),
4288                                scores_split.view().into_dyn(),
4289                                mask_split.view().into_dyn(),
4290                                protos_f32.view().into_dyn(),
4291                            ]
4292                        } else {
4293                            vec![
4294                                boxes_combined.view().into_dyn(),
4295                                protos_f32.view().into_dyn(),
4296                            ]
4297                        };
4298                        decoder
4299                            .decode_tracked_float_proto(
4300                                &mut tracker,
4301                                0,
4302                                &inputs,
4303                                &mut output_boxes,
4304                                &mut output_tracks,
4305                            )
4306                            .unwrap();
4307                    }
4308                    assert_eq!(output_boxes.len(), 2);
4309                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4310                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4311
4312                    if is_split {
4313                        for score in scores_split.iter_mut() {
4314                            *score = 0.0;
4315                        }
4316                    } else {
4317                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4318                            *score = 0.0;
4319                        }
4320                    }
4321
4322                    let proto_result = {
4323                        let inputs = if is_split {
4324                            vec![
4325                                boxes_split.view().into_dyn(),
4326                                scores_split.view().into_dyn(),
4327                                mask_split.view().into_dyn(),
4328                                protos_f32.view().into_dyn(),
4329                            ]
4330                        } else {
4331                            vec![
4332                                boxes_combined.view().into_dyn(),
4333                                protos_f32.view().into_dyn(),
4334                            ]
4335                        };
4336                        decoder
4337                            .decode_tracked_float_proto(
4338                                &mut tracker,
4339                                100_000_000 / 3,
4340                                &inputs,
4341                                &mut output_boxes,
4342                                &mut output_tracks,
4343                            )
4344                            .unwrap()
4345                    };
4346                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4347                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4348                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4349                } else {
4350                    let mut output_masks = Vec::with_capacity(50);
4351                    {
4352                        let inputs = if is_split {
4353                            vec![
4354                                boxes_split.view().into_dyn(),
4355                                scores_split.view().into_dyn(),
4356                                mask_split.view().into_dyn(),
4357                                protos_f32.view().into_dyn(),
4358                            ]
4359                        } else {
4360                            vec![
4361                                boxes_combined.view().into_dyn(),
4362                                protos_f32.view().into_dyn(),
4363                            ]
4364                        };
4365                        decoder
4366                            .decode_tracked_float(
4367                                &mut tracker,
4368                                0,
4369                                &inputs,
4370                                &mut output_boxes,
4371                                &mut output_masks,
4372                                &mut output_tracks,
4373                            )
4374                            .unwrap();
4375                    }
4376                    assert_eq!(output_boxes.len(), 2);
4377                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4378                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4379
4380                    if is_split {
4381                        for score in scores_split.iter_mut() {
4382                            *score = 0.0;
4383                        }
4384                    } else {
4385                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4386                            *score = 0.0;
4387                        }
4388                    }
4389
4390                    {
4391                        let inputs = if is_split {
4392                            vec![
4393                                boxes_split.view().into_dyn(),
4394                                scores_split.view().into_dyn(),
4395                                mask_split.view().into_dyn(),
4396                                protos_f32.view().into_dyn(),
4397                            ]
4398                        } else {
4399                            vec![
4400                                boxes_combined.view().into_dyn(),
4401                                protos_f32.view().into_dyn(),
4402                            ]
4403                        };
4404                        decoder
4405                            .decode_tracked_float(
4406                                &mut tracker,
4407                                100_000_000 / 3,
4408                                &inputs,
4409                                &mut output_boxes,
4410                                &mut output_masks,
4411                                &mut output_tracks,
4412                            )
4413                            .unwrap();
4414                    }
4415                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4416                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4417                    assert!(output_masks.is_empty());
4418                }
4419            }
4420        };
4421    }
4422
4423    real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
4424    real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
4425    real_data_tracked_test!(
4426        test_decoder_tracked_segdet_proto,
4427        quantized,
4428        combined,
4429        proto
4430    );
4431    real_data_tracked_test!(
4432        test_decoder_tracked_segdet_proto_float,
4433        float,
4434        combined,
4435        proto
4436    );
4437    real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
4438    real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
4439    real_data_tracked_test!(
4440        test_decoder_tracked_segdet_split_proto,
4441        quantized,
4442        split,
4443        proto
4444    );
4445    real_data_tracked_test!(
4446        test_decoder_tracked_segdet_split_proto_float,
4447        float,
4448        split,
4449        proto
4450    );
4451
4452    // ─── End-to-end tracked test macro ──────────────────────────────
4453    //
4454    // Generates tests with synthetic data to exercise all tracked
4455    // decode paths without needing real model output files.
4456
4457    const E2E_COMBINED_CONFIG: &str = "
4458decoder_version: yolo26
4459outputs:
4460 - type: detection
4461   decoder: ultralytics
4462   quantization: [0.00784313725490196, 0]
4463   shape: [1, 10, 38]
4464   dshape:
4465    - [batch, 1]
4466    - [num_boxes, 10]
4467    - [num_features, 38]
4468   normalized: true
4469 - type: protos
4470   decoder: ultralytics
4471   quantization: [0.0039215686274509803921568627451, 128]
4472   shape: [1, 160, 160, 32]
4473   dshape:
4474    - [batch, 1]
4475    - [height, 160]
4476    - [width, 160]
4477    - [num_protos, 32]
4478";
4479
4480    const E2E_SPLIT_CONFIG: &str = "
4481decoder_version: yolo26
4482outputs:
4483 - type: boxes
4484   decoder: ultralytics
4485   quantization: [0.00784313725490196, 0]
4486   shape: [1, 10, 4]
4487   dshape:
4488    - [batch, 1]
4489    - [num_boxes, 10]
4490    - [box_coords, 4]
4491   normalized: true
4492 - type: scores
4493   decoder: ultralytics
4494   quantization: [0.00784313725490196, 0]
4495   shape: [1, 10, 1]
4496   dshape:
4497    - [batch, 1]
4498    - [num_boxes, 10]
4499    - [num_classes, 1]
4500 - type: classes
4501   decoder: ultralytics
4502   quantization: [0.00784313725490196, 0]
4503   shape: [1, 10, 1]
4504   dshape:
4505    - [batch, 1]
4506    - [num_boxes, 10]
4507    - [num_classes, 1]
4508 - type: mask_coefficients
4509   decoder: ultralytics
4510   quantization: [0.00784313725490196, 0]
4511   shape: [1, 10, 32]
4512   dshape:
4513    - [batch, 1]
4514    - [num_boxes, 10]
4515    - [num_protos, 32]
4516 - type: protos
4517   decoder: ultralytics
4518   quantization: [0.0039215686274509803921568627451, 128]
4519   shape: [1, 160, 160, 32]
4520   dshape:
4521    - [batch, 1]
4522    - [height, 160]
4523    - [width, 160]
4524    - [num_protos, 32]
4525";
4526
4527    macro_rules! e2e_tracked_test {
4528        ($name:ident, quantized, $layout:ident, $output:ident) => {
4529            #[test]
4530            fn $name() {
4531                let is_split = matches!(stringify!($layout), "split");
4532                let is_proto = matches!(stringify!($output), "proto");
4533
4534                let score_threshold = 0.45;
4535                let iou_threshold = 0.45;
4536
4537                let mut boxes = Array2::zeros((10, 4));
4538                let mut scores = Array2::zeros((10, 1));
4539                let mut classes = Array2::zeros((10, 1));
4540                let mask = Array2::zeros((10, 32));
4541                let protos = Array3::<f64>::zeros((160, 160, 32));
4542                let protos = protos.insert_axis(Axis(0));
4543                let protos_quant = (1.0 / 255.0, 0.0);
4544                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4545
4546                boxes
4547                    .slice_mut(s![0, ..])
4548                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4549                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4550                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4551
4552                let detect_quant = (2.0 / 255.0, 0.0);
4553
4554                let decoder = if is_split {
4555                    DecoderBuilder::default()
4556                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
4557                        .with_score_threshold(score_threshold)
4558                        .with_iou_threshold(iou_threshold)
4559                        .build()
4560                        .unwrap()
4561                } else {
4562                    DecoderBuilder::default()
4563                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
4564                        .with_score_threshold(score_threshold)
4565                        .with_iou_threshold(iou_threshold)
4566                        .build()
4567                        .unwrap()
4568                };
4569
4570                let expected = e2e_expected_boxes_quant();
4571                let mut tracker = ByteTrackBuilder::new()
4572                    .track_update(0.1)
4573                    .track_high_conf(0.7)
4574                    .build();
4575                let mut output_boxes = Vec::with_capacity(50);
4576                let mut output_tracks = Vec::with_capacity(50);
4577
4578                if is_split {
4579                    let boxes = boxes.insert_axis(Axis(0));
4580                    let scores = scores.insert_axis(Axis(0));
4581                    let classes = classes.insert_axis(Axis(0));
4582                    let mask = mask.insert_axis(Axis(0));
4583
4584                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4585                    let mut scores: Array3<u8> =
4586                        quantize_ndarray(scores.view(), detect_quant.into());
4587                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4588                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4589
4590                    if is_proto {
4591                        {
4592                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4593                                boxes.view().into(),
4594                                scores.view().into(),
4595                                classes.view().into(),
4596                                mask.view().into(),
4597                                protos.view().into(),
4598                            ];
4599                            decoder
4600                                .decode_tracked_quantized_proto(
4601                                    &mut tracker,
4602                                    0,
4603                                    &inputs,
4604                                    &mut output_boxes,
4605                                    &mut output_tracks,
4606                                )
4607                                .unwrap();
4608                        }
4609                        assert_eq!(output_boxes.len(), 1);
4610                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4611
4612                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4613                            *score = u8::MIN;
4614                        }
4615                        let proto_result = {
4616                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4617                                boxes.view().into(),
4618                                scores.view().into(),
4619                                classes.view().into(),
4620                                mask.view().into(),
4621                                protos.view().into(),
4622                            ];
4623                            decoder
4624                                .decode_tracked_quantized_proto(
4625                                    &mut tracker,
4626                                    100_000_000 / 3,
4627                                    &inputs,
4628                                    &mut output_boxes,
4629                                    &mut output_tracks,
4630                                )
4631                                .unwrap()
4632                        };
4633                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4634                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4635                    } else {
4636                        let mut output_masks = Vec::with_capacity(50);
4637                        {
4638                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4639                                boxes.view().into(),
4640                                scores.view().into(),
4641                                classes.view().into(),
4642                                mask.view().into(),
4643                                protos.view().into(),
4644                            ];
4645                            decoder
4646                                .decode_tracked_quantized(
4647                                    &mut tracker,
4648                                    0,
4649                                    &inputs,
4650                                    &mut output_boxes,
4651                                    &mut output_masks,
4652                                    &mut output_tracks,
4653                                )
4654                                .unwrap();
4655                        }
4656                        assert_eq!(output_boxes.len(), 1);
4657                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4658
4659                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4660                            *score = u8::MIN;
4661                        }
4662                        {
4663                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
4664                                boxes.view().into(),
4665                                scores.view().into(),
4666                                classes.view().into(),
4667                                mask.view().into(),
4668                                protos.view().into(),
4669                            ];
4670                            decoder
4671                                .decode_tracked_quantized(
4672                                    &mut tracker,
4673                                    100_000_000 / 3,
4674                                    &inputs,
4675                                    &mut output_boxes,
4676                                    &mut output_masks,
4677                                    &mut output_tracks,
4678                                )
4679                                .unwrap();
4680                        }
4681                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4682                        assert!(output_masks.is_empty());
4683                    }
4684                } else {
4685                    // Combined layout
4686                    let detect = ndarray::concatenate![
4687                        Axis(1),
4688                        boxes.view(),
4689                        scores.view(),
4690                        classes.view(),
4691                        mask.view()
4692                    ];
4693                    let detect = detect.insert_axis(Axis(0));
4694                    assert_eq!(detect.shape(), &[1, 10, 38]);
4695                    let mut detect: Array3<u8> =
4696                        quantize_ndarray(detect.view(), detect_quant.into());
4697
4698                    if is_proto {
4699                        {
4700                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4701                                vec![detect.view().into(), protos.view().into()];
4702                            decoder
4703                                .decode_tracked_quantized_proto(
4704                                    &mut tracker,
4705                                    0,
4706                                    &inputs,
4707                                    &mut output_boxes,
4708                                    &mut output_tracks,
4709                                )
4710                                .unwrap();
4711                        }
4712                        assert_eq!(output_boxes.len(), 1);
4713                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4714
4715                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4716                            *score = u8::MIN;
4717                        }
4718                        let proto_result = {
4719                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4720                                vec![detect.view().into(), protos.view().into()];
4721                            decoder
4722                                .decode_tracked_quantized_proto(
4723                                    &mut tracker,
4724                                    100_000_000 / 3,
4725                                    &inputs,
4726                                    &mut output_boxes,
4727                                    &mut output_tracks,
4728                                )
4729                                .unwrap()
4730                        };
4731                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4732                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4733                    } else {
4734                        let mut output_masks = Vec::with_capacity(50);
4735                        {
4736                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4737                                vec![detect.view().into(), protos.view().into()];
4738                            decoder
4739                                .decode_tracked_quantized(
4740                                    &mut tracker,
4741                                    0,
4742                                    &inputs,
4743                                    &mut output_boxes,
4744                                    &mut output_masks,
4745                                    &mut output_tracks,
4746                                )
4747                                .unwrap();
4748                        }
4749                        assert_eq!(output_boxes.len(), 1);
4750                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4751
4752                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4753                            *score = u8::MIN;
4754                        }
4755                        {
4756                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
4757                                vec![detect.view().into(), protos.view().into()];
4758                            decoder
4759                                .decode_tracked_quantized(
4760                                    &mut tracker,
4761                                    100_000_000 / 3,
4762                                    &inputs,
4763                                    &mut output_boxes,
4764                                    &mut output_masks,
4765                                    &mut output_tracks,
4766                                )
4767                                .unwrap();
4768                        }
4769                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4770                        assert!(output_masks.is_empty());
4771                    }
4772                }
4773            }
4774        };
4775        ($name:ident, float, $layout:ident, $output:ident) => {
4776            #[test]
4777            fn $name() {
4778                let is_split = matches!(stringify!($layout), "split");
4779                let is_proto = matches!(stringify!($output), "proto");
4780
4781                let score_threshold = 0.45;
4782                let iou_threshold = 0.45;
4783
4784                let mut boxes = Array2::zeros((10, 4));
4785                let mut scores = Array2::zeros((10, 1));
4786                let mut classes = Array2::zeros((10, 1));
4787                let mask: Array2<f64> = Array2::zeros((10, 32));
4788                let protos = Array3::<f64>::zeros((160, 160, 32));
4789                let protos = protos.insert_axis(Axis(0));
4790
4791                boxes
4792                    .slice_mut(s![0, ..])
4793                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4794                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4795                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4796
4797                let decoder = if is_split {
4798                    DecoderBuilder::default()
4799                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
4800                        .with_score_threshold(score_threshold)
4801                        .with_iou_threshold(iou_threshold)
4802                        .build()
4803                        .unwrap()
4804                } else {
4805                    DecoderBuilder::default()
4806                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
4807                        .with_score_threshold(score_threshold)
4808                        .with_iou_threshold(iou_threshold)
4809                        .build()
4810                        .unwrap()
4811                };
4812
4813                let expected = e2e_expected_boxes_float();
4814                let mut tracker = ByteTrackBuilder::new()
4815                    .track_update(0.1)
4816                    .track_high_conf(0.7)
4817                    .build();
4818                let mut output_boxes = Vec::with_capacity(50);
4819                let mut output_tracks = Vec::with_capacity(50);
4820
4821                if is_split {
4822                    let boxes = boxes.insert_axis(Axis(0));
4823                    let mut scores = scores.insert_axis(Axis(0));
4824                    let classes = classes.insert_axis(Axis(0));
4825                    let mask = mask.insert_axis(Axis(0));
4826
4827                    if is_proto {
4828                        {
4829                            let inputs = vec![
4830                                boxes.view().into_dyn(),
4831                                scores.view().into_dyn(),
4832                                classes.view().into_dyn(),
4833                                mask.view().into_dyn(),
4834                                protos.view().into_dyn(),
4835                            ];
4836                            decoder
4837                                .decode_tracked_float_proto(
4838                                    &mut tracker,
4839                                    0,
4840                                    &inputs,
4841                                    &mut output_boxes,
4842                                    &mut output_tracks,
4843                                )
4844                                .unwrap();
4845                        }
4846                        assert_eq!(output_boxes.len(), 1);
4847                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4848
4849                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4850                            *score = 0.0;
4851                        }
4852                        let proto_result = {
4853                            let inputs = vec![
4854                                boxes.view().into_dyn(),
4855                                scores.view().into_dyn(),
4856                                classes.view().into_dyn(),
4857                                mask.view().into_dyn(),
4858                                protos.view().into_dyn(),
4859                            ];
4860                            decoder
4861                                .decode_tracked_float_proto(
4862                                    &mut tracker,
4863                                    100_000_000 / 3,
4864                                    &inputs,
4865                                    &mut output_boxes,
4866                                    &mut output_tracks,
4867                                )
4868                                .unwrap()
4869                        };
4870                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4871                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4872                    } else {
4873                        let mut output_masks = Vec::with_capacity(50);
4874                        {
4875                            let inputs = vec![
4876                                boxes.view().into_dyn(),
4877                                scores.view().into_dyn(),
4878                                classes.view().into_dyn(),
4879                                mask.view().into_dyn(),
4880                                protos.view().into_dyn(),
4881                            ];
4882                            decoder
4883                                .decode_tracked_float(
4884                                    &mut tracker,
4885                                    0,
4886                                    &inputs,
4887                                    &mut output_boxes,
4888                                    &mut output_masks,
4889                                    &mut output_tracks,
4890                                )
4891                                .unwrap();
4892                        }
4893                        assert_eq!(output_boxes.len(), 1);
4894                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4895
4896                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4897                            *score = 0.0;
4898                        }
4899                        {
4900                            let inputs = vec![
4901                                boxes.view().into_dyn(),
4902                                scores.view().into_dyn(),
4903                                classes.view().into_dyn(),
4904                                mask.view().into_dyn(),
4905                                protos.view().into_dyn(),
4906                            ];
4907                            decoder
4908                                .decode_tracked_float(
4909                                    &mut tracker,
4910                                    100_000_000 / 3,
4911                                    &inputs,
4912                                    &mut output_boxes,
4913                                    &mut output_masks,
4914                                    &mut output_tracks,
4915                                )
4916                                .unwrap();
4917                        }
4918                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4919                        assert!(output_masks.is_empty());
4920                    }
4921                } else {
4922                    // Combined layout
4923                    let detect = ndarray::concatenate![
4924                        Axis(1),
4925                        boxes.view(),
4926                        scores.view(),
4927                        classes.view(),
4928                        mask.view()
4929                    ];
4930                    let mut detect = detect.insert_axis(Axis(0));
4931                    assert_eq!(detect.shape(), &[1, 10, 38]);
4932
4933                    if is_proto {
4934                        {
4935                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4936                            decoder
4937                                .decode_tracked_float_proto(
4938                                    &mut tracker,
4939                                    0,
4940                                    &inputs,
4941                                    &mut output_boxes,
4942                                    &mut output_tracks,
4943                                )
4944                                .unwrap();
4945                        }
4946                        assert_eq!(output_boxes.len(), 1);
4947                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4948
4949                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4950                            *score = 0.0;
4951                        }
4952                        let proto_result = {
4953                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4954                            decoder
4955                                .decode_tracked_float_proto(
4956                                    &mut tracker,
4957                                    100_000_000 / 3,
4958                                    &inputs,
4959                                    &mut output_boxes,
4960                                    &mut output_tracks,
4961                                )
4962                                .unwrap()
4963                        };
4964                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4965                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
4966                    } else {
4967                        let mut output_masks = Vec::with_capacity(50);
4968                        {
4969                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4970                            decoder
4971                                .decode_tracked_float(
4972                                    &mut tracker,
4973                                    0,
4974                                    &inputs,
4975                                    &mut output_boxes,
4976                                    &mut output_masks,
4977                                    &mut output_tracks,
4978                                )
4979                                .unwrap();
4980                        }
4981                        assert_eq!(output_boxes.len(), 1);
4982                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4983
4984                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4985                            *score = 0.0;
4986                        }
4987                        {
4988                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
4989                            decoder
4990                                .decode_tracked_float(
4991                                    &mut tracker,
4992                                    100_000_000 / 3,
4993                                    &inputs,
4994                                    &mut output_boxes,
4995                                    &mut output_masks,
4996                                    &mut output_tracks,
4997                                )
4998                                .unwrap();
4999                        }
5000                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5001                        assert!(output_masks.is_empty());
5002                    }
5003                }
5004            }
5005        };
5006    }
5007
5008    e2e_tracked_test!(
5009        test_decoder_tracked_end_to_end_segdet,
5010        quantized,
5011        combined,
5012        masks
5013    );
5014    e2e_tracked_test!(
5015        test_decoder_tracked_end_to_end_segdet_float,
5016        float,
5017        combined,
5018        masks
5019    );
5020    e2e_tracked_test!(
5021        test_decoder_tracked_end_to_end_segdet_proto,
5022        quantized,
5023        combined,
5024        proto
5025    );
5026    e2e_tracked_test!(
5027        test_decoder_tracked_end_to_end_segdet_proto_float,
5028        float,
5029        combined,
5030        proto
5031    );
5032    e2e_tracked_test!(
5033        test_decoder_tracked_end_to_end_segdet_split,
5034        quantized,
5035        split,
5036        masks
5037    );
5038    e2e_tracked_test!(
5039        test_decoder_tracked_end_to_end_segdet_split_float,
5040        float,
5041        split,
5042        masks
5043    );
5044    e2e_tracked_test!(
5045        test_decoder_tracked_end_to_end_segdet_split_proto,
5046        quantized,
5047        split,
5048        proto
5049    );
5050    e2e_tracked_test!(
5051        test_decoder_tracked_end_to_end_segdet_split_proto_float,
5052        float,
5053        split,
5054        proto
5055    );
5056
5057    // ─── End-to-end tracked TensorDyn test macro ────────────────────
5058    //
5059    // Same as e2e_tracked_test but wraps data in TensorDyn and exercises
5060    // the public decode_tracked / decode_proto_tracked API.
5061
5062    macro_rules! e2e_tracked_tensor_test {
5063        ($name:ident, quantized, $layout:ident, $output:ident) => {
5064            #[test]
5065            fn $name() {
5066                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5067
5068                let is_split = matches!(stringify!($layout), "split");
5069                let is_proto = matches!(stringify!($output), "proto");
5070
5071                let score_threshold = 0.45;
5072                let iou_threshold = 0.45;
5073
5074                let mut boxes = Array2::zeros((10, 4));
5075                let mut scores = Array2::zeros((10, 1));
5076                let mut classes = Array2::zeros((10, 1));
5077                let mask = Array2::zeros((10, 32));
5078                let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5079                let protos_f64 = protos_f64.insert_axis(Axis(0));
5080                let protos_quant = (1.0 / 255.0, 0.0);
5081                let protos_u8: Array4<u8> =
5082                    quantize_ndarray(protos_f64.view(), protos_quant.into());
5083
5084                boxes
5085                    .slice_mut(s![0, ..])
5086                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5087                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5088                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5089
5090                let detect_quant = (2.0 / 255.0, 0.0);
5091
5092                let decoder = if is_split {
5093                    DecoderBuilder::default()
5094                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5095                        .with_score_threshold(score_threshold)
5096                        .with_iou_threshold(iou_threshold)
5097                        .build()
5098                        .unwrap()
5099                } else {
5100                    DecoderBuilder::default()
5101                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5102                        .with_score_threshold(score_threshold)
5103                        .with_iou_threshold(iou_threshold)
5104                        .build()
5105                        .unwrap()
5106                };
5107
5108                // Helper to wrap a u8 slice into a TensorDyn
5109                let make_u8_tensor =
5110                    |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5111                        let t = Tensor::<u8>::new(shape, None, None).unwrap();
5112                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5113                        t.into()
5114                    };
5115
5116                let expected = e2e_expected_boxes_quant();
5117                let mut tracker = ByteTrackBuilder::new()
5118                    .track_update(0.1)
5119                    .track_high_conf(0.7)
5120                    .build();
5121                let mut output_boxes = Vec::with_capacity(50);
5122                let mut output_tracks = Vec::with_capacity(50);
5123
5124                let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5125
5126                if is_split {
5127                    let boxes = boxes.insert_axis(Axis(0));
5128                    let scores = scores.insert_axis(Axis(0));
5129                    let classes = classes.insert_axis(Axis(0));
5130                    let mask = mask.insert_axis(Axis(0));
5131
5132                    let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5133                    let mut scores_q: Array3<u8> =
5134                        quantize_ndarray(scores.view(), detect_quant.into());
5135                    let classes_q: Array3<u8> =
5136                        quantize_ndarray(classes.view(), detect_quant.into());
5137                    let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5138
5139                    let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5140                    let classes_td =
5141                        make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5142                    let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5143
5144                    if is_proto {
5145                        let scores_td =
5146                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5147                        decoder
5148                            .decode_proto_tracked(
5149                                &mut tracker,
5150                                0,
5151                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5152                                &mut output_boxes,
5153                                &mut output_tracks,
5154                            )
5155                            .unwrap();
5156
5157                        assert_eq!(output_boxes.len(), 1);
5158                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5159
5160                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5161                            *score = u8::MIN;
5162                        }
5163                        let scores_td =
5164                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5165                        let proto_result = decoder
5166                            .decode_proto_tracked(
5167                                &mut tracker,
5168                                100_000_000 / 3,
5169                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5170                                &mut output_boxes,
5171                                &mut output_tracks,
5172                            )
5173                            .unwrap();
5174                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5175                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5176                    } else {
5177                        let scores_td =
5178                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5179                        let mut output_masks = Vec::with_capacity(50);
5180                        decoder
5181                            .decode_tracked(
5182                                &mut tracker,
5183                                0,
5184                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5185                                &mut output_boxes,
5186                                &mut output_masks,
5187                                &mut output_tracks,
5188                            )
5189                            .unwrap();
5190
5191                        assert_eq!(output_boxes.len(), 1);
5192                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5193
5194                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5195                            *score = u8::MIN;
5196                        }
5197                        let scores_td =
5198                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5199                        decoder
5200                            .decode_tracked(
5201                                &mut tracker,
5202                                100_000_000 / 3,
5203                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5204                                &mut output_boxes,
5205                                &mut output_masks,
5206                                &mut output_tracks,
5207                            )
5208                            .unwrap();
5209                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5210                        assert!(output_masks.is_empty());
5211                    }
5212                } else {
5213                    // Combined layout
5214                    let detect = ndarray::concatenate![
5215                        Axis(1),
5216                        boxes.view(),
5217                        scores.view(),
5218                        classes.view(),
5219                        mask.view()
5220                    ];
5221                    let detect = detect.insert_axis(Axis(0));
5222                    assert_eq!(detect.shape(), &[1, 10, 38]);
5223                    // Ensure contiguous layout after concatenation for as_slice()
5224                    let detect =
5225                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5226                            .unwrap();
5227                    let mut detect_q: Array3<u8> =
5228                        quantize_ndarray(detect.view(), detect_quant.into());
5229
5230                    if is_proto {
5231                        let detect_td =
5232                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5233                        decoder
5234                            .decode_proto_tracked(
5235                                &mut tracker,
5236                                0,
5237                                &[&detect_td, &protos_td],
5238                                &mut output_boxes,
5239                                &mut output_tracks,
5240                            )
5241                            .unwrap();
5242
5243                        assert_eq!(output_boxes.len(), 1);
5244                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5245
5246                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5247                            *score = u8::MIN;
5248                        }
5249                        let detect_td =
5250                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5251                        let proto_result = decoder
5252                            .decode_proto_tracked(
5253                                &mut tracker,
5254                                100_000_000 / 3,
5255                                &[&detect_td, &protos_td],
5256                                &mut output_boxes,
5257                                &mut output_tracks,
5258                            )
5259                            .unwrap();
5260                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5261                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5262                    } else {
5263                        let detect_td =
5264                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5265                        let mut output_masks = Vec::with_capacity(50);
5266                        decoder
5267                            .decode_tracked(
5268                                &mut tracker,
5269                                0,
5270                                &[&detect_td, &protos_td],
5271                                &mut output_boxes,
5272                                &mut output_masks,
5273                                &mut output_tracks,
5274                            )
5275                            .unwrap();
5276
5277                        assert_eq!(output_boxes.len(), 1);
5278                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5279
5280                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5281                            *score = u8::MIN;
5282                        }
5283                        let detect_td =
5284                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5285                        decoder
5286                            .decode_tracked(
5287                                &mut tracker,
5288                                100_000_000 / 3,
5289                                &[&detect_td, &protos_td],
5290                                &mut output_boxes,
5291                                &mut output_masks,
5292                                &mut output_tracks,
5293                            )
5294                            .unwrap();
5295                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5296                        assert!(output_masks.is_empty());
5297                    }
5298                }
5299            }
5300        };
5301        ($name:ident, float, $layout:ident, $output:ident) => {
5302            #[test]
5303            fn $name() {
5304                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5305
5306                let is_split = matches!(stringify!($layout), "split");
5307                let is_proto = matches!(stringify!($output), "proto");
5308
5309                let score_threshold = 0.45;
5310                let iou_threshold = 0.45;
5311
5312                let mut boxes = Array2::zeros((10, 4));
5313                let mut scores = Array2::zeros((10, 1));
5314                let mut classes = Array2::zeros((10, 1));
5315                let mask: Array2<f64> = Array2::zeros((10, 32));
5316                let protos = Array3::<f64>::zeros((160, 160, 32));
5317                let protos = protos.insert_axis(Axis(0));
5318
5319                boxes
5320                    .slice_mut(s![0, ..])
5321                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5322                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5323                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5324
5325                let decoder = if is_split {
5326                    DecoderBuilder::default()
5327                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5328                        .with_score_threshold(score_threshold)
5329                        .with_iou_threshold(iou_threshold)
5330                        .build()
5331                        .unwrap()
5332                } else {
5333                    DecoderBuilder::default()
5334                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5335                        .with_score_threshold(score_threshold)
5336                        .with_iou_threshold(iou_threshold)
5337                        .build()
5338                        .unwrap()
5339                };
5340
5341                // Helper to wrap an f64 slice into a TensorDyn
5342                let make_f64_tensor =
5343                    |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
5344                        let t = Tensor::<f64>::new(shape, None, None).unwrap();
5345                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5346                        t.into()
5347                    };
5348
5349                let expected = e2e_expected_boxes_float();
5350                let mut tracker = ByteTrackBuilder::new()
5351                    .track_update(0.1)
5352                    .track_high_conf(0.7)
5353                    .build();
5354                let mut output_boxes = Vec::with_capacity(50);
5355                let mut output_tracks = Vec::with_capacity(50);
5356
5357                let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
5358
5359                if is_split {
5360                    let boxes = boxes.insert_axis(Axis(0));
5361                    let mut scores = scores.insert_axis(Axis(0));
5362                    let classes = classes.insert_axis(Axis(0));
5363                    let mask = mask.insert_axis(Axis(0));
5364
5365                    let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
5366                    let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
5367                    let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
5368
5369                    if is_proto {
5370                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5371                        decoder
5372                            .decode_proto_tracked(
5373                                &mut tracker,
5374                                0,
5375                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5376                                &mut output_boxes,
5377                                &mut output_tracks,
5378                            )
5379                            .unwrap();
5380
5381                        assert_eq!(output_boxes.len(), 1);
5382                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5383
5384                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5385                            *score = 0.0;
5386                        }
5387                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5388                        let proto_result = decoder
5389                            .decode_proto_tracked(
5390                                &mut tracker,
5391                                100_000_000 / 3,
5392                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5393                                &mut output_boxes,
5394                                &mut output_tracks,
5395                            )
5396                            .unwrap();
5397                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5398                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5399                    } else {
5400                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5401                        let mut output_masks = Vec::with_capacity(50);
5402                        decoder
5403                            .decode_tracked(
5404                                &mut tracker,
5405                                0,
5406                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5407                                &mut output_boxes,
5408                                &mut output_masks,
5409                                &mut output_tracks,
5410                            )
5411                            .unwrap();
5412
5413                        assert_eq!(output_boxes.len(), 1);
5414                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5415
5416                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5417                            *score = 0.0;
5418                        }
5419                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5420                        decoder
5421                            .decode_tracked(
5422                                &mut tracker,
5423                                100_000_000 / 3,
5424                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5425                                &mut output_boxes,
5426                                &mut output_masks,
5427                                &mut output_tracks,
5428                            )
5429                            .unwrap();
5430                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5431                        assert!(output_masks.is_empty());
5432                    }
5433                } else {
5434                    // Combined layout
5435                    let detect = ndarray::concatenate![
5436                        Axis(1),
5437                        boxes.view(),
5438                        scores.view(),
5439                        classes.view(),
5440                        mask.view()
5441                    ];
5442                    let detect = detect.insert_axis(Axis(0));
5443                    assert_eq!(detect.shape(), &[1, 10, 38]);
5444                    // Ensure contiguous layout after concatenation for as_slice()
5445                    let mut detect =
5446                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5447                            .unwrap();
5448
5449                    if is_proto {
5450                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5451                        decoder
5452                            .decode_proto_tracked(
5453                                &mut tracker,
5454                                0,
5455                                &[&detect_td, &protos_td],
5456                                &mut output_boxes,
5457                                &mut output_tracks,
5458                            )
5459                            .unwrap();
5460
5461                        assert_eq!(output_boxes.len(), 1);
5462                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5463
5464                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5465                            *score = 0.0;
5466                        }
5467                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5468                        let proto_result = decoder
5469                            .decode_proto_tracked(
5470                                &mut tracker,
5471                                100_000_000 / 3,
5472                                &[&detect_td, &protos_td],
5473                                &mut output_boxes,
5474                                &mut output_tracks,
5475                            )
5476                            .unwrap();
5477                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5478                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.is_empty()));
5479                    } else {
5480                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5481                        let mut output_masks = Vec::with_capacity(50);
5482                        decoder
5483                            .decode_tracked(
5484                                &mut tracker,
5485                                0,
5486                                &[&detect_td, &protos_td],
5487                                &mut output_boxes,
5488                                &mut output_masks,
5489                                &mut output_tracks,
5490                            )
5491                            .unwrap();
5492
5493                        assert_eq!(output_boxes.len(), 1);
5494                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5495
5496                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5497                            *score = 0.0;
5498                        }
5499                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5500                        decoder
5501                            .decode_tracked(
5502                                &mut tracker,
5503                                100_000_000 / 3,
5504                                &[&detect_td, &protos_td],
5505                                &mut output_boxes,
5506                                &mut output_masks,
5507                                &mut output_tracks,
5508                            )
5509                            .unwrap();
5510                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5511                        assert!(output_masks.is_empty());
5512                    }
5513                }
5514            }
5515        };
5516    }
5517
5518    e2e_tracked_tensor_test!(
5519        test_decoder_tracked_tensor_end_to_end_segdet,
5520        quantized,
5521        combined,
5522        masks
5523    );
5524    e2e_tracked_tensor_test!(
5525        test_decoder_tracked_tensor_end_to_end_segdet_float,
5526        float,
5527        combined,
5528        masks
5529    );
5530    e2e_tracked_tensor_test!(
5531        test_decoder_tracked_tensor_end_to_end_segdet_proto,
5532        quantized,
5533        combined,
5534        proto
5535    );
5536    e2e_tracked_tensor_test!(
5537        test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
5538        float,
5539        combined,
5540        proto
5541    );
5542    e2e_tracked_tensor_test!(
5543        test_decoder_tracked_tensor_end_to_end_segdet_split,
5544        quantized,
5545        split,
5546        masks
5547    );
5548    e2e_tracked_tensor_test!(
5549        test_decoder_tracked_tensor_end_to_end_segdet_split_float,
5550        float,
5551        split,
5552        masks
5553    );
5554    e2e_tracked_tensor_test!(
5555        test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
5556        quantized,
5557        split,
5558        proto
5559    );
5560    e2e_tracked_tensor_test!(
5561        test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
5562        float,
5563        split,
5564        proto
5565    );
5566
5567    #[test]
5568    fn test_decoder_tracked_linear_motion() {
5569        use crate::configs::{DecoderType, Nms};
5570        use crate::DecoderBuilder;
5571
5572        let score_threshold = 0.25;
5573        let iou_threshold = 0.1;
5574        let out = include_bytes!(concat!(
5575            env!("CARGO_MANIFEST_DIR"),
5576            "/../../testdata/yolov8s_80_classes.bin"
5577        ));
5578        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5579        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5580        let quant = (0.0040811873, -123).into();
5581
5582        let decoder = DecoderBuilder::default()
5583            .with_config_yolo_det(
5584                crate::configs::Detection {
5585                    decoder: DecoderType::Ultralytics,
5586                    shape: vec![1, 84, 8400],
5587                    anchors: None,
5588                    quantization: Some(quant),
5589                    dshape: vec![
5590                        (crate::configs::DimName::Batch, 1),
5591                        (crate::configs::DimName::NumFeatures, 84),
5592                        (crate::configs::DimName::NumBoxes, 8400),
5593                    ],
5594                    normalized: Some(true),
5595                },
5596                None,
5597            )
5598            .with_score_threshold(score_threshold)
5599            .with_iou_threshold(iou_threshold)
5600            .with_nms(Some(Nms::ClassAgnostic))
5601            .build()
5602            .unwrap();
5603
5604        let mut expected_boxes = [
5605            DetectBox {
5606                bbox: BoundingBox {
5607                    xmin: 0.5285137,
5608                    ymin: 0.05305544,
5609                    xmax: 0.87541467,
5610                    ymax: 0.9998909,
5611                },
5612                score: 0.5591227,
5613                label: 0,
5614            },
5615            DetectBox {
5616                bbox: BoundingBox {
5617                    xmin: 0.130598,
5618                    ymin: 0.43260583,
5619                    xmax: 0.35098213,
5620                    ymax: 0.9958097,
5621                },
5622                score: 0.33057618,
5623                label: 75,
5624            },
5625        ];
5626
5627        let mut tracker = ByteTrackBuilder::new()
5628            .track_update(0.1)
5629            .track_high_conf(0.3)
5630            .build();
5631
5632        let mut output_boxes = Vec::with_capacity(50);
5633        let mut output_masks = Vec::with_capacity(50);
5634        let mut output_tracks = Vec::with_capacity(50);
5635
5636        decoder
5637            .decode_tracked_quantized(
5638                &mut tracker,
5639                0,
5640                &[out.view().into()],
5641                &mut output_boxes,
5642                &mut output_masks,
5643                &mut output_tracks,
5644            )
5645            .unwrap();
5646
5647        assert_eq!(output_boxes.len(), 2);
5648        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5649        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5650
5651        for i in 1..=100 {
5652            let mut out = out.clone();
5653            // introduce linear movement into the XY coordinates
5654            let mut x_values = out.slice_mut(s![0, 0, ..]);
5655            for x in x_values.iter_mut() {
5656                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5657            }
5658
5659            decoder
5660                .decode_tracked_quantized(
5661                    &mut tracker,
5662                    100_000_000 * i / 3, // simulate 33.333ms between frames
5663                    &[out.view().into()],
5664                    &mut output_boxes,
5665                    &mut output_masks,
5666                    &mut output_tracks,
5667                )
5668                .unwrap();
5669
5670            assert_eq!(output_boxes.len(), 2);
5671        }
5672        let tracks = tracker.get_active_tracks();
5673        let predicted_boxes: Vec<_> = tracks
5674            .iter()
5675            .map(|track| {
5676                let mut l = track.last_box;
5677                l.bbox = track.info.tracked_location.into();
5678                l
5679            })
5680            .collect();
5681        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
5682        expected_boxes[0].bbox.xmax += 0.1;
5683        expected_boxes[1].bbox.xmin += 0.1;
5684        expected_boxes[1].bbox.xmax += 0.1;
5685
5686        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5687        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5688
5689        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5690        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5691        for score in scores_values.iter_mut() {
5692            *score = i8::MIN; // set all scores to minimum to simulate no detections
5693        }
5694        decoder
5695            .decode_tracked_quantized(
5696                &mut tracker,
5697                100_000_000 * 101 / 3,
5698                &[out.view().into()],
5699                &mut output_boxes,
5700                &mut output_masks,
5701                &mut output_tracks,
5702            )
5703            .unwrap();
5704        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
5705        expected_boxes[0].bbox.xmax += 0.001;
5706        expected_boxes[1].bbox.xmin += 0.001;
5707        expected_boxes[1].bbox.xmax += 0.001;
5708
5709        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5710        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5711    }
5712
5713    #[test]
5714    fn test_decoder_tracked_end_to_end_float() {
5715        let score_threshold = 0.45;
5716        let iou_threshold = 0.45;
5717
5718        let mut boxes = Array2::zeros((10, 4));
5719        let mut scores = Array2::zeros((10, 1));
5720        let mut classes = Array2::zeros((10, 1));
5721
5722        boxes
5723            .slice_mut(s![0, ..,])
5724            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5725        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5726        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5727
5728        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5729        let mut detect = detect.insert_axis(Axis(0));
5730        assert_eq!(detect.shape(), &[1, 10, 6]);
5731        let config = "
5732decoder_version: yolo26
5733outputs:
5734 - type: detection
5735   decoder: ultralytics
5736   quantization: [0.00784313725490196, 0]
5737   shape: [1, 10, 6]
5738   dshape:
5739    - [batch, 1]
5740    - [num_boxes, 10]
5741    - [num_features, 6]
5742   normalized: true
5743";
5744
5745        let decoder = DecoderBuilder::default()
5746            .with_config_yaml_str(config.to_string())
5747            .with_score_threshold(score_threshold)
5748            .with_iou_threshold(iou_threshold)
5749            .build()
5750            .unwrap();
5751
5752        let expected_boxes = [DetectBox {
5753            bbox: BoundingBox {
5754                xmin: 0.1234,
5755                ymin: 0.1234,
5756                xmax: 0.2345,
5757                ymax: 0.2345,
5758            },
5759            score: 0.9876,
5760            label: 2,
5761        }];
5762
5763        let mut tracker = ByteTrackBuilder::new()
5764            .track_update(0.1)
5765            .track_high_conf(0.7)
5766            .build();
5767
5768        let mut output_boxes = Vec::with_capacity(50);
5769        let mut output_masks = Vec::with_capacity(50);
5770        let mut output_tracks = Vec::with_capacity(50);
5771
5772        decoder
5773            .decode_tracked_float(
5774                &mut tracker,
5775                0,
5776                &[detect.view().into_dyn()],
5777                &mut output_boxes,
5778                &mut output_masks,
5779                &mut output_tracks,
5780            )
5781            .unwrap();
5782
5783        assert_eq!(output_boxes.len(), 1);
5784        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5785
5786        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5787
5788        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5789            *score = 0.0; // set all scores to minimum to simulate no detections
5790        }
5791
5792        decoder
5793            .decode_tracked_float(
5794                &mut tracker,
5795                100_000_000 / 3,
5796                &[detect.view().into_dyn()],
5797                &mut output_boxes,
5798                &mut output_masks,
5799                &mut output_tracks,
5800            )
5801            .unwrap();
5802        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5803    }
5804}