Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod per_scale;
77pub mod schema;
78pub mod yolo;
79
80mod decoder;
81pub use decoder::*;
82
83pub use configs::{DecoderVersion, Nms};
84pub use error::{DecoderError, DecoderResult};
85pub use per_scale::DecodeDtype;
86
87use crate::{
88    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89    yolo::yolo_segmentation_to_mask,
90};
91
92/// Trait to convert bounding box formats to XYXY float format
93pub trait BBoxTypeTrait {
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97    /// Converts the bbox into XYXY float format.
98    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99        input: &[B; 4],
100        quant: Quantization,
101    ) -> [A; 4]
102    where
103        f32: AsPrimitive<A>,
104        i32: AsPrimitive<A>;
105
106    /// Converts the bbox into XYXY float format.
107    ///
108    /// # Examples
109    /// ```rust
110    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
111    /// # use ndarray::array;
112    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
113    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
114    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
115    /// ```
116    #[inline(always)]
117    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118        input: ArrayView1<B>,
119    ) -> [A; 4] {
120        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121    }
122
123    #[inline(always)]
124    /// Converts the bbox into XYXY float format.
125    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126        input: ArrayView1<B>,
127        quant: Quantization,
128    ) -> [A; 4]
129    where
130        f32: AsPrimitive<A>,
131        i32: AsPrimitive<A>,
132    {
133        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134    }
135}
136
137/// Converts XYXY bounding boxes to XYXY
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143        input.map(|b| b.as_())
144    }
145
146    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147        input: &[B; 4],
148        quant: Quantization,
149    ) -> [A; 4]
150    where
151        f32: AsPrimitive<A>,
152        i32: AsPrimitive<A>,
153    {
154        let scale = quant.scale.as_();
155        let zp = quant.zero_point.as_();
156        input.map(|b| (b.as_() - zp) * scale)
157    }
158
159    #[inline(always)]
160    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161        input: ArrayView1<B>,
162    ) -> [A; 4] {
163        [
164            input[0].as_(),
165            input[1].as_(),
166            input[2].as_(),
167            input[3].as_(),
168        ]
169    }
170}
171
172/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
173/// box
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178    #[inline(always)]
179    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180        let half = A::one() / (A::one() + A::one());
181        [
182            (input[0].as_()) - (input[2].as_() * half),
183            (input[1].as_()) - (input[3].as_() * half),
184            (input[0].as_()) + (input[2].as_() * half),
185            (input[1].as_()) + (input[3].as_() * half),
186        ]
187    }
188
189    #[inline(always)]
190    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191        input: &[B; 4],
192        quant: Quantization,
193    ) -> [A; 4]
194    where
195        f32: AsPrimitive<A>,
196        i32: AsPrimitive<A>,
197    {
198        let scale = quant.scale.as_();
199        let half_scale = (quant.scale * 0.5).as_();
200        let zp = quant.zero_point.as_();
201        let [x, y, w, h] = [
202            (input[0].as_() - zp) * scale,
203            (input[1].as_() - zp) * scale,
204            (input[2].as_() - zp) * half_scale,
205            (input[3].as_() - zp) * half_scale,
206        ];
207
208        [x - w, y - h, x + w, y + h]
209    }
210
211    #[inline(always)]
212    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213        input: ArrayView1<B>,
214    ) -> [A; 4] {
215        let half = A::one() / (A::one() + A::one());
216        [
217            (input[0].as_()) - (input[2].as_() * half),
218            (input[1].as_()) - (input[3].as_() * half),
219            (input[0].as_()) + (input[2].as_() * half),
220            (input[1].as_()) + (input[3].as_() * half),
221        ]
222    }
223}
224
225/// Describes the quantization parameters for a tensor
226#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228    pub scale: f32,
229    pub zero_point: i32,
230}
231
232impl Quantization {
233    /// Creates a new Quantization struct
234    /// # Examples
235    /// ```
236    /// # use edgefirst_decoder::Quantization;
237    /// let quant = Quantization::new(0.1, -128);
238    /// assert_eq!(quant.scale, 0.1);
239    /// assert_eq!(quant.zero_point, -128);
240    /// ```
241    pub fn new(scale: f32, zero_point: i32) -> Self {
242        Self { scale, zero_point }
243    }
244
245    /// Returns a quantization that's a no-op: scale=1.0, zero_point=0.
246    ///
247    /// Used by the per-scale pipeline as a sentinel for float-to-float
248    /// passthrough cells where no affine transform is required.
249    /// # Examples
250    /// ```
251    /// # use edgefirst_decoder::Quantization;
252    /// let quant = Quantization::identity();
253    /// assert_eq!(quant.scale, 1.0);
254    /// assert_eq!(quant.zero_point, 0);
255    /// ```
256    pub fn identity() -> Self {
257        Self {
258            scale: 1.0,
259            zero_point: 0,
260        }
261    }
262}
263
264impl From<QuantTuple> for Quantization {
265    /// Creates a new Quantization struct from a QuantTuple
266    /// # Examples
267    /// ```
268    /// # use edgefirst_decoder::Quantization;
269    /// # use edgefirst_decoder::configs::QuantTuple;
270    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
271    /// let quant = Quantization::from(quant_tuple);
272    /// assert_eq!(quant.scale, 0.1);
273    /// assert_eq!(quant.zero_point, -128);
274    /// ```
275    fn from(quant_tuple: QuantTuple) -> Quantization {
276        Quantization {
277            scale: quant_tuple.0,
278            zero_point: quant_tuple.1,
279        }
280    }
281}
282
283impl<S, Z> From<(S, Z)> for Quantization
284where
285    S: AsPrimitive<f32>,
286    Z: AsPrimitive<i32>,
287{
288    /// Creates a new Quantization struct from a tuple
289    /// # Examples
290    /// ```
291    /// # use edgefirst_decoder::Quantization;
292    /// let quant = Quantization::from((0.1_f64, -128_i64));
293    /// assert_eq!(quant.scale, 0.1);
294    /// assert_eq!(quant.zero_point, -128);
295    /// ```
296    fn from((scale, zp): (S, Z)) -> Quantization {
297        Self {
298            scale: scale.as_(),
299            zero_point: zp.as_(),
300        }
301    }
302}
303
304impl Default for Quantization {
305    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
306    /// # Examples
307    /// ```rust
308    /// # use edgefirst_decoder::Quantization;
309    /// let quant = Quantization::default();
310    /// assert_eq!(quant.scale, 1.0);
311    /// assert_eq!(quant.zero_point, 0);
312    /// ```
313    fn default() -> Self {
314        Self {
315            scale: 1.0,
316            zero_point: 0,
317        }
318    }
319}
320
321/// A detection box with f32 bbox and score
322#[derive(Debug, Clone, Copy, PartialEq, Default)]
323pub struct DetectBox {
324    pub bbox: BoundingBox,
325    /// model-specific score for this detection, higher implies more confidence
326    pub score: f32,
327    /// label index for this detection
328    pub label: usize,
329}
330
331/// A bounding box with f32 coordinates in XYXY format
332#[derive(Debug, Clone, Copy, PartialEq, Default)]
333pub struct BoundingBox {
334    /// left-most normalized coordinate of the bounding box
335    pub xmin: f32,
336    /// top-most normalized coordinate of the bounding box
337    pub ymin: f32,
338    /// right-most normalized coordinate of the bounding box
339    pub xmax: f32,
340    /// bottom-most normalized coordinate of the bounding box
341    pub ymax: f32,
342}
343
344impl BoundingBox {
345    /// Creates a new BoundingBox from the given coordinates
346    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
347        Self {
348            xmin,
349            ymin,
350            xmax,
351            ymax,
352        }
353    }
354
355    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
356    ///
357    /// ```
358    /// # use edgefirst_decoder::BoundingBox;
359    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
360    /// let canonical_bbox = bbox.to_canonical();
361    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
362    /// ```
363    pub fn to_canonical(&self) -> Self {
364        let xmin = self.xmin.min(self.xmax);
365        let xmax = self.xmin.max(self.xmax);
366        let ymin = self.ymin.min(self.ymax);
367        let ymax = self.ymin.max(self.ymax);
368        BoundingBox {
369            xmin,
370            ymin,
371            xmax,
372            ymax,
373        }
374    }
375}
376
377impl From<BoundingBox> for [f32; 4] {
378    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
379    /// xmax, ymax order
380    /// # Examples
381    /// ```
382    /// # use edgefirst_decoder::BoundingBox;
383    /// let bbox = BoundingBox {
384    ///     xmin: 0.1,
385    ///     ymin: 0.2,
386    ///     xmax: 0.3,
387    ///     ymax: 0.4,
388    /// };
389    /// let arr: [f32; 4] = bbox.into();
390    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
391    /// ```
392    fn from(b: BoundingBox) -> Self {
393        [b.xmin, b.ymin, b.xmax, b.ymax]
394    }
395}
396
397impl From<[f32; 4]> for BoundingBox {
398    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
399    // BoundingBox
400    fn from(arr: [f32; 4]) -> Self {
401        BoundingBox {
402            xmin: arr[0],
403            ymin: arr[1],
404            xmax: arr[2],
405            ymax: arr[3],
406        }
407    }
408}
409
410impl DetectBox {
411    /// Returns true if one detect box is equal to another detect box, within
412    /// the given `eps`
413    ///
414    /// # Examples
415    /// ```
416    /// # use edgefirst_decoder::DetectBox;
417    /// let box1 = DetectBox {
418    ///     bbox: edgefirst_decoder::BoundingBox {
419    ///         xmin: 0.1,
420    ///         ymin: 0.2,
421    ///         xmax: 0.3,
422    ///         ymax: 0.4,
423    ///     },
424    ///     score: 0.5,
425    ///     label: 1,
426    /// };
427    /// let box2 = DetectBox {
428    ///     bbox: edgefirst_decoder::BoundingBox {
429    ///         xmin: 0.101,
430    ///         ymin: 0.199,
431    ///         xmax: 0.301,
432    ///         ymax: 0.399,
433    ///     },
434    ///     score: 0.510,
435    ///     label: 1,
436    /// };
437    /// assert!(box1.equal_within_delta(&box2, 0.011));
438    /// ```
439    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
440        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
441        self.label == rhs.label
442            && eq_delta(self.score, rhs.score)
443            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
444            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
445            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
446            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
447    }
448}
449
450/// A segmentation result paired with the normalized bounding box that
451/// describes the spatial extent of `segmentation`.
452///
453/// The `xmin`/`ymin`/`xmax`/`ymax` fields describe the **mask region** —
454/// the proto-grid-aligned crop the [`segmentation`] tensor was sliced
455/// from. They are quantized to the proto grid step (`1/proto_height` and
456/// `1/proto_width`, typically `1/160`), so the region's origin floors and
457/// its extent ceils relative to the companion [`DetectBox`]'s `bbox`. The
458/// detection bbox itself stays un-snapped (see EDGEAI-1304); use
459/// `Segmentation` bounds for rendering masks, and `DetectBox.bbox` for
460/// IoU evaluation, drawing detection rectangles, etc. The mask region
461/// always encloses the detection bbox.
462#[derive(Debug, Clone, PartialEq, Default)]
463pub struct Segmentation {
464    /// Left-most normalized coordinate of the mask region (proto-grid
465    /// floor of the companion `DetectBox.bbox.xmin`).
466    pub xmin: f32,
467    /// Top-most normalized coordinate of the mask region (proto-grid
468    /// floor of the companion `DetectBox.bbox.ymin`).
469    pub ymin: f32,
470    /// Right-most normalized coordinate of the mask region (proto-grid
471    /// ceil of the companion `DetectBox.bbox.xmax`).
472    pub xmax: f32,
473    /// Bottom-most normalized coordinate of the mask region (proto-grid
474    /// ceil of the companion `DetectBox.bbox.ymax`).
475    pub ymax: f32,
476    /// 3D segmentation array of shape `(H, W, C)`.
477    ///
478    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
479    /// continuous sigmoid confidence values quantized to u8 (0 = background,
480    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
481    /// 0.5) or use smooth interpolation for anti-aliased edges.
482    ///
483    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
484    /// class scores where the object class is the argmax index.
485    pub segmentation: Array3<u8>,
486}
487
488/// Memory layout of the prototype tensor within [`ProtoData`].
489///
490/// Models may output protos in either channel-last (NHWC) or channel-first
491/// (NCHW) layout. The mask materialisation kernels dispatch on this field to
492/// avoid a costly per-frame transpose.
493#[derive(Debug, Clone, Copy, PartialEq, Eq)]
494pub enum ProtoLayout {
495    /// Channel-last: tensor shape is `[H, W, K]`, contiguous along K.
496    /// This is the traditional layout produced by the `extract_proto_data`
497    /// path after transposing from NCHW model outputs.
498    Nhwc,
499    /// Channel-first: tensor shape is `[K, H, W]`, each channel plane is
500    /// contiguous. Skipping the NCHW→NHWC transpose saves ~3 ms per frame
501    /// on Cortex-A53/A55 targets (819 KB for 32×160×160 protos).
502    Nchw,
503}
504
505/// Raw prototype data for fused decode+render pipelines.
506///
507/// Holds post-NMS intermediate state before mask materialization, allowing the
508/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
509/// shader) without materializing intermediate `Array3<u8>` masks.
510///
511/// Both fields are carried as [`TensorDyn`] so downstream consumers (Rust, C
512/// API, Python) get zero-copy typed access through the HAL's shared tensor
513/// infrastructure. Dtype policy:
514///
515/// | Source model | protos dtype | mask_coefficients dtype | protos.quantization |
516/// |---|---|---|---|
517/// | int8 quantized | [`TensorDyn::I8`] | [`TensorDyn::I8`] (raw + quantization) | `Some(q)` |
518/// | f32 | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
519/// | f16 (TensorRT fp16) | [`TensorDyn::F16`] | [`TensorDyn::F16`] | `None` |
520/// | f64 (narrowed) | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
521///
522/// Quantization metadata lives on the proto tensor itself via
523/// [`edgefirst_tensor::Tensor::quantization`] — float tensors cannot carry
524/// quantization (compile-time gated on the `IntegerType` sealed trait).
525///
526/// `TensorDyn` is not `Clone`, so neither is `ProtoData`. Consumers that need
527/// to share the proto buffer across threads should use `TensorDyn::clone_fd`
528/// / `dmabuf_clone` to dup the backing fd.
529#[derive(Debug)]
530pub struct ProtoData {
531    /// Per-detection mask coefficients, shape `[num_detections, num_protos]`.
532    pub mask_coefficients: edgefirst_tensor::TensorDyn,
533    /// Prototype tensor.
534    ///
535    /// - When `layout == ProtoLayout::Nhwc`: shape is `[proto_h, proto_w, num_protos]`.
536    /// - When `layout == ProtoLayout::Nchw`: shape is `[num_protos, proto_h, proto_w]`.
537    pub protos: edgefirst_tensor::TensorDyn,
538    /// Physical memory layout of the `protos` tensor.
539    pub layout: ProtoLayout,
540}
541
542/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
543///
544///  # Examples
545/// ```
546/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
547/// let quant = Quantization::new(0.1, -128);
548/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
549/// let detect_quant = DetectBoxQuantized {
550///     bbox,
551///     score: 100_i8,
552///     label: 1,
553/// };
554/// let detect = dequant_detect_box(&detect_quant, quant);
555/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
556/// assert_eq!(detect.label, 1);
557/// assert_eq!(detect.bbox, bbox);
558/// ```
559pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
560    detect: &DetectBoxQuantized<SCORE>,
561    quant_scores: Quantization,
562) -> DetectBox {
563    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
564    DetectBox {
565        bbox: detect.bbox,
566        score: quant_scores.scale * detect.score.as_() + scaled_zp,
567        label: detect.label,
568    }
569}
570/// A detection box with a f32 bbox and quantized score
571#[derive(Debug, Clone, Copy, PartialEq)]
572pub struct DetectBoxQuantized<
573    // BOX: Signed + PrimInt + AsPrimitive<f32>,
574    SCORE: PrimInt + AsPrimitive<f32>,
575> {
576    // pub bbox: BoundingBoxQuantized<BOX>,
577    pub bbox: BoundingBox,
578    /// model-specific score for this detection, higher implies more
579    /// confidence.
580    pub score: SCORE,
581    /// label index for this detect
582    pub label: usize,
583}
584
585/// Dequantizes an ndarray from quantized values to f32 values using the given
586/// quantization parameters
587///
588/// # Examples
589/// ```
590/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
591/// let quant = Quantization::new(0.1, -128);
592/// let input: Vec<i8> = vec![0, 127, -128, 64];
593/// let input_array = ndarray::Array1::from(input);
594/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
595/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
596/// ```
597pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
598    input: ArrayView<T, D>,
599    quant: Quantization,
600) -> Array<F, D>
601where
602    i32: num_traits::AsPrimitive<F>,
603    f32: num_traits::AsPrimitive<F>,
604{
605    let zero_point = quant.zero_point.as_();
606    let scale = quant.scale.as_();
607    if zero_point != F::zero() {
608        let scaled_zero = -zero_point * scale;
609        input.mapv(|d| d.as_() * scale + scaled_zero)
610    } else {
611        input.mapv(|d| d.as_() * scale)
612    }
613}
614
615/// Dequantizes a slice from quantized values to float values using the given
616/// quantization parameters
617///
618/// # Examples
619/// ```
620/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
621/// let quant = Quantization::new(0.1, -128);
622/// let input: Vec<i8> = vec![0, 127, -128, 64];
623/// let mut output: Vec<f32> = vec![0.0; input.len()];
624/// dequantize_cpu(&input, quant, &mut output);
625/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
626/// ```
627pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
628    input: &[T],
629    quant: Quantization,
630    output: &mut [F],
631) where
632    f32: num_traits::AsPrimitive<F>,
633    i32: num_traits::AsPrimitive<F>,
634{
635    assert!(input.len() == output.len());
636    let zero_point = quant.zero_point.as_();
637    let scale = quant.scale.as_();
638    if zero_point != F::zero() {
639        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
640        input
641            .iter()
642            .zip(output)
643            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
644    } else {
645        input
646            .iter()
647            .zip(output)
648            .for_each(|(d, deq)| *deq = d.as_() * scale);
649    }
650}
651
652/// Dequantizes a slice from quantized values to float values using the given
653/// quantization parameters, using chunked processing. This is around 5% faster
654/// than `dequantize_cpu` for large slices.
655///
656/// # Examples
657/// ```
658/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
659/// let quant = Quantization::new(0.1, -128);
660/// let input: Vec<i8> = vec![0, 127, -128, 64];
661/// let mut output: Vec<f32> = vec![0.0; input.len()];
662/// dequantize_cpu_chunked(&input, quant, &mut output);
663/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
664/// ```
665pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
666    input: &[T],
667    quant: Quantization,
668    output: &mut [F],
669) where
670    f32: num_traits::AsPrimitive<F>,
671    i32: num_traits::AsPrimitive<F>,
672{
673    assert!(input.len() == output.len());
674    let zero_point = quant.zero_point.as_();
675    let scale = quant.scale.as_();
676
677    let input = input.as_chunks::<4>();
678    let output = output.as_chunks_mut::<4>();
679
680    if zero_point != F::zero() {
681        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
682
683        input
684            .0
685            .iter()
686            .zip(output.0)
687            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
688        input
689            .1
690            .iter()
691            .zip(output.1)
692            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
693    } else {
694        input
695            .0
696            .iter()
697            .zip(output.0)
698            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
699        input
700            .1
701            .iter()
702            .zip(output.1)
703            .for_each(|(d, deq)| *deq = d.as_() * scale);
704    }
705}
706
707/// Converts a segmentation tensor into a 2D mask
708/// If the last dimension of the segmentation tensor is 1, values equal or
709/// above 128 are considered objects. Otherwise the object is the argmax index
710///
711/// # Errors
712///
713/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
714/// invalid shape.
715///
716/// # Examples
717/// ```
718/// # use edgefirst_decoder::segmentation_to_mask;
719/// let segmentation =
720///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
721/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
722/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
723/// ```
724pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
725    if segmentation.shape()[2] == 0 {
726        return Err(DecoderError::InvalidShape(
727            "Segmentation tensor must have non-zero depth".to_string(),
728        ));
729    }
730    if segmentation.shape()[2] == 1 {
731        yolo_segmentation_to_mask(segmentation, 128)
732    } else {
733        Ok(modelpack_segmentation_to_mask(segmentation))
734    }
735}
736
737/// Returns the maximum value and its index from a 1D array
738fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
739    score
740        .iter()
741        .enumerate()
742        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
743            if max > *s {
744                (max, arg_max)
745            } else {
746                (*s, ind)
747            }
748        })
749}
750
751/// NEON-accelerated argmax for i8 slices on aarch64.
752///
753/// Finds the maximum value and its index (last index wins on ties, matching
754/// the semantics of [`arg_max`]).  Falls back to the scalar implementation
755/// on non-aarch64 targets or when the slice is too short to benefit from NEON.
756#[cfg(target_arch = "aarch64")]
757pub(crate) fn arg_max_i8(scores: &[i8]) -> (i8, usize) {
758    use std::arch::aarch64::*;
759
760    let n = scores.len();
761    if n < 16 {
762        // Scalar fallback for very short slices.
763        let mut max = scores[0];
764        let mut idx = 0;
765        for (i, &s) in scores.iter().enumerate().skip(1) {
766            if s >= max {
767                max = s;
768                idx = i;
769            }
770        }
771        return (max, idx);
772    }
773
774    unsafe {
775        // Phase 1: Find the global max value using NEON horizontal max.
776        let chunks = n / 16;
777        let mut vmax = vld1q_s8(scores.as_ptr());
778        for i in 1..chunks {
779            let v = vld1q_s8(scores.as_ptr().add(i * 16));
780            vmax = vmaxq_s8(vmax, v);
781        }
782        let global_max = vmaxvq_s8(vmax);
783
784        // Handle remainder with scalar
785        let remainder_start = chunks * 16;
786        let mut final_max = global_max;
787        for &s in &scores[remainder_start..] {
788            if s > final_max {
789                final_max = s;
790            }
791        }
792
793        // Phase 2: Find the LAST index of `final_max` (preserves tie semantics).
794        // Scan backwards for the last occurrence.
795        let mut idx = 0;
796        for i in (0..n).rev() {
797            if scores[i] == final_max {
798                idx = i;
799                break;
800            }
801        }
802        (final_max, idx)
803    }
804}
805#[cfg(test)]
806#[cfg_attr(coverage_nightly, coverage(off))]
807mod decoder_tests {
808    #![allow(clippy::excessive_precision)]
809    use crate::{
810        configs::{DecoderType, DimName, Protos},
811        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
812        yolo::{
813            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
814            decode_yolo_segdet_quant,
815        },
816        *,
817    };
818    use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
819    use ndarray::Dimension;
820    use ndarray::{array, s, Array2, Array3, Array4, Axis};
821    use ndarray_stats::DeviationExt;
822    use num_traits::{AsPrimitive, PrimInt};
823
824    fn compare_outputs(
825        boxes: (&[DetectBox], &[DetectBox]),
826        masks: (&[Segmentation], &[Segmentation]),
827    ) {
828        let (boxes0, boxes1) = boxes;
829        let (masks0, masks1) = masks;
830
831        assert_eq!(boxes0.len(), boxes1.len());
832        assert_eq!(masks0.len(), masks1.len());
833
834        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
835            assert!(
836                b_i8.equal_within_delta(b_f32, 1e-6),
837                "{b_i8:?} is not equal to {b_f32:?}"
838            );
839        }
840
841        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
842            assert_eq!(
843                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
844                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
845            );
846            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
847            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
848            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
849            let diff = &mask_i8 - &mask_f32;
850            for x in 0..diff.shape()[0] {
851                for y in 0..diff.shape()[1] {
852                    for z in 0..diff.shape()[2] {
853                        let val = diff[[x, y, z]];
854                        assert!(
855                            val.abs() <= 1,
856                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
857                            x,
858                            y,
859                            z,
860                            val
861                        );
862                    }
863                }
864            }
865            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
866            assert!(
867                mean_sq_err < 1e-2,
868                "Mean Square Error between masks was greater than 1%: {:.2}%",
869                mean_sq_err * 100.0
870            );
871        }
872    }
873
874    // ─── Shared test data loaders ────────────────────────
875
876    fn load_yolov8_boxes() -> Array3<i8> {
877        let raw = include_bytes!(concat!(
878            env!("CARGO_MANIFEST_DIR"),
879            "/../../testdata/yolov8_boxes_116x8400.bin"
880        ));
881        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
882        Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
883    }
884
885    fn load_yolov8_protos() -> Array4<i8> {
886        let raw = include_bytes!(concat!(
887            env!("CARGO_MANIFEST_DIR"),
888            "/../../testdata/yolov8_protos_160x160x32.bin"
889        ));
890        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
891        Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
892    }
893
894    fn load_yolov8s_det() -> Array3<i8> {
895        let raw = include_bytes!(concat!(
896            env!("CARGO_MANIFEST_DIR"),
897            "/../../testdata/yolov8s_80_classes.bin"
898        ));
899        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
900        Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
901    }
902
903    #[test]
904    fn test_decoder_modelpack() {
905        let score_threshold = 0.45;
906        let iou_threshold = 0.45;
907        let boxes = include_bytes!(concat!(
908            env!("CARGO_MANIFEST_DIR"),
909            "/../../testdata/modelpack_boxes_1935x1x4.bin"
910        ));
911        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
912
913        let scores = include_bytes!(concat!(
914            env!("CARGO_MANIFEST_DIR"),
915            "/../../testdata/modelpack_scores_1935x1.bin"
916        ));
917        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
918
919        let quant_boxes = (0.004656755365431309, 21).into();
920        let quant_scores = (0.0019603664986789227, 0).into();
921
922        let decoder = DecoderBuilder::default()
923            .with_config_modelpack_det(
924                configs::Boxes {
925                    decoder: DecoderType::ModelPack,
926                    quantization: Some(quant_boxes),
927                    shape: vec![1, 1935, 1, 4],
928                    dshape: vec![
929                        (DimName::Batch, 1),
930                        (DimName::NumBoxes, 1935),
931                        (DimName::Padding, 1),
932                        (DimName::BoxCoords, 4),
933                    ],
934                    normalized: Some(true),
935                },
936                configs::Scores {
937                    decoder: DecoderType::ModelPack,
938                    quantization: Some(quant_scores),
939                    shape: vec![1, 1935, 1],
940                    dshape: vec![
941                        (DimName::Batch, 1),
942                        (DimName::NumBoxes, 1935),
943                        (DimName::NumClasses, 1),
944                    ],
945                },
946            )
947            .with_score_threshold(score_threshold)
948            .with_iou_threshold(iou_threshold)
949            .build()
950            .unwrap();
951
952        let quant_boxes = quant_boxes.into();
953        let quant_scores = quant_scores.into();
954
955        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
956        decode_modelpack_det(
957            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
958            (scores.slice(s![0, .., ..]), quant_scores),
959            score_threshold,
960            iou_threshold,
961            300,
962            &mut output_boxes,
963        );
964        assert!(output_boxes[0].equal_within_delta(
965            &DetectBox {
966                bbox: BoundingBox {
967                    xmin: 0.40513772,
968                    ymin: 0.6379755,
969                    xmax: 0.5122431,
970                    ymax: 0.7730214,
971                },
972                score: 0.4861709,
973                label: 0
974            },
975            1e-6
976        ));
977
978        let mut output_boxes1 = Vec::with_capacity(50);
979        let mut output_masks1 = Vec::with_capacity(50);
980
981        decoder
982            .decode_quantized(
983                &[boxes.view().into(), scores.view().into()],
984                &mut output_boxes1,
985                &mut output_masks1,
986            )
987            .unwrap();
988
989        let mut output_boxes_float = Vec::with_capacity(50);
990        let mut output_masks_float = Vec::with_capacity(50);
991
992        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
993        let scores = dequantize_ndarray(scores.view(), quant_scores);
994
995        decoder
996            .decode_float::<f32>(
997                &[boxes.view().into_dyn(), scores.view().into_dyn()],
998                &mut output_boxes_float,
999                &mut output_masks_float,
1000            )
1001            .unwrap();
1002
1003        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1004        compare_outputs(
1005            (&output_boxes, &output_boxes_float),
1006            (&[], &output_masks_float),
1007        );
1008    }
1009
1010    #[test]
1011    fn test_decoder_modelpack_split_u8() {
1012        let score_threshold = 0.45;
1013        let iou_threshold = 0.45;
1014        let detect0 = include_bytes!(concat!(
1015            env!("CARGO_MANIFEST_DIR"),
1016            "/../../testdata/modelpack_split_9x15x18.bin"
1017        ));
1018        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1019
1020        let detect1 = include_bytes!(concat!(
1021            env!("CARGO_MANIFEST_DIR"),
1022            "/../../testdata/modelpack_split_17x30x18.bin"
1023        ));
1024        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1025
1026        let quant0 = (0.08547406643629074, 174).into();
1027        let quant1 = (0.09929127991199493, 183).into();
1028        let anchors0 = vec![
1029            [0.36666667461395264, 0.31481480598449707],
1030            [0.38749998807907104, 0.4740740656852722],
1031            [0.5333333611488342, 0.644444465637207],
1032        ];
1033        let anchors1 = vec![
1034            [0.13750000298023224, 0.2074074000120163],
1035            [0.2541666626930237, 0.21481481194496155],
1036            [0.23125000298023224, 0.35185185074806213],
1037        ];
1038
1039        let detect_config0 = configs::Detection {
1040            decoder: DecoderType::ModelPack,
1041            shape: vec![1, 9, 15, 18],
1042            anchors: Some(anchors0.clone()),
1043            quantization: Some(quant0),
1044            dshape: vec![
1045                (DimName::Batch, 1),
1046                (DimName::Height, 9),
1047                (DimName::Width, 15),
1048                (DimName::NumAnchorsXFeatures, 18),
1049            ],
1050            normalized: Some(true),
1051        };
1052
1053        let detect_config1 = configs::Detection {
1054            decoder: DecoderType::ModelPack,
1055            shape: vec![1, 17, 30, 18],
1056            anchors: Some(anchors1.clone()),
1057            quantization: Some(quant1),
1058            dshape: vec![
1059                (DimName::Batch, 1),
1060                (DimName::Height, 17),
1061                (DimName::Width, 30),
1062                (DimName::NumAnchorsXFeatures, 18),
1063            ],
1064            normalized: Some(true),
1065        };
1066
1067        let config0 = (&detect_config0).try_into().unwrap();
1068        let config1 = (&detect_config1).try_into().unwrap();
1069
1070        let decoder = DecoderBuilder::default()
1071            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
1072            .with_score_threshold(score_threshold)
1073            .with_iou_threshold(iou_threshold)
1074            .build()
1075            .unwrap();
1076
1077        let quant0 = quant0.into();
1078        let quant1 = quant1.into();
1079
1080        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1081        decode_modelpack_split_quant(
1082            &[
1083                detect0.slice(s![0, .., .., ..]),
1084                detect1.slice(s![0, .., .., ..]),
1085            ],
1086            &[config0, config1],
1087            score_threshold,
1088            iou_threshold,
1089            300,
1090            &mut output_boxes,
1091        );
1092        assert!(output_boxes[0].equal_within_delta(
1093            &DetectBox {
1094                bbox: BoundingBox {
1095                    xmin: 0.43171933,
1096                    ymin: 0.68243736,
1097                    xmax: 0.5626645,
1098                    ymax: 0.808863,
1099                },
1100                score: 0.99240804,
1101                label: 0
1102            },
1103            1e-6
1104        ));
1105
1106        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1107        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1108        decoder
1109            .decode_quantized(
1110                &[detect0.view().into(), detect1.view().into()],
1111                &mut output_boxes1,
1112                &mut output_masks1,
1113            )
1114            .unwrap();
1115
1116        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1117        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1118
1119        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1120        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1121        decoder
1122            .decode_float::<f32>(
1123                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1124                &mut output_boxes1_f32,
1125                &mut output_masks1_f32,
1126            )
1127            .unwrap();
1128
1129        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1130        compare_outputs(
1131            (&output_boxes, &output_boxes1_f32),
1132            (&[], &output_masks1_f32),
1133        );
1134    }
1135
1136    #[test]
1137    fn test_decoder_parse_config_modelpack_split_u8() {
1138        let score_threshold = 0.45;
1139        let iou_threshold = 0.45;
1140        let detect0 = include_bytes!(concat!(
1141            env!("CARGO_MANIFEST_DIR"),
1142            "/../../testdata/modelpack_split_9x15x18.bin"
1143        ));
1144        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1145
1146        let detect1 = include_bytes!(concat!(
1147            env!("CARGO_MANIFEST_DIR"),
1148            "/../../testdata/modelpack_split_17x30x18.bin"
1149        ));
1150        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1151
1152        let decoder = DecoderBuilder::default()
1153            .with_config_yaml_str(
1154                include_str!(concat!(
1155                    env!("CARGO_MANIFEST_DIR"),
1156                    "/../../testdata/modelpack_split.yaml"
1157                ))
1158                .to_string(),
1159            )
1160            .with_score_threshold(score_threshold)
1161            .with_iou_threshold(iou_threshold)
1162            .build()
1163            .unwrap();
1164
1165        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1166        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1167        decoder
1168            .decode_quantized(
1169                &[
1170                    ArrayViewDQuantized::from(detect1.view()),
1171                    ArrayViewDQuantized::from(detect0.view()),
1172                ],
1173                &mut output_boxes,
1174                &mut output_masks,
1175            )
1176            .unwrap();
1177        assert!(output_boxes[0].equal_within_delta(
1178            &DetectBox {
1179                bbox: BoundingBox {
1180                    xmin: 0.43171933,
1181                    ymin: 0.68243736,
1182                    xmax: 0.5626645,
1183                    ymax: 0.808863,
1184                },
1185                score: 0.99240804,
1186                label: 0
1187            },
1188            1e-6
1189        ));
1190    }
1191
1192    #[test]
1193    fn test_modelpack_seg() {
1194        let out = include_bytes!(concat!(
1195            env!("CARGO_MANIFEST_DIR"),
1196            "/../../testdata/modelpack_seg_2x160x160.bin"
1197        ));
1198        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1199        let quant = (1.0 / 255.0, 0).into();
1200
1201        let decoder = DecoderBuilder::default()
1202            .with_config_modelpack_seg(configs::Segmentation {
1203                decoder: DecoderType::ModelPack,
1204                quantization: Some(quant),
1205                shape: vec![1, 2, 160, 160],
1206                dshape: vec![
1207                    (DimName::Batch, 1),
1208                    (DimName::NumClasses, 2),
1209                    (DimName::Height, 160),
1210                    (DimName::Width, 160),
1211                ],
1212            })
1213            .build()
1214            .unwrap();
1215        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1216        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1217        decoder
1218            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1219            .unwrap();
1220
1221        let mut mask = out.slice(s![0, .., .., ..]);
1222        mask.swap_axes(0, 1);
1223        mask.swap_axes(1, 2);
1224        let mask = [Segmentation {
1225            xmin: 0.0,
1226            ymin: 0.0,
1227            xmax: 1.0,
1228            ymax: 1.0,
1229            segmentation: mask.into_owned(),
1230        }];
1231        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1232
1233        decoder
1234            .decode_float::<f32>(
1235                &[dequantize_ndarray(out.view(), quant.into())
1236                    .view()
1237                    .into_dyn()],
1238                &mut output_boxes,
1239                &mut output_masks,
1240            )
1241            .unwrap();
1242
1243        // not expected for float decoder to have same values as quantized decoder, as
1244        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1245        // the model output. Thus the float output is the same as the quantized output
1246        // but scaled differently. However, it is expected that the mask after argmax
1247        // will be the same.
1248        compare_outputs((&[], &output_boxes), (&[], &[]));
1249        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1250        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1251
1252        assert_eq!(mask0, mask1);
1253    }
1254    #[test]
1255    fn test_modelpack_seg_quant() {
1256        let out = include_bytes!(concat!(
1257            env!("CARGO_MANIFEST_DIR"),
1258            "/../../testdata/modelpack_seg_2x160x160.bin"
1259        ));
1260        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1261        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1262        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1263        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1264        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1265        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1266
1267        let quant = (1.0 / 255.0, 0).into();
1268
1269        let decoder = DecoderBuilder::default()
1270            .with_config_modelpack_seg(configs::Segmentation {
1271                decoder: DecoderType::ModelPack,
1272                quantization: Some(quant),
1273                shape: vec![1, 2, 160, 160],
1274                dshape: vec![
1275                    (DimName::Batch, 1),
1276                    (DimName::NumClasses, 2),
1277                    (DimName::Height, 160),
1278                    (DimName::Width, 160),
1279                ],
1280            })
1281            .build()
1282            .unwrap();
1283        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1284        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1285        decoder
1286            .decode_quantized(
1287                &[out_u8.view().into()],
1288                &mut output_boxes,
1289                &mut output_masks_u8,
1290            )
1291            .unwrap();
1292
1293        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1294        decoder
1295            .decode_quantized(
1296                &[out_i8.view().into()],
1297                &mut output_boxes,
1298                &mut output_masks_i8,
1299            )
1300            .unwrap();
1301
1302        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1303        decoder
1304            .decode_quantized(
1305                &[out_u16.view().into()],
1306                &mut output_boxes,
1307                &mut output_masks_u16,
1308            )
1309            .unwrap();
1310
1311        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1312        decoder
1313            .decode_quantized(
1314                &[out_i16.view().into()],
1315                &mut output_boxes,
1316                &mut output_masks_i16,
1317            )
1318            .unwrap();
1319
1320        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1321        decoder
1322            .decode_quantized(
1323                &[out_u32.view().into()],
1324                &mut output_boxes,
1325                &mut output_masks_u32,
1326            )
1327            .unwrap();
1328
1329        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1330        decoder
1331            .decode_quantized(
1332                &[out_i32.view().into()],
1333                &mut output_boxes,
1334                &mut output_masks_i32,
1335            )
1336            .unwrap();
1337
1338        compare_outputs((&[], &output_boxes), (&[], &[]));
1339        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1340        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1341        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1342        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1343        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1344        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1345        assert_eq!(mask_u8, mask_i8);
1346        assert_eq!(mask_u8, mask_u16);
1347        assert_eq!(mask_u8, mask_i16);
1348        assert_eq!(mask_u8, mask_u32);
1349        assert_eq!(mask_u8, mask_i32);
1350    }
1351
1352    #[test]
1353    fn test_modelpack_segdet() {
1354        let score_threshold = 0.45;
1355        let iou_threshold = 0.45;
1356
1357        let boxes = include_bytes!(concat!(
1358            env!("CARGO_MANIFEST_DIR"),
1359            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1360        ));
1361        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1362
1363        let scores = include_bytes!(concat!(
1364            env!("CARGO_MANIFEST_DIR"),
1365            "/../../testdata/modelpack_scores_1935x1.bin"
1366        ));
1367        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1368
1369        let seg = include_bytes!(concat!(
1370            env!("CARGO_MANIFEST_DIR"),
1371            "/../../testdata/modelpack_seg_2x160x160.bin"
1372        ));
1373        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1374
1375        let quant_boxes = (0.004656755365431309, 21).into();
1376        let quant_scores = (0.0019603664986789227, 0).into();
1377        let quant_seg = (1.0 / 255.0, 0).into();
1378
1379        let decoder = DecoderBuilder::default()
1380            .with_config_modelpack_segdet(
1381                configs::Boxes {
1382                    decoder: DecoderType::ModelPack,
1383                    quantization: Some(quant_boxes),
1384                    shape: vec![1, 1935, 1, 4],
1385                    dshape: vec![
1386                        (DimName::Batch, 1),
1387                        (DimName::NumBoxes, 1935),
1388                        (DimName::Padding, 1),
1389                        (DimName::BoxCoords, 4),
1390                    ],
1391                    normalized: Some(true),
1392                },
1393                configs::Scores {
1394                    decoder: DecoderType::ModelPack,
1395                    quantization: Some(quant_scores),
1396                    shape: vec![1, 1935, 1],
1397                    dshape: vec![
1398                        (DimName::Batch, 1),
1399                        (DimName::NumBoxes, 1935),
1400                        (DimName::NumClasses, 1),
1401                    ],
1402                },
1403                configs::Segmentation {
1404                    decoder: DecoderType::ModelPack,
1405                    quantization: Some(quant_seg),
1406                    shape: vec![1, 2, 160, 160],
1407                    dshape: vec![
1408                        (DimName::Batch, 1),
1409                        (DimName::NumClasses, 2),
1410                        (DimName::Height, 160),
1411                        (DimName::Width, 160),
1412                    ],
1413                },
1414            )
1415            .with_iou_threshold(iou_threshold)
1416            .with_score_threshold(score_threshold)
1417            .build()
1418            .unwrap();
1419        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1420        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1421        decoder
1422            .decode_quantized(
1423                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1424                &mut output_boxes,
1425                &mut output_masks,
1426            )
1427            .unwrap();
1428
1429        let mut mask = seg.slice(s![0, .., .., ..]);
1430        mask.swap_axes(0, 1);
1431        mask.swap_axes(1, 2);
1432        let mask = [Segmentation {
1433            xmin: 0.0,
1434            ymin: 0.0,
1435            xmax: 1.0,
1436            ymax: 1.0,
1437            segmentation: mask.into_owned(),
1438        }];
1439        let correct_boxes = [DetectBox {
1440            bbox: BoundingBox {
1441                xmin: 0.40513772,
1442                ymin: 0.6379755,
1443                xmax: 0.5122431,
1444                ymax: 0.7730214,
1445            },
1446            score: 0.4861709,
1447            label: 0,
1448        }];
1449        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1450
1451        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1452        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1453        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1454        decoder
1455            .decode_float::<f32>(
1456                &[
1457                    scores.view().into_dyn(),
1458                    boxes.view().into_dyn(),
1459                    seg.view().into_dyn(),
1460                ],
1461                &mut output_boxes,
1462                &mut output_masks,
1463            )
1464            .unwrap();
1465
1466        // not expected for float segmentation decoder to have same values as quantized
1467        // segmentation decoder, as float decoder ensures the data fills 0-255,
1468        // quantized decoder uses whatever the model output. Thus the float
1469        // output is the same as the quantized output but scaled differently.
1470        // However, it is expected that the mask after argmax will be the same.
1471        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1472        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1473        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1474
1475        assert_eq!(mask0, mask1);
1476    }
1477
1478    #[test]
1479    fn test_modelpack_segdet_split() {
1480        let score_threshold = 0.8;
1481        let iou_threshold = 0.5;
1482
1483        let seg = include_bytes!(concat!(
1484            env!("CARGO_MANIFEST_DIR"),
1485            "/../../testdata/modelpack_seg_2x160x160.bin"
1486        ));
1487        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1488
1489        let detect0 = include_bytes!(concat!(
1490            env!("CARGO_MANIFEST_DIR"),
1491            "/../../testdata/modelpack_split_9x15x18.bin"
1492        ));
1493        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1494
1495        let detect1 = include_bytes!(concat!(
1496            env!("CARGO_MANIFEST_DIR"),
1497            "/../../testdata/modelpack_split_17x30x18.bin"
1498        ));
1499        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1500
1501        let quant0 = (0.08547406643629074, 174).into();
1502        let quant1 = (0.09929127991199493, 183).into();
1503        let quant_seg = (1.0 / 255.0, 0).into();
1504
1505        let anchors0 = vec![
1506            [0.36666667461395264, 0.31481480598449707],
1507            [0.38749998807907104, 0.4740740656852722],
1508            [0.5333333611488342, 0.644444465637207],
1509        ];
1510        let anchors1 = vec![
1511            [0.13750000298023224, 0.2074074000120163],
1512            [0.2541666626930237, 0.21481481194496155],
1513            [0.23125000298023224, 0.35185185074806213],
1514        ];
1515
1516        let decoder = DecoderBuilder::default()
1517            .with_config_modelpack_segdet_split(
1518                vec![
1519                    configs::Detection {
1520                        decoder: DecoderType::ModelPack,
1521                        shape: vec![1, 17, 30, 18],
1522                        anchors: Some(anchors1),
1523                        quantization: Some(quant1),
1524                        dshape: vec![
1525                            (DimName::Batch, 1),
1526                            (DimName::Height, 17),
1527                            (DimName::Width, 30),
1528                            (DimName::NumAnchorsXFeatures, 18),
1529                        ],
1530                        normalized: Some(true),
1531                    },
1532                    configs::Detection {
1533                        decoder: DecoderType::ModelPack,
1534                        shape: vec![1, 9, 15, 18],
1535                        anchors: Some(anchors0),
1536                        quantization: Some(quant0),
1537                        dshape: vec![
1538                            (DimName::Batch, 1),
1539                            (DimName::Height, 9),
1540                            (DimName::Width, 15),
1541                            (DimName::NumAnchorsXFeatures, 18),
1542                        ],
1543                        normalized: Some(true),
1544                    },
1545                ],
1546                configs::Segmentation {
1547                    decoder: DecoderType::ModelPack,
1548                    quantization: Some(quant_seg),
1549                    shape: vec![1, 2, 160, 160],
1550                    dshape: vec![
1551                        (DimName::Batch, 1),
1552                        (DimName::NumClasses, 2),
1553                        (DimName::Height, 160),
1554                        (DimName::Width, 160),
1555                    ],
1556                },
1557            )
1558            .with_score_threshold(score_threshold)
1559            .with_iou_threshold(iou_threshold)
1560            .build()
1561            .unwrap();
1562        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1563        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1564        decoder
1565            .decode_quantized(
1566                &[
1567                    detect0.view().into(),
1568                    detect1.view().into(),
1569                    seg.view().into(),
1570                ],
1571                &mut output_boxes,
1572                &mut output_masks,
1573            )
1574            .unwrap();
1575
1576        let mut mask = seg.slice(s![0, .., .., ..]);
1577        mask.swap_axes(0, 1);
1578        mask.swap_axes(1, 2);
1579        let mask = [Segmentation {
1580            xmin: 0.0,
1581            ymin: 0.0,
1582            xmax: 1.0,
1583            ymax: 1.0,
1584            segmentation: mask.into_owned(),
1585        }];
1586        let correct_boxes = [DetectBox {
1587            bbox: BoundingBox {
1588                xmin: 0.43171933,
1589                ymin: 0.68243736,
1590                xmax: 0.5626645,
1591                ymax: 0.808863,
1592            },
1593            score: 0.99240804,
1594            label: 0,
1595        }];
1596        println!("Output Boxes: {:?}", output_boxes);
1597        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1598
1599        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1600        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1601        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1602        decoder
1603            .decode_float::<f32>(
1604                &[
1605                    detect0.view().into_dyn(),
1606                    detect1.view().into_dyn(),
1607                    seg.view().into_dyn(),
1608                ],
1609                &mut output_boxes,
1610                &mut output_masks,
1611            )
1612            .unwrap();
1613
1614        // not expected for float segmentation decoder to have same values as quantized
1615        // segmentation decoder, as float decoder ensures the data fills 0-255,
1616        // quantized decoder uses whatever the model output. Thus the float
1617        // output is the same as the quantized output but scaled differently.
1618        // However, it is expected that the mask after argmax will be the same.
1619        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1620        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1621        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1622
1623        assert_eq!(mask0, mask1);
1624    }
1625
1626    #[test]
1627    fn test_dequant_chunked() {
1628        let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1629        out.push(123); // make sure to test non multiple of 16 length
1630
1631        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1632        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1633        let quant = Quantization::new(0.0040811873, -123);
1634        dequantize_cpu(&out, quant, &mut out_dequant);
1635
1636        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1637        assert_eq!(out_dequant, out_dequant_simd);
1638
1639        let quant = Quantization::new(0.0040811873, 0);
1640        dequantize_cpu(&out, quant, &mut out_dequant);
1641
1642        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1643        assert_eq!(out_dequant, out_dequant_simd);
1644    }
1645
1646    #[test]
1647    fn test_dequant_ground_truth() {
1648        // Formula: output = (input - zero_point) * scale
1649        // Verify both dequantize_cpu and dequantize_cpu_chunked against hand-computed values.
1650
1651        // Case 1: scale=0.1, zero_point=-128 (from doc example)
1652        let quant = Quantization::new(0.1, -128);
1653        let input: Vec<i8> = vec![0, 127, -128, 64];
1654        let mut output = vec![0.0f32; 4];
1655        let mut output_chunked = vec![0.0f32; 4];
1656        dequantize_cpu(&input, quant, &mut output);
1657        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1658        // (0 - (-128)) * 0.1 = 12.8
1659        // (127 - (-128)) * 0.1 = 25.5
1660        // (-128 - (-128)) * 0.1 = 0.0
1661        // (64 - (-128)) * 0.1 = 19.2
1662        let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1663        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1664            assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1665        }
1666        for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1667            assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1668        }
1669
1670        // Case 2: scale=1.0, zero_point=0 (identity-like)
1671        let quant = Quantization::new(1.0, 0);
1672        dequantize_cpu(&input, quant, &mut output);
1673        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1674        let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1675        assert_eq!(output, expected);
1676        assert_eq!(output_chunked, expected);
1677
1678        // Case 3: scale=0.5, zero_point=0
1679        let quant = Quantization::new(0.5, 0);
1680        dequantize_cpu(&input, quant, &mut output);
1681        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1682        let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1683        assert_eq!(output, expected);
1684        assert_eq!(output_chunked, expected);
1685
1686        // Case 4: i8 min/max boundaries with typical quantization params
1687        let quant = Quantization::new(0.021287762, 31);
1688        let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1689        let mut output = vec![0.0f32; 6];
1690        let mut output_chunked = vec![0.0f32; 6];
1691        dequantize_cpu(&input, quant, &mut output);
1692        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1693        for i in 0..6 {
1694            let expected = (input[i] as f32 - 31.0) * 0.021287762;
1695            assert!(
1696                (output[i] - expected).abs() < 1e-5,
1697                "cpu[{i}]: {} != {expected}",
1698                output[i]
1699            );
1700            assert!(
1701                (output_chunked[i] - expected).abs() < 1e-5,
1702                "chunked[{i}]: {} != {expected}",
1703                output_chunked[i]
1704            );
1705        }
1706    }
1707
1708    #[test]
1709    fn test_decoder_yolo_det() {
1710        let score_threshold = 0.25;
1711        let iou_threshold = 0.7;
1712        let out = load_yolov8s_det();
1713        let quant = (0.0040811873, -123).into();
1714
1715        let decoder = DecoderBuilder::default()
1716            .with_config_yolo_det(
1717                configs::Detection {
1718                    decoder: DecoderType::Ultralytics,
1719                    shape: vec![1, 84, 8400],
1720                    anchors: None,
1721                    quantization: Some(quant),
1722                    dshape: vec![
1723                        (DimName::Batch, 1),
1724                        (DimName::NumFeatures, 84),
1725                        (DimName::NumBoxes, 8400),
1726                    ],
1727                    normalized: Some(true),
1728                },
1729                Some(DecoderVersion::Yolo11),
1730            )
1731            .with_score_threshold(score_threshold)
1732            .with_iou_threshold(iou_threshold)
1733            .build()
1734            .unwrap();
1735
1736        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1737        decode_yolo_det(
1738            (out.slice(s![0, .., ..]), quant.into()),
1739            score_threshold,
1740            iou_threshold,
1741            Some(configs::Nms::ClassAgnostic),
1742            &mut output_boxes,
1743        );
1744        assert!(output_boxes[0].equal_within_delta(
1745            &DetectBox {
1746                bbox: BoundingBox {
1747                    xmin: 0.5285137,
1748                    ymin: 0.05305544,
1749                    xmax: 0.87541467,
1750                    ymax: 0.9998909,
1751                },
1752                score: 0.5591227,
1753                label: 0
1754            },
1755            1e-6
1756        ));
1757
1758        assert!(output_boxes[1].equal_within_delta(
1759            &DetectBox {
1760                bbox: BoundingBox {
1761                    xmin: 0.130598,
1762                    ymin: 0.43260583,
1763                    xmax: 0.35098213,
1764                    ymax: 0.9958097,
1765                },
1766                score: 0.33057618,
1767                label: 75
1768            },
1769            1e-6
1770        ));
1771
1772        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1773        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1774        decoder
1775            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1776            .unwrap();
1777
1778        let out = dequantize_ndarray(out.view(), quant.into());
1779        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1780        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1781        decoder
1782            .decode_float::<f32>(
1783                &[out.view().into_dyn()],
1784                &mut output_boxes_f32,
1785                &mut output_masks_f32,
1786            )
1787            .unwrap();
1788
1789        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1790        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1791    }
1792
1793    #[test]
1794    fn test_decoder_masks() {
1795        let score_threshold = 0.45;
1796        let iou_threshold = 0.45;
1797        let boxes = load_yolov8_boxes();
1798        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1799
1800        let protos = load_yolov8_protos();
1801        let quant_protos = Quantization::new(0.02491161972284317, -117);
1802        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1803        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1804        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1805        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1806        decode_yolo_segdet_float(
1807            seg.slice(s![0, .., ..]),
1808            protos.slice(s![0, .., .., ..]),
1809            score_threshold,
1810            iou_threshold,
1811            Some(configs::Nms::ClassAgnostic),
1812            &mut output_boxes,
1813            &mut output_masks,
1814        )
1815        .unwrap();
1816        assert_eq!(output_boxes.len(), 2);
1817        assert_eq!(output_boxes.len(), output_masks.len());
1818
1819        for (b, m) in output_boxes.iter().zip(&output_masks) {
1820            // Mask region is the proto-grid-aligned crop (floor for min,
1821            // ceil for max), so it encloses the post-NMS bbox (EDGEAI-1304).
1822            assert!(b.bbox.xmin >= m.xmin);
1823            assert!(b.bbox.ymin >= m.ymin);
1824            assert!(b.bbox.xmax <= m.xmax);
1825            assert!(b.bbox.ymax <= m.ymax);
1826        }
1827        assert!(output_boxes[0].equal_within_delta(
1828            &DetectBox {
1829                bbox: BoundingBox {
1830                    xmin: 0.08515105,
1831                    ymin: 0.7131401,
1832                    xmax: 0.29802868,
1833                    ymax: 0.8195788,
1834                },
1835                score: 0.91537374,
1836                label: 23
1837            },
1838            1.0 / 160.0, // wider range because mask will expand the box
1839        ));
1840
1841        assert!(output_boxes[1].equal_within_delta(
1842            &DetectBox {
1843                bbox: BoundingBox {
1844                    xmin: 0.59605736,
1845                    ymin: 0.25545314,
1846                    xmax: 0.93666154,
1847                    ymax: 0.72378385,
1848                },
1849                score: 0.91537374,
1850                label: 23
1851            },
1852            1.0 / 160.0, // wider range because mask will expand the box
1853        ));
1854
1855        let full_mask = include_bytes!(concat!(
1856            env!("CARGO_MANIFEST_DIR"),
1857            "/../../testdata/yolov8_mask_results.bin"
1858        ));
1859        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1860
1861        let cropped_mask = full_mask.slice(ndarray::s![
1862            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1863            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1864        ]);
1865
1866        assert_eq!(
1867            cropped_mask,
1868            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1869        );
1870    }
1871
1872    /// Regression test: config-driven path with physically-NCHW protos.
1873    /// Simulates YOLOv8-seg ONNX outputs where the producer emits protos
1874    /// as `(1, 32, 160, 160)` in CHW memory order — the caller declares
1875    /// shape and dshape matching that physical order and HAL permutes to
1876    /// canonical HWC via `swap_axes_if_needed`.
1877    ///
1878    /// This is the counterpart to the NHWC-producer case (TFLite /
1879    /// Ara-2) where shape+dshape are `(1, 160, 160, 32)` +
1880    /// `[batch, height, width, num_protos]` and no reordering is needed.
1881    #[test]
1882    fn test_decoder_masks_nchw_protos() {
1883        let score_threshold = 0.45;
1884        let iou_threshold = 0.45;
1885
1886        // Load test data — boxes as [116, 8400]
1887        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1888        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1889
1890        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1891        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1892        let quant_protos = Quantization::new(0.02491161972284317, -117);
1893        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1894
1895        // ---- Reference: direct call with HWC protos (known working) ----
1896        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1897        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1898        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1899        decode_yolo_segdet_float(
1900            seg.view(),
1901            protos_f32_hwc.view(),
1902            score_threshold,
1903            iou_threshold,
1904            Some(configs::Nms::ClassAgnostic),
1905            &mut ref_boxes,
1906            &mut ref_masks,
1907        )
1908        .unwrap();
1909        assert_eq!(ref_boxes.len(), 2);
1910
1911        // ---- Config-driven path: NCHW protos declared in physical order ----
1912        // Permute the HWC test data to CHW memory order — this is what an
1913        // ONNX-style producer would emit into the tensor buffer.
1914        // `to_owned` materialises a C-contiguous Array3<f32> with CHW
1915        // strides, matching a producer that writes channels-outer.
1916        let protos_f32_chw_view = protos_f32_hwc.view().permuted_axes([2, 0, 1]); // [32, 160, 160]
1917        let protos_f32_chw = protos_f32_chw_view.to_owned();
1918        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1919
1920        // Build boxes as [1, 116, 8400] f32
1921        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1922
1923        // Declare shape and dshape in physical memory order (outermost
1924        // first). `swap_axes_if_needed` uses dshape role lookup to
1925        // permute the stride tuple into canonical HWC.
1926        let decoder = DecoderBuilder::default()
1927            .with_config_yolo_segdet(
1928                configs::Detection {
1929                    decoder: configs::DecoderType::Ultralytics,
1930                    quantization: None,
1931                    shape: vec![1, 116, 8400],
1932                    dshape: vec![
1933                        (configs::DimName::Batch, 1),
1934                        (configs::DimName::NumFeatures, 116),
1935                        (configs::DimName::NumBoxes, 8400),
1936                    ],
1937                    normalized: Some(true),
1938                    anchors: None,
1939                },
1940                configs::Protos {
1941                    decoder: configs::DecoderType::Ultralytics,
1942                    quantization: None,
1943                    shape: vec![1, 32, 160, 160],
1944                    dshape: vec![
1945                        (configs::DimName::Batch, 1),
1946                        (configs::DimName::NumProtos, 32),
1947                        (configs::DimName::Height, 160),
1948                        (configs::DimName::Width, 160),
1949                    ],
1950                },
1951                None, // decoder version
1952            )
1953            .with_score_threshold(score_threshold)
1954            .with_iou_threshold(iou_threshold)
1955            .build()
1956            .unwrap();
1957
1958        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1959        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1960        decoder
1961            .decode_float(
1962                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1963                &mut cfg_boxes,
1964                &mut cfg_masks,
1965            )
1966            .unwrap();
1967
1968        // Must produce the same number of detections
1969        assert_eq!(
1970            cfg_boxes.len(),
1971            ref_boxes.len(),
1972            "config path produced {} boxes, reference produced {}",
1973            cfg_boxes.len(),
1974            ref_boxes.len()
1975        );
1976
1977        // Boxes must match
1978        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1979            assert!(
1980                cb.equal_within_delta(rb, 0.01),
1981                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1982            );
1983        }
1984
1985        // Masks must match pixel-for-pixel
1986        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1987            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1988            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1989            assert_eq!(
1990                cm_arr, rm_arr,
1991                "mask {i} pixel mismatch between config-driven and reference paths"
1992            );
1993        }
1994    }
1995
1996    #[test]
1997    fn test_decoder_masks_i8() {
1998        let score_threshold = 0.45;
1999        let iou_threshold = 0.45;
2000        let boxes = load_yolov8_boxes();
2001        let quant_boxes = (0.021287761628627777, 31).into();
2002
2003        let protos = load_yolov8_protos();
2004        let quant_protos = (0.02491161972284317, -117).into();
2005        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2006        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2007
2008        let decoder = DecoderBuilder::default()
2009            .with_config_yolo_segdet(
2010                configs::Detection {
2011                    decoder: configs::DecoderType::Ultralytics,
2012                    quantization: Some(quant_boxes),
2013                    shape: vec![1, 116, 8400],
2014                    anchors: None,
2015                    dshape: vec![
2016                        (DimName::Batch, 1),
2017                        (DimName::NumFeatures, 116),
2018                        (DimName::NumBoxes, 8400),
2019                    ],
2020                    normalized: Some(true),
2021                },
2022                Protos {
2023                    decoder: configs::DecoderType::Ultralytics,
2024                    quantization: Some(quant_protos),
2025                    shape: vec![1, 160, 160, 32],
2026                    dshape: vec![
2027                        (DimName::Batch, 1),
2028                        (DimName::Height, 160),
2029                        (DimName::Width, 160),
2030                        (DimName::NumProtos, 32),
2031                    ],
2032                },
2033                Some(DecoderVersion::Yolo11),
2034            )
2035            .with_score_threshold(score_threshold)
2036            .with_iou_threshold(iou_threshold)
2037            .build()
2038            .unwrap();
2039
2040        let quant_boxes = quant_boxes.into();
2041        let quant_protos = quant_protos.into();
2042
2043        decode_yolo_segdet_quant(
2044            (boxes.slice(s![0, .., ..]), quant_boxes),
2045            (protos.slice(s![0, .., .., ..]), quant_protos),
2046            score_threshold,
2047            iou_threshold,
2048            Some(configs::Nms::ClassAgnostic),
2049            &mut output_boxes,
2050            &mut output_masks,
2051        )
2052        .unwrap();
2053
2054        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2055        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2056
2057        decoder
2058            .decode_quantized(
2059                &[boxes.view().into(), protos.view().into()],
2060                &mut output_boxes1,
2061                &mut output_masks1,
2062            )
2063            .unwrap();
2064
2065        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2066        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2067
2068        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2069        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2070        decode_yolo_segdet_float(
2071            seg.slice(s![0, .., ..]),
2072            protos.slice(s![0, .., .., ..]),
2073            score_threshold,
2074            iou_threshold,
2075            Some(configs::Nms::ClassAgnostic),
2076            &mut output_boxes_f32,
2077            &mut output_masks_f32,
2078        )
2079        .unwrap();
2080
2081        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
2082        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
2083
2084        decoder
2085            .decode_float(
2086                &[seg.view().into_dyn(), protos.view().into_dyn()],
2087                &mut output_boxes1_f32,
2088                &mut output_masks1_f32,
2089            )
2090            .unwrap();
2091
2092        compare_outputs(
2093            (&output_boxes, &output_boxes1),
2094            (&output_masks, &output_masks1),
2095        );
2096
2097        compare_outputs(
2098            (&output_boxes, &output_boxes_f32),
2099            (&output_masks, &output_masks_f32),
2100        );
2101
2102        compare_outputs(
2103            (&output_boxes_f32, &output_boxes1_f32),
2104            (&output_masks_f32, &output_masks1_f32),
2105        );
2106    }
2107
2108    #[test]
2109    fn test_decoder_yolo_split() {
2110        let score_threshold = 0.45;
2111        let iou_threshold = 0.45;
2112        let boxes = load_yolov8_boxes();
2113        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2114        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2115
2116        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2117
2118        let decoder = DecoderBuilder::default()
2119            .with_config_yolo_split_det(
2120                configs::Boxes {
2121                    decoder: configs::DecoderType::Ultralytics,
2122                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2123                    shape: vec![1, 4, 8400],
2124                    dshape: vec![
2125                        (DimName::Batch, 1),
2126                        (DimName::BoxCoords, 4),
2127                        (DimName::NumBoxes, 8400),
2128                    ],
2129                    normalized: Some(true),
2130                },
2131                configs::Scores {
2132                    decoder: configs::DecoderType::Ultralytics,
2133                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2134                    shape: vec![1, 80, 8400],
2135                    dshape: vec![
2136                        (DimName::Batch, 1),
2137                        (DimName::NumClasses, 80),
2138                        (DimName::NumBoxes, 8400),
2139                    ],
2140                },
2141            )
2142            .with_score_threshold(score_threshold)
2143            .with_iou_threshold(iou_threshold)
2144            .build()
2145            .unwrap();
2146
2147        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2148        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2149
2150        decoder
2151            .decode_quantized(
2152                &[
2153                    boxes.slice(s![.., ..4, ..]).into(),
2154                    boxes.slice(s![.., 4..84, ..]).into(),
2155                ],
2156                &mut output_boxes,
2157                &mut output_masks,
2158            )
2159            .unwrap();
2160
2161        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2162        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2163        decode_yolo_det_float(
2164            seg.slice(s![0, ..84, ..]),
2165            score_threshold,
2166            iou_threshold,
2167            Some(configs::Nms::ClassAgnostic),
2168            &mut output_boxes_f32,
2169        );
2170
2171        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2172        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2173
2174        decoder
2175            .decode_float(
2176                &[
2177                    seg.slice(s![.., ..4, ..]).into_dyn(),
2178                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2179                ],
2180                &mut output_boxes1,
2181                &mut output_masks1,
2182            )
2183            .unwrap();
2184        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2185        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2186    }
2187
2188    #[test]
2189    fn test_decoder_masks_config_mixed() {
2190        let score_threshold = 0.45;
2191        let iou_threshold = 0.45;
2192        let boxes_raw = load_yolov8_boxes();
2193        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2194        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2195
2196        let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2197
2198        let protos = load_yolov8_protos();
2199        let quant_protos = (0.02491161972284317, -117);
2200
2201        let decoder = build_yolo_split_segdet_decoder(
2202            score_threshold,
2203            iou_threshold,
2204            quant_boxes,
2205            quant_protos,
2206        );
2207        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2208        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2209
2210        decoder
2211            .decode_quantized(
2212                &[
2213                    boxes.slice(s![.., ..4, ..]).into(),
2214                    boxes.slice(s![.., 4..84, ..]).into(),
2215                    boxes.slice(s![.., 84.., ..]).into(),
2216                    protos.view().into(),
2217                ],
2218                &mut output_boxes,
2219                &mut output_masks,
2220            )
2221            .unwrap();
2222
2223        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2224        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2225        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2226        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2227        decode_yolo_segdet_float(
2228            seg.slice(s![0, .., ..]),
2229            protos.slice(s![0, .., .., ..]),
2230            score_threshold,
2231            iou_threshold,
2232            Some(configs::Nms::ClassAgnostic),
2233            &mut output_boxes_f32,
2234            &mut output_masks_f32,
2235        )
2236        .unwrap();
2237
2238        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2239        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2240
2241        decoder
2242            .decode_float(
2243                &[
2244                    seg.slice(s![.., ..4, ..]).into_dyn(),
2245                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2246                    seg.slice(s![.., 84.., ..]).into_dyn(),
2247                    protos.view().into_dyn(),
2248                ],
2249                &mut output_boxes1,
2250                &mut output_masks1,
2251            )
2252            .unwrap();
2253        compare_outputs(
2254            (&output_boxes, &output_boxes_f32),
2255            (&output_masks, &output_masks_f32),
2256        );
2257        compare_outputs(
2258            (&output_boxes_f32, &output_boxes1),
2259            (&output_masks_f32, &output_masks1),
2260        );
2261    }
2262
2263    fn build_yolo_split_segdet_decoder(
2264        score_threshold: f32,
2265        iou_threshold: f32,
2266        quant_boxes: (f32, i32),
2267        quant_protos: (f32, i32),
2268    ) -> crate::Decoder {
2269        DecoderBuilder::default()
2270            .with_config_yolo_split_segdet(
2271                configs::Boxes {
2272                    decoder: configs::DecoderType::Ultralytics,
2273                    quantization: Some(quant_boxes.into()),
2274                    shape: vec![1, 4, 8400],
2275                    dshape: vec![
2276                        (DimName::Batch, 1),
2277                        (DimName::BoxCoords, 4),
2278                        (DimName::NumBoxes, 8400),
2279                    ],
2280                    normalized: Some(true),
2281                },
2282                configs::Scores {
2283                    decoder: configs::DecoderType::Ultralytics,
2284                    quantization: Some(quant_boxes.into()),
2285                    shape: vec![1, 80, 8400],
2286                    dshape: vec![
2287                        (DimName::Batch, 1),
2288                        (DimName::NumClasses, 80),
2289                        (DimName::NumBoxes, 8400),
2290                    ],
2291                },
2292                configs::MaskCoefficients {
2293                    decoder: configs::DecoderType::Ultralytics,
2294                    quantization: Some(quant_boxes.into()),
2295                    shape: vec![1, 32, 8400],
2296                    dshape: vec![
2297                        (DimName::Batch, 1),
2298                        (DimName::NumProtos, 32),
2299                        (DimName::NumBoxes, 8400),
2300                    ],
2301                },
2302                configs::Protos {
2303                    decoder: configs::DecoderType::Ultralytics,
2304                    quantization: Some(quant_protos.into()),
2305                    shape: vec![1, 160, 160, 32],
2306                    dshape: vec![
2307                        (DimName::Batch, 1),
2308                        (DimName::Height, 160),
2309                        (DimName::Width, 160),
2310                        (DimName::NumProtos, 32),
2311                    ],
2312                },
2313            )
2314            .with_score_threshold(score_threshold)
2315            .with_iou_threshold(iou_threshold)
2316            .build()
2317            .unwrap()
2318    }
2319
2320    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2321        let config_yaml = include_str!(concat!(
2322            env!("CARGO_MANIFEST_DIR"),
2323            "/../../testdata/yolov8_seg.yaml"
2324        ));
2325        DecoderBuilder::default()
2326            .with_config_yaml_str(config_yaml.to_string())
2327            .with_score_threshold(score_threshold)
2328            .with_iou_threshold(iou_threshold)
2329            .build()
2330            .unwrap()
2331    }
2332    #[test]
2333    fn test_decoder_masks_config_i32() {
2334        let score_threshold = 0.45;
2335        let iou_threshold = 0.45;
2336        let boxes_raw = load_yolov8_boxes();
2337        let scale = 1 << 23;
2338        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2339        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2340
2341        let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2342
2343        let protos_raw = load_yolov8_protos();
2344        let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2345        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2346        let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2347
2348        let decoder = build_yolo_split_segdet_decoder(
2349            score_threshold,
2350            iou_threshold,
2351            quant_boxes,
2352            quant_protos,
2353        );
2354
2355        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2356        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2357
2358        decoder
2359            .decode_quantized(
2360                &[
2361                    boxes.slice(s![.., ..4, ..]).into(),
2362                    boxes.slice(s![.., 4..84, ..]).into(),
2363                    boxes.slice(s![.., 84.., ..]).into(),
2364                    protos.view().into(),
2365                ],
2366                &mut output_boxes,
2367                &mut output_masks,
2368            )
2369            .unwrap();
2370
2371        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2372        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2373        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2374        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2375        decode_yolo_segdet_float(
2376            seg.slice(s![0, .., ..]),
2377            protos.slice(s![0, .., .., ..]),
2378            score_threshold,
2379            iou_threshold,
2380            Some(configs::Nms::ClassAgnostic),
2381            &mut output_boxes_f32,
2382            &mut output_masks_f32,
2383        )
2384        .unwrap();
2385
2386        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2387        assert_eq!(output_masks.len(), output_masks_f32.len());
2388
2389        compare_outputs(
2390            (&output_boxes, &output_boxes_f32),
2391            (&output_masks, &output_masks_f32),
2392        );
2393    }
2394
2395    /// test running multiple decoders concurrently
2396    #[test]
2397    fn test_context_switch() {
2398        let yolo_det = || {
2399            let score_threshold = 0.25;
2400            let iou_threshold = 0.7;
2401            let out = load_yolov8s_det();
2402            let quant = (0.0040811873, -123).into();
2403
2404            let decoder = DecoderBuilder::default()
2405                .with_config_yolo_det(
2406                    configs::Detection {
2407                        decoder: DecoderType::Ultralytics,
2408                        shape: vec![1, 84, 8400],
2409                        anchors: None,
2410                        quantization: Some(quant),
2411                        dshape: vec![
2412                            (DimName::Batch, 1),
2413                            (DimName::NumFeatures, 84),
2414                            (DimName::NumBoxes, 8400),
2415                        ],
2416                        normalized: None,
2417                    },
2418                    None,
2419                )
2420                .with_score_threshold(score_threshold)
2421                .with_iou_threshold(iou_threshold)
2422                .build()
2423                .unwrap();
2424
2425            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2426            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2427
2428            for _ in 0..100 {
2429                decoder
2430                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2431                    .unwrap();
2432
2433                assert!(output_boxes[0].equal_within_delta(
2434                    &DetectBox {
2435                        bbox: BoundingBox {
2436                            xmin: 0.5285137,
2437                            ymin: 0.05305544,
2438                            xmax: 0.87541467,
2439                            ymax: 0.9998909,
2440                        },
2441                        score: 0.5591227,
2442                        label: 0
2443                    },
2444                    1e-6
2445                ));
2446
2447                assert!(output_boxes[1].equal_within_delta(
2448                    &DetectBox {
2449                        bbox: BoundingBox {
2450                            xmin: 0.130598,
2451                            ymin: 0.43260583,
2452                            xmax: 0.35098213,
2453                            ymax: 0.9958097,
2454                        },
2455                        score: 0.33057618,
2456                        label: 75
2457                    },
2458                    1e-6
2459                ));
2460                assert!(output_masks.is_empty());
2461            }
2462        };
2463
2464        let modelpack_det_split = || {
2465            let score_threshold = 0.8;
2466            let iou_threshold = 0.5;
2467
2468            let seg = include_bytes!(concat!(
2469                env!("CARGO_MANIFEST_DIR"),
2470                "/../../testdata/modelpack_seg_2x160x160.bin"
2471            ));
2472            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2473
2474            let detect0 = include_bytes!(concat!(
2475                env!("CARGO_MANIFEST_DIR"),
2476                "/../../testdata/modelpack_split_9x15x18.bin"
2477            ));
2478            let detect0 =
2479                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2480
2481            let detect1 = include_bytes!(concat!(
2482                env!("CARGO_MANIFEST_DIR"),
2483                "/../../testdata/modelpack_split_17x30x18.bin"
2484            ));
2485            let detect1 =
2486                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2487
2488            let mut mask = seg.slice(s![0, .., .., ..]);
2489            mask.swap_axes(0, 1);
2490            mask.swap_axes(1, 2);
2491            let mask = [Segmentation {
2492                xmin: 0.0,
2493                ymin: 0.0,
2494                xmax: 1.0,
2495                ymax: 1.0,
2496                segmentation: mask.into_owned(),
2497            }];
2498            let correct_boxes = [DetectBox {
2499                bbox: BoundingBox {
2500                    xmin: 0.43171933,
2501                    ymin: 0.68243736,
2502                    xmax: 0.5626645,
2503                    ymax: 0.808863,
2504                },
2505                score: 0.99240804,
2506                label: 0,
2507            }];
2508
2509            let quant0 = (0.08547406643629074, 174).into();
2510            let quant1 = (0.09929127991199493, 183).into();
2511            let quant_seg = (1.0 / 255.0, 0).into();
2512
2513            let anchors0 = vec![
2514                [0.36666667461395264, 0.31481480598449707],
2515                [0.38749998807907104, 0.4740740656852722],
2516                [0.5333333611488342, 0.644444465637207],
2517            ];
2518            let anchors1 = vec![
2519                [0.13750000298023224, 0.2074074000120163],
2520                [0.2541666626930237, 0.21481481194496155],
2521                [0.23125000298023224, 0.35185185074806213],
2522            ];
2523
2524            let decoder = DecoderBuilder::default()
2525                .with_config_modelpack_segdet_split(
2526                    vec![
2527                        configs::Detection {
2528                            decoder: DecoderType::ModelPack,
2529                            shape: vec![1, 17, 30, 18],
2530                            anchors: Some(anchors1),
2531                            quantization: Some(quant1),
2532                            dshape: vec![
2533                                (DimName::Batch, 1),
2534                                (DimName::Height, 17),
2535                                (DimName::Width, 30),
2536                                (DimName::NumAnchorsXFeatures, 18),
2537                            ],
2538                            normalized: None,
2539                        },
2540                        configs::Detection {
2541                            decoder: DecoderType::ModelPack,
2542                            shape: vec![1, 9, 15, 18],
2543                            anchors: Some(anchors0),
2544                            quantization: Some(quant0),
2545                            dshape: vec![
2546                                (DimName::Batch, 1),
2547                                (DimName::Height, 9),
2548                                (DimName::Width, 15),
2549                                (DimName::NumAnchorsXFeatures, 18),
2550                            ],
2551                            normalized: None,
2552                        },
2553                    ],
2554                    configs::Segmentation {
2555                        decoder: DecoderType::ModelPack,
2556                        quantization: Some(quant_seg),
2557                        shape: vec![1, 2, 160, 160],
2558                        dshape: vec![
2559                            (DimName::Batch, 1),
2560                            (DimName::NumClasses, 2),
2561                            (DimName::Height, 160),
2562                            (DimName::Width, 160),
2563                        ],
2564                    },
2565                )
2566                .with_score_threshold(score_threshold)
2567                .with_iou_threshold(iou_threshold)
2568                .build()
2569                .unwrap();
2570            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2571            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2572
2573            for _ in 0..100 {
2574                decoder
2575                    .decode_quantized(
2576                        &[
2577                            detect0.view().into(),
2578                            detect1.view().into(),
2579                            seg.view().into(),
2580                        ],
2581                        &mut output_boxes,
2582                        &mut output_masks,
2583                    )
2584                    .unwrap();
2585
2586                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2587            }
2588        };
2589
2590        let handles = vec![
2591            std::thread::spawn(yolo_det),
2592            std::thread::spawn(modelpack_det_split),
2593            std::thread::spawn(yolo_det),
2594            std::thread::spawn(modelpack_det_split),
2595            std::thread::spawn(yolo_det),
2596            std::thread::spawn(modelpack_det_split),
2597            std::thread::spawn(yolo_det),
2598            std::thread::spawn(modelpack_det_split),
2599        ];
2600        for handle in handles {
2601            handle.join().unwrap();
2602        }
2603    }
2604
2605    #[test]
2606    fn test_ndarray_to_xyxy_float() {
2607        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2608        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2609        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2610
2611        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2612        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2613        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2614    }
2615
2616    #[test]
2617    fn test_class_aware_nms_float() {
2618        use crate::float::nms_class_aware_float;
2619
2620        // Create two overlapping boxes with different classes
2621        let boxes = vec![
2622            DetectBox {
2623                bbox: BoundingBox {
2624                    xmin: 0.0,
2625                    ymin: 0.0,
2626                    xmax: 0.5,
2627                    ymax: 0.5,
2628                },
2629                score: 0.9,
2630                label: 0, // class 0
2631            },
2632            DetectBox {
2633                bbox: BoundingBox {
2634                    xmin: 0.1,
2635                    ymin: 0.1,
2636                    xmax: 0.6,
2637                    ymax: 0.6,
2638                },
2639                score: 0.8,
2640                label: 1, // class 1 - different class
2641            },
2642        ];
2643
2644        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2645        // threshold 0.3)
2646        let result = nms_class_aware_float(0.3, None, boxes.clone());
2647        assert_eq!(
2648            result.len(),
2649            2,
2650            "Class-aware NMS should keep both boxes with different classes"
2651        );
2652
2653        // Now test with same class - should suppress one
2654        let same_class_boxes = vec![
2655            DetectBox {
2656                bbox: BoundingBox {
2657                    xmin: 0.0,
2658                    ymin: 0.0,
2659                    xmax: 0.5,
2660                    ymax: 0.5,
2661                },
2662                score: 0.9,
2663                label: 0,
2664            },
2665            DetectBox {
2666                bbox: BoundingBox {
2667                    xmin: 0.1,
2668                    ymin: 0.1,
2669                    xmax: 0.6,
2670                    ymax: 0.6,
2671                },
2672                score: 0.8,
2673                label: 0, // same class
2674            },
2675        ];
2676
2677        let result = nms_class_aware_float(0.3, None, same_class_boxes);
2678        assert_eq!(
2679            result.len(),
2680            1,
2681            "Class-aware NMS should suppress overlapping box with same class"
2682        );
2683        assert_eq!(result[0].label, 0);
2684        assert!((result[0].score - 0.9).abs() < 1e-6);
2685    }
2686
2687    #[test]
2688    fn test_class_agnostic_vs_aware_nms() {
2689        use crate::float::{nms_class_aware_float, nms_float};
2690
2691        // Two overlapping boxes with different classes
2692        let boxes = vec![
2693            DetectBox {
2694                bbox: BoundingBox {
2695                    xmin: 0.0,
2696                    ymin: 0.0,
2697                    xmax: 0.5,
2698                    ymax: 0.5,
2699                },
2700                score: 0.9,
2701                label: 0,
2702            },
2703            DetectBox {
2704                bbox: BoundingBox {
2705                    xmin: 0.1,
2706                    ymin: 0.1,
2707                    xmax: 0.6,
2708                    ymax: 0.6,
2709                },
2710                score: 0.8,
2711                label: 1,
2712            },
2713        ];
2714
2715        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2716        let agnostic_result = nms_float(0.3, None, boxes.clone());
2717        assert_eq!(
2718            agnostic_result.len(),
2719            1,
2720            "Class-agnostic NMS should suppress overlapping boxes"
2721        );
2722
2723        // Class-aware should keep both (different classes)
2724        let aware_result = nms_class_aware_float(0.3, None, boxes);
2725        assert_eq!(
2726            aware_result.len(),
2727            2,
2728            "Class-aware NMS should keep boxes with different classes"
2729        );
2730    }
2731
2732    #[test]
2733    fn test_class_aware_nms_int() {
2734        use crate::byte::nms_class_aware_int;
2735
2736        // Create two overlapping boxes with different classes
2737        let boxes = vec![
2738            DetectBoxQuantized {
2739                bbox: BoundingBox {
2740                    xmin: 0.0,
2741                    ymin: 0.0,
2742                    xmax: 0.5,
2743                    ymax: 0.5,
2744                },
2745                score: 200_u8,
2746                label: 0,
2747            },
2748            DetectBoxQuantized {
2749                bbox: BoundingBox {
2750                    xmin: 0.1,
2751                    ymin: 0.1,
2752                    xmax: 0.6,
2753                    ymax: 0.6,
2754                },
2755                score: 180_u8,
2756                label: 1, // different class
2757            },
2758        ];
2759
2760        // Should keep both (different classes)
2761        let result = nms_class_aware_int(0.5, None, boxes);
2762        assert_eq!(
2763            result.len(),
2764            2,
2765            "Class-aware NMS (int) should keep boxes with different classes"
2766        );
2767    }
2768
2769    #[test]
2770    fn test_nms_enum_default() {
2771        // Test that Nms enum has the correct default
2772        let default_nms: configs::Nms = Default::default();
2773        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2774    }
2775
2776    #[test]
2777    fn test_decoder_nms_mode() {
2778        // Test that decoder properly stores NMS mode
2779        let decoder = DecoderBuilder::default()
2780            .with_config_yolo_det(
2781                configs::Detection {
2782                    anchors: None,
2783                    decoder: DecoderType::Ultralytics,
2784                    quantization: None,
2785                    shape: vec![1, 84, 8400],
2786                    dshape: Vec::new(),
2787                    normalized: Some(true),
2788                },
2789                None,
2790            )
2791            .with_nms(Some(configs::Nms::ClassAware))
2792            .build()
2793            .unwrap();
2794
2795        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2796    }
2797
2798    #[test]
2799    fn test_decoder_nms_bypass() {
2800        // Test that decoder can be configured with nms=None (bypass)
2801        let decoder = DecoderBuilder::default()
2802            .with_config_yolo_det(
2803                configs::Detection {
2804                    anchors: None,
2805                    decoder: DecoderType::Ultralytics,
2806                    quantization: None,
2807                    shape: vec![1, 84, 8400],
2808                    dshape: Vec::new(),
2809                    normalized: Some(true),
2810                },
2811                None,
2812            )
2813            .with_nms(None)
2814            .build()
2815            .unwrap();
2816
2817        assert_eq!(decoder.nms, None);
2818    }
2819
2820    #[test]
2821    fn test_decoder_normalized_boxes_true() {
2822        // Test that normalized_boxes returns Some(true) when explicitly set
2823        let decoder = DecoderBuilder::default()
2824            .with_config_yolo_det(
2825                configs::Detection {
2826                    anchors: None,
2827                    decoder: DecoderType::Ultralytics,
2828                    quantization: None,
2829                    shape: vec![1, 84, 8400],
2830                    dshape: Vec::new(),
2831                    normalized: Some(true),
2832                },
2833                None,
2834            )
2835            .build()
2836            .unwrap();
2837
2838        assert_eq!(decoder.normalized_boxes(), Some(true));
2839    }
2840
2841    #[test]
2842    fn test_decoder_normalized_boxes_false() {
2843        // Test that normalized_boxes returns Some(false) when config specifies
2844        // unnormalized
2845        let decoder = DecoderBuilder::default()
2846            .with_config_yolo_det(
2847                configs::Detection {
2848                    anchors: None,
2849                    decoder: DecoderType::Ultralytics,
2850                    quantization: None,
2851                    shape: vec![1, 84, 8400],
2852                    dshape: Vec::new(),
2853                    normalized: Some(false),
2854                },
2855                None,
2856            )
2857            .build()
2858            .unwrap();
2859
2860        assert_eq!(decoder.normalized_boxes(), Some(false));
2861    }
2862
2863    #[test]
2864    fn test_decoder_normalized_boxes_unknown() {
2865        // Test that normalized_boxes returns None when not specified in config
2866        let decoder = DecoderBuilder::default()
2867            .with_config_yolo_det(
2868                configs::Detection {
2869                    anchors: None,
2870                    decoder: DecoderType::Ultralytics,
2871                    quantization: None,
2872                    shape: vec![1, 84, 8400],
2873                    dshape: Vec::new(),
2874                    normalized: None,
2875                },
2876                Some(DecoderVersion::Yolo11),
2877            )
2878            .build()
2879            .unwrap();
2880
2881        assert_eq!(decoder.normalized_boxes(), None);
2882    }
2883
2884    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2885        input: ArrayView<F, D>,
2886        quant: Quantization,
2887    ) -> Array<T, D>
2888    where
2889        i32: num_traits::AsPrimitive<F>,
2890        f32: num_traits::AsPrimitive<F>,
2891    {
2892        let zero_point = quant.zero_point.as_();
2893        let div_scale = F::one() / quant.scale.as_();
2894        if zero_point != F::zero() {
2895            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2896        } else {
2897            input.mapv(|d| (d * div_scale).round().as_())
2898        }
2899    }
2900
2901    fn real_data_expected_boxes() -> [DetectBox; 2] {
2902        [
2903            DetectBox {
2904                bbox: BoundingBox {
2905                    xmin: 0.08515105,
2906                    ymin: 0.7131401,
2907                    xmax: 0.29802868,
2908                    ymax: 0.8195788,
2909                },
2910                score: 0.91537374,
2911                label: 23,
2912            },
2913            DetectBox {
2914                bbox: BoundingBox {
2915                    xmin: 0.59605736,
2916                    ymin: 0.25545314,
2917                    xmax: 0.93666154,
2918                    ymax: 0.72378385,
2919                },
2920                score: 0.91537374,
2921                label: 23,
2922            },
2923        ]
2924    }
2925
2926    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2927        [DetectBox {
2928            bbox: BoundingBox {
2929                xmin: 0.12549022,
2930                ymin: 0.12549022,
2931                xmax: 0.23529413,
2932                ymax: 0.23529413,
2933            },
2934            score: 0.98823535,
2935            label: 2,
2936        }]
2937    }
2938
2939    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2940        [DetectBox {
2941            bbox: BoundingBox {
2942                xmin: 0.1234,
2943                ymin: 0.1234,
2944                xmax: 0.2345,
2945                ymax: 0.2345,
2946            },
2947            score: 0.9876,
2948            label: 2,
2949        }]
2950    }
2951
2952    macro_rules! real_data_proto_test {
2953        ($name:ident, quantized, $layout:ident) => {
2954            #[test]
2955            fn $name() {
2956                let is_split = matches!(stringify!($layout), "split");
2957
2958                let score_threshold = 0.45;
2959                let iou_threshold = 0.45;
2960                let quant_boxes = (0.021287762_f32, 31_i32);
2961                let quant_protos = (0.02491162_f32, -117_i32);
2962
2963                let raw_boxes = include_bytes!(concat!(
2964                    env!("CARGO_MANIFEST_DIR"),
2965                    "/../../testdata/yolov8_boxes_116x8400.bin"
2966                ));
2967                let raw_boxes = unsafe {
2968                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2969                };
2970                let boxes_i8 =
2971                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2972
2973                let raw_protos = include_bytes!(concat!(
2974                    env!("CARGO_MANIFEST_DIR"),
2975                    "/../../testdata/yolov8_protos_160x160x32.bin"
2976                ));
2977                let raw_protos = unsafe {
2978                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2979                };
2980                let protos_i8 =
2981                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2982                        .unwrap();
2983
2984                // Pre-split (unused for combined, but harmless)
2985                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2986                let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2987                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2988                let boxes_combined = boxes_i8;
2989
2990                let decoder = if is_split {
2991                    build_yolo_split_segdet_decoder(
2992                        score_threshold,
2993                        iou_threshold,
2994                        quant_boxes,
2995                        quant_protos,
2996                    )
2997                } else {
2998                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2999                };
3000
3001                let expected = real_data_expected_boxes();
3002                let mut output_boxes = Vec::with_capacity(50);
3003
3004                let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
3005                    vec![
3006                        boxes_split.view().into(),
3007                        scores_split.view().into(),
3008                        mask_split.view().into(),
3009                        protos_i8.view().into(),
3010                    ]
3011                } else {
3012                    vec![boxes_combined.view().into(), protos_i8.view().into()]
3013                };
3014                decoder
3015                    .decode_quantized_proto(&inputs, &mut output_boxes)
3016                    .unwrap();
3017
3018                assert_eq!(output_boxes.len(), 2);
3019                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3020                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3021            }
3022        };
3023        ($name:ident, float, $layout:ident) => {
3024            #[test]
3025            fn $name() {
3026                let is_split = matches!(stringify!($layout), "split");
3027
3028                let score_threshold = 0.45;
3029                let iou_threshold = 0.45;
3030                let quant_boxes = (0.021287762_f32, 31_i32);
3031                let quant_protos = (0.02491162_f32, -117_i32);
3032
3033                let raw_boxes = include_bytes!(concat!(
3034                    env!("CARGO_MANIFEST_DIR"),
3035                    "/../../testdata/yolov8_boxes_116x8400.bin"
3036                ));
3037                let raw_boxes = unsafe {
3038                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
3039                };
3040                let boxes_i8 =
3041                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
3042                let boxes_f32: Array3<f32> =
3043                    dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
3044
3045                let raw_protos = include_bytes!(concat!(
3046                    env!("CARGO_MANIFEST_DIR"),
3047                    "/../../testdata/yolov8_protos_160x160x32.bin"
3048                ));
3049                let raw_protos = unsafe {
3050                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3051                };
3052                let protos_i8 =
3053                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
3054                        .unwrap();
3055                let protos_f32: Array4<f32> =
3056                    dequantize_ndarray(protos_i8.view(), quant_protos.into());
3057
3058                // Pre-split from dequantized data
3059                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
3060                let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
3061                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
3062                let boxes_combined = boxes_f32;
3063
3064                let decoder = if is_split {
3065                    build_yolo_split_segdet_decoder(
3066                        score_threshold,
3067                        iou_threshold,
3068                        quant_boxes,
3069                        quant_protos,
3070                    )
3071                } else {
3072                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
3073                };
3074
3075                let expected = real_data_expected_boxes();
3076                let mut output_boxes = Vec::with_capacity(50);
3077
3078                let inputs = if is_split {
3079                    vec![
3080                        boxes_split.view().into_dyn(),
3081                        scores_split.view().into_dyn(),
3082                        mask_split.view().into_dyn(),
3083                        protos_f32.view().into_dyn(),
3084                    ]
3085                } else {
3086                    vec![
3087                        boxes_combined.view().into_dyn(),
3088                        protos_f32.view().into_dyn(),
3089                    ]
3090                };
3091                decoder
3092                    .decode_float_proto(&inputs, &mut output_boxes)
3093                    .unwrap();
3094
3095                assert_eq!(output_boxes.len(), 2);
3096                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3097                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3098            }
3099        };
3100    }
3101
3102    real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
3103    real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
3104    real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
3105    real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3106
3107    const E2E_COMBINED_DET_CONFIG: &str = "
3108decoder_version: yolo26
3109outputs:
3110 - type: detection
3111   decoder: ultralytics
3112   quantization: [0.00784313725490196, 0]
3113   shape: [1, 10, 6]
3114   dshape:
3115    - [batch, 1]
3116    - [num_boxes, 10]
3117    - [num_features, 6]
3118   normalized: true
3119";
3120
3121    const E2E_COMBINED_SEGDET_CONFIG: &str = "
3122decoder_version: yolo26
3123outputs:
3124 - type: detection
3125   decoder: ultralytics
3126   quantization: [0.00784313725490196, 0]
3127   shape: [1, 10, 38]
3128   dshape:
3129    - [batch, 1]
3130    - [num_boxes, 10]
3131    - [num_features, 38]
3132   normalized: true
3133 - type: protos
3134   decoder: ultralytics
3135   quantization: [0.0039215686274509803921568627451, 128]
3136   shape: [1, 160, 160, 32]
3137   dshape:
3138    - [batch, 1]
3139    - [height, 160]
3140    - [width, 160]
3141    - [num_protos, 32]
3142";
3143
3144    const E2E_SPLIT_DET_CONFIG: &str = "
3145decoder_version: yolo26
3146outputs:
3147 - type: boxes
3148   decoder: ultralytics
3149   quantization: [0.00784313725490196, 0]
3150   shape: [1, 10, 4]
3151   dshape:
3152    - [batch, 1]
3153    - [num_boxes, 10]
3154    - [box_coords, 4]
3155   normalized: true
3156 - type: scores
3157   decoder: ultralytics
3158   quantization: [0.00784313725490196, 0]
3159   shape: [1, 10, 1]
3160   dshape:
3161    - [batch, 1]
3162    - [num_boxes, 10]
3163    - [num_classes, 1]
3164 - type: classes
3165   decoder: ultralytics
3166   quantization: [0.00784313725490196, 0]
3167   shape: [1, 10, 1]
3168   dshape:
3169    - [batch, 1]
3170    - [num_boxes, 10]
3171    - [num_classes, 1]
3172";
3173
3174    const E2E_SPLIT_SEGDET_CONFIG: &str = "
3175decoder_version: yolo26
3176outputs:
3177 - type: boxes
3178   decoder: ultralytics
3179   quantization: [0.00784313725490196, 0]
3180   shape: [1, 10, 4]
3181   dshape:
3182    - [batch, 1]
3183    - [num_boxes, 10]
3184    - [box_coords, 4]
3185   normalized: true
3186 - type: scores
3187   decoder: ultralytics
3188   quantization: [0.00784313725490196, 0]
3189   shape: [1, 10, 1]
3190   dshape:
3191    - [batch, 1]
3192    - [num_boxes, 10]
3193    - [num_classes, 1]
3194 - type: classes
3195   decoder: ultralytics
3196   quantization: [0.00784313725490196, 0]
3197   shape: [1, 10, 1]
3198   dshape:
3199    - [batch, 1]
3200    - [num_boxes, 10]
3201    - [num_classes, 1]
3202 - type: mask_coefficients
3203   decoder: ultralytics
3204   quantization: [0.00784313725490196, 0]
3205   shape: [1, 10, 32]
3206   dshape:
3207    - [batch, 1]
3208    - [num_boxes, 10]
3209    - [num_protos, 32]
3210 - type: protos
3211   decoder: ultralytics
3212   quantization: [0.0039215686274509803921568627451, 128]
3213   shape: [1, 160, 160, 32]
3214   dshape:
3215    - [batch, 1]
3216    - [height, 160]
3217    - [width, 160]
3218    - [num_protos, 32]
3219";
3220
3221    macro_rules! e2e_segdet_test {
3222        ($name:ident, quantized, $layout:ident, $output:ident) => {
3223            #[test]
3224            fn $name() {
3225                let is_split = matches!(stringify!($layout), "split");
3226                let is_proto = matches!(stringify!($output), "proto");
3227
3228                let score_threshold = 0.45;
3229                let iou_threshold = 0.45;
3230
3231                let mut boxes = Array2::zeros((10, 4));
3232                let mut scores = Array2::zeros((10, 1));
3233                let mut classes = Array2::zeros((10, 1));
3234                let mask = Array2::zeros((10, 32));
3235                let protos = Array3::<f64>::zeros((160, 160, 32));
3236                let protos = protos.insert_axis(Axis(0));
3237                let protos_quant = (1.0 / 255.0, 0.0);
3238                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3239
3240                boxes
3241                    .slice_mut(s![0, ..])
3242                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3243                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3244                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3245
3246                let detect_quant = (2.0 / 255.0, 0.0);
3247
3248                let decoder = if is_split {
3249                    DecoderBuilder::default()
3250                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3251                        .with_score_threshold(score_threshold)
3252                        .with_iou_threshold(iou_threshold)
3253                        .build()
3254                        .unwrap()
3255                } else {
3256                    DecoderBuilder::default()
3257                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3258                        .with_score_threshold(score_threshold)
3259                        .with_iou_threshold(iou_threshold)
3260                        .build()
3261                        .unwrap()
3262                };
3263
3264                let expected = e2e_expected_boxes_quant();
3265                let mut output_boxes = Vec::with_capacity(50);
3266
3267                if is_split {
3268                    let boxes = boxes.insert_axis(Axis(0));
3269                    let scores = scores.insert_axis(Axis(0));
3270                    let classes = classes.insert_axis(Axis(0));
3271                    let mask = mask.insert_axis(Axis(0));
3272
3273                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3274                    let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3275                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3276                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3277
3278                    if is_proto {
3279                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3280                            boxes.view().into(),
3281                            scores.view().into(),
3282                            classes.view().into(),
3283                            mask.view().into(),
3284                            protos.view().into(),
3285                        ];
3286                        decoder
3287                            .decode_quantized_proto(&inputs, &mut output_boxes)
3288                            .unwrap();
3289
3290                        assert_eq!(output_boxes.len(), 1);
3291                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3292                    } else {
3293                        let mut output_masks = Vec::with_capacity(50);
3294                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3295                            boxes.view().into(),
3296                            scores.view().into(),
3297                            classes.view().into(),
3298                            mask.view().into(),
3299                            protos.view().into(),
3300                        ];
3301                        decoder
3302                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3303                            .unwrap();
3304
3305                        assert_eq!(output_boxes.len(), 1);
3306                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3307                    }
3308                } else {
3309                    // Combined layout
3310                    let detect = ndarray::concatenate![
3311                        Axis(1),
3312                        boxes.view(),
3313                        scores.view(),
3314                        classes.view(),
3315                        mask.view()
3316                    ];
3317                    let detect = detect.insert_axis(Axis(0));
3318                    assert_eq!(detect.shape(), &[1, 10, 38]);
3319                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3320
3321                    if is_proto {
3322                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3323                            vec![detect.view().into(), protos.view().into()];
3324                        decoder
3325                            .decode_quantized_proto(&inputs, &mut output_boxes)
3326                            .unwrap();
3327
3328                        assert_eq!(output_boxes.len(), 1);
3329                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3330                    } else {
3331                        let mut output_masks = Vec::with_capacity(50);
3332                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3333                            vec![detect.view().into(), protos.view().into()];
3334                        decoder
3335                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3336                            .unwrap();
3337
3338                        assert_eq!(output_boxes.len(), 1);
3339                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3340                    }
3341                }
3342            }
3343        };
3344        ($name:ident, float, $layout:ident, $output:ident) => {
3345            #[test]
3346            fn $name() {
3347                let is_split = matches!(stringify!($layout), "split");
3348                let is_proto = matches!(stringify!($output), "proto");
3349
3350                let score_threshold = 0.45;
3351                let iou_threshold = 0.45;
3352
3353                let mut boxes = Array2::zeros((10, 4));
3354                let mut scores = Array2::zeros((10, 1));
3355                let mut classes = Array2::zeros((10, 1));
3356                let mask: Array2<f64> = Array2::zeros((10, 32));
3357                let protos = Array3::<f64>::zeros((160, 160, 32));
3358                let protos = protos.insert_axis(Axis(0));
3359
3360                boxes
3361                    .slice_mut(s![0, ..])
3362                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3363                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3364                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3365
3366                let decoder = if is_split {
3367                    DecoderBuilder::default()
3368                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3369                        .with_score_threshold(score_threshold)
3370                        .with_iou_threshold(iou_threshold)
3371                        .build()
3372                        .unwrap()
3373                } else {
3374                    DecoderBuilder::default()
3375                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3376                        .with_score_threshold(score_threshold)
3377                        .with_iou_threshold(iou_threshold)
3378                        .build()
3379                        .unwrap()
3380                };
3381
3382                let expected = e2e_expected_boxes_float();
3383                let mut output_boxes = Vec::with_capacity(50);
3384
3385                if is_split {
3386                    let boxes = boxes.insert_axis(Axis(0));
3387                    let scores = scores.insert_axis(Axis(0));
3388                    let classes = classes.insert_axis(Axis(0));
3389                    let mask = mask.insert_axis(Axis(0));
3390
3391                    if is_proto {
3392                        let inputs = vec![
3393                            boxes.view().into_dyn(),
3394                            scores.view().into_dyn(),
3395                            classes.view().into_dyn(),
3396                            mask.view().into_dyn(),
3397                            protos.view().into_dyn(),
3398                        ];
3399                        decoder
3400                            .decode_float_proto(&inputs, &mut output_boxes)
3401                            .unwrap();
3402
3403                        assert_eq!(output_boxes.len(), 1);
3404                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3405                    } else {
3406                        let mut output_masks = Vec::with_capacity(50);
3407                        let inputs = vec![
3408                            boxes.view().into_dyn(),
3409                            scores.view().into_dyn(),
3410                            classes.view().into_dyn(),
3411                            mask.view().into_dyn(),
3412                            protos.view().into_dyn(),
3413                        ];
3414                        decoder
3415                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3416                            .unwrap();
3417
3418                        assert_eq!(output_boxes.len(), 1);
3419                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3420                    }
3421                } else {
3422                    // Combined layout
3423                    let detect = ndarray::concatenate![
3424                        Axis(1),
3425                        boxes.view(),
3426                        scores.view(),
3427                        classes.view(),
3428                        mask.view()
3429                    ];
3430                    let detect = detect.insert_axis(Axis(0));
3431                    assert_eq!(detect.shape(), &[1, 10, 38]);
3432
3433                    if is_proto {
3434                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3435                        decoder
3436                            .decode_float_proto(&inputs, &mut output_boxes)
3437                            .unwrap();
3438
3439                        assert_eq!(output_boxes.len(), 1);
3440                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3441                    } else {
3442                        let mut output_masks = Vec::with_capacity(50);
3443                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3444                        decoder
3445                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3446                            .unwrap();
3447
3448                        assert_eq!(output_boxes.len(), 1);
3449                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3450                    }
3451                }
3452            }
3453        };
3454    }
3455
3456    e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3457    e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3458    e2e_segdet_test!(
3459        test_decoder_end_to_end_segdet_proto,
3460        quantized,
3461        combined,
3462        proto
3463    );
3464    e2e_segdet_test!(
3465        test_decoder_end_to_end_segdet_proto_float,
3466        float,
3467        combined,
3468        proto
3469    );
3470    e2e_segdet_test!(
3471        test_decoder_end_to_end_segdet_split,
3472        quantized,
3473        split,
3474        masks
3475    );
3476    e2e_segdet_test!(
3477        test_decoder_end_to_end_segdet_split_float,
3478        float,
3479        split,
3480        masks
3481    );
3482    e2e_segdet_test!(
3483        test_decoder_end_to_end_segdet_split_proto,
3484        quantized,
3485        split,
3486        proto
3487    );
3488    e2e_segdet_test!(
3489        test_decoder_end_to_end_segdet_split_proto_float,
3490        float,
3491        split,
3492        proto
3493    );
3494
3495    macro_rules! e2e_det_test {
3496        ($name:ident, quantized, $layout:ident) => {
3497            #[test]
3498            fn $name() {
3499                let is_split = matches!(stringify!($layout), "split");
3500
3501                let score_threshold = 0.45;
3502                let iou_threshold = 0.45;
3503
3504                let mut boxes = Array3::zeros((1, 10, 4));
3505                let mut scores = Array3::zeros((1, 10, 1));
3506                let mut classes = Array3::zeros((1, 10, 1));
3507
3508                boxes
3509                    .slice_mut(s![0, 0, ..])
3510                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3511                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3512                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3513
3514                let detect_quant = (2.0 / 255.0, 0_i32);
3515
3516                let decoder = if is_split {
3517                    DecoderBuilder::default()
3518                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3519                        .with_score_threshold(score_threshold)
3520                        .with_iou_threshold(iou_threshold)
3521                        .build()
3522                        .unwrap()
3523                } else {
3524                    DecoderBuilder::default()
3525                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3526                        .with_score_threshold(score_threshold)
3527                        .with_iou_threshold(iou_threshold)
3528                        .build()
3529                        .unwrap()
3530                };
3531
3532                let expected = e2e_expected_boxes_quant();
3533                let mut output_boxes = Vec::with_capacity(50);
3534
3535                if is_split {
3536                    let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3537                    let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3538                    let classes: Array<u8, _> =
3539                        quantize_ndarray(classes.view(), detect_quant.into());
3540                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3541                        boxes.view().into(),
3542                        scores.view().into(),
3543                        classes.view().into(),
3544                    ];
3545                    decoder
3546                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3547                        .unwrap();
3548                } else {
3549                    let detect =
3550                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3551                    assert_eq!(detect.shape(), &[1, 10, 6]);
3552                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3553                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3554                        vec![detect.view().into()];
3555                    decoder
3556                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3557                        .unwrap();
3558                }
3559
3560                assert_eq!(output_boxes.len(), 1);
3561                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3562            }
3563        };
3564        ($name:ident, float, $layout:ident) => {
3565            #[test]
3566            fn $name() {
3567                let is_split = matches!(stringify!($layout), "split");
3568
3569                let score_threshold = 0.45;
3570                let iou_threshold = 0.45;
3571
3572                let mut boxes = Array3::zeros((1, 10, 4));
3573                let mut scores = Array3::zeros((1, 10, 1));
3574                let mut classes = Array3::zeros((1, 10, 1));
3575
3576                boxes
3577                    .slice_mut(s![0, 0, ..])
3578                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3579                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3580                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3581
3582                let decoder = if is_split {
3583                    DecoderBuilder::default()
3584                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3585                        .with_score_threshold(score_threshold)
3586                        .with_iou_threshold(iou_threshold)
3587                        .build()
3588                        .unwrap()
3589                } else {
3590                    DecoderBuilder::default()
3591                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3592                        .with_score_threshold(score_threshold)
3593                        .with_iou_threshold(iou_threshold)
3594                        .build()
3595                        .unwrap()
3596                };
3597
3598                let expected = e2e_expected_boxes_float();
3599                let mut output_boxes = Vec::with_capacity(50);
3600
3601                if is_split {
3602                    let inputs = vec![
3603                        boxes.view().into_dyn(),
3604                        scores.view().into_dyn(),
3605                        classes.view().into_dyn(),
3606                    ];
3607                    decoder
3608                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3609                        .unwrap();
3610                } else {
3611                    let detect =
3612                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3613                    assert_eq!(detect.shape(), &[1, 10, 6]);
3614                    let inputs = vec![detect.view().into_dyn()];
3615                    decoder
3616                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3617                        .unwrap();
3618                }
3619
3620                assert_eq!(output_boxes.len(), 1);
3621                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3622            }
3623        };
3624    }
3625
3626    e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3627    e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3628    e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3629    e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3630
3631    #[test]
3632    fn test_decode_tensor() {
3633        let score_threshold = 0.45;
3634        let iou_threshold = 0.45;
3635
3636        let raw_boxes = include_bytes!(concat!(
3637            env!("CARGO_MANIFEST_DIR"),
3638            "/../../testdata/yolov8_boxes_116x8400.bin"
3639        ));
3640        let raw_boxes =
3641            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3642        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3643        boxes_i8
3644            .map()
3645            .unwrap()
3646            .as_mut_slice()
3647            .copy_from_slice(raw_boxes);
3648        let boxes_i8 = boxes_i8.into();
3649
3650        let raw_protos = include_bytes!(concat!(
3651            env!("CARGO_MANIFEST_DIR"),
3652            "/../../testdata/yolov8_protos_160x160x32.bin"
3653        ));
3654        let raw_protos = unsafe {
3655            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3656        };
3657        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3658        protos_i8
3659            .map()
3660            .unwrap()
3661            .as_mut_slice()
3662            .copy_from_slice(raw_protos);
3663        let protos_i8 = protos_i8.into();
3664
3665        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3666        let expected = real_data_expected_boxes();
3667        let mut output_boxes = Vec::with_capacity(50);
3668
3669        decoder
3670            .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3671            .unwrap();
3672
3673        assert_eq!(output_boxes.len(), 2);
3674        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3675        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3676    }
3677
3678    #[test]
3679    fn test_decode_tensor_f32() {
3680        let score_threshold = 0.45;
3681        let iou_threshold = 0.45;
3682
3683        let quant_boxes = (0.021287762_f32, 31_i32);
3684        let quant_protos = (0.02491162_f32, -117_i32);
3685        let raw_boxes = include_bytes!(concat!(
3686            env!("CARGO_MANIFEST_DIR"),
3687            "/../../testdata/yolov8_boxes_116x8400.bin"
3688        ));
3689        let raw_boxes =
3690            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3691        let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3692        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3693        let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3694        boxes_f32
3695            .map()
3696            .unwrap()
3697            .as_mut_slice()
3698            .copy_from_slice(&raw_boxes_f32);
3699        let boxes_f32 = boxes_f32.into();
3700
3701        let raw_protos = include_bytes!(concat!(
3702            env!("CARGO_MANIFEST_DIR"),
3703            "/../../testdata/yolov8_protos_160x160x32.bin"
3704        ));
3705        let raw_protos = unsafe {
3706            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3707        };
3708        let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3709        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3710        let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3711        protos_f32
3712            .map()
3713            .unwrap()
3714            .as_mut_slice()
3715            .copy_from_slice(&raw_protos_f32);
3716        let protos_f32 = protos_f32.into();
3717
3718        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3719
3720        let expected = real_data_expected_boxes();
3721        let mut output_boxes = Vec::with_capacity(50);
3722
3723        decoder
3724            .decode(
3725                &[&boxes_f32, &protos_f32],
3726                &mut output_boxes,
3727                &mut Vec::new(),
3728            )
3729            .unwrap();
3730
3731        assert_eq!(output_boxes.len(), 2);
3732        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3733        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3734    }
3735
3736    #[test]
3737    fn test_decode_tensor_f64() {
3738        let score_threshold = 0.45;
3739        let iou_threshold = 0.45;
3740
3741        let quant_boxes = (0.021287762_f32, 31_i32);
3742        let quant_protos = (0.02491162_f32, -117_i32);
3743        let raw_boxes = include_bytes!(concat!(
3744            env!("CARGO_MANIFEST_DIR"),
3745            "/../../testdata/yolov8_boxes_116x8400.bin"
3746        ));
3747        let raw_boxes =
3748            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3749        let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3750        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3751        let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3752        boxes_f64
3753            .map()
3754            .unwrap()
3755            .as_mut_slice()
3756            .copy_from_slice(&raw_boxes_f64);
3757        let boxes_f64 = boxes_f64.into();
3758
3759        let raw_protos = include_bytes!(concat!(
3760            env!("CARGO_MANIFEST_DIR"),
3761            "/../../testdata/yolov8_protos_160x160x32.bin"
3762        ));
3763        let raw_protos = unsafe {
3764            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3765        };
3766        let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3767        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3768        let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3769        protos_f64
3770            .map()
3771            .unwrap()
3772            .as_mut_slice()
3773            .copy_from_slice(&raw_protos_f64);
3774        let protos_f64 = protos_f64.into();
3775
3776        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3777
3778        let expected = real_data_expected_boxes();
3779        let mut output_boxes = Vec::with_capacity(50);
3780
3781        decoder
3782            .decode(
3783                &[&boxes_f64, &protos_f64],
3784                &mut output_boxes,
3785                &mut Vec::new(),
3786            )
3787            .unwrap();
3788
3789        assert_eq!(output_boxes.len(), 2);
3790        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3791        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3792    }
3793
3794    #[test]
3795    fn test_decode_tensor_proto() {
3796        let score_threshold = 0.45;
3797        let iou_threshold = 0.45;
3798
3799        let raw_boxes = include_bytes!(concat!(
3800            env!("CARGO_MANIFEST_DIR"),
3801            "/../../testdata/yolov8_boxes_116x8400.bin"
3802        ));
3803        let raw_boxes =
3804            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3805        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3806        boxes_i8
3807            .map()
3808            .unwrap()
3809            .as_mut_slice()
3810            .copy_from_slice(raw_boxes);
3811        let boxes_i8 = boxes_i8.into();
3812
3813        let raw_protos = include_bytes!(concat!(
3814            env!("CARGO_MANIFEST_DIR"),
3815            "/../../testdata/yolov8_protos_160x160x32.bin"
3816        ));
3817        let raw_protos = unsafe {
3818            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3819        };
3820        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3821        protos_i8
3822            .map()
3823            .unwrap()
3824            .as_mut_slice()
3825            .copy_from_slice(raw_protos);
3826        let protos_i8 = protos_i8.into();
3827
3828        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3829
3830        let expected = real_data_expected_boxes();
3831        let mut output_boxes = Vec::with_capacity(50);
3832
3833        let proto_data = decoder
3834            .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3835            .unwrap();
3836
3837        assert_eq!(output_boxes.len(), 2);
3838        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3839        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3840
3841        let proto_data = proto_data.expect("segmentation model should return ProtoData");
3842        let coeffs_shape = proto_data.mask_coefficients.shape();
3843        assert_eq!(
3844            coeffs_shape[0],
3845            output_boxes.len(),
3846            "mask_coefficients count must match detection count"
3847        );
3848        assert_eq!(
3849            coeffs_shape[1], 32,
3850            "each detection should have 32 mask coefficients"
3851        );
3852    }
3853
3854    // =========================================================================
3855    // Physical-order contract regression tests
3856    //
3857    // These cover the TFLite VX-delegate vertical-stripe mask bug and the
3858    // Ara-2-via-overlay anchor-first split-tensor bug. The core claim:
3859    // declaring shape+dshape in physical memory order produces correct
3860    // strides at wrap time, and `swap_axes_if_needed` permutes the stride
3861    // tuple into canonical logical order without touching bytes.
3862    // =========================================================================
3863
3864    /// TFLite YOLOv8-seg protos arrive as physically-NHWC bytes with
3865    /// NNStreamer dim string `"32:160:160:1"` (innermost-first). The
3866    /// overlay's 2026-04-22 workaround declares `[1, 160, 160, 32]` with
3867    /// dshape `[batch, height, width, num_protos]` — shape matches
3868    /// physical order, so no permutation is needed and the view is
3869    /// already in canonical HWC after the batch-dim slice. This test
3870    /// confirms that path matches the reference (direct-HWC) decode.
3871    #[test]
3872    fn test_physical_order_tflite_nhwc_protos() {
3873        let score_threshold = 0.45;
3874        let iou_threshold = 0.45;
3875
3876        // Load HWC protos and dequantize to f32.
3877        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
3878        let quant_protos = Quantization::new(0.02491161972284317, -117);
3879        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
3880
3881        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
3882        let quant_boxes = Quantization::new(0.021287761628627777, 31);
3883        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
3884
3885        // Reference: direct call with canonical HWC protos.
3886        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
3887        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
3888        decode_yolo_segdet_float(
3889            seg.view(),
3890            protos_f32_hwc.view(),
3891            score_threshold,
3892            iou_threshold,
3893            Some(configs::Nms::ClassAgnostic),
3894            &mut ref_boxes,
3895            &mut ref_masks,
3896        )
3897        .unwrap();
3898
3899        // Build the same protos as a 4D NHWC tensor and feed it through
3900        // the config-driven path with dshape in physical order.
3901        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0)); // [1, 160, 160, 32]
3902        let seg_3d = seg.insert_axis(Axis(0)); // [1, 116, 8400]
3903
3904        let decoder = DecoderBuilder::default()
3905            .with_config_yolo_segdet(
3906                configs::Detection {
3907                    decoder: configs::DecoderType::Ultralytics,
3908                    quantization: None,
3909                    shape: vec![1, 116, 8400],
3910                    dshape: vec![
3911                        (DimName::Batch, 1),
3912                        (DimName::NumFeatures, 116),
3913                        (DimName::NumBoxes, 8400),
3914                    ],
3915                    normalized: Some(true),
3916                    anchors: None,
3917                },
3918                configs::Protos {
3919                    decoder: configs::DecoderType::Ultralytics,
3920                    quantization: None,
3921                    shape: vec![1, 160, 160, 32],
3922                    // Physical NHWC — matches the TFLite buffer layout.
3923                    dshape: vec![
3924                        (DimName::Batch, 1),
3925                        (DimName::Height, 160),
3926                        (DimName::Width, 160),
3927                        (DimName::NumProtos, 32),
3928                    ],
3929                },
3930                None,
3931            )
3932            .with_score_threshold(score_threshold)
3933            .with_iou_threshold(iou_threshold)
3934            .build()
3935            .expect("config with NHWC protos dshape must build");
3936
3937        let mut cfg_boxes = Vec::with_capacity(10);
3938        let mut cfg_masks = Vec::with_capacity(10);
3939        decoder
3940            .decode_float(
3941                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
3942                &mut cfg_boxes,
3943                &mut cfg_masks,
3944            )
3945            .unwrap();
3946
3947        assert_eq!(cfg_boxes.len(), ref_boxes.len(), "box count mismatch");
3948        for (c, r) in cfg_boxes.iter().zip(&ref_boxes) {
3949            assert!(
3950                c.equal_within_delta(r, 0.01),
3951                "NHWC-declared box does not match reference: {c:?} vs {r:?}"
3952            );
3953        }
3954        for (cm, rm) in cfg_masks.iter().zip(&ref_masks) {
3955            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
3956            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
3957            assert_eq!(
3958                cm_arr, rm_arr,
3959                "NHWC-declared mask must match reference pixel-for-pixel"
3960            );
3961        }
3962    }
3963
3964    /// Ara-2 split-tensor boxes arrive physically anchor-first: NNStreamer
3965    /// dim string `"4:1:8400:1"` means innermost axis is 4 (coords), next
3966    /// is 1 (padding), next is 8400 (anchors), outermost is 1 (batch).
3967    /// The correct physical-order declaration is `[1, 8400, 1, 4]` with
3968    /// dshape `[batch, num_boxes, padding, box_coords]`. This test
3969    /// confirms that declaration produces the same decoded boxes as the
3970    /// canonical features-first TFLite declaration over the same logical
3971    /// data.
3972    #[test]
3973    fn test_physical_order_ara2_anchor_first_split_boxes() {
3974        use configs::{Boxes, Scores};
3975
3976        // Build synthetic boxes data in features-first canonical layout:
3977        // shape [1, 4, 8400] with one detection at anchor 42.
3978        const N: usize = 8400;
3979        let mut boxes_canonical = Array3::<f32>::zeros((1, 4, N));
3980        let target_anchor = 42usize;
3981        boxes_canonical[[0, 0, target_anchor]] = 0.4; // xc
3982        boxes_canonical[[0, 1, target_anchor]] = 0.5; // yc
3983        boxes_canonical[[0, 2, target_anchor]] = 0.2; // w
3984        boxes_canonical[[0, 3, target_anchor]] = 0.2; // h
3985
3986        // Scores: one class at 0.9 for the same anchor.
3987        let mut scores_canonical = Array3::<f32>::zeros((1, 80, N));
3988        scores_canonical[[0, 0, target_anchor]] = 0.9;
3989
3990        // Reference: canonical (features-first) shape declaration.
3991        let ref_decoder = DecoderBuilder::default()
3992            .with_config_yolo_split_det(
3993                Boxes {
3994                    decoder: configs::DecoderType::Ultralytics,
3995                    quantization: None,
3996                    shape: vec![1, 4, N],
3997                    dshape: vec![
3998                        (DimName::Batch, 1),
3999                        (DimName::BoxCoords, 4),
4000                        (DimName::NumBoxes, N),
4001                    ],
4002                    normalized: Some(true),
4003                },
4004                Scores {
4005                    decoder: configs::DecoderType::Ultralytics,
4006                    quantization: None,
4007                    shape: vec![1, 80, N],
4008                    dshape: vec![
4009                        (DimName::Batch, 1),
4010                        (DimName::NumClasses, 80),
4011                        (DimName::NumBoxes, N),
4012                    ],
4013                },
4014            )
4015            .with_score_threshold(0.5)
4016            .with_iou_threshold(0.5)
4017            .with_nms(Some(configs::Nms::ClassAgnostic))
4018            .build()
4019            .expect("reference canonical split decoder must build");
4020
4021        let mut ref_boxes = Vec::with_capacity(4);
4022        let mut ref_masks = Vec::with_capacity(0);
4023        ref_decoder
4024            .decode_float(
4025                &[
4026                    boxes_canonical.view().into_dyn(),
4027                    scores_canonical.view().into_dyn(),
4028                ],
4029                &mut ref_boxes,
4030                &mut ref_masks,
4031            )
4032            .unwrap();
4033        assert_eq!(ref_boxes.len(), 1, "reference should produce one box");
4034
4035        // Ara-2 physical layout: transpose axes 1↔2 so the innermost
4036        // axis is BoxCoords / NumClasses. Materialise to get a
4037        // C-contiguous Array3 with strides matching the physical order
4038        // the Ara-2 backend would write into the tensor.
4039        let boxes_ara2 = boxes_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 4]
4040        let scores_ara2 = scores_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 80]
4041
4042        let ara2_decoder = DecoderBuilder::default()
4043            .with_config_yolo_split_det(
4044                Boxes {
4045                    decoder: configs::DecoderType::Ultralytics,
4046                    quantization: None,
4047                    shape: vec![1, N, 4],
4048                    dshape: vec![
4049                        (DimName::Batch, 1),
4050                        (DimName::NumBoxes, N),
4051                        (DimName::BoxCoords, 4),
4052                    ],
4053                    normalized: Some(true),
4054                },
4055                Scores {
4056                    decoder: configs::DecoderType::Ultralytics,
4057                    quantization: None,
4058                    shape: vec![1, N, 80],
4059                    dshape: vec![
4060                        (DimName::Batch, 1),
4061                        (DimName::NumBoxes, N),
4062                        (DimName::NumClasses, 80),
4063                    ],
4064                },
4065            )
4066            .with_score_threshold(0.5)
4067            .with_iou_threshold(0.5)
4068            .with_nms(Some(configs::Nms::ClassAgnostic))
4069            .build()
4070            .expect("Ara-2 anchor-first decoder must build");
4071
4072        let mut ara2_boxes = Vec::with_capacity(4);
4073        let mut ara2_masks = Vec::with_capacity(0);
4074        ara2_decoder
4075            .decode_float(
4076                &[boxes_ara2.view().into_dyn(), scores_ara2.view().into_dyn()],
4077                &mut ara2_boxes,
4078                &mut ara2_masks,
4079            )
4080            .unwrap();
4081
4082        assert_eq!(
4083            ara2_boxes.len(),
4084            ref_boxes.len(),
4085            "Ara-2 anchor-first declaration must produce the same number \
4086             of boxes as the canonical features-first reference"
4087        );
4088        for (a, r) in ara2_boxes.iter().zip(&ref_boxes) {
4089            assert!(
4090                a.equal_within_delta(r, 1e-4),
4091                "Ara-2 box differs from reference: {a:?} vs {r:?}"
4092            );
4093        }
4094    }
4095
4096    /// Dshape whose per-axis sizes disagree with shape (the exact
4097    /// failure mode that caused the TFLite stripe bug — shape declared
4098    /// in one order, dshape in another) is rejected with a clear error.
4099    #[test]
4100    fn test_physical_order_rejects_shape_dshape_mismatch() {
4101        let result = DecoderBuilder::default()
4102            .with_config_yolo_segdet(
4103                configs::Detection {
4104                    decoder: configs::DecoderType::Ultralytics,
4105                    quantization: None,
4106                    shape: vec![1, 116, 8400],
4107                    dshape: vec![
4108                        (DimName::Batch, 1),
4109                        (DimName::NumFeatures, 116),
4110                        (DimName::NumBoxes, 8400),
4111                    ],
4112                    normalized: Some(true),
4113                    anchors: None,
4114                },
4115                configs::Protos {
4116                    decoder: configs::DecoderType::Ultralytics,
4117                    quantization: None,
4118                    // Shape in NCHW order...
4119                    shape: vec![1, 32, 160, 160],
4120                    // ...but dshape in NHWC order — the sizes don't line
4121                    // up positionally (dshape[1]=160 vs shape[1]=32).
4122                    dshape: vec![
4123                        (DimName::Batch, 1),
4124                        (DimName::Height, 160),
4125                        (DimName::Width, 160),
4126                        (DimName::NumProtos, 32),
4127                    ],
4128                },
4129                None,
4130            )
4131            .build();
4132
4133        match result {
4134            Err(DecoderError::InvalidConfig(msg)) => {
4135                assert!(
4136                    msg.contains("does not match shape"),
4137                    "expected shape/dshape size mismatch error, got: {msg}"
4138                );
4139            }
4140            other => panic!("expected InvalidConfig, got {other:?}"),
4141        }
4142    }
4143
4144    /// Duplicate dim name in dshape is rejected (two axes mapping to the
4145    /// same canonical slot would break the permutation).
4146    #[test]
4147    fn test_physical_order_rejects_duplicate_dshape_axis() {
4148        let result = DecoderBuilder::default()
4149            .with_config_yolo_split_det(
4150                configs::Boxes {
4151                    decoder: configs::DecoderType::Ultralytics,
4152                    quantization: None,
4153                    shape: vec![1, 4, 8400],
4154                    dshape: vec![
4155                        (DimName::Batch, 1),
4156                        (DimName::BoxCoords, 4),
4157                        (DimName::BoxCoords, 4), // duplicate (size matches shape[2]=8400? no)
4158                    ],
4159                    normalized: Some(true),
4160                },
4161                configs::Scores {
4162                    decoder: configs::DecoderType::Ultralytics,
4163                    quantization: None,
4164                    shape: vec![1, 80, 8400],
4165                    dshape: vec![
4166                        (DimName::Batch, 1),
4167                        (DimName::NumClasses, 80),
4168                        (DimName::NumBoxes, 8400),
4169                    ],
4170                },
4171            )
4172            .build();
4173
4174        // Shape[2] = 8400 ≠ dshape[2] size 4, so the positional-mismatch
4175        // check fires first — which is fine: any mis-shaped dshape
4176        // should be rejected. Check for either error as long as it's a
4177        // clear InvalidConfig.
4178        match result {
4179            Err(DecoderError::InvalidConfig(msg)) => {
4180                assert!(
4181                    msg.contains("appears at both index") || msg.contains("does not match shape"),
4182                    "expected positional or duplicate-axis error, got: {msg}"
4183                );
4184            }
4185            other => panic!("expected InvalidConfig, got {other:?}"),
4186        }
4187
4188        // Separate case: size-consistent duplicate to exercise the
4189        // duplicate-axis code path explicitly. Use `Batch` (no size
4190        // constraint, can repeat with size 1 against a shape that
4191        // carries two singleton dims).
4192        let result = DecoderBuilder::default()
4193            .with_config_yolo_split_det(
4194                configs::Boxes {
4195                    decoder: configs::DecoderType::Ultralytics,
4196                    quantization: None,
4197                    shape: vec![1, 1, 4, 8400],
4198                    dshape: vec![
4199                        (DimName::Batch, 1),
4200                        (DimName::Batch, 1), // duplicate, sizes match
4201                        (DimName::BoxCoords, 4),
4202                        (DimName::NumBoxes, 8400),
4203                    ],
4204                    normalized: Some(true),
4205                },
4206                configs::Scores {
4207                    decoder: configs::DecoderType::Ultralytics,
4208                    quantization: None,
4209                    shape: vec![1, 80, 8400],
4210                    dshape: vec![
4211                        (DimName::Batch, 1),
4212                        (DimName::NumClasses, 80),
4213                        (DimName::NumBoxes, 8400),
4214                    ],
4215                },
4216            )
4217            .build();
4218        match result {
4219            Err(DecoderError::InvalidConfig(msg)) => {
4220                assert!(
4221                    msg.contains("appears at both index"),
4222                    "expected duplicate-axis error, got: {msg}"
4223                );
4224            }
4225            other => panic!("expected InvalidConfig, got {other:?}"),
4226        }
4227    }
4228
4229    /// Canonical (dshape-omitted) high-level decode produces the same
4230    /// numeric result as the dshape-populated form. This closes a
4231    /// coverage gap flagged during review: dshape omission had been
4232    /// exercised only at the builder level, not through the full
4233    /// `decode_float` pipeline.
4234    #[test]
4235    fn test_physical_order_dshape_omitted_decodes_numerically() {
4236        let score_threshold = 0.45;
4237        let iou_threshold = 0.45;
4238
4239        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
4240        let quant_protos = Quantization::new(0.02491161972284317, -117);
4241        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
4242
4243        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
4244        let quant_boxes = Quantization::new(0.021287761628627777, 31);
4245        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
4246
4247        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0));
4248        let seg_3d = seg.insert_axis(Axis(0));
4249
4250        let build_decoder = |det_dshape: Vec<(DimName, usize)>,
4251                             proto_dshape: Vec<(DimName, usize)>| {
4252            DecoderBuilder::default()
4253                .with_config_yolo_segdet(
4254                    configs::Detection {
4255                        decoder: configs::DecoderType::Ultralytics,
4256                        quantization: None,
4257                        shape: vec![1, 116, 8400],
4258                        dshape: det_dshape,
4259                        normalized: Some(true),
4260                        anchors: None,
4261                    },
4262                    configs::Protos {
4263                        decoder: configs::DecoderType::Ultralytics,
4264                        quantization: None,
4265                        shape: vec![1, 160, 160, 32],
4266                        dshape: proto_dshape,
4267                    },
4268                    None,
4269                )
4270                .with_score_threshold(score_threshold)
4271                .with_iou_threshold(iou_threshold)
4272                .build()
4273                .unwrap()
4274        };
4275
4276        // Baseline: dshape populated in physical order.
4277        let dshaped = build_decoder(
4278            vec![
4279                (DimName::Batch, 1),
4280                (DimName::NumFeatures, 116),
4281                (DimName::NumBoxes, 8400),
4282            ],
4283            vec![
4284                (DimName::Batch, 1),
4285                (DimName::Height, 160),
4286                (DimName::Width, 160),
4287                (DimName::NumProtos, 32),
4288            ],
4289        );
4290        let mut dshaped_boxes = Vec::new();
4291        let mut dshaped_masks = Vec::new();
4292        dshaped
4293            .decode_float(
4294                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4295                &mut dshaped_boxes,
4296                &mut dshaped_masks,
4297            )
4298            .unwrap();
4299
4300        // Variant: dshape omitted — caller asserts shape is already in
4301        // the decoder's canonical order.
4302        let bare = build_decoder(vec![], vec![]);
4303        let mut bare_boxes = Vec::new();
4304        let mut bare_masks = Vec::new();
4305        bare.decode_float(
4306            &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4307            &mut bare_boxes,
4308            &mut bare_masks,
4309        )
4310        .unwrap();
4311
4312        assert_eq!(bare_boxes.len(), dshaped_boxes.len());
4313        for (b, d) in bare_boxes.iter().zip(&dshaped_boxes) {
4314            assert!(
4315                b.equal_within_delta(d, 1e-4),
4316                "dshape-omitted box {b:?} differs from dshape-populated {d:?}"
4317            );
4318        }
4319        for (bm, dm) in bare_masks.iter().zip(&dshaped_masks) {
4320            let bm_arr = segmentation_to_mask(bm.segmentation.view()).unwrap();
4321            let dm_arr = segmentation_to_mask(dm.segmentation.view()).unwrap();
4322            assert_eq!(
4323                bm_arr, dm_arr,
4324                "dshape-omitted mask must match dshape-populated pixel-for-pixel"
4325            );
4326        }
4327    }
4328
4329    /// 4D anchor-first boxes schema with a trailing `padding` axis — the
4330    /// exact Ara-2 DVM shape (`"4:1:8400:1"` innermost-first = physical
4331    /// `[1, 8400, 1, 4]`). Exercises the schema path's
4332    /// `squeeze_padding_dims`: the caller declares a 4D physical-order
4333    /// layout including `padding`, HAL squeezes the size-1 axis during
4334    /// `to_legacy_config_outputs`, and the decoder sees the resulting
4335    /// 3D shape. Callers feed HAL a 3D squeezed view of the same bytes.
4336    /// Complements the 3D `test_physical_order_ara2_anchor_first_split_boxes`
4337    /// which exercises the programmatic-builder path.
4338    #[test]
4339    fn test_physical_order_ara2_4d_anchor_first_with_padding() {
4340        // Build synthetic data in anchor-first 3D layout (the shape HAL
4341        // actually operates on after squeezing). The 4D-with-padding
4342        // declaration in the schema is a producer-side convention; the
4343        // caller is expected to present HAL with a squeezed view.
4344        const N: usize = 8400;
4345        let mut boxes = Array3::<f32>::zeros((1, N, 4));
4346        let target = 42usize;
4347        boxes[[0, target, 0]] = 0.4;
4348        boxes[[0, target, 1]] = 0.5;
4349        boxes[[0, target, 2]] = 0.2;
4350        boxes[[0, target, 3]] = 0.2;
4351        let mut scores = Array3::<f32>::zeros((1, N, 80));
4352        scores[[0, target, 0]] = 0.9;
4353
4354        // Schema JSON declares shape+dshape in physical anchor-first
4355        // order including padding — mirrors `ara2_int8_edgefirst.json`'s
4356        // feature-first declaration but with the axes in physical
4357        // anchor-first order. `to_legacy_config_outputs` squeezes
4358        // padding before the decoder's rank-3 verification.
4359        let json = r#"{
4360          "schema_version": 2,
4361          "decoder_version": "yolov8",
4362          "nms": "class_agnostic",
4363          "outputs": [
4364            {"name": "boxes", "type": "boxes",
4365             "shape": [1, 8400, 1, 4],
4366             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"box_coords":4}],
4367             "encoding": "direct",
4368             "decoder": "ultralytics",
4369             "normalized": true},
4370            {"name": "scores", "type": "scores",
4371             "shape": [1, 8400, 1, 80],
4372             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"num_classes":80}],
4373             "decoder": "ultralytics",
4374             "score_format": "per_class"}
4375          ]
4376        }"#;
4377        let decoder = DecoderBuilder::default()
4378            .with_config_json_str(json.to_string())
4379            .with_score_threshold(0.5)
4380            .with_iou_threshold(0.5)
4381            .build()
4382            .expect("4D anchor-first schema should build via squeeze_padding_dims");
4383
4384        let mut out_boxes = Vec::with_capacity(4);
4385        let mut out_masks = Vec::with_capacity(0);
4386        decoder
4387            .decode_float(
4388                &[boxes.view().into_dyn(), scores.view().into_dyn()],
4389                &mut out_boxes,
4390                &mut out_masks,
4391            )
4392            .unwrap();
4393
4394        assert_eq!(
4395            out_boxes.len(),
4396            1,
4397            "4D anchor-first with padding should decode exactly one box from the seeded anchor"
4398        );
4399        let b = &out_boxes[0];
4400        // xywh(0.4, 0.5, 0.2, 0.2) → xyxy(0.3, 0.4, 0.5, 0.6)
4401        assert!((b.bbox.xmin - 0.3).abs() < 1e-3, "xmin wrong: {b:?}");
4402        assert!((b.bbox.ymin - 0.4).abs() < 1e-3, "ymin wrong: {b:?}");
4403        assert!((b.bbox.xmax - 0.5).abs() < 1e-3, "xmax wrong: {b:?}");
4404        assert!((b.bbox.ymax - 0.6).abs() < 1e-3, "ymax wrong: {b:?}");
4405        assert_eq!(b.label, 0);
4406        assert!(b.score > 0.85, "score {}: {b:?}", b.score);
4407    }
4408}
4409
4410#[cfg(feature = "tracker")]
4411#[cfg(test)]
4412#[cfg_attr(coverage_nightly, coverage(off))]
4413mod decoder_tracked_tests {
4414
4415    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
4416    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
4417    use num_traits::{AsPrimitive, Float, PrimInt};
4418    use rand::{RngExt, SeedableRng};
4419    use rand_distr::StandardNormal;
4420
4421    use crate::{
4422        configs::{self, DimName},
4423        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
4424    };
4425
4426    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
4427        input: ArrayView<F, D>,
4428        quant: Quantization,
4429    ) -> Array<T, D>
4430    where
4431        i32: num_traits::AsPrimitive<F>,
4432        f32: num_traits::AsPrimitive<F>,
4433    {
4434        let zero_point = quant.zero_point.as_();
4435        let div_scale = F::one() / quant.scale.as_();
4436        if zero_point != F::zero() {
4437            input.mapv(|d| (d * div_scale + zero_point).round().as_())
4438        } else {
4439            input.mapv(|d| (d * div_scale).round().as_())
4440        }
4441    }
4442
4443    #[test]
4444    fn test_decoder_tracked_random_jitter() {
4445        use crate::configs::{DecoderType, Nms};
4446        use crate::DecoderBuilder;
4447
4448        let score_threshold = 0.25;
4449        let iou_threshold = 0.1;
4450        let out = include_bytes!(concat!(
4451            env!("CARGO_MANIFEST_DIR"),
4452            "/../../testdata/yolov8s_80_classes.bin"
4453        ));
4454        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4455        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4456        let quant = (0.0040811873, -123).into();
4457
4458        let decoder = DecoderBuilder::default()
4459            .with_config_yolo_det(
4460                crate::configs::Detection {
4461                    decoder: DecoderType::Ultralytics,
4462                    shape: vec![1, 84, 8400],
4463                    anchors: None,
4464                    quantization: Some(quant),
4465                    dshape: vec![
4466                        (crate::configs::DimName::Batch, 1),
4467                        (crate::configs::DimName::NumFeatures, 84),
4468                        (crate::configs::DimName::NumBoxes, 8400),
4469                    ],
4470                    normalized: Some(true),
4471                },
4472                None,
4473            )
4474            .with_score_threshold(score_threshold)
4475            .with_iou_threshold(iou_threshold)
4476            .with_nms(Some(Nms::ClassAgnostic))
4477            .build()
4478            .unwrap();
4479        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
4480
4481        let expected_boxes = [
4482            crate::DetectBox {
4483                bbox: crate::BoundingBox {
4484                    xmin: 0.5285137,
4485                    ymin: 0.05305544,
4486                    xmax: 0.87541467,
4487                    ymax: 0.9998909,
4488                },
4489                score: 0.5591227,
4490                label: 0,
4491            },
4492            crate::DetectBox {
4493                bbox: crate::BoundingBox {
4494                    xmin: 0.130598,
4495                    ymin: 0.43260583,
4496                    xmax: 0.35098213,
4497                    ymax: 0.9958097,
4498                },
4499                score: 0.33057618,
4500                label: 75,
4501            },
4502        ];
4503
4504        let mut tracker = ByteTrackBuilder::new()
4505            .track_update(0.1)
4506            .track_high_conf(0.3)
4507            .build();
4508
4509        let mut output_boxes = Vec::with_capacity(50);
4510        let mut output_masks = Vec::with_capacity(50);
4511        let mut output_tracks = Vec::with_capacity(50);
4512
4513        decoder
4514            .decode_tracked_quantized(
4515                &mut tracker,
4516                0,
4517                &[out.view().into()],
4518                &mut output_boxes,
4519                &mut output_masks,
4520                &mut output_tracks,
4521            )
4522            .unwrap();
4523
4524        assert_eq!(output_boxes.len(), 2);
4525        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4526        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4527
4528        let mut last_boxes = output_boxes.clone();
4529
4530        for i in 1..=100 {
4531            let mut out = out.clone();
4532            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
4533            let mut x_values = out.slice_mut(s![0, 0, ..]);
4534            for x in x_values.iter_mut() {
4535                let r: f32 = rng.sample(StandardNormal);
4536                let r = r.clamp(-2.0, 2.0) / 2.0;
4537                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
4538            }
4539
4540            let mut y_values = out.slice_mut(s![0, 1, ..]);
4541            for y in y_values.iter_mut() {
4542                let r: f32 = rng.sample(StandardNormal);
4543                let r = r.clamp(-2.0, 2.0) / 2.0;
4544                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
4545            }
4546
4547            decoder
4548                .decode_tracked_quantized(
4549                    &mut tracker,
4550                    100_000_000 * i / 3, // simulate 33.333ms between frames
4551                    &[out.view().into()],
4552                    &mut output_boxes,
4553                    &mut output_masks,
4554                    &mut output_tracks,
4555                )
4556                .unwrap();
4557
4558            assert_eq!(output_boxes.len(), 2);
4559            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
4560            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
4561
4562            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
4563            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
4564            last_boxes = output_boxes.clone();
4565        }
4566    }
4567
4568    // ─── Shared helpers for tracked decoder tests ────────────────────
4569
4570    fn real_data_expected_boxes() -> [DetectBox; 2] {
4571        [
4572            DetectBox {
4573                bbox: BoundingBox {
4574                    xmin: 0.08515105,
4575                    ymin: 0.7131401,
4576                    xmax: 0.29802868,
4577                    ymax: 0.8195788,
4578                },
4579                score: 0.91537374,
4580                label: 23,
4581            },
4582            DetectBox {
4583                bbox: BoundingBox {
4584                    xmin: 0.59605736,
4585                    ymin: 0.25545314,
4586                    xmax: 0.93666154,
4587                    ymax: 0.72378385,
4588                },
4589                score: 0.91537374,
4590                label: 23,
4591            },
4592        ]
4593    }
4594
4595    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
4596        [DetectBox {
4597            bbox: BoundingBox {
4598                xmin: 0.12549022,
4599                ymin: 0.12549022,
4600                xmax: 0.23529413,
4601                ymax: 0.23529413,
4602            },
4603            score: 0.98823535,
4604            label: 2,
4605        }]
4606    }
4607
4608    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
4609        [DetectBox {
4610            bbox: BoundingBox {
4611                xmin: 0.1234,
4612                ymin: 0.1234,
4613                xmax: 0.2345,
4614                ymax: 0.2345,
4615            },
4616            score: 0.9876,
4617            label: 2,
4618        }]
4619    }
4620
4621    fn build_yolo_split_segdet_decoder(
4622        score_threshold: f32,
4623        iou_threshold: f32,
4624        quant_boxes: (f32, i32),
4625        quant_protos: (f32, i32),
4626    ) -> crate::Decoder {
4627        DecoderBuilder::default()
4628            .with_config_yolo_split_segdet(
4629                configs::Boxes {
4630                    decoder: configs::DecoderType::Ultralytics,
4631                    quantization: Some(quant_boxes.into()),
4632                    shape: vec![1, 4, 8400],
4633                    dshape: vec![
4634                        (DimName::Batch, 1),
4635                        (DimName::BoxCoords, 4),
4636                        (DimName::NumBoxes, 8400),
4637                    ],
4638                    normalized: Some(true),
4639                },
4640                configs::Scores {
4641                    decoder: configs::DecoderType::Ultralytics,
4642                    quantization: Some(quant_boxes.into()),
4643                    shape: vec![1, 80, 8400],
4644                    dshape: vec![
4645                        (DimName::Batch, 1),
4646                        (DimName::NumClasses, 80),
4647                        (DimName::NumBoxes, 8400),
4648                    ],
4649                },
4650                configs::MaskCoefficients {
4651                    decoder: configs::DecoderType::Ultralytics,
4652                    quantization: Some(quant_boxes.into()),
4653                    shape: vec![1, 32, 8400],
4654                    dshape: vec![
4655                        (DimName::Batch, 1),
4656                        (DimName::NumProtos, 32),
4657                        (DimName::NumBoxes, 8400),
4658                    ],
4659                },
4660                configs::Protos {
4661                    decoder: configs::DecoderType::Ultralytics,
4662                    quantization: Some(quant_protos.into()),
4663                    shape: vec![1, 160, 160, 32],
4664                    dshape: vec![
4665                        (DimName::Batch, 1),
4666                        (DimName::Height, 160),
4667                        (DimName::Width, 160),
4668                        (DimName::NumProtos, 32),
4669                    ],
4670                },
4671            )
4672            .with_score_threshold(score_threshold)
4673            .with_iou_threshold(iou_threshold)
4674            .build()
4675            .unwrap()
4676    }
4677
4678    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4679        let config_yaml = include_str!(concat!(
4680            env!("CARGO_MANIFEST_DIR"),
4681            "/../../testdata/yolov8_seg.yaml"
4682        ));
4683        DecoderBuilder::default()
4684            .with_config_yaml_str(config_yaml.to_string())
4685            .with_score_threshold(score_threshold)
4686            .with_iou_threshold(iou_threshold)
4687            .build()
4688            .unwrap()
4689    }
4690
4691    // ─── Real-data tracked test macro ───────────────────────────────
4692    //
4693    // Generates tests that load i8 binary test data from testdata/ and
4694    // exercise all (quant/float) × (combined/split) × (masks/proto)
4695    // decoder paths.
4696
4697    macro_rules! real_data_tracked_test {
4698        ($name:ident, quantized, $layout:ident, $output:ident) => {
4699            #[test]
4700            fn $name() {
4701                let is_split = matches!(stringify!($layout), "split");
4702                let is_proto = matches!(stringify!($output), "proto");
4703
4704                let score_threshold = 0.45;
4705                let iou_threshold = 0.45;
4706                let quant_boxes = (0.021287762_f32, 31_i32);
4707                let quant_protos = (0.02491162_f32, -117_i32);
4708
4709                let raw_boxes = include_bytes!(concat!(
4710                    env!("CARGO_MANIFEST_DIR"),
4711                    "/../../testdata/yolov8_boxes_116x8400.bin"
4712                ));
4713                let raw_boxes = unsafe {
4714                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4715                };
4716                let boxes_i8 =
4717                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4718
4719                let raw_protos = include_bytes!(concat!(
4720                    env!("CARGO_MANIFEST_DIR"),
4721                    "/../../testdata/yolov8_protos_160x160x32.bin"
4722                ));
4723                let raw_protos = unsafe {
4724                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4725                };
4726                let protos_i8 =
4727                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4728                        .unwrap();
4729
4730                // Pre-split (unused for combined, but harmless)
4731                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4732                let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4733                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4734                let mut boxes_combined = boxes_i8;
4735
4736                let decoder = if is_split {
4737                    build_yolo_split_segdet_decoder(
4738                        score_threshold,
4739                        iou_threshold,
4740                        quant_boxes,
4741                        quant_protos,
4742                    )
4743                } else {
4744                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4745                };
4746
4747                let expected = real_data_expected_boxes();
4748                let mut tracker = ByteTrackBuilder::new()
4749                    .track_update(0.1)
4750                    .track_high_conf(0.7)
4751                    .build();
4752                let mut output_boxes = Vec::with_capacity(50);
4753                let mut output_tracks = Vec::with_capacity(50);
4754
4755                // Frame 1: decode
4756                if is_proto {
4757                    {
4758                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4759                            vec![
4760                                boxes_split.view().into(),
4761                                scores_split.view().into(),
4762                                mask_split.view().into(),
4763                                protos_i8.view().into(),
4764                            ]
4765                        } else {
4766                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4767                        };
4768                        decoder
4769                            .decode_tracked_quantized_proto(
4770                                &mut tracker,
4771                                0,
4772                                &inputs,
4773                                &mut output_boxes,
4774                                &mut output_tracks,
4775                            )
4776                            .unwrap();
4777                    }
4778                    assert_eq!(output_boxes.len(), 2);
4779                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4780                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4781
4782                    // Zero scores for frame 2
4783                    if is_split {
4784                        for score in scores_split.iter_mut() {
4785                            *score = i8::MIN;
4786                        }
4787                    } else {
4788                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4789                            *score = i8::MIN;
4790                        }
4791                    }
4792
4793                    let proto_result = {
4794                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4795                            vec![
4796                                boxes_split.view().into(),
4797                                scores_split.view().into(),
4798                                mask_split.view().into(),
4799                                protos_i8.view().into(),
4800                            ]
4801                        } else {
4802                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4803                        };
4804                        decoder
4805                            .decode_tracked_quantized_proto(
4806                                &mut tracker,
4807                                100_000_000 / 3,
4808                                &inputs,
4809                                &mut output_boxes,
4810                                &mut output_tracks,
4811                            )
4812                            .unwrap()
4813                    };
4814                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4815                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4816                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4817                } else {
4818                    let mut output_masks = Vec::with_capacity(50);
4819                    {
4820                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4821                            vec![
4822                                boxes_split.view().into(),
4823                                scores_split.view().into(),
4824                                mask_split.view().into(),
4825                                protos_i8.view().into(),
4826                            ]
4827                        } else {
4828                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4829                        };
4830                        decoder
4831                            .decode_tracked_quantized(
4832                                &mut tracker,
4833                                0,
4834                                &inputs,
4835                                &mut output_boxes,
4836                                &mut output_masks,
4837                                &mut output_tracks,
4838                            )
4839                            .unwrap();
4840                    }
4841                    assert_eq!(output_boxes.len(), 2);
4842                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4843                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4844
4845                    if is_split {
4846                        for score in scores_split.iter_mut() {
4847                            *score = i8::MIN;
4848                        }
4849                    } else {
4850                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4851                            *score = i8::MIN;
4852                        }
4853                    }
4854
4855                    {
4856                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4857                            vec![
4858                                boxes_split.view().into(),
4859                                scores_split.view().into(),
4860                                mask_split.view().into(),
4861                                protos_i8.view().into(),
4862                            ]
4863                        } else {
4864                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4865                        };
4866                        decoder
4867                            .decode_tracked_quantized(
4868                                &mut tracker,
4869                                100_000_000 / 3,
4870                                &inputs,
4871                                &mut output_boxes,
4872                                &mut output_masks,
4873                                &mut output_tracks,
4874                            )
4875                            .unwrap();
4876                    }
4877                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4878                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4879                    assert!(output_masks.is_empty());
4880                }
4881            }
4882        };
4883        ($name:ident, float, $layout:ident, $output:ident) => {
4884            #[test]
4885            fn $name() {
4886                let is_split = matches!(stringify!($layout), "split");
4887                let is_proto = matches!(stringify!($output), "proto");
4888
4889                let score_threshold = 0.45;
4890                let iou_threshold = 0.45;
4891                let quant_boxes = (0.021287762_f32, 31_i32);
4892                let quant_protos = (0.02491162_f32, -117_i32);
4893
4894                let raw_boxes = include_bytes!(concat!(
4895                    env!("CARGO_MANIFEST_DIR"),
4896                    "/../../testdata/yolov8_boxes_116x8400.bin"
4897                ));
4898                let raw_boxes = unsafe {
4899                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4900                };
4901                let boxes_i8 =
4902                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4903                let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4904
4905                let raw_protos = include_bytes!(concat!(
4906                    env!("CARGO_MANIFEST_DIR"),
4907                    "/../../testdata/yolov8_protos_160x160x32.bin"
4908                ));
4909                let raw_protos = unsafe {
4910                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4911                };
4912                let protos_i8 =
4913                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4914                        .unwrap();
4915                let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4916
4917                // Pre-split from dequantized data
4918                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4919                let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4920                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4921                let mut boxes_combined = boxes_f32;
4922
4923                let decoder = if is_split {
4924                    build_yolo_split_segdet_decoder(
4925                        score_threshold,
4926                        iou_threshold,
4927                        quant_boxes,
4928                        quant_protos,
4929                    )
4930                } else {
4931                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4932                };
4933
4934                let expected = real_data_expected_boxes();
4935                let mut tracker = ByteTrackBuilder::new()
4936                    .track_update(0.1)
4937                    .track_high_conf(0.7)
4938                    .build();
4939                let mut output_boxes = Vec::with_capacity(50);
4940                let mut output_tracks = Vec::with_capacity(50);
4941
4942                if is_proto {
4943                    {
4944                        let inputs = if is_split {
4945                            vec![
4946                                boxes_split.view().into_dyn(),
4947                                scores_split.view().into_dyn(),
4948                                mask_split.view().into_dyn(),
4949                                protos_f32.view().into_dyn(),
4950                            ]
4951                        } else {
4952                            vec![
4953                                boxes_combined.view().into_dyn(),
4954                                protos_f32.view().into_dyn(),
4955                            ]
4956                        };
4957                        decoder
4958                            .decode_tracked_float_proto(
4959                                &mut tracker,
4960                                0,
4961                                &inputs,
4962                                &mut output_boxes,
4963                                &mut output_tracks,
4964                            )
4965                            .unwrap();
4966                    }
4967                    assert_eq!(output_boxes.len(), 2);
4968                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4969                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4970
4971                    if is_split {
4972                        for score in scores_split.iter_mut() {
4973                            *score = 0.0;
4974                        }
4975                    } else {
4976                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4977                            *score = 0.0;
4978                        }
4979                    }
4980
4981                    let proto_result = {
4982                        let inputs = if is_split {
4983                            vec![
4984                                boxes_split.view().into_dyn(),
4985                                scores_split.view().into_dyn(),
4986                                mask_split.view().into_dyn(),
4987                                protos_f32.view().into_dyn(),
4988                            ]
4989                        } else {
4990                            vec![
4991                                boxes_combined.view().into_dyn(),
4992                                protos_f32.view().into_dyn(),
4993                            ]
4994                        };
4995                        decoder
4996                            .decode_tracked_float_proto(
4997                                &mut tracker,
4998                                100_000_000 / 3,
4999                                &inputs,
5000                                &mut output_boxes,
5001                                &mut output_tracks,
5002                            )
5003                            .unwrap()
5004                    };
5005                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5006                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
5007                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5008                } else {
5009                    let mut output_masks = Vec::with_capacity(50);
5010                    {
5011                        let inputs = if is_split {
5012                            vec![
5013                                boxes_split.view().into_dyn(),
5014                                scores_split.view().into_dyn(),
5015                                mask_split.view().into_dyn(),
5016                                protos_f32.view().into_dyn(),
5017                            ]
5018                        } else {
5019                            vec![
5020                                boxes_combined.view().into_dyn(),
5021                                protos_f32.view().into_dyn(),
5022                            ]
5023                        };
5024                        decoder
5025                            .decode_tracked_float(
5026                                &mut tracker,
5027                                0,
5028                                &inputs,
5029                                &mut output_boxes,
5030                                &mut output_masks,
5031                                &mut output_tracks,
5032                            )
5033                            .unwrap();
5034                    }
5035                    assert_eq!(output_boxes.len(), 2);
5036                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5037                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
5038
5039                    if is_split {
5040                        for score in scores_split.iter_mut() {
5041                            *score = 0.0;
5042                        }
5043                    } else {
5044                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
5045                            *score = 0.0;
5046                        }
5047                    }
5048
5049                    {
5050                        let inputs = if is_split {
5051                            vec![
5052                                boxes_split.view().into_dyn(),
5053                                scores_split.view().into_dyn(),
5054                                mask_split.view().into_dyn(),
5055                                protos_f32.view().into_dyn(),
5056                            ]
5057                        } else {
5058                            vec![
5059                                boxes_combined.view().into_dyn(),
5060                                protos_f32.view().into_dyn(),
5061                            ]
5062                        };
5063                        decoder
5064                            .decode_tracked_float(
5065                                &mut tracker,
5066                                100_000_000 / 3,
5067                                &inputs,
5068                                &mut output_boxes,
5069                                &mut output_masks,
5070                                &mut output_tracks,
5071                            )
5072                            .unwrap();
5073                    }
5074                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5075                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
5076                    assert!(output_masks.is_empty());
5077                }
5078            }
5079        };
5080    }
5081
5082    real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
5083    real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
5084    real_data_tracked_test!(
5085        test_decoder_tracked_segdet_proto,
5086        quantized,
5087        combined,
5088        proto
5089    );
5090    real_data_tracked_test!(
5091        test_decoder_tracked_segdet_proto_float,
5092        float,
5093        combined,
5094        proto
5095    );
5096    real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
5097    real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
5098    real_data_tracked_test!(
5099        test_decoder_tracked_segdet_split_proto,
5100        quantized,
5101        split,
5102        proto
5103    );
5104    real_data_tracked_test!(
5105        test_decoder_tracked_segdet_split_proto_float,
5106        float,
5107        split,
5108        proto
5109    );
5110
5111    // ─── End-to-end tracked test macro ──────────────────────────────
5112    //
5113    // Generates tests with synthetic data to exercise all tracked
5114    // decode paths without needing real model output files.
5115
5116    const E2E_COMBINED_CONFIG: &str = "
5117decoder_version: yolo26
5118outputs:
5119 - type: detection
5120   decoder: ultralytics
5121   quantization: [0.00784313725490196, 0]
5122   shape: [1, 10, 38]
5123   dshape:
5124    - [batch, 1]
5125    - [num_boxes, 10]
5126    - [num_features, 38]
5127   normalized: true
5128 - type: protos
5129   decoder: ultralytics
5130   quantization: [0.0039215686274509803921568627451, 128]
5131   shape: [1, 160, 160, 32]
5132   dshape:
5133    - [batch, 1]
5134    - [height, 160]
5135    - [width, 160]
5136    - [num_protos, 32]
5137";
5138
5139    const E2E_SPLIT_CONFIG: &str = "
5140decoder_version: yolo26
5141outputs:
5142 - type: boxes
5143   decoder: ultralytics
5144   quantization: [0.00784313725490196, 0]
5145   shape: [1, 10, 4]
5146   dshape:
5147    - [batch, 1]
5148    - [num_boxes, 10]
5149    - [box_coords, 4]
5150   normalized: true
5151 - type: scores
5152   decoder: ultralytics
5153   quantization: [0.00784313725490196, 0]
5154   shape: [1, 10, 1]
5155   dshape:
5156    - [batch, 1]
5157    - [num_boxes, 10]
5158    - [num_classes, 1]
5159 - type: classes
5160   decoder: ultralytics
5161   quantization: [0.00784313725490196, 0]
5162   shape: [1, 10, 1]
5163   dshape:
5164    - [batch, 1]
5165    - [num_boxes, 10]
5166    - [num_classes, 1]
5167 - type: mask_coefficients
5168   decoder: ultralytics
5169   quantization: [0.00784313725490196, 0]
5170   shape: [1, 10, 32]
5171   dshape:
5172    - [batch, 1]
5173    - [num_boxes, 10]
5174    - [num_protos, 32]
5175 - type: protos
5176   decoder: ultralytics
5177   quantization: [0.0039215686274509803921568627451, 128]
5178   shape: [1, 160, 160, 32]
5179   dshape:
5180    - [batch, 1]
5181    - [height, 160]
5182    - [width, 160]
5183    - [num_protos, 32]
5184";
5185
5186    macro_rules! e2e_tracked_test {
5187        ($name:ident, quantized, $layout:ident, $output:ident) => {
5188            #[test]
5189            fn $name() {
5190                let is_split = matches!(stringify!($layout), "split");
5191                let is_proto = matches!(stringify!($output), "proto");
5192
5193                let score_threshold = 0.45;
5194                let iou_threshold = 0.45;
5195
5196                let mut boxes = Array2::zeros((10, 4));
5197                let mut scores = Array2::zeros((10, 1));
5198                let mut classes = Array2::zeros((10, 1));
5199                let mask = Array2::zeros((10, 32));
5200                let protos = Array3::<f64>::zeros((160, 160, 32));
5201                let protos = protos.insert_axis(Axis(0));
5202                let protos_quant = (1.0 / 255.0, 0.0);
5203                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
5204
5205                boxes
5206                    .slice_mut(s![0, ..])
5207                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5208                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5209                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5210
5211                let detect_quant = (2.0 / 255.0, 0.0);
5212
5213                let decoder = if is_split {
5214                    DecoderBuilder::default()
5215                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5216                        .with_score_threshold(score_threshold)
5217                        .with_iou_threshold(iou_threshold)
5218                        .build()
5219                        .unwrap()
5220                } else {
5221                    DecoderBuilder::default()
5222                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5223                        .with_score_threshold(score_threshold)
5224                        .with_iou_threshold(iou_threshold)
5225                        .build()
5226                        .unwrap()
5227                };
5228
5229                let expected = e2e_expected_boxes_quant();
5230                let mut tracker = ByteTrackBuilder::new()
5231                    .track_update(0.1)
5232                    .track_high_conf(0.7)
5233                    .build();
5234                let mut output_boxes = Vec::with_capacity(50);
5235                let mut output_tracks = Vec::with_capacity(50);
5236
5237                if is_split {
5238                    let boxes = boxes.insert_axis(Axis(0));
5239                    let scores = scores.insert_axis(Axis(0));
5240                    let classes = classes.insert_axis(Axis(0));
5241                    let mask = mask.insert_axis(Axis(0));
5242
5243                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5244                    let mut scores: Array3<u8> =
5245                        quantize_ndarray(scores.view(), detect_quant.into());
5246                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
5247                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5248
5249                    if is_proto {
5250                        {
5251                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5252                                boxes.view().into(),
5253                                scores.view().into(),
5254                                classes.view().into(),
5255                                mask.view().into(),
5256                                protos.view().into(),
5257                            ];
5258                            decoder
5259                                .decode_tracked_quantized_proto(
5260                                    &mut tracker,
5261                                    0,
5262                                    &inputs,
5263                                    &mut output_boxes,
5264                                    &mut output_tracks,
5265                                )
5266                                .unwrap();
5267                        }
5268                        assert_eq!(output_boxes.len(), 1);
5269                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5270
5271                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5272                            *score = u8::MIN;
5273                        }
5274                        let proto_result = {
5275                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5276                                boxes.view().into(),
5277                                scores.view().into(),
5278                                classes.view().into(),
5279                                mask.view().into(),
5280                                protos.view().into(),
5281                            ];
5282                            decoder
5283                                .decode_tracked_quantized_proto(
5284                                    &mut tracker,
5285                                    100_000_000 / 3,
5286                                    &inputs,
5287                                    &mut output_boxes,
5288                                    &mut output_tracks,
5289                                )
5290                                .unwrap()
5291                        };
5292                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5293                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5294                    } else {
5295                        let mut output_masks = Vec::with_capacity(50);
5296                        {
5297                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5298                                boxes.view().into(),
5299                                scores.view().into(),
5300                                classes.view().into(),
5301                                mask.view().into(),
5302                                protos.view().into(),
5303                            ];
5304                            decoder
5305                                .decode_tracked_quantized(
5306                                    &mut tracker,
5307                                    0,
5308                                    &inputs,
5309                                    &mut output_boxes,
5310                                    &mut output_masks,
5311                                    &mut output_tracks,
5312                                )
5313                                .unwrap();
5314                        }
5315                        assert_eq!(output_boxes.len(), 1);
5316                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5317
5318                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5319                            *score = u8::MIN;
5320                        }
5321                        {
5322                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5323                                boxes.view().into(),
5324                                scores.view().into(),
5325                                classes.view().into(),
5326                                mask.view().into(),
5327                                protos.view().into(),
5328                            ];
5329                            decoder
5330                                .decode_tracked_quantized(
5331                                    &mut tracker,
5332                                    100_000_000 / 3,
5333                                    &inputs,
5334                                    &mut output_boxes,
5335                                    &mut output_masks,
5336                                    &mut output_tracks,
5337                                )
5338                                .unwrap();
5339                        }
5340                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5341                        assert!(output_masks.is_empty());
5342                    }
5343                } else {
5344                    // Combined layout
5345                    let detect = ndarray::concatenate![
5346                        Axis(1),
5347                        boxes.view(),
5348                        scores.view(),
5349                        classes.view(),
5350                        mask.view()
5351                    ];
5352                    let detect = detect.insert_axis(Axis(0));
5353                    assert_eq!(detect.shape(), &[1, 10, 38]);
5354                    let mut detect: Array3<u8> =
5355                        quantize_ndarray(detect.view(), detect_quant.into());
5356
5357                    if is_proto {
5358                        {
5359                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5360                                vec![detect.view().into(), protos.view().into()];
5361                            decoder
5362                                .decode_tracked_quantized_proto(
5363                                    &mut tracker,
5364                                    0,
5365                                    &inputs,
5366                                    &mut output_boxes,
5367                                    &mut output_tracks,
5368                                )
5369                                .unwrap();
5370                        }
5371                        assert_eq!(output_boxes.len(), 1);
5372                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5373
5374                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5375                            *score = u8::MIN;
5376                        }
5377                        let proto_result = {
5378                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5379                                vec![detect.view().into(), protos.view().into()];
5380                            decoder
5381                                .decode_tracked_quantized_proto(
5382                                    &mut tracker,
5383                                    100_000_000 / 3,
5384                                    &inputs,
5385                                    &mut output_boxes,
5386                                    &mut output_tracks,
5387                                )
5388                                .unwrap()
5389                        };
5390                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5391                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5392                    } else {
5393                        let mut output_masks = Vec::with_capacity(50);
5394                        {
5395                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5396                                vec![detect.view().into(), protos.view().into()];
5397                            decoder
5398                                .decode_tracked_quantized(
5399                                    &mut tracker,
5400                                    0,
5401                                    &inputs,
5402                                    &mut output_boxes,
5403                                    &mut output_masks,
5404                                    &mut output_tracks,
5405                                )
5406                                .unwrap();
5407                        }
5408                        assert_eq!(output_boxes.len(), 1);
5409                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5410
5411                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5412                            *score = u8::MIN;
5413                        }
5414                        {
5415                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5416                                vec![detect.view().into(), protos.view().into()];
5417                            decoder
5418                                .decode_tracked_quantized(
5419                                    &mut tracker,
5420                                    100_000_000 / 3,
5421                                    &inputs,
5422                                    &mut output_boxes,
5423                                    &mut output_masks,
5424                                    &mut output_tracks,
5425                                )
5426                                .unwrap();
5427                        }
5428                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5429                        assert!(output_masks.is_empty());
5430                    }
5431                }
5432            }
5433        };
5434        ($name:ident, float, $layout:ident, $output:ident) => {
5435            #[test]
5436            fn $name() {
5437                let is_split = matches!(stringify!($layout), "split");
5438                let is_proto = matches!(stringify!($output), "proto");
5439
5440                let score_threshold = 0.45;
5441                let iou_threshold = 0.45;
5442
5443                let mut boxes = Array2::zeros((10, 4));
5444                let mut scores = Array2::zeros((10, 1));
5445                let mut classes = Array2::zeros((10, 1));
5446                let mask: Array2<f64> = Array2::zeros((10, 32));
5447                let protos = Array3::<f64>::zeros((160, 160, 32));
5448                let protos = protos.insert_axis(Axis(0));
5449
5450                boxes
5451                    .slice_mut(s![0, ..])
5452                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5453                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5454                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5455
5456                let decoder = if is_split {
5457                    DecoderBuilder::default()
5458                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5459                        .with_score_threshold(score_threshold)
5460                        .with_iou_threshold(iou_threshold)
5461                        .build()
5462                        .unwrap()
5463                } else {
5464                    DecoderBuilder::default()
5465                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5466                        .with_score_threshold(score_threshold)
5467                        .with_iou_threshold(iou_threshold)
5468                        .build()
5469                        .unwrap()
5470                };
5471
5472                let expected = e2e_expected_boxes_float();
5473                let mut tracker = ByteTrackBuilder::new()
5474                    .track_update(0.1)
5475                    .track_high_conf(0.7)
5476                    .build();
5477                let mut output_boxes = Vec::with_capacity(50);
5478                let mut output_tracks = Vec::with_capacity(50);
5479
5480                if is_split {
5481                    let boxes = boxes.insert_axis(Axis(0));
5482                    let mut scores = scores.insert_axis(Axis(0));
5483                    let classes = classes.insert_axis(Axis(0));
5484                    let mask = mask.insert_axis(Axis(0));
5485
5486                    if is_proto {
5487                        {
5488                            let inputs = vec![
5489                                boxes.view().into_dyn(),
5490                                scores.view().into_dyn(),
5491                                classes.view().into_dyn(),
5492                                mask.view().into_dyn(),
5493                                protos.view().into_dyn(),
5494                            ];
5495                            decoder
5496                                .decode_tracked_float_proto(
5497                                    &mut tracker,
5498                                    0,
5499                                    &inputs,
5500                                    &mut output_boxes,
5501                                    &mut output_tracks,
5502                                )
5503                                .unwrap();
5504                        }
5505                        assert_eq!(output_boxes.len(), 1);
5506                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5507
5508                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5509                            *score = 0.0;
5510                        }
5511                        let proto_result = {
5512                            let inputs = vec![
5513                                boxes.view().into_dyn(),
5514                                scores.view().into_dyn(),
5515                                classes.view().into_dyn(),
5516                                mask.view().into_dyn(),
5517                                protos.view().into_dyn(),
5518                            ];
5519                            decoder
5520                                .decode_tracked_float_proto(
5521                                    &mut tracker,
5522                                    100_000_000 / 3,
5523                                    &inputs,
5524                                    &mut output_boxes,
5525                                    &mut output_tracks,
5526                                )
5527                                .unwrap()
5528                        };
5529                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5530                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5531                    } else {
5532                        let mut output_masks = Vec::with_capacity(50);
5533                        {
5534                            let inputs = vec![
5535                                boxes.view().into_dyn(),
5536                                scores.view().into_dyn(),
5537                                classes.view().into_dyn(),
5538                                mask.view().into_dyn(),
5539                                protos.view().into_dyn(),
5540                            ];
5541                            decoder
5542                                .decode_tracked_float(
5543                                    &mut tracker,
5544                                    0,
5545                                    &inputs,
5546                                    &mut output_boxes,
5547                                    &mut output_masks,
5548                                    &mut output_tracks,
5549                                )
5550                                .unwrap();
5551                        }
5552                        assert_eq!(output_boxes.len(), 1);
5553                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5554
5555                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5556                            *score = 0.0;
5557                        }
5558                        {
5559                            let inputs = vec![
5560                                boxes.view().into_dyn(),
5561                                scores.view().into_dyn(),
5562                                classes.view().into_dyn(),
5563                                mask.view().into_dyn(),
5564                                protos.view().into_dyn(),
5565                            ];
5566                            decoder
5567                                .decode_tracked_float(
5568                                    &mut tracker,
5569                                    100_000_000 / 3,
5570                                    &inputs,
5571                                    &mut output_boxes,
5572                                    &mut output_masks,
5573                                    &mut output_tracks,
5574                                )
5575                                .unwrap();
5576                        }
5577                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5578                        assert!(output_masks.is_empty());
5579                    }
5580                } else {
5581                    // Combined layout
5582                    let detect = ndarray::concatenate![
5583                        Axis(1),
5584                        boxes.view(),
5585                        scores.view(),
5586                        classes.view(),
5587                        mask.view()
5588                    ];
5589                    let mut detect = detect.insert_axis(Axis(0));
5590                    assert_eq!(detect.shape(), &[1, 10, 38]);
5591
5592                    if is_proto {
5593                        {
5594                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5595                            decoder
5596                                .decode_tracked_float_proto(
5597                                    &mut tracker,
5598                                    0,
5599                                    &inputs,
5600                                    &mut output_boxes,
5601                                    &mut output_tracks,
5602                                )
5603                                .unwrap();
5604                        }
5605                        assert_eq!(output_boxes.len(), 1);
5606                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5607
5608                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5609                            *score = 0.0;
5610                        }
5611                        let proto_result = {
5612                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5613                            decoder
5614                                .decode_tracked_float_proto(
5615                                    &mut tracker,
5616                                    100_000_000 / 3,
5617                                    &inputs,
5618                                    &mut output_boxes,
5619                                    &mut output_tracks,
5620                                )
5621                                .unwrap()
5622                        };
5623                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5624                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5625                    } else {
5626                        let mut output_masks = Vec::with_capacity(50);
5627                        {
5628                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5629                            decoder
5630                                .decode_tracked_float(
5631                                    &mut tracker,
5632                                    0,
5633                                    &inputs,
5634                                    &mut output_boxes,
5635                                    &mut output_masks,
5636                                    &mut output_tracks,
5637                                )
5638                                .unwrap();
5639                        }
5640                        assert_eq!(output_boxes.len(), 1);
5641                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5642
5643                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5644                            *score = 0.0;
5645                        }
5646                        {
5647                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5648                            decoder
5649                                .decode_tracked_float(
5650                                    &mut tracker,
5651                                    100_000_000 / 3,
5652                                    &inputs,
5653                                    &mut output_boxes,
5654                                    &mut output_masks,
5655                                    &mut output_tracks,
5656                                )
5657                                .unwrap();
5658                        }
5659                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5660                        assert!(output_masks.is_empty());
5661                    }
5662                }
5663            }
5664        };
5665    }
5666
5667    e2e_tracked_test!(
5668        test_decoder_tracked_end_to_end_segdet,
5669        quantized,
5670        combined,
5671        masks
5672    );
5673    e2e_tracked_test!(
5674        test_decoder_tracked_end_to_end_segdet_float,
5675        float,
5676        combined,
5677        masks
5678    );
5679    e2e_tracked_test!(
5680        test_decoder_tracked_end_to_end_segdet_proto,
5681        quantized,
5682        combined,
5683        proto
5684    );
5685    e2e_tracked_test!(
5686        test_decoder_tracked_end_to_end_segdet_proto_float,
5687        float,
5688        combined,
5689        proto
5690    );
5691    e2e_tracked_test!(
5692        test_decoder_tracked_end_to_end_segdet_split,
5693        quantized,
5694        split,
5695        masks
5696    );
5697    e2e_tracked_test!(
5698        test_decoder_tracked_end_to_end_segdet_split_float,
5699        float,
5700        split,
5701        masks
5702    );
5703    e2e_tracked_test!(
5704        test_decoder_tracked_end_to_end_segdet_split_proto,
5705        quantized,
5706        split,
5707        proto
5708    );
5709    e2e_tracked_test!(
5710        test_decoder_tracked_end_to_end_segdet_split_proto_float,
5711        float,
5712        split,
5713        proto
5714    );
5715
5716    // ─── End-to-end tracked TensorDyn test macro ────────────────────
5717    //
5718    // Same as e2e_tracked_test but wraps data in TensorDyn and exercises
5719    // the public decode_tracked / decode_proto_tracked API.
5720
5721    macro_rules! e2e_tracked_tensor_test {
5722        ($name:ident, quantized, $layout:ident, $output:ident) => {
5723            #[test]
5724            fn $name() {
5725                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5726
5727                let is_split = matches!(stringify!($layout), "split");
5728                let is_proto = matches!(stringify!($output), "proto");
5729
5730                let score_threshold = 0.45;
5731                let iou_threshold = 0.45;
5732
5733                let mut boxes = Array2::zeros((10, 4));
5734                let mut scores = Array2::zeros((10, 1));
5735                let mut classes = Array2::zeros((10, 1));
5736                let mask = Array2::zeros((10, 32));
5737                let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5738                let protos_f64 = protos_f64.insert_axis(Axis(0));
5739                let protos_quant = (1.0 / 255.0, 0.0);
5740                let protos_u8: Array4<u8> =
5741                    quantize_ndarray(protos_f64.view(), protos_quant.into());
5742
5743                boxes
5744                    .slice_mut(s![0, ..])
5745                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5746                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5747                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5748
5749                let detect_quant = (2.0 / 255.0, 0.0);
5750
5751                let decoder = if is_split {
5752                    DecoderBuilder::default()
5753                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5754                        .with_score_threshold(score_threshold)
5755                        .with_iou_threshold(iou_threshold)
5756                        .build()
5757                        .unwrap()
5758                } else {
5759                    DecoderBuilder::default()
5760                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5761                        .with_score_threshold(score_threshold)
5762                        .with_iou_threshold(iou_threshold)
5763                        .build()
5764                        .unwrap()
5765                };
5766
5767                // Helper to wrap a u8 slice into a TensorDyn
5768                let make_u8_tensor =
5769                    |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5770                        let t = Tensor::<u8>::new(shape, None, None).unwrap();
5771                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5772                        t.into()
5773                    };
5774
5775                let expected = e2e_expected_boxes_quant();
5776                let mut tracker = ByteTrackBuilder::new()
5777                    .track_update(0.1)
5778                    .track_high_conf(0.7)
5779                    .build();
5780                let mut output_boxes = Vec::with_capacity(50);
5781                let mut output_tracks = Vec::with_capacity(50);
5782
5783                let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5784
5785                if is_split {
5786                    let boxes = boxes.insert_axis(Axis(0));
5787                    let scores = scores.insert_axis(Axis(0));
5788                    let classes = classes.insert_axis(Axis(0));
5789                    let mask = mask.insert_axis(Axis(0));
5790
5791                    let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5792                    let mut scores_q: Array3<u8> =
5793                        quantize_ndarray(scores.view(), detect_quant.into());
5794                    let classes_q: Array3<u8> =
5795                        quantize_ndarray(classes.view(), detect_quant.into());
5796                    let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5797
5798                    let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5799                    let classes_td =
5800                        make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5801                    let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5802
5803                    if is_proto {
5804                        let scores_td =
5805                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5806                        decoder
5807                            .decode_proto_tracked(
5808                                &mut tracker,
5809                                0,
5810                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5811                                &mut output_boxes,
5812                                &mut output_tracks,
5813                            )
5814                            .unwrap();
5815
5816                        assert_eq!(output_boxes.len(), 1);
5817                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5818
5819                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5820                            *score = u8::MIN;
5821                        }
5822                        let scores_td =
5823                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5824                        let proto_result = decoder
5825                            .decode_proto_tracked(
5826                                &mut tracker,
5827                                100_000_000 / 3,
5828                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5829                                &mut output_boxes,
5830                                &mut output_tracks,
5831                            )
5832                            .unwrap();
5833                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5834                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5835                    } else {
5836                        let scores_td =
5837                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5838                        let mut output_masks = Vec::with_capacity(50);
5839                        decoder
5840                            .decode_tracked(
5841                                &mut tracker,
5842                                0,
5843                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5844                                &mut output_boxes,
5845                                &mut output_masks,
5846                                &mut output_tracks,
5847                            )
5848                            .unwrap();
5849
5850                        assert_eq!(output_boxes.len(), 1);
5851                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5852
5853                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5854                            *score = u8::MIN;
5855                        }
5856                        let scores_td =
5857                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5858                        decoder
5859                            .decode_tracked(
5860                                &mut tracker,
5861                                100_000_000 / 3,
5862                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5863                                &mut output_boxes,
5864                                &mut output_masks,
5865                                &mut output_tracks,
5866                            )
5867                            .unwrap();
5868                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5869                        assert!(output_masks.is_empty());
5870                    }
5871                } else {
5872                    // Combined layout
5873                    let detect = ndarray::concatenate![
5874                        Axis(1),
5875                        boxes.view(),
5876                        scores.view(),
5877                        classes.view(),
5878                        mask.view()
5879                    ];
5880                    let detect = detect.insert_axis(Axis(0));
5881                    assert_eq!(detect.shape(), &[1, 10, 38]);
5882                    // Ensure contiguous layout after concatenation for as_slice()
5883                    let detect =
5884                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5885                            .unwrap();
5886                    let mut detect_q: Array3<u8> =
5887                        quantize_ndarray(detect.view(), detect_quant.into());
5888
5889                    if is_proto {
5890                        let detect_td =
5891                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5892                        decoder
5893                            .decode_proto_tracked(
5894                                &mut tracker,
5895                                0,
5896                                &[&detect_td, &protos_td],
5897                                &mut output_boxes,
5898                                &mut output_tracks,
5899                            )
5900                            .unwrap();
5901
5902                        assert_eq!(output_boxes.len(), 1);
5903                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5904
5905                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5906                            *score = u8::MIN;
5907                        }
5908                        let detect_td =
5909                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5910                        let proto_result = decoder
5911                            .decode_proto_tracked(
5912                                &mut tracker,
5913                                100_000_000 / 3,
5914                                &[&detect_td, &protos_td],
5915                                &mut output_boxes,
5916                                &mut output_tracks,
5917                            )
5918                            .unwrap();
5919                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5920                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5921                    } else {
5922                        let detect_td =
5923                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5924                        let mut output_masks = Vec::with_capacity(50);
5925                        decoder
5926                            .decode_tracked(
5927                                &mut tracker,
5928                                0,
5929                                &[&detect_td, &protos_td],
5930                                &mut output_boxes,
5931                                &mut output_masks,
5932                                &mut output_tracks,
5933                            )
5934                            .unwrap();
5935
5936                        assert_eq!(output_boxes.len(), 1);
5937                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5938
5939                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5940                            *score = u8::MIN;
5941                        }
5942                        let detect_td =
5943                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5944                        decoder
5945                            .decode_tracked(
5946                                &mut tracker,
5947                                100_000_000 / 3,
5948                                &[&detect_td, &protos_td],
5949                                &mut output_boxes,
5950                                &mut output_masks,
5951                                &mut output_tracks,
5952                            )
5953                            .unwrap();
5954                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5955                        assert!(output_masks.is_empty());
5956                    }
5957                }
5958            }
5959        };
5960        ($name:ident, float, $layout:ident, $output:ident) => {
5961            #[test]
5962            fn $name() {
5963                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5964
5965                let is_split = matches!(stringify!($layout), "split");
5966                let is_proto = matches!(stringify!($output), "proto");
5967
5968                let score_threshold = 0.45;
5969                let iou_threshold = 0.45;
5970
5971                let mut boxes = Array2::zeros((10, 4));
5972                let mut scores = Array2::zeros((10, 1));
5973                let mut classes = Array2::zeros((10, 1));
5974                let mask: Array2<f64> = Array2::zeros((10, 32));
5975                let protos = Array3::<f64>::zeros((160, 160, 32));
5976                let protos = protos.insert_axis(Axis(0));
5977
5978                boxes
5979                    .slice_mut(s![0, ..])
5980                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5981                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5982                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5983
5984                let decoder = if is_split {
5985                    DecoderBuilder::default()
5986                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5987                        .with_score_threshold(score_threshold)
5988                        .with_iou_threshold(iou_threshold)
5989                        .build()
5990                        .unwrap()
5991                } else {
5992                    DecoderBuilder::default()
5993                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5994                        .with_score_threshold(score_threshold)
5995                        .with_iou_threshold(iou_threshold)
5996                        .build()
5997                        .unwrap()
5998                };
5999
6000                // Helper to wrap an f64 slice into a TensorDyn
6001                let make_f64_tensor =
6002                    |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
6003                        let t = Tensor::<f64>::new(shape, None, None).unwrap();
6004                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
6005                        t.into()
6006                    };
6007
6008                let expected = e2e_expected_boxes_float();
6009                let mut tracker = ByteTrackBuilder::new()
6010                    .track_update(0.1)
6011                    .track_high_conf(0.7)
6012                    .build();
6013                let mut output_boxes = Vec::with_capacity(50);
6014                let mut output_tracks = Vec::with_capacity(50);
6015
6016                let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
6017
6018                if is_split {
6019                    let boxes = boxes.insert_axis(Axis(0));
6020                    let mut scores = scores.insert_axis(Axis(0));
6021                    let classes = classes.insert_axis(Axis(0));
6022                    let mask = mask.insert_axis(Axis(0));
6023
6024                    let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
6025                    let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
6026                    let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
6027
6028                    if is_proto {
6029                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6030                        decoder
6031                            .decode_proto_tracked(
6032                                &mut tracker,
6033                                0,
6034                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6035                                &mut output_boxes,
6036                                &mut output_tracks,
6037                            )
6038                            .unwrap();
6039
6040                        assert_eq!(output_boxes.len(), 1);
6041                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6042
6043                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
6044                            *score = 0.0;
6045                        }
6046                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6047                        let proto_result = decoder
6048                            .decode_proto_tracked(
6049                                &mut tracker,
6050                                100_000_000 / 3,
6051                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6052                                &mut output_boxes,
6053                                &mut output_tracks,
6054                            )
6055                            .unwrap();
6056                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6057                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6058                    } else {
6059                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6060                        let mut output_masks = Vec::with_capacity(50);
6061                        decoder
6062                            .decode_tracked(
6063                                &mut tracker,
6064                                0,
6065                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6066                                &mut output_boxes,
6067                                &mut output_masks,
6068                                &mut output_tracks,
6069                            )
6070                            .unwrap();
6071
6072                        assert_eq!(output_boxes.len(), 1);
6073                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6074
6075                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
6076                            *score = 0.0;
6077                        }
6078                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
6079                        decoder
6080                            .decode_tracked(
6081                                &mut tracker,
6082                                100_000_000 / 3,
6083                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
6084                                &mut output_boxes,
6085                                &mut output_masks,
6086                                &mut output_tracks,
6087                            )
6088                            .unwrap();
6089                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6090                        assert!(output_masks.is_empty());
6091                    }
6092                } else {
6093                    // Combined layout
6094                    let detect = ndarray::concatenate![
6095                        Axis(1),
6096                        boxes.view(),
6097                        scores.view(),
6098                        classes.view(),
6099                        mask.view()
6100                    ];
6101                    let detect = detect.insert_axis(Axis(0));
6102                    assert_eq!(detect.shape(), &[1, 10, 38]);
6103                    // Ensure contiguous layout after concatenation for as_slice()
6104                    let mut detect =
6105                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
6106                            .unwrap();
6107
6108                    if is_proto {
6109                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6110                        decoder
6111                            .decode_proto_tracked(
6112                                &mut tracker,
6113                                0,
6114                                &[&detect_td, &protos_td],
6115                                &mut output_boxes,
6116                                &mut output_tracks,
6117                            )
6118                            .unwrap();
6119
6120                        assert_eq!(output_boxes.len(), 1);
6121                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6122
6123                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6124                            *score = 0.0;
6125                        }
6126                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6127                        let proto_result = decoder
6128                            .decode_proto_tracked(
6129                                &mut tracker,
6130                                100_000_000 / 3,
6131                                &[&detect_td, &protos_td],
6132                                &mut output_boxes,
6133                                &mut output_tracks,
6134                            )
6135                            .unwrap();
6136                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6137                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6138                    } else {
6139                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6140                        let mut output_masks = Vec::with_capacity(50);
6141                        decoder
6142                            .decode_tracked(
6143                                &mut tracker,
6144                                0,
6145                                &[&detect_td, &protos_td],
6146                                &mut output_boxes,
6147                                &mut output_masks,
6148                                &mut output_tracks,
6149                            )
6150                            .unwrap();
6151
6152                        assert_eq!(output_boxes.len(), 1);
6153                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6154
6155                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6156                            *score = 0.0;
6157                        }
6158                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6159                        decoder
6160                            .decode_tracked(
6161                                &mut tracker,
6162                                100_000_000 / 3,
6163                                &[&detect_td, &protos_td],
6164                                &mut output_boxes,
6165                                &mut output_masks,
6166                                &mut output_tracks,
6167                            )
6168                            .unwrap();
6169                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6170                        assert!(output_masks.is_empty());
6171                    }
6172                }
6173            }
6174        };
6175    }
6176
6177    e2e_tracked_tensor_test!(
6178        test_decoder_tracked_tensor_end_to_end_segdet,
6179        quantized,
6180        combined,
6181        masks
6182    );
6183    e2e_tracked_tensor_test!(
6184        test_decoder_tracked_tensor_end_to_end_segdet_float,
6185        float,
6186        combined,
6187        masks
6188    );
6189    e2e_tracked_tensor_test!(
6190        test_decoder_tracked_tensor_end_to_end_segdet_proto,
6191        quantized,
6192        combined,
6193        proto
6194    );
6195    e2e_tracked_tensor_test!(
6196        test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
6197        float,
6198        combined,
6199        proto
6200    );
6201    e2e_tracked_tensor_test!(
6202        test_decoder_tracked_tensor_end_to_end_segdet_split,
6203        quantized,
6204        split,
6205        masks
6206    );
6207    e2e_tracked_tensor_test!(
6208        test_decoder_tracked_tensor_end_to_end_segdet_split_float,
6209        float,
6210        split,
6211        masks
6212    );
6213    e2e_tracked_tensor_test!(
6214        test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
6215        quantized,
6216        split,
6217        proto
6218    );
6219    e2e_tracked_tensor_test!(
6220        test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
6221        float,
6222        split,
6223        proto
6224    );
6225
6226    #[test]
6227    fn test_decoder_tracked_linear_motion() {
6228        use crate::configs::{DecoderType, Nms};
6229        use crate::DecoderBuilder;
6230
6231        let score_threshold = 0.25;
6232        let iou_threshold = 0.1;
6233        let out = include_bytes!(concat!(
6234            env!("CARGO_MANIFEST_DIR"),
6235            "/../../testdata/yolov8s_80_classes.bin"
6236        ));
6237        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6238        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6239        let quant = (0.0040811873, -123).into();
6240
6241        let decoder = DecoderBuilder::default()
6242            .with_config_yolo_det(
6243                crate::configs::Detection {
6244                    decoder: DecoderType::Ultralytics,
6245                    shape: vec![1, 84, 8400],
6246                    anchors: None,
6247                    quantization: Some(quant),
6248                    dshape: vec![
6249                        (crate::configs::DimName::Batch, 1),
6250                        (crate::configs::DimName::NumFeatures, 84),
6251                        (crate::configs::DimName::NumBoxes, 8400),
6252                    ],
6253                    normalized: Some(true),
6254                },
6255                None,
6256            )
6257            .with_score_threshold(score_threshold)
6258            .with_iou_threshold(iou_threshold)
6259            .with_nms(Some(Nms::ClassAgnostic))
6260            .build()
6261            .unwrap();
6262
6263        let mut expected_boxes = [
6264            DetectBox {
6265                bbox: BoundingBox {
6266                    xmin: 0.5285137,
6267                    ymin: 0.05305544,
6268                    xmax: 0.87541467,
6269                    ymax: 0.9998909,
6270                },
6271                score: 0.5591227,
6272                label: 0,
6273            },
6274            DetectBox {
6275                bbox: BoundingBox {
6276                    xmin: 0.130598,
6277                    ymin: 0.43260583,
6278                    xmax: 0.35098213,
6279                    ymax: 0.9958097,
6280                },
6281                score: 0.33057618,
6282                label: 75,
6283            },
6284        ];
6285
6286        let mut tracker = ByteTrackBuilder::new()
6287            .track_update(0.1)
6288            .track_high_conf(0.3)
6289            .build();
6290
6291        let mut output_boxes = Vec::with_capacity(50);
6292        let mut output_masks = Vec::with_capacity(50);
6293        let mut output_tracks = Vec::with_capacity(50);
6294
6295        decoder
6296            .decode_tracked_quantized(
6297                &mut tracker,
6298                0,
6299                &[out.view().into()],
6300                &mut output_boxes,
6301                &mut output_masks,
6302                &mut output_tracks,
6303            )
6304            .unwrap();
6305
6306        assert_eq!(output_boxes.len(), 2);
6307        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6308        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
6309
6310        for i in 1..=100 {
6311            let mut out = out.clone();
6312            // introduce linear movement into the XY coordinates
6313            let mut x_values = out.slice_mut(s![0, 0, ..]);
6314            for x in x_values.iter_mut() {
6315                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
6316            }
6317
6318            decoder
6319                .decode_tracked_quantized(
6320                    &mut tracker,
6321                    100_000_000 * i / 3, // simulate 33.333ms between frames
6322                    &[out.view().into()],
6323                    &mut output_boxes,
6324                    &mut output_masks,
6325                    &mut output_tracks,
6326                )
6327                .unwrap();
6328
6329            assert_eq!(output_boxes.len(), 2);
6330        }
6331        let tracks = tracker.get_active_tracks();
6332        let predicted_boxes: Vec<_> = tracks
6333            .iter()
6334            .map(|track| {
6335                let mut l = track.last_box;
6336                l.bbox = track.info.tracked_location.into();
6337                l
6338            })
6339            .collect();
6340        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
6341        expected_boxes[0].bbox.xmax += 0.1;
6342        expected_boxes[1].bbox.xmin += 0.1;
6343        expected_boxes[1].bbox.xmax += 0.1;
6344
6345        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6346        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6347
6348        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6349        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
6350        for score in scores_values.iter_mut() {
6351            *score = i8::MIN; // set all scores to minimum to simulate no detections
6352        }
6353        decoder
6354            .decode_tracked_quantized(
6355                &mut tracker,
6356                100_000_000 * 101 / 3,
6357                &[out.view().into()],
6358                &mut output_boxes,
6359                &mut output_masks,
6360                &mut output_tracks,
6361            )
6362            .unwrap();
6363        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
6364        expected_boxes[0].bbox.xmax += 0.001;
6365        expected_boxes[1].bbox.xmin += 0.001;
6366        expected_boxes[1].bbox.xmax += 0.001;
6367
6368        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6369        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6370    }
6371
6372    #[test]
6373    fn test_decoder_tracked_end_to_end_float() {
6374        let score_threshold = 0.45;
6375        let iou_threshold = 0.45;
6376
6377        let mut boxes = Array2::zeros((10, 4));
6378        let mut scores = Array2::zeros((10, 1));
6379        let mut classes = Array2::zeros((10, 1));
6380
6381        boxes
6382            .slice_mut(s![0, ..,])
6383            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
6384        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
6385        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
6386
6387        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
6388        let mut detect = detect.insert_axis(Axis(0));
6389        assert_eq!(detect.shape(), &[1, 10, 6]);
6390        let config = "
6391decoder_version: yolo26
6392outputs:
6393 - type: detection
6394   decoder: ultralytics
6395   quantization: [0.00784313725490196, 0]
6396   shape: [1, 10, 6]
6397   dshape:
6398    - [batch, 1]
6399    - [num_boxes, 10]
6400    - [num_features, 6]
6401   normalized: true
6402";
6403
6404        let decoder = DecoderBuilder::default()
6405            .with_config_yaml_str(config.to_string())
6406            .with_score_threshold(score_threshold)
6407            .with_iou_threshold(iou_threshold)
6408            .build()
6409            .unwrap();
6410
6411        let expected_boxes = [DetectBox {
6412            bbox: BoundingBox {
6413                xmin: 0.1234,
6414                ymin: 0.1234,
6415                xmax: 0.2345,
6416                ymax: 0.2345,
6417            },
6418            score: 0.9876,
6419            label: 2,
6420        }];
6421
6422        let mut tracker = ByteTrackBuilder::new()
6423            .track_update(0.1)
6424            .track_high_conf(0.7)
6425            .build();
6426
6427        let mut output_boxes = Vec::with_capacity(50);
6428        let mut output_masks = Vec::with_capacity(50);
6429        let mut output_tracks = Vec::with_capacity(50);
6430
6431        decoder
6432            .decode_tracked_float(
6433                &mut tracker,
6434                0,
6435                &[detect.view().into_dyn()],
6436                &mut output_boxes,
6437                &mut output_masks,
6438                &mut output_tracks,
6439            )
6440            .unwrap();
6441
6442        assert_eq!(output_boxes.len(), 1);
6443        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6444
6445        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6446
6447        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6448            *score = 0.0; // set all scores to minimum to simulate no detections
6449        }
6450
6451        decoder
6452            .decode_tracked_float(
6453                &mut tracker,
6454                100_000_000 / 3,
6455                &[detect.view().into_dyn()],
6456                &mut output_boxes,
6457                &mut output_masks,
6458                &mut output_tracks,
6459            )
6460            .unwrap();
6461        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6462    }
6463}