Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod per_scale;
77pub mod schema;
78pub mod yolo;
79
80mod decoder;
81pub use decoder::*;
82
83pub use configs::{DecoderVersion, Nms};
84pub use error::{DecoderError, DecoderResult};
85pub use per_scale::DecodeDtype;
86
87use crate::{
88    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89    yolo::yolo_segmentation_to_mask,
90};
91
92/// Trait to convert bounding box formats to XYXY float format
93pub trait BBoxTypeTrait {
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97    /// Converts the bbox into XYXY float format.
98    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99        input: &[B; 4],
100        quant: Quantization,
101    ) -> [A; 4]
102    where
103        f32: AsPrimitive<A>,
104        i32: AsPrimitive<A>;
105
106    /// Converts the bbox into XYXY float format.
107    ///
108    /// # Examples
109    /// ```rust
110    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
111    /// # use ndarray::array;
112    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
113    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
114    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
115    /// ```
116    #[inline(always)]
117    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118        input: ArrayView1<B>,
119    ) -> [A; 4] {
120        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121    }
122
123    #[inline(always)]
124    /// Converts the bbox into XYXY float format.
125    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126        input: ArrayView1<B>,
127        quant: Quantization,
128    ) -> [A; 4]
129    where
130        f32: AsPrimitive<A>,
131        i32: AsPrimitive<A>,
132    {
133        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134    }
135}
136
137/// Converts XYXY bounding boxes to XYXY
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143        input.map(|b| b.as_())
144    }
145
146    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147        input: &[B; 4],
148        quant: Quantization,
149    ) -> [A; 4]
150    where
151        f32: AsPrimitive<A>,
152        i32: AsPrimitive<A>,
153    {
154        let scale = quant.scale.as_();
155        let zp = quant.zero_point.as_();
156        input.map(|b| (b.as_() - zp) * scale)
157    }
158
159    #[inline(always)]
160    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161        input: ArrayView1<B>,
162    ) -> [A; 4] {
163        [
164            input[0].as_(),
165            input[1].as_(),
166            input[2].as_(),
167            input[3].as_(),
168        ]
169    }
170}
171
172/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
173/// box
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178    #[inline(always)]
179    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180        let half = A::one() / (A::one() + A::one());
181        [
182            (input[0].as_()) - (input[2].as_() * half),
183            (input[1].as_()) - (input[3].as_() * half),
184            (input[0].as_()) + (input[2].as_() * half),
185            (input[1].as_()) + (input[3].as_() * half),
186        ]
187    }
188
189    #[inline(always)]
190    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191        input: &[B; 4],
192        quant: Quantization,
193    ) -> [A; 4]
194    where
195        f32: AsPrimitive<A>,
196        i32: AsPrimitive<A>,
197    {
198        let scale = quant.scale.as_();
199        let half_scale = (quant.scale * 0.5).as_();
200        let zp = quant.zero_point.as_();
201        let [x, y, w, h] = [
202            (input[0].as_() - zp) * scale,
203            (input[1].as_() - zp) * scale,
204            (input[2].as_() - zp) * half_scale,
205            (input[3].as_() - zp) * half_scale,
206        ];
207
208        [x - w, y - h, x + w, y + h]
209    }
210
211    #[inline(always)]
212    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213        input: ArrayView1<B>,
214    ) -> [A; 4] {
215        let half = A::one() / (A::one() + A::one());
216        [
217            (input[0].as_()) - (input[2].as_() * half),
218            (input[1].as_()) - (input[3].as_() * half),
219            (input[0].as_()) + (input[2].as_() * half),
220            (input[1].as_()) + (input[3].as_() * half),
221        ]
222    }
223}
224
225/// Describes the quantization parameters for a tensor
226#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228    pub scale: f32,
229    pub zero_point: i32,
230}
231
232impl Quantization {
233    /// Creates a new Quantization struct
234    /// # Examples
235    /// ```
236    /// # use edgefirst_decoder::Quantization;
237    /// let quant = Quantization::new(0.1, -128);
238    /// assert_eq!(quant.scale, 0.1);
239    /// assert_eq!(quant.zero_point, -128);
240    /// ```
241    pub fn new(scale: f32, zero_point: i32) -> Self {
242        Self { scale, zero_point }
243    }
244
245    /// Returns a quantization that's a no-op: scale=1.0, zero_point=0.
246    ///
247    /// Used by the per-scale pipeline as a sentinel for float-to-float
248    /// passthrough cells where no affine transform is required.
249    /// # Examples
250    /// ```
251    /// # use edgefirst_decoder::Quantization;
252    /// let quant = Quantization::identity();
253    /// assert_eq!(quant.scale, 1.0);
254    /// assert_eq!(quant.zero_point, 0);
255    /// ```
256    pub fn identity() -> Self {
257        Self {
258            scale: 1.0,
259            zero_point: 0,
260        }
261    }
262}
263
264impl From<QuantTuple> for Quantization {
265    /// Creates a new Quantization struct from a QuantTuple
266    /// # Examples
267    /// ```
268    /// # use edgefirst_decoder::Quantization;
269    /// # use edgefirst_decoder::configs::QuantTuple;
270    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
271    /// let quant = Quantization::from(quant_tuple);
272    /// assert_eq!(quant.scale, 0.1);
273    /// assert_eq!(quant.zero_point, -128);
274    /// ```
275    fn from(quant_tuple: QuantTuple) -> Quantization {
276        Quantization {
277            scale: quant_tuple.0,
278            zero_point: quant_tuple.1,
279        }
280    }
281}
282
283impl<S, Z> From<(S, Z)> for Quantization
284where
285    S: AsPrimitive<f32>,
286    Z: AsPrimitive<i32>,
287{
288    /// Creates a new Quantization struct from a tuple
289    /// # Examples
290    /// ```
291    /// # use edgefirst_decoder::Quantization;
292    /// let quant = Quantization::from((0.1_f64, -128_i64));
293    /// assert_eq!(quant.scale, 0.1);
294    /// assert_eq!(quant.zero_point, -128);
295    /// ```
296    fn from((scale, zp): (S, Z)) -> Quantization {
297        Self {
298            scale: scale.as_(),
299            zero_point: zp.as_(),
300        }
301    }
302}
303
304impl Default for Quantization {
305    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
306    /// # Examples
307    /// ```rust
308    /// # use edgefirst_decoder::Quantization;
309    /// let quant = Quantization::default();
310    /// assert_eq!(quant.scale, 1.0);
311    /// assert_eq!(quant.zero_point, 0);
312    /// ```
313    fn default() -> Self {
314        Self {
315            scale: 1.0,
316            zero_point: 0,
317        }
318    }
319}
320
321/// A detection box with f32 bbox and score
322#[derive(Debug, Clone, Copy, PartialEq, Default)]
323pub struct DetectBox {
324    pub bbox: BoundingBox,
325    /// model-specific score for this detection, higher implies more confidence
326    pub score: f32,
327    /// label index for this detection
328    pub label: usize,
329}
330
331/// A bounding box with f32 coordinates in XYXY format
332#[derive(Debug, Clone, Copy, PartialEq, Default)]
333pub struct BoundingBox {
334    /// left-most normalized coordinate of the bounding box
335    pub xmin: f32,
336    /// top-most normalized coordinate of the bounding box
337    pub ymin: f32,
338    /// right-most normalized coordinate of the bounding box
339    pub xmax: f32,
340    /// bottom-most normalized coordinate of the bounding box
341    pub ymax: f32,
342}
343
344impl BoundingBox {
345    /// Creates a new BoundingBox from the given coordinates
346    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
347        Self {
348            xmin,
349            ymin,
350            xmax,
351            ymax,
352        }
353    }
354
355    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
356    ///
357    /// ```
358    /// # use edgefirst_decoder::BoundingBox;
359    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
360    /// let canonical_bbox = bbox.to_canonical();
361    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
362    /// ```
363    pub fn to_canonical(&self) -> Self {
364        let xmin = self.xmin.min(self.xmax);
365        let xmax = self.xmin.max(self.xmax);
366        let ymin = self.ymin.min(self.ymax);
367        let ymax = self.ymin.max(self.ymax);
368        BoundingBox {
369            xmin,
370            ymin,
371            xmax,
372            ymax,
373        }
374    }
375}
376
377impl From<BoundingBox> for [f32; 4] {
378    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
379    /// xmax, ymax order
380    /// # Examples
381    /// ```
382    /// # use edgefirst_decoder::BoundingBox;
383    /// let bbox = BoundingBox {
384    ///     xmin: 0.1,
385    ///     ymin: 0.2,
386    ///     xmax: 0.3,
387    ///     ymax: 0.4,
388    /// };
389    /// let arr: [f32; 4] = bbox.into();
390    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
391    /// ```
392    fn from(b: BoundingBox) -> Self {
393        [b.xmin, b.ymin, b.xmax, b.ymax]
394    }
395}
396
397impl From<[f32; 4]> for BoundingBox {
398    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
399    // BoundingBox
400    fn from(arr: [f32; 4]) -> Self {
401        BoundingBox {
402            xmin: arr[0],
403            ymin: arr[1],
404            xmax: arr[2],
405            ymax: arr[3],
406        }
407    }
408}
409
410impl DetectBox {
411    /// Returns true if one detect box is equal to another detect box, within
412    /// the given `eps`
413    ///
414    /// # Examples
415    /// ```
416    /// # use edgefirst_decoder::DetectBox;
417    /// let box1 = DetectBox {
418    ///     bbox: edgefirst_decoder::BoundingBox {
419    ///         xmin: 0.1,
420    ///         ymin: 0.2,
421    ///         xmax: 0.3,
422    ///         ymax: 0.4,
423    ///     },
424    ///     score: 0.5,
425    ///     label: 1,
426    /// };
427    /// let box2 = DetectBox {
428    ///     bbox: edgefirst_decoder::BoundingBox {
429    ///         xmin: 0.101,
430    ///         ymin: 0.199,
431    ///         xmax: 0.301,
432    ///         ymax: 0.399,
433    ///     },
434    ///     score: 0.510,
435    ///     label: 1,
436    /// };
437    /// assert!(box1.equal_within_delta(&box2, 0.011));
438    /// ```
439    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
440        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
441        self.label == rhs.label
442            && eq_delta(self.score, rhs.score)
443            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
444            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
445            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
446            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
447    }
448}
449
450/// A segmentation result paired with the normalized bounding box that
451/// describes the spatial extent of `segmentation`.
452///
453/// The `xmin`/`ymin`/`xmax`/`ymax` fields describe the **mask region** —
454/// the proto-grid-aligned crop the [`segmentation`] tensor was sliced
455/// from. They are quantized to the proto grid step (`1/proto_height` and
456/// `1/proto_width`, typically `1/160`), so the region's origin floors and
457/// its extent ceils relative to the companion [`DetectBox`]'s `bbox`. The
458/// detection bbox itself stays un-snapped (see EDGEAI-1304); use
459/// `Segmentation` bounds for rendering masks, and `DetectBox.bbox` for
460/// IoU evaluation, drawing detection rectangles, etc. The mask region
461/// always encloses the detection bbox.
462#[derive(Debug, Clone, PartialEq, Default)]
463pub struct Segmentation {
464    /// Left-most normalized coordinate of the mask region (proto-grid
465    /// floor of the companion `DetectBox.bbox.xmin`).
466    pub xmin: f32,
467    /// Top-most normalized coordinate of the mask region (proto-grid
468    /// floor of the companion `DetectBox.bbox.ymin`).
469    pub ymin: f32,
470    /// Right-most normalized coordinate of the mask region (proto-grid
471    /// ceil of the companion `DetectBox.bbox.xmax`).
472    pub xmax: f32,
473    /// Bottom-most normalized coordinate of the mask region (proto-grid
474    /// ceil of the companion `DetectBox.bbox.ymax`).
475    pub ymax: f32,
476    /// 3D segmentation array of shape `(H, W, C)`.
477    ///
478    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
479    /// continuous sigmoid confidence values quantized to u8 (0 = background,
480    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
481    /// 0.5) or use smooth interpolation for anti-aliased edges.
482    ///
483    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
484    /// class scores where the object class is the argmax index.
485    pub segmentation: Array3<u8>,
486}
487
488/// Memory layout of the prototype tensor within [`ProtoData`].
489///
490/// Models may output protos in either channel-last (NHWC) or channel-first
491/// (NCHW) layout. The mask materialisation kernels dispatch on this field to
492/// avoid a costly per-frame transpose.
493#[derive(Debug, Clone, Copy, PartialEq, Eq)]
494pub enum ProtoLayout {
495    /// Channel-last: tensor shape is `[H, W, K]`, contiguous along K.
496    /// This is the traditional layout produced by the `extract_proto_data`
497    /// path after transposing from NCHW model outputs.
498    Nhwc,
499    /// Channel-first: tensor shape is `[K, H, W]`, each channel plane is
500    /// contiguous. Skipping the NCHW→NHWC transpose saves ~3 ms per frame
501    /// on Cortex-A53/A55 targets (819 KB for 32×160×160 protos).
502    Nchw,
503}
504
505/// Raw prototype data for fused decode+render pipelines.
506///
507/// Holds post-NMS intermediate state before mask materialization, allowing the
508/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
509/// shader) without materializing intermediate `Array3<u8>` masks.
510///
511/// Both fields are carried as [`TensorDyn`] so downstream consumers (Rust, C
512/// API, Python) get zero-copy typed access through the HAL's shared tensor
513/// infrastructure. Dtype policy:
514///
515/// | Source model | protos dtype | mask_coefficients dtype | protos.quantization |
516/// |---|---|---|---|
517/// | int8 quantized | [`TensorDyn::I8`] | [`TensorDyn::I8`] (raw + quantization) | `Some(q)` |
518/// | f32 | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
519/// | f16 (TensorRT fp16) | [`TensorDyn::F16`] | [`TensorDyn::F16`] | `None` |
520/// | f64 (narrowed) | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
521///
522/// Quantization metadata lives on the proto tensor itself via
523/// [`edgefirst_tensor::Tensor::quantization`] — float tensors cannot carry
524/// quantization (compile-time gated on the `IntegerType` sealed trait).
525///
526/// `TensorDyn` is not `Clone`, so neither is `ProtoData`. Consumers that need
527/// to share the proto buffer across threads should use `TensorDyn::clone_fd`
528/// / `dmabuf_clone` to dup the backing fd.
529#[derive(Debug)]
530pub struct ProtoData {
531    /// Per-detection mask coefficients, shape `[num_detections, num_protos]`.
532    pub mask_coefficients: edgefirst_tensor::TensorDyn,
533    /// Prototype tensor.
534    ///
535    /// - When `layout == ProtoLayout::Nhwc`: shape is `[proto_h, proto_w, num_protos]`.
536    /// - When `layout == ProtoLayout::Nchw`: shape is `[num_protos, proto_h, proto_w]`.
537    pub protos: edgefirst_tensor::TensorDyn,
538    /// Physical memory layout of the `protos` tensor.
539    pub layout: ProtoLayout,
540}
541
542/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
543///
544///  # Examples
545/// ```
546/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
547/// let quant = Quantization::new(0.1, -128);
548/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
549/// let detect_quant = DetectBoxQuantized {
550///     bbox,
551///     score: 100_i8,
552///     label: 1,
553/// };
554/// let detect = dequant_detect_box(&detect_quant, quant);
555/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
556/// assert_eq!(detect.label, 1);
557/// assert_eq!(detect.bbox, bbox);
558/// ```
559pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
560    detect: &DetectBoxQuantized<SCORE>,
561    quant_scores: Quantization,
562) -> DetectBox {
563    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
564    DetectBox {
565        bbox: detect.bbox,
566        score: quant_scores.scale * detect.score.as_() + scaled_zp,
567        label: detect.label,
568    }
569}
570/// A detection box with a f32 bbox and quantized score
571#[derive(Debug, Clone, Copy, PartialEq)]
572pub struct DetectBoxQuantized<
573    // BOX: Signed + PrimInt + AsPrimitive<f32>,
574    SCORE: PrimInt + AsPrimitive<f32>,
575> {
576    // pub bbox: BoundingBoxQuantized<BOX>,
577    pub bbox: BoundingBox,
578    /// model-specific score for this detection, higher implies more
579    /// confidence.
580    pub score: SCORE,
581    /// label index for this detect
582    pub label: usize,
583}
584
585/// Dequantizes an ndarray from quantized values to f32 values using the given
586/// quantization parameters
587///
588/// # Examples
589/// ```
590/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
591/// let quant = Quantization::new(0.1, -128);
592/// let input: Vec<i8> = vec![0, 127, -128, 64];
593/// let input_array = ndarray::Array1::from(input);
594/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
595/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
596/// ```
597pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
598    input: ArrayView<T, D>,
599    quant: Quantization,
600) -> Array<F, D>
601where
602    i32: num_traits::AsPrimitive<F>,
603    f32: num_traits::AsPrimitive<F>,
604{
605    let zero_point = quant.zero_point.as_();
606    let scale = quant.scale.as_();
607    if zero_point != F::zero() {
608        let scaled_zero = -zero_point * scale;
609        input.mapv(|d| d.as_() * scale + scaled_zero)
610    } else {
611        input.mapv(|d| d.as_() * scale)
612    }
613}
614
615/// Dequantizes a slice from quantized values to float values using the given
616/// quantization parameters
617///
618/// # Examples
619/// ```
620/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
621/// let quant = Quantization::new(0.1, -128);
622/// let input: Vec<i8> = vec![0, 127, -128, 64];
623/// let mut output: Vec<f32> = vec![0.0; input.len()];
624/// dequantize_cpu(&input, quant, &mut output);
625/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
626/// ```
627pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
628    input: &[T],
629    quant: Quantization,
630    output: &mut [F],
631) where
632    f32: num_traits::AsPrimitive<F>,
633    i32: num_traits::AsPrimitive<F>,
634{
635    assert!(input.len() == output.len());
636    let zero_point = quant.zero_point.as_();
637    let scale = quant.scale.as_();
638    if zero_point != F::zero() {
639        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
640        input
641            .iter()
642            .zip(output)
643            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
644    } else {
645        input
646            .iter()
647            .zip(output)
648            .for_each(|(d, deq)| *deq = d.as_() * scale);
649    }
650}
651
652/// Dequantizes a slice from quantized values to float values using the given
653/// quantization parameters, using chunked processing. This is around 5% faster
654/// than `dequantize_cpu` for large slices.
655///
656/// # Examples
657/// ```
658/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
659/// let quant = Quantization::new(0.1, -128);
660/// let input: Vec<i8> = vec![0, 127, -128, 64];
661/// let mut output: Vec<f32> = vec![0.0; input.len()];
662/// dequantize_cpu_chunked(&input, quant, &mut output);
663/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
664/// ```
665pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
666    input: &[T],
667    quant: Quantization,
668    output: &mut [F],
669) where
670    f32: num_traits::AsPrimitive<F>,
671    i32: num_traits::AsPrimitive<F>,
672{
673    assert!(input.len() == output.len());
674    let zero_point = quant.zero_point.as_();
675    let scale = quant.scale.as_();
676
677    let input = input.as_chunks::<4>();
678    let output = output.as_chunks_mut::<4>();
679
680    if zero_point != F::zero() {
681        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
682
683        input
684            .0
685            .iter()
686            .zip(output.0)
687            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
688        input
689            .1
690            .iter()
691            .zip(output.1)
692            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
693    } else {
694        input
695            .0
696            .iter()
697            .zip(output.0)
698            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
699        input
700            .1
701            .iter()
702            .zip(output.1)
703            .for_each(|(d, deq)| *deq = d.as_() * scale);
704    }
705}
706
707/// Converts a segmentation tensor into a 2D mask
708/// If the last dimension of the segmentation tensor is 1, values equal or
709/// above 128 are considered objects. Otherwise the object is the argmax index
710///
711/// # Errors
712///
713/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
714/// invalid shape.
715///
716/// # Examples
717/// ```
718/// # use edgefirst_decoder::segmentation_to_mask;
719/// let segmentation =
720///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
721/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
722/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
723/// ```
724pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
725    if segmentation.shape()[2] == 0 {
726        return Err(DecoderError::InvalidShape(
727            "Segmentation tensor must have non-zero depth".to_string(),
728        ));
729    }
730    if segmentation.shape()[2] == 1 {
731        yolo_segmentation_to_mask(segmentation, 128)
732    } else {
733        Ok(modelpack_segmentation_to_mask(segmentation))
734    }
735}
736
737/// Returns the maximum value and its index from a 1D array
738fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
739    score
740        .iter()
741        .enumerate()
742        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
743            if max > *s {
744                (max, arg_max)
745            } else {
746                (*s, ind)
747            }
748        })
749}
750
751/// NEON-accelerated argmax for i8 slices on aarch64.
752///
753/// Finds the maximum value and its index (last index wins on ties, matching
754/// the semantics of [`arg_max`]).  Falls back to the scalar implementation
755/// on non-aarch64 targets or when the slice is too short to benefit from NEON.
756#[cfg(target_arch = "aarch64")]
757pub(crate) fn arg_max_i8(scores: &[i8]) -> (i8, usize) {
758    use std::arch::aarch64::*;
759
760    let n = scores.len();
761    if n < 16 {
762        // Scalar fallback for very short slices.
763        let mut max = scores[0];
764        let mut idx = 0;
765        for (i, &s) in scores.iter().enumerate().skip(1) {
766            if s >= max {
767                max = s;
768                idx = i;
769            }
770        }
771        return (max, idx);
772    }
773
774    unsafe {
775        // Phase 1: Find the global max value using NEON horizontal max.
776        let chunks = n / 16;
777        let mut vmax = vld1q_s8(scores.as_ptr());
778        for i in 1..chunks {
779            let v = vld1q_s8(scores.as_ptr().add(i * 16));
780            vmax = vmaxq_s8(vmax, v);
781        }
782        let global_max = vmaxvq_s8(vmax);
783
784        // Handle remainder with scalar
785        let remainder_start = chunks * 16;
786        let mut final_max = global_max;
787        for &s in &scores[remainder_start..] {
788            if s > final_max {
789                final_max = s;
790            }
791        }
792
793        // Phase 2: Find the LAST index of `final_max` (preserves tie semantics).
794        // Scan backwards for the last occurrence.
795        let mut idx = 0;
796        for i in (0..n).rev() {
797            if scores[i] == final_max {
798                idx = i;
799                break;
800            }
801        }
802        (final_max, idx)
803    }
804}
805#[cfg(test)]
806#[cfg_attr(coverage_nightly, coverage(off))]
807mod decoder_tests {
808    #![allow(clippy::excessive_precision)]
809    use crate::{
810        configs::{DecoderType, DimName, Protos},
811        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
812        yolo::{
813            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
814            decode_yolo_segdet_quant,
815        },
816        *,
817    };
818    use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
819    use ndarray::Dimension;
820    use ndarray::{array, s, Array2, Array3, Array4, Axis};
821    use ndarray_stats::DeviationExt;
822    use num_traits::{AsPrimitive, PrimInt};
823
824    fn compare_outputs(
825        boxes: (&[DetectBox], &[DetectBox]),
826        masks: (&[Segmentation], &[Segmentation]),
827    ) {
828        let (boxes0, boxes1) = boxes;
829        let (masks0, masks1) = masks;
830
831        assert_eq!(boxes0.len(), boxes1.len());
832        assert_eq!(masks0.len(), masks1.len());
833
834        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
835            assert!(
836                b_i8.equal_within_delta(b_f32, 1e-6),
837                "{b_i8:?} is not equal to {b_f32:?}"
838            );
839        }
840
841        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
842            assert_eq!(
843                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
844                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
845            );
846            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
847            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
848            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
849            let diff = &mask_i8 - &mask_f32;
850            for x in 0..diff.shape()[0] {
851                for y in 0..diff.shape()[1] {
852                    for z in 0..diff.shape()[2] {
853                        let val = diff[[x, y, z]];
854                        assert!(
855                            val.abs() <= 1,
856                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
857                            x,
858                            y,
859                            z,
860                            val
861                        );
862                    }
863                }
864            }
865            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
866            assert!(
867                mean_sq_err < 1e-2,
868                "Mean Square Error between masks was greater than 1%: {:.2}%",
869                mean_sq_err * 100.0
870            );
871        }
872    }
873
874    // ─── Shared test data loaders ────────────────────────
875
876    fn load_yolov8_boxes() -> Array3<i8> {
877        let raw = include_bytes!(concat!(
878            env!("CARGO_MANIFEST_DIR"),
879            "/../../testdata/yolov8_boxes_116x8400.bin"
880        ));
881        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
882        Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
883    }
884
885    fn load_yolov8_protos() -> Array4<i8> {
886        let raw = include_bytes!(concat!(
887            env!("CARGO_MANIFEST_DIR"),
888            "/../../testdata/yolov8_protos_160x160x32.bin"
889        ));
890        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
891        Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
892    }
893
894    fn load_yolov8s_det() -> Array3<i8> {
895        let raw = include_bytes!(concat!(
896            env!("CARGO_MANIFEST_DIR"),
897            "/../../testdata/yolov8s_80_classes.bin"
898        ));
899        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
900        Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
901    }
902
903    #[test]
904    fn test_decoder_modelpack() {
905        let score_threshold = 0.45;
906        let iou_threshold = 0.45;
907        let boxes = include_bytes!(concat!(
908            env!("CARGO_MANIFEST_DIR"),
909            "/../../testdata/modelpack_boxes_1935x1x4.bin"
910        ));
911        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
912
913        let scores = include_bytes!(concat!(
914            env!("CARGO_MANIFEST_DIR"),
915            "/../../testdata/modelpack_scores_1935x1.bin"
916        ));
917        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
918
919        let quant_boxes = (0.004656755365431309, 21).into();
920        let quant_scores = (0.0019603664986789227, 0).into();
921
922        let decoder = DecoderBuilder::default()
923            .with_config_modelpack_det(
924                configs::Boxes {
925                    decoder: DecoderType::ModelPack,
926                    quantization: Some(quant_boxes),
927                    shape: vec![1, 1935, 1, 4],
928                    dshape: vec![
929                        (DimName::Batch, 1),
930                        (DimName::NumBoxes, 1935),
931                        (DimName::Padding, 1),
932                        (DimName::BoxCoords, 4),
933                    ],
934                    normalized: Some(true),
935                },
936                configs::Scores {
937                    decoder: DecoderType::ModelPack,
938                    quantization: Some(quant_scores),
939                    shape: vec![1, 1935, 1],
940                    dshape: vec![
941                        (DimName::Batch, 1),
942                        (DimName::NumBoxes, 1935),
943                        (DimName::NumClasses, 1),
944                    ],
945                },
946            )
947            .with_score_threshold(score_threshold)
948            .with_iou_threshold(iou_threshold)
949            .build()
950            .unwrap();
951
952        let quant_boxes = quant_boxes.into();
953        let quant_scores = quant_scores.into();
954
955        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
956        decode_modelpack_det(
957            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
958            (scores.slice(s![0, .., ..]), quant_scores),
959            score_threshold,
960            iou_threshold,
961            &mut output_boxes,
962        );
963        assert!(output_boxes[0].equal_within_delta(
964            &DetectBox {
965                bbox: BoundingBox {
966                    xmin: 0.40513772,
967                    ymin: 0.6379755,
968                    xmax: 0.5122431,
969                    ymax: 0.7730214,
970                },
971                score: 0.4861709,
972                label: 0
973            },
974            1e-6
975        ));
976
977        let mut output_boxes1 = Vec::with_capacity(50);
978        let mut output_masks1 = Vec::with_capacity(50);
979
980        decoder
981            .decode_quantized(
982                &[boxes.view().into(), scores.view().into()],
983                &mut output_boxes1,
984                &mut output_masks1,
985            )
986            .unwrap();
987
988        let mut output_boxes_float = Vec::with_capacity(50);
989        let mut output_masks_float = Vec::with_capacity(50);
990
991        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
992        let scores = dequantize_ndarray(scores.view(), quant_scores);
993
994        decoder
995            .decode_float::<f32>(
996                &[boxes.view().into_dyn(), scores.view().into_dyn()],
997                &mut output_boxes_float,
998                &mut output_masks_float,
999            )
1000            .unwrap();
1001
1002        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1003        compare_outputs(
1004            (&output_boxes, &output_boxes_float),
1005            (&[], &output_masks_float),
1006        );
1007    }
1008
1009    #[test]
1010    fn test_decoder_modelpack_split_u8() {
1011        let score_threshold = 0.45;
1012        let iou_threshold = 0.45;
1013        let detect0 = include_bytes!(concat!(
1014            env!("CARGO_MANIFEST_DIR"),
1015            "/../../testdata/modelpack_split_9x15x18.bin"
1016        ));
1017        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1018
1019        let detect1 = include_bytes!(concat!(
1020            env!("CARGO_MANIFEST_DIR"),
1021            "/../../testdata/modelpack_split_17x30x18.bin"
1022        ));
1023        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1024
1025        let quant0 = (0.08547406643629074, 174).into();
1026        let quant1 = (0.09929127991199493, 183).into();
1027        let anchors0 = vec![
1028            [0.36666667461395264, 0.31481480598449707],
1029            [0.38749998807907104, 0.4740740656852722],
1030            [0.5333333611488342, 0.644444465637207],
1031        ];
1032        let anchors1 = vec![
1033            [0.13750000298023224, 0.2074074000120163],
1034            [0.2541666626930237, 0.21481481194496155],
1035            [0.23125000298023224, 0.35185185074806213],
1036        ];
1037
1038        let detect_config0 = configs::Detection {
1039            decoder: DecoderType::ModelPack,
1040            shape: vec![1, 9, 15, 18],
1041            anchors: Some(anchors0.clone()),
1042            quantization: Some(quant0),
1043            dshape: vec![
1044                (DimName::Batch, 1),
1045                (DimName::Height, 9),
1046                (DimName::Width, 15),
1047                (DimName::NumAnchorsXFeatures, 18),
1048            ],
1049            normalized: Some(true),
1050        };
1051
1052        let detect_config1 = configs::Detection {
1053            decoder: DecoderType::ModelPack,
1054            shape: vec![1, 17, 30, 18],
1055            anchors: Some(anchors1.clone()),
1056            quantization: Some(quant1),
1057            dshape: vec![
1058                (DimName::Batch, 1),
1059                (DimName::Height, 17),
1060                (DimName::Width, 30),
1061                (DimName::NumAnchorsXFeatures, 18),
1062            ],
1063            normalized: Some(true),
1064        };
1065
1066        let config0 = (&detect_config0).try_into().unwrap();
1067        let config1 = (&detect_config1).try_into().unwrap();
1068
1069        let decoder = DecoderBuilder::default()
1070            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
1071            .with_score_threshold(score_threshold)
1072            .with_iou_threshold(iou_threshold)
1073            .build()
1074            .unwrap();
1075
1076        let quant0 = quant0.into();
1077        let quant1 = quant1.into();
1078
1079        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1080        decode_modelpack_split_quant(
1081            &[
1082                detect0.slice(s![0, .., .., ..]),
1083                detect1.slice(s![0, .., .., ..]),
1084            ],
1085            &[config0, config1],
1086            score_threshold,
1087            iou_threshold,
1088            &mut output_boxes,
1089        );
1090        assert!(output_boxes[0].equal_within_delta(
1091            &DetectBox {
1092                bbox: BoundingBox {
1093                    xmin: 0.43171933,
1094                    ymin: 0.68243736,
1095                    xmax: 0.5626645,
1096                    ymax: 0.808863,
1097                },
1098                score: 0.99240804,
1099                label: 0
1100            },
1101            1e-6
1102        ));
1103
1104        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1105        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1106        decoder
1107            .decode_quantized(
1108                &[detect0.view().into(), detect1.view().into()],
1109                &mut output_boxes1,
1110                &mut output_masks1,
1111            )
1112            .unwrap();
1113
1114        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1115        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1116
1117        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1118        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1119        decoder
1120            .decode_float::<f32>(
1121                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1122                &mut output_boxes1_f32,
1123                &mut output_masks1_f32,
1124            )
1125            .unwrap();
1126
1127        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1128        compare_outputs(
1129            (&output_boxes, &output_boxes1_f32),
1130            (&[], &output_masks1_f32),
1131        );
1132    }
1133
1134    #[test]
1135    fn test_decoder_parse_config_modelpack_split_u8() {
1136        let score_threshold = 0.45;
1137        let iou_threshold = 0.45;
1138        let detect0 = include_bytes!(concat!(
1139            env!("CARGO_MANIFEST_DIR"),
1140            "/../../testdata/modelpack_split_9x15x18.bin"
1141        ));
1142        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1143
1144        let detect1 = include_bytes!(concat!(
1145            env!("CARGO_MANIFEST_DIR"),
1146            "/../../testdata/modelpack_split_17x30x18.bin"
1147        ));
1148        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1149
1150        let decoder = DecoderBuilder::default()
1151            .with_config_yaml_str(
1152                include_str!(concat!(
1153                    env!("CARGO_MANIFEST_DIR"),
1154                    "/../../testdata/modelpack_split.yaml"
1155                ))
1156                .to_string(),
1157            )
1158            .with_score_threshold(score_threshold)
1159            .with_iou_threshold(iou_threshold)
1160            .build()
1161            .unwrap();
1162
1163        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1164        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1165        decoder
1166            .decode_quantized(
1167                &[
1168                    ArrayViewDQuantized::from(detect1.view()),
1169                    ArrayViewDQuantized::from(detect0.view()),
1170                ],
1171                &mut output_boxes,
1172                &mut output_masks,
1173            )
1174            .unwrap();
1175        assert!(output_boxes[0].equal_within_delta(
1176            &DetectBox {
1177                bbox: BoundingBox {
1178                    xmin: 0.43171933,
1179                    ymin: 0.68243736,
1180                    xmax: 0.5626645,
1181                    ymax: 0.808863,
1182                },
1183                score: 0.99240804,
1184                label: 0
1185            },
1186            1e-6
1187        ));
1188    }
1189
1190    #[test]
1191    fn test_modelpack_seg() {
1192        let out = include_bytes!(concat!(
1193            env!("CARGO_MANIFEST_DIR"),
1194            "/../../testdata/modelpack_seg_2x160x160.bin"
1195        ));
1196        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1197        let quant = (1.0 / 255.0, 0).into();
1198
1199        let decoder = DecoderBuilder::default()
1200            .with_config_modelpack_seg(configs::Segmentation {
1201                decoder: DecoderType::ModelPack,
1202                quantization: Some(quant),
1203                shape: vec![1, 2, 160, 160],
1204                dshape: vec![
1205                    (DimName::Batch, 1),
1206                    (DimName::NumClasses, 2),
1207                    (DimName::Height, 160),
1208                    (DimName::Width, 160),
1209                ],
1210            })
1211            .build()
1212            .unwrap();
1213        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1214        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1215        decoder
1216            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1217            .unwrap();
1218
1219        let mut mask = out.slice(s![0, .., .., ..]);
1220        mask.swap_axes(0, 1);
1221        mask.swap_axes(1, 2);
1222        let mask = [Segmentation {
1223            xmin: 0.0,
1224            ymin: 0.0,
1225            xmax: 1.0,
1226            ymax: 1.0,
1227            segmentation: mask.into_owned(),
1228        }];
1229        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1230
1231        decoder
1232            .decode_float::<f32>(
1233                &[dequantize_ndarray(out.view(), quant.into())
1234                    .view()
1235                    .into_dyn()],
1236                &mut output_boxes,
1237                &mut output_masks,
1238            )
1239            .unwrap();
1240
1241        // not expected for float decoder to have same values as quantized decoder, as
1242        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1243        // the model output. Thus the float output is the same as the quantized output
1244        // but scaled differently. However, it is expected that the mask after argmax
1245        // will be the same.
1246        compare_outputs((&[], &output_boxes), (&[], &[]));
1247        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1248        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1249
1250        assert_eq!(mask0, mask1);
1251    }
1252    #[test]
1253    fn test_modelpack_seg_quant() {
1254        let out = include_bytes!(concat!(
1255            env!("CARGO_MANIFEST_DIR"),
1256            "/../../testdata/modelpack_seg_2x160x160.bin"
1257        ));
1258        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1259        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1260        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1261        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1262        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1263        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1264
1265        let quant = (1.0 / 255.0, 0).into();
1266
1267        let decoder = DecoderBuilder::default()
1268            .with_config_modelpack_seg(configs::Segmentation {
1269                decoder: DecoderType::ModelPack,
1270                quantization: Some(quant),
1271                shape: vec![1, 2, 160, 160],
1272                dshape: vec![
1273                    (DimName::Batch, 1),
1274                    (DimName::NumClasses, 2),
1275                    (DimName::Height, 160),
1276                    (DimName::Width, 160),
1277                ],
1278            })
1279            .build()
1280            .unwrap();
1281        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1282        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1283        decoder
1284            .decode_quantized(
1285                &[out_u8.view().into()],
1286                &mut output_boxes,
1287                &mut output_masks_u8,
1288            )
1289            .unwrap();
1290
1291        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1292        decoder
1293            .decode_quantized(
1294                &[out_i8.view().into()],
1295                &mut output_boxes,
1296                &mut output_masks_i8,
1297            )
1298            .unwrap();
1299
1300        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1301        decoder
1302            .decode_quantized(
1303                &[out_u16.view().into()],
1304                &mut output_boxes,
1305                &mut output_masks_u16,
1306            )
1307            .unwrap();
1308
1309        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1310        decoder
1311            .decode_quantized(
1312                &[out_i16.view().into()],
1313                &mut output_boxes,
1314                &mut output_masks_i16,
1315            )
1316            .unwrap();
1317
1318        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1319        decoder
1320            .decode_quantized(
1321                &[out_u32.view().into()],
1322                &mut output_boxes,
1323                &mut output_masks_u32,
1324            )
1325            .unwrap();
1326
1327        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1328        decoder
1329            .decode_quantized(
1330                &[out_i32.view().into()],
1331                &mut output_boxes,
1332                &mut output_masks_i32,
1333            )
1334            .unwrap();
1335
1336        compare_outputs((&[], &output_boxes), (&[], &[]));
1337        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1338        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1339        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1340        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1341        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1342        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1343        assert_eq!(mask_u8, mask_i8);
1344        assert_eq!(mask_u8, mask_u16);
1345        assert_eq!(mask_u8, mask_i16);
1346        assert_eq!(mask_u8, mask_u32);
1347        assert_eq!(mask_u8, mask_i32);
1348    }
1349
1350    #[test]
1351    fn test_modelpack_segdet() {
1352        let score_threshold = 0.45;
1353        let iou_threshold = 0.45;
1354
1355        let boxes = include_bytes!(concat!(
1356            env!("CARGO_MANIFEST_DIR"),
1357            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1358        ));
1359        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1360
1361        let scores = include_bytes!(concat!(
1362            env!("CARGO_MANIFEST_DIR"),
1363            "/../../testdata/modelpack_scores_1935x1.bin"
1364        ));
1365        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1366
1367        let seg = include_bytes!(concat!(
1368            env!("CARGO_MANIFEST_DIR"),
1369            "/../../testdata/modelpack_seg_2x160x160.bin"
1370        ));
1371        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1372
1373        let quant_boxes = (0.004656755365431309, 21).into();
1374        let quant_scores = (0.0019603664986789227, 0).into();
1375        let quant_seg = (1.0 / 255.0, 0).into();
1376
1377        let decoder = DecoderBuilder::default()
1378            .with_config_modelpack_segdet(
1379                configs::Boxes {
1380                    decoder: DecoderType::ModelPack,
1381                    quantization: Some(quant_boxes),
1382                    shape: vec![1, 1935, 1, 4],
1383                    dshape: vec![
1384                        (DimName::Batch, 1),
1385                        (DimName::NumBoxes, 1935),
1386                        (DimName::Padding, 1),
1387                        (DimName::BoxCoords, 4),
1388                    ],
1389                    normalized: Some(true),
1390                },
1391                configs::Scores {
1392                    decoder: DecoderType::ModelPack,
1393                    quantization: Some(quant_scores),
1394                    shape: vec![1, 1935, 1],
1395                    dshape: vec![
1396                        (DimName::Batch, 1),
1397                        (DimName::NumBoxes, 1935),
1398                        (DimName::NumClasses, 1),
1399                    ],
1400                },
1401                configs::Segmentation {
1402                    decoder: DecoderType::ModelPack,
1403                    quantization: Some(quant_seg),
1404                    shape: vec![1, 2, 160, 160],
1405                    dshape: vec![
1406                        (DimName::Batch, 1),
1407                        (DimName::NumClasses, 2),
1408                        (DimName::Height, 160),
1409                        (DimName::Width, 160),
1410                    ],
1411                },
1412            )
1413            .with_iou_threshold(iou_threshold)
1414            .with_score_threshold(score_threshold)
1415            .build()
1416            .unwrap();
1417        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1418        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1419        decoder
1420            .decode_quantized(
1421                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1422                &mut output_boxes,
1423                &mut output_masks,
1424            )
1425            .unwrap();
1426
1427        let mut mask = seg.slice(s![0, .., .., ..]);
1428        mask.swap_axes(0, 1);
1429        mask.swap_axes(1, 2);
1430        let mask = [Segmentation {
1431            xmin: 0.0,
1432            ymin: 0.0,
1433            xmax: 1.0,
1434            ymax: 1.0,
1435            segmentation: mask.into_owned(),
1436        }];
1437        let correct_boxes = [DetectBox {
1438            bbox: BoundingBox {
1439                xmin: 0.40513772,
1440                ymin: 0.6379755,
1441                xmax: 0.5122431,
1442                ymax: 0.7730214,
1443            },
1444            score: 0.4861709,
1445            label: 0,
1446        }];
1447        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1448
1449        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1450        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1451        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1452        decoder
1453            .decode_float::<f32>(
1454                &[
1455                    scores.view().into_dyn(),
1456                    boxes.view().into_dyn(),
1457                    seg.view().into_dyn(),
1458                ],
1459                &mut output_boxes,
1460                &mut output_masks,
1461            )
1462            .unwrap();
1463
1464        // not expected for float segmentation decoder to have same values as quantized
1465        // segmentation decoder, as float decoder ensures the data fills 0-255,
1466        // quantized decoder uses whatever the model output. Thus the float
1467        // output is the same as the quantized output but scaled differently.
1468        // However, it is expected that the mask after argmax will be the same.
1469        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1470        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1471        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1472
1473        assert_eq!(mask0, mask1);
1474    }
1475
1476    #[test]
1477    fn test_modelpack_segdet_split() {
1478        let score_threshold = 0.8;
1479        let iou_threshold = 0.5;
1480
1481        let seg = include_bytes!(concat!(
1482            env!("CARGO_MANIFEST_DIR"),
1483            "/../../testdata/modelpack_seg_2x160x160.bin"
1484        ));
1485        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1486
1487        let detect0 = include_bytes!(concat!(
1488            env!("CARGO_MANIFEST_DIR"),
1489            "/../../testdata/modelpack_split_9x15x18.bin"
1490        ));
1491        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1492
1493        let detect1 = include_bytes!(concat!(
1494            env!("CARGO_MANIFEST_DIR"),
1495            "/../../testdata/modelpack_split_17x30x18.bin"
1496        ));
1497        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1498
1499        let quant0 = (0.08547406643629074, 174).into();
1500        let quant1 = (0.09929127991199493, 183).into();
1501        let quant_seg = (1.0 / 255.0, 0).into();
1502
1503        let anchors0 = vec![
1504            [0.36666667461395264, 0.31481480598449707],
1505            [0.38749998807907104, 0.4740740656852722],
1506            [0.5333333611488342, 0.644444465637207],
1507        ];
1508        let anchors1 = vec![
1509            [0.13750000298023224, 0.2074074000120163],
1510            [0.2541666626930237, 0.21481481194496155],
1511            [0.23125000298023224, 0.35185185074806213],
1512        ];
1513
1514        let decoder = DecoderBuilder::default()
1515            .with_config_modelpack_segdet_split(
1516                vec![
1517                    configs::Detection {
1518                        decoder: DecoderType::ModelPack,
1519                        shape: vec![1, 17, 30, 18],
1520                        anchors: Some(anchors1),
1521                        quantization: Some(quant1),
1522                        dshape: vec![
1523                            (DimName::Batch, 1),
1524                            (DimName::Height, 17),
1525                            (DimName::Width, 30),
1526                            (DimName::NumAnchorsXFeatures, 18),
1527                        ],
1528                        normalized: Some(true),
1529                    },
1530                    configs::Detection {
1531                        decoder: DecoderType::ModelPack,
1532                        shape: vec![1, 9, 15, 18],
1533                        anchors: Some(anchors0),
1534                        quantization: Some(quant0),
1535                        dshape: vec![
1536                            (DimName::Batch, 1),
1537                            (DimName::Height, 9),
1538                            (DimName::Width, 15),
1539                            (DimName::NumAnchorsXFeatures, 18),
1540                        ],
1541                        normalized: Some(true),
1542                    },
1543                ],
1544                configs::Segmentation {
1545                    decoder: DecoderType::ModelPack,
1546                    quantization: Some(quant_seg),
1547                    shape: vec![1, 2, 160, 160],
1548                    dshape: vec![
1549                        (DimName::Batch, 1),
1550                        (DimName::NumClasses, 2),
1551                        (DimName::Height, 160),
1552                        (DimName::Width, 160),
1553                    ],
1554                },
1555            )
1556            .with_score_threshold(score_threshold)
1557            .with_iou_threshold(iou_threshold)
1558            .build()
1559            .unwrap();
1560        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1561        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1562        decoder
1563            .decode_quantized(
1564                &[
1565                    detect0.view().into(),
1566                    detect1.view().into(),
1567                    seg.view().into(),
1568                ],
1569                &mut output_boxes,
1570                &mut output_masks,
1571            )
1572            .unwrap();
1573
1574        let mut mask = seg.slice(s![0, .., .., ..]);
1575        mask.swap_axes(0, 1);
1576        mask.swap_axes(1, 2);
1577        let mask = [Segmentation {
1578            xmin: 0.0,
1579            ymin: 0.0,
1580            xmax: 1.0,
1581            ymax: 1.0,
1582            segmentation: mask.into_owned(),
1583        }];
1584        let correct_boxes = [DetectBox {
1585            bbox: BoundingBox {
1586                xmin: 0.43171933,
1587                ymin: 0.68243736,
1588                xmax: 0.5626645,
1589                ymax: 0.808863,
1590            },
1591            score: 0.99240804,
1592            label: 0,
1593        }];
1594        println!("Output Boxes: {:?}", output_boxes);
1595        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1596
1597        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1598        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1599        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1600        decoder
1601            .decode_float::<f32>(
1602                &[
1603                    detect0.view().into_dyn(),
1604                    detect1.view().into_dyn(),
1605                    seg.view().into_dyn(),
1606                ],
1607                &mut output_boxes,
1608                &mut output_masks,
1609            )
1610            .unwrap();
1611
1612        // not expected for float segmentation decoder to have same values as quantized
1613        // segmentation decoder, as float decoder ensures the data fills 0-255,
1614        // quantized decoder uses whatever the model output. Thus the float
1615        // output is the same as the quantized output but scaled differently.
1616        // However, it is expected that the mask after argmax will be the same.
1617        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1618        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1619        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1620
1621        assert_eq!(mask0, mask1);
1622    }
1623
1624    #[test]
1625    fn test_dequant_chunked() {
1626        let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1627        out.push(123); // make sure to test non multiple of 16 length
1628
1629        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1630        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1631        let quant = Quantization::new(0.0040811873, -123);
1632        dequantize_cpu(&out, quant, &mut out_dequant);
1633
1634        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1635        assert_eq!(out_dequant, out_dequant_simd);
1636
1637        let quant = Quantization::new(0.0040811873, 0);
1638        dequantize_cpu(&out, quant, &mut out_dequant);
1639
1640        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1641        assert_eq!(out_dequant, out_dequant_simd);
1642    }
1643
1644    #[test]
1645    fn test_dequant_ground_truth() {
1646        // Formula: output = (input - zero_point) * scale
1647        // Verify both dequantize_cpu and dequantize_cpu_chunked against hand-computed values.
1648
1649        // Case 1: scale=0.1, zero_point=-128 (from doc example)
1650        let quant = Quantization::new(0.1, -128);
1651        let input: Vec<i8> = vec![0, 127, -128, 64];
1652        let mut output = vec![0.0f32; 4];
1653        let mut output_chunked = vec![0.0f32; 4];
1654        dequantize_cpu(&input, quant, &mut output);
1655        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1656        // (0 - (-128)) * 0.1 = 12.8
1657        // (127 - (-128)) * 0.1 = 25.5
1658        // (-128 - (-128)) * 0.1 = 0.0
1659        // (64 - (-128)) * 0.1 = 19.2
1660        let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1661        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1662            assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1663        }
1664        for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1665            assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1666        }
1667
1668        // Case 2: scale=1.0, zero_point=0 (identity-like)
1669        let quant = Quantization::new(1.0, 0);
1670        dequantize_cpu(&input, quant, &mut output);
1671        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1672        let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1673        assert_eq!(output, expected);
1674        assert_eq!(output_chunked, expected);
1675
1676        // Case 3: scale=0.5, zero_point=0
1677        let quant = Quantization::new(0.5, 0);
1678        dequantize_cpu(&input, quant, &mut output);
1679        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1680        let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1681        assert_eq!(output, expected);
1682        assert_eq!(output_chunked, expected);
1683
1684        // Case 4: i8 min/max boundaries with typical quantization params
1685        let quant = Quantization::new(0.021287762, 31);
1686        let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1687        let mut output = vec![0.0f32; 6];
1688        let mut output_chunked = vec![0.0f32; 6];
1689        dequantize_cpu(&input, quant, &mut output);
1690        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1691        for i in 0..6 {
1692            let expected = (input[i] as f32 - 31.0) * 0.021287762;
1693            assert!(
1694                (output[i] - expected).abs() < 1e-5,
1695                "cpu[{i}]: {} != {expected}",
1696                output[i]
1697            );
1698            assert!(
1699                (output_chunked[i] - expected).abs() < 1e-5,
1700                "chunked[{i}]: {} != {expected}",
1701                output_chunked[i]
1702            );
1703        }
1704    }
1705
1706    #[test]
1707    fn test_decoder_yolo_det() {
1708        let score_threshold = 0.25;
1709        let iou_threshold = 0.7;
1710        let out = load_yolov8s_det();
1711        let quant = (0.0040811873, -123).into();
1712
1713        let decoder = DecoderBuilder::default()
1714            .with_config_yolo_det(
1715                configs::Detection {
1716                    decoder: DecoderType::Ultralytics,
1717                    shape: vec![1, 84, 8400],
1718                    anchors: None,
1719                    quantization: Some(quant),
1720                    dshape: vec![
1721                        (DimName::Batch, 1),
1722                        (DimName::NumFeatures, 84),
1723                        (DimName::NumBoxes, 8400),
1724                    ],
1725                    normalized: Some(true),
1726                },
1727                Some(DecoderVersion::Yolo11),
1728            )
1729            .with_score_threshold(score_threshold)
1730            .with_iou_threshold(iou_threshold)
1731            .build()
1732            .unwrap();
1733
1734        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1735        decode_yolo_det(
1736            (out.slice(s![0, .., ..]), quant.into()),
1737            score_threshold,
1738            iou_threshold,
1739            Some(configs::Nms::ClassAgnostic),
1740            &mut output_boxes,
1741        );
1742        assert!(output_boxes[0].equal_within_delta(
1743            &DetectBox {
1744                bbox: BoundingBox {
1745                    xmin: 0.5285137,
1746                    ymin: 0.05305544,
1747                    xmax: 0.87541467,
1748                    ymax: 0.9998909,
1749                },
1750                score: 0.5591227,
1751                label: 0
1752            },
1753            1e-6
1754        ));
1755
1756        assert!(output_boxes[1].equal_within_delta(
1757            &DetectBox {
1758                bbox: BoundingBox {
1759                    xmin: 0.130598,
1760                    ymin: 0.43260583,
1761                    xmax: 0.35098213,
1762                    ymax: 0.9958097,
1763                },
1764                score: 0.33057618,
1765                label: 75
1766            },
1767            1e-6
1768        ));
1769
1770        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1771        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1772        decoder
1773            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1774            .unwrap();
1775
1776        let out = dequantize_ndarray(out.view(), quant.into());
1777        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1778        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1779        decoder
1780            .decode_float::<f32>(
1781                &[out.view().into_dyn()],
1782                &mut output_boxes_f32,
1783                &mut output_masks_f32,
1784            )
1785            .unwrap();
1786
1787        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1788        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1789    }
1790
1791    #[test]
1792    fn test_decoder_masks() {
1793        let score_threshold = 0.45;
1794        let iou_threshold = 0.45;
1795        let boxes = load_yolov8_boxes();
1796        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1797
1798        let protos = load_yolov8_protos();
1799        let quant_protos = Quantization::new(0.02491161972284317, -117);
1800        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1801        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1802        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1803        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1804        decode_yolo_segdet_float(
1805            seg.slice(s![0, .., ..]),
1806            protos.slice(s![0, .., .., ..]),
1807            score_threshold,
1808            iou_threshold,
1809            Some(configs::Nms::ClassAgnostic),
1810            &mut output_boxes,
1811            &mut output_masks,
1812        )
1813        .unwrap();
1814        assert_eq!(output_boxes.len(), 2);
1815        assert_eq!(output_boxes.len(), output_masks.len());
1816
1817        for (b, m) in output_boxes.iter().zip(&output_masks) {
1818            // Mask region is the proto-grid-aligned crop (floor for min,
1819            // ceil for max), so it encloses the post-NMS bbox (EDGEAI-1304).
1820            assert!(b.bbox.xmin >= m.xmin);
1821            assert!(b.bbox.ymin >= m.ymin);
1822            assert!(b.bbox.xmax <= m.xmax);
1823            assert!(b.bbox.ymax <= m.ymax);
1824        }
1825        assert!(output_boxes[0].equal_within_delta(
1826            &DetectBox {
1827                bbox: BoundingBox {
1828                    xmin: 0.08515105,
1829                    ymin: 0.7131401,
1830                    xmax: 0.29802868,
1831                    ymax: 0.8195788,
1832                },
1833                score: 0.91537374,
1834                label: 23
1835            },
1836            1.0 / 160.0, // wider range because mask will expand the box
1837        ));
1838
1839        assert!(output_boxes[1].equal_within_delta(
1840            &DetectBox {
1841                bbox: BoundingBox {
1842                    xmin: 0.59605736,
1843                    ymin: 0.25545314,
1844                    xmax: 0.93666154,
1845                    ymax: 0.72378385,
1846                },
1847                score: 0.91537374,
1848                label: 23
1849            },
1850            1.0 / 160.0, // wider range because mask will expand the box
1851        ));
1852
1853        let full_mask = include_bytes!(concat!(
1854            env!("CARGO_MANIFEST_DIR"),
1855            "/../../testdata/yolov8_mask_results.bin"
1856        ));
1857        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1858
1859        let cropped_mask = full_mask.slice(ndarray::s![
1860            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1861            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1862        ]);
1863
1864        assert_eq!(
1865            cropped_mask,
1866            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1867        );
1868    }
1869
1870    /// Regression test: config-driven path with physically-NCHW protos.
1871    /// Simulates YOLOv8-seg ONNX outputs where the producer emits protos
1872    /// as `(1, 32, 160, 160)` in CHW memory order — the caller declares
1873    /// shape and dshape matching that physical order and HAL permutes to
1874    /// canonical HWC via `swap_axes_if_needed`.
1875    ///
1876    /// This is the counterpart to the NHWC-producer case (TFLite /
1877    /// Ara-2) where shape+dshape are `(1, 160, 160, 32)` +
1878    /// `[batch, height, width, num_protos]` and no reordering is needed.
1879    #[test]
1880    fn test_decoder_masks_nchw_protos() {
1881        let score_threshold = 0.45;
1882        let iou_threshold = 0.45;
1883
1884        // Load test data — boxes as [116, 8400]
1885        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1886        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1887
1888        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1889        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1890        let quant_protos = Quantization::new(0.02491161972284317, -117);
1891        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1892
1893        // ---- Reference: direct call with HWC protos (known working) ----
1894        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1895        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1896        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1897        decode_yolo_segdet_float(
1898            seg.view(),
1899            protos_f32_hwc.view(),
1900            score_threshold,
1901            iou_threshold,
1902            Some(configs::Nms::ClassAgnostic),
1903            &mut ref_boxes,
1904            &mut ref_masks,
1905        )
1906        .unwrap();
1907        assert_eq!(ref_boxes.len(), 2);
1908
1909        // ---- Config-driven path: NCHW protos declared in physical order ----
1910        // Permute the HWC test data to CHW memory order — this is what an
1911        // ONNX-style producer would emit into the tensor buffer.
1912        // `to_owned` materialises a C-contiguous Array3<f32> with CHW
1913        // strides, matching a producer that writes channels-outer.
1914        let protos_f32_chw_view = protos_f32_hwc.view().permuted_axes([2, 0, 1]); // [32, 160, 160]
1915        let protos_f32_chw = protos_f32_chw_view.to_owned();
1916        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1917
1918        // Build boxes as [1, 116, 8400] f32
1919        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1920
1921        // Declare shape and dshape in physical memory order (outermost
1922        // first). `swap_axes_if_needed` uses dshape role lookup to
1923        // permute the stride tuple into canonical HWC.
1924        let decoder = DecoderBuilder::default()
1925            .with_config_yolo_segdet(
1926                configs::Detection {
1927                    decoder: configs::DecoderType::Ultralytics,
1928                    quantization: None,
1929                    shape: vec![1, 116, 8400],
1930                    dshape: vec![
1931                        (configs::DimName::Batch, 1),
1932                        (configs::DimName::NumFeatures, 116),
1933                        (configs::DimName::NumBoxes, 8400),
1934                    ],
1935                    normalized: Some(true),
1936                    anchors: None,
1937                },
1938                configs::Protos {
1939                    decoder: configs::DecoderType::Ultralytics,
1940                    quantization: None,
1941                    shape: vec![1, 32, 160, 160],
1942                    dshape: vec![
1943                        (configs::DimName::Batch, 1),
1944                        (configs::DimName::NumProtos, 32),
1945                        (configs::DimName::Height, 160),
1946                        (configs::DimName::Width, 160),
1947                    ],
1948                },
1949                None, // decoder version
1950            )
1951            .with_score_threshold(score_threshold)
1952            .with_iou_threshold(iou_threshold)
1953            .build()
1954            .unwrap();
1955
1956        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1957        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1958        decoder
1959            .decode_float(
1960                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1961                &mut cfg_boxes,
1962                &mut cfg_masks,
1963            )
1964            .unwrap();
1965
1966        // Must produce the same number of detections
1967        assert_eq!(
1968            cfg_boxes.len(),
1969            ref_boxes.len(),
1970            "config path produced {} boxes, reference produced {}",
1971            cfg_boxes.len(),
1972            ref_boxes.len()
1973        );
1974
1975        // Boxes must match
1976        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1977            assert!(
1978                cb.equal_within_delta(rb, 0.01),
1979                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1980            );
1981        }
1982
1983        // Masks must match pixel-for-pixel
1984        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1985            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1986            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1987            assert_eq!(
1988                cm_arr, rm_arr,
1989                "mask {i} pixel mismatch between config-driven and reference paths"
1990            );
1991        }
1992    }
1993
1994    #[test]
1995    fn test_decoder_masks_i8() {
1996        let score_threshold = 0.45;
1997        let iou_threshold = 0.45;
1998        let boxes = load_yolov8_boxes();
1999        let quant_boxes = (0.021287761628627777, 31).into();
2000
2001        let protos = load_yolov8_protos();
2002        let quant_protos = (0.02491161972284317, -117).into();
2003        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2004        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2005
2006        let decoder = DecoderBuilder::default()
2007            .with_config_yolo_segdet(
2008                configs::Detection {
2009                    decoder: configs::DecoderType::Ultralytics,
2010                    quantization: Some(quant_boxes),
2011                    shape: vec![1, 116, 8400],
2012                    anchors: None,
2013                    dshape: vec![
2014                        (DimName::Batch, 1),
2015                        (DimName::NumFeatures, 116),
2016                        (DimName::NumBoxes, 8400),
2017                    ],
2018                    normalized: Some(true),
2019                },
2020                Protos {
2021                    decoder: configs::DecoderType::Ultralytics,
2022                    quantization: Some(quant_protos),
2023                    shape: vec![1, 160, 160, 32],
2024                    dshape: vec![
2025                        (DimName::Batch, 1),
2026                        (DimName::Height, 160),
2027                        (DimName::Width, 160),
2028                        (DimName::NumProtos, 32),
2029                    ],
2030                },
2031                Some(DecoderVersion::Yolo11),
2032            )
2033            .with_score_threshold(score_threshold)
2034            .with_iou_threshold(iou_threshold)
2035            .build()
2036            .unwrap();
2037
2038        let quant_boxes = quant_boxes.into();
2039        let quant_protos = quant_protos.into();
2040
2041        decode_yolo_segdet_quant(
2042            (boxes.slice(s![0, .., ..]), quant_boxes),
2043            (protos.slice(s![0, .., .., ..]), quant_protos),
2044            score_threshold,
2045            iou_threshold,
2046            Some(configs::Nms::ClassAgnostic),
2047            &mut output_boxes,
2048            &mut output_masks,
2049        )
2050        .unwrap();
2051
2052        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2053        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2054
2055        decoder
2056            .decode_quantized(
2057                &[boxes.view().into(), protos.view().into()],
2058                &mut output_boxes1,
2059                &mut output_masks1,
2060            )
2061            .unwrap();
2062
2063        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2064        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2065
2066        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2067        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2068        decode_yolo_segdet_float(
2069            seg.slice(s![0, .., ..]),
2070            protos.slice(s![0, .., .., ..]),
2071            score_threshold,
2072            iou_threshold,
2073            Some(configs::Nms::ClassAgnostic),
2074            &mut output_boxes_f32,
2075            &mut output_masks_f32,
2076        )
2077        .unwrap();
2078
2079        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
2080        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
2081
2082        decoder
2083            .decode_float(
2084                &[seg.view().into_dyn(), protos.view().into_dyn()],
2085                &mut output_boxes1_f32,
2086                &mut output_masks1_f32,
2087            )
2088            .unwrap();
2089
2090        compare_outputs(
2091            (&output_boxes, &output_boxes1),
2092            (&output_masks, &output_masks1),
2093        );
2094
2095        compare_outputs(
2096            (&output_boxes, &output_boxes_f32),
2097            (&output_masks, &output_masks_f32),
2098        );
2099
2100        compare_outputs(
2101            (&output_boxes_f32, &output_boxes1_f32),
2102            (&output_masks_f32, &output_masks1_f32),
2103        );
2104    }
2105
2106    #[test]
2107    fn test_decoder_yolo_split() {
2108        let score_threshold = 0.45;
2109        let iou_threshold = 0.45;
2110        let boxes = load_yolov8_boxes();
2111        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2112        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2113
2114        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2115
2116        let decoder = DecoderBuilder::default()
2117            .with_config_yolo_split_det(
2118                configs::Boxes {
2119                    decoder: configs::DecoderType::Ultralytics,
2120                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2121                    shape: vec![1, 4, 8400],
2122                    dshape: vec![
2123                        (DimName::Batch, 1),
2124                        (DimName::BoxCoords, 4),
2125                        (DimName::NumBoxes, 8400),
2126                    ],
2127                    normalized: Some(true),
2128                },
2129                configs::Scores {
2130                    decoder: configs::DecoderType::Ultralytics,
2131                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2132                    shape: vec![1, 80, 8400],
2133                    dshape: vec![
2134                        (DimName::Batch, 1),
2135                        (DimName::NumClasses, 80),
2136                        (DimName::NumBoxes, 8400),
2137                    ],
2138                },
2139            )
2140            .with_score_threshold(score_threshold)
2141            .with_iou_threshold(iou_threshold)
2142            .build()
2143            .unwrap();
2144
2145        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2146        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2147
2148        decoder
2149            .decode_quantized(
2150                &[
2151                    boxes.slice(s![.., ..4, ..]).into(),
2152                    boxes.slice(s![.., 4..84, ..]).into(),
2153                ],
2154                &mut output_boxes,
2155                &mut output_masks,
2156            )
2157            .unwrap();
2158
2159        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2160        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2161        decode_yolo_det_float(
2162            seg.slice(s![0, ..84, ..]),
2163            score_threshold,
2164            iou_threshold,
2165            Some(configs::Nms::ClassAgnostic),
2166            &mut output_boxes_f32,
2167        );
2168
2169        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2170        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2171
2172        decoder
2173            .decode_float(
2174                &[
2175                    seg.slice(s![.., ..4, ..]).into_dyn(),
2176                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2177                ],
2178                &mut output_boxes1,
2179                &mut output_masks1,
2180            )
2181            .unwrap();
2182        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2183        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2184    }
2185
2186    #[test]
2187    fn test_decoder_masks_config_mixed() {
2188        let score_threshold = 0.45;
2189        let iou_threshold = 0.45;
2190        let boxes_raw = load_yolov8_boxes();
2191        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2192        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2193
2194        let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2195
2196        let protos = load_yolov8_protos();
2197        let quant_protos = (0.02491161972284317, -117);
2198
2199        let decoder = build_yolo_split_segdet_decoder(
2200            score_threshold,
2201            iou_threshold,
2202            quant_boxes,
2203            quant_protos,
2204        );
2205        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2206        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2207
2208        decoder
2209            .decode_quantized(
2210                &[
2211                    boxes.slice(s![.., ..4, ..]).into(),
2212                    boxes.slice(s![.., 4..84, ..]).into(),
2213                    boxes.slice(s![.., 84.., ..]).into(),
2214                    protos.view().into(),
2215                ],
2216                &mut output_boxes,
2217                &mut output_masks,
2218            )
2219            .unwrap();
2220
2221        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2222        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2223        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2224        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2225        decode_yolo_segdet_float(
2226            seg.slice(s![0, .., ..]),
2227            protos.slice(s![0, .., .., ..]),
2228            score_threshold,
2229            iou_threshold,
2230            Some(configs::Nms::ClassAgnostic),
2231            &mut output_boxes_f32,
2232            &mut output_masks_f32,
2233        )
2234        .unwrap();
2235
2236        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2237        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2238
2239        decoder
2240            .decode_float(
2241                &[
2242                    seg.slice(s![.., ..4, ..]).into_dyn(),
2243                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2244                    seg.slice(s![.., 84.., ..]).into_dyn(),
2245                    protos.view().into_dyn(),
2246                ],
2247                &mut output_boxes1,
2248                &mut output_masks1,
2249            )
2250            .unwrap();
2251        compare_outputs(
2252            (&output_boxes, &output_boxes_f32),
2253            (&output_masks, &output_masks_f32),
2254        );
2255        compare_outputs(
2256            (&output_boxes_f32, &output_boxes1),
2257            (&output_masks_f32, &output_masks1),
2258        );
2259    }
2260
2261    fn build_yolo_split_segdet_decoder(
2262        score_threshold: f32,
2263        iou_threshold: f32,
2264        quant_boxes: (f32, i32),
2265        quant_protos: (f32, i32),
2266    ) -> crate::Decoder {
2267        DecoderBuilder::default()
2268            .with_config_yolo_split_segdet(
2269                configs::Boxes {
2270                    decoder: configs::DecoderType::Ultralytics,
2271                    quantization: Some(quant_boxes.into()),
2272                    shape: vec![1, 4, 8400],
2273                    dshape: vec![
2274                        (DimName::Batch, 1),
2275                        (DimName::BoxCoords, 4),
2276                        (DimName::NumBoxes, 8400),
2277                    ],
2278                    normalized: Some(true),
2279                },
2280                configs::Scores {
2281                    decoder: configs::DecoderType::Ultralytics,
2282                    quantization: Some(quant_boxes.into()),
2283                    shape: vec![1, 80, 8400],
2284                    dshape: vec![
2285                        (DimName::Batch, 1),
2286                        (DimName::NumClasses, 80),
2287                        (DimName::NumBoxes, 8400),
2288                    ],
2289                },
2290                configs::MaskCoefficients {
2291                    decoder: configs::DecoderType::Ultralytics,
2292                    quantization: Some(quant_boxes.into()),
2293                    shape: vec![1, 32, 8400],
2294                    dshape: vec![
2295                        (DimName::Batch, 1),
2296                        (DimName::NumProtos, 32),
2297                        (DimName::NumBoxes, 8400),
2298                    ],
2299                },
2300                configs::Protos {
2301                    decoder: configs::DecoderType::Ultralytics,
2302                    quantization: Some(quant_protos.into()),
2303                    shape: vec![1, 160, 160, 32],
2304                    dshape: vec![
2305                        (DimName::Batch, 1),
2306                        (DimName::Height, 160),
2307                        (DimName::Width, 160),
2308                        (DimName::NumProtos, 32),
2309                    ],
2310                },
2311            )
2312            .with_score_threshold(score_threshold)
2313            .with_iou_threshold(iou_threshold)
2314            .build()
2315            .unwrap()
2316    }
2317
2318    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2319        let config_yaml = include_str!(concat!(
2320            env!("CARGO_MANIFEST_DIR"),
2321            "/../../testdata/yolov8_seg.yaml"
2322        ));
2323        DecoderBuilder::default()
2324            .with_config_yaml_str(config_yaml.to_string())
2325            .with_score_threshold(score_threshold)
2326            .with_iou_threshold(iou_threshold)
2327            .build()
2328            .unwrap()
2329    }
2330    #[test]
2331    fn test_decoder_masks_config_i32() {
2332        let score_threshold = 0.45;
2333        let iou_threshold = 0.45;
2334        let boxes_raw = load_yolov8_boxes();
2335        let scale = 1 << 23;
2336        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2337        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2338
2339        let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2340
2341        let protos_raw = load_yolov8_protos();
2342        let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2343        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2344        let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2345
2346        let decoder = build_yolo_split_segdet_decoder(
2347            score_threshold,
2348            iou_threshold,
2349            quant_boxes,
2350            quant_protos,
2351        );
2352
2353        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2354        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2355
2356        decoder
2357            .decode_quantized(
2358                &[
2359                    boxes.slice(s![.., ..4, ..]).into(),
2360                    boxes.slice(s![.., 4..84, ..]).into(),
2361                    boxes.slice(s![.., 84.., ..]).into(),
2362                    protos.view().into(),
2363                ],
2364                &mut output_boxes,
2365                &mut output_masks,
2366            )
2367            .unwrap();
2368
2369        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2370        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2371        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2372        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2373        decode_yolo_segdet_float(
2374            seg.slice(s![0, .., ..]),
2375            protos.slice(s![0, .., .., ..]),
2376            score_threshold,
2377            iou_threshold,
2378            Some(configs::Nms::ClassAgnostic),
2379            &mut output_boxes_f32,
2380            &mut output_masks_f32,
2381        )
2382        .unwrap();
2383
2384        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2385        assert_eq!(output_masks.len(), output_masks_f32.len());
2386
2387        compare_outputs(
2388            (&output_boxes, &output_boxes_f32),
2389            (&output_masks, &output_masks_f32),
2390        );
2391    }
2392
2393    /// test running multiple decoders concurrently
2394    #[test]
2395    fn test_context_switch() {
2396        let yolo_det = || {
2397            let score_threshold = 0.25;
2398            let iou_threshold = 0.7;
2399            let out = load_yolov8s_det();
2400            let quant = (0.0040811873, -123).into();
2401
2402            let decoder = DecoderBuilder::default()
2403                .with_config_yolo_det(
2404                    configs::Detection {
2405                        decoder: DecoderType::Ultralytics,
2406                        shape: vec![1, 84, 8400],
2407                        anchors: None,
2408                        quantization: Some(quant),
2409                        dshape: vec![
2410                            (DimName::Batch, 1),
2411                            (DimName::NumFeatures, 84),
2412                            (DimName::NumBoxes, 8400),
2413                        ],
2414                        normalized: None,
2415                    },
2416                    None,
2417                )
2418                .with_score_threshold(score_threshold)
2419                .with_iou_threshold(iou_threshold)
2420                .build()
2421                .unwrap();
2422
2423            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2424            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2425
2426            for _ in 0..100 {
2427                decoder
2428                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2429                    .unwrap();
2430
2431                assert!(output_boxes[0].equal_within_delta(
2432                    &DetectBox {
2433                        bbox: BoundingBox {
2434                            xmin: 0.5285137,
2435                            ymin: 0.05305544,
2436                            xmax: 0.87541467,
2437                            ymax: 0.9998909,
2438                        },
2439                        score: 0.5591227,
2440                        label: 0
2441                    },
2442                    1e-6
2443                ));
2444
2445                assert!(output_boxes[1].equal_within_delta(
2446                    &DetectBox {
2447                        bbox: BoundingBox {
2448                            xmin: 0.130598,
2449                            ymin: 0.43260583,
2450                            xmax: 0.35098213,
2451                            ymax: 0.9958097,
2452                        },
2453                        score: 0.33057618,
2454                        label: 75
2455                    },
2456                    1e-6
2457                ));
2458                assert!(output_masks.is_empty());
2459            }
2460        };
2461
2462        let modelpack_det_split = || {
2463            let score_threshold = 0.8;
2464            let iou_threshold = 0.5;
2465
2466            let seg = include_bytes!(concat!(
2467                env!("CARGO_MANIFEST_DIR"),
2468                "/../../testdata/modelpack_seg_2x160x160.bin"
2469            ));
2470            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2471
2472            let detect0 = include_bytes!(concat!(
2473                env!("CARGO_MANIFEST_DIR"),
2474                "/../../testdata/modelpack_split_9x15x18.bin"
2475            ));
2476            let detect0 =
2477                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2478
2479            let detect1 = include_bytes!(concat!(
2480                env!("CARGO_MANIFEST_DIR"),
2481                "/../../testdata/modelpack_split_17x30x18.bin"
2482            ));
2483            let detect1 =
2484                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2485
2486            let mut mask = seg.slice(s![0, .., .., ..]);
2487            mask.swap_axes(0, 1);
2488            mask.swap_axes(1, 2);
2489            let mask = [Segmentation {
2490                xmin: 0.0,
2491                ymin: 0.0,
2492                xmax: 1.0,
2493                ymax: 1.0,
2494                segmentation: mask.into_owned(),
2495            }];
2496            let correct_boxes = [DetectBox {
2497                bbox: BoundingBox {
2498                    xmin: 0.43171933,
2499                    ymin: 0.68243736,
2500                    xmax: 0.5626645,
2501                    ymax: 0.808863,
2502                },
2503                score: 0.99240804,
2504                label: 0,
2505            }];
2506
2507            let quant0 = (0.08547406643629074, 174).into();
2508            let quant1 = (0.09929127991199493, 183).into();
2509            let quant_seg = (1.0 / 255.0, 0).into();
2510
2511            let anchors0 = vec![
2512                [0.36666667461395264, 0.31481480598449707],
2513                [0.38749998807907104, 0.4740740656852722],
2514                [0.5333333611488342, 0.644444465637207],
2515            ];
2516            let anchors1 = vec![
2517                [0.13750000298023224, 0.2074074000120163],
2518                [0.2541666626930237, 0.21481481194496155],
2519                [0.23125000298023224, 0.35185185074806213],
2520            ];
2521
2522            let decoder = DecoderBuilder::default()
2523                .with_config_modelpack_segdet_split(
2524                    vec![
2525                        configs::Detection {
2526                            decoder: DecoderType::ModelPack,
2527                            shape: vec![1, 17, 30, 18],
2528                            anchors: Some(anchors1),
2529                            quantization: Some(quant1),
2530                            dshape: vec![
2531                                (DimName::Batch, 1),
2532                                (DimName::Height, 17),
2533                                (DimName::Width, 30),
2534                                (DimName::NumAnchorsXFeatures, 18),
2535                            ],
2536                            normalized: None,
2537                        },
2538                        configs::Detection {
2539                            decoder: DecoderType::ModelPack,
2540                            shape: vec![1, 9, 15, 18],
2541                            anchors: Some(anchors0),
2542                            quantization: Some(quant0),
2543                            dshape: vec![
2544                                (DimName::Batch, 1),
2545                                (DimName::Height, 9),
2546                                (DimName::Width, 15),
2547                                (DimName::NumAnchorsXFeatures, 18),
2548                            ],
2549                            normalized: None,
2550                        },
2551                    ],
2552                    configs::Segmentation {
2553                        decoder: DecoderType::ModelPack,
2554                        quantization: Some(quant_seg),
2555                        shape: vec![1, 2, 160, 160],
2556                        dshape: vec![
2557                            (DimName::Batch, 1),
2558                            (DimName::NumClasses, 2),
2559                            (DimName::Height, 160),
2560                            (DimName::Width, 160),
2561                        ],
2562                    },
2563                )
2564                .with_score_threshold(score_threshold)
2565                .with_iou_threshold(iou_threshold)
2566                .build()
2567                .unwrap();
2568            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2569            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2570
2571            for _ in 0..100 {
2572                decoder
2573                    .decode_quantized(
2574                        &[
2575                            detect0.view().into(),
2576                            detect1.view().into(),
2577                            seg.view().into(),
2578                        ],
2579                        &mut output_boxes,
2580                        &mut output_masks,
2581                    )
2582                    .unwrap();
2583
2584                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2585            }
2586        };
2587
2588        let handles = vec![
2589            std::thread::spawn(yolo_det),
2590            std::thread::spawn(modelpack_det_split),
2591            std::thread::spawn(yolo_det),
2592            std::thread::spawn(modelpack_det_split),
2593            std::thread::spawn(yolo_det),
2594            std::thread::spawn(modelpack_det_split),
2595            std::thread::spawn(yolo_det),
2596            std::thread::spawn(modelpack_det_split),
2597        ];
2598        for handle in handles {
2599            handle.join().unwrap();
2600        }
2601    }
2602
2603    #[test]
2604    fn test_ndarray_to_xyxy_float() {
2605        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2606        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2607        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2608
2609        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2610        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2611        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2612    }
2613
2614    #[test]
2615    fn test_class_aware_nms_float() {
2616        use crate::float::nms_class_aware_float;
2617
2618        // Create two overlapping boxes with different classes
2619        let boxes = vec![
2620            DetectBox {
2621                bbox: BoundingBox {
2622                    xmin: 0.0,
2623                    ymin: 0.0,
2624                    xmax: 0.5,
2625                    ymax: 0.5,
2626                },
2627                score: 0.9,
2628                label: 0, // class 0
2629            },
2630            DetectBox {
2631                bbox: BoundingBox {
2632                    xmin: 0.1,
2633                    ymin: 0.1,
2634                    xmax: 0.6,
2635                    ymax: 0.6,
2636                },
2637                score: 0.8,
2638                label: 1, // class 1 - different class
2639            },
2640        ];
2641
2642        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2643        // threshold 0.3)
2644        let result = nms_class_aware_float(0.3, boxes.clone());
2645        assert_eq!(
2646            result.len(),
2647            2,
2648            "Class-aware NMS should keep both boxes with different classes"
2649        );
2650
2651        // Now test with same class - should suppress one
2652        let same_class_boxes = vec![
2653            DetectBox {
2654                bbox: BoundingBox {
2655                    xmin: 0.0,
2656                    ymin: 0.0,
2657                    xmax: 0.5,
2658                    ymax: 0.5,
2659                },
2660                score: 0.9,
2661                label: 0,
2662            },
2663            DetectBox {
2664                bbox: BoundingBox {
2665                    xmin: 0.1,
2666                    ymin: 0.1,
2667                    xmax: 0.6,
2668                    ymax: 0.6,
2669                },
2670                score: 0.8,
2671                label: 0, // same class
2672            },
2673        ];
2674
2675        let result = nms_class_aware_float(0.3, same_class_boxes);
2676        assert_eq!(
2677            result.len(),
2678            1,
2679            "Class-aware NMS should suppress overlapping box with same class"
2680        );
2681        assert_eq!(result[0].label, 0);
2682        assert!((result[0].score - 0.9).abs() < 1e-6);
2683    }
2684
2685    #[test]
2686    fn test_class_agnostic_vs_aware_nms() {
2687        use crate::float::{nms_class_aware_float, nms_float};
2688
2689        // Two overlapping boxes with different classes
2690        let boxes = vec![
2691            DetectBox {
2692                bbox: BoundingBox {
2693                    xmin: 0.0,
2694                    ymin: 0.0,
2695                    xmax: 0.5,
2696                    ymax: 0.5,
2697                },
2698                score: 0.9,
2699                label: 0,
2700            },
2701            DetectBox {
2702                bbox: BoundingBox {
2703                    xmin: 0.1,
2704                    ymin: 0.1,
2705                    xmax: 0.6,
2706                    ymax: 0.6,
2707                },
2708                score: 0.8,
2709                label: 1,
2710            },
2711        ];
2712
2713        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2714        let agnostic_result = nms_float(0.3, boxes.clone());
2715        assert_eq!(
2716            agnostic_result.len(),
2717            1,
2718            "Class-agnostic NMS should suppress overlapping boxes"
2719        );
2720
2721        // Class-aware should keep both (different classes)
2722        let aware_result = nms_class_aware_float(0.3, boxes);
2723        assert_eq!(
2724            aware_result.len(),
2725            2,
2726            "Class-aware NMS should keep boxes with different classes"
2727        );
2728    }
2729
2730    #[test]
2731    fn test_class_aware_nms_int() {
2732        use crate::byte::nms_class_aware_int;
2733
2734        // Create two overlapping boxes with different classes
2735        let boxes = vec![
2736            DetectBoxQuantized {
2737                bbox: BoundingBox {
2738                    xmin: 0.0,
2739                    ymin: 0.0,
2740                    xmax: 0.5,
2741                    ymax: 0.5,
2742                },
2743                score: 200_u8,
2744                label: 0,
2745            },
2746            DetectBoxQuantized {
2747                bbox: BoundingBox {
2748                    xmin: 0.1,
2749                    ymin: 0.1,
2750                    xmax: 0.6,
2751                    ymax: 0.6,
2752                },
2753                score: 180_u8,
2754                label: 1, // different class
2755            },
2756        ];
2757
2758        // Should keep both (different classes)
2759        let result = nms_class_aware_int(0.5, boxes);
2760        assert_eq!(
2761            result.len(),
2762            2,
2763            "Class-aware NMS (int) should keep boxes with different classes"
2764        );
2765    }
2766
2767    #[test]
2768    fn test_nms_enum_default() {
2769        // Test that Nms enum has the correct default
2770        let default_nms: configs::Nms = Default::default();
2771        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2772    }
2773
2774    #[test]
2775    fn test_decoder_nms_mode() {
2776        // Test that decoder properly stores NMS mode
2777        let decoder = DecoderBuilder::default()
2778            .with_config_yolo_det(
2779                configs::Detection {
2780                    anchors: None,
2781                    decoder: DecoderType::Ultralytics,
2782                    quantization: None,
2783                    shape: vec![1, 84, 8400],
2784                    dshape: Vec::new(),
2785                    normalized: Some(true),
2786                },
2787                None,
2788            )
2789            .with_nms(Some(configs::Nms::ClassAware))
2790            .build()
2791            .unwrap();
2792
2793        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2794    }
2795
2796    #[test]
2797    fn test_decoder_nms_bypass() {
2798        // Test that decoder can be configured with nms=None (bypass)
2799        let decoder = DecoderBuilder::default()
2800            .with_config_yolo_det(
2801                configs::Detection {
2802                    anchors: None,
2803                    decoder: DecoderType::Ultralytics,
2804                    quantization: None,
2805                    shape: vec![1, 84, 8400],
2806                    dshape: Vec::new(),
2807                    normalized: Some(true),
2808                },
2809                None,
2810            )
2811            .with_nms(None)
2812            .build()
2813            .unwrap();
2814
2815        assert_eq!(decoder.nms, None);
2816    }
2817
2818    #[test]
2819    fn test_decoder_normalized_boxes_true() {
2820        // Test that normalized_boxes returns Some(true) when explicitly set
2821        let decoder = DecoderBuilder::default()
2822            .with_config_yolo_det(
2823                configs::Detection {
2824                    anchors: None,
2825                    decoder: DecoderType::Ultralytics,
2826                    quantization: None,
2827                    shape: vec![1, 84, 8400],
2828                    dshape: Vec::new(),
2829                    normalized: Some(true),
2830                },
2831                None,
2832            )
2833            .build()
2834            .unwrap();
2835
2836        assert_eq!(decoder.normalized_boxes(), Some(true));
2837    }
2838
2839    #[test]
2840    fn test_decoder_normalized_boxes_false() {
2841        // Test that normalized_boxes returns Some(false) when config specifies
2842        // unnormalized
2843        let decoder = DecoderBuilder::default()
2844            .with_config_yolo_det(
2845                configs::Detection {
2846                    anchors: None,
2847                    decoder: DecoderType::Ultralytics,
2848                    quantization: None,
2849                    shape: vec![1, 84, 8400],
2850                    dshape: Vec::new(),
2851                    normalized: Some(false),
2852                },
2853                None,
2854            )
2855            .build()
2856            .unwrap();
2857
2858        assert_eq!(decoder.normalized_boxes(), Some(false));
2859    }
2860
2861    #[test]
2862    fn test_decoder_normalized_boxes_unknown() {
2863        // Test that normalized_boxes returns None when not specified in config
2864        let decoder = DecoderBuilder::default()
2865            .with_config_yolo_det(
2866                configs::Detection {
2867                    anchors: None,
2868                    decoder: DecoderType::Ultralytics,
2869                    quantization: None,
2870                    shape: vec![1, 84, 8400],
2871                    dshape: Vec::new(),
2872                    normalized: None,
2873                },
2874                Some(DecoderVersion::Yolo11),
2875            )
2876            .build()
2877            .unwrap();
2878
2879        assert_eq!(decoder.normalized_boxes(), None);
2880    }
2881
2882    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2883        input: ArrayView<F, D>,
2884        quant: Quantization,
2885    ) -> Array<T, D>
2886    where
2887        i32: num_traits::AsPrimitive<F>,
2888        f32: num_traits::AsPrimitive<F>,
2889    {
2890        let zero_point = quant.zero_point.as_();
2891        let div_scale = F::one() / quant.scale.as_();
2892        if zero_point != F::zero() {
2893            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2894        } else {
2895            input.mapv(|d| (d * div_scale).round().as_())
2896        }
2897    }
2898
2899    fn real_data_expected_boxes() -> [DetectBox; 2] {
2900        [
2901            DetectBox {
2902                bbox: BoundingBox {
2903                    xmin: 0.08515105,
2904                    ymin: 0.7131401,
2905                    xmax: 0.29802868,
2906                    ymax: 0.8195788,
2907                },
2908                score: 0.91537374,
2909                label: 23,
2910            },
2911            DetectBox {
2912                bbox: BoundingBox {
2913                    xmin: 0.59605736,
2914                    ymin: 0.25545314,
2915                    xmax: 0.93666154,
2916                    ymax: 0.72378385,
2917                },
2918                score: 0.91537374,
2919                label: 23,
2920            },
2921        ]
2922    }
2923
2924    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2925        [DetectBox {
2926            bbox: BoundingBox {
2927                xmin: 0.12549022,
2928                ymin: 0.12549022,
2929                xmax: 0.23529413,
2930                ymax: 0.23529413,
2931            },
2932            score: 0.98823535,
2933            label: 2,
2934        }]
2935    }
2936
2937    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2938        [DetectBox {
2939            bbox: BoundingBox {
2940                xmin: 0.1234,
2941                ymin: 0.1234,
2942                xmax: 0.2345,
2943                ymax: 0.2345,
2944            },
2945            score: 0.9876,
2946            label: 2,
2947        }]
2948    }
2949
2950    macro_rules! real_data_proto_test {
2951        ($name:ident, quantized, $layout:ident) => {
2952            #[test]
2953            fn $name() {
2954                let is_split = matches!(stringify!($layout), "split");
2955
2956                let score_threshold = 0.45;
2957                let iou_threshold = 0.45;
2958                let quant_boxes = (0.021287762_f32, 31_i32);
2959                let quant_protos = (0.02491162_f32, -117_i32);
2960
2961                let raw_boxes = include_bytes!(concat!(
2962                    env!("CARGO_MANIFEST_DIR"),
2963                    "/../../testdata/yolov8_boxes_116x8400.bin"
2964                ));
2965                let raw_boxes = unsafe {
2966                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2967                };
2968                let boxes_i8 =
2969                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2970
2971                let raw_protos = include_bytes!(concat!(
2972                    env!("CARGO_MANIFEST_DIR"),
2973                    "/../../testdata/yolov8_protos_160x160x32.bin"
2974                ));
2975                let raw_protos = unsafe {
2976                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2977                };
2978                let protos_i8 =
2979                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2980                        .unwrap();
2981
2982                // Pre-split (unused for combined, but harmless)
2983                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2984                let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2985                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2986                let boxes_combined = boxes_i8;
2987
2988                let decoder = if is_split {
2989                    build_yolo_split_segdet_decoder(
2990                        score_threshold,
2991                        iou_threshold,
2992                        quant_boxes,
2993                        quant_protos,
2994                    )
2995                } else {
2996                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2997                };
2998
2999                let expected = real_data_expected_boxes();
3000                let mut output_boxes = Vec::with_capacity(50);
3001
3002                let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3003                    vec![
3004                        boxes_split.view().into(),
3005                        scores_split.view().into(),
3006                        mask_split.view().into(),
3007                        protos_i8.view().into(),
3008                    ]
3009                } else {
3010                    vec![boxes_combined.view().into(), protos_i8.view().into()]
3011                };
3012                decoder
3013                    .decode_quantized_proto(&inputs, &mut output_boxes)
3014                    .unwrap();
3015
3016                assert_eq!(output_boxes.len(), 2);
3017                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3018                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3019            }
3020        };
3021        ($name:ident, float, $layout:ident) => {
3022            #[test]
3023            fn $name() {
3024                let is_split = matches!(stringify!($layout), "split");
3025
3026                let score_threshold = 0.45;
3027                let iou_threshold = 0.45;
3028                let quant_boxes = (0.021287762_f32, 31_i32);
3029                let quant_protos = (0.02491162_f32, -117_i32);
3030
3031                let raw_boxes = include_bytes!(concat!(
3032                    env!("CARGO_MANIFEST_DIR"),
3033                    "/../../testdata/yolov8_boxes_116x8400.bin"
3034                ));
3035                let raw_boxes = unsafe {
3036                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
3037                };
3038                let boxes_i8 =
3039                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
3040                let boxes_f32: Array3<f32> =
3041                    dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
3042
3043                let raw_protos = include_bytes!(concat!(
3044                    env!("CARGO_MANIFEST_DIR"),
3045                    "/../../testdata/yolov8_protos_160x160x32.bin"
3046                ));
3047                let raw_protos = unsafe {
3048                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3049                };
3050                let protos_i8 =
3051                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
3052                        .unwrap();
3053                let protos_f32: Array4<f32> =
3054                    dequantize_ndarray(protos_i8.view(), quant_protos.into());
3055
3056                // Pre-split from dequantized data
3057                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
3058                let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
3059                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
3060                let boxes_combined = boxes_f32;
3061
3062                let decoder = if is_split {
3063                    build_yolo_split_segdet_decoder(
3064                        score_threshold,
3065                        iou_threshold,
3066                        quant_boxes,
3067                        quant_protos,
3068                    )
3069                } else {
3070                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
3071                };
3072
3073                let expected = real_data_expected_boxes();
3074                let mut output_boxes = Vec::with_capacity(50);
3075
3076                let inputs = if is_split {
3077                    vec![
3078                        boxes_split.view().into_dyn(),
3079                        scores_split.view().into_dyn(),
3080                        mask_split.view().into_dyn(),
3081                        protos_f32.view().into_dyn(),
3082                    ]
3083                } else {
3084                    vec![
3085                        boxes_combined.view().into_dyn(),
3086                        protos_f32.view().into_dyn(),
3087                    ]
3088                };
3089                decoder
3090                    .decode_float_proto(&inputs, &mut output_boxes)
3091                    .unwrap();
3092
3093                assert_eq!(output_boxes.len(), 2);
3094                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3095                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3096            }
3097        };
3098    }
3099
3100    real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
3101    real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
3102    real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
3103    real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3104
3105    const E2E_COMBINED_DET_CONFIG: &str = "
3106decoder_version: yolo26
3107outputs:
3108 - type: detection
3109   decoder: ultralytics
3110   quantization: [0.00784313725490196, 0]
3111   shape: [1, 10, 6]
3112   dshape:
3113    - [batch, 1]
3114    - [num_boxes, 10]
3115    - [num_features, 6]
3116   normalized: true
3117";
3118
3119    const E2E_COMBINED_SEGDET_CONFIG: &str = "
3120decoder_version: yolo26
3121outputs:
3122 - type: detection
3123   decoder: ultralytics
3124   quantization: [0.00784313725490196, 0]
3125   shape: [1, 10, 38]
3126   dshape:
3127    - [batch, 1]
3128    - [num_boxes, 10]
3129    - [num_features, 38]
3130   normalized: true
3131 - type: protos
3132   decoder: ultralytics
3133   quantization: [0.0039215686274509803921568627451, 128]
3134   shape: [1, 160, 160, 32]
3135   dshape:
3136    - [batch, 1]
3137    - [height, 160]
3138    - [width, 160]
3139    - [num_protos, 32]
3140";
3141
3142    const E2E_SPLIT_DET_CONFIG: &str = "
3143decoder_version: yolo26
3144outputs:
3145 - type: boxes
3146   decoder: ultralytics
3147   quantization: [0.00784313725490196, 0]
3148   shape: [1, 10, 4]
3149   dshape:
3150    - [batch, 1]
3151    - [num_boxes, 10]
3152    - [box_coords, 4]
3153   normalized: true
3154 - type: scores
3155   decoder: ultralytics
3156   quantization: [0.00784313725490196, 0]
3157   shape: [1, 10, 1]
3158   dshape:
3159    - [batch, 1]
3160    - [num_boxes, 10]
3161    - [num_classes, 1]
3162 - type: classes
3163   decoder: ultralytics
3164   quantization: [0.00784313725490196, 0]
3165   shape: [1, 10, 1]
3166   dshape:
3167    - [batch, 1]
3168    - [num_boxes, 10]
3169    - [num_classes, 1]
3170";
3171
3172    const E2E_SPLIT_SEGDET_CONFIG: &str = "
3173decoder_version: yolo26
3174outputs:
3175 - type: boxes
3176   decoder: ultralytics
3177   quantization: [0.00784313725490196, 0]
3178   shape: [1, 10, 4]
3179   dshape:
3180    - [batch, 1]
3181    - [num_boxes, 10]
3182    - [box_coords, 4]
3183   normalized: true
3184 - type: scores
3185   decoder: ultralytics
3186   quantization: [0.00784313725490196, 0]
3187   shape: [1, 10, 1]
3188   dshape:
3189    - [batch, 1]
3190    - [num_boxes, 10]
3191    - [num_classes, 1]
3192 - type: classes
3193   decoder: ultralytics
3194   quantization: [0.00784313725490196, 0]
3195   shape: [1, 10, 1]
3196   dshape:
3197    - [batch, 1]
3198    - [num_boxes, 10]
3199    - [num_classes, 1]
3200 - type: mask_coefficients
3201   decoder: ultralytics
3202   quantization: [0.00784313725490196, 0]
3203   shape: [1, 10, 32]
3204   dshape:
3205    - [batch, 1]
3206    - [num_boxes, 10]
3207    - [num_protos, 32]
3208 - type: protos
3209   decoder: ultralytics
3210   quantization: [0.0039215686274509803921568627451, 128]
3211   shape: [1, 160, 160, 32]
3212   dshape:
3213    - [batch, 1]
3214    - [height, 160]
3215    - [width, 160]
3216    - [num_protos, 32]
3217";
3218
3219    macro_rules! e2e_segdet_test {
3220        ($name:ident, quantized, $layout:ident, $output:ident) => {
3221            #[test]
3222            fn $name() {
3223                let is_split = matches!(stringify!($layout), "split");
3224                let is_proto = matches!(stringify!($output), "proto");
3225
3226                let score_threshold = 0.45;
3227                let iou_threshold = 0.45;
3228
3229                let mut boxes = Array2::zeros((10, 4));
3230                let mut scores = Array2::zeros((10, 1));
3231                let mut classes = Array2::zeros((10, 1));
3232                let mask = Array2::zeros((10, 32));
3233                let protos = Array3::<f64>::zeros((160, 160, 32));
3234                let protos = protos.insert_axis(Axis(0));
3235                let protos_quant = (1.0 / 255.0, 0.0);
3236                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3237
3238                boxes
3239                    .slice_mut(s![0, ..])
3240                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3241                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3242                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3243
3244                let detect_quant = (2.0 / 255.0, 0.0);
3245
3246                let decoder = if is_split {
3247                    DecoderBuilder::default()
3248                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3249                        .with_score_threshold(score_threshold)
3250                        .with_iou_threshold(iou_threshold)
3251                        .build()
3252                        .unwrap()
3253                } else {
3254                    DecoderBuilder::default()
3255                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3256                        .with_score_threshold(score_threshold)
3257                        .with_iou_threshold(iou_threshold)
3258                        .build()
3259                        .unwrap()
3260                };
3261
3262                let expected = e2e_expected_boxes_quant();
3263                let mut output_boxes = Vec::with_capacity(50);
3264
3265                if is_split {
3266                    let boxes = boxes.insert_axis(Axis(0));
3267                    let scores = scores.insert_axis(Axis(0));
3268                    let classes = classes.insert_axis(Axis(0));
3269                    let mask = mask.insert_axis(Axis(0));
3270
3271                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3272                    let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3273                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3274                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3275
3276                    if is_proto {
3277                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3278                            boxes.view().into(),
3279                            scores.view().into(),
3280                            classes.view().into(),
3281                            mask.view().into(),
3282                            protos.view().into(),
3283                        ];
3284                        decoder
3285                            .decode_quantized_proto(&inputs, &mut output_boxes)
3286                            .unwrap();
3287
3288                        assert_eq!(output_boxes.len(), 1);
3289                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3290                    } else {
3291                        let mut output_masks = Vec::with_capacity(50);
3292                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3293                            boxes.view().into(),
3294                            scores.view().into(),
3295                            classes.view().into(),
3296                            mask.view().into(),
3297                            protos.view().into(),
3298                        ];
3299                        decoder
3300                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3301                            .unwrap();
3302
3303                        assert_eq!(output_boxes.len(), 1);
3304                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3305                    }
3306                } else {
3307                    // Combined layout
3308                    let detect = ndarray::concatenate![
3309                        Axis(1),
3310                        boxes.view(),
3311                        scores.view(),
3312                        classes.view(),
3313                        mask.view()
3314                    ];
3315                    let detect = detect.insert_axis(Axis(0));
3316                    assert_eq!(detect.shape(), &[1, 10, 38]);
3317                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3318
3319                    if is_proto {
3320                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3321                            vec![detect.view().into(), protos.view().into()];
3322                        decoder
3323                            .decode_quantized_proto(&inputs, &mut output_boxes)
3324                            .unwrap();
3325
3326                        assert_eq!(output_boxes.len(), 1);
3327                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3328                    } else {
3329                        let mut output_masks = Vec::with_capacity(50);
3330                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3331                            vec![detect.view().into(), protos.view().into()];
3332                        decoder
3333                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3334                            .unwrap();
3335
3336                        assert_eq!(output_boxes.len(), 1);
3337                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3338                    }
3339                }
3340            }
3341        };
3342        ($name:ident, float, $layout:ident, $output:ident) => {
3343            #[test]
3344            fn $name() {
3345                let is_split = matches!(stringify!($layout), "split");
3346                let is_proto = matches!(stringify!($output), "proto");
3347
3348                let score_threshold = 0.45;
3349                let iou_threshold = 0.45;
3350
3351                let mut boxes = Array2::zeros((10, 4));
3352                let mut scores = Array2::zeros((10, 1));
3353                let mut classes = Array2::zeros((10, 1));
3354                let mask: Array2<f64> = Array2::zeros((10, 32));
3355                let protos = Array3::<f64>::zeros((160, 160, 32));
3356                let protos = protos.insert_axis(Axis(0));
3357
3358                boxes
3359                    .slice_mut(s![0, ..])
3360                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3361                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3362                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3363
3364                let decoder = if is_split {
3365                    DecoderBuilder::default()
3366                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3367                        .with_score_threshold(score_threshold)
3368                        .with_iou_threshold(iou_threshold)
3369                        .build()
3370                        .unwrap()
3371                } else {
3372                    DecoderBuilder::default()
3373                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3374                        .with_score_threshold(score_threshold)
3375                        .with_iou_threshold(iou_threshold)
3376                        .build()
3377                        .unwrap()
3378                };
3379
3380                let expected = e2e_expected_boxes_float();
3381                let mut output_boxes = Vec::with_capacity(50);
3382
3383                if is_split {
3384                    let boxes = boxes.insert_axis(Axis(0));
3385                    let scores = scores.insert_axis(Axis(0));
3386                    let classes = classes.insert_axis(Axis(0));
3387                    let mask = mask.insert_axis(Axis(0));
3388
3389                    if is_proto {
3390                        let inputs = vec![
3391                            boxes.view().into_dyn(),
3392                            scores.view().into_dyn(),
3393                            classes.view().into_dyn(),
3394                            mask.view().into_dyn(),
3395                            protos.view().into_dyn(),
3396                        ];
3397                        decoder
3398                            .decode_float_proto(&inputs, &mut output_boxes)
3399                            .unwrap();
3400
3401                        assert_eq!(output_boxes.len(), 1);
3402                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3403                    } else {
3404                        let mut output_masks = Vec::with_capacity(50);
3405                        let inputs = vec![
3406                            boxes.view().into_dyn(),
3407                            scores.view().into_dyn(),
3408                            classes.view().into_dyn(),
3409                            mask.view().into_dyn(),
3410                            protos.view().into_dyn(),
3411                        ];
3412                        decoder
3413                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3414                            .unwrap();
3415
3416                        assert_eq!(output_boxes.len(), 1);
3417                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3418                    }
3419                } else {
3420                    // Combined layout
3421                    let detect = ndarray::concatenate![
3422                        Axis(1),
3423                        boxes.view(),
3424                        scores.view(),
3425                        classes.view(),
3426                        mask.view()
3427                    ];
3428                    let detect = detect.insert_axis(Axis(0));
3429                    assert_eq!(detect.shape(), &[1, 10, 38]);
3430
3431                    if is_proto {
3432                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3433                        decoder
3434                            .decode_float_proto(&inputs, &mut output_boxes)
3435                            .unwrap();
3436
3437                        assert_eq!(output_boxes.len(), 1);
3438                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3439                    } else {
3440                        let mut output_masks = Vec::with_capacity(50);
3441                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3442                        decoder
3443                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3444                            .unwrap();
3445
3446                        assert_eq!(output_boxes.len(), 1);
3447                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3448                    }
3449                }
3450            }
3451        };
3452    }
3453
3454    e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3455    e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3456    e2e_segdet_test!(
3457        test_decoder_end_to_end_segdet_proto,
3458        quantized,
3459        combined,
3460        proto
3461    );
3462    e2e_segdet_test!(
3463        test_decoder_end_to_end_segdet_proto_float,
3464        float,
3465        combined,
3466        proto
3467    );
3468    e2e_segdet_test!(
3469        test_decoder_end_to_end_segdet_split,
3470        quantized,
3471        split,
3472        masks
3473    );
3474    e2e_segdet_test!(
3475        test_decoder_end_to_end_segdet_split_float,
3476        float,
3477        split,
3478        masks
3479    );
3480    e2e_segdet_test!(
3481        test_decoder_end_to_end_segdet_split_proto,
3482        quantized,
3483        split,
3484        proto
3485    );
3486    e2e_segdet_test!(
3487        test_decoder_end_to_end_segdet_split_proto_float,
3488        float,
3489        split,
3490        proto
3491    );
3492
3493    macro_rules! e2e_det_test {
3494        ($name:ident, quantized, $layout:ident) => {
3495            #[test]
3496            fn $name() {
3497                let is_split = matches!(stringify!($layout), "split");
3498
3499                let score_threshold = 0.45;
3500                let iou_threshold = 0.45;
3501
3502                let mut boxes = Array3::zeros((1, 10, 4));
3503                let mut scores = Array3::zeros((1, 10, 1));
3504                let mut classes = Array3::zeros((1, 10, 1));
3505
3506                boxes
3507                    .slice_mut(s![0, 0, ..])
3508                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3509                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3510                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3511
3512                let detect_quant = (2.0 / 255.0, 0_i32);
3513
3514                let decoder = if is_split {
3515                    DecoderBuilder::default()
3516                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3517                        .with_score_threshold(score_threshold)
3518                        .with_iou_threshold(iou_threshold)
3519                        .build()
3520                        .unwrap()
3521                } else {
3522                    DecoderBuilder::default()
3523                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3524                        .with_score_threshold(score_threshold)
3525                        .with_iou_threshold(iou_threshold)
3526                        .build()
3527                        .unwrap()
3528                };
3529
3530                let expected = e2e_expected_boxes_quant();
3531                let mut output_boxes = Vec::with_capacity(50);
3532
3533                if is_split {
3534                    let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3535                    let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3536                    let classes: Array<u8, _> =
3537                        quantize_ndarray(classes.view(), detect_quant.into());
3538                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3539                        boxes.view().into(),
3540                        scores.view().into(),
3541                        classes.view().into(),
3542                    ];
3543                    decoder
3544                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3545                        .unwrap();
3546                } else {
3547                    let detect =
3548                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3549                    assert_eq!(detect.shape(), &[1, 10, 6]);
3550                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3551                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3552                        vec![detect.view().into()];
3553                    decoder
3554                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3555                        .unwrap();
3556                }
3557
3558                assert_eq!(output_boxes.len(), 1);
3559                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3560            }
3561        };
3562        ($name:ident, float, $layout:ident) => {
3563            #[test]
3564            fn $name() {
3565                let is_split = matches!(stringify!($layout), "split");
3566
3567                let score_threshold = 0.45;
3568                let iou_threshold = 0.45;
3569
3570                let mut boxes = Array3::zeros((1, 10, 4));
3571                let mut scores = Array3::zeros((1, 10, 1));
3572                let mut classes = Array3::zeros((1, 10, 1));
3573
3574                boxes
3575                    .slice_mut(s![0, 0, ..])
3576                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3577                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3578                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3579
3580                let decoder = if is_split {
3581                    DecoderBuilder::default()
3582                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3583                        .with_score_threshold(score_threshold)
3584                        .with_iou_threshold(iou_threshold)
3585                        .build()
3586                        .unwrap()
3587                } else {
3588                    DecoderBuilder::default()
3589                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3590                        .with_score_threshold(score_threshold)
3591                        .with_iou_threshold(iou_threshold)
3592                        .build()
3593                        .unwrap()
3594                };
3595
3596                let expected = e2e_expected_boxes_float();
3597                let mut output_boxes = Vec::with_capacity(50);
3598
3599                if is_split {
3600                    let inputs = vec![
3601                        boxes.view().into_dyn(),
3602                        scores.view().into_dyn(),
3603                        classes.view().into_dyn(),
3604                    ];
3605                    decoder
3606                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3607                        .unwrap();
3608                } else {
3609                    let detect =
3610                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3611                    assert_eq!(detect.shape(), &[1, 10, 6]);
3612                    let inputs = vec![detect.view().into_dyn()];
3613                    decoder
3614                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3615                        .unwrap();
3616                }
3617
3618                assert_eq!(output_boxes.len(), 1);
3619                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3620            }
3621        };
3622    }
3623
3624    e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3625    e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3626    e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3627    e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3628
3629    #[test]
3630    fn test_decode_tensor() {
3631        let score_threshold = 0.45;
3632        let iou_threshold = 0.45;
3633
3634        let raw_boxes = include_bytes!(concat!(
3635            env!("CARGO_MANIFEST_DIR"),
3636            "/../../testdata/yolov8_boxes_116x8400.bin"
3637        ));
3638        let raw_boxes =
3639            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3640        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3641        boxes_i8
3642            .map()
3643            .unwrap()
3644            .as_mut_slice()
3645            .copy_from_slice(raw_boxes);
3646        let boxes_i8 = boxes_i8.into();
3647
3648        let raw_protos = include_bytes!(concat!(
3649            env!("CARGO_MANIFEST_DIR"),
3650            "/../../testdata/yolov8_protos_160x160x32.bin"
3651        ));
3652        let raw_protos = unsafe {
3653            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3654        };
3655        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3656        protos_i8
3657            .map()
3658            .unwrap()
3659            .as_mut_slice()
3660            .copy_from_slice(raw_protos);
3661        let protos_i8 = protos_i8.into();
3662
3663        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3664        let expected = real_data_expected_boxes();
3665        let mut output_boxes = Vec::with_capacity(50);
3666
3667        decoder
3668            .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3669            .unwrap();
3670
3671        assert_eq!(output_boxes.len(), 2);
3672        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3673        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3674    }
3675
3676    #[test]
3677    fn test_decode_tensor_f32() {
3678        let score_threshold = 0.45;
3679        let iou_threshold = 0.45;
3680
3681        let quant_boxes = (0.021287762_f32, 31_i32);
3682        let quant_protos = (0.02491162_f32, -117_i32);
3683        let raw_boxes = include_bytes!(concat!(
3684            env!("CARGO_MANIFEST_DIR"),
3685            "/../../testdata/yolov8_boxes_116x8400.bin"
3686        ));
3687        let raw_boxes =
3688            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3689        let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3690        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3691        let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3692        boxes_f32
3693            .map()
3694            .unwrap()
3695            .as_mut_slice()
3696            .copy_from_slice(&raw_boxes_f32);
3697        let boxes_f32 = boxes_f32.into();
3698
3699        let raw_protos = include_bytes!(concat!(
3700            env!("CARGO_MANIFEST_DIR"),
3701            "/../../testdata/yolov8_protos_160x160x32.bin"
3702        ));
3703        let raw_protos = unsafe {
3704            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3705        };
3706        let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3707        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3708        let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3709        protos_f32
3710            .map()
3711            .unwrap()
3712            .as_mut_slice()
3713            .copy_from_slice(&raw_protos_f32);
3714        let protos_f32 = protos_f32.into();
3715
3716        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3717
3718        let expected = real_data_expected_boxes();
3719        let mut output_boxes = Vec::with_capacity(50);
3720
3721        decoder
3722            .decode(
3723                &[&boxes_f32, &protos_f32],
3724                &mut output_boxes,
3725                &mut Vec::new(),
3726            )
3727            .unwrap();
3728
3729        assert_eq!(output_boxes.len(), 2);
3730        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3731        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3732    }
3733
3734    #[test]
3735    fn test_decode_tensor_f64() {
3736        let score_threshold = 0.45;
3737        let iou_threshold = 0.45;
3738
3739        let quant_boxes = (0.021287762_f32, 31_i32);
3740        let quant_protos = (0.02491162_f32, -117_i32);
3741        let raw_boxes = include_bytes!(concat!(
3742            env!("CARGO_MANIFEST_DIR"),
3743            "/../../testdata/yolov8_boxes_116x8400.bin"
3744        ));
3745        let raw_boxes =
3746            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3747        let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3748        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3749        let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3750        boxes_f64
3751            .map()
3752            .unwrap()
3753            .as_mut_slice()
3754            .copy_from_slice(&raw_boxes_f64);
3755        let boxes_f64 = boxes_f64.into();
3756
3757        let raw_protos = include_bytes!(concat!(
3758            env!("CARGO_MANIFEST_DIR"),
3759            "/../../testdata/yolov8_protos_160x160x32.bin"
3760        ));
3761        let raw_protos = unsafe {
3762            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3763        };
3764        let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3765        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3766        let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3767        protos_f64
3768            .map()
3769            .unwrap()
3770            .as_mut_slice()
3771            .copy_from_slice(&raw_protos_f64);
3772        let protos_f64 = protos_f64.into();
3773
3774        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3775
3776        let expected = real_data_expected_boxes();
3777        let mut output_boxes = Vec::with_capacity(50);
3778
3779        decoder
3780            .decode(
3781                &[&boxes_f64, &protos_f64],
3782                &mut output_boxes,
3783                &mut Vec::new(),
3784            )
3785            .unwrap();
3786
3787        assert_eq!(output_boxes.len(), 2);
3788        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3789        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3790    }
3791
3792    #[test]
3793    fn test_decode_tensor_proto() {
3794        let score_threshold = 0.45;
3795        let iou_threshold = 0.45;
3796
3797        let raw_boxes = include_bytes!(concat!(
3798            env!("CARGO_MANIFEST_DIR"),
3799            "/../../testdata/yolov8_boxes_116x8400.bin"
3800        ));
3801        let raw_boxes =
3802            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3803        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3804        boxes_i8
3805            .map()
3806            .unwrap()
3807            .as_mut_slice()
3808            .copy_from_slice(raw_boxes);
3809        let boxes_i8 = boxes_i8.into();
3810
3811        let raw_protos = include_bytes!(concat!(
3812            env!("CARGO_MANIFEST_DIR"),
3813            "/../../testdata/yolov8_protos_160x160x32.bin"
3814        ));
3815        let raw_protos = unsafe {
3816            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3817        };
3818        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3819        protos_i8
3820            .map()
3821            .unwrap()
3822            .as_mut_slice()
3823            .copy_from_slice(raw_protos);
3824        let protos_i8 = protos_i8.into();
3825
3826        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3827
3828        let expected = real_data_expected_boxes();
3829        let mut output_boxes = Vec::with_capacity(50);
3830
3831        let proto_data = decoder
3832            .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3833            .unwrap();
3834
3835        assert_eq!(output_boxes.len(), 2);
3836        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3837        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3838
3839        let proto_data = proto_data.expect("segmentation model should return ProtoData");
3840        let coeffs_shape = proto_data.mask_coefficients.shape();
3841        assert_eq!(
3842            coeffs_shape[0],
3843            output_boxes.len(),
3844            "mask_coefficients count must match detection count"
3845        );
3846        assert_eq!(
3847            coeffs_shape[1], 32,
3848            "each detection should have 32 mask coefficients"
3849        );
3850    }
3851
3852    // =========================================================================
3853    // Physical-order contract regression tests
3854    //
3855    // These cover the TFLite VX-delegate vertical-stripe mask bug and the
3856    // Ara-2-via-overlay anchor-first split-tensor bug. The core claim:
3857    // declaring shape+dshape in physical memory order produces correct
3858    // strides at wrap time, and `swap_axes_if_needed` permutes the stride
3859    // tuple into canonical logical order without touching bytes.
3860    // =========================================================================
3861
3862    /// TFLite YOLOv8-seg protos arrive as physically-NHWC bytes with
3863    /// NNStreamer dim string `"32:160:160:1"` (innermost-first). The
3864    /// overlay's 2026-04-22 workaround declares `[1, 160, 160, 32]` with
3865    /// dshape `[batch, height, width, num_protos]` — shape matches
3866    /// physical order, so no permutation is needed and the view is
3867    /// already in canonical HWC after the batch-dim slice. This test
3868    /// confirms that path matches the reference (direct-HWC) decode.
3869    #[test]
3870    fn test_physical_order_tflite_nhwc_protos() {
3871        let score_threshold = 0.45;
3872        let iou_threshold = 0.45;
3873
3874        // Load HWC protos and dequantize to f32.
3875        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
3876        let quant_protos = Quantization::new(0.02491161972284317, -117);
3877        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
3878
3879        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
3880        let quant_boxes = Quantization::new(0.021287761628627777, 31);
3881        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
3882
3883        // Reference: direct call with canonical HWC protos.
3884        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
3885        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
3886        decode_yolo_segdet_float(
3887            seg.view(),
3888            protos_f32_hwc.view(),
3889            score_threshold,
3890            iou_threshold,
3891            Some(configs::Nms::ClassAgnostic),
3892            &mut ref_boxes,
3893            &mut ref_masks,
3894        )
3895        .unwrap();
3896
3897        // Build the same protos as a 4D NHWC tensor and feed it through
3898        // the config-driven path with dshape in physical order.
3899        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0)); // [1, 160, 160, 32]
3900        let seg_3d = seg.insert_axis(Axis(0)); // [1, 116, 8400]
3901
3902        let decoder = DecoderBuilder::default()
3903            .with_config_yolo_segdet(
3904                configs::Detection {
3905                    decoder: configs::DecoderType::Ultralytics,
3906                    quantization: None,
3907                    shape: vec![1, 116, 8400],
3908                    dshape: vec![
3909                        (DimName::Batch, 1),
3910                        (DimName::NumFeatures, 116),
3911                        (DimName::NumBoxes, 8400),
3912                    ],
3913                    normalized: Some(true),
3914                    anchors: None,
3915                },
3916                configs::Protos {
3917                    decoder: configs::DecoderType::Ultralytics,
3918                    quantization: None,
3919                    shape: vec![1, 160, 160, 32],
3920                    // Physical NHWC — matches the TFLite buffer layout.
3921                    dshape: vec![
3922                        (DimName::Batch, 1),
3923                        (DimName::Height, 160),
3924                        (DimName::Width, 160),
3925                        (DimName::NumProtos, 32),
3926                    ],
3927                },
3928                None,
3929            )
3930            .with_score_threshold(score_threshold)
3931            .with_iou_threshold(iou_threshold)
3932            .build()
3933            .expect("config with NHWC protos dshape must build");
3934
3935        let mut cfg_boxes = Vec::with_capacity(10);
3936        let mut cfg_masks = Vec::with_capacity(10);
3937        decoder
3938            .decode_float(
3939                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
3940                &mut cfg_boxes,
3941                &mut cfg_masks,
3942            )
3943            .unwrap();
3944
3945        assert_eq!(cfg_boxes.len(), ref_boxes.len(), "box count mismatch");
3946        for (c, r) in cfg_boxes.iter().zip(&ref_boxes) {
3947            assert!(
3948                c.equal_within_delta(r, 0.01),
3949                "NHWC-declared box does not match reference: {c:?} vs {r:?}"
3950            );
3951        }
3952        for (cm, rm) in cfg_masks.iter().zip(&ref_masks) {
3953            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
3954            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
3955            assert_eq!(
3956                cm_arr, rm_arr,
3957                "NHWC-declared mask must match reference pixel-for-pixel"
3958            );
3959        }
3960    }
3961
3962    /// Ara-2 split-tensor boxes arrive physically anchor-first: NNStreamer
3963    /// dim string `"4:1:8400:1"` means innermost axis is 4 (coords), next
3964    /// is 1 (padding), next is 8400 (anchors), outermost is 1 (batch).
3965    /// The correct physical-order declaration is `[1, 8400, 1, 4]` with
3966    /// dshape `[batch, num_boxes, padding, box_coords]`. This test
3967    /// confirms that declaration produces the same decoded boxes as the
3968    /// canonical features-first TFLite declaration over the same logical
3969    /// data.
3970    #[test]
3971    fn test_physical_order_ara2_anchor_first_split_boxes() {
3972        use configs::{Boxes, Scores};
3973
3974        // Build synthetic boxes data in features-first canonical layout:
3975        // shape [1, 4, 8400] with one detection at anchor 42.
3976        const N: usize = 8400;
3977        let mut boxes_canonical = Array3::<f32>::zeros((1, 4, N));
3978        let target_anchor = 42usize;
3979        boxes_canonical[[0, 0, target_anchor]] = 0.4; // xc
3980        boxes_canonical[[0, 1, target_anchor]] = 0.5; // yc
3981        boxes_canonical[[0, 2, target_anchor]] = 0.2; // w
3982        boxes_canonical[[0, 3, target_anchor]] = 0.2; // h
3983
3984        // Scores: one class at 0.9 for the same anchor.
3985        let mut scores_canonical = Array3::<f32>::zeros((1, 80, N));
3986        scores_canonical[[0, 0, target_anchor]] = 0.9;
3987
3988        // Reference: canonical (features-first) shape declaration.
3989        let ref_decoder = DecoderBuilder::default()
3990            .with_config_yolo_split_det(
3991                Boxes {
3992                    decoder: configs::DecoderType::Ultralytics,
3993                    quantization: None,
3994                    shape: vec![1, 4, N],
3995                    dshape: vec![
3996                        (DimName::Batch, 1),
3997                        (DimName::BoxCoords, 4),
3998                        (DimName::NumBoxes, N),
3999                    ],
4000                    normalized: Some(true),
4001                },
4002                Scores {
4003                    decoder: configs::DecoderType::Ultralytics,
4004                    quantization: None,
4005                    shape: vec![1, 80, N],
4006                    dshape: vec![
4007                        (DimName::Batch, 1),
4008                        (DimName::NumClasses, 80),
4009                        (DimName::NumBoxes, N),
4010                    ],
4011                },
4012            )
4013            .with_score_threshold(0.5)
4014            .with_iou_threshold(0.5)
4015            .with_nms(Some(configs::Nms::ClassAgnostic))
4016            .build()
4017            .expect("reference canonical split decoder must build");
4018
4019        let mut ref_boxes = Vec::with_capacity(4);
4020        let mut ref_masks = Vec::with_capacity(0);
4021        ref_decoder
4022            .decode_float(
4023                &[
4024                    boxes_canonical.view().into_dyn(),
4025                    scores_canonical.view().into_dyn(),
4026                ],
4027                &mut ref_boxes,
4028                &mut ref_masks,
4029            )
4030            .unwrap();
4031        assert_eq!(ref_boxes.len(), 1, "reference should produce one box");
4032
4033        // Ara-2 physical layout: transpose axes 1↔2 so the innermost
4034        // axis is BoxCoords / NumClasses. Materialise to get a
4035        // C-contiguous Array3 with strides matching the physical order
4036        // the Ara-2 backend would write into the tensor.
4037        let boxes_ara2 = boxes_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 4]
4038        let scores_ara2 = scores_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 80]
4039
4040        let ara2_decoder = DecoderBuilder::default()
4041            .with_config_yolo_split_det(
4042                Boxes {
4043                    decoder: configs::DecoderType::Ultralytics,
4044                    quantization: None,
4045                    shape: vec![1, N, 4],
4046                    dshape: vec![
4047                        (DimName::Batch, 1),
4048                        (DimName::NumBoxes, N),
4049                        (DimName::BoxCoords, 4),
4050                    ],
4051                    normalized: Some(true),
4052                },
4053                Scores {
4054                    decoder: configs::DecoderType::Ultralytics,
4055                    quantization: None,
4056                    shape: vec![1, N, 80],
4057                    dshape: vec![
4058                        (DimName::Batch, 1),
4059                        (DimName::NumBoxes, N),
4060                        (DimName::NumClasses, 80),
4061                    ],
4062                },
4063            )
4064            .with_score_threshold(0.5)
4065            .with_iou_threshold(0.5)
4066            .with_nms(Some(configs::Nms::ClassAgnostic))
4067            .build()
4068            .expect("Ara-2 anchor-first decoder must build");
4069
4070        let mut ara2_boxes = Vec::with_capacity(4);
4071        let mut ara2_masks = Vec::with_capacity(0);
4072        ara2_decoder
4073            .decode_float(
4074                &[boxes_ara2.view().into_dyn(), scores_ara2.view().into_dyn()],
4075                &mut ara2_boxes,
4076                &mut ara2_masks,
4077            )
4078            .unwrap();
4079
4080        assert_eq!(
4081            ara2_boxes.len(),
4082            ref_boxes.len(),
4083            "Ara-2 anchor-first declaration must produce the same number \
4084             of boxes as the canonical features-first reference"
4085        );
4086        for (a, r) in ara2_boxes.iter().zip(&ref_boxes) {
4087            assert!(
4088                a.equal_within_delta(r, 1e-4),
4089                "Ara-2 box differs from reference: {a:?} vs {r:?}"
4090            );
4091        }
4092    }
4093
4094    /// Dshape whose per-axis sizes disagree with shape (the exact
4095    /// failure mode that caused the TFLite stripe bug — shape declared
4096    /// in one order, dshape in another) is rejected with a clear error.
4097    #[test]
4098    fn test_physical_order_rejects_shape_dshape_mismatch() {
4099        let result = DecoderBuilder::default()
4100            .with_config_yolo_segdet(
4101                configs::Detection {
4102                    decoder: configs::DecoderType::Ultralytics,
4103                    quantization: None,
4104                    shape: vec![1, 116, 8400],
4105                    dshape: vec![
4106                        (DimName::Batch, 1),
4107                        (DimName::NumFeatures, 116),
4108                        (DimName::NumBoxes, 8400),
4109                    ],
4110                    normalized: Some(true),
4111                    anchors: None,
4112                },
4113                configs::Protos {
4114                    decoder: configs::DecoderType::Ultralytics,
4115                    quantization: None,
4116                    // Shape in NCHW order...
4117                    shape: vec![1, 32, 160, 160],
4118                    // ...but dshape in NHWC order — the sizes don't line
4119                    // up positionally (dshape[1]=160 vs shape[1]=32).
4120                    dshape: vec![
4121                        (DimName::Batch, 1),
4122                        (DimName::Height, 160),
4123                        (DimName::Width, 160),
4124                        (DimName::NumProtos, 32),
4125                    ],
4126                },
4127                None,
4128            )
4129            .build();
4130
4131        match result {
4132            Err(DecoderError::InvalidConfig(msg)) => {
4133                assert!(
4134                    msg.contains("does not match shape"),
4135                    "expected shape/dshape size mismatch error, got: {msg}"
4136                );
4137            }
4138            other => panic!("expected InvalidConfig, got {other:?}"),
4139        }
4140    }
4141
4142    /// Duplicate dim name in dshape is rejected (two axes mapping to the
4143    /// same canonical slot would break the permutation).
4144    #[test]
4145    fn test_physical_order_rejects_duplicate_dshape_axis() {
4146        let result = DecoderBuilder::default()
4147            .with_config_yolo_split_det(
4148                configs::Boxes {
4149                    decoder: configs::DecoderType::Ultralytics,
4150                    quantization: None,
4151                    shape: vec![1, 4, 8400],
4152                    dshape: vec![
4153                        (DimName::Batch, 1),
4154                        (DimName::BoxCoords, 4),
4155                        (DimName::BoxCoords, 4), // duplicate (size matches shape[2]=8400? no)
4156                    ],
4157                    normalized: Some(true),
4158                },
4159                configs::Scores {
4160                    decoder: configs::DecoderType::Ultralytics,
4161                    quantization: None,
4162                    shape: vec![1, 80, 8400],
4163                    dshape: vec![
4164                        (DimName::Batch, 1),
4165                        (DimName::NumClasses, 80),
4166                        (DimName::NumBoxes, 8400),
4167                    ],
4168                },
4169            )
4170            .build();
4171
4172        // Shape[2] = 8400 ≠ dshape[2] size 4, so the positional-mismatch
4173        // check fires first — which is fine: any mis-shaped dshape
4174        // should be rejected. Check for either error as long as it's a
4175        // clear InvalidConfig.
4176        match result {
4177            Err(DecoderError::InvalidConfig(msg)) => {
4178                assert!(
4179                    msg.contains("appears at both index") || msg.contains("does not match shape"),
4180                    "expected positional or duplicate-axis error, got: {msg}"
4181                );
4182            }
4183            other => panic!("expected InvalidConfig, got {other:?}"),
4184        }
4185
4186        // Separate case: size-consistent duplicate to exercise the
4187        // duplicate-axis code path explicitly. Use `Batch` (no size
4188        // constraint, can repeat with size 1 against a shape that
4189        // carries two singleton dims).
4190        let result = DecoderBuilder::default()
4191            .with_config_yolo_split_det(
4192                configs::Boxes {
4193                    decoder: configs::DecoderType::Ultralytics,
4194                    quantization: None,
4195                    shape: vec![1, 1, 4, 8400],
4196                    dshape: vec![
4197                        (DimName::Batch, 1),
4198                        (DimName::Batch, 1), // duplicate, sizes match
4199                        (DimName::BoxCoords, 4),
4200                        (DimName::NumBoxes, 8400),
4201                    ],
4202                    normalized: Some(true),
4203                },
4204                configs::Scores {
4205                    decoder: configs::DecoderType::Ultralytics,
4206                    quantization: None,
4207                    shape: vec![1, 80, 8400],
4208                    dshape: vec![
4209                        (DimName::Batch, 1),
4210                        (DimName::NumClasses, 80),
4211                        (DimName::NumBoxes, 8400),
4212                    ],
4213                },
4214            )
4215            .build();
4216        match result {
4217            Err(DecoderError::InvalidConfig(msg)) => {
4218                assert!(
4219                    msg.contains("appears at both index"),
4220                    "expected duplicate-axis error, got: {msg}"
4221                );
4222            }
4223            other => panic!("expected InvalidConfig, got {other:?}"),
4224        }
4225    }
4226
4227    /// Canonical (dshape-omitted) high-level decode produces the same
4228    /// numeric result as the dshape-populated form. This closes a
4229    /// coverage gap flagged during review: dshape omission had been
4230    /// exercised only at the builder level, not through the full
4231    /// `decode_float` pipeline.
4232    #[test]
4233    fn test_physical_order_dshape_omitted_decodes_numerically() {
4234        let score_threshold = 0.45;
4235        let iou_threshold = 0.45;
4236
4237        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
4238        let quant_protos = Quantization::new(0.02491161972284317, -117);
4239        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
4240
4241        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
4242        let quant_boxes = Quantization::new(0.021287761628627777, 31);
4243        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
4244
4245        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0));
4246        let seg_3d = seg.insert_axis(Axis(0));
4247
4248        let build_decoder = |det_dshape: Vec<(DimName, usize)>,
4249                             proto_dshape: Vec<(DimName, usize)>| {
4250            DecoderBuilder::default()
4251                .with_config_yolo_segdet(
4252                    configs::Detection {
4253                        decoder: configs::DecoderType::Ultralytics,
4254                        quantization: None,
4255                        shape: vec![1, 116, 8400],
4256                        dshape: det_dshape,
4257                        normalized: Some(true),
4258                        anchors: None,
4259                    },
4260                    configs::Protos {
4261                        decoder: configs::DecoderType::Ultralytics,
4262                        quantization: None,
4263                        shape: vec![1, 160, 160, 32],
4264                        dshape: proto_dshape,
4265                    },
4266                    None,
4267                )
4268                .with_score_threshold(score_threshold)
4269                .with_iou_threshold(iou_threshold)
4270                .build()
4271                .unwrap()
4272        };
4273
4274        // Baseline: dshape populated in physical order.
4275        let dshaped = build_decoder(
4276            vec![
4277                (DimName::Batch, 1),
4278                (DimName::NumFeatures, 116),
4279                (DimName::NumBoxes, 8400),
4280            ],
4281            vec![
4282                (DimName::Batch, 1),
4283                (DimName::Height, 160),
4284                (DimName::Width, 160),
4285                (DimName::NumProtos, 32),
4286            ],
4287        );
4288        let mut dshaped_boxes = Vec::new();
4289        let mut dshaped_masks = Vec::new();
4290        dshaped
4291            .decode_float(
4292                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4293                &mut dshaped_boxes,
4294                &mut dshaped_masks,
4295            )
4296            .unwrap();
4297
4298        // Variant: dshape omitted — caller asserts shape is already in
4299        // the decoder's canonical order.
4300        let bare = build_decoder(vec![], vec![]);
4301        let mut bare_boxes = Vec::new();
4302        let mut bare_masks = Vec::new();
4303        bare.decode_float(
4304            &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4305            &mut bare_boxes,
4306            &mut bare_masks,
4307        )
4308        .unwrap();
4309
4310        assert_eq!(bare_boxes.len(), dshaped_boxes.len());
4311        for (b, d) in bare_boxes.iter().zip(&dshaped_boxes) {
4312            assert!(
4313                b.equal_within_delta(d, 1e-4),
4314                "dshape-omitted box {b:?} differs from dshape-populated {d:?}"
4315            );
4316        }
4317        for (bm, dm) in bare_masks.iter().zip(&dshaped_masks) {
4318            let bm_arr = segmentation_to_mask(bm.segmentation.view()).unwrap();
4319            let dm_arr = segmentation_to_mask(dm.segmentation.view()).unwrap();
4320            assert_eq!(
4321                bm_arr, dm_arr,
4322                "dshape-omitted mask must match dshape-populated pixel-for-pixel"
4323            );
4324        }
4325    }
4326
4327    /// 4D anchor-first boxes schema with a trailing `padding` axis — the
4328    /// exact Ara-2 DVM shape (`"4:1:8400:1"` innermost-first = physical
4329    /// `[1, 8400, 1, 4]`). Exercises the schema path's
4330    /// `squeeze_padding_dims`: the caller declares a 4D physical-order
4331    /// layout including `padding`, HAL squeezes the size-1 axis during
4332    /// `to_legacy_config_outputs`, and the decoder sees the resulting
4333    /// 3D shape. Callers feed HAL a 3D squeezed view of the same bytes.
4334    /// Complements the 3D `test_physical_order_ara2_anchor_first_split_boxes`
4335    /// which exercises the programmatic-builder path.
4336    #[test]
4337    fn test_physical_order_ara2_4d_anchor_first_with_padding() {
4338        // Build synthetic data in anchor-first 3D layout (the shape HAL
4339        // actually operates on after squeezing). The 4D-with-padding
4340        // declaration in the schema is a producer-side convention; the
4341        // caller is expected to present HAL with a squeezed view.
4342        const N: usize = 8400;
4343        let mut boxes = Array3::<f32>::zeros((1, N, 4));
4344        let target = 42usize;
4345        boxes[[0, target, 0]] = 0.4;
4346        boxes[[0, target, 1]] = 0.5;
4347        boxes[[0, target, 2]] = 0.2;
4348        boxes[[0, target, 3]] = 0.2;
4349        let mut scores = Array3::<f32>::zeros((1, N, 80));
4350        scores[[0, target, 0]] = 0.9;
4351
4352        // Schema JSON declares shape+dshape in physical anchor-first
4353        // order including padding — mirrors `ara2_int8_edgefirst.json`'s
4354        // feature-first declaration but with the axes in physical
4355        // anchor-first order. `to_legacy_config_outputs` squeezes
4356        // padding before the decoder's rank-3 verification.
4357        let json = r#"{
4358          "schema_version": 2,
4359          "decoder_version": "yolov8",
4360          "nms": "class_agnostic",
4361          "outputs": [
4362            {"name": "boxes", "type": "boxes",
4363             "shape": [1, 8400, 1, 4],
4364             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"box_coords":4}],
4365             "encoding": "direct",
4366             "decoder": "ultralytics",
4367             "normalized": true},
4368            {"name": "scores", "type": "scores",
4369             "shape": [1, 8400, 1, 80],
4370             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"num_classes":80}],
4371             "decoder": "ultralytics",
4372             "score_format": "per_class"}
4373          ]
4374        }"#;
4375        let decoder = DecoderBuilder::default()
4376            .with_config_json_str(json.to_string())
4377            .with_score_threshold(0.5)
4378            .with_iou_threshold(0.5)
4379            .build()
4380            .expect("4D anchor-first schema should build via squeeze_padding_dims");
4381
4382        let mut out_boxes = Vec::with_capacity(4);
4383        let mut out_masks = Vec::with_capacity(0);
4384        decoder
4385            .decode_float(
4386                &[boxes.view().into_dyn(), scores.view().into_dyn()],
4387                &mut out_boxes,
4388                &mut out_masks,
4389            )
4390            .unwrap();
4391
4392        assert_eq!(
4393            out_boxes.len(),
4394            1,
4395            "4D anchor-first with padding should decode exactly one box from the seeded anchor"
4396        );
4397        let b = &out_boxes[0];
4398        // xywh(0.4, 0.5, 0.2, 0.2) → xyxy(0.3, 0.4, 0.5, 0.6)
4399        assert!((b.bbox.xmin - 0.3).abs() < 1e-3, "xmin wrong: {b:?}");
4400        assert!((b.bbox.ymin - 0.4).abs() < 1e-3, "ymin wrong: {b:?}");
4401        assert!((b.bbox.xmax - 0.5).abs() < 1e-3, "xmax wrong: {b:?}");
4402        assert!((b.bbox.ymax - 0.6).abs() < 1e-3, "ymax wrong: {b:?}");
4403        assert_eq!(b.label, 0);
4404        assert!(b.score > 0.85, "score {}: {b:?}", b.score);
4405    }
4406}
4407
4408#[cfg(feature = "tracker")]
4409#[cfg(test)]
4410#[cfg_attr(coverage_nightly, coverage(off))]
4411mod decoder_tracked_tests {
4412
4413    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
4414    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
4415    use num_traits::{AsPrimitive, Float, PrimInt};
4416    use rand::{RngExt, SeedableRng};
4417    use rand_distr::StandardNormal;
4418
4419    use crate::{
4420        configs::{self, DimName},
4421        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
4422    };
4423
4424    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
4425        input: ArrayView<F, D>,
4426        quant: Quantization,
4427    ) -> Array<T, D>
4428    where
4429        i32: num_traits::AsPrimitive<F>,
4430        f32: num_traits::AsPrimitive<F>,
4431    {
4432        let zero_point = quant.zero_point.as_();
4433        let div_scale = F::one() / quant.scale.as_();
4434        if zero_point != F::zero() {
4435            input.mapv(|d| (d * div_scale + zero_point).round().as_())
4436        } else {
4437            input.mapv(|d| (d * div_scale).round().as_())
4438        }
4439    }
4440
4441    #[test]
4442    fn test_decoder_tracked_random_jitter() {
4443        use crate::configs::{DecoderType, Nms};
4444        use crate::DecoderBuilder;
4445
4446        let score_threshold = 0.25;
4447        let iou_threshold = 0.1;
4448        let out = include_bytes!(concat!(
4449            env!("CARGO_MANIFEST_DIR"),
4450            "/../../testdata/yolov8s_80_classes.bin"
4451        ));
4452        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4453        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4454        let quant = (0.0040811873, -123).into();
4455
4456        let decoder = DecoderBuilder::default()
4457            .with_config_yolo_det(
4458                crate::configs::Detection {
4459                    decoder: DecoderType::Ultralytics,
4460                    shape: vec![1, 84, 8400],
4461                    anchors: None,
4462                    quantization: Some(quant),
4463                    dshape: vec![
4464                        (crate::configs::DimName::Batch, 1),
4465                        (crate::configs::DimName::NumFeatures, 84),
4466                        (crate::configs::DimName::NumBoxes, 8400),
4467                    ],
4468                    normalized: Some(true),
4469                },
4470                None,
4471            )
4472            .with_score_threshold(score_threshold)
4473            .with_iou_threshold(iou_threshold)
4474            .with_nms(Some(Nms::ClassAgnostic))
4475            .build()
4476            .unwrap();
4477        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
4478
4479        let expected_boxes = [
4480            crate::DetectBox {
4481                bbox: crate::BoundingBox {
4482                    xmin: 0.5285137,
4483                    ymin: 0.05305544,
4484                    xmax: 0.87541467,
4485                    ymax: 0.9998909,
4486                },
4487                score: 0.5591227,
4488                label: 0,
4489            },
4490            crate::DetectBox {
4491                bbox: crate::BoundingBox {
4492                    xmin: 0.130598,
4493                    ymin: 0.43260583,
4494                    xmax: 0.35098213,
4495                    ymax: 0.9958097,
4496                },
4497                score: 0.33057618,
4498                label: 75,
4499            },
4500        ];
4501
4502        let mut tracker = ByteTrackBuilder::new()
4503            .track_update(0.1)
4504            .track_high_conf(0.3)
4505            .build();
4506
4507        let mut output_boxes = Vec::with_capacity(50);
4508        let mut output_masks = Vec::with_capacity(50);
4509        let mut output_tracks = Vec::with_capacity(50);
4510
4511        decoder
4512            .decode_tracked_quantized(
4513                &mut tracker,
4514                0,
4515                &[out.view().into()],
4516                &mut output_boxes,
4517                &mut output_masks,
4518                &mut output_tracks,
4519            )
4520            .unwrap();
4521
4522        assert_eq!(output_boxes.len(), 2);
4523        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4524        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4525
4526        let mut last_boxes = output_boxes.clone();
4527
4528        for i in 1..=100 {
4529            let mut out = out.clone();
4530            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
4531            let mut x_values = out.slice_mut(s![0, 0, ..]);
4532            for x in x_values.iter_mut() {
4533                let r: f32 = rng.sample(StandardNormal);
4534                let r = r.clamp(-2.0, 2.0) / 2.0;
4535                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
4536            }
4537
4538            let mut y_values = out.slice_mut(s![0, 1, ..]);
4539            for y in y_values.iter_mut() {
4540                let r: f32 = rng.sample(StandardNormal);
4541                let r = r.clamp(-2.0, 2.0) / 2.0;
4542                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
4543            }
4544
4545            decoder
4546                .decode_tracked_quantized(
4547                    &mut tracker,
4548                    100_000_000 * i / 3, // simulate 33.333ms between frames
4549                    &[out.view().into()],
4550                    &mut output_boxes,
4551                    &mut output_masks,
4552                    &mut output_tracks,
4553                )
4554                .unwrap();
4555
4556            assert_eq!(output_boxes.len(), 2);
4557            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
4558            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
4559
4560            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
4561            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
4562            last_boxes = output_boxes.clone();
4563        }
4564    }
4565
4566    // ─── Shared helpers for tracked decoder tests ────────────────────
4567
4568    fn real_data_expected_boxes() -> [DetectBox; 2] {
4569        [
4570            DetectBox {
4571                bbox: BoundingBox {
4572                    xmin: 0.08515105,
4573                    ymin: 0.7131401,
4574                    xmax: 0.29802868,
4575                    ymax: 0.8195788,
4576                },
4577                score: 0.91537374,
4578                label: 23,
4579            },
4580            DetectBox {
4581                bbox: BoundingBox {
4582                    xmin: 0.59605736,
4583                    ymin: 0.25545314,
4584                    xmax: 0.93666154,
4585                    ymax: 0.72378385,
4586                },
4587                score: 0.91537374,
4588                label: 23,
4589            },
4590        ]
4591    }
4592
4593    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
4594        [DetectBox {
4595            bbox: BoundingBox {
4596                xmin: 0.12549022,
4597                ymin: 0.12549022,
4598                xmax: 0.23529413,
4599                ymax: 0.23529413,
4600            },
4601            score: 0.98823535,
4602            label: 2,
4603        }]
4604    }
4605
4606    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
4607        [DetectBox {
4608            bbox: BoundingBox {
4609                xmin: 0.1234,
4610                ymin: 0.1234,
4611                xmax: 0.2345,
4612                ymax: 0.2345,
4613            },
4614            score: 0.9876,
4615            label: 2,
4616        }]
4617    }
4618
4619    fn build_yolo_split_segdet_decoder(
4620        score_threshold: f32,
4621        iou_threshold: f32,
4622        quant_boxes: (f32, i32),
4623        quant_protos: (f32, i32),
4624    ) -> crate::Decoder {
4625        DecoderBuilder::default()
4626            .with_config_yolo_split_segdet(
4627                configs::Boxes {
4628                    decoder: configs::DecoderType::Ultralytics,
4629                    quantization: Some(quant_boxes.into()),
4630                    shape: vec![1, 4, 8400],
4631                    dshape: vec![
4632                        (DimName::Batch, 1),
4633                        (DimName::BoxCoords, 4),
4634                        (DimName::NumBoxes, 8400),
4635                    ],
4636                    normalized: Some(true),
4637                },
4638                configs::Scores {
4639                    decoder: configs::DecoderType::Ultralytics,
4640                    quantization: Some(quant_boxes.into()),
4641                    shape: vec![1, 80, 8400],
4642                    dshape: vec![
4643                        (DimName::Batch, 1),
4644                        (DimName::NumClasses, 80),
4645                        (DimName::NumBoxes, 8400),
4646                    ],
4647                },
4648                configs::MaskCoefficients {
4649                    decoder: configs::DecoderType::Ultralytics,
4650                    quantization: Some(quant_boxes.into()),
4651                    shape: vec![1, 32, 8400],
4652                    dshape: vec![
4653                        (DimName::Batch, 1),
4654                        (DimName::NumProtos, 32),
4655                        (DimName::NumBoxes, 8400),
4656                    ],
4657                },
4658                configs::Protos {
4659                    decoder: configs::DecoderType::Ultralytics,
4660                    quantization: Some(quant_protos.into()),
4661                    shape: vec![1, 160, 160, 32],
4662                    dshape: vec![
4663                        (DimName::Batch, 1),
4664                        (DimName::Height, 160),
4665                        (DimName::Width, 160),
4666                        (DimName::NumProtos, 32),
4667                    ],
4668                },
4669            )
4670            .with_score_threshold(score_threshold)
4671            .with_iou_threshold(iou_threshold)
4672            .build()
4673            .unwrap()
4674    }
4675
4676    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4677        let config_yaml = include_str!(concat!(
4678            env!("CARGO_MANIFEST_DIR"),
4679            "/../../testdata/yolov8_seg.yaml"
4680        ));
4681        DecoderBuilder::default()
4682            .with_config_yaml_str(config_yaml.to_string())
4683            .with_score_threshold(score_threshold)
4684            .with_iou_threshold(iou_threshold)
4685            .build()
4686            .unwrap()
4687    }
4688
4689    // ─── Real-data tracked test macro ───────────────────────────────
4690    //
4691    // Generates tests that load i8 binary test data from testdata/ and
4692    // exercise all (quant/float) × (combined/split) × (masks/proto)
4693    // decoder paths.
4694
4695    macro_rules! real_data_tracked_test {
4696        ($name:ident, quantized, $layout:ident, $output:ident) => {
4697            #[test]
4698            fn $name() {
4699                let is_split = matches!(stringify!($layout), "split");
4700                let is_proto = matches!(stringify!($output), "proto");
4701
4702                let score_threshold = 0.45;
4703                let iou_threshold = 0.45;
4704                let quant_boxes = (0.021287762_f32, 31_i32);
4705                let quant_protos = (0.02491162_f32, -117_i32);
4706
4707                let raw_boxes = include_bytes!(concat!(
4708                    env!("CARGO_MANIFEST_DIR"),
4709                    "/../../testdata/yolov8_boxes_116x8400.bin"
4710                ));
4711                let raw_boxes = unsafe {
4712                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4713                };
4714                let boxes_i8 =
4715                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4716
4717                let raw_protos = include_bytes!(concat!(
4718                    env!("CARGO_MANIFEST_DIR"),
4719                    "/../../testdata/yolov8_protos_160x160x32.bin"
4720                ));
4721                let raw_protos = unsafe {
4722                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4723                };
4724                let protos_i8 =
4725                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4726                        .unwrap();
4727
4728                // Pre-split (unused for combined, but harmless)
4729                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4730                let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4731                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4732                let mut boxes_combined = boxes_i8;
4733
4734                let decoder = if is_split {
4735                    build_yolo_split_segdet_decoder(
4736                        score_threshold,
4737                        iou_threshold,
4738                        quant_boxes,
4739                        quant_protos,
4740                    )
4741                } else {
4742                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4743                };
4744
4745                let expected = real_data_expected_boxes();
4746                let mut tracker = ByteTrackBuilder::new()
4747                    .track_update(0.1)
4748                    .track_high_conf(0.7)
4749                    .build();
4750                let mut output_boxes = Vec::with_capacity(50);
4751                let mut output_tracks = Vec::with_capacity(50);
4752
4753                // Frame 1: decode
4754                if is_proto {
4755                    {
4756                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4757                            vec![
4758                                boxes_split.view().into(),
4759                                scores_split.view().into(),
4760                                mask_split.view().into(),
4761                                protos_i8.view().into(),
4762                            ]
4763                        } else {
4764                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4765                        };
4766                        decoder
4767                            .decode_tracked_quantized_proto(
4768                                &mut tracker,
4769                                0,
4770                                &inputs,
4771                                &mut output_boxes,
4772                                &mut output_tracks,
4773                            )
4774                            .unwrap();
4775                    }
4776                    assert_eq!(output_boxes.len(), 2);
4777                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4778                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4779
4780                    // Zero scores for frame 2
4781                    if is_split {
4782                        for score in scores_split.iter_mut() {
4783                            *score = i8::MIN;
4784                        }
4785                    } else {
4786                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4787                            *score = i8::MIN;
4788                        }
4789                    }
4790
4791                    let proto_result = {
4792                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4793                            vec![
4794                                boxes_split.view().into(),
4795                                scores_split.view().into(),
4796                                mask_split.view().into(),
4797                                protos_i8.view().into(),
4798                            ]
4799                        } else {
4800                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4801                        };
4802                        decoder
4803                            .decode_tracked_quantized_proto(
4804                                &mut tracker,
4805                                100_000_000 / 3,
4806                                &inputs,
4807                                &mut output_boxes,
4808                                &mut output_tracks,
4809                            )
4810                            .unwrap()
4811                    };
4812                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4813                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4814                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4815                } else {
4816                    let mut output_masks = Vec::with_capacity(50);
4817                    {
4818                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4819                            vec![
4820                                boxes_split.view().into(),
4821                                scores_split.view().into(),
4822                                mask_split.view().into(),
4823                                protos_i8.view().into(),
4824                            ]
4825                        } else {
4826                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4827                        };
4828                        decoder
4829                            .decode_tracked_quantized(
4830                                &mut tracker,
4831                                0,
4832                                &inputs,
4833                                &mut output_boxes,
4834                                &mut output_masks,
4835                                &mut output_tracks,
4836                            )
4837                            .unwrap();
4838                    }
4839                    assert_eq!(output_boxes.len(), 2);
4840                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4841                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4842
4843                    if is_split {
4844                        for score in scores_split.iter_mut() {
4845                            *score = i8::MIN;
4846                        }
4847                    } else {
4848                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4849                            *score = i8::MIN;
4850                        }
4851                    }
4852
4853                    {
4854                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4855                            vec![
4856                                boxes_split.view().into(),
4857                                scores_split.view().into(),
4858                                mask_split.view().into(),
4859                                protos_i8.view().into(),
4860                            ]
4861                        } else {
4862                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4863                        };
4864                        decoder
4865                            .decode_tracked_quantized(
4866                                &mut tracker,
4867                                100_000_000 / 3,
4868                                &inputs,
4869                                &mut output_boxes,
4870                                &mut output_masks,
4871                                &mut output_tracks,
4872                            )
4873                            .unwrap();
4874                    }
4875                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4876                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4877                    assert!(output_masks.is_empty());
4878                }
4879            }
4880        };
4881        ($name:ident, float, $layout:ident, $output:ident) => {
4882            #[test]
4883            fn $name() {
4884                let is_split = matches!(stringify!($layout), "split");
4885                let is_proto = matches!(stringify!($output), "proto");
4886
4887                let score_threshold = 0.45;
4888                let iou_threshold = 0.45;
4889                let quant_boxes = (0.021287762_f32, 31_i32);
4890                let quant_protos = (0.02491162_f32, -117_i32);
4891
4892                let raw_boxes = include_bytes!(concat!(
4893                    env!("CARGO_MANIFEST_DIR"),
4894                    "/../../testdata/yolov8_boxes_116x8400.bin"
4895                ));
4896                let raw_boxes = unsafe {
4897                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4898                };
4899                let boxes_i8 =
4900                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4901                let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4902
4903                let raw_protos = include_bytes!(concat!(
4904                    env!("CARGO_MANIFEST_DIR"),
4905                    "/../../testdata/yolov8_protos_160x160x32.bin"
4906                ));
4907                let raw_protos = unsafe {
4908                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4909                };
4910                let protos_i8 =
4911                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4912                        .unwrap();
4913                let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4914
4915                // Pre-split from dequantized data
4916                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4917                let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4918                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4919                let mut boxes_combined = boxes_f32;
4920
4921                let decoder = if is_split {
4922                    build_yolo_split_segdet_decoder(
4923                        score_threshold,
4924                        iou_threshold,
4925                        quant_boxes,
4926                        quant_protos,
4927                    )
4928                } else {
4929                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4930                };
4931
4932                let expected = real_data_expected_boxes();
4933                let mut tracker = ByteTrackBuilder::new()
4934                    .track_update(0.1)
4935                    .track_high_conf(0.7)
4936                    .build();
4937                let mut output_boxes = Vec::with_capacity(50);
4938                let mut output_tracks = Vec::with_capacity(50);
4939
4940                if is_proto {
4941                    {
4942                        let inputs = if is_split {
4943                            vec![
4944                                boxes_split.view().into_dyn(),
4945                                scores_split.view().into_dyn(),
4946                                mask_split.view().into_dyn(),
4947                                protos_f32.view().into_dyn(),
4948                            ]
4949                        } else {
4950                            vec![
4951                                boxes_combined.view().into_dyn(),
4952                                protos_f32.view().into_dyn(),
4953                            ]
4954                        };
4955                        decoder
4956                            .decode_tracked_float_proto(
4957                                &mut tracker,
4958                                0,
4959                                &inputs,
4960                                &mut output_boxes,
4961                                &mut output_tracks,
4962                            )
4963                            .unwrap();
4964                    }
4965                    assert_eq!(output_boxes.len(), 2);
4966                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4967                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4968
4969                    if is_split {
4970                        for score in scores_split.iter_mut() {
4971                            *score = 0.0;
4972                        }
4973                    } else {
4974                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4975                            *score = 0.0;
4976                        }
4977                    }
4978
4979                    let proto_result = {
4980                        let inputs = if is_split {
4981                            vec![
4982                                boxes_split.view().into_dyn(),
4983                                scores_split.view().into_dyn(),
4984                                mask_split.view().into_dyn(),
4985                                protos_f32.view().into_dyn(),
4986                            ]
4987                        } else {
4988                            vec![
4989                                boxes_combined.view().into_dyn(),
4990                                protos_f32.view().into_dyn(),
4991                            ]
4992                        };
4993                        decoder
4994                            .decode_tracked_float_proto(
4995                                &mut tracker,
4996                                100_000_000 / 3,
4997                                &inputs,
4998                                &mut output_boxes,
4999                                &mut output_tracks,
5000                            )
5001                            .unwrap()
5002                    };
5003                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5004                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
5005                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5006                } else {
5007                    let mut output_masks = Vec::with_capacity(50);
5008                    {
5009                        let inputs = if is_split {
5010                            vec![
5011                                boxes_split.view().into_dyn(),
5012                                scores_split.view().into_dyn(),
5013                                mask_split.view().into_dyn(),
5014                                protos_f32.view().into_dyn(),
5015                            ]
5016                        } else {
5017                            vec![
5018                                boxes_combined.view().into_dyn(),
5019                                protos_f32.view().into_dyn(),
5020                            ]
5021                        };
5022                        decoder
5023                            .decode_tracked_float(
5024                                &mut tracker,
5025                                0,
5026                                &inputs,
5027                                &mut output_boxes,
5028                                &mut output_masks,
5029                                &mut output_tracks,
5030                            )
5031                            .unwrap();
5032                    }
5033                    assert_eq!(output_boxes.len(), 2);
5034                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5035                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
5036
5037                    if is_split {
5038                        for score in scores_split.iter_mut() {
5039                            *score = 0.0;
5040                        }
5041                    } else {
5042                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
5043                            *score = 0.0;
5044                        }
5045                    }
5046
5047                    {
5048                        let inputs = if is_split {
5049                            vec![
5050                                boxes_split.view().into_dyn(),
5051                                scores_split.view().into_dyn(),
5052                                mask_split.view().into_dyn(),
5053                                protos_f32.view().into_dyn(),
5054                            ]
5055                        } else {
5056                            vec![
5057                                boxes_combined.view().into_dyn(),
5058                                protos_f32.view().into_dyn(),
5059                            ]
5060                        };
5061                        decoder
5062                            .decode_tracked_float(
5063                                &mut tracker,
5064                                100_000_000 / 3,
5065                                &inputs,
5066                                &mut output_boxes,
5067                                &mut output_masks,
5068                                &mut output_tracks,
5069                            )
5070                            .unwrap();
5071                    }
5072                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5073                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
5074                    assert!(output_masks.is_empty());
5075                }
5076            }
5077        };
5078    }
5079
5080    real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
5081    real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
5082    real_data_tracked_test!(
5083        test_decoder_tracked_segdet_proto,
5084        quantized,
5085        combined,
5086        proto
5087    );
5088    real_data_tracked_test!(
5089        test_decoder_tracked_segdet_proto_float,
5090        float,
5091        combined,
5092        proto
5093    );
5094    real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
5095    real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
5096    real_data_tracked_test!(
5097        test_decoder_tracked_segdet_split_proto,
5098        quantized,
5099        split,
5100        proto
5101    );
5102    real_data_tracked_test!(
5103        test_decoder_tracked_segdet_split_proto_float,
5104        float,
5105        split,
5106        proto
5107    );
5108
5109    // ─── End-to-end tracked test macro ──────────────────────────────
5110    //
5111    // Generates tests with synthetic data to exercise all tracked
5112    // decode paths without needing real model output files.
5113
5114    const E2E_COMBINED_CONFIG: &str = "
5115decoder_version: yolo26
5116outputs:
5117 - type: detection
5118   decoder: ultralytics
5119   quantization: [0.00784313725490196, 0]
5120   shape: [1, 10, 38]
5121   dshape:
5122    - [batch, 1]
5123    - [num_boxes, 10]
5124    - [num_features, 38]
5125   normalized: true
5126 - type: protos
5127   decoder: ultralytics
5128   quantization: [0.0039215686274509803921568627451, 128]
5129   shape: [1, 160, 160, 32]
5130   dshape:
5131    - [batch, 1]
5132    - [height, 160]
5133    - [width, 160]
5134    - [num_protos, 32]
5135";
5136
5137    const E2E_SPLIT_CONFIG: &str = "
5138decoder_version: yolo26
5139outputs:
5140 - type: boxes
5141   decoder: ultralytics
5142   quantization: [0.00784313725490196, 0]
5143   shape: [1, 10, 4]
5144   dshape:
5145    - [batch, 1]
5146    - [num_boxes, 10]
5147    - [box_coords, 4]
5148   normalized: true
5149 - type: scores
5150   decoder: ultralytics
5151   quantization: [0.00784313725490196, 0]
5152   shape: [1, 10, 1]
5153   dshape:
5154    - [batch, 1]
5155    - [num_boxes, 10]
5156    - [num_classes, 1]
5157 - type: classes
5158   decoder: ultralytics
5159   quantization: [0.00784313725490196, 0]
5160   shape: [1, 10, 1]
5161   dshape:
5162    - [batch, 1]
5163    - [num_boxes, 10]
5164    - [num_classes, 1]
5165 - type: mask_coefficients
5166   decoder: ultralytics
5167   quantization: [0.00784313725490196, 0]
5168   shape: [1, 10, 32]
5169   dshape:
5170    - [batch, 1]
5171    - [num_boxes, 10]
5172    - [num_protos, 32]
5173 - type: protos
5174   decoder: ultralytics
5175   quantization: [0.0039215686274509803921568627451, 128]
5176   shape: [1, 160, 160, 32]
5177   dshape:
5178    - [batch, 1]
5179    - [height, 160]
5180    - [width, 160]
5181    - [num_protos, 32]
5182";
5183
5184    macro_rules! e2e_tracked_test {
5185        ($name:ident, quantized, $layout:ident, $output:ident) => {
5186            #[test]
5187            fn $name() {
5188                let is_split = matches!(stringify!($layout), "split");
5189                let is_proto = matches!(stringify!($output), "proto");
5190
5191                let score_threshold = 0.45;
5192                let iou_threshold = 0.45;
5193
5194                let mut boxes = Array2::zeros((10, 4));
5195                let mut scores = Array2::zeros((10, 1));
5196                let mut classes = Array2::zeros((10, 1));
5197                let mask = Array2::zeros((10, 32));
5198                let protos = Array3::<f64>::zeros((160, 160, 32));
5199                let protos = protos.insert_axis(Axis(0));
5200                let protos_quant = (1.0 / 255.0, 0.0);
5201                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
5202
5203                boxes
5204                    .slice_mut(s![0, ..])
5205                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5206                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5207                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5208
5209                let detect_quant = (2.0 / 255.0, 0.0);
5210
5211                let decoder = if is_split {
5212                    DecoderBuilder::default()
5213                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5214                        .with_score_threshold(score_threshold)
5215                        .with_iou_threshold(iou_threshold)
5216                        .build()
5217                        .unwrap()
5218                } else {
5219                    DecoderBuilder::default()
5220                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5221                        .with_score_threshold(score_threshold)
5222                        .with_iou_threshold(iou_threshold)
5223                        .build()
5224                        .unwrap()
5225                };
5226
5227                let expected = e2e_expected_boxes_quant();
5228                let mut tracker = ByteTrackBuilder::new()
5229                    .track_update(0.1)
5230                    .track_high_conf(0.7)
5231                    .build();
5232                let mut output_boxes = Vec::with_capacity(50);
5233                let mut output_tracks = Vec::with_capacity(50);
5234
5235                if is_split {
5236                    let boxes = boxes.insert_axis(Axis(0));
5237                    let scores = scores.insert_axis(Axis(0));
5238                    let classes = classes.insert_axis(Axis(0));
5239                    let mask = mask.insert_axis(Axis(0));
5240
5241                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5242                    let mut scores: Array3<u8> =
5243                        quantize_ndarray(scores.view(), detect_quant.into());
5244                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
5245                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5246
5247                    if is_proto {
5248                        {
5249                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5250                                boxes.view().into(),
5251                                scores.view().into(),
5252                                classes.view().into(),
5253                                mask.view().into(),
5254                                protos.view().into(),
5255                            ];
5256                            decoder
5257                                .decode_tracked_quantized_proto(
5258                                    &mut tracker,
5259                                    0,
5260                                    &inputs,
5261                                    &mut output_boxes,
5262                                    &mut output_tracks,
5263                                )
5264                                .unwrap();
5265                        }
5266                        assert_eq!(output_boxes.len(), 1);
5267                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5268
5269                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5270                            *score = u8::MIN;
5271                        }
5272                        let proto_result = {
5273                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5274                                boxes.view().into(),
5275                                scores.view().into(),
5276                                classes.view().into(),
5277                                mask.view().into(),
5278                                protos.view().into(),
5279                            ];
5280                            decoder
5281                                .decode_tracked_quantized_proto(
5282                                    &mut tracker,
5283                                    100_000_000 / 3,
5284                                    &inputs,
5285                                    &mut output_boxes,
5286                                    &mut output_tracks,
5287                                )
5288                                .unwrap()
5289                        };
5290                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5291                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5292                    } else {
5293                        let mut output_masks = Vec::with_capacity(50);
5294                        {
5295                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5296                                boxes.view().into(),
5297                                scores.view().into(),
5298                                classes.view().into(),
5299                                mask.view().into(),
5300                                protos.view().into(),
5301                            ];
5302                            decoder
5303                                .decode_tracked_quantized(
5304                                    &mut tracker,
5305                                    0,
5306                                    &inputs,
5307                                    &mut output_boxes,
5308                                    &mut output_masks,
5309                                    &mut output_tracks,
5310                                )
5311                                .unwrap();
5312                        }
5313                        assert_eq!(output_boxes.len(), 1);
5314                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5315
5316                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5317                            *score = u8::MIN;
5318                        }
5319                        {
5320                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5321                                boxes.view().into(),
5322                                scores.view().into(),
5323                                classes.view().into(),
5324                                mask.view().into(),
5325                                protos.view().into(),
5326                            ];
5327                            decoder
5328                                .decode_tracked_quantized(
5329                                    &mut tracker,
5330                                    100_000_000 / 3,
5331                                    &inputs,
5332                                    &mut output_boxes,
5333                                    &mut output_masks,
5334                                    &mut output_tracks,
5335                                )
5336                                .unwrap();
5337                        }
5338                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5339                        assert!(output_masks.is_empty());
5340                    }
5341                } else {
5342                    // Combined layout
5343                    let detect = ndarray::concatenate![
5344                        Axis(1),
5345                        boxes.view(),
5346                        scores.view(),
5347                        classes.view(),
5348                        mask.view()
5349                    ];
5350                    let detect = detect.insert_axis(Axis(0));
5351                    assert_eq!(detect.shape(), &[1, 10, 38]);
5352                    let mut detect: Array3<u8> =
5353                        quantize_ndarray(detect.view(), detect_quant.into());
5354
5355                    if is_proto {
5356                        {
5357                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5358                                vec![detect.view().into(), protos.view().into()];
5359                            decoder
5360                                .decode_tracked_quantized_proto(
5361                                    &mut tracker,
5362                                    0,
5363                                    &inputs,
5364                                    &mut output_boxes,
5365                                    &mut output_tracks,
5366                                )
5367                                .unwrap();
5368                        }
5369                        assert_eq!(output_boxes.len(), 1);
5370                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5371
5372                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5373                            *score = u8::MIN;
5374                        }
5375                        let proto_result = {
5376                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5377                                vec![detect.view().into(), protos.view().into()];
5378                            decoder
5379                                .decode_tracked_quantized_proto(
5380                                    &mut tracker,
5381                                    100_000_000 / 3,
5382                                    &inputs,
5383                                    &mut output_boxes,
5384                                    &mut output_tracks,
5385                                )
5386                                .unwrap()
5387                        };
5388                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5389                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5390                    } else {
5391                        let mut output_masks = Vec::with_capacity(50);
5392                        {
5393                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5394                                vec![detect.view().into(), protos.view().into()];
5395                            decoder
5396                                .decode_tracked_quantized(
5397                                    &mut tracker,
5398                                    0,
5399                                    &inputs,
5400                                    &mut output_boxes,
5401                                    &mut output_masks,
5402                                    &mut output_tracks,
5403                                )
5404                                .unwrap();
5405                        }
5406                        assert_eq!(output_boxes.len(), 1);
5407                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5408
5409                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5410                            *score = u8::MIN;
5411                        }
5412                        {
5413                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5414                                vec![detect.view().into(), protos.view().into()];
5415                            decoder
5416                                .decode_tracked_quantized(
5417                                    &mut tracker,
5418                                    100_000_000 / 3,
5419                                    &inputs,
5420                                    &mut output_boxes,
5421                                    &mut output_masks,
5422                                    &mut output_tracks,
5423                                )
5424                                .unwrap();
5425                        }
5426                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5427                        assert!(output_masks.is_empty());
5428                    }
5429                }
5430            }
5431        };
5432        ($name:ident, float, $layout:ident, $output:ident) => {
5433            #[test]
5434            fn $name() {
5435                let is_split = matches!(stringify!($layout), "split");
5436                let is_proto = matches!(stringify!($output), "proto");
5437
5438                let score_threshold = 0.45;
5439                let iou_threshold = 0.45;
5440
5441                let mut boxes = Array2::zeros((10, 4));
5442                let mut scores = Array2::zeros((10, 1));
5443                let mut classes = Array2::zeros((10, 1));
5444                let mask: Array2<f64> = Array2::zeros((10, 32));
5445                let protos = Array3::<f64>::zeros((160, 160, 32));
5446                let protos = protos.insert_axis(Axis(0));
5447
5448                boxes
5449                    .slice_mut(s![0, ..])
5450                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5451                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5452                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5453
5454                let decoder = if is_split {
5455                    DecoderBuilder::default()
5456                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5457                        .with_score_threshold(score_threshold)
5458                        .with_iou_threshold(iou_threshold)
5459                        .build()
5460                        .unwrap()
5461                } else {
5462                    DecoderBuilder::default()
5463                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5464                        .with_score_threshold(score_threshold)
5465                        .with_iou_threshold(iou_threshold)
5466                        .build()
5467                        .unwrap()
5468                };
5469
5470                let expected = e2e_expected_boxes_float();
5471                let mut tracker = ByteTrackBuilder::new()
5472                    .track_update(0.1)
5473                    .track_high_conf(0.7)
5474                    .build();
5475                let mut output_boxes = Vec::with_capacity(50);
5476                let mut output_tracks = Vec::with_capacity(50);
5477
5478                if is_split {
5479                    let boxes = boxes.insert_axis(Axis(0));
5480                    let mut scores = scores.insert_axis(Axis(0));
5481                    let classes = classes.insert_axis(Axis(0));
5482                    let mask = mask.insert_axis(Axis(0));
5483
5484                    if is_proto {
5485                        {
5486                            let inputs = vec![
5487                                boxes.view().into_dyn(),
5488                                scores.view().into_dyn(),
5489                                classes.view().into_dyn(),
5490                                mask.view().into_dyn(),
5491                                protos.view().into_dyn(),
5492                            ];
5493                            decoder
5494                                .decode_tracked_float_proto(
5495                                    &mut tracker,
5496                                    0,
5497                                    &inputs,
5498                                    &mut output_boxes,
5499                                    &mut output_tracks,
5500                                )
5501                                .unwrap();
5502                        }
5503                        assert_eq!(output_boxes.len(), 1);
5504                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5505
5506                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5507                            *score = 0.0;
5508                        }
5509                        let proto_result = {
5510                            let inputs = vec![
5511                                boxes.view().into_dyn(),
5512                                scores.view().into_dyn(),
5513                                classes.view().into_dyn(),
5514                                mask.view().into_dyn(),
5515                                protos.view().into_dyn(),
5516                            ];
5517                            decoder
5518                                .decode_tracked_float_proto(
5519                                    &mut tracker,
5520                                    100_000_000 / 3,
5521                                    &inputs,
5522                                    &mut output_boxes,
5523                                    &mut output_tracks,
5524                                )
5525                                .unwrap()
5526                        };
5527                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5528                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5529                    } else {
5530                        let mut output_masks = Vec::with_capacity(50);
5531                        {
5532                            let inputs = vec![
5533                                boxes.view().into_dyn(),
5534                                scores.view().into_dyn(),
5535                                classes.view().into_dyn(),
5536                                mask.view().into_dyn(),
5537                                protos.view().into_dyn(),
5538                            ];
5539                            decoder
5540                                .decode_tracked_float(
5541                                    &mut tracker,
5542                                    0,
5543                                    &inputs,
5544                                    &mut output_boxes,
5545                                    &mut output_masks,
5546                                    &mut output_tracks,
5547                                )
5548                                .unwrap();
5549                        }
5550                        assert_eq!(output_boxes.len(), 1);
5551                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5552
5553                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5554                            *score = 0.0;
5555                        }
5556                        {
5557                            let inputs = vec![
5558                                boxes.view().into_dyn(),
5559                                scores.view().into_dyn(),
5560                                classes.view().into_dyn(),
5561                                mask.view().into_dyn(),
5562                                protos.view().into_dyn(),
5563                            ];
5564                            decoder
5565                                .decode_tracked_float(
5566                                    &mut tracker,
5567                                    100_000_000 / 3,
5568                                    &inputs,
5569                                    &mut output_boxes,
5570                                    &mut output_masks,
5571                                    &mut output_tracks,
5572                                )
5573                                .unwrap();
5574                        }
5575                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5576                        assert!(output_masks.is_empty());
5577                    }
5578                } else {
5579                    // Combined layout
5580                    let detect = ndarray::concatenate![
5581                        Axis(1),
5582                        boxes.view(),
5583                        scores.view(),
5584                        classes.view(),
5585                        mask.view()
5586                    ];
5587                    let mut detect = detect.insert_axis(Axis(0));
5588                    assert_eq!(detect.shape(), &[1, 10, 38]);
5589
5590                    if is_proto {
5591                        {
5592                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5593                            decoder
5594                                .decode_tracked_float_proto(
5595                                    &mut tracker,
5596                                    0,
5597                                    &inputs,
5598                                    &mut output_boxes,
5599                                    &mut output_tracks,
5600                                )
5601                                .unwrap();
5602                        }
5603                        assert_eq!(output_boxes.len(), 1);
5604                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5605
5606                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5607                            *score = 0.0;
5608                        }
5609                        let proto_result = {
5610                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5611                            decoder
5612                                .decode_tracked_float_proto(
5613                                    &mut tracker,
5614                                    100_000_000 / 3,
5615                                    &inputs,
5616                                    &mut output_boxes,
5617                                    &mut output_tracks,
5618                                )
5619                                .unwrap()
5620                        };
5621                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5622                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5623                    } else {
5624                        let mut output_masks = Vec::with_capacity(50);
5625                        {
5626                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5627                            decoder
5628                                .decode_tracked_float(
5629                                    &mut tracker,
5630                                    0,
5631                                    &inputs,
5632                                    &mut output_boxes,
5633                                    &mut output_masks,
5634                                    &mut output_tracks,
5635                                )
5636                                .unwrap();
5637                        }
5638                        assert_eq!(output_boxes.len(), 1);
5639                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5640
5641                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5642                            *score = 0.0;
5643                        }
5644                        {
5645                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5646                            decoder
5647                                .decode_tracked_float(
5648                                    &mut tracker,
5649                                    100_000_000 / 3,
5650                                    &inputs,
5651                                    &mut output_boxes,
5652                                    &mut output_masks,
5653                                    &mut output_tracks,
5654                                )
5655                                .unwrap();
5656                        }
5657                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5658                        assert!(output_masks.is_empty());
5659                    }
5660                }
5661            }
5662        };
5663    }
5664
5665    e2e_tracked_test!(
5666        test_decoder_tracked_end_to_end_segdet,
5667        quantized,
5668        combined,
5669        masks
5670    );
5671    e2e_tracked_test!(
5672        test_decoder_tracked_end_to_end_segdet_float,
5673        float,
5674        combined,
5675        masks
5676    );
5677    e2e_tracked_test!(
5678        test_decoder_tracked_end_to_end_segdet_proto,
5679        quantized,
5680        combined,
5681        proto
5682    );
5683    e2e_tracked_test!(
5684        test_decoder_tracked_end_to_end_segdet_proto_float,
5685        float,
5686        combined,
5687        proto
5688    );
5689    e2e_tracked_test!(
5690        test_decoder_tracked_end_to_end_segdet_split,
5691        quantized,
5692        split,
5693        masks
5694    );
5695    e2e_tracked_test!(
5696        test_decoder_tracked_end_to_end_segdet_split_float,
5697        float,
5698        split,
5699        masks
5700    );
5701    e2e_tracked_test!(
5702        test_decoder_tracked_end_to_end_segdet_split_proto,
5703        quantized,
5704        split,
5705        proto
5706    );
5707    e2e_tracked_test!(
5708        test_decoder_tracked_end_to_end_segdet_split_proto_float,
5709        float,
5710        split,
5711        proto
5712    );
5713
5714    // ─── End-to-end tracked TensorDyn test macro ────────────────────
5715    //
5716    // Same as e2e_tracked_test but wraps data in TensorDyn and exercises
5717    // the public decode_tracked / decode_proto_tracked API.
5718
5719    macro_rules! e2e_tracked_tensor_test {
5720        ($name:ident, quantized, $layout:ident, $output:ident) => {
5721            #[test]
5722            fn $name() {
5723                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5724
5725                let is_split = matches!(stringify!($layout), "split");
5726                let is_proto = matches!(stringify!($output), "proto");
5727
5728                let score_threshold = 0.45;
5729                let iou_threshold = 0.45;
5730
5731                let mut boxes = Array2::zeros((10, 4));
5732                let mut scores = Array2::zeros((10, 1));
5733                let mut classes = Array2::zeros((10, 1));
5734                let mask = Array2::zeros((10, 32));
5735                let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5736                let protos_f64 = protos_f64.insert_axis(Axis(0));
5737                let protos_quant = (1.0 / 255.0, 0.0);
5738                let protos_u8: Array4<u8> =
5739                    quantize_ndarray(protos_f64.view(), protos_quant.into());
5740
5741                boxes
5742                    .slice_mut(s![0, ..])
5743                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5744                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5745                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5746
5747                let detect_quant = (2.0 / 255.0, 0.0);
5748
5749                let decoder = if is_split {
5750                    DecoderBuilder::default()
5751                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5752                        .with_score_threshold(score_threshold)
5753                        .with_iou_threshold(iou_threshold)
5754                        .build()
5755                        .unwrap()
5756                } else {
5757                    DecoderBuilder::default()
5758                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5759                        .with_score_threshold(score_threshold)
5760                        .with_iou_threshold(iou_threshold)
5761                        .build()
5762                        .unwrap()
5763                };
5764
5765                // Helper to wrap a u8 slice into a TensorDyn
5766                let make_u8_tensor =
5767                    |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5768                        let t = Tensor::<u8>::new(shape, None, None).unwrap();
5769                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5770                        t.into()
5771                    };
5772
5773                let expected = e2e_expected_boxes_quant();
5774                let mut tracker = ByteTrackBuilder::new()
5775                    .track_update(0.1)
5776                    .track_high_conf(0.7)
5777                    .build();
5778                let mut output_boxes = Vec::with_capacity(50);
5779                let mut output_tracks = Vec::with_capacity(50);
5780
5781                let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5782
5783                if is_split {
5784                    let boxes = boxes.insert_axis(Axis(0));
5785                    let scores = scores.insert_axis(Axis(0));
5786                    let classes = classes.insert_axis(Axis(0));
5787                    let mask = mask.insert_axis(Axis(0));
5788
5789                    let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5790                    let mut scores_q: Array3<u8> =
5791                        quantize_ndarray(scores.view(), detect_quant.into());
5792                    let classes_q: Array3<u8> =
5793                        quantize_ndarray(classes.view(), detect_quant.into());
5794                    let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5795
5796                    let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5797                    let classes_td =
5798                        make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5799                    let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5800
5801                    if is_proto {
5802                        let scores_td =
5803                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5804                        decoder
5805                            .decode_proto_tracked(
5806                                &mut tracker,
5807                                0,
5808                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5809                                &mut output_boxes,
5810                                &mut output_tracks,
5811                            )
5812                            .unwrap();
5813
5814                        assert_eq!(output_boxes.len(), 1);
5815                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5816
5817                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5818                            *score = u8::MIN;
5819                        }
5820                        let scores_td =
5821                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5822                        let proto_result = decoder
5823                            .decode_proto_tracked(
5824                                &mut tracker,
5825                                100_000_000 / 3,
5826                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5827                                &mut output_boxes,
5828                                &mut output_tracks,
5829                            )
5830                            .unwrap();
5831                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5832                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5833                    } else {
5834                        let scores_td =
5835                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5836                        let mut output_masks = Vec::with_capacity(50);
5837                        decoder
5838                            .decode_tracked(
5839                                &mut tracker,
5840                                0,
5841                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5842                                &mut output_boxes,
5843                                &mut output_masks,
5844                                &mut output_tracks,
5845                            )
5846                            .unwrap();
5847
5848                        assert_eq!(output_boxes.len(), 1);
5849                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5850
5851                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5852                            *score = u8::MIN;
5853                        }
5854                        let scores_td =
5855                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5856                        decoder
5857                            .decode_tracked(
5858                                &mut tracker,
5859                                100_000_000 / 3,
5860                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5861                                &mut output_boxes,
5862                                &mut output_masks,
5863                                &mut output_tracks,
5864                            )
5865                            .unwrap();
5866                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5867                        assert!(output_masks.is_empty());
5868                    }
5869                } else {
5870                    // Combined layout
5871                    let detect = ndarray::concatenate![
5872                        Axis(1),
5873                        boxes.view(),
5874                        scores.view(),
5875                        classes.view(),
5876                        mask.view()
5877                    ];
5878                    let detect = detect.insert_axis(Axis(0));
5879                    assert_eq!(detect.shape(), &[1, 10, 38]);
5880                    // Ensure contiguous layout after concatenation for as_slice()
5881                    let detect =
5882                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5883                            .unwrap();
5884                    let mut detect_q: Array3<u8> =
5885                        quantize_ndarray(detect.view(), detect_quant.into());
5886
5887                    if is_proto {
5888                        let detect_td =
5889                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5890                        decoder
5891                            .decode_proto_tracked(
5892                                &mut tracker,
5893                                0,
5894                                &[&detect_td, &protos_td],
5895                                &mut output_boxes,
5896                                &mut output_tracks,
5897                            )
5898                            .unwrap();
5899
5900                        assert_eq!(output_boxes.len(), 1);
5901                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5902
5903                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5904                            *score = u8::MIN;
5905                        }
5906                        let detect_td =
5907                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5908                        let proto_result = decoder
5909                            .decode_proto_tracked(
5910                                &mut tracker,
5911                                100_000_000 / 3,
5912                                &[&detect_td, &protos_td],
5913                                &mut output_boxes,
5914                                &mut output_tracks,
5915                            )
5916                            .unwrap();
5917                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5918                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5919                    } else {
5920                        let detect_td =
5921                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5922                        let mut output_masks = Vec::with_capacity(50);
5923                        decoder
5924                            .decode_tracked(
5925                                &mut tracker,
5926                                0,
5927                                &[&detect_td, &protos_td],
5928                                &mut output_boxes,
5929                                &mut output_masks,
5930                                &mut output_tracks,
5931                            )
5932                            .unwrap();
5933
5934                        assert_eq!(output_boxes.len(), 1);
5935                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5936
5937                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5938                            *score = u8::MIN;
5939                        }
5940                        let detect_td =
5941                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5942                        decoder
5943                            .decode_tracked(
5944                                &mut tracker,
5945                                100_000_000 / 3,
5946                                &[&detect_td, &protos_td],
5947                                &mut output_boxes,
5948                                &mut output_masks,
5949                                &mut output_tracks,
5950                            )
5951                            .unwrap();
5952                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5953                        assert!(output_masks.is_empty());
5954                    }
5955                }
5956            }
5957        };
5958        ($name:ident, float, $layout:ident, $output:ident) => {
5959            #[test]
5960            fn $name() {
5961                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5962
5963                let is_split = matches!(stringify!($layout), "split");
5964                let is_proto = matches!(stringify!($output), "proto");
5965
5966                let score_threshold = 0.45;
5967                let iou_threshold = 0.45;
5968
5969                let mut boxes = Array2::zeros((10, 4));
5970                let mut scores = Array2::zeros((10, 1));
5971                let mut classes = Array2::zeros((10, 1));
5972                let mask: Array2<f64> = Array2::zeros((10, 32));
5973                let protos = Array3::<f64>::zeros((160, 160, 32));
5974                let protos = protos.insert_axis(Axis(0));
5975
5976                boxes
5977                    .slice_mut(s![0, ..])
5978                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5979                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5980                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5981
5982                let decoder = if is_split {
5983                    DecoderBuilder::default()
5984                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5985                        .with_score_threshold(score_threshold)
5986                        .with_iou_threshold(iou_threshold)
5987                        .build()
5988                        .unwrap()
5989                } else {
5990                    DecoderBuilder::default()
5991                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5992                        .with_score_threshold(score_threshold)
5993                        .with_iou_threshold(iou_threshold)
5994                        .build()
5995                        .unwrap()
5996                };
5997
5998                // Helper to wrap an f64 slice into a TensorDyn
5999                let make_f64_tensor =
6000                    |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
6001                        let t = Tensor::<f64>::new(shape, None, None).unwrap();
6002                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
6003                        t.into()
6004                    };
6005
6006                let expected = e2e_expected_boxes_float();
6007                let mut tracker = ByteTrackBuilder::new()
6008                    .track_update(0.1)
6009                    .track_high_conf(0.7)
6010                    .build();
6011                let mut output_boxes = Vec::with_capacity(50);
6012                let mut output_tracks = Vec::with_capacity(50);
6013
6014                let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
6015
6016                if is_split {
6017                    let boxes = boxes.insert_axis(Axis(0));
6018                    let mut scores = scores.insert_axis(Axis(0));
6019                    let classes = classes.insert_axis(Axis(0));
6020                    let mask = mask.insert_axis(Axis(0));
6021
6022                    let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
6023                    let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
6024                    let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
6025
6026                    if is_proto {
6027                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6028                        decoder
6029                            .decode_proto_tracked(
6030                                &mut tracker,
6031                                0,
6032                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6033                                &mut output_boxes,
6034                                &mut output_tracks,
6035                            )
6036                            .unwrap();
6037
6038                        assert_eq!(output_boxes.len(), 1);
6039                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6040
6041                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
6042                            *score = 0.0;
6043                        }
6044                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6045                        let proto_result = decoder
6046                            .decode_proto_tracked(
6047                                &mut tracker,
6048                                100_000_000 / 3,
6049                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6050                                &mut output_boxes,
6051                                &mut output_tracks,
6052                            )
6053                            .unwrap();
6054                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6055                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6056                    } else {
6057                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6058                        let mut output_masks = Vec::with_capacity(50);
6059                        decoder
6060                            .decode_tracked(
6061                                &mut tracker,
6062                                0,
6063                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6064                                &mut output_boxes,
6065                                &mut output_masks,
6066                                &mut output_tracks,
6067                            )
6068                            .unwrap();
6069
6070                        assert_eq!(output_boxes.len(), 1);
6071                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6072
6073                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
6074                            *score = 0.0;
6075                        }
6076                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6077                        decoder
6078                            .decode_tracked(
6079                                &mut tracker,
6080                                100_000_000 / 3,
6081                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6082                                &mut output_boxes,
6083                                &mut output_masks,
6084                                &mut output_tracks,
6085                            )
6086                            .unwrap();
6087                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6088                        assert!(output_masks.is_empty());
6089                    }
6090                } else {
6091                    // Combined layout
6092                    let detect = ndarray::concatenate![
6093                        Axis(1),
6094                        boxes.view(),
6095                        scores.view(),
6096                        classes.view(),
6097                        mask.view()
6098                    ];
6099                    let detect = detect.insert_axis(Axis(0));
6100                    assert_eq!(detect.shape(), &[1, 10, 38]);
6101                    // Ensure contiguous layout after concatenation for as_slice()
6102                    let mut detect =
6103                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
6104                            .unwrap();
6105
6106                    if is_proto {
6107                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6108                        decoder
6109                            .decode_proto_tracked(
6110                                &mut tracker,
6111                                0,
6112                                &[&detect_td, &protos_td],
6113                                &mut output_boxes,
6114                                &mut output_tracks,
6115                            )
6116                            .unwrap();
6117
6118                        assert_eq!(output_boxes.len(), 1);
6119                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6120
6121                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6122                            *score = 0.0;
6123                        }
6124                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6125                        let proto_result = decoder
6126                            .decode_proto_tracked(
6127                                &mut tracker,
6128                                100_000_000 / 3,
6129                                &[&detect_td, &protos_td],
6130                                &mut output_boxes,
6131                                &mut output_tracks,
6132                            )
6133                            .unwrap();
6134                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6135                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6136                    } else {
6137                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6138                        let mut output_masks = Vec::with_capacity(50);
6139                        decoder
6140                            .decode_tracked(
6141                                &mut tracker,
6142                                0,
6143                                &[&detect_td, &protos_td],
6144                                &mut output_boxes,
6145                                &mut output_masks,
6146                                &mut output_tracks,
6147                            )
6148                            .unwrap();
6149
6150                        assert_eq!(output_boxes.len(), 1);
6151                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6152
6153                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6154                            *score = 0.0;
6155                        }
6156                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6157                        decoder
6158                            .decode_tracked(
6159                                &mut tracker,
6160                                100_000_000 / 3,
6161                                &[&detect_td, &protos_td],
6162                                &mut output_boxes,
6163                                &mut output_masks,
6164                                &mut output_tracks,
6165                            )
6166                            .unwrap();
6167                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6168                        assert!(output_masks.is_empty());
6169                    }
6170                }
6171            }
6172        };
6173    }
6174
6175    e2e_tracked_tensor_test!(
6176        test_decoder_tracked_tensor_end_to_end_segdet,
6177        quantized,
6178        combined,
6179        masks
6180    );
6181    e2e_tracked_tensor_test!(
6182        test_decoder_tracked_tensor_end_to_end_segdet_float,
6183        float,
6184        combined,
6185        masks
6186    );
6187    e2e_tracked_tensor_test!(
6188        test_decoder_tracked_tensor_end_to_end_segdet_proto,
6189        quantized,
6190        combined,
6191        proto
6192    );
6193    e2e_tracked_tensor_test!(
6194        test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
6195        float,
6196        combined,
6197        proto
6198    );
6199    e2e_tracked_tensor_test!(
6200        test_decoder_tracked_tensor_end_to_end_segdet_split,
6201        quantized,
6202        split,
6203        masks
6204    );
6205    e2e_tracked_tensor_test!(
6206        test_decoder_tracked_tensor_end_to_end_segdet_split_float,
6207        float,
6208        split,
6209        masks
6210    );
6211    e2e_tracked_tensor_test!(
6212        test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
6213        quantized,
6214        split,
6215        proto
6216    );
6217    e2e_tracked_tensor_test!(
6218        test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
6219        float,
6220        split,
6221        proto
6222    );
6223
6224    #[test]
6225    fn test_decoder_tracked_linear_motion() {
6226        use crate::configs::{DecoderType, Nms};
6227        use crate::DecoderBuilder;
6228
6229        let score_threshold = 0.25;
6230        let iou_threshold = 0.1;
6231        let out = include_bytes!(concat!(
6232            env!("CARGO_MANIFEST_DIR"),
6233            "/../../testdata/yolov8s_80_classes.bin"
6234        ));
6235        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6236        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6237        let quant = (0.0040811873, -123).into();
6238
6239        let decoder = DecoderBuilder::default()
6240            .with_config_yolo_det(
6241                crate::configs::Detection {
6242                    decoder: DecoderType::Ultralytics,
6243                    shape: vec![1, 84, 8400],
6244                    anchors: None,
6245                    quantization: Some(quant),
6246                    dshape: vec![
6247                        (crate::configs::DimName::Batch, 1),
6248                        (crate::configs::DimName::NumFeatures, 84),
6249                        (crate::configs::DimName::NumBoxes, 8400),
6250                    ],
6251                    normalized: Some(true),
6252                },
6253                None,
6254            )
6255            .with_score_threshold(score_threshold)
6256            .with_iou_threshold(iou_threshold)
6257            .with_nms(Some(Nms::ClassAgnostic))
6258            .build()
6259            .unwrap();
6260
6261        let mut expected_boxes = [
6262            DetectBox {
6263                bbox: BoundingBox {
6264                    xmin: 0.5285137,
6265                    ymin: 0.05305544,
6266                    xmax: 0.87541467,
6267                    ymax: 0.9998909,
6268                },
6269                score: 0.5591227,
6270                label: 0,
6271            },
6272            DetectBox {
6273                bbox: BoundingBox {
6274                    xmin: 0.130598,
6275                    ymin: 0.43260583,
6276                    xmax: 0.35098213,
6277                    ymax: 0.9958097,
6278                },
6279                score: 0.33057618,
6280                label: 75,
6281            },
6282        ];
6283
6284        let mut tracker = ByteTrackBuilder::new()
6285            .track_update(0.1)
6286            .track_high_conf(0.3)
6287            .build();
6288
6289        let mut output_boxes = Vec::with_capacity(50);
6290        let mut output_masks = Vec::with_capacity(50);
6291        let mut output_tracks = Vec::with_capacity(50);
6292
6293        decoder
6294            .decode_tracked_quantized(
6295                &mut tracker,
6296                0,
6297                &[out.view().into()],
6298                &mut output_boxes,
6299                &mut output_masks,
6300                &mut output_tracks,
6301            )
6302            .unwrap();
6303
6304        assert_eq!(output_boxes.len(), 2);
6305        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6306        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
6307
6308        for i in 1..=100 {
6309            let mut out = out.clone();
6310            // introduce linear movement into the XY coordinates
6311            let mut x_values = out.slice_mut(s![0, 0, ..]);
6312            for x in x_values.iter_mut() {
6313                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
6314            }
6315
6316            decoder
6317                .decode_tracked_quantized(
6318                    &mut tracker,
6319                    100_000_000 * i / 3, // simulate 33.333ms between frames
6320                    &[out.view().into()],
6321                    &mut output_boxes,
6322                    &mut output_masks,
6323                    &mut output_tracks,
6324                )
6325                .unwrap();
6326
6327            assert_eq!(output_boxes.len(), 2);
6328        }
6329        let tracks = tracker.get_active_tracks();
6330        let predicted_boxes: Vec<_> = tracks
6331            .iter()
6332            .map(|track| {
6333                let mut l = track.last_box;
6334                l.bbox = track.info.tracked_location.into();
6335                l
6336            })
6337            .collect();
6338        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
6339        expected_boxes[0].bbox.xmax += 0.1;
6340        expected_boxes[1].bbox.xmin += 0.1;
6341        expected_boxes[1].bbox.xmax += 0.1;
6342
6343        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6344        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6345
6346        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6347        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
6348        for score in scores_values.iter_mut() {
6349            *score = i8::MIN; // set all scores to minimum to simulate no detections
6350        }
6351        decoder
6352            .decode_tracked_quantized(
6353                &mut tracker,
6354                100_000_000 * 101 / 3,
6355                &[out.view().into()],
6356                &mut output_boxes,
6357                &mut output_masks,
6358                &mut output_tracks,
6359            )
6360            .unwrap();
6361        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
6362        expected_boxes[0].bbox.xmax += 0.001;
6363        expected_boxes[1].bbox.xmin += 0.001;
6364        expected_boxes[1].bbox.xmax += 0.001;
6365
6366        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6367        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6368    }
6369
6370    #[test]
6371    fn test_decoder_tracked_end_to_end_float() {
6372        let score_threshold = 0.45;
6373        let iou_threshold = 0.45;
6374
6375        let mut boxes = Array2::zeros((10, 4));
6376        let mut scores = Array2::zeros((10, 1));
6377        let mut classes = Array2::zeros((10, 1));
6378
6379        boxes
6380            .slice_mut(s![0, ..,])
6381            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
6382        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
6383        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
6384
6385        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
6386        let mut detect = detect.insert_axis(Axis(0));
6387        assert_eq!(detect.shape(), &[1, 10, 6]);
6388        let config = "
6389decoder_version: yolo26
6390outputs:
6391 - type: detection
6392   decoder: ultralytics
6393   quantization: [0.00784313725490196, 0]
6394   shape: [1, 10, 6]
6395   dshape:
6396    - [batch, 1]
6397    - [num_boxes, 10]
6398    - [num_features, 6]
6399   normalized: true
6400";
6401
6402        let decoder = DecoderBuilder::default()
6403            .with_config_yaml_str(config.to_string())
6404            .with_score_threshold(score_threshold)
6405            .with_iou_threshold(iou_threshold)
6406            .build()
6407            .unwrap();
6408
6409        let expected_boxes = [DetectBox {
6410            bbox: BoundingBox {
6411                xmin: 0.1234,
6412                ymin: 0.1234,
6413                xmax: 0.2345,
6414                ymax: 0.2345,
6415            },
6416            score: 0.9876,
6417            label: 2,
6418        }];
6419
6420        let mut tracker = ByteTrackBuilder::new()
6421            .track_update(0.1)
6422            .track_high_conf(0.7)
6423            .build();
6424
6425        let mut output_boxes = Vec::with_capacity(50);
6426        let mut output_masks = Vec::with_capacity(50);
6427        let mut output_tracks = Vec::with_capacity(50);
6428
6429        decoder
6430            .decode_tracked_float(
6431                &mut tracker,
6432                0,
6433                &[detect.view().into_dyn()],
6434                &mut output_boxes,
6435                &mut output_masks,
6436                &mut output_tracks,
6437            )
6438            .unwrap();
6439
6440        assert_eq!(output_boxes.len(), 1);
6441        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6442
6443        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6444
6445        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6446            *score = 0.0; // set all scores to minimum to simulate no detections
6447        }
6448
6449        decoder
6450            .decode_tracked_float(
6451                &mut tracker,
6452                100_000_000 / 3,
6453                &[detect.view().into_dyn()],
6454                &mut output_boxes,
6455                &mut output_masks,
6456                &mut output_tracks,
6457            )
6458            .unwrap();
6459        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6460    }
6461}