Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod per_scale;
77pub mod schema;
78pub mod yolo;
79
80mod decoder;
81pub use decoder::*;
82
83pub use configs::{DecoderVersion, Nms};
84pub use error::{DecoderError, DecoderResult};
85pub use per_scale::DecodeDtype;
86
87use crate::{
88    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89    yolo::yolo_segmentation_to_mask,
90};
91
92/// Trait to convert bounding box formats to XYXY float format
93pub trait BBoxTypeTrait {
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97    /// Converts the bbox into XYXY float format.
98    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99        input: &[B; 4],
100        quant: Quantization,
101    ) -> [A; 4]
102    where
103        f32: AsPrimitive<A>,
104        i32: AsPrimitive<A>;
105
106    /// Converts the bbox into XYXY float format.
107    ///
108    /// # Examples
109    /// ```rust
110    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
111    /// # use ndarray::array;
112    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
113    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
114    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
115    /// ```
116    #[inline(always)]
117    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118        input: ArrayView1<B>,
119    ) -> [A; 4] {
120        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121    }
122
123    #[inline(always)]
124    /// Converts the bbox into XYXY float format.
125    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126        input: ArrayView1<B>,
127        quant: Quantization,
128    ) -> [A; 4]
129    where
130        f32: AsPrimitive<A>,
131        i32: AsPrimitive<A>,
132    {
133        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134    }
135}
136
137/// Converts XYXY bounding boxes to XYXY
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143        input.map(|b| b.as_())
144    }
145
146    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147        input: &[B; 4],
148        quant: Quantization,
149    ) -> [A; 4]
150    where
151        f32: AsPrimitive<A>,
152        i32: AsPrimitive<A>,
153    {
154        let scale = quant.scale.as_();
155        let zp = quant.zero_point.as_();
156        input.map(|b| (b.as_() - zp) * scale)
157    }
158
159    #[inline(always)]
160    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161        input: ArrayView1<B>,
162    ) -> [A; 4] {
163        [
164            input[0].as_(),
165            input[1].as_(),
166            input[2].as_(),
167            input[3].as_(),
168        ]
169    }
170}
171
172/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
173/// box
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178    #[inline(always)]
179    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180        let half = A::one() / (A::one() + A::one());
181        [
182            (input[0].as_()) - (input[2].as_() * half),
183            (input[1].as_()) - (input[3].as_() * half),
184            (input[0].as_()) + (input[2].as_() * half),
185            (input[1].as_()) + (input[3].as_() * half),
186        ]
187    }
188
189    #[inline(always)]
190    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191        input: &[B; 4],
192        quant: Quantization,
193    ) -> [A; 4]
194    where
195        f32: AsPrimitive<A>,
196        i32: AsPrimitive<A>,
197    {
198        let scale = quant.scale.as_();
199        let half_scale = (quant.scale * 0.5).as_();
200        let zp = quant.zero_point.as_();
201        let [x, y, w, h] = [
202            (input[0].as_() - zp) * scale,
203            (input[1].as_() - zp) * scale,
204            (input[2].as_() - zp) * half_scale,
205            (input[3].as_() - zp) * half_scale,
206        ];
207
208        [x - w, y - h, x + w, y + h]
209    }
210
211    #[inline(always)]
212    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213        input: ArrayView1<B>,
214    ) -> [A; 4] {
215        let half = A::one() / (A::one() + A::one());
216        [
217            (input[0].as_()) - (input[2].as_() * half),
218            (input[1].as_()) - (input[3].as_() * half),
219            (input[0].as_()) + (input[2].as_() * half),
220            (input[1].as_()) + (input[3].as_() * half),
221        ]
222    }
223}
224
225/// Describes the quantization parameters for a tensor
226#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228    pub scale: f32,
229    pub zero_point: i32,
230}
231
232impl Quantization {
233    /// Creates a new Quantization struct
234    /// # Examples
235    /// ```
236    /// # use edgefirst_decoder::Quantization;
237    /// let quant = Quantization::new(0.1, -128);
238    /// assert_eq!(quant.scale, 0.1);
239    /// assert_eq!(quant.zero_point, -128);
240    /// ```
241    pub fn new(scale: f32, zero_point: i32) -> Self {
242        Self { scale, zero_point }
243    }
244
245    /// Returns a quantization that's a no-op: scale=1.0, zero_point=0.
246    ///
247    /// Used by the per-scale pipeline as a sentinel for float-to-float
248    /// passthrough cells where no affine transform is required.
249    /// # Examples
250    /// ```
251    /// # use edgefirst_decoder::Quantization;
252    /// let quant = Quantization::identity();
253    /// assert_eq!(quant.scale, 1.0);
254    /// assert_eq!(quant.zero_point, 0);
255    /// ```
256    pub fn identity() -> Self {
257        Self {
258            scale: 1.0,
259            zero_point: 0,
260        }
261    }
262}
263
264impl From<QuantTuple> for Quantization {
265    /// Creates a new Quantization struct from a QuantTuple
266    /// # Examples
267    /// ```
268    /// # use edgefirst_decoder::Quantization;
269    /// # use edgefirst_decoder::configs::QuantTuple;
270    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
271    /// let quant = Quantization::from(quant_tuple);
272    /// assert_eq!(quant.scale, 0.1);
273    /// assert_eq!(quant.zero_point, -128);
274    /// ```
275    fn from(quant_tuple: QuantTuple) -> Quantization {
276        Quantization {
277            scale: quant_tuple.0,
278            zero_point: quant_tuple.1,
279        }
280    }
281}
282
283impl<S, Z> From<(S, Z)> for Quantization
284where
285    S: AsPrimitive<f32>,
286    Z: AsPrimitive<i32>,
287{
288    /// Creates a new Quantization struct from a tuple
289    /// # Examples
290    /// ```
291    /// # use edgefirst_decoder::Quantization;
292    /// let quant = Quantization::from((0.1_f64, -128_i64));
293    /// assert_eq!(quant.scale, 0.1);
294    /// assert_eq!(quant.zero_point, -128);
295    /// ```
296    fn from((scale, zp): (S, Z)) -> Quantization {
297        Self {
298            scale: scale.as_(),
299            zero_point: zp.as_(),
300        }
301    }
302}
303
304impl Default for Quantization {
305    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
306    /// # Examples
307    /// ```rust
308    /// # use edgefirst_decoder::Quantization;
309    /// let quant = Quantization::default();
310    /// assert_eq!(quant.scale, 1.0);
311    /// assert_eq!(quant.zero_point, 0);
312    /// ```
313    fn default() -> Self {
314        Self {
315            scale: 1.0,
316            zero_point: 0,
317        }
318    }
319}
320
321/// A detection box with f32 bbox and score
322#[derive(Debug, Clone, Copy, PartialEq, Default)]
323pub struct DetectBox {
324    pub bbox: BoundingBox,
325    /// model-specific score for this detection, higher implies more confidence
326    pub score: f32,
327    /// label index for this detection
328    pub label: usize,
329}
330
331/// A bounding box with f32 coordinates in XYXY format
332#[derive(Debug, Clone, Copy, PartialEq, Default)]
333#[repr(C)]
334pub struct BoundingBox {
335    /// left-most normalized coordinate of the bounding box
336    pub xmin: f32,
337    /// top-most normalized coordinate of the bounding box
338    pub ymin: f32,
339    /// right-most normalized coordinate of the bounding box
340    pub xmax: f32,
341    /// bottom-most normalized coordinate of the bounding box
342    pub ymax: f32,
343}
344
345// Guarantee that BoundingBox is 4 contiguous f32 fields for NEON vld1q_f32 loads.
346const _: () = assert!(std::mem::size_of::<BoundingBox>() == 4 * std::mem::size_of::<f32>());
347const _: () = assert!(std::mem::align_of::<BoundingBox>() == std::mem::align_of::<f32>());
348
349impl BoundingBox {
350    /// Creates a new BoundingBox from the given coordinates
351    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
352        Self {
353            xmin,
354            ymin,
355            xmax,
356            ymax,
357        }
358    }
359
360    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
361    ///
362    /// ```
363    /// # use edgefirst_decoder::BoundingBox;
364    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
365    /// let canonical_bbox = bbox.to_canonical();
366    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
367    /// ```
368    pub fn to_canonical(&self) -> Self {
369        let xmin = self.xmin.min(self.xmax);
370        let xmax = self.xmin.max(self.xmax);
371        let ymin = self.ymin.min(self.ymax);
372        let ymax = self.ymin.max(self.ymax);
373        BoundingBox {
374            xmin,
375            ymin,
376            xmax,
377            ymax,
378        }
379    }
380}
381
382impl From<BoundingBox> for [f32; 4] {
383    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
384    /// xmax, ymax order
385    /// # Examples
386    /// ```
387    /// # use edgefirst_decoder::BoundingBox;
388    /// let bbox = BoundingBox {
389    ///     xmin: 0.1,
390    ///     ymin: 0.2,
391    ///     xmax: 0.3,
392    ///     ymax: 0.4,
393    /// };
394    /// let arr: [f32; 4] = bbox.into();
395    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
396    /// ```
397    fn from(b: BoundingBox) -> Self {
398        [b.xmin, b.ymin, b.xmax, b.ymax]
399    }
400}
401
402impl From<[f32; 4]> for BoundingBox {
403    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
404    // BoundingBox
405    fn from(arr: [f32; 4]) -> Self {
406        BoundingBox {
407            xmin: arr[0],
408            ymin: arr[1],
409            xmax: arr[2],
410            ymax: arr[3],
411        }
412    }
413}
414
415impl DetectBox {
416    /// Returns true if one detect box is equal to another detect box, within
417    /// the given `eps`
418    ///
419    /// # Examples
420    /// ```
421    /// # use edgefirst_decoder::DetectBox;
422    /// let box1 = DetectBox {
423    ///     bbox: edgefirst_decoder::BoundingBox {
424    ///         xmin: 0.1,
425    ///         ymin: 0.2,
426    ///         xmax: 0.3,
427    ///         ymax: 0.4,
428    ///     },
429    ///     score: 0.5,
430    ///     label: 1,
431    /// };
432    /// let box2 = DetectBox {
433    ///     bbox: edgefirst_decoder::BoundingBox {
434    ///         xmin: 0.101,
435    ///         ymin: 0.199,
436    ///         xmax: 0.301,
437    ///         ymax: 0.399,
438    ///     },
439    ///     score: 0.510,
440    ///     label: 1,
441    /// };
442    /// assert!(box1.equal_within_delta(&box2, 0.011));
443    /// ```
444    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
445        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
446        self.label == rhs.label
447            && eq_delta(self.score, rhs.score)
448            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
449            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
450            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
451            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
452    }
453}
454
455/// A segmentation result paired with the normalized bounding box that
456/// describes the spatial extent of `segmentation`.
457///
458/// The `xmin`/`ymin`/`xmax`/`ymax` fields describe the **mask region** —
459/// the proto-grid-aligned crop the [`segmentation`] tensor was sliced
460/// from. They are quantized to the proto grid step (`1/proto_height` and
461/// `1/proto_width`, typically `1/160`), so the region's origin floors and
462/// its extent ceils relative to the companion [`DetectBox`]'s `bbox`. The
463/// detection bbox itself stays un-snapped (see EDGEAI-1304); use
464/// `Segmentation` bounds for rendering masks, and `DetectBox.bbox` for
465/// IoU evaluation, drawing detection rectangles, etc. The mask region
466/// always encloses the detection bbox.
467#[derive(Debug, Clone, PartialEq, Default)]
468pub struct Segmentation {
469    /// Left-most normalized coordinate of the mask region (proto-grid
470    /// floor of the companion `DetectBox.bbox.xmin`).
471    pub xmin: f32,
472    /// Top-most normalized coordinate of the mask region (proto-grid
473    /// floor of the companion `DetectBox.bbox.ymin`).
474    pub ymin: f32,
475    /// Right-most normalized coordinate of the mask region (proto-grid
476    /// ceil of the companion `DetectBox.bbox.xmax`).
477    pub xmax: f32,
478    /// Bottom-most normalized coordinate of the mask region (proto-grid
479    /// ceil of the companion `DetectBox.bbox.ymax`).
480    pub ymax: f32,
481    /// 3D segmentation array of shape `(H, W, C)`.
482    ///
483    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
484    /// continuous sigmoid confidence values quantized to u8 (0 = background,
485    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
486    /// 0.5) or use smooth interpolation for anti-aliased edges.
487    ///
488    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
489    /// class scores where the object class is the argmax index.
490    pub segmentation: Array3<u8>,
491}
492
493/// Memory layout of the prototype tensor within [`ProtoData`].
494///
495/// Models may output protos in either channel-last (NHWC) or channel-first
496/// (NCHW) layout. The mask materialisation kernels dispatch on this field to
497/// avoid a costly per-frame transpose.
498#[derive(Debug, Clone, Copy, PartialEq, Eq)]
499pub enum ProtoLayout {
500    /// Channel-last: tensor shape is `[H, W, K]`, contiguous along K.
501    /// This is the traditional layout produced by the `extract_proto_data`
502    /// path after transposing from NCHW model outputs.
503    Nhwc,
504    /// Channel-first: tensor shape is `[K, H, W]`, each channel plane is
505    /// contiguous. Skipping the NCHW→NHWC transpose saves ~3 ms per frame
506    /// on Cortex-A53/A55 targets (819 KB for 32×160×160 protos).
507    Nchw,
508}
509
510/// Raw prototype data for fused decode+render pipelines.
511///
512/// Holds post-NMS intermediate state before mask materialization, allowing the
513/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
514/// shader) without materializing intermediate `Array3<u8>` masks.
515///
516/// Both fields are carried as [`TensorDyn`] so downstream consumers (Rust, C
517/// API, Python) get zero-copy typed access through the HAL's shared tensor
518/// infrastructure. Dtype policy:
519///
520/// | Source model | protos dtype | mask_coefficients dtype | protos.quantization |
521/// |---|---|---|---|
522/// | int8 quantized | [`TensorDyn::I8`] | [`TensorDyn::I8`] (raw + quantization) | `Some(q)` |
523/// | f32 | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
524/// | f16 (TensorRT fp16) | [`TensorDyn::F16`] | [`TensorDyn::F16`] | `None` |
525/// | f64 (narrowed) | [`TensorDyn::F32`] | [`TensorDyn::F32`] | `None` |
526///
527/// Quantization metadata lives on the proto tensor itself via
528/// [`edgefirst_tensor::Tensor::quantization`] — float tensors cannot carry
529/// quantization (compile-time gated on the `IntegerType` sealed trait).
530///
531/// `TensorDyn` is not `Clone`, so neither is `ProtoData`. Consumers that need
532/// to share the proto buffer across threads should use `TensorDyn::clone_fd`
533/// / `dmabuf_clone` to dup the backing fd.
534#[derive(Debug)]
535pub struct ProtoData {
536    /// Per-detection mask coefficients, shape `[num_detections, num_protos]`.
537    pub mask_coefficients: edgefirst_tensor::TensorDyn,
538    /// Prototype tensor.
539    ///
540    /// - When `layout == ProtoLayout::Nhwc`: shape is `[proto_h, proto_w, num_protos]`.
541    /// - When `layout == ProtoLayout::Nchw`: shape is `[num_protos, proto_h, proto_w]`.
542    pub protos: edgefirst_tensor::TensorDyn,
543    /// Physical memory layout of the `protos` tensor.
544    pub layout: ProtoLayout,
545}
546
547/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
548///
549///  # Examples
550/// ```
551/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
552/// let quant = Quantization::new(0.1, -128);
553/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
554/// let detect_quant = DetectBoxQuantized {
555///     bbox,
556///     score: 100_i8,
557///     label: 1,
558/// };
559/// let detect = dequant_detect_box(&detect_quant, quant);
560/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
561/// assert_eq!(detect.label, 1);
562/// assert_eq!(detect.bbox, bbox);
563/// ```
564pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
565    detect: &DetectBoxQuantized<SCORE>,
566    quant_scores: Quantization,
567) -> DetectBox {
568    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
569    DetectBox {
570        bbox: detect.bbox,
571        score: quant_scores.scale * detect.score.as_() + scaled_zp,
572        label: detect.label,
573    }
574}
575/// A detection box with a f32 bbox and quantized score
576#[derive(Debug, Clone, Copy, PartialEq)]
577pub struct DetectBoxQuantized<
578    // BOX: Signed + PrimInt + AsPrimitive<f32>,
579    SCORE: PrimInt + AsPrimitive<f32>,
580> {
581    // pub bbox: BoundingBoxQuantized<BOX>,
582    pub bbox: BoundingBox,
583    /// model-specific score for this detection, higher implies more
584    /// confidence.
585    pub score: SCORE,
586    /// label index for this detect
587    pub label: usize,
588}
589
590/// Dequantizes an ndarray from quantized values to f32 values using the given
591/// quantization parameters
592///
593/// # Examples
594/// ```
595/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
596/// let quant = Quantization::new(0.1, -128);
597/// let input: Vec<i8> = vec![0, 127, -128, 64];
598/// let input_array = ndarray::Array1::from(input);
599/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
600/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
601/// ```
602pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
603    input: ArrayView<T, D>,
604    quant: Quantization,
605) -> Array<F, D>
606where
607    i32: num_traits::AsPrimitive<F>,
608    f32: num_traits::AsPrimitive<F>,
609{
610    let zero_point = quant.zero_point.as_();
611    let scale = quant.scale.as_();
612    if zero_point != F::zero() {
613        let scaled_zero = -zero_point * scale;
614        input.mapv(|d| d.as_() * scale + scaled_zero)
615    } else {
616        input.mapv(|d| d.as_() * scale)
617    }
618}
619
620/// Dequantizes a slice from quantized values to float values using the given
621/// quantization parameters
622///
623/// # Examples
624/// ```
625/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
626/// let quant = Quantization::new(0.1, -128);
627/// let input: Vec<i8> = vec![0, 127, -128, 64];
628/// let mut output: Vec<f32> = vec![0.0; input.len()];
629/// dequantize_cpu(&input, quant, &mut output);
630/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
631/// ```
632pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
633    input: &[T],
634    quant: Quantization,
635    output: &mut [F],
636) where
637    f32: num_traits::AsPrimitive<F>,
638    i32: num_traits::AsPrimitive<F>,
639{
640    assert!(input.len() == output.len());
641    let zero_point = quant.zero_point.as_();
642    let scale = quant.scale.as_();
643    if zero_point != F::zero() {
644        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
645        input
646            .iter()
647            .zip(output)
648            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
649    } else {
650        input
651            .iter()
652            .zip(output)
653            .for_each(|(d, deq)| *deq = d.as_() * scale);
654    }
655}
656
657/// Dequantizes a slice from quantized values to float values using the given
658/// quantization parameters, using chunked processing. This is around 5% faster
659/// than `dequantize_cpu` for large slices.
660///
661/// # Examples
662/// ```
663/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
664/// let quant = Quantization::new(0.1, -128);
665/// let input: Vec<i8> = vec![0, 127, -128, 64];
666/// let mut output: Vec<f32> = vec![0.0; input.len()];
667/// dequantize_cpu_chunked(&input, quant, &mut output);
668/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
669/// ```
670pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
671    input: &[T],
672    quant: Quantization,
673    output: &mut [F],
674) where
675    f32: num_traits::AsPrimitive<F>,
676    i32: num_traits::AsPrimitive<F>,
677{
678    assert!(input.len() == output.len());
679    let zero_point = quant.zero_point.as_();
680    let scale = quant.scale.as_();
681
682    let input = input.as_chunks::<4>();
683    let output = output.as_chunks_mut::<4>();
684
685    if zero_point != F::zero() {
686        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
687
688        input
689            .0
690            .iter()
691            .zip(output.0)
692            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
693        input
694            .1
695            .iter()
696            .zip(output.1)
697            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
698    } else {
699        input
700            .0
701            .iter()
702            .zip(output.0)
703            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
704        input
705            .1
706            .iter()
707            .zip(output.1)
708            .for_each(|(d, deq)| *deq = d.as_() * scale);
709    }
710}
711
712/// Converts a segmentation tensor into a 2D mask
713/// If the last dimension of the segmentation tensor is 1, values equal or
714/// above 128 are considered objects. Otherwise the object is the argmax index
715///
716/// # Errors
717///
718/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
719/// invalid shape.
720///
721/// # Examples
722/// ```
723/// # use edgefirst_decoder::segmentation_to_mask;
724/// let segmentation =
725///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
726/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
727/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
728/// ```
729pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
730    if segmentation.shape()[2] == 0 {
731        return Err(DecoderError::InvalidShape(
732            "Segmentation tensor must have non-zero depth".to_string(),
733        ));
734    }
735    if segmentation.shape()[2] == 1 {
736        yolo_segmentation_to_mask(segmentation, 128)
737    } else {
738        Ok(modelpack_segmentation_to_mask(segmentation))
739    }
740}
741
742/// Returns the maximum value and its index from a 1D array
743fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
744    score
745        .iter()
746        .enumerate()
747        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
748            if max > *s {
749                (max, arg_max)
750            } else {
751                (*s, ind)
752            }
753        })
754}
755
756/// NEON-accelerated argmax for i8 slices on aarch64.
757///
758/// Finds the maximum value and its index (last index wins on ties, matching
759/// the semantics of [`arg_max`]).  Falls back to the scalar implementation
760/// on non-aarch64 targets or when the slice is too short to benefit from NEON.
761#[cfg(target_arch = "aarch64")]
762pub(crate) fn arg_max_i8(scores: &[i8]) -> (i8, usize) {
763    use std::arch::aarch64::*;
764
765    let n = scores.len();
766    if n < 16 {
767        // Scalar fallback for very short slices.
768        let mut max = scores[0];
769        let mut idx = 0;
770        for (i, &s) in scores.iter().enumerate().skip(1) {
771            if s >= max {
772                max = s;
773                idx = i;
774            }
775        }
776        return (max, idx);
777    }
778
779    unsafe {
780        // Phase 1: Find the global max value using NEON horizontal max.
781        let chunks = n / 16;
782        let mut vmax = vld1q_s8(scores.as_ptr());
783        for i in 1..chunks {
784            let v = vld1q_s8(scores.as_ptr().add(i * 16));
785            vmax = vmaxq_s8(vmax, v);
786        }
787        let global_max = vmaxvq_s8(vmax);
788
789        // Handle remainder with scalar
790        let remainder_start = chunks * 16;
791        let mut final_max = global_max;
792        for &s in &scores[remainder_start..] {
793            if s > final_max {
794                final_max = s;
795            }
796        }
797
798        // Phase 2: Find the LAST index of `final_max` (preserves tie semantics).
799        // Scan backwards for the last occurrence.
800        let mut idx = 0;
801        for i in (0..n).rev() {
802            if scores[i] == final_max {
803                idx = i;
804                break;
805            }
806        }
807        (final_max, idx)
808    }
809}
810#[cfg(test)]
811#[cfg_attr(coverage_nightly, coverage(off))]
812mod decoder_tests {
813    #![allow(clippy::excessive_precision)]
814    use crate::{
815        configs::{DecoderType, DimName, Protos},
816        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
817        yolo::{
818            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
819            decode_yolo_segdet_quant,
820        },
821        *,
822    };
823    use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
824    use ndarray::Dimension;
825    use ndarray::{array, s, Array2, Array3, Array4, Axis};
826    use ndarray_stats::DeviationExt;
827    use num_traits::{AsPrimitive, PrimInt};
828
829    fn compare_outputs(
830        boxes: (&[DetectBox], &[DetectBox]),
831        masks: (&[Segmentation], &[Segmentation]),
832    ) {
833        let (boxes0, boxes1) = boxes;
834        let (masks0, masks1) = masks;
835
836        assert_eq!(boxes0.len(), boxes1.len());
837        assert_eq!(masks0.len(), masks1.len());
838
839        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
840            assert!(
841                b_i8.equal_within_delta(b_f32, 1e-6),
842                "{b_i8:?} is not equal to {b_f32:?}"
843            );
844        }
845
846        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
847            assert_eq!(
848                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
849                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
850            );
851            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
852            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
853            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
854            let diff = &mask_i8 - &mask_f32;
855            for x in 0..diff.shape()[0] {
856                for y in 0..diff.shape()[1] {
857                    for z in 0..diff.shape()[2] {
858                        let val = diff[[x, y, z]];
859                        assert!(
860                            val.abs() <= 1,
861                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
862                            x,
863                            y,
864                            z,
865                            val
866                        );
867                    }
868                }
869            }
870            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
871            assert!(
872                mean_sq_err < 1e-2,
873                "Mean Square Error between masks was greater than 1%: {:.2}%",
874                mean_sq_err * 100.0
875            );
876        }
877    }
878
879    // ─── Shared test data loaders ────────────────────────
880
881    fn load_yolov8_boxes() -> Array3<i8> {
882        let raw = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
883        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
884        Array3::from_shape_vec((1, 116, 8400), raw.to_vec()).unwrap()
885    }
886
887    fn load_yolov8_protos() -> Array4<i8> {
888        let raw = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
889        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
890        Array4::from_shape_vec((1, 160, 160, 32), raw.to_vec()).unwrap()
891    }
892
893    fn load_yolov8s_det() -> Array3<i8> {
894        let raw = edgefirst_bench::testdata::read("yolov8s_80_classes.bin");
895        let raw = unsafe { std::slice::from_raw_parts(raw.as_ptr() as *const i8, raw.len()) };
896        Array3::from_shape_vec((1, 84, 8400), raw.to_vec()).unwrap()
897    }
898
899    #[test]
900    fn test_decoder_modelpack() {
901        let score_threshold = 0.45;
902        let iou_threshold = 0.45;
903        let boxes = edgefirst_bench::testdata::read("modelpack_boxes_1935x1x4.bin");
904        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
905
906        let scores = edgefirst_bench::testdata::read("modelpack_scores_1935x1.bin");
907        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
908
909        let quant_boxes = (0.004656755365431309, 21).into();
910        let quant_scores = (0.0019603664986789227, 0).into();
911
912        let decoder = DecoderBuilder::default()
913            .with_config_modelpack_det(
914                configs::Boxes {
915                    decoder: DecoderType::ModelPack,
916                    quantization: Some(quant_boxes),
917                    shape: vec![1, 1935, 1, 4],
918                    dshape: vec![
919                        (DimName::Batch, 1),
920                        (DimName::NumBoxes, 1935),
921                        (DimName::Padding, 1),
922                        (DimName::BoxCoords, 4),
923                    ],
924                    normalized: Some(true),
925                },
926                configs::Scores {
927                    decoder: DecoderType::ModelPack,
928                    quantization: Some(quant_scores),
929                    shape: vec![1, 1935, 1],
930                    dshape: vec![
931                        (DimName::Batch, 1),
932                        (DimName::NumBoxes, 1935),
933                        (DimName::NumClasses, 1),
934                    ],
935                },
936            )
937            .with_score_threshold(score_threshold)
938            .with_iou_threshold(iou_threshold)
939            .build()
940            .unwrap();
941
942        let quant_boxes = quant_boxes.into();
943        let quant_scores = quant_scores.into();
944
945        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
946        decode_modelpack_det(
947            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
948            (scores.slice(s![0, .., ..]), quant_scores),
949            score_threshold,
950            iou_threshold,
951            300,
952            &mut output_boxes,
953        );
954        assert!(output_boxes[0].equal_within_delta(
955            &DetectBox {
956                bbox: BoundingBox {
957                    xmin: 0.40513772,
958                    ymin: 0.6379755,
959                    xmax: 0.5122431,
960                    ymax: 0.7730214,
961                },
962                score: 0.4861709,
963                label: 0
964            },
965            1e-6
966        ));
967
968        let mut output_boxes1 = Vec::with_capacity(50);
969        let mut output_masks1 = Vec::with_capacity(50);
970
971        decoder
972            .decode_quantized(
973                &[boxes.view().into(), scores.view().into()],
974                &mut output_boxes1,
975                &mut output_masks1,
976            )
977            .unwrap();
978
979        let mut output_boxes_float = Vec::with_capacity(50);
980        let mut output_masks_float = Vec::with_capacity(50);
981
982        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
983        let scores = dequantize_ndarray(scores.view(), quant_scores);
984
985        decoder
986            .decode_float::<f32>(
987                &[boxes.view().into_dyn(), scores.view().into_dyn()],
988                &mut output_boxes_float,
989                &mut output_masks_float,
990            )
991            .unwrap();
992
993        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
994        compare_outputs(
995            (&output_boxes, &output_boxes_float),
996            (&[], &output_masks_float),
997        );
998    }
999
1000    #[test]
1001    fn test_decoder_modelpack_split_u8() {
1002        let score_threshold = 0.45;
1003        let iou_threshold = 0.45;
1004        let detect0 = edgefirst_bench::testdata::read("modelpack_split_9x15x18.bin");
1005        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1006
1007        let detect1 = edgefirst_bench::testdata::read("modelpack_split_17x30x18.bin");
1008        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1009
1010        let quant0 = (0.08547406643629074, 174).into();
1011        let quant1 = (0.09929127991199493, 183).into();
1012        let anchors0 = vec![
1013            [0.36666667461395264, 0.31481480598449707],
1014            [0.38749998807907104, 0.4740740656852722],
1015            [0.5333333611488342, 0.644444465637207],
1016        ];
1017        let anchors1 = vec![
1018            [0.13750000298023224, 0.2074074000120163],
1019            [0.2541666626930237, 0.21481481194496155],
1020            [0.23125000298023224, 0.35185185074806213],
1021        ];
1022
1023        let detect_config0 = configs::Detection {
1024            decoder: DecoderType::ModelPack,
1025            shape: vec![1, 9, 15, 18],
1026            anchors: Some(anchors0.clone()),
1027            quantization: Some(quant0),
1028            dshape: vec![
1029                (DimName::Batch, 1),
1030                (DimName::Height, 9),
1031                (DimName::Width, 15),
1032                (DimName::NumAnchorsXFeatures, 18),
1033            ],
1034            normalized: Some(true),
1035        };
1036
1037        let detect_config1 = configs::Detection {
1038            decoder: DecoderType::ModelPack,
1039            shape: vec![1, 17, 30, 18],
1040            anchors: Some(anchors1.clone()),
1041            quantization: Some(quant1),
1042            dshape: vec![
1043                (DimName::Batch, 1),
1044                (DimName::Height, 17),
1045                (DimName::Width, 30),
1046                (DimName::NumAnchorsXFeatures, 18),
1047            ],
1048            normalized: Some(true),
1049        };
1050
1051        let config0 = (&detect_config0).try_into().unwrap();
1052        let config1 = (&detect_config1).try_into().unwrap();
1053
1054        let decoder = DecoderBuilder::default()
1055            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
1056            .with_score_threshold(score_threshold)
1057            .with_iou_threshold(iou_threshold)
1058            .build()
1059            .unwrap();
1060
1061        let quant0 = quant0.into();
1062        let quant1 = quant1.into();
1063
1064        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1065        decode_modelpack_split_quant(
1066            &[
1067                detect0.slice(s![0, .., .., ..]),
1068                detect1.slice(s![0, .., .., ..]),
1069            ],
1070            &[config0, config1],
1071            score_threshold,
1072            iou_threshold,
1073            300,
1074            &mut output_boxes,
1075        );
1076        assert!(output_boxes[0].equal_within_delta(
1077            &DetectBox {
1078                bbox: BoundingBox {
1079                    xmin: 0.43171933,
1080                    ymin: 0.68243736,
1081                    xmax: 0.5626645,
1082                    ymax: 0.808863,
1083                },
1084                score: 0.99240804,
1085                label: 0
1086            },
1087            1e-6
1088        ));
1089
1090        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
1091        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
1092        decoder
1093            .decode_quantized(
1094                &[detect0.view().into(), detect1.view().into()],
1095                &mut output_boxes1,
1096                &mut output_masks1,
1097            )
1098            .unwrap();
1099
1100        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1101        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1102
1103        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1104        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1105        decoder
1106            .decode_float::<f32>(
1107                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1108                &mut output_boxes1_f32,
1109                &mut output_masks1_f32,
1110            )
1111            .unwrap();
1112
1113        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1114        compare_outputs(
1115            (&output_boxes, &output_boxes1_f32),
1116            (&[], &output_masks1_f32),
1117        );
1118    }
1119
1120    #[test]
1121    fn test_decoder_parse_config_modelpack_split_u8() {
1122        let score_threshold = 0.45;
1123        let iou_threshold = 0.45;
1124        let detect0 = edgefirst_bench::testdata::read("modelpack_split_9x15x18.bin");
1125        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1126
1127        let detect1 = edgefirst_bench::testdata::read("modelpack_split_17x30x18.bin");
1128        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1129
1130        let decoder = DecoderBuilder::default()
1131            .with_config_yaml_str(
1132                edgefirst_bench::testdata::read_to_string("modelpack_split.yaml").to_string(),
1133            )
1134            .with_score_threshold(score_threshold)
1135            .with_iou_threshold(iou_threshold)
1136            .build()
1137            .unwrap();
1138
1139        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1140        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1141        decoder
1142            .decode_quantized(
1143                &[
1144                    ArrayViewDQuantized::from(detect1.view()),
1145                    ArrayViewDQuantized::from(detect0.view()),
1146                ],
1147                &mut output_boxes,
1148                &mut output_masks,
1149            )
1150            .unwrap();
1151        assert!(output_boxes[0].equal_within_delta(
1152            &DetectBox {
1153                bbox: BoundingBox {
1154                    xmin: 0.43171933,
1155                    ymin: 0.68243736,
1156                    xmax: 0.5626645,
1157                    ymax: 0.808863,
1158                },
1159                score: 0.99240804,
1160                label: 0
1161            },
1162            1e-6
1163        ));
1164    }
1165
1166    #[test]
1167    fn test_modelpack_seg() {
1168        let out = edgefirst_bench::testdata::read("modelpack_seg_2x160x160.bin");
1169        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1170        let quant = (1.0 / 255.0, 0).into();
1171
1172        let decoder = DecoderBuilder::default()
1173            .with_config_modelpack_seg(configs::Segmentation {
1174                decoder: DecoderType::ModelPack,
1175                quantization: Some(quant),
1176                shape: vec![1, 2, 160, 160],
1177                dshape: vec![
1178                    (DimName::Batch, 1),
1179                    (DimName::NumClasses, 2),
1180                    (DimName::Height, 160),
1181                    (DimName::Width, 160),
1182                ],
1183            })
1184            .build()
1185            .unwrap();
1186        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1187        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1188        decoder
1189            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1190            .unwrap();
1191
1192        let mut mask = out.slice(s![0, .., .., ..]);
1193        mask.swap_axes(0, 1);
1194        mask.swap_axes(1, 2);
1195        let mask = [Segmentation {
1196            xmin: 0.0,
1197            ymin: 0.0,
1198            xmax: 1.0,
1199            ymax: 1.0,
1200            segmentation: mask.into_owned(),
1201        }];
1202        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1203
1204        decoder
1205            .decode_float::<f32>(
1206                &[dequantize_ndarray(out.view(), quant.into())
1207                    .view()
1208                    .into_dyn()],
1209                &mut output_boxes,
1210                &mut output_masks,
1211            )
1212            .unwrap();
1213
1214        // not expected for float decoder to have same values as quantized decoder, as
1215        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1216        // the model output. Thus the float output is the same as the quantized output
1217        // but scaled differently. However, it is expected that the mask after argmax
1218        // will be the same.
1219        compare_outputs((&[], &output_boxes), (&[], &[]));
1220        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1221        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1222
1223        assert_eq!(mask0, mask1);
1224    }
1225    #[test]
1226    fn test_modelpack_seg_quant() {
1227        let out = edgefirst_bench::testdata::read("modelpack_seg_2x160x160.bin");
1228        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1229        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1230        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1231        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1232        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1233        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1234
1235        let quant = (1.0 / 255.0, 0).into();
1236
1237        let decoder = DecoderBuilder::default()
1238            .with_config_modelpack_seg(configs::Segmentation {
1239                decoder: DecoderType::ModelPack,
1240                quantization: Some(quant),
1241                shape: vec![1, 2, 160, 160],
1242                dshape: vec![
1243                    (DimName::Batch, 1),
1244                    (DimName::NumClasses, 2),
1245                    (DimName::Height, 160),
1246                    (DimName::Width, 160),
1247                ],
1248            })
1249            .build()
1250            .unwrap();
1251        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1252        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1253        decoder
1254            .decode_quantized(
1255                &[out_u8.view().into()],
1256                &mut output_boxes,
1257                &mut output_masks_u8,
1258            )
1259            .unwrap();
1260
1261        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1262        decoder
1263            .decode_quantized(
1264                &[out_i8.view().into()],
1265                &mut output_boxes,
1266                &mut output_masks_i8,
1267            )
1268            .unwrap();
1269
1270        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1271        decoder
1272            .decode_quantized(
1273                &[out_u16.view().into()],
1274                &mut output_boxes,
1275                &mut output_masks_u16,
1276            )
1277            .unwrap();
1278
1279        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1280        decoder
1281            .decode_quantized(
1282                &[out_i16.view().into()],
1283                &mut output_boxes,
1284                &mut output_masks_i16,
1285            )
1286            .unwrap();
1287
1288        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1289        decoder
1290            .decode_quantized(
1291                &[out_u32.view().into()],
1292                &mut output_boxes,
1293                &mut output_masks_u32,
1294            )
1295            .unwrap();
1296
1297        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1298        decoder
1299            .decode_quantized(
1300                &[out_i32.view().into()],
1301                &mut output_boxes,
1302                &mut output_masks_i32,
1303            )
1304            .unwrap();
1305
1306        compare_outputs((&[], &output_boxes), (&[], &[]));
1307        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1308        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1309        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1310        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1311        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1312        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1313        assert_eq!(mask_u8, mask_i8);
1314        assert_eq!(mask_u8, mask_u16);
1315        assert_eq!(mask_u8, mask_i16);
1316        assert_eq!(mask_u8, mask_u32);
1317        assert_eq!(mask_u8, mask_i32);
1318    }
1319
1320    #[test]
1321    fn test_modelpack_segdet() {
1322        let score_threshold = 0.45;
1323        let iou_threshold = 0.45;
1324
1325        let boxes = edgefirst_bench::testdata::read("modelpack_boxes_1935x1x4.bin");
1326        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1327
1328        let scores = edgefirst_bench::testdata::read("modelpack_scores_1935x1.bin");
1329        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1330
1331        let seg = edgefirst_bench::testdata::read("modelpack_seg_2x160x160.bin");
1332        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1333
1334        let quant_boxes = (0.004656755365431309, 21).into();
1335        let quant_scores = (0.0019603664986789227, 0).into();
1336        let quant_seg = (1.0 / 255.0, 0).into();
1337
1338        let decoder = DecoderBuilder::default()
1339            .with_config_modelpack_segdet(
1340                configs::Boxes {
1341                    decoder: DecoderType::ModelPack,
1342                    quantization: Some(quant_boxes),
1343                    shape: vec![1, 1935, 1, 4],
1344                    dshape: vec![
1345                        (DimName::Batch, 1),
1346                        (DimName::NumBoxes, 1935),
1347                        (DimName::Padding, 1),
1348                        (DimName::BoxCoords, 4),
1349                    ],
1350                    normalized: Some(true),
1351                },
1352                configs::Scores {
1353                    decoder: DecoderType::ModelPack,
1354                    quantization: Some(quant_scores),
1355                    shape: vec![1, 1935, 1],
1356                    dshape: vec![
1357                        (DimName::Batch, 1),
1358                        (DimName::NumBoxes, 1935),
1359                        (DimName::NumClasses, 1),
1360                    ],
1361                },
1362                configs::Segmentation {
1363                    decoder: DecoderType::ModelPack,
1364                    quantization: Some(quant_seg),
1365                    shape: vec![1, 2, 160, 160],
1366                    dshape: vec![
1367                        (DimName::Batch, 1),
1368                        (DimName::NumClasses, 2),
1369                        (DimName::Height, 160),
1370                        (DimName::Width, 160),
1371                    ],
1372                },
1373            )
1374            .with_iou_threshold(iou_threshold)
1375            .with_score_threshold(score_threshold)
1376            .build()
1377            .unwrap();
1378        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1379        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1380        decoder
1381            .decode_quantized(
1382                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1383                &mut output_boxes,
1384                &mut output_masks,
1385            )
1386            .unwrap();
1387
1388        let mut mask = seg.slice(s![0, .., .., ..]);
1389        mask.swap_axes(0, 1);
1390        mask.swap_axes(1, 2);
1391        let mask = [Segmentation {
1392            xmin: 0.0,
1393            ymin: 0.0,
1394            xmax: 1.0,
1395            ymax: 1.0,
1396            segmentation: mask.into_owned(),
1397        }];
1398        let correct_boxes = [DetectBox {
1399            bbox: BoundingBox {
1400                xmin: 0.40513772,
1401                ymin: 0.6379755,
1402                xmax: 0.5122431,
1403                ymax: 0.7730214,
1404            },
1405            score: 0.4861709,
1406            label: 0,
1407        }];
1408        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1409
1410        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1411        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1412        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1413        decoder
1414            .decode_float::<f32>(
1415                &[
1416                    scores.view().into_dyn(),
1417                    boxes.view().into_dyn(),
1418                    seg.view().into_dyn(),
1419                ],
1420                &mut output_boxes,
1421                &mut output_masks,
1422            )
1423            .unwrap();
1424
1425        // not expected for float segmentation decoder to have same values as quantized
1426        // segmentation decoder, as float decoder ensures the data fills 0-255,
1427        // quantized decoder uses whatever the model output. Thus the float
1428        // output is the same as the quantized output but scaled differently.
1429        // However, it is expected that the mask after argmax will be the same.
1430        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1431        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1432        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1433
1434        assert_eq!(mask0, mask1);
1435    }
1436
1437    #[test]
1438    fn test_modelpack_segdet_split() {
1439        let score_threshold = 0.8;
1440        let iou_threshold = 0.5;
1441
1442        let seg = edgefirst_bench::testdata::read("modelpack_seg_2x160x160.bin");
1443        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1444
1445        let detect0 = edgefirst_bench::testdata::read("modelpack_split_9x15x18.bin");
1446        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1447
1448        let detect1 = edgefirst_bench::testdata::read("modelpack_split_17x30x18.bin");
1449        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1450
1451        let quant0 = (0.08547406643629074, 174).into();
1452        let quant1 = (0.09929127991199493, 183).into();
1453        let quant_seg = (1.0 / 255.0, 0).into();
1454
1455        let anchors0 = vec![
1456            [0.36666667461395264, 0.31481480598449707],
1457            [0.38749998807907104, 0.4740740656852722],
1458            [0.5333333611488342, 0.644444465637207],
1459        ];
1460        let anchors1 = vec![
1461            [0.13750000298023224, 0.2074074000120163],
1462            [0.2541666626930237, 0.21481481194496155],
1463            [0.23125000298023224, 0.35185185074806213],
1464        ];
1465
1466        let decoder = DecoderBuilder::default()
1467            .with_config_modelpack_segdet_split(
1468                vec![
1469                    configs::Detection {
1470                        decoder: DecoderType::ModelPack,
1471                        shape: vec![1, 17, 30, 18],
1472                        anchors: Some(anchors1),
1473                        quantization: Some(quant1),
1474                        dshape: vec![
1475                            (DimName::Batch, 1),
1476                            (DimName::Height, 17),
1477                            (DimName::Width, 30),
1478                            (DimName::NumAnchorsXFeatures, 18),
1479                        ],
1480                        normalized: Some(true),
1481                    },
1482                    configs::Detection {
1483                        decoder: DecoderType::ModelPack,
1484                        shape: vec![1, 9, 15, 18],
1485                        anchors: Some(anchors0),
1486                        quantization: Some(quant0),
1487                        dshape: vec![
1488                            (DimName::Batch, 1),
1489                            (DimName::Height, 9),
1490                            (DimName::Width, 15),
1491                            (DimName::NumAnchorsXFeatures, 18),
1492                        ],
1493                        normalized: Some(true),
1494                    },
1495                ],
1496                configs::Segmentation {
1497                    decoder: DecoderType::ModelPack,
1498                    quantization: Some(quant_seg),
1499                    shape: vec![1, 2, 160, 160],
1500                    dshape: vec![
1501                        (DimName::Batch, 1),
1502                        (DimName::NumClasses, 2),
1503                        (DimName::Height, 160),
1504                        (DimName::Width, 160),
1505                    ],
1506                },
1507            )
1508            .with_score_threshold(score_threshold)
1509            .with_iou_threshold(iou_threshold)
1510            .build()
1511            .unwrap();
1512        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1513        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1514        decoder
1515            .decode_quantized(
1516                &[
1517                    detect0.view().into(),
1518                    detect1.view().into(),
1519                    seg.view().into(),
1520                ],
1521                &mut output_boxes,
1522                &mut output_masks,
1523            )
1524            .unwrap();
1525
1526        let mut mask = seg.slice(s![0, .., .., ..]);
1527        mask.swap_axes(0, 1);
1528        mask.swap_axes(1, 2);
1529        let mask = [Segmentation {
1530            xmin: 0.0,
1531            ymin: 0.0,
1532            xmax: 1.0,
1533            ymax: 1.0,
1534            segmentation: mask.into_owned(),
1535        }];
1536        let correct_boxes = [DetectBox {
1537            bbox: BoundingBox {
1538                xmin: 0.43171933,
1539                ymin: 0.68243736,
1540                xmax: 0.5626645,
1541                ymax: 0.808863,
1542            },
1543            score: 0.99240804,
1544            label: 0,
1545        }];
1546        println!("Output Boxes: {:?}", output_boxes);
1547        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1548
1549        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1550        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1551        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1552        decoder
1553            .decode_float::<f32>(
1554                &[
1555                    detect0.view().into_dyn(),
1556                    detect1.view().into_dyn(),
1557                    seg.view().into_dyn(),
1558                ],
1559                &mut output_boxes,
1560                &mut output_masks,
1561            )
1562            .unwrap();
1563
1564        // not expected for float segmentation decoder to have same values as quantized
1565        // segmentation decoder, as float decoder ensures the data fills 0-255,
1566        // quantized decoder uses whatever the model output. Thus the float
1567        // output is the same as the quantized output but scaled differently.
1568        // However, it is expected that the mask after argmax will be the same.
1569        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1570        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1571        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1572
1573        assert_eq!(mask0, mask1);
1574    }
1575
1576    #[test]
1577    fn test_dequant_chunked() {
1578        let mut out = load_yolov8s_det().into_raw_vec_and_offset().0;
1579        out.push(123); // make sure to test non multiple of 16 length
1580
1581        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1582        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1583        let quant = Quantization::new(0.0040811873, -123);
1584        dequantize_cpu(&out, quant, &mut out_dequant);
1585
1586        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1587        assert_eq!(out_dequant, out_dequant_simd);
1588
1589        let quant = Quantization::new(0.0040811873, 0);
1590        dequantize_cpu(&out, quant, &mut out_dequant);
1591
1592        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1593        assert_eq!(out_dequant, out_dequant_simd);
1594    }
1595
1596    #[test]
1597    fn test_dequant_ground_truth() {
1598        // Formula: output = (input - zero_point) * scale
1599        // Verify both dequantize_cpu and dequantize_cpu_chunked against hand-computed values.
1600
1601        // Case 1: scale=0.1, zero_point=-128 (from doc example)
1602        let quant = Quantization::new(0.1, -128);
1603        let input: Vec<i8> = vec![0, 127, -128, 64];
1604        let mut output = vec![0.0f32; 4];
1605        let mut output_chunked = vec![0.0f32; 4];
1606        dequantize_cpu(&input, quant, &mut output);
1607        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1608        // (0 - (-128)) * 0.1 = 12.8
1609        // (127 - (-128)) * 0.1 = 25.5
1610        // (-128 - (-128)) * 0.1 = 0.0
1611        // (64 - (-128)) * 0.1 = 19.2
1612        let expected: Vec<f32> = vec![12.8, 25.5, 0.0, 19.2];
1613        for (i, (&out, &exp)) in output.iter().zip(expected.iter()).enumerate() {
1614            assert!((out - exp).abs() < 1e-5, "cpu[{i}]: {out} != {exp}");
1615        }
1616        for (i, (&out, &exp)) in output_chunked.iter().zip(expected.iter()).enumerate() {
1617            assert!((out - exp).abs() < 1e-5, "chunked[{i}]: {out} != {exp}");
1618        }
1619
1620        // Case 2: scale=1.0, zero_point=0 (identity-like)
1621        let quant = Quantization::new(1.0, 0);
1622        dequantize_cpu(&input, quant, &mut output);
1623        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1624        let expected: Vec<f32> = vec![0.0, 127.0, -128.0, 64.0];
1625        assert_eq!(output, expected);
1626        assert_eq!(output_chunked, expected);
1627
1628        // Case 3: scale=0.5, zero_point=0
1629        let quant = Quantization::new(0.5, 0);
1630        dequantize_cpu(&input, quant, &mut output);
1631        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1632        let expected: Vec<f32> = vec![0.0, 63.5, -64.0, 32.0];
1633        assert_eq!(output, expected);
1634        assert_eq!(output_chunked, expected);
1635
1636        // Case 4: i8 min/max boundaries with typical quantization params
1637        let quant = Quantization::new(0.021287762, 31);
1638        let input: Vec<i8> = vec![-128, -1, 0, 1, 31, 127];
1639        let mut output = vec![0.0f32; 6];
1640        let mut output_chunked = vec![0.0f32; 6];
1641        dequantize_cpu(&input, quant, &mut output);
1642        dequantize_cpu_chunked(&input, quant, &mut output_chunked);
1643        for i in 0..6 {
1644            let expected = (input[i] as f32 - 31.0) * 0.021287762;
1645            assert!(
1646                (output[i] - expected).abs() < 1e-5,
1647                "cpu[{i}]: {} != {expected}",
1648                output[i]
1649            );
1650            assert!(
1651                (output_chunked[i] - expected).abs() < 1e-5,
1652                "chunked[{i}]: {} != {expected}",
1653                output_chunked[i]
1654            );
1655        }
1656    }
1657
1658    #[test]
1659    fn test_decoder_yolo_det() {
1660        let score_threshold = 0.25;
1661        let iou_threshold = 0.7;
1662        let out = load_yolov8s_det();
1663        let quant = (0.0040811873, -123).into();
1664
1665        let decoder = DecoderBuilder::default()
1666            .with_config_yolo_det(
1667                configs::Detection {
1668                    decoder: DecoderType::Ultralytics,
1669                    shape: vec![1, 84, 8400],
1670                    anchors: None,
1671                    quantization: Some(quant),
1672                    dshape: vec![
1673                        (DimName::Batch, 1),
1674                        (DimName::NumFeatures, 84),
1675                        (DimName::NumBoxes, 8400),
1676                    ],
1677                    normalized: Some(true),
1678                },
1679                Some(DecoderVersion::Yolo11),
1680            )
1681            .with_score_threshold(score_threshold)
1682            .with_iou_threshold(iou_threshold)
1683            .build()
1684            .unwrap();
1685
1686        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1687        decode_yolo_det(
1688            (out.slice(s![0, .., ..]), quant.into()),
1689            score_threshold,
1690            iou_threshold,
1691            Some(configs::Nms::ClassAgnostic),
1692            &mut output_boxes,
1693        );
1694        assert!(output_boxes[0].equal_within_delta(
1695            &DetectBox {
1696                bbox: BoundingBox {
1697                    xmin: 0.5285137,
1698                    ymin: 0.05305544,
1699                    xmax: 0.87541467,
1700                    ymax: 0.9998909,
1701                },
1702                score: 0.5591227,
1703                label: 0
1704            },
1705            1e-6
1706        ));
1707
1708        assert!(output_boxes[1].equal_within_delta(
1709            &DetectBox {
1710                bbox: BoundingBox {
1711                    xmin: 0.130598,
1712                    ymin: 0.43260583,
1713                    xmax: 0.35098213,
1714                    ymax: 0.9958097,
1715                },
1716                score: 0.33057618,
1717                label: 75
1718            },
1719            1e-6
1720        ));
1721
1722        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1723        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1724        decoder
1725            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1726            .unwrap();
1727
1728        let out = dequantize_ndarray(out.view(), quant.into());
1729        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1730        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1731        decoder
1732            .decode_float::<f32>(
1733                &[out.view().into_dyn()],
1734                &mut output_boxes_f32,
1735                &mut output_masks_f32,
1736            )
1737            .unwrap();
1738
1739        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1740        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1741    }
1742
1743    #[test]
1744    fn test_decoder_masks() {
1745        let score_threshold = 0.45;
1746        let iou_threshold = 0.45;
1747        let boxes = load_yolov8_boxes();
1748        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1749
1750        let protos = load_yolov8_protos();
1751        let quant_protos = Quantization::new(0.02491161972284317, -117);
1752        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1753        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1754        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1755        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1756        decode_yolo_segdet_float(
1757            seg.slice(s![0, .., ..]),
1758            protos.slice(s![0, .., .., ..]),
1759            score_threshold,
1760            iou_threshold,
1761            Some(configs::Nms::ClassAgnostic),
1762            &mut output_boxes,
1763            &mut output_masks,
1764        )
1765        .unwrap();
1766        assert_eq!(output_boxes.len(), 2);
1767        assert_eq!(output_boxes.len(), output_masks.len());
1768
1769        for (b, m) in output_boxes.iter().zip(&output_masks) {
1770            // Mask region is the proto-grid-aligned crop (floor for min,
1771            // ceil for max), so it encloses the post-NMS bbox (EDGEAI-1304).
1772            assert!(b.bbox.xmin >= m.xmin);
1773            assert!(b.bbox.ymin >= m.ymin);
1774            assert!(b.bbox.xmax <= m.xmax);
1775            assert!(b.bbox.ymax <= m.ymax);
1776        }
1777        assert!(output_boxes[0].equal_within_delta(
1778            &DetectBox {
1779                bbox: BoundingBox {
1780                    xmin: 0.08515105,
1781                    ymin: 0.7131401,
1782                    xmax: 0.29802868,
1783                    ymax: 0.8195788,
1784                },
1785                score: 0.91537374,
1786                label: 23
1787            },
1788            1.0 / 160.0, // wider range because mask will expand the box
1789        ));
1790
1791        assert!(output_boxes[1].equal_within_delta(
1792            &DetectBox {
1793                bbox: BoundingBox {
1794                    xmin: 0.59605736,
1795                    ymin: 0.25545314,
1796                    xmax: 0.93666154,
1797                    ymax: 0.72378385,
1798                },
1799                score: 0.91537374,
1800                label: 23
1801            },
1802            1.0 / 160.0, // wider range because mask will expand the box
1803        ));
1804
1805        let full_mask = edgefirst_bench::testdata::read("yolov8_mask_results.bin");
1806        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1807
1808        let cropped_mask = full_mask.slice(ndarray::s![
1809            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1810            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1811        ]);
1812
1813        assert_eq!(
1814            cropped_mask,
1815            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1816        );
1817    }
1818
1819    /// Regression test: config-driven path with physically-NCHW protos.
1820    /// Simulates YOLOv8-seg ONNX outputs where the producer emits protos
1821    /// as `(1, 32, 160, 160)` in CHW memory order — the caller declares
1822    /// shape and dshape matching that physical order and HAL permutes to
1823    /// canonical HWC via `swap_axes_if_needed`.
1824    ///
1825    /// This is the counterpart to the NHWC-producer case (TFLite /
1826    /// Ara-2) where shape+dshape are `(1, 160, 160, 32)` +
1827    /// `[batch, height, width, num_protos]` and no reordering is needed.
1828    #[test]
1829    fn test_decoder_masks_nchw_protos() {
1830        let score_threshold = 0.45;
1831        let iou_threshold = 0.45;
1832
1833        // Load test data — boxes as [116, 8400]
1834        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
1835        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1836
1837        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1838        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
1839        let quant_protos = Quantization::new(0.02491161972284317, -117);
1840        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1841
1842        // ---- Reference: direct call with HWC protos (known working) ----
1843        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1844        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1845        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1846        decode_yolo_segdet_float(
1847            seg.view(),
1848            protos_f32_hwc.view(),
1849            score_threshold,
1850            iou_threshold,
1851            Some(configs::Nms::ClassAgnostic),
1852            &mut ref_boxes,
1853            &mut ref_masks,
1854        )
1855        .unwrap();
1856        assert_eq!(ref_boxes.len(), 2);
1857
1858        // ---- Config-driven path: NCHW protos declared in physical order ----
1859        // Permute the HWC test data to CHW memory order — this is what an
1860        // ONNX-style producer would emit into the tensor buffer.
1861        // `to_owned` materialises a C-contiguous Array3<f32> with CHW
1862        // strides, matching a producer that writes channels-outer.
1863        let protos_f32_chw_view = protos_f32_hwc.view().permuted_axes([2, 0, 1]); // [32, 160, 160]
1864        let protos_f32_chw = protos_f32_chw_view.to_owned();
1865        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1866
1867        // Build boxes as [1, 116, 8400] f32
1868        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1869
1870        // Declare shape and dshape in physical memory order (outermost
1871        // first). `swap_axes_if_needed` uses dshape role lookup to
1872        // permute the stride tuple into canonical HWC.
1873        let decoder = DecoderBuilder::default()
1874            .with_config_yolo_segdet(
1875                configs::Detection {
1876                    decoder: configs::DecoderType::Ultralytics,
1877                    quantization: None,
1878                    shape: vec![1, 116, 8400],
1879                    dshape: vec![
1880                        (configs::DimName::Batch, 1),
1881                        (configs::DimName::NumFeatures, 116),
1882                        (configs::DimName::NumBoxes, 8400),
1883                    ],
1884                    normalized: Some(true),
1885                    anchors: None,
1886                },
1887                configs::Protos {
1888                    decoder: configs::DecoderType::Ultralytics,
1889                    quantization: None,
1890                    shape: vec![1, 32, 160, 160],
1891                    dshape: vec![
1892                        (configs::DimName::Batch, 1),
1893                        (configs::DimName::NumProtos, 32),
1894                        (configs::DimName::Height, 160),
1895                        (configs::DimName::Width, 160),
1896                    ],
1897                },
1898                None, // decoder version
1899            )
1900            .with_score_threshold(score_threshold)
1901            .with_iou_threshold(iou_threshold)
1902            .build()
1903            .unwrap();
1904
1905        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1906        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1907        decoder
1908            .decode_float(
1909                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1910                &mut cfg_boxes,
1911                &mut cfg_masks,
1912            )
1913            .unwrap();
1914
1915        // Must produce the same number of detections
1916        assert_eq!(
1917            cfg_boxes.len(),
1918            ref_boxes.len(),
1919            "config path produced {} boxes, reference produced {}",
1920            cfg_boxes.len(),
1921            ref_boxes.len()
1922        );
1923
1924        // Boxes must match
1925        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1926            assert!(
1927                cb.equal_within_delta(rb, 0.01),
1928                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1929            );
1930        }
1931
1932        // Masks must match pixel-for-pixel
1933        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1934            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1935            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1936            assert_eq!(
1937                cm_arr, rm_arr,
1938                "mask {i} pixel mismatch between config-driven and reference paths"
1939            );
1940        }
1941    }
1942
1943    #[test]
1944    fn test_decoder_masks_i8() {
1945        let score_threshold = 0.45;
1946        let iou_threshold = 0.45;
1947        let boxes = load_yolov8_boxes();
1948        let quant_boxes = (0.021287761628627777, 31).into();
1949
1950        let protos = load_yolov8_protos();
1951        let quant_protos = (0.02491161972284317, -117).into();
1952        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1953        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1954
1955        let decoder = DecoderBuilder::default()
1956            .with_config_yolo_segdet(
1957                configs::Detection {
1958                    decoder: configs::DecoderType::Ultralytics,
1959                    quantization: Some(quant_boxes),
1960                    shape: vec![1, 116, 8400],
1961                    anchors: None,
1962                    dshape: vec![
1963                        (DimName::Batch, 1),
1964                        (DimName::NumFeatures, 116),
1965                        (DimName::NumBoxes, 8400),
1966                    ],
1967                    normalized: Some(true),
1968                },
1969                Protos {
1970                    decoder: configs::DecoderType::Ultralytics,
1971                    quantization: Some(quant_protos),
1972                    shape: vec![1, 160, 160, 32],
1973                    dshape: vec![
1974                        (DimName::Batch, 1),
1975                        (DimName::Height, 160),
1976                        (DimName::Width, 160),
1977                        (DimName::NumProtos, 32),
1978                    ],
1979                },
1980                Some(DecoderVersion::Yolo11),
1981            )
1982            .with_score_threshold(score_threshold)
1983            .with_iou_threshold(iou_threshold)
1984            .build()
1985            .unwrap();
1986
1987        let quant_boxes = quant_boxes.into();
1988        let quant_protos = quant_protos.into();
1989
1990        decode_yolo_segdet_quant(
1991            (boxes.slice(s![0, .., ..]), quant_boxes),
1992            (protos.slice(s![0, .., .., ..]), quant_protos),
1993            score_threshold,
1994            iou_threshold,
1995            Some(configs::Nms::ClassAgnostic),
1996            &mut output_boxes,
1997            &mut output_masks,
1998        )
1999        .unwrap();
2000
2001        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2002        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2003
2004        decoder
2005            .decode_quantized(
2006                &[boxes.view().into(), protos.view().into()],
2007                &mut output_boxes1,
2008                &mut output_masks1,
2009            )
2010            .unwrap();
2011
2012        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2013        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2014
2015        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2016        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2017        decode_yolo_segdet_float(
2018            seg.slice(s![0, .., ..]),
2019            protos.slice(s![0, .., .., ..]),
2020            score_threshold,
2021            iou_threshold,
2022            Some(configs::Nms::ClassAgnostic),
2023            &mut output_boxes_f32,
2024            &mut output_masks_f32,
2025        )
2026        .unwrap();
2027
2028        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
2029        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
2030
2031        decoder
2032            .decode_float(
2033                &[seg.view().into_dyn(), protos.view().into_dyn()],
2034                &mut output_boxes1_f32,
2035                &mut output_masks1_f32,
2036            )
2037            .unwrap();
2038
2039        compare_outputs(
2040            (&output_boxes, &output_boxes1),
2041            (&output_masks, &output_masks1),
2042        );
2043
2044        compare_outputs(
2045            (&output_boxes, &output_boxes_f32),
2046            (&output_masks, &output_masks_f32),
2047        );
2048
2049        compare_outputs(
2050            (&output_boxes_f32, &output_boxes1_f32),
2051            (&output_masks_f32, &output_masks1_f32),
2052        );
2053    }
2054
2055    #[test]
2056    fn test_decoder_yolo_split() {
2057        let score_threshold = 0.45;
2058        let iou_threshold = 0.45;
2059        let boxes = load_yolov8_boxes();
2060        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2061        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2062
2063        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2064
2065        let decoder = DecoderBuilder::default()
2066            .with_config_yolo_split_det(
2067                configs::Boxes {
2068                    decoder: configs::DecoderType::Ultralytics,
2069                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2070                    shape: vec![1, 4, 8400],
2071                    dshape: vec![
2072                        (DimName::Batch, 1),
2073                        (DimName::BoxCoords, 4),
2074                        (DimName::NumBoxes, 8400),
2075                    ],
2076                    normalized: Some(true),
2077                },
2078                configs::Scores {
2079                    decoder: configs::DecoderType::Ultralytics,
2080                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2081                    shape: vec![1, 80, 8400],
2082                    dshape: vec![
2083                        (DimName::Batch, 1),
2084                        (DimName::NumClasses, 80),
2085                        (DimName::NumBoxes, 8400),
2086                    ],
2087                },
2088            )
2089            .with_score_threshold(score_threshold)
2090            .with_iou_threshold(iou_threshold)
2091            .build()
2092            .unwrap();
2093
2094        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2095        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2096
2097        decoder
2098            .decode_quantized(
2099                &[
2100                    boxes.slice(s![.., ..4, ..]).into(),
2101                    boxes.slice(s![.., 4..84, ..]).into(),
2102                ],
2103                &mut output_boxes,
2104                &mut output_masks,
2105            )
2106            .unwrap();
2107
2108        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2109        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2110        decode_yolo_det_float(
2111            seg.slice(s![0, ..84, ..]),
2112            score_threshold,
2113            iou_threshold,
2114            Some(configs::Nms::ClassAgnostic),
2115            &mut output_boxes_f32,
2116        );
2117
2118        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2119        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2120
2121        decoder
2122            .decode_float(
2123                &[
2124                    seg.slice(s![.., ..4, ..]).into_dyn(),
2125                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2126                ],
2127                &mut output_boxes1,
2128                &mut output_masks1,
2129            )
2130            .unwrap();
2131        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2132        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2133    }
2134
2135    #[test]
2136    fn test_decoder_masks_config_mixed() {
2137        let score_threshold = 0.45;
2138        let iou_threshold = 0.45;
2139        let boxes_raw = load_yolov8_boxes();
2140        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i16 * 256).collect();
2141        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2142
2143        let quant_boxes = (0.021287761628627777 / 256.0, 31 * 256);
2144
2145        let protos = load_yolov8_protos();
2146        let quant_protos = (0.02491161972284317, -117);
2147
2148        let decoder = build_yolo_split_segdet_decoder(
2149            score_threshold,
2150            iou_threshold,
2151            quant_boxes,
2152            quant_protos,
2153        );
2154        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2155        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2156
2157        decoder
2158            .decode_quantized(
2159                &[
2160                    boxes.slice(s![.., ..4, ..]).into(),
2161                    boxes.slice(s![.., 4..84, ..]).into(),
2162                    boxes.slice(s![.., 84.., ..]).into(),
2163                    protos.view().into(),
2164                ],
2165                &mut output_boxes,
2166                &mut output_masks,
2167            )
2168            .unwrap();
2169
2170        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2171        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2172        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2173        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2174        decode_yolo_segdet_float(
2175            seg.slice(s![0, .., ..]),
2176            protos.slice(s![0, .., .., ..]),
2177            score_threshold,
2178            iou_threshold,
2179            Some(configs::Nms::ClassAgnostic),
2180            &mut output_boxes_f32,
2181            &mut output_masks_f32,
2182        )
2183        .unwrap();
2184
2185        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2186        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2187
2188        decoder
2189            .decode_float(
2190                &[
2191                    seg.slice(s![.., ..4, ..]).into_dyn(),
2192                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2193                    seg.slice(s![.., 84.., ..]).into_dyn(),
2194                    protos.view().into_dyn(),
2195                ],
2196                &mut output_boxes1,
2197                &mut output_masks1,
2198            )
2199            .unwrap();
2200        compare_outputs(
2201            (&output_boxes, &output_boxes_f32),
2202            (&output_masks, &output_masks_f32),
2203        );
2204        compare_outputs(
2205            (&output_boxes_f32, &output_boxes1),
2206            (&output_masks_f32, &output_masks1),
2207        );
2208    }
2209
2210    fn build_yolo_split_segdet_decoder(
2211        score_threshold: f32,
2212        iou_threshold: f32,
2213        quant_boxes: (f32, i32),
2214        quant_protos: (f32, i32),
2215    ) -> crate::Decoder {
2216        DecoderBuilder::default()
2217            .with_config_yolo_split_segdet(
2218                configs::Boxes {
2219                    decoder: configs::DecoderType::Ultralytics,
2220                    quantization: Some(quant_boxes.into()),
2221                    shape: vec![1, 4, 8400],
2222                    dshape: vec![
2223                        (DimName::Batch, 1),
2224                        (DimName::BoxCoords, 4),
2225                        (DimName::NumBoxes, 8400),
2226                    ],
2227                    normalized: Some(true),
2228                },
2229                configs::Scores {
2230                    decoder: configs::DecoderType::Ultralytics,
2231                    quantization: Some(quant_boxes.into()),
2232                    shape: vec![1, 80, 8400],
2233                    dshape: vec![
2234                        (DimName::Batch, 1),
2235                        (DimName::NumClasses, 80),
2236                        (DimName::NumBoxes, 8400),
2237                    ],
2238                },
2239                configs::MaskCoefficients {
2240                    decoder: configs::DecoderType::Ultralytics,
2241                    quantization: Some(quant_boxes.into()),
2242                    shape: vec![1, 32, 8400],
2243                    dshape: vec![
2244                        (DimName::Batch, 1),
2245                        (DimName::NumProtos, 32),
2246                        (DimName::NumBoxes, 8400),
2247                    ],
2248                },
2249                configs::Protos {
2250                    decoder: configs::DecoderType::Ultralytics,
2251                    quantization: Some(quant_protos.into()),
2252                    shape: vec![1, 160, 160, 32],
2253                    dshape: vec![
2254                        (DimName::Batch, 1),
2255                        (DimName::Height, 160),
2256                        (DimName::Width, 160),
2257                        (DimName::NumProtos, 32),
2258                    ],
2259                },
2260            )
2261            .with_score_threshold(score_threshold)
2262            .with_iou_threshold(iou_threshold)
2263            .build()
2264            .unwrap()
2265    }
2266
2267    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
2268        let config_yaml = edgefirst_bench::testdata::read_to_string("yolov8_seg.yaml");
2269        DecoderBuilder::default()
2270            .with_config_yaml_str(config_yaml.to_string())
2271            .with_score_threshold(score_threshold)
2272            .with_iou_threshold(iou_threshold)
2273            .build()
2274            .unwrap()
2275    }
2276    #[test]
2277    fn test_decoder_masks_config_i32() {
2278        let score_threshold = 0.45;
2279        let iou_threshold = 0.45;
2280        let boxes_raw = load_yolov8_boxes();
2281        let scale = 1 << 23;
2282        let boxes: Vec<_> = boxes_raw.iter().map(|x| *x as i32 * scale).collect();
2283        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2284
2285        let quant_boxes = (0.021287761628627777 / scale as f32, 31 * scale);
2286
2287        let protos_raw = load_yolov8_protos();
2288        let protos: Vec<_> = protos_raw.iter().map(|x| *x as i32 * scale).collect();
2289        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
2290        let quant_protos = (0.02491161972284317 / scale as f32, -117 * scale);
2291
2292        let decoder = build_yolo_split_segdet_decoder(
2293            score_threshold,
2294            iou_threshold,
2295            quant_boxes,
2296            quant_protos,
2297        );
2298
2299        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2300        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2301
2302        decoder
2303            .decode_quantized(
2304                &[
2305                    boxes.slice(s![.., ..4, ..]).into(),
2306                    boxes.slice(s![.., 4..84, ..]).into(),
2307                    boxes.slice(s![.., 84.., ..]).into(),
2308                    protos.view().into(),
2309                ],
2310                &mut output_boxes,
2311                &mut output_masks,
2312            )
2313            .unwrap();
2314
2315        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos.into());
2316        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes.into());
2317        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2318        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2319        decode_yolo_segdet_float(
2320            seg.slice(s![0, .., ..]),
2321            protos.slice(s![0, .., .., ..]),
2322            score_threshold,
2323            iou_threshold,
2324            Some(configs::Nms::ClassAgnostic),
2325            &mut output_boxes_f32,
2326            &mut output_masks_f32,
2327        )
2328        .unwrap();
2329
2330        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2331        assert_eq!(output_masks.len(), output_masks_f32.len());
2332
2333        compare_outputs(
2334            (&output_boxes, &output_boxes_f32),
2335            (&output_masks, &output_masks_f32),
2336        );
2337    }
2338
2339    /// test running multiple decoders concurrently
2340    #[test]
2341    fn test_context_switch() {
2342        let yolo_det = || {
2343            let score_threshold = 0.25;
2344            let iou_threshold = 0.7;
2345            let out = load_yolov8s_det();
2346            let quant = (0.0040811873, -123).into();
2347
2348            let decoder = DecoderBuilder::default()
2349                .with_config_yolo_det(
2350                    configs::Detection {
2351                        decoder: DecoderType::Ultralytics,
2352                        shape: vec![1, 84, 8400],
2353                        anchors: None,
2354                        quantization: Some(quant),
2355                        dshape: vec![
2356                            (DimName::Batch, 1),
2357                            (DimName::NumFeatures, 84),
2358                            (DimName::NumBoxes, 8400),
2359                        ],
2360                        normalized: None,
2361                    },
2362                    None,
2363                )
2364                .with_score_threshold(score_threshold)
2365                .with_iou_threshold(iou_threshold)
2366                .build()
2367                .unwrap();
2368
2369            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2370            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2371
2372            for _ in 0..100 {
2373                decoder
2374                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2375                    .unwrap();
2376
2377                assert!(output_boxes[0].equal_within_delta(
2378                    &DetectBox {
2379                        bbox: BoundingBox {
2380                            xmin: 0.5285137,
2381                            ymin: 0.05305544,
2382                            xmax: 0.87541467,
2383                            ymax: 0.9998909,
2384                        },
2385                        score: 0.5591227,
2386                        label: 0
2387                    },
2388                    1e-6
2389                ));
2390
2391                assert!(output_boxes[1].equal_within_delta(
2392                    &DetectBox {
2393                        bbox: BoundingBox {
2394                            xmin: 0.130598,
2395                            ymin: 0.43260583,
2396                            xmax: 0.35098213,
2397                            ymax: 0.9958097,
2398                        },
2399                        score: 0.33057618,
2400                        label: 75
2401                    },
2402                    1e-6
2403                ));
2404                assert!(output_masks.is_empty());
2405            }
2406        };
2407
2408        let modelpack_det_split = || {
2409            let score_threshold = 0.8;
2410            let iou_threshold = 0.5;
2411
2412            let seg = edgefirst_bench::testdata::read("modelpack_seg_2x160x160.bin");
2413            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2414
2415            let detect0 = edgefirst_bench::testdata::read("modelpack_split_9x15x18.bin");
2416            let detect0 =
2417                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2418
2419            let detect1 = edgefirst_bench::testdata::read("modelpack_split_17x30x18.bin");
2420            let detect1 =
2421                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2422
2423            let mut mask = seg.slice(s![0, .., .., ..]);
2424            mask.swap_axes(0, 1);
2425            mask.swap_axes(1, 2);
2426            let mask = [Segmentation {
2427                xmin: 0.0,
2428                ymin: 0.0,
2429                xmax: 1.0,
2430                ymax: 1.0,
2431                segmentation: mask.into_owned(),
2432            }];
2433            let correct_boxes = [DetectBox {
2434                bbox: BoundingBox {
2435                    xmin: 0.43171933,
2436                    ymin: 0.68243736,
2437                    xmax: 0.5626645,
2438                    ymax: 0.808863,
2439                },
2440                score: 0.99240804,
2441                label: 0,
2442            }];
2443
2444            let quant0 = (0.08547406643629074, 174).into();
2445            let quant1 = (0.09929127991199493, 183).into();
2446            let quant_seg = (1.0 / 255.0, 0).into();
2447
2448            let anchors0 = vec![
2449                [0.36666667461395264, 0.31481480598449707],
2450                [0.38749998807907104, 0.4740740656852722],
2451                [0.5333333611488342, 0.644444465637207],
2452            ];
2453            let anchors1 = vec![
2454                [0.13750000298023224, 0.2074074000120163],
2455                [0.2541666626930237, 0.21481481194496155],
2456                [0.23125000298023224, 0.35185185074806213],
2457            ];
2458
2459            let decoder = DecoderBuilder::default()
2460                .with_config_modelpack_segdet_split(
2461                    vec![
2462                        configs::Detection {
2463                            decoder: DecoderType::ModelPack,
2464                            shape: vec![1, 17, 30, 18],
2465                            anchors: Some(anchors1),
2466                            quantization: Some(quant1),
2467                            dshape: vec![
2468                                (DimName::Batch, 1),
2469                                (DimName::Height, 17),
2470                                (DimName::Width, 30),
2471                                (DimName::NumAnchorsXFeatures, 18),
2472                            ],
2473                            normalized: None,
2474                        },
2475                        configs::Detection {
2476                            decoder: DecoderType::ModelPack,
2477                            shape: vec![1, 9, 15, 18],
2478                            anchors: Some(anchors0),
2479                            quantization: Some(quant0),
2480                            dshape: vec![
2481                                (DimName::Batch, 1),
2482                                (DimName::Height, 9),
2483                                (DimName::Width, 15),
2484                                (DimName::NumAnchorsXFeatures, 18),
2485                            ],
2486                            normalized: None,
2487                        },
2488                    ],
2489                    configs::Segmentation {
2490                        decoder: DecoderType::ModelPack,
2491                        quantization: Some(quant_seg),
2492                        shape: vec![1, 2, 160, 160],
2493                        dshape: vec![
2494                            (DimName::Batch, 1),
2495                            (DimName::NumClasses, 2),
2496                            (DimName::Height, 160),
2497                            (DimName::Width, 160),
2498                        ],
2499                    },
2500                )
2501                .with_score_threshold(score_threshold)
2502                .with_iou_threshold(iou_threshold)
2503                .build()
2504                .unwrap();
2505            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2506            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2507
2508            for _ in 0..100 {
2509                decoder
2510                    .decode_quantized(
2511                        &[
2512                            detect0.view().into(),
2513                            detect1.view().into(),
2514                            seg.view().into(),
2515                        ],
2516                        &mut output_boxes,
2517                        &mut output_masks,
2518                    )
2519                    .unwrap();
2520
2521                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2522            }
2523        };
2524
2525        let handles = vec![
2526            std::thread::spawn(yolo_det),
2527            std::thread::spawn(modelpack_det_split),
2528            std::thread::spawn(yolo_det),
2529            std::thread::spawn(modelpack_det_split),
2530            std::thread::spawn(yolo_det),
2531            std::thread::spawn(modelpack_det_split),
2532            std::thread::spawn(yolo_det),
2533            std::thread::spawn(modelpack_det_split),
2534        ];
2535        for handle in handles {
2536            handle.join().unwrap();
2537        }
2538    }
2539
2540    #[test]
2541    fn test_ndarray_to_xyxy_float() {
2542        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2543        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2544        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2545
2546        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2547        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2548        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2549    }
2550
2551    #[test]
2552    fn test_class_aware_nms_float() {
2553        use crate::float::nms_class_aware_float;
2554
2555        // Create two overlapping boxes with different classes
2556        let boxes = vec![
2557            DetectBox {
2558                bbox: BoundingBox {
2559                    xmin: 0.0,
2560                    ymin: 0.0,
2561                    xmax: 0.5,
2562                    ymax: 0.5,
2563                },
2564                score: 0.9,
2565                label: 0, // class 0
2566            },
2567            DetectBox {
2568                bbox: BoundingBox {
2569                    xmin: 0.1,
2570                    ymin: 0.1,
2571                    xmax: 0.6,
2572                    ymax: 0.6,
2573                },
2574                score: 0.8,
2575                label: 1, // class 1 - different class
2576            },
2577        ];
2578
2579        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2580        // threshold 0.3)
2581        let result = nms_class_aware_float(0.3, None, boxes.clone());
2582        assert_eq!(
2583            result.len(),
2584            2,
2585            "Class-aware NMS should keep both boxes with different classes"
2586        );
2587
2588        // Now test with same class - should suppress one
2589        let same_class_boxes = vec![
2590            DetectBox {
2591                bbox: BoundingBox {
2592                    xmin: 0.0,
2593                    ymin: 0.0,
2594                    xmax: 0.5,
2595                    ymax: 0.5,
2596                },
2597                score: 0.9,
2598                label: 0,
2599            },
2600            DetectBox {
2601                bbox: BoundingBox {
2602                    xmin: 0.1,
2603                    ymin: 0.1,
2604                    xmax: 0.6,
2605                    ymax: 0.6,
2606                },
2607                score: 0.8,
2608                label: 0, // same class
2609            },
2610        ];
2611
2612        let result = nms_class_aware_float(0.3, None, same_class_boxes);
2613        assert_eq!(
2614            result.len(),
2615            1,
2616            "Class-aware NMS should suppress overlapping box with same class"
2617        );
2618        assert_eq!(result[0].label, 0);
2619        assert!((result[0].score - 0.9).abs() < 1e-6);
2620    }
2621
2622    #[test]
2623    fn test_class_agnostic_vs_aware_nms() {
2624        use crate::float::{nms_class_aware_float, nms_float};
2625
2626        // Two overlapping boxes with different classes
2627        let boxes = vec![
2628            DetectBox {
2629                bbox: BoundingBox {
2630                    xmin: 0.0,
2631                    ymin: 0.0,
2632                    xmax: 0.5,
2633                    ymax: 0.5,
2634                },
2635                score: 0.9,
2636                label: 0,
2637            },
2638            DetectBox {
2639                bbox: BoundingBox {
2640                    xmin: 0.1,
2641                    ymin: 0.1,
2642                    xmax: 0.6,
2643                    ymax: 0.6,
2644                },
2645                score: 0.8,
2646                label: 1,
2647            },
2648        ];
2649
2650        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2651        let agnostic_result = nms_float(0.3, None, boxes.clone());
2652        assert_eq!(
2653            agnostic_result.len(),
2654            1,
2655            "Class-agnostic NMS should suppress overlapping boxes"
2656        );
2657
2658        // Class-aware should keep both (different classes)
2659        let aware_result = nms_class_aware_float(0.3, None, boxes);
2660        assert_eq!(
2661            aware_result.len(),
2662            2,
2663            "Class-aware NMS should keep boxes with different classes"
2664        );
2665    }
2666
2667    #[test]
2668    fn test_class_aware_nms_int() {
2669        use crate::byte::nms_class_aware_int;
2670
2671        // Create two overlapping boxes with different classes
2672        let boxes = vec![
2673            DetectBoxQuantized {
2674                bbox: BoundingBox {
2675                    xmin: 0.0,
2676                    ymin: 0.0,
2677                    xmax: 0.5,
2678                    ymax: 0.5,
2679                },
2680                score: 200_u8,
2681                label: 0,
2682            },
2683            DetectBoxQuantized {
2684                bbox: BoundingBox {
2685                    xmin: 0.1,
2686                    ymin: 0.1,
2687                    xmax: 0.6,
2688                    ymax: 0.6,
2689                },
2690                score: 180_u8,
2691                label: 1, // different class
2692            },
2693        ];
2694
2695        // Should keep both (different classes)
2696        let result = nms_class_aware_int(0.5, None, boxes);
2697        assert_eq!(
2698            result.len(),
2699            2,
2700            "Class-aware NMS (int) should keep boxes with different classes"
2701        );
2702    }
2703
2704    #[test]
2705    fn test_nms_enum_default() {
2706        // Test that Nms enum has the correct default
2707        let default_nms: configs::Nms = Default::default();
2708        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2709    }
2710
2711    #[test]
2712    fn test_decoder_nms_mode() {
2713        // Test that decoder properly stores NMS mode
2714        let decoder = DecoderBuilder::default()
2715            .with_config_yolo_det(
2716                configs::Detection {
2717                    anchors: None,
2718                    decoder: DecoderType::Ultralytics,
2719                    quantization: None,
2720                    shape: vec![1, 84, 8400],
2721                    dshape: Vec::new(),
2722                    normalized: Some(true),
2723                },
2724                None,
2725            )
2726            .with_nms(Some(configs::Nms::ClassAware))
2727            .build()
2728            .unwrap();
2729
2730        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2731    }
2732
2733    #[test]
2734    fn test_decoder_nms_bypass() {
2735        // Test that decoder can be configured with nms=None (bypass)
2736        let decoder = DecoderBuilder::default()
2737            .with_config_yolo_det(
2738                configs::Detection {
2739                    anchors: None,
2740                    decoder: DecoderType::Ultralytics,
2741                    quantization: None,
2742                    shape: vec![1, 84, 8400],
2743                    dshape: Vec::new(),
2744                    normalized: Some(true),
2745                },
2746                None,
2747            )
2748            .with_nms(None)
2749            .build()
2750            .unwrap();
2751
2752        assert_eq!(decoder.nms, None);
2753    }
2754
2755    #[test]
2756    fn test_decoder_normalized_boxes_true() {
2757        // Test that normalized_boxes returns Some(true) when explicitly set
2758        let decoder = DecoderBuilder::default()
2759            .with_config_yolo_det(
2760                configs::Detection {
2761                    anchors: None,
2762                    decoder: DecoderType::Ultralytics,
2763                    quantization: None,
2764                    shape: vec![1, 84, 8400],
2765                    dshape: Vec::new(),
2766                    normalized: Some(true),
2767                },
2768                None,
2769            )
2770            .build()
2771            .unwrap();
2772
2773        assert_eq!(decoder.normalized_boxes(), Some(true));
2774    }
2775
2776    #[test]
2777    fn test_decoder_normalized_boxes_false() {
2778        // Test that normalized_boxes returns Some(false) when config specifies
2779        // unnormalized
2780        let decoder = DecoderBuilder::default()
2781            .with_config_yolo_det(
2782                configs::Detection {
2783                    anchors: None,
2784                    decoder: DecoderType::Ultralytics,
2785                    quantization: None,
2786                    shape: vec![1, 84, 8400],
2787                    dshape: Vec::new(),
2788                    normalized: Some(false),
2789                },
2790                None,
2791            )
2792            .build()
2793            .unwrap();
2794
2795        assert_eq!(decoder.normalized_boxes(), Some(false));
2796    }
2797
2798    #[test]
2799    fn test_decoder_normalized_boxes_unknown() {
2800        // Test that normalized_boxes returns None when not specified in config
2801        let decoder = DecoderBuilder::default()
2802            .with_config_yolo_det(
2803                configs::Detection {
2804                    anchors: None,
2805                    decoder: DecoderType::Ultralytics,
2806                    quantization: None,
2807                    shape: vec![1, 84, 8400],
2808                    dshape: Vec::new(),
2809                    normalized: None,
2810                },
2811                Some(DecoderVersion::Yolo11),
2812            )
2813            .build()
2814            .unwrap();
2815
2816        assert_eq!(decoder.normalized_boxes(), None);
2817    }
2818
2819    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2820        input: ArrayView<F, D>,
2821        quant: Quantization,
2822    ) -> Array<T, D>
2823    where
2824        i32: num_traits::AsPrimitive<F>,
2825        f32: num_traits::AsPrimitive<F>,
2826    {
2827        let zero_point = quant.zero_point.as_();
2828        let div_scale = F::one() / quant.scale.as_();
2829        if zero_point != F::zero() {
2830            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2831        } else {
2832            input.mapv(|d| (d * div_scale).round().as_())
2833        }
2834    }
2835
2836    fn real_data_expected_boxes() -> [DetectBox; 2] {
2837        [
2838            DetectBox {
2839                bbox: BoundingBox {
2840                    xmin: 0.08515105,
2841                    ymin: 0.7131401,
2842                    xmax: 0.29802868,
2843                    ymax: 0.8195788,
2844                },
2845                score: 0.91537374,
2846                label: 23,
2847            },
2848            DetectBox {
2849                bbox: BoundingBox {
2850                    xmin: 0.59605736,
2851                    ymin: 0.25545314,
2852                    xmax: 0.93666154,
2853                    ymax: 0.72378385,
2854                },
2855                score: 0.91537374,
2856                label: 23,
2857            },
2858        ]
2859    }
2860
2861    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
2862        [DetectBox {
2863            bbox: BoundingBox {
2864                xmin: 0.12549022,
2865                ymin: 0.12549022,
2866                xmax: 0.23529413,
2867                ymax: 0.23529413,
2868            },
2869            score: 0.98823535,
2870            label: 2,
2871        }]
2872    }
2873
2874    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
2875        [DetectBox {
2876            bbox: BoundingBox {
2877                xmin: 0.1234,
2878                ymin: 0.1234,
2879                xmax: 0.2345,
2880                ymax: 0.2345,
2881            },
2882            score: 0.9876,
2883            label: 2,
2884        }]
2885    }
2886
2887    macro_rules! real_data_proto_test {
2888        ($name:ident, quantized, $layout:ident) => {
2889            #[test]
2890            fn $name() {
2891                let is_split = matches!(stringify!($layout), "split");
2892
2893                let score_threshold = 0.45;
2894                let iou_threshold = 0.45;
2895                let quant_boxes = (0.021287762_f32, 31_i32);
2896                let quant_protos = (0.02491162_f32, -117_i32);
2897
2898                let raw_boxes = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
2899                let raw_boxes = unsafe {
2900                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2901                };
2902                let boxes_i8 =
2903                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2904
2905                let raw_protos = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
2906                let raw_protos = unsafe {
2907                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2908                };
2909                let protos_i8 =
2910                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2911                        .unwrap();
2912
2913                // Pre-split (unused for combined, but harmless)
2914                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
2915                let scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
2916                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
2917                let boxes_combined = boxes_i8;
2918
2919                let decoder = if is_split {
2920                    build_yolo_split_segdet_decoder(
2921                        score_threshold,
2922                        iou_threshold,
2923                        quant_boxes,
2924                        quant_protos,
2925                    )
2926                } else {
2927                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2928                };
2929
2930                let expected = real_data_expected_boxes();
2931                let mut output_boxes = Vec::with_capacity(50);
2932
2933                let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
2934                    vec![
2935                        boxes_split.view().into(),
2936                        scores_split.view().into(),
2937                        mask_split.view().into(),
2938                        protos_i8.view().into(),
2939                    ]
2940                } else {
2941                    vec![boxes_combined.view().into(), protos_i8.view().into()]
2942                };
2943                decoder
2944                    .decode_quantized_proto(&inputs, &mut output_boxes)
2945                    .unwrap();
2946
2947                assert_eq!(output_boxes.len(), 2);
2948                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
2949                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
2950            }
2951        };
2952        ($name:ident, float, $layout:ident) => {
2953            #[test]
2954            fn $name() {
2955                let is_split = matches!(stringify!($layout), "split");
2956
2957                let score_threshold = 0.45;
2958                let iou_threshold = 0.45;
2959                let quant_boxes = (0.021287762_f32, 31_i32);
2960                let quant_protos = (0.02491162_f32, -117_i32);
2961
2962                let raw_boxes = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
2963                let raw_boxes = unsafe {
2964                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
2965                };
2966                let boxes_i8 =
2967                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
2968                let boxes_f32: Array3<f32> =
2969                    dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
2970
2971                let raw_protos = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
2972                let raw_protos = unsafe {
2973                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
2974                };
2975                let protos_i8 =
2976                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
2977                        .unwrap();
2978                let protos_f32: Array4<f32> =
2979                    dequantize_ndarray(protos_i8.view(), quant_protos.into());
2980
2981                // Pre-split from dequantized data
2982                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
2983                let scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
2984                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
2985                let boxes_combined = boxes_f32;
2986
2987                let decoder = if is_split {
2988                    build_yolo_split_segdet_decoder(
2989                        score_threshold,
2990                        iou_threshold,
2991                        quant_boxes,
2992                        quant_protos,
2993                    )
2994                } else {
2995                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
2996                };
2997
2998                let expected = real_data_expected_boxes();
2999                let mut output_boxes = Vec::with_capacity(50);
3000
3001                let inputs = if is_split {
3002                    vec![
3003                        boxes_split.view().into_dyn(),
3004                        scores_split.view().into_dyn(),
3005                        mask_split.view().into_dyn(),
3006                        protos_f32.view().into_dyn(),
3007                    ]
3008                } else {
3009                    vec![
3010                        boxes_combined.view().into_dyn(),
3011                        protos_f32.view().into_dyn(),
3012                    ]
3013                };
3014                decoder
3015                    .decode_float_proto(&inputs, &mut output_boxes)
3016                    .unwrap();
3017
3018                assert_eq!(output_boxes.len(), 2);
3019                assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3020                assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3021            }
3022        };
3023    }
3024
3025    real_data_proto_test!(test_decoder_segdet_proto, quantized, combined);
3026    real_data_proto_test!(test_decoder_segdet_proto_float, float, combined);
3027    real_data_proto_test!(test_decoder_segdet_split_proto, quantized, split);
3028    real_data_proto_test!(test_decoder_segdet_split_proto_float, float, split);
3029
3030    const E2E_COMBINED_DET_CONFIG: &str = "
3031decoder_version: yolo26
3032outputs:
3033 - type: detection
3034   decoder: ultralytics
3035   quantization: [0.00784313725490196, 0]
3036   shape: [1, 10, 6]
3037   dshape:
3038    - [batch, 1]
3039    - [num_boxes, 10]
3040    - [num_features, 6]
3041   normalized: true
3042";
3043
3044    const E2E_COMBINED_SEGDET_CONFIG: &str = "
3045decoder_version: yolo26
3046outputs:
3047 - type: detection
3048   decoder: ultralytics
3049   quantization: [0.00784313725490196, 0]
3050   shape: [1, 10, 38]
3051   dshape:
3052    - [batch, 1]
3053    - [num_boxes, 10]
3054    - [num_features, 38]
3055   normalized: true
3056 - type: protos
3057   decoder: ultralytics
3058   quantization: [0.0039215686274509803921568627451, 128]
3059   shape: [1, 160, 160, 32]
3060   dshape:
3061    - [batch, 1]
3062    - [height, 160]
3063    - [width, 160]
3064    - [num_protos, 32]
3065";
3066
3067    const E2E_SPLIT_DET_CONFIG: &str = "
3068decoder_version: yolo26
3069outputs:
3070 - type: boxes
3071   decoder: ultralytics
3072   quantization: [0.00784313725490196, 0]
3073   shape: [1, 10, 4]
3074   dshape:
3075    - [batch, 1]
3076    - [num_boxes, 10]
3077    - [box_coords, 4]
3078   normalized: true
3079 - type: scores
3080   decoder: ultralytics
3081   quantization: [0.00784313725490196, 0]
3082   shape: [1, 10, 1]
3083   dshape:
3084    - [batch, 1]
3085    - [num_boxes, 10]
3086    - [num_classes, 1]
3087 - type: classes
3088   decoder: ultralytics
3089   quantization: [0.00784313725490196, 0]
3090   shape: [1, 10, 1]
3091   dshape:
3092    - [batch, 1]
3093    - [num_boxes, 10]
3094    - [num_classes, 1]
3095";
3096
3097    const E2E_SPLIT_SEGDET_CONFIG: &str = "
3098decoder_version: yolo26
3099outputs:
3100 - type: boxes
3101   decoder: ultralytics
3102   quantization: [0.00784313725490196, 0]
3103   shape: [1, 10, 4]
3104   dshape:
3105    - [batch, 1]
3106    - [num_boxes, 10]
3107    - [box_coords, 4]
3108   normalized: true
3109 - type: scores
3110   decoder: ultralytics
3111   quantization: [0.00784313725490196, 0]
3112   shape: [1, 10, 1]
3113   dshape:
3114    - [batch, 1]
3115    - [num_boxes, 10]
3116    - [num_classes, 1]
3117 - type: classes
3118   decoder: ultralytics
3119   quantization: [0.00784313725490196, 0]
3120   shape: [1, 10, 1]
3121   dshape:
3122    - [batch, 1]
3123    - [num_boxes, 10]
3124    - [num_classes, 1]
3125 - type: mask_coefficients
3126   decoder: ultralytics
3127   quantization: [0.00784313725490196, 0]
3128   shape: [1, 10, 32]
3129   dshape:
3130    - [batch, 1]
3131    - [num_boxes, 10]
3132    - [num_protos, 32]
3133 - type: protos
3134   decoder: ultralytics
3135   quantization: [0.0039215686274509803921568627451, 128]
3136   shape: [1, 160, 160, 32]
3137   dshape:
3138    - [batch, 1]
3139    - [height, 160]
3140    - [width, 160]
3141    - [num_protos, 32]
3142";
3143
3144    macro_rules! e2e_segdet_test {
3145        ($name:ident, quantized, $layout:ident, $output:ident) => {
3146            #[test]
3147            fn $name() {
3148                let is_split = matches!(stringify!($layout), "split");
3149                let is_proto = matches!(stringify!($output), "proto");
3150
3151                let score_threshold = 0.45;
3152                let iou_threshold = 0.45;
3153
3154                let mut boxes = Array2::zeros((10, 4));
3155                let mut scores = Array2::zeros((10, 1));
3156                let mut classes = Array2::zeros((10, 1));
3157                let mask = Array2::zeros((10, 32));
3158                let protos = Array3::<f64>::zeros((160, 160, 32));
3159                let protos = protos.insert_axis(Axis(0));
3160                let protos_quant = (1.0 / 255.0, 0.0);
3161                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
3162
3163                boxes
3164                    .slice_mut(s![0, ..])
3165                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3166                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3167                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3168
3169                let detect_quant = (2.0 / 255.0, 0.0);
3170
3171                let decoder = if is_split {
3172                    DecoderBuilder::default()
3173                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3174                        .with_score_threshold(score_threshold)
3175                        .with_iou_threshold(iou_threshold)
3176                        .build()
3177                        .unwrap()
3178                } else {
3179                    DecoderBuilder::default()
3180                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3181                        .with_score_threshold(score_threshold)
3182                        .with_iou_threshold(iou_threshold)
3183                        .build()
3184                        .unwrap()
3185                };
3186
3187                let expected = e2e_expected_boxes_quant();
3188                let mut output_boxes = Vec::with_capacity(50);
3189
3190                if is_split {
3191                    let boxes = boxes.insert_axis(Axis(0));
3192                    let scores = scores.insert_axis(Axis(0));
3193                    let classes = classes.insert_axis(Axis(0));
3194                    let mask = mask.insert_axis(Axis(0));
3195
3196                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
3197                    let scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
3198                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
3199                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
3200
3201                    if is_proto {
3202                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3203                            boxes.view().into(),
3204                            scores.view().into(),
3205                            classes.view().into(),
3206                            mask.view().into(),
3207                            protos.view().into(),
3208                        ];
3209                        decoder
3210                            .decode_quantized_proto(&inputs, &mut output_boxes)
3211                            .unwrap();
3212
3213                        assert_eq!(output_boxes.len(), 1);
3214                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3215                    } else {
3216                        let mut output_masks = Vec::with_capacity(50);
3217                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3218                            boxes.view().into(),
3219                            scores.view().into(),
3220                            classes.view().into(),
3221                            mask.view().into(),
3222                            protos.view().into(),
3223                        ];
3224                        decoder
3225                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3226                            .unwrap();
3227
3228                        assert_eq!(output_boxes.len(), 1);
3229                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3230                    }
3231                } else {
3232                    // Combined layout
3233                    let detect = ndarray::concatenate![
3234                        Axis(1),
3235                        boxes.view(),
3236                        scores.view(),
3237                        classes.view(),
3238                        mask.view()
3239                    ];
3240                    let detect = detect.insert_axis(Axis(0));
3241                    assert_eq!(detect.shape(), &[1, 10, 38]);
3242                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3243
3244                    if is_proto {
3245                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3246                            vec![detect.view().into(), protos.view().into()];
3247                        decoder
3248                            .decode_quantized_proto(&inputs, &mut output_boxes)
3249                            .unwrap();
3250
3251                        assert_eq!(output_boxes.len(), 1);
3252                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3253                    } else {
3254                        let mut output_masks = Vec::with_capacity(50);
3255                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3256                            vec![detect.view().into(), protos.view().into()];
3257                        decoder
3258                            .decode_quantized(&inputs, &mut output_boxes, &mut output_masks)
3259                            .unwrap();
3260
3261                        assert_eq!(output_boxes.len(), 1);
3262                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3263                    }
3264                }
3265            }
3266        };
3267        ($name:ident, float, $layout:ident, $output:ident) => {
3268            #[test]
3269            fn $name() {
3270                let is_split = matches!(stringify!($layout), "split");
3271                let is_proto = matches!(stringify!($output), "proto");
3272
3273                let score_threshold = 0.45;
3274                let iou_threshold = 0.45;
3275
3276                let mut boxes = Array2::zeros((10, 4));
3277                let mut scores = Array2::zeros((10, 1));
3278                let mut classes = Array2::zeros((10, 1));
3279                let mask: Array2<f64> = Array2::zeros((10, 32));
3280                let protos = Array3::<f64>::zeros((160, 160, 32));
3281                let protos = protos.insert_axis(Axis(0));
3282
3283                boxes
3284                    .slice_mut(s![0, ..])
3285                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3286                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
3287                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
3288
3289                let decoder = if is_split {
3290                    DecoderBuilder::default()
3291                        .with_config_yaml_str(E2E_SPLIT_SEGDET_CONFIG.to_string())
3292                        .with_score_threshold(score_threshold)
3293                        .with_iou_threshold(iou_threshold)
3294                        .build()
3295                        .unwrap()
3296                } else {
3297                    DecoderBuilder::default()
3298                        .with_config_yaml_str(E2E_COMBINED_SEGDET_CONFIG.to_string())
3299                        .with_score_threshold(score_threshold)
3300                        .with_iou_threshold(iou_threshold)
3301                        .build()
3302                        .unwrap()
3303                };
3304
3305                let expected = e2e_expected_boxes_float();
3306                let mut output_boxes = Vec::with_capacity(50);
3307
3308                if is_split {
3309                    let boxes = boxes.insert_axis(Axis(0));
3310                    let scores = scores.insert_axis(Axis(0));
3311                    let classes = classes.insert_axis(Axis(0));
3312                    let mask = mask.insert_axis(Axis(0));
3313
3314                    if is_proto {
3315                        let inputs = vec![
3316                            boxes.view().into_dyn(),
3317                            scores.view().into_dyn(),
3318                            classes.view().into_dyn(),
3319                            mask.view().into_dyn(),
3320                            protos.view().into_dyn(),
3321                        ];
3322                        decoder
3323                            .decode_float_proto(&inputs, &mut output_boxes)
3324                            .unwrap();
3325
3326                        assert_eq!(output_boxes.len(), 1);
3327                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3328                    } else {
3329                        let mut output_masks = Vec::with_capacity(50);
3330                        let inputs = vec![
3331                            boxes.view().into_dyn(),
3332                            scores.view().into_dyn(),
3333                            classes.view().into_dyn(),
3334                            mask.view().into_dyn(),
3335                            protos.view().into_dyn(),
3336                        ];
3337                        decoder
3338                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3339                            .unwrap();
3340
3341                        assert_eq!(output_boxes.len(), 1);
3342                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3343                    }
3344                } else {
3345                    // Combined layout
3346                    let detect = ndarray::concatenate![
3347                        Axis(1),
3348                        boxes.view(),
3349                        scores.view(),
3350                        classes.view(),
3351                        mask.view()
3352                    ];
3353                    let detect = detect.insert_axis(Axis(0));
3354                    assert_eq!(detect.shape(), &[1, 10, 38]);
3355
3356                    if is_proto {
3357                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3358                        decoder
3359                            .decode_float_proto(&inputs, &mut output_boxes)
3360                            .unwrap();
3361
3362                        assert_eq!(output_boxes.len(), 1);
3363                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3364                    } else {
3365                        let mut output_masks = Vec::with_capacity(50);
3366                        let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
3367                        decoder
3368                            .decode_float(&inputs, &mut output_boxes, &mut output_masks)
3369                            .unwrap();
3370
3371                        assert_eq!(output_boxes.len(), 1);
3372                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3373                    }
3374                }
3375            }
3376        };
3377    }
3378
3379    e2e_segdet_test!(test_decoder_end_to_end_segdet, quantized, combined, masks);
3380    e2e_segdet_test!(test_decoder_end_to_end_segdet_float, float, combined, masks);
3381    e2e_segdet_test!(
3382        test_decoder_end_to_end_segdet_proto,
3383        quantized,
3384        combined,
3385        proto
3386    );
3387    e2e_segdet_test!(
3388        test_decoder_end_to_end_segdet_proto_float,
3389        float,
3390        combined,
3391        proto
3392    );
3393    e2e_segdet_test!(
3394        test_decoder_end_to_end_segdet_split,
3395        quantized,
3396        split,
3397        masks
3398    );
3399    e2e_segdet_test!(
3400        test_decoder_end_to_end_segdet_split_float,
3401        float,
3402        split,
3403        masks
3404    );
3405    e2e_segdet_test!(
3406        test_decoder_end_to_end_segdet_split_proto,
3407        quantized,
3408        split,
3409        proto
3410    );
3411    e2e_segdet_test!(
3412        test_decoder_end_to_end_segdet_split_proto_float,
3413        float,
3414        split,
3415        proto
3416    );
3417
3418    macro_rules! e2e_det_test {
3419        ($name:ident, quantized, $layout:ident) => {
3420            #[test]
3421            fn $name() {
3422                let is_split = matches!(stringify!($layout), "split");
3423
3424                let score_threshold = 0.45;
3425                let iou_threshold = 0.45;
3426
3427                let mut boxes = Array3::zeros((1, 10, 4));
3428                let mut scores = Array3::zeros((1, 10, 1));
3429                let mut classes = Array3::zeros((1, 10, 1));
3430
3431                boxes
3432                    .slice_mut(s![0, 0, ..])
3433                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3434                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3435                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3436
3437                let detect_quant = (2.0 / 255.0, 0_i32);
3438
3439                let decoder = if is_split {
3440                    DecoderBuilder::default()
3441                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3442                        .with_score_threshold(score_threshold)
3443                        .with_iou_threshold(iou_threshold)
3444                        .build()
3445                        .unwrap()
3446                } else {
3447                    DecoderBuilder::default()
3448                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3449                        .with_score_threshold(score_threshold)
3450                        .with_iou_threshold(iou_threshold)
3451                        .build()
3452                        .unwrap()
3453                };
3454
3455                let expected = e2e_expected_boxes_quant();
3456                let mut output_boxes = Vec::with_capacity(50);
3457
3458                if is_split {
3459                    let boxes: Array<u8, _> = quantize_ndarray(boxes.view(), detect_quant.into());
3460                    let scores: Array<u8, _> = quantize_ndarray(scores.view(), detect_quant.into());
3461                    let classes: Array<u8, _> =
3462                        quantize_ndarray(classes.view(), detect_quant.into());
3463                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
3464                        boxes.view().into(),
3465                        scores.view().into(),
3466                        classes.view().into(),
3467                    ];
3468                    decoder
3469                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3470                        .unwrap();
3471                } else {
3472                    let detect =
3473                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3474                    assert_eq!(detect.shape(), &[1, 10, 6]);
3475                    let detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
3476                    let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
3477                        vec![detect.view().into()];
3478                    decoder
3479                        .decode_quantized(&inputs, &mut output_boxes, &mut Vec::new())
3480                        .unwrap();
3481                }
3482
3483                assert_eq!(output_boxes.len(), 1);
3484                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3485            }
3486        };
3487        ($name:ident, float, $layout:ident) => {
3488            #[test]
3489            fn $name() {
3490                let is_split = matches!(stringify!($layout), "split");
3491
3492                let score_threshold = 0.45;
3493                let iou_threshold = 0.45;
3494
3495                let mut boxes = Array3::zeros((1, 10, 4));
3496                let mut scores = Array3::zeros((1, 10, 1));
3497                let mut classes = Array3::zeros((1, 10, 1));
3498
3499                boxes
3500                    .slice_mut(s![0, 0, ..])
3501                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
3502                scores.slice_mut(s![0, 0, ..]).assign(&array![0.9876]);
3503                classes.slice_mut(s![0, 0, ..]).assign(&array![2.0]);
3504
3505                let decoder = if is_split {
3506                    DecoderBuilder::default()
3507                        .with_config_yaml_str(E2E_SPLIT_DET_CONFIG.to_string())
3508                        .with_score_threshold(score_threshold)
3509                        .with_iou_threshold(iou_threshold)
3510                        .build()
3511                        .unwrap()
3512                } else {
3513                    DecoderBuilder::default()
3514                        .with_config_yaml_str(E2E_COMBINED_DET_CONFIG.to_string())
3515                        .with_score_threshold(score_threshold)
3516                        .with_iou_threshold(iou_threshold)
3517                        .build()
3518                        .unwrap()
3519                };
3520
3521                let expected = e2e_expected_boxes_float();
3522                let mut output_boxes = Vec::with_capacity(50);
3523
3524                if is_split {
3525                    let inputs = vec![
3526                        boxes.view().into_dyn(),
3527                        scores.view().into_dyn(),
3528                        classes.view().into_dyn(),
3529                    ];
3530                    decoder
3531                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3532                        .unwrap();
3533                } else {
3534                    let detect =
3535                        ndarray::concatenate![Axis(2), boxes.view(), scores.view(), classes.view()];
3536                    assert_eq!(detect.shape(), &[1, 10, 6]);
3537                    let inputs = vec![detect.view().into_dyn()];
3538                    decoder
3539                        .decode_float(&inputs, &mut output_boxes, &mut Vec::new())
3540                        .unwrap();
3541                }
3542
3543                assert_eq!(output_boxes.len(), 1);
3544                assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
3545            }
3546        };
3547    }
3548
3549    e2e_det_test!(test_decoder_end_to_end_combined_det, quantized, combined);
3550    e2e_det_test!(test_decoder_end_to_end_combined_det_float, float, combined);
3551    e2e_det_test!(test_decoder_end_to_end_split_det, quantized, split);
3552    e2e_det_test!(test_decoder_end_to_end_split_det_float, float, split);
3553
3554    #[test]
3555    fn test_decode_tensor() {
3556        let score_threshold = 0.45;
3557        let iou_threshold = 0.45;
3558
3559        let raw_boxes = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
3560        let raw_boxes =
3561            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3562        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3563        boxes_i8
3564            .map()
3565            .unwrap()
3566            .as_mut_slice()
3567            .copy_from_slice(raw_boxes);
3568        let boxes_i8 = boxes_i8.into();
3569
3570        let raw_protos = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
3571        let raw_protos = unsafe {
3572            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3573        };
3574        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3575        protos_i8
3576            .map()
3577            .unwrap()
3578            .as_mut_slice()
3579            .copy_from_slice(raw_protos);
3580        let protos_i8 = protos_i8.into();
3581
3582        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3583        let expected = real_data_expected_boxes();
3584        let mut output_boxes = Vec::with_capacity(50);
3585
3586        decoder
3587            .decode(&[&boxes_i8, &protos_i8], &mut output_boxes, &mut Vec::new())
3588            .unwrap();
3589
3590        assert_eq!(output_boxes.len(), 2);
3591        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3592        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3593    }
3594
3595    #[test]
3596    fn test_decode_tensor_f32() {
3597        let score_threshold = 0.45;
3598        let iou_threshold = 0.45;
3599
3600        let quant_boxes = (0.021287762_f32, 31_i32);
3601        let quant_protos = (0.02491162_f32, -117_i32);
3602        let raw_boxes = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
3603        let raw_boxes =
3604            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3605        let mut raw_boxes_f32 = vec![0f32; raw_boxes.len()];
3606        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f32);
3607        let boxes_f32: Tensor<f32> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3608        boxes_f32
3609            .map()
3610            .unwrap()
3611            .as_mut_slice()
3612            .copy_from_slice(&raw_boxes_f32);
3613        let boxes_f32 = boxes_f32.into();
3614
3615        let raw_protos = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
3616        let raw_protos = unsafe {
3617            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3618        };
3619        let mut raw_protos_f32 = vec![0f32; raw_protos.len()];
3620        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f32);
3621        let protos_f32: Tensor<f32> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3622        protos_f32
3623            .map()
3624            .unwrap()
3625            .as_mut_slice()
3626            .copy_from_slice(&raw_protos_f32);
3627        let protos_f32 = protos_f32.into();
3628
3629        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3630
3631        let expected = real_data_expected_boxes();
3632        let mut output_boxes = Vec::with_capacity(50);
3633
3634        decoder
3635            .decode(
3636                &[&boxes_f32, &protos_f32],
3637                &mut output_boxes,
3638                &mut Vec::new(),
3639            )
3640            .unwrap();
3641
3642        assert_eq!(output_boxes.len(), 2);
3643        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3644        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3645    }
3646
3647    #[test]
3648    fn test_decode_tensor_f64() {
3649        let score_threshold = 0.45;
3650        let iou_threshold = 0.45;
3651
3652        let quant_boxes = (0.021287762_f32, 31_i32);
3653        let quant_protos = (0.02491162_f32, -117_i32);
3654        let raw_boxes = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
3655        let raw_boxes =
3656            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3657        let mut raw_boxes_f64 = vec![0f64; raw_boxes.len()];
3658        dequantize_cpu(raw_boxes, quant_boxes.into(), &mut raw_boxes_f64);
3659        let boxes_f64: Tensor<f64> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3660        boxes_f64
3661            .map()
3662            .unwrap()
3663            .as_mut_slice()
3664            .copy_from_slice(&raw_boxes_f64);
3665        let boxes_f64 = boxes_f64.into();
3666
3667        let raw_protos = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
3668        let raw_protos = unsafe {
3669            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3670        };
3671        let mut raw_protos_f64 = vec![0f64; raw_protos.len()];
3672        dequantize_cpu(raw_protos, quant_protos.into(), &mut raw_protos_f64);
3673        let protos_f64: Tensor<f64> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3674        protos_f64
3675            .map()
3676            .unwrap()
3677            .as_mut_slice()
3678            .copy_from_slice(&raw_protos_f64);
3679        let protos_f64 = protos_f64.into();
3680
3681        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3682
3683        let expected = real_data_expected_boxes();
3684        let mut output_boxes = Vec::with_capacity(50);
3685
3686        decoder
3687            .decode(
3688                &[&boxes_f64, &protos_f64],
3689                &mut output_boxes,
3690                &mut Vec::new(),
3691            )
3692            .unwrap();
3693
3694        assert_eq!(output_boxes.len(), 2);
3695        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3696        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3697    }
3698
3699    #[test]
3700    fn test_decode_tensor_proto() {
3701        let score_threshold = 0.45;
3702        let iou_threshold = 0.45;
3703
3704        let raw_boxes = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
3705        let raw_boxes =
3706            unsafe { std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len()) };
3707        let boxes_i8: Tensor<i8> = Tensor::new(&[1, 116, 8400], None, None).unwrap();
3708        boxes_i8
3709            .map()
3710            .unwrap()
3711            .as_mut_slice()
3712            .copy_from_slice(raw_boxes);
3713        let boxes_i8 = boxes_i8.into();
3714
3715        let raw_protos = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
3716        let raw_protos = unsafe {
3717            std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
3718        };
3719        let protos_i8: Tensor<i8> = Tensor::new(&[1, 160, 160, 32], None, None).unwrap();
3720        protos_i8
3721            .map()
3722            .unwrap()
3723            .as_mut_slice()
3724            .copy_from_slice(raw_protos);
3725        let protos_i8 = protos_i8.into();
3726
3727        let decoder = build_yolov8_seg_decoder(score_threshold, iou_threshold);
3728
3729        let expected = real_data_expected_boxes();
3730        let mut output_boxes = Vec::with_capacity(50);
3731
3732        let proto_data = decoder
3733            .decode_proto(&[&boxes_i8, &protos_i8], &mut output_boxes)
3734            .unwrap();
3735
3736        assert_eq!(output_boxes.len(), 2);
3737        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
3738        assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
3739
3740        let proto_data = proto_data.expect("segmentation model should return ProtoData");
3741        let coeffs_shape = proto_data.mask_coefficients.shape();
3742        assert_eq!(
3743            coeffs_shape[0],
3744            output_boxes.len(),
3745            "mask_coefficients count must match detection count"
3746        );
3747        assert_eq!(
3748            coeffs_shape[1], 32,
3749            "each detection should have 32 mask coefficients"
3750        );
3751    }
3752
3753    // =========================================================================
3754    // Physical-order contract regression tests
3755    //
3756    // These cover the TFLite VX-delegate vertical-stripe mask bug and the
3757    // Ara-2-via-overlay anchor-first split-tensor bug. The core claim:
3758    // declaring shape+dshape in physical memory order produces correct
3759    // strides at wrap time, and `swap_axes_if_needed` permutes the stride
3760    // tuple into canonical logical order without touching bytes.
3761    // =========================================================================
3762
3763    /// TFLite YOLOv8-seg protos arrive as physically-NHWC bytes with
3764    /// NNStreamer dim string `"32:160:160:1"` (innermost-first). The
3765    /// overlay's 2026-04-22 workaround declares `[1, 160, 160, 32]` with
3766    /// dshape `[batch, height, width, num_protos]` — shape matches
3767    /// physical order, so no permutation is needed and the view is
3768    /// already in canonical HWC after the batch-dim slice. This test
3769    /// confirms that path matches the reference (direct-HWC) decode.
3770    #[test]
3771    fn test_physical_order_tflite_nhwc_protos() {
3772        let score_threshold = 0.45;
3773        let iou_threshold = 0.45;
3774
3775        // Load HWC protos and dequantize to f32.
3776        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
3777        let quant_protos = Quantization::new(0.02491161972284317, -117);
3778        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
3779
3780        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
3781        let quant_boxes = Quantization::new(0.021287761628627777, 31);
3782        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
3783
3784        // Reference: direct call with canonical HWC protos.
3785        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
3786        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
3787        decode_yolo_segdet_float(
3788            seg.view(),
3789            protos_f32_hwc.view(),
3790            score_threshold,
3791            iou_threshold,
3792            Some(configs::Nms::ClassAgnostic),
3793            &mut ref_boxes,
3794            &mut ref_masks,
3795        )
3796        .unwrap();
3797
3798        // Build the same protos as a 4D NHWC tensor and feed it through
3799        // the config-driven path with dshape in physical order.
3800        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0)); // [1, 160, 160, 32]
3801        let seg_3d = seg.insert_axis(Axis(0)); // [1, 116, 8400]
3802
3803        let decoder = DecoderBuilder::default()
3804            .with_config_yolo_segdet(
3805                configs::Detection {
3806                    decoder: configs::DecoderType::Ultralytics,
3807                    quantization: None,
3808                    shape: vec![1, 116, 8400],
3809                    dshape: vec![
3810                        (DimName::Batch, 1),
3811                        (DimName::NumFeatures, 116),
3812                        (DimName::NumBoxes, 8400),
3813                    ],
3814                    normalized: Some(true),
3815                    anchors: None,
3816                },
3817                configs::Protos {
3818                    decoder: configs::DecoderType::Ultralytics,
3819                    quantization: None,
3820                    shape: vec![1, 160, 160, 32],
3821                    // Physical NHWC — matches the TFLite buffer layout.
3822                    dshape: vec![
3823                        (DimName::Batch, 1),
3824                        (DimName::Height, 160),
3825                        (DimName::Width, 160),
3826                        (DimName::NumProtos, 32),
3827                    ],
3828                },
3829                None,
3830            )
3831            .with_score_threshold(score_threshold)
3832            .with_iou_threshold(iou_threshold)
3833            .build()
3834            .expect("config with NHWC protos dshape must build");
3835
3836        let mut cfg_boxes = Vec::with_capacity(10);
3837        let mut cfg_masks = Vec::with_capacity(10);
3838        decoder
3839            .decode_float(
3840                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
3841                &mut cfg_boxes,
3842                &mut cfg_masks,
3843            )
3844            .unwrap();
3845
3846        assert_eq!(cfg_boxes.len(), ref_boxes.len(), "box count mismatch");
3847        for (c, r) in cfg_boxes.iter().zip(&ref_boxes) {
3848            assert!(
3849                c.equal_within_delta(r, 0.01),
3850                "NHWC-declared box does not match reference: {c:?} vs {r:?}"
3851            );
3852        }
3853        for (cm, rm) in cfg_masks.iter().zip(&ref_masks) {
3854            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
3855            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
3856            assert_eq!(
3857                cm_arr, rm_arr,
3858                "NHWC-declared mask must match reference pixel-for-pixel"
3859            );
3860        }
3861    }
3862
3863    /// Ara-2 split-tensor boxes arrive physically anchor-first: NNStreamer
3864    /// dim string `"4:1:8400:1"` means innermost axis is 4 (coords), next
3865    /// is 1 (padding), next is 8400 (anchors), outermost is 1 (batch).
3866    /// The correct physical-order declaration is `[1, 8400, 1, 4]` with
3867    /// dshape `[batch, num_boxes, padding, box_coords]`. This test
3868    /// confirms that declaration produces the same decoded boxes as the
3869    /// canonical features-first TFLite declaration over the same logical
3870    /// data.
3871    #[test]
3872    fn test_physical_order_ara2_anchor_first_split_boxes() {
3873        use configs::{Boxes, Scores};
3874
3875        // Build synthetic boxes data in features-first canonical layout:
3876        // shape [1, 4, 8400] with one detection at anchor 42.
3877        const N: usize = 8400;
3878        let mut boxes_canonical = Array3::<f32>::zeros((1, 4, N));
3879        let target_anchor = 42usize;
3880        boxes_canonical[[0, 0, target_anchor]] = 0.4; // xc
3881        boxes_canonical[[0, 1, target_anchor]] = 0.5; // yc
3882        boxes_canonical[[0, 2, target_anchor]] = 0.2; // w
3883        boxes_canonical[[0, 3, target_anchor]] = 0.2; // h
3884
3885        // Scores: one class at 0.9 for the same anchor.
3886        let mut scores_canonical = Array3::<f32>::zeros((1, 80, N));
3887        scores_canonical[[0, 0, target_anchor]] = 0.9;
3888
3889        // Reference: canonical (features-first) shape declaration.
3890        let ref_decoder = DecoderBuilder::default()
3891            .with_config_yolo_split_det(
3892                Boxes {
3893                    decoder: configs::DecoderType::Ultralytics,
3894                    quantization: None,
3895                    shape: vec![1, 4, N],
3896                    dshape: vec![
3897                        (DimName::Batch, 1),
3898                        (DimName::BoxCoords, 4),
3899                        (DimName::NumBoxes, N),
3900                    ],
3901                    normalized: Some(true),
3902                },
3903                Scores {
3904                    decoder: configs::DecoderType::Ultralytics,
3905                    quantization: None,
3906                    shape: vec![1, 80, N],
3907                    dshape: vec![
3908                        (DimName::Batch, 1),
3909                        (DimName::NumClasses, 80),
3910                        (DimName::NumBoxes, N),
3911                    ],
3912                },
3913            )
3914            .with_score_threshold(0.5)
3915            .with_iou_threshold(0.5)
3916            .with_nms(Some(configs::Nms::ClassAgnostic))
3917            .build()
3918            .expect("reference canonical split decoder must build");
3919
3920        let mut ref_boxes = Vec::with_capacity(4);
3921        let mut ref_masks = Vec::with_capacity(0);
3922        ref_decoder
3923            .decode_float(
3924                &[
3925                    boxes_canonical.view().into_dyn(),
3926                    scores_canonical.view().into_dyn(),
3927                ],
3928                &mut ref_boxes,
3929                &mut ref_masks,
3930            )
3931            .unwrap();
3932        assert_eq!(ref_boxes.len(), 1, "reference should produce one box");
3933
3934        // Ara-2 physical layout: transpose axes 1↔2 so the innermost
3935        // axis is BoxCoords / NumClasses. Materialise to get a
3936        // C-contiguous Array3 with strides matching the physical order
3937        // the Ara-2 backend would write into the tensor.
3938        let boxes_ara2 = boxes_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 4]
3939        let scores_ara2 = scores_canonical.view().permuted_axes([0, 2, 1]).to_owned(); // [1, 8400, 80]
3940
3941        let ara2_decoder = DecoderBuilder::default()
3942            .with_config_yolo_split_det(
3943                Boxes {
3944                    decoder: configs::DecoderType::Ultralytics,
3945                    quantization: None,
3946                    shape: vec![1, N, 4],
3947                    dshape: vec![
3948                        (DimName::Batch, 1),
3949                        (DimName::NumBoxes, N),
3950                        (DimName::BoxCoords, 4),
3951                    ],
3952                    normalized: Some(true),
3953                },
3954                Scores {
3955                    decoder: configs::DecoderType::Ultralytics,
3956                    quantization: None,
3957                    shape: vec![1, N, 80],
3958                    dshape: vec![
3959                        (DimName::Batch, 1),
3960                        (DimName::NumBoxes, N),
3961                        (DimName::NumClasses, 80),
3962                    ],
3963                },
3964            )
3965            .with_score_threshold(0.5)
3966            .with_iou_threshold(0.5)
3967            .with_nms(Some(configs::Nms::ClassAgnostic))
3968            .build()
3969            .expect("Ara-2 anchor-first decoder must build");
3970
3971        let mut ara2_boxes = Vec::with_capacity(4);
3972        let mut ara2_masks = Vec::with_capacity(0);
3973        ara2_decoder
3974            .decode_float(
3975                &[boxes_ara2.view().into_dyn(), scores_ara2.view().into_dyn()],
3976                &mut ara2_boxes,
3977                &mut ara2_masks,
3978            )
3979            .unwrap();
3980
3981        assert_eq!(
3982            ara2_boxes.len(),
3983            ref_boxes.len(),
3984            "Ara-2 anchor-first declaration must produce the same number \
3985             of boxes as the canonical features-first reference"
3986        );
3987        for (a, r) in ara2_boxes.iter().zip(&ref_boxes) {
3988            assert!(
3989                a.equal_within_delta(r, 1e-4),
3990                "Ara-2 box differs from reference: {a:?} vs {r:?}"
3991            );
3992        }
3993    }
3994
3995    /// Dshape whose per-axis sizes disagree with shape (the exact
3996    /// failure mode that caused the TFLite stripe bug — shape declared
3997    /// in one order, dshape in another) is rejected with a clear error.
3998    #[test]
3999    fn test_physical_order_rejects_shape_dshape_mismatch() {
4000        let result = DecoderBuilder::default()
4001            .with_config_yolo_segdet(
4002                configs::Detection {
4003                    decoder: configs::DecoderType::Ultralytics,
4004                    quantization: None,
4005                    shape: vec![1, 116, 8400],
4006                    dshape: vec![
4007                        (DimName::Batch, 1),
4008                        (DimName::NumFeatures, 116),
4009                        (DimName::NumBoxes, 8400),
4010                    ],
4011                    normalized: Some(true),
4012                    anchors: None,
4013                },
4014                configs::Protos {
4015                    decoder: configs::DecoderType::Ultralytics,
4016                    quantization: None,
4017                    // Shape in NCHW order...
4018                    shape: vec![1, 32, 160, 160],
4019                    // ...but dshape in NHWC order — the sizes don't line
4020                    // up positionally (dshape[1]=160 vs shape[1]=32).
4021                    dshape: vec![
4022                        (DimName::Batch, 1),
4023                        (DimName::Height, 160),
4024                        (DimName::Width, 160),
4025                        (DimName::NumProtos, 32),
4026                    ],
4027                },
4028                None,
4029            )
4030            .build();
4031
4032        match result {
4033            Err(DecoderError::InvalidConfig(msg)) => {
4034                assert!(
4035                    msg.contains("does not match shape"),
4036                    "expected shape/dshape size mismatch error, got: {msg}"
4037                );
4038            }
4039            other => panic!("expected InvalidConfig, got {other:?}"),
4040        }
4041    }
4042
4043    /// Duplicate dim name in dshape is rejected (two axes mapping to the
4044    /// same canonical slot would break the permutation).
4045    #[test]
4046    fn test_physical_order_rejects_duplicate_dshape_axis() {
4047        let result = DecoderBuilder::default()
4048            .with_config_yolo_split_det(
4049                configs::Boxes {
4050                    decoder: configs::DecoderType::Ultralytics,
4051                    quantization: None,
4052                    shape: vec![1, 4, 8400],
4053                    dshape: vec![
4054                        (DimName::Batch, 1),
4055                        (DimName::BoxCoords, 4),
4056                        (DimName::BoxCoords, 4), // duplicate (size matches shape[2]=8400? no)
4057                    ],
4058                    normalized: Some(true),
4059                },
4060                configs::Scores {
4061                    decoder: configs::DecoderType::Ultralytics,
4062                    quantization: None,
4063                    shape: vec![1, 80, 8400],
4064                    dshape: vec![
4065                        (DimName::Batch, 1),
4066                        (DimName::NumClasses, 80),
4067                        (DimName::NumBoxes, 8400),
4068                    ],
4069                },
4070            )
4071            .build();
4072
4073        // Shape[2] = 8400 ≠ dshape[2] size 4, so the positional-mismatch
4074        // check fires first — which is fine: any mis-shaped dshape
4075        // should be rejected. Check for either error as long as it's a
4076        // clear InvalidConfig.
4077        match result {
4078            Err(DecoderError::InvalidConfig(msg)) => {
4079                assert!(
4080                    msg.contains("appears at both index") || msg.contains("does not match shape"),
4081                    "expected positional or duplicate-axis error, got: {msg}"
4082                );
4083            }
4084            other => panic!("expected InvalidConfig, got {other:?}"),
4085        }
4086
4087        // Separate case: size-consistent duplicate to exercise the
4088        // duplicate-axis code path explicitly. Use `Batch` (no size
4089        // constraint, can repeat with size 1 against a shape that
4090        // carries two singleton dims).
4091        let result = DecoderBuilder::default()
4092            .with_config_yolo_split_det(
4093                configs::Boxes {
4094                    decoder: configs::DecoderType::Ultralytics,
4095                    quantization: None,
4096                    shape: vec![1, 1, 4, 8400],
4097                    dshape: vec![
4098                        (DimName::Batch, 1),
4099                        (DimName::Batch, 1), // duplicate, sizes match
4100                        (DimName::BoxCoords, 4),
4101                        (DimName::NumBoxes, 8400),
4102                    ],
4103                    normalized: Some(true),
4104                },
4105                configs::Scores {
4106                    decoder: configs::DecoderType::Ultralytics,
4107                    quantization: None,
4108                    shape: vec![1, 80, 8400],
4109                    dshape: vec![
4110                        (DimName::Batch, 1),
4111                        (DimName::NumClasses, 80),
4112                        (DimName::NumBoxes, 8400),
4113                    ],
4114                },
4115            )
4116            .build();
4117        match result {
4118            Err(DecoderError::InvalidConfig(msg)) => {
4119                assert!(
4120                    msg.contains("appears at both index"),
4121                    "expected duplicate-axis error, got: {msg}"
4122                );
4123            }
4124            other => panic!("expected InvalidConfig, got {other:?}"),
4125        }
4126    }
4127
4128    /// Canonical (dshape-omitted) high-level decode produces the same
4129    /// numeric result as the dshape-populated form. This closes a
4130    /// coverage gap flagged during review: dshape omission had been
4131    /// exercised only at the builder level, not through the full
4132    /// `decode_float` pipeline.
4133    #[test]
4134    fn test_physical_order_dshape_omitted_decodes_numerically() {
4135        let score_threshold = 0.45;
4136        let iou_threshold = 0.45;
4137
4138        let protos_hwc = load_yolov8_protos().slice_move(s![0, .., .., ..]);
4139        let quant_protos = Quantization::new(0.02491161972284317, -117);
4140        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
4141
4142        let boxes_2d = load_yolov8_boxes().slice_move(s![0, .., ..]);
4143        let quant_boxes = Quantization::new(0.021287761628627777, 31);
4144        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
4145
4146        let protos_nhwc = protos_f32_hwc.clone().insert_axis(Axis(0));
4147        let seg_3d = seg.insert_axis(Axis(0));
4148
4149        let build_decoder = |det_dshape: Vec<(DimName, usize)>,
4150                             proto_dshape: Vec<(DimName, usize)>| {
4151            DecoderBuilder::default()
4152                .with_config_yolo_segdet(
4153                    configs::Detection {
4154                        decoder: configs::DecoderType::Ultralytics,
4155                        quantization: None,
4156                        shape: vec![1, 116, 8400],
4157                        dshape: det_dshape,
4158                        normalized: Some(true),
4159                        anchors: None,
4160                    },
4161                    configs::Protos {
4162                        decoder: configs::DecoderType::Ultralytics,
4163                        quantization: None,
4164                        shape: vec![1, 160, 160, 32],
4165                        dshape: proto_dshape,
4166                    },
4167                    None,
4168                )
4169                .with_score_threshold(score_threshold)
4170                .with_iou_threshold(iou_threshold)
4171                .build()
4172                .unwrap()
4173        };
4174
4175        // Baseline: dshape populated in physical order.
4176        let dshaped = build_decoder(
4177            vec![
4178                (DimName::Batch, 1),
4179                (DimName::NumFeatures, 116),
4180                (DimName::NumBoxes, 8400),
4181            ],
4182            vec![
4183                (DimName::Batch, 1),
4184                (DimName::Height, 160),
4185                (DimName::Width, 160),
4186                (DimName::NumProtos, 32),
4187            ],
4188        );
4189        let mut dshaped_boxes = Vec::new();
4190        let mut dshaped_masks = Vec::new();
4191        dshaped
4192            .decode_float(
4193                &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4194                &mut dshaped_boxes,
4195                &mut dshaped_masks,
4196            )
4197            .unwrap();
4198
4199        // Variant: dshape omitted — caller asserts shape is already in
4200        // the decoder's canonical order.
4201        let bare = build_decoder(vec![], vec![]);
4202        let mut bare_boxes = Vec::new();
4203        let mut bare_masks = Vec::new();
4204        bare.decode_float(
4205            &[seg_3d.view().into_dyn(), protos_nhwc.view().into_dyn()],
4206            &mut bare_boxes,
4207            &mut bare_masks,
4208        )
4209        .unwrap();
4210
4211        assert_eq!(bare_boxes.len(), dshaped_boxes.len());
4212        for (b, d) in bare_boxes.iter().zip(&dshaped_boxes) {
4213            assert!(
4214                b.equal_within_delta(d, 1e-4),
4215                "dshape-omitted box {b:?} differs from dshape-populated {d:?}"
4216            );
4217        }
4218        for (bm, dm) in bare_masks.iter().zip(&dshaped_masks) {
4219            let bm_arr = segmentation_to_mask(bm.segmentation.view()).unwrap();
4220            let dm_arr = segmentation_to_mask(dm.segmentation.view()).unwrap();
4221            assert_eq!(
4222                bm_arr, dm_arr,
4223                "dshape-omitted mask must match dshape-populated pixel-for-pixel"
4224            );
4225        }
4226    }
4227
4228    /// 4D anchor-first boxes schema with a trailing `padding` axis — the
4229    /// exact Ara-2 DVM shape (`"4:1:8400:1"` innermost-first = physical
4230    /// `[1, 8400, 1, 4]`). Exercises the schema path's
4231    /// `squeeze_padding_dims`: the caller declares a 4D physical-order
4232    /// layout including `padding`, HAL squeezes the size-1 axis during
4233    /// `to_legacy_config_outputs`, and the decoder sees the resulting
4234    /// 3D shape. Callers feed HAL a 3D squeezed view of the same bytes.
4235    /// Complements the 3D `test_physical_order_ara2_anchor_first_split_boxes`
4236    /// which exercises the programmatic-builder path.
4237    #[test]
4238    fn test_physical_order_ara2_4d_anchor_first_with_padding() {
4239        // Build synthetic data in anchor-first 3D layout (the shape HAL
4240        // actually operates on after squeezing). The 4D-with-padding
4241        // declaration in the schema is a producer-side convention; the
4242        // caller is expected to present HAL with a squeezed view.
4243        const N: usize = 8400;
4244        let mut boxes = Array3::<f32>::zeros((1, N, 4));
4245        let target = 42usize;
4246        boxes[[0, target, 0]] = 0.4;
4247        boxes[[0, target, 1]] = 0.5;
4248        boxes[[0, target, 2]] = 0.2;
4249        boxes[[0, target, 3]] = 0.2;
4250        let mut scores = Array3::<f32>::zeros((1, N, 80));
4251        scores[[0, target, 0]] = 0.9;
4252
4253        // Schema JSON declares shape+dshape in physical anchor-first
4254        // order including padding — mirrors `ara2_int8_edgefirst.json`'s
4255        // feature-first declaration but with the axes in physical
4256        // anchor-first order. `to_legacy_config_outputs` squeezes
4257        // padding before the decoder's rank-3 verification.
4258        let json = r#"{
4259          "schema_version": 2,
4260          "decoder_version": "yolov8",
4261          "nms": "class_agnostic",
4262          "outputs": [
4263            {"name": "boxes", "type": "boxes",
4264             "shape": [1, 8400, 1, 4],
4265             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"box_coords":4}],
4266             "encoding": "direct",
4267             "decoder": "ultralytics",
4268             "normalized": true},
4269            {"name": "scores", "type": "scores",
4270             "shape": [1, 8400, 1, 80],
4271             "dshape": [{"batch":1},{"num_boxes":8400},{"padding":1},{"num_classes":80}],
4272             "decoder": "ultralytics",
4273             "score_format": "per_class"}
4274          ]
4275        }"#;
4276        let decoder = DecoderBuilder::default()
4277            .with_config_json_str(json.to_string())
4278            .with_score_threshold(0.5)
4279            .with_iou_threshold(0.5)
4280            .build()
4281            .expect("4D anchor-first schema should build via squeeze_padding_dims");
4282
4283        let mut out_boxes = Vec::with_capacity(4);
4284        let mut out_masks = Vec::with_capacity(0);
4285        decoder
4286            .decode_float(
4287                &[boxes.view().into_dyn(), scores.view().into_dyn()],
4288                &mut out_boxes,
4289                &mut out_masks,
4290            )
4291            .unwrap();
4292
4293        assert_eq!(
4294            out_boxes.len(),
4295            1,
4296            "4D anchor-first with padding should decode exactly one box from the seeded anchor"
4297        );
4298        let b = &out_boxes[0];
4299        // xywh(0.4, 0.5, 0.2, 0.2) → xyxy(0.3, 0.4, 0.5, 0.6)
4300        assert!((b.bbox.xmin - 0.3).abs() < 1e-3, "xmin wrong: {b:?}");
4301        assert!((b.bbox.ymin - 0.4).abs() < 1e-3, "ymin wrong: {b:?}");
4302        assert!((b.bbox.xmax - 0.5).abs() < 1e-3, "xmax wrong: {b:?}");
4303        assert!((b.bbox.ymax - 0.6).abs() < 1e-3, "ymax wrong: {b:?}");
4304        assert_eq!(b.label, 0);
4305        assert!(b.score > 0.85, "score {}: {b:?}", b.score);
4306    }
4307}
4308
4309#[cfg(feature = "tracker")]
4310#[cfg(test)]
4311#[cfg_attr(coverage_nightly, coverage(off))]
4312mod decoder_tracked_tests {
4313
4314    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
4315    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
4316    use num_traits::{AsPrimitive, Float, PrimInt};
4317    use rand::{RngExt, SeedableRng};
4318    use rand_distr::StandardNormal;
4319
4320    use crate::{
4321        configs::{self, DimName},
4322        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
4323    };
4324
4325    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
4326        input: ArrayView<F, D>,
4327        quant: Quantization,
4328    ) -> Array<T, D>
4329    where
4330        i32: num_traits::AsPrimitive<F>,
4331        f32: num_traits::AsPrimitive<F>,
4332    {
4333        let zero_point = quant.zero_point.as_();
4334        let div_scale = F::one() / quant.scale.as_();
4335        if zero_point != F::zero() {
4336            input.mapv(|d| (d * div_scale + zero_point).round().as_())
4337        } else {
4338            input.mapv(|d| (d * div_scale).round().as_())
4339        }
4340    }
4341
4342    #[test]
4343    fn test_decoder_tracked_random_jitter() {
4344        use crate::configs::{DecoderType, Nms};
4345        use crate::DecoderBuilder;
4346
4347        let score_threshold = 0.25;
4348        let iou_threshold = 0.1;
4349        let out = edgefirst_bench::testdata::read("yolov8s_80_classes.bin");
4350        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4351        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4352        let quant = (0.0040811873, -123).into();
4353
4354        let decoder = DecoderBuilder::default()
4355            .with_config_yolo_det(
4356                crate::configs::Detection {
4357                    decoder: DecoderType::Ultralytics,
4358                    shape: vec![1, 84, 8400],
4359                    anchors: None,
4360                    quantization: Some(quant),
4361                    dshape: vec![
4362                        (crate::configs::DimName::Batch, 1),
4363                        (crate::configs::DimName::NumFeatures, 84),
4364                        (crate::configs::DimName::NumBoxes, 8400),
4365                    ],
4366                    normalized: Some(true),
4367                },
4368                None,
4369            )
4370            .with_score_threshold(score_threshold)
4371            .with_iou_threshold(iou_threshold)
4372            .with_nms(Some(Nms::ClassAgnostic))
4373            .build()
4374            .unwrap();
4375        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
4376
4377        let expected_boxes = [
4378            crate::DetectBox {
4379                bbox: crate::BoundingBox {
4380                    xmin: 0.5285137,
4381                    ymin: 0.05305544,
4382                    xmax: 0.87541467,
4383                    ymax: 0.9998909,
4384                },
4385                score: 0.5591227,
4386                label: 0,
4387            },
4388            crate::DetectBox {
4389                bbox: crate::BoundingBox {
4390                    xmin: 0.130598,
4391                    ymin: 0.43260583,
4392                    xmax: 0.35098213,
4393                    ymax: 0.9958097,
4394                },
4395                score: 0.33057618,
4396                label: 75,
4397            },
4398        ];
4399
4400        let mut tracker = ByteTrackBuilder::new()
4401            .track_update(0.1)
4402            .track_high_conf(0.3)
4403            .build();
4404
4405        let mut output_boxes = Vec::with_capacity(50);
4406        let mut output_masks = Vec::with_capacity(50);
4407        let mut output_tracks = Vec::with_capacity(50);
4408
4409        decoder
4410            .decode_tracked_quantized(
4411                &mut tracker,
4412                0,
4413                &[out.view().into()],
4414                &mut output_boxes,
4415                &mut output_masks,
4416                &mut output_tracks,
4417            )
4418            .unwrap();
4419
4420        assert_eq!(output_boxes.len(), 2);
4421        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4422        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
4423
4424        let mut last_boxes = output_boxes.clone();
4425
4426        for i in 1..=100 {
4427            let mut out = out.clone();
4428            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
4429            let mut x_values = out.slice_mut(s![0, 0, ..]);
4430            for x in x_values.iter_mut() {
4431                let r: f32 = rng.sample(StandardNormal);
4432                let r = r.clamp(-2.0, 2.0) / 2.0;
4433                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
4434            }
4435
4436            let mut y_values = out.slice_mut(s![0, 1, ..]);
4437            for y in y_values.iter_mut() {
4438                let r: f32 = rng.sample(StandardNormal);
4439                let r = r.clamp(-2.0, 2.0) / 2.0;
4440                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
4441            }
4442
4443            decoder
4444                .decode_tracked_quantized(
4445                    &mut tracker,
4446                    100_000_000 * i / 3, // simulate 33.333ms between frames
4447                    &[out.view().into()],
4448                    &mut output_boxes,
4449                    &mut output_masks,
4450                    &mut output_tracks,
4451                )
4452                .unwrap();
4453
4454            assert_eq!(output_boxes.len(), 2);
4455            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
4456            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
4457
4458            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
4459            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
4460            last_boxes = output_boxes.clone();
4461        }
4462    }
4463
4464    // ─── Shared helpers for tracked decoder tests ────────────────────
4465
4466    fn real_data_expected_boxes() -> [DetectBox; 2] {
4467        [
4468            DetectBox {
4469                bbox: BoundingBox {
4470                    xmin: 0.08515105,
4471                    ymin: 0.7131401,
4472                    xmax: 0.29802868,
4473                    ymax: 0.8195788,
4474                },
4475                score: 0.91537374,
4476                label: 23,
4477            },
4478            DetectBox {
4479                bbox: BoundingBox {
4480                    xmin: 0.59605736,
4481                    ymin: 0.25545314,
4482                    xmax: 0.93666154,
4483                    ymax: 0.72378385,
4484                },
4485                score: 0.91537374,
4486                label: 23,
4487            },
4488        ]
4489    }
4490
4491    fn e2e_expected_boxes_quant() -> [DetectBox; 1] {
4492        [DetectBox {
4493            bbox: BoundingBox {
4494                xmin: 0.12549022,
4495                ymin: 0.12549022,
4496                xmax: 0.23529413,
4497                ymax: 0.23529413,
4498            },
4499            score: 0.98823535,
4500            label: 2,
4501        }]
4502    }
4503
4504    fn e2e_expected_boxes_float() -> [DetectBox; 1] {
4505        [DetectBox {
4506            bbox: BoundingBox {
4507                xmin: 0.1234,
4508                ymin: 0.1234,
4509                xmax: 0.2345,
4510                ymax: 0.2345,
4511            },
4512            score: 0.9876,
4513            label: 2,
4514        }]
4515    }
4516
4517    fn build_yolo_split_segdet_decoder(
4518        score_threshold: f32,
4519        iou_threshold: f32,
4520        quant_boxes: (f32, i32),
4521        quant_protos: (f32, i32),
4522    ) -> crate::Decoder {
4523        DecoderBuilder::default()
4524            .with_config_yolo_split_segdet(
4525                configs::Boxes {
4526                    decoder: configs::DecoderType::Ultralytics,
4527                    quantization: Some(quant_boxes.into()),
4528                    shape: vec![1, 4, 8400],
4529                    dshape: vec![
4530                        (DimName::Batch, 1),
4531                        (DimName::BoxCoords, 4),
4532                        (DimName::NumBoxes, 8400),
4533                    ],
4534                    normalized: Some(true),
4535                },
4536                configs::Scores {
4537                    decoder: configs::DecoderType::Ultralytics,
4538                    quantization: Some(quant_boxes.into()),
4539                    shape: vec![1, 80, 8400],
4540                    dshape: vec![
4541                        (DimName::Batch, 1),
4542                        (DimName::NumClasses, 80),
4543                        (DimName::NumBoxes, 8400),
4544                    ],
4545                },
4546                configs::MaskCoefficients {
4547                    decoder: configs::DecoderType::Ultralytics,
4548                    quantization: Some(quant_boxes.into()),
4549                    shape: vec![1, 32, 8400],
4550                    dshape: vec![
4551                        (DimName::Batch, 1),
4552                        (DimName::NumProtos, 32),
4553                        (DimName::NumBoxes, 8400),
4554                    ],
4555                },
4556                configs::Protos {
4557                    decoder: configs::DecoderType::Ultralytics,
4558                    quantization: Some(quant_protos.into()),
4559                    shape: vec![1, 160, 160, 32],
4560                    dshape: vec![
4561                        (DimName::Batch, 1),
4562                        (DimName::Height, 160),
4563                        (DimName::Width, 160),
4564                        (DimName::NumProtos, 32),
4565                    ],
4566                },
4567            )
4568            .with_score_threshold(score_threshold)
4569            .with_iou_threshold(iou_threshold)
4570            .build()
4571            .unwrap()
4572    }
4573
4574    fn build_yolov8_seg_decoder(score_threshold: f32, iou_threshold: f32) -> crate::Decoder {
4575        let config_yaml = edgefirst_bench::testdata::read_to_string("yolov8_seg.yaml");
4576        DecoderBuilder::default()
4577            .with_config_yaml_str(config_yaml.to_string())
4578            .with_score_threshold(score_threshold)
4579            .with_iou_threshold(iou_threshold)
4580            .build()
4581            .unwrap()
4582    }
4583
4584    // ─── Real-data tracked test macro ───────────────────────────────
4585    //
4586    // Generates tests that load i8 binary test data from testdata/ and
4587    // exercise all (quant/float) × (combined/split) × (masks/proto)
4588    // decoder paths.
4589
4590    macro_rules! real_data_tracked_test {
4591        ($name:ident, quantized, $layout:ident, $output:ident) => {
4592            #[test]
4593            fn $name() {
4594                let is_split = matches!(stringify!($layout), "split");
4595                let is_proto = matches!(stringify!($output), "proto");
4596
4597                let score_threshold = 0.45;
4598                let iou_threshold = 0.45;
4599                let quant_boxes = (0.021287762_f32, 31_i32);
4600                let quant_protos = (0.02491162_f32, -117_i32);
4601
4602                let raw_boxes = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
4603                let raw_boxes = unsafe {
4604                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4605                };
4606                let boxes_i8 =
4607                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4608
4609                let raw_protos = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
4610                let raw_protos = unsafe {
4611                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4612                };
4613                let protos_i8 =
4614                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4615                        .unwrap();
4616
4617                // Pre-split (unused for combined, but harmless)
4618                let mask_split = boxes_i8.slice(s![.., 84.., ..]).to_owned();
4619                let mut scores_split = boxes_i8.slice(s![.., 4..84, ..]).to_owned();
4620                let boxes_split = boxes_i8.slice(s![.., ..4, ..]).to_owned();
4621                let mut boxes_combined = boxes_i8;
4622
4623                let decoder = if is_split {
4624                    build_yolo_split_segdet_decoder(
4625                        score_threshold,
4626                        iou_threshold,
4627                        quant_boxes,
4628                        quant_protos,
4629                    )
4630                } else {
4631                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4632                };
4633
4634                let expected = real_data_expected_boxes();
4635                let mut tracker = ByteTrackBuilder::new()
4636                    .track_update(0.1)
4637                    .track_high_conf(0.7)
4638                    .build();
4639                let mut output_boxes = Vec::with_capacity(50);
4640                let mut output_tracks = Vec::with_capacity(50);
4641
4642                // Frame 1: decode
4643                if is_proto {
4644                    {
4645                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4646                            vec![
4647                                boxes_split.view().into(),
4648                                scores_split.view().into(),
4649                                mask_split.view().into(),
4650                                protos_i8.view().into(),
4651                            ]
4652                        } else {
4653                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4654                        };
4655                        decoder
4656                            .decode_tracked_quantized_proto(
4657                                &mut tracker,
4658                                0,
4659                                &inputs,
4660                                &mut output_boxes,
4661                                &mut output_tracks,
4662                            )
4663                            .unwrap();
4664                    }
4665                    assert_eq!(output_boxes.len(), 2);
4666                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4667                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4668
4669                    // Zero scores for frame 2
4670                    if is_split {
4671                        for score in scores_split.iter_mut() {
4672                            *score = i8::MIN;
4673                        }
4674                    } else {
4675                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4676                            *score = i8::MIN;
4677                        }
4678                    }
4679
4680                    let proto_result = {
4681                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4682                            vec![
4683                                boxes_split.view().into(),
4684                                scores_split.view().into(),
4685                                mask_split.view().into(),
4686                                protos_i8.view().into(),
4687                            ]
4688                        } else {
4689                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4690                        };
4691                        decoder
4692                            .decode_tracked_quantized_proto(
4693                                &mut tracker,
4694                                100_000_000 / 3,
4695                                &inputs,
4696                                &mut output_boxes,
4697                                &mut output_tracks,
4698                            )
4699                            .unwrap()
4700                    };
4701                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4702                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4703                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4704                } else {
4705                    let mut output_masks = Vec::with_capacity(50);
4706                    {
4707                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4708                            vec![
4709                                boxes_split.view().into(),
4710                                scores_split.view().into(),
4711                                mask_split.view().into(),
4712                                protos_i8.view().into(),
4713                            ]
4714                        } else {
4715                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4716                        };
4717                        decoder
4718                            .decode_tracked_quantized(
4719                                &mut tracker,
4720                                0,
4721                                &inputs,
4722                                &mut output_boxes,
4723                                &mut output_masks,
4724                                &mut output_tracks,
4725                            )
4726                            .unwrap();
4727                    }
4728                    assert_eq!(output_boxes.len(), 2);
4729                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4730                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4731
4732                    if is_split {
4733                        for score in scores_split.iter_mut() {
4734                            *score = i8::MIN;
4735                        }
4736                    } else {
4737                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4738                            *score = i8::MIN;
4739                        }
4740                    }
4741
4742                    {
4743                        let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = if is_split {
4744                            vec![
4745                                boxes_split.view().into(),
4746                                scores_split.view().into(),
4747                                mask_split.view().into(),
4748                                protos_i8.view().into(),
4749                            ]
4750                        } else {
4751                            vec![boxes_combined.view().into(), protos_i8.view().into()]
4752                        };
4753                        decoder
4754                            .decode_tracked_quantized(
4755                                &mut tracker,
4756                                100_000_000 / 3,
4757                                &inputs,
4758                                &mut output_boxes,
4759                                &mut output_masks,
4760                                &mut output_tracks,
4761                            )
4762                            .unwrap();
4763                    }
4764                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4765                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4766                    assert!(output_masks.is_empty());
4767                }
4768            }
4769        };
4770        ($name:ident, float, $layout:ident, $output:ident) => {
4771            #[test]
4772            fn $name() {
4773                let is_split = matches!(stringify!($layout), "split");
4774                let is_proto = matches!(stringify!($output), "proto");
4775
4776                let score_threshold = 0.45;
4777                let iou_threshold = 0.45;
4778                let quant_boxes = (0.021287762_f32, 31_i32);
4779                let quant_protos = (0.02491162_f32, -117_i32);
4780
4781                let raw_boxes = edgefirst_bench::testdata::read("yolov8_boxes_116x8400.bin");
4782                let raw_boxes = unsafe {
4783                    std::slice::from_raw_parts(raw_boxes.as_ptr() as *const i8, raw_boxes.len())
4784                };
4785                let boxes_i8 =
4786                    ndarray::Array3::from_shape_vec((1, 116, 8400), raw_boxes.to_vec()).unwrap();
4787                let boxes_f32 = dequantize_ndarray(boxes_i8.view(), quant_boxes.into());
4788
4789                let raw_protos = edgefirst_bench::testdata::read("yolov8_protos_160x160x32.bin");
4790                let raw_protos = unsafe {
4791                    std::slice::from_raw_parts(raw_protos.as_ptr() as *const i8, raw_protos.len())
4792                };
4793                let protos_i8 =
4794                    ndarray::Array4::from_shape_vec((1, 160, 160, 32), raw_protos.to_vec())
4795                        .unwrap();
4796                let protos_f32 = dequantize_ndarray(protos_i8.view(), quant_protos.into());
4797
4798                // Pre-split from dequantized data
4799                let mask_split = boxes_f32.slice(s![.., 84.., ..]).to_owned();
4800                let mut scores_split = boxes_f32.slice(s![.., 4..84, ..]).to_owned();
4801                let boxes_split = boxes_f32.slice(s![.., ..4, ..]).to_owned();
4802                let mut boxes_combined = boxes_f32;
4803
4804                let decoder = if is_split {
4805                    build_yolo_split_segdet_decoder(
4806                        score_threshold,
4807                        iou_threshold,
4808                        quant_boxes,
4809                        quant_protos,
4810                    )
4811                } else {
4812                    build_yolov8_seg_decoder(score_threshold, iou_threshold)
4813                };
4814
4815                let expected = real_data_expected_boxes();
4816                let mut tracker = ByteTrackBuilder::new()
4817                    .track_update(0.1)
4818                    .track_high_conf(0.7)
4819                    .build();
4820                let mut output_boxes = Vec::with_capacity(50);
4821                let mut output_tracks = Vec::with_capacity(50);
4822
4823                if is_proto {
4824                    {
4825                        let inputs = if is_split {
4826                            vec![
4827                                boxes_split.view().into_dyn(),
4828                                scores_split.view().into_dyn(),
4829                                mask_split.view().into_dyn(),
4830                                protos_f32.view().into_dyn(),
4831                            ]
4832                        } else {
4833                            vec![
4834                                boxes_combined.view().into_dyn(),
4835                                protos_f32.view().into_dyn(),
4836                            ]
4837                        };
4838                        decoder
4839                            .decode_tracked_float_proto(
4840                                &mut tracker,
4841                                0,
4842                                &inputs,
4843                                &mut output_boxes,
4844                                &mut output_tracks,
4845                            )
4846                            .unwrap();
4847                    }
4848                    assert_eq!(output_boxes.len(), 2);
4849                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4850                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4851
4852                    if is_split {
4853                        for score in scores_split.iter_mut() {
4854                            *score = 0.0;
4855                        }
4856                    } else {
4857                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4858                            *score = 0.0;
4859                        }
4860                    }
4861
4862                    let proto_result = {
4863                        let inputs = if is_split {
4864                            vec![
4865                                boxes_split.view().into_dyn(),
4866                                scores_split.view().into_dyn(),
4867                                mask_split.view().into_dyn(),
4868                                protos_f32.view().into_dyn(),
4869                            ]
4870                        } else {
4871                            vec![
4872                                boxes_combined.view().into_dyn(),
4873                                protos_f32.view().into_dyn(),
4874                            ]
4875                        };
4876                        decoder
4877                            .decode_tracked_float_proto(
4878                                &mut tracker,
4879                                100_000_000 / 3,
4880                                &inputs,
4881                                &mut output_boxes,
4882                                &mut output_tracks,
4883                            )
4884                            .unwrap()
4885                    };
4886                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4887                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4888                    assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
4889                } else {
4890                    let mut output_masks = Vec::with_capacity(50);
4891                    {
4892                        let inputs = if is_split {
4893                            vec![
4894                                boxes_split.view().into_dyn(),
4895                                scores_split.view().into_dyn(),
4896                                mask_split.view().into_dyn(),
4897                                protos_f32.view().into_dyn(),
4898                            ]
4899                        } else {
4900                            vec![
4901                                boxes_combined.view().into_dyn(),
4902                                protos_f32.view().into_dyn(),
4903                            ]
4904                        };
4905                        decoder
4906                            .decode_tracked_float(
4907                                &mut tracker,
4908                                0,
4909                                &inputs,
4910                                &mut output_boxes,
4911                                &mut output_masks,
4912                                &mut output_tracks,
4913                            )
4914                            .unwrap();
4915                    }
4916                    assert_eq!(output_boxes.len(), 2);
4917                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
4918                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1.0 / 160.0));
4919
4920                    if is_split {
4921                        for score in scores_split.iter_mut() {
4922                            *score = 0.0;
4923                        }
4924                    } else {
4925                        for score in boxes_combined.slice_mut(s![0, 4..84, ..]).iter_mut() {
4926                            *score = 0.0;
4927                        }
4928                    }
4929
4930                    {
4931                        let inputs = if is_split {
4932                            vec![
4933                                boxes_split.view().into_dyn(),
4934                                scores_split.view().into_dyn(),
4935                                mask_split.view().into_dyn(),
4936                                protos_f32.view().into_dyn(),
4937                            ]
4938                        } else {
4939                            vec![
4940                                boxes_combined.view().into_dyn(),
4941                                protos_f32.view().into_dyn(),
4942                            ]
4943                        };
4944                        decoder
4945                            .decode_tracked_float(
4946                                &mut tracker,
4947                                100_000_000 / 3,
4948                                &inputs,
4949                                &mut output_boxes,
4950                                &mut output_masks,
4951                                &mut output_tracks,
4952                            )
4953                            .unwrap();
4954                    }
4955                    assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
4956                    assert!(output_boxes[1].equal_within_delta(&expected[1], 1e-6));
4957                    assert!(output_masks.is_empty());
4958                }
4959            }
4960        };
4961    }
4962
4963    real_data_tracked_test!(test_decoder_tracked_segdet, quantized, combined, masks);
4964    real_data_tracked_test!(test_decoder_tracked_segdet_float, float, combined, masks);
4965    real_data_tracked_test!(
4966        test_decoder_tracked_segdet_proto,
4967        quantized,
4968        combined,
4969        proto
4970    );
4971    real_data_tracked_test!(
4972        test_decoder_tracked_segdet_proto_float,
4973        float,
4974        combined,
4975        proto
4976    );
4977    real_data_tracked_test!(test_decoder_tracked_segdet_split, quantized, split, masks);
4978    real_data_tracked_test!(test_decoder_tracked_segdet_split_float, float, split, masks);
4979    real_data_tracked_test!(
4980        test_decoder_tracked_segdet_split_proto,
4981        quantized,
4982        split,
4983        proto
4984    );
4985    real_data_tracked_test!(
4986        test_decoder_tracked_segdet_split_proto_float,
4987        float,
4988        split,
4989        proto
4990    );
4991
4992    // ─── End-to-end tracked test macro ──────────────────────────────
4993    //
4994    // Generates tests with synthetic data to exercise all tracked
4995    // decode paths without needing real model output files.
4996
4997    const E2E_COMBINED_CONFIG: &str = "
4998decoder_version: yolo26
4999outputs:
5000 - type: detection
5001   decoder: ultralytics
5002   quantization: [0.00784313725490196, 0]
5003   shape: [1, 10, 38]
5004   dshape:
5005    - [batch, 1]
5006    - [num_boxes, 10]
5007    - [num_features, 38]
5008   normalized: true
5009 - type: protos
5010   decoder: ultralytics
5011   quantization: [0.0039215686274509803921568627451, 128]
5012   shape: [1, 160, 160, 32]
5013   dshape:
5014    - [batch, 1]
5015    - [height, 160]
5016    - [width, 160]
5017    - [num_protos, 32]
5018";
5019
5020    const E2E_SPLIT_CONFIG: &str = "
5021decoder_version: yolo26
5022outputs:
5023 - type: boxes
5024   decoder: ultralytics
5025   quantization: [0.00784313725490196, 0]
5026   shape: [1, 10, 4]
5027   dshape:
5028    - [batch, 1]
5029    - [num_boxes, 10]
5030    - [box_coords, 4]
5031   normalized: true
5032 - type: scores
5033   decoder: ultralytics
5034   quantization: [0.00784313725490196, 0]
5035   shape: [1, 10, 1]
5036   dshape:
5037    - [batch, 1]
5038    - [num_boxes, 10]
5039    - [num_classes, 1]
5040 - type: classes
5041   decoder: ultralytics
5042   quantization: [0.00784313725490196, 0]
5043   shape: [1, 10, 1]
5044   dshape:
5045    - [batch, 1]
5046    - [num_boxes, 10]
5047    - [num_classes, 1]
5048 - type: mask_coefficients
5049   decoder: ultralytics
5050   quantization: [0.00784313725490196, 0]
5051   shape: [1, 10, 32]
5052   dshape:
5053    - [batch, 1]
5054    - [num_boxes, 10]
5055    - [num_protos, 32]
5056 - type: protos
5057   decoder: ultralytics
5058   quantization: [0.0039215686274509803921568627451, 128]
5059   shape: [1, 160, 160, 32]
5060   dshape:
5061    - [batch, 1]
5062    - [height, 160]
5063    - [width, 160]
5064    - [num_protos, 32]
5065";
5066
5067    macro_rules! e2e_tracked_test {
5068        ($name:ident, quantized, $layout:ident, $output:ident) => {
5069            #[test]
5070            fn $name() {
5071                let is_split = matches!(stringify!($layout), "split");
5072                let is_proto = matches!(stringify!($output), "proto");
5073
5074                let score_threshold = 0.45;
5075                let iou_threshold = 0.45;
5076
5077                let mut boxes = Array2::zeros((10, 4));
5078                let mut scores = Array2::zeros((10, 1));
5079                let mut classes = Array2::zeros((10, 1));
5080                let mask = Array2::zeros((10, 32));
5081                let protos = Array3::<f64>::zeros((160, 160, 32));
5082                let protos = protos.insert_axis(Axis(0));
5083                let protos_quant = (1.0 / 255.0, 0.0);
5084                let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
5085
5086                boxes
5087                    .slice_mut(s![0, ..])
5088                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5089                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5090                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5091
5092                let detect_quant = (2.0 / 255.0, 0.0);
5093
5094                let decoder = if is_split {
5095                    DecoderBuilder::default()
5096                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5097                        .with_score_threshold(score_threshold)
5098                        .with_iou_threshold(iou_threshold)
5099                        .build()
5100                        .unwrap()
5101                } else {
5102                    DecoderBuilder::default()
5103                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5104                        .with_score_threshold(score_threshold)
5105                        .with_iou_threshold(iou_threshold)
5106                        .build()
5107                        .unwrap()
5108                };
5109
5110                let expected = e2e_expected_boxes_quant();
5111                let mut tracker = ByteTrackBuilder::new()
5112                    .track_update(0.1)
5113                    .track_high_conf(0.7)
5114                    .build();
5115                let mut output_boxes = Vec::with_capacity(50);
5116                let mut output_tracks = Vec::with_capacity(50);
5117
5118                if is_split {
5119                    let boxes = boxes.insert_axis(Axis(0));
5120                    let scores = scores.insert_axis(Axis(0));
5121                    let classes = classes.insert_axis(Axis(0));
5122                    let mask = mask.insert_axis(Axis(0));
5123
5124                    let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5125                    let mut scores: Array3<u8> =
5126                        quantize_ndarray(scores.view(), detect_quant.into());
5127                    let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
5128                    let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5129
5130                    if is_proto {
5131                        {
5132                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5133                                boxes.view().into(),
5134                                scores.view().into(),
5135                                classes.view().into(),
5136                                mask.view().into(),
5137                                protos.view().into(),
5138                            ];
5139                            decoder
5140                                .decode_tracked_quantized_proto(
5141                                    &mut tracker,
5142                                    0,
5143                                    &inputs,
5144                                    &mut output_boxes,
5145                                    &mut output_tracks,
5146                                )
5147                                .unwrap();
5148                        }
5149                        assert_eq!(output_boxes.len(), 1);
5150                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5151
5152                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5153                            *score = u8::MIN;
5154                        }
5155                        let proto_result = {
5156                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5157                                boxes.view().into(),
5158                                scores.view().into(),
5159                                classes.view().into(),
5160                                mask.view().into(),
5161                                protos.view().into(),
5162                            ];
5163                            decoder
5164                                .decode_tracked_quantized_proto(
5165                                    &mut tracker,
5166                                    100_000_000 / 3,
5167                                    &inputs,
5168                                    &mut output_boxes,
5169                                    &mut output_tracks,
5170                                )
5171                                .unwrap()
5172                        };
5173                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5174                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5175                    } else {
5176                        let mut output_masks = Vec::with_capacity(50);
5177                        {
5178                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5179                                boxes.view().into(),
5180                                scores.view().into(),
5181                                classes.view().into(),
5182                                mask.view().into(),
5183                                protos.view().into(),
5184                            ];
5185                            decoder
5186                                .decode_tracked_quantized(
5187                                    &mut tracker,
5188                                    0,
5189                                    &inputs,
5190                                    &mut output_boxes,
5191                                    &mut output_masks,
5192                                    &mut output_tracks,
5193                                )
5194                                .unwrap();
5195                        }
5196                        assert_eq!(output_boxes.len(), 1);
5197                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5198
5199                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5200                            *score = u8::MIN;
5201                        }
5202                        {
5203                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> = vec![
5204                                boxes.view().into(),
5205                                scores.view().into(),
5206                                classes.view().into(),
5207                                mask.view().into(),
5208                                protos.view().into(),
5209                            ];
5210                            decoder
5211                                .decode_tracked_quantized(
5212                                    &mut tracker,
5213                                    100_000_000 / 3,
5214                                    &inputs,
5215                                    &mut output_boxes,
5216                                    &mut output_masks,
5217                                    &mut output_tracks,
5218                                )
5219                                .unwrap();
5220                        }
5221                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5222                        assert!(output_masks.is_empty());
5223                    }
5224                } else {
5225                    // Combined layout
5226                    let detect = ndarray::concatenate![
5227                        Axis(1),
5228                        boxes.view(),
5229                        scores.view(),
5230                        classes.view(),
5231                        mask.view()
5232                    ];
5233                    let detect = detect.insert_axis(Axis(0));
5234                    assert_eq!(detect.shape(), &[1, 10, 38]);
5235                    let mut detect: Array3<u8> =
5236                        quantize_ndarray(detect.view(), detect_quant.into());
5237
5238                    if is_proto {
5239                        {
5240                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5241                                vec![detect.view().into(), protos.view().into()];
5242                            decoder
5243                                .decode_tracked_quantized_proto(
5244                                    &mut tracker,
5245                                    0,
5246                                    &inputs,
5247                                    &mut output_boxes,
5248                                    &mut output_tracks,
5249                                )
5250                                .unwrap();
5251                        }
5252                        assert_eq!(output_boxes.len(), 1);
5253                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5254
5255                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5256                            *score = u8::MIN;
5257                        }
5258                        let proto_result = {
5259                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5260                                vec![detect.view().into(), protos.view().into()];
5261                            decoder
5262                                .decode_tracked_quantized_proto(
5263                                    &mut tracker,
5264                                    100_000_000 / 3,
5265                                    &inputs,
5266                                    &mut output_boxes,
5267                                    &mut output_tracks,
5268                                )
5269                                .unwrap()
5270                        };
5271                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5272                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5273                    } else {
5274                        let mut output_masks = Vec::with_capacity(50);
5275                        {
5276                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5277                                vec![detect.view().into(), protos.view().into()];
5278                            decoder
5279                                .decode_tracked_quantized(
5280                                    &mut tracker,
5281                                    0,
5282                                    &inputs,
5283                                    &mut output_boxes,
5284                                    &mut output_masks,
5285                                    &mut output_tracks,
5286                                )
5287                                .unwrap();
5288                        }
5289                        assert_eq!(output_boxes.len(), 1);
5290                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5291
5292                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5293                            *score = u8::MIN;
5294                        }
5295                        {
5296                            let inputs: Vec<crate::decoder::ArrayViewDQuantized<'_>> =
5297                                vec![detect.view().into(), protos.view().into()];
5298                            decoder
5299                                .decode_tracked_quantized(
5300                                    &mut tracker,
5301                                    100_000_000 / 3,
5302                                    &inputs,
5303                                    &mut output_boxes,
5304                                    &mut output_masks,
5305                                    &mut output_tracks,
5306                                )
5307                                .unwrap();
5308                        }
5309                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5310                        assert!(output_masks.is_empty());
5311                    }
5312                }
5313            }
5314        };
5315        ($name:ident, float, $layout:ident, $output:ident) => {
5316            #[test]
5317            fn $name() {
5318                let is_split = matches!(stringify!($layout), "split");
5319                let is_proto = matches!(stringify!($output), "proto");
5320
5321                let score_threshold = 0.45;
5322                let iou_threshold = 0.45;
5323
5324                let mut boxes = Array2::zeros((10, 4));
5325                let mut scores = Array2::zeros((10, 1));
5326                let mut classes = Array2::zeros((10, 1));
5327                let mask: Array2<f64> = Array2::zeros((10, 32));
5328                let protos = Array3::<f64>::zeros((160, 160, 32));
5329                let protos = protos.insert_axis(Axis(0));
5330
5331                boxes
5332                    .slice_mut(s![0, ..])
5333                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5334                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5335                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5336
5337                let decoder = if is_split {
5338                    DecoderBuilder::default()
5339                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5340                        .with_score_threshold(score_threshold)
5341                        .with_iou_threshold(iou_threshold)
5342                        .build()
5343                        .unwrap()
5344                } else {
5345                    DecoderBuilder::default()
5346                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5347                        .with_score_threshold(score_threshold)
5348                        .with_iou_threshold(iou_threshold)
5349                        .build()
5350                        .unwrap()
5351                };
5352
5353                let expected = e2e_expected_boxes_float();
5354                let mut tracker = ByteTrackBuilder::new()
5355                    .track_update(0.1)
5356                    .track_high_conf(0.7)
5357                    .build();
5358                let mut output_boxes = Vec::with_capacity(50);
5359                let mut output_tracks = Vec::with_capacity(50);
5360
5361                if is_split {
5362                    let boxes = boxes.insert_axis(Axis(0));
5363                    let mut scores = scores.insert_axis(Axis(0));
5364                    let classes = classes.insert_axis(Axis(0));
5365                    let mask = mask.insert_axis(Axis(0));
5366
5367                    if is_proto {
5368                        {
5369                            let inputs = vec![
5370                                boxes.view().into_dyn(),
5371                                scores.view().into_dyn(),
5372                                classes.view().into_dyn(),
5373                                mask.view().into_dyn(),
5374                                protos.view().into_dyn(),
5375                            ];
5376                            decoder
5377                                .decode_tracked_float_proto(
5378                                    &mut tracker,
5379                                    0,
5380                                    &inputs,
5381                                    &mut output_boxes,
5382                                    &mut output_tracks,
5383                                )
5384                                .unwrap();
5385                        }
5386                        assert_eq!(output_boxes.len(), 1);
5387                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5388
5389                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5390                            *score = 0.0;
5391                        }
5392                        let proto_result = {
5393                            let inputs = vec![
5394                                boxes.view().into_dyn(),
5395                                scores.view().into_dyn(),
5396                                classes.view().into_dyn(),
5397                                mask.view().into_dyn(),
5398                                protos.view().into_dyn(),
5399                            ];
5400                            decoder
5401                                .decode_tracked_float_proto(
5402                                    &mut tracker,
5403                                    100_000_000 / 3,
5404                                    &inputs,
5405                                    &mut output_boxes,
5406                                    &mut output_tracks,
5407                                )
5408                                .unwrap()
5409                        };
5410                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5411                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5412                    } else {
5413                        let mut output_masks = Vec::with_capacity(50);
5414                        {
5415                            let inputs = vec![
5416                                boxes.view().into_dyn(),
5417                                scores.view().into_dyn(),
5418                                classes.view().into_dyn(),
5419                                mask.view().into_dyn(),
5420                                protos.view().into_dyn(),
5421                            ];
5422                            decoder
5423                                .decode_tracked_float(
5424                                    &mut tracker,
5425                                    0,
5426                                    &inputs,
5427                                    &mut output_boxes,
5428                                    &mut output_masks,
5429                                    &mut output_tracks,
5430                                )
5431                                .unwrap();
5432                        }
5433                        assert_eq!(output_boxes.len(), 1);
5434                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5435
5436                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5437                            *score = 0.0;
5438                        }
5439                        {
5440                            let inputs = vec![
5441                                boxes.view().into_dyn(),
5442                                scores.view().into_dyn(),
5443                                classes.view().into_dyn(),
5444                                mask.view().into_dyn(),
5445                                protos.view().into_dyn(),
5446                            ];
5447                            decoder
5448                                .decode_tracked_float(
5449                                    &mut tracker,
5450                                    100_000_000 / 3,
5451                                    &inputs,
5452                                    &mut output_boxes,
5453                                    &mut output_masks,
5454                                    &mut output_tracks,
5455                                )
5456                                .unwrap();
5457                        }
5458                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5459                        assert!(output_masks.is_empty());
5460                    }
5461                } else {
5462                    // Combined layout
5463                    let detect = ndarray::concatenate![
5464                        Axis(1),
5465                        boxes.view(),
5466                        scores.view(),
5467                        classes.view(),
5468                        mask.view()
5469                    ];
5470                    let mut detect = detect.insert_axis(Axis(0));
5471                    assert_eq!(detect.shape(), &[1, 10, 38]);
5472
5473                    if is_proto {
5474                        {
5475                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5476                            decoder
5477                                .decode_tracked_float_proto(
5478                                    &mut tracker,
5479                                    0,
5480                                    &inputs,
5481                                    &mut output_boxes,
5482                                    &mut output_tracks,
5483                                )
5484                                .unwrap();
5485                        }
5486                        assert_eq!(output_boxes.len(), 1);
5487                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5488
5489                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5490                            *score = 0.0;
5491                        }
5492                        let proto_result = {
5493                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5494                            decoder
5495                                .decode_tracked_float_proto(
5496                                    &mut tracker,
5497                                    100_000_000 / 3,
5498                                    &inputs,
5499                                    &mut output_boxes,
5500                                    &mut output_tracks,
5501                                )
5502                                .unwrap()
5503                        };
5504                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5505                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5506                    } else {
5507                        let mut output_masks = Vec::with_capacity(50);
5508                        {
5509                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5510                            decoder
5511                                .decode_tracked_float(
5512                                    &mut tracker,
5513                                    0,
5514                                    &inputs,
5515                                    &mut output_boxes,
5516                                    &mut output_masks,
5517                                    &mut output_tracks,
5518                                )
5519                                .unwrap();
5520                        }
5521                        assert_eq!(output_boxes.len(), 1);
5522                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5523
5524                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5525                            *score = 0.0;
5526                        }
5527                        {
5528                            let inputs = vec![detect.view().into_dyn(), protos.view().into_dyn()];
5529                            decoder
5530                                .decode_tracked_float(
5531                                    &mut tracker,
5532                                    100_000_000 / 3,
5533                                    &inputs,
5534                                    &mut output_boxes,
5535                                    &mut output_masks,
5536                                    &mut output_tracks,
5537                                )
5538                                .unwrap();
5539                        }
5540                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5541                        assert!(output_masks.is_empty());
5542                    }
5543                }
5544            }
5545        };
5546    }
5547
5548    e2e_tracked_test!(
5549        test_decoder_tracked_end_to_end_segdet,
5550        quantized,
5551        combined,
5552        masks
5553    );
5554    e2e_tracked_test!(
5555        test_decoder_tracked_end_to_end_segdet_float,
5556        float,
5557        combined,
5558        masks
5559    );
5560    e2e_tracked_test!(
5561        test_decoder_tracked_end_to_end_segdet_proto,
5562        quantized,
5563        combined,
5564        proto
5565    );
5566    e2e_tracked_test!(
5567        test_decoder_tracked_end_to_end_segdet_proto_float,
5568        float,
5569        combined,
5570        proto
5571    );
5572    e2e_tracked_test!(
5573        test_decoder_tracked_end_to_end_segdet_split,
5574        quantized,
5575        split,
5576        masks
5577    );
5578    e2e_tracked_test!(
5579        test_decoder_tracked_end_to_end_segdet_split_float,
5580        float,
5581        split,
5582        masks
5583    );
5584    e2e_tracked_test!(
5585        test_decoder_tracked_end_to_end_segdet_split_proto,
5586        quantized,
5587        split,
5588        proto
5589    );
5590    e2e_tracked_test!(
5591        test_decoder_tracked_end_to_end_segdet_split_proto_float,
5592        float,
5593        split,
5594        proto
5595    );
5596
5597    // ─── End-to-end tracked TensorDyn test macro ────────────────────
5598    //
5599    // Same as e2e_tracked_test but wraps data in TensorDyn and exercises
5600    // the public decode_tracked / decode_proto_tracked API.
5601
5602    macro_rules! e2e_tracked_tensor_test {
5603        ($name:ident, quantized, $layout:ident, $output:ident) => {
5604            #[test]
5605            fn $name() {
5606                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5607
5608                let is_split = matches!(stringify!($layout), "split");
5609                let is_proto = matches!(stringify!($output), "proto");
5610
5611                let score_threshold = 0.45;
5612                let iou_threshold = 0.45;
5613
5614                let mut boxes = Array2::zeros((10, 4));
5615                let mut scores = Array2::zeros((10, 1));
5616                let mut classes = Array2::zeros((10, 1));
5617                let mask = Array2::zeros((10, 32));
5618                let protos_f64 = Array3::<f64>::zeros((160, 160, 32));
5619                let protos_f64 = protos_f64.insert_axis(Axis(0));
5620                let protos_quant = (1.0 / 255.0, 0.0);
5621                let protos_u8: Array4<u8> =
5622                    quantize_ndarray(protos_f64.view(), protos_quant.into());
5623
5624                boxes
5625                    .slice_mut(s![0, ..])
5626                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5627                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5628                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5629
5630                let detect_quant = (2.0 / 255.0, 0.0);
5631
5632                let decoder = if is_split {
5633                    DecoderBuilder::default()
5634                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5635                        .with_score_threshold(score_threshold)
5636                        .with_iou_threshold(iou_threshold)
5637                        .build()
5638                        .unwrap()
5639                } else {
5640                    DecoderBuilder::default()
5641                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5642                        .with_score_threshold(score_threshold)
5643                        .with_iou_threshold(iou_threshold)
5644                        .build()
5645                        .unwrap()
5646                };
5647
5648                // Helper to wrap a u8 slice into a TensorDyn
5649                let make_u8_tensor =
5650                    |shape: &[usize], data: &[u8]| -> edgefirst_tensor::TensorDyn {
5651                        let t = Tensor::<u8>::new(shape, None, None).unwrap();
5652                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5653                        t.into()
5654                    };
5655
5656                let expected = e2e_expected_boxes_quant();
5657                let mut tracker = ByteTrackBuilder::new()
5658                    .track_update(0.1)
5659                    .track_high_conf(0.7)
5660                    .build();
5661                let mut output_boxes = Vec::with_capacity(50);
5662                let mut output_tracks = Vec::with_capacity(50);
5663
5664                let protos_td = make_u8_tensor(protos_u8.shape(), protos_u8.as_slice().unwrap());
5665
5666                if is_split {
5667                    let boxes = boxes.insert_axis(Axis(0));
5668                    let scores = scores.insert_axis(Axis(0));
5669                    let classes = classes.insert_axis(Axis(0));
5670                    let mask = mask.insert_axis(Axis(0));
5671
5672                    let boxes_q: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
5673                    let mut scores_q: Array3<u8> =
5674                        quantize_ndarray(scores.view(), detect_quant.into());
5675                    let classes_q: Array3<u8> =
5676                        quantize_ndarray(classes.view(), detect_quant.into());
5677                    let mask_q: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
5678
5679                    let boxes_td = make_u8_tensor(boxes_q.shape(), boxes_q.as_slice().unwrap());
5680                    let classes_td =
5681                        make_u8_tensor(classes_q.shape(), classes_q.as_slice().unwrap());
5682                    let mask_td = make_u8_tensor(mask_q.shape(), mask_q.as_slice().unwrap());
5683
5684                    if is_proto {
5685                        let scores_td =
5686                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5687                        decoder
5688                            .decode_proto_tracked(
5689                                &mut tracker,
5690                                0,
5691                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5692                                &mut output_boxes,
5693                                &mut output_tracks,
5694                            )
5695                            .unwrap();
5696
5697                        assert_eq!(output_boxes.len(), 1);
5698                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5699
5700                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5701                            *score = u8::MIN;
5702                        }
5703                        let scores_td =
5704                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5705                        let proto_result = decoder
5706                            .decode_proto_tracked(
5707                                &mut tracker,
5708                                100_000_000 / 3,
5709                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5710                                &mut output_boxes,
5711                                &mut output_tracks,
5712                            )
5713                            .unwrap();
5714                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5715                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5716                    } else {
5717                        let scores_td =
5718                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5719                        let mut output_masks = Vec::with_capacity(50);
5720                        decoder
5721                            .decode_tracked(
5722                                &mut tracker,
5723                                0,
5724                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5725                                &mut output_boxes,
5726                                &mut output_masks,
5727                                &mut output_tracks,
5728                            )
5729                            .unwrap();
5730
5731                        assert_eq!(output_boxes.len(), 1);
5732                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5733
5734                        for score in scores_q.slice_mut(s![.., .., ..]).iter_mut() {
5735                            *score = u8::MIN;
5736                        }
5737                        let scores_td =
5738                            make_u8_tensor(scores_q.shape(), scores_q.as_slice().unwrap());
5739                        decoder
5740                            .decode_tracked(
5741                                &mut tracker,
5742                                100_000_000 / 3,
5743                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5744                                &mut output_boxes,
5745                                &mut output_masks,
5746                                &mut output_tracks,
5747                            )
5748                            .unwrap();
5749                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5750                        assert!(output_masks.is_empty());
5751                    }
5752                } else {
5753                    // Combined layout
5754                    let detect = ndarray::concatenate![
5755                        Axis(1),
5756                        boxes.view(),
5757                        scores.view(),
5758                        classes.view(),
5759                        mask.view()
5760                    ];
5761                    let detect = detect.insert_axis(Axis(0));
5762                    assert_eq!(detect.shape(), &[1, 10, 38]);
5763                    // Ensure contiguous layout after concatenation for as_slice()
5764                    let detect =
5765                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5766                            .unwrap();
5767                    let mut detect_q: Array3<u8> =
5768                        quantize_ndarray(detect.view(), detect_quant.into());
5769
5770                    if is_proto {
5771                        let detect_td =
5772                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5773                        decoder
5774                            .decode_proto_tracked(
5775                                &mut tracker,
5776                                0,
5777                                &[&detect_td, &protos_td],
5778                                &mut output_boxes,
5779                                &mut output_tracks,
5780                            )
5781                            .unwrap();
5782
5783                        assert_eq!(output_boxes.len(), 1);
5784                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5785
5786                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5787                            *score = u8::MIN;
5788                        }
5789                        let detect_td =
5790                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5791                        let proto_result = decoder
5792                            .decode_proto_tracked(
5793                                &mut tracker,
5794                                100_000_000 / 3,
5795                                &[&detect_td, &protos_td],
5796                                &mut output_boxes,
5797                                &mut output_tracks,
5798                            )
5799                            .unwrap();
5800                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5801                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5802                    } else {
5803                        let detect_td =
5804                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5805                        let mut output_masks = Vec::with_capacity(50);
5806                        decoder
5807                            .decode_tracked(
5808                                &mut tracker,
5809                                0,
5810                                &[&detect_td, &protos_td],
5811                                &mut output_boxes,
5812                                &mut output_masks,
5813                                &mut output_tracks,
5814                            )
5815                            .unwrap();
5816
5817                        assert_eq!(output_boxes.len(), 1);
5818                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5819
5820                        for score in detect_q.slice_mut(s![.., .., 4]).iter_mut() {
5821                            *score = u8::MIN;
5822                        }
5823                        let detect_td =
5824                            make_u8_tensor(detect_q.shape(), detect_q.as_slice().unwrap());
5825                        decoder
5826                            .decode_tracked(
5827                                &mut tracker,
5828                                100_000_000 / 3,
5829                                &[&detect_td, &protos_td],
5830                                &mut output_boxes,
5831                                &mut output_masks,
5832                                &mut output_tracks,
5833                            )
5834                            .unwrap();
5835                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5836                        assert!(output_masks.is_empty());
5837                    }
5838                }
5839            }
5840        };
5841        ($name:ident, float, $layout:ident, $output:ident) => {
5842            #[test]
5843            fn $name() {
5844                use edgefirst_tensor::{Tensor, TensorMapTrait, TensorTrait};
5845
5846                let is_split = matches!(stringify!($layout), "split");
5847                let is_proto = matches!(stringify!($output), "proto");
5848
5849                let score_threshold = 0.45;
5850                let iou_threshold = 0.45;
5851
5852                let mut boxes = Array2::zeros((10, 4));
5853                let mut scores = Array2::zeros((10, 1));
5854                let mut classes = Array2::zeros((10, 1));
5855                let mask: Array2<f64> = Array2::zeros((10, 32));
5856                let protos = Array3::<f64>::zeros((160, 160, 32));
5857                let protos = protos.insert_axis(Axis(0));
5858
5859                boxes
5860                    .slice_mut(s![0, ..])
5861                    .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5862                scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5863                classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5864
5865                let decoder = if is_split {
5866                    DecoderBuilder::default()
5867                        .with_config_yaml_str(E2E_SPLIT_CONFIG.to_string())
5868                        .with_score_threshold(score_threshold)
5869                        .with_iou_threshold(iou_threshold)
5870                        .build()
5871                        .unwrap()
5872                } else {
5873                    DecoderBuilder::default()
5874                        .with_config_yaml_str(E2E_COMBINED_CONFIG.to_string())
5875                        .with_score_threshold(score_threshold)
5876                        .with_iou_threshold(iou_threshold)
5877                        .build()
5878                        .unwrap()
5879                };
5880
5881                // Helper to wrap an f64 slice into a TensorDyn
5882                let make_f64_tensor =
5883                    |shape: &[usize], data: &[f64]| -> edgefirst_tensor::TensorDyn {
5884                        let t = Tensor::<f64>::new(shape, None, None).unwrap();
5885                        t.map().unwrap().as_mut_slice()[..data.len()].copy_from_slice(data);
5886                        t.into()
5887                    };
5888
5889                let expected = e2e_expected_boxes_float();
5890                let mut tracker = ByteTrackBuilder::new()
5891                    .track_update(0.1)
5892                    .track_high_conf(0.7)
5893                    .build();
5894                let mut output_boxes = Vec::with_capacity(50);
5895                let mut output_tracks = Vec::with_capacity(50);
5896
5897                let protos_td = make_f64_tensor(protos.shape(), protos.as_slice().unwrap());
5898
5899                if is_split {
5900                    let boxes = boxes.insert_axis(Axis(0));
5901                    let mut scores = scores.insert_axis(Axis(0));
5902                    let classes = classes.insert_axis(Axis(0));
5903                    let mask = mask.insert_axis(Axis(0));
5904
5905                    let boxes_td = make_f64_tensor(boxes.shape(), boxes.as_slice().unwrap());
5906                    let classes_td = make_f64_tensor(classes.shape(), classes.as_slice().unwrap());
5907                    let mask_td = make_f64_tensor(mask.shape(), mask.as_slice().unwrap());
5908
5909                    if is_proto {
5910                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5911                        decoder
5912                            .decode_proto_tracked(
5913                                &mut tracker,
5914                                0,
5915                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5916                                &mut output_boxes,
5917                                &mut output_tracks,
5918                            )
5919                            .unwrap();
5920
5921                        assert_eq!(output_boxes.len(), 1);
5922                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5923
5924                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5925                            *score = 0.0;
5926                        }
5927                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5928                        let proto_result = decoder
5929                            .decode_proto_tracked(
5930                                &mut tracker,
5931                                100_000_000 / 3,
5932                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5933                                &mut output_boxes,
5934                                &mut output_tracks,
5935                            )
5936                            .unwrap();
5937                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5938                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
5939                    } else {
5940                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5941                        let mut output_masks = Vec::with_capacity(50);
5942                        decoder
5943                            .decode_tracked(
5944                                &mut tracker,
5945                                0,
5946                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5947                                &mut output_boxes,
5948                                &mut output_masks,
5949                                &mut output_tracks,
5950                            )
5951                            .unwrap();
5952
5953                        assert_eq!(output_boxes.len(), 1);
5954                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
5955
5956                        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5957                            *score = 0.0;
5958                        }
5959                        let scores_td = make_f64_tensor(scores.shape(), scores.as_slice().unwrap());
5960                        decoder
5961                            .decode_tracked(
5962                                &mut tracker,
5963                                100_000_000 / 3,
5964                                &[&boxes_td, &scores_td, &classes_td, &mask_td, &protos_td],
5965                                &mut output_boxes,
5966                                &mut output_masks,
5967                                &mut output_tracks,
5968                            )
5969                            .unwrap();
5970                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
5971                        assert!(output_masks.is_empty());
5972                    }
5973                } else {
5974                    // Combined layout
5975                    let detect = ndarray::concatenate![
5976                        Axis(1),
5977                        boxes.view(),
5978                        scores.view(),
5979                        classes.view(),
5980                        mask.view()
5981                    ];
5982                    let detect = detect.insert_axis(Axis(0));
5983                    assert_eq!(detect.shape(), &[1, 10, 38]);
5984                    // Ensure contiguous layout after concatenation for as_slice()
5985                    let mut detect =
5986                        Array3::from_shape_vec(detect.raw_dim(), detect.iter().copied().collect())
5987                            .unwrap();
5988
5989                    if is_proto {
5990                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
5991                        decoder
5992                            .decode_proto_tracked(
5993                                &mut tracker,
5994                                0,
5995                                &[&detect_td, &protos_td],
5996                                &mut output_boxes,
5997                                &mut output_tracks,
5998                            )
5999                            .unwrap();
6000
6001                        assert_eq!(output_boxes.len(), 1);
6002                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6003
6004                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6005                            *score = 0.0;
6006                        }
6007                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6008                        let proto_result = decoder
6009                            .decode_proto_tracked(
6010                                &mut tracker,
6011                                100_000_000 / 3,
6012                                &[&detect_td, &protos_td],
6013                                &mut output_boxes,
6014                                &mut output_tracks,
6015                            )
6016                            .unwrap();
6017                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6018                        assert!(proto_result.is_some_and(|x| x.mask_coefficients.shape()[0] == 0));
6019                    } else {
6020                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6021                        let mut output_masks = Vec::with_capacity(50);
6022                        decoder
6023                            .decode_tracked(
6024                                &mut tracker,
6025                                0,
6026                                &[&detect_td, &protos_td],
6027                                &mut output_boxes,
6028                                &mut output_masks,
6029                                &mut output_tracks,
6030                            )
6031                            .unwrap();
6032
6033                        assert_eq!(output_boxes.len(), 1);
6034                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1.0 / 160.0));
6035
6036                        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6037                            *score = 0.0;
6038                        }
6039                        let detect_td = make_f64_tensor(detect.shape(), detect.as_slice().unwrap());
6040                        decoder
6041                            .decode_tracked(
6042                                &mut tracker,
6043                                100_000_000 / 3,
6044                                &[&detect_td, &protos_td],
6045                                &mut output_boxes,
6046                                &mut output_masks,
6047                                &mut output_tracks,
6048                            )
6049                            .unwrap();
6050                        assert!(output_boxes[0].equal_within_delta(&expected[0], 1e-6));
6051                        assert!(output_masks.is_empty());
6052                    }
6053                }
6054            }
6055        };
6056    }
6057
6058    e2e_tracked_tensor_test!(
6059        test_decoder_tracked_tensor_end_to_end_segdet,
6060        quantized,
6061        combined,
6062        masks
6063    );
6064    e2e_tracked_tensor_test!(
6065        test_decoder_tracked_tensor_end_to_end_segdet_float,
6066        float,
6067        combined,
6068        masks
6069    );
6070    e2e_tracked_tensor_test!(
6071        test_decoder_tracked_tensor_end_to_end_segdet_proto,
6072        quantized,
6073        combined,
6074        proto
6075    );
6076    e2e_tracked_tensor_test!(
6077        test_decoder_tracked_tensor_end_to_end_segdet_proto_float,
6078        float,
6079        combined,
6080        proto
6081    );
6082    e2e_tracked_tensor_test!(
6083        test_decoder_tracked_tensor_end_to_end_segdet_split,
6084        quantized,
6085        split,
6086        masks
6087    );
6088    e2e_tracked_tensor_test!(
6089        test_decoder_tracked_tensor_end_to_end_segdet_split_float,
6090        float,
6091        split,
6092        masks
6093    );
6094    e2e_tracked_tensor_test!(
6095        test_decoder_tracked_tensor_end_to_end_segdet_split_proto,
6096        quantized,
6097        split,
6098        proto
6099    );
6100    e2e_tracked_tensor_test!(
6101        test_decoder_tracked_tensor_end_to_end_segdet_split_proto_float,
6102        float,
6103        split,
6104        proto
6105    );
6106
6107    #[test]
6108    fn test_decoder_tracked_linear_motion() {
6109        use crate::configs::{DecoderType, Nms};
6110        use crate::DecoderBuilder;
6111
6112        let score_threshold = 0.25;
6113        let iou_threshold = 0.1;
6114        let out = edgefirst_bench::testdata::read("yolov8s_80_classes.bin");
6115        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
6116        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
6117        let quant = (0.0040811873, -123).into();
6118
6119        let decoder = DecoderBuilder::default()
6120            .with_config_yolo_det(
6121                crate::configs::Detection {
6122                    decoder: DecoderType::Ultralytics,
6123                    shape: vec![1, 84, 8400],
6124                    anchors: None,
6125                    quantization: Some(quant),
6126                    dshape: vec![
6127                        (crate::configs::DimName::Batch, 1),
6128                        (crate::configs::DimName::NumFeatures, 84),
6129                        (crate::configs::DimName::NumBoxes, 8400),
6130                    ],
6131                    normalized: Some(true),
6132                },
6133                None,
6134            )
6135            .with_score_threshold(score_threshold)
6136            .with_iou_threshold(iou_threshold)
6137            .with_nms(Some(Nms::ClassAgnostic))
6138            .build()
6139            .unwrap();
6140
6141        let mut expected_boxes = [
6142            DetectBox {
6143                bbox: BoundingBox {
6144                    xmin: 0.5285137,
6145                    ymin: 0.05305544,
6146                    xmax: 0.87541467,
6147                    ymax: 0.9998909,
6148                },
6149                score: 0.5591227,
6150                label: 0,
6151            },
6152            DetectBox {
6153                bbox: BoundingBox {
6154                    xmin: 0.130598,
6155                    ymin: 0.43260583,
6156                    xmax: 0.35098213,
6157                    ymax: 0.9958097,
6158                },
6159                score: 0.33057618,
6160                label: 75,
6161            },
6162        ];
6163
6164        let mut tracker = ByteTrackBuilder::new()
6165            .track_update(0.1)
6166            .track_high_conf(0.3)
6167            .build();
6168
6169        let mut output_boxes = Vec::with_capacity(50);
6170        let mut output_masks = Vec::with_capacity(50);
6171        let mut output_tracks = Vec::with_capacity(50);
6172
6173        decoder
6174            .decode_tracked_quantized(
6175                &mut tracker,
6176                0,
6177                &[out.view().into()],
6178                &mut output_boxes,
6179                &mut output_masks,
6180                &mut output_tracks,
6181            )
6182            .unwrap();
6183
6184        assert_eq!(output_boxes.len(), 2);
6185        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6186        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
6187
6188        for i in 1..=100 {
6189            let mut out = out.clone();
6190            // introduce linear movement into the XY coordinates
6191            let mut x_values = out.slice_mut(s![0, 0, ..]);
6192            for x in x_values.iter_mut() {
6193                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
6194            }
6195
6196            decoder
6197                .decode_tracked_quantized(
6198                    &mut tracker,
6199                    100_000_000 * i / 3, // simulate 33.333ms between frames
6200                    &[out.view().into()],
6201                    &mut output_boxes,
6202                    &mut output_masks,
6203                    &mut output_tracks,
6204                )
6205                .unwrap();
6206
6207            assert_eq!(output_boxes.len(), 2);
6208        }
6209        let tracks = tracker.get_active_tracks();
6210        let predicted_boxes: Vec<_> = tracks
6211            .iter()
6212            .map(|track| {
6213                let mut l = track.last_box;
6214                l.bbox = track.info.tracked_location.into();
6215                l
6216            })
6217            .collect();
6218        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
6219        expected_boxes[0].bbox.xmax += 0.1;
6220        expected_boxes[1].bbox.xmin += 0.1;
6221        expected_boxes[1].bbox.xmax += 0.1;
6222
6223        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6224        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6225
6226        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6227        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
6228        for score in scores_values.iter_mut() {
6229            *score = i8::MIN; // set all scores to minimum to simulate no detections
6230        }
6231        decoder
6232            .decode_tracked_quantized(
6233                &mut tracker,
6234                100_000_000 * 101 / 3,
6235                &[out.view().into()],
6236                &mut output_boxes,
6237                &mut output_masks,
6238                &mut output_tracks,
6239            )
6240            .unwrap();
6241        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
6242        expected_boxes[0].bbox.xmax += 0.001;
6243        expected_boxes[1].bbox.xmin += 0.001;
6244        expected_boxes[1].bbox.xmax += 0.001;
6245
6246        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
6247        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
6248    }
6249
6250    #[test]
6251    fn test_decoder_tracked_end_to_end_float() {
6252        let score_threshold = 0.45;
6253        let iou_threshold = 0.45;
6254
6255        let mut boxes = Array2::zeros((10, 4));
6256        let mut scores = Array2::zeros((10, 1));
6257        let mut classes = Array2::zeros((10, 1));
6258
6259        boxes
6260            .slice_mut(s![0, ..,])
6261            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
6262        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
6263        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
6264
6265        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
6266        let mut detect = detect.insert_axis(Axis(0));
6267        assert_eq!(detect.shape(), &[1, 10, 6]);
6268        let config = "
6269decoder_version: yolo26
6270outputs:
6271 - type: detection
6272   decoder: ultralytics
6273   quantization: [0.00784313725490196, 0]
6274   shape: [1, 10, 6]
6275   dshape:
6276    - [batch, 1]
6277    - [num_boxes, 10]
6278    - [num_features, 6]
6279   normalized: true
6280";
6281
6282        let decoder = DecoderBuilder::default()
6283            .with_config_yaml_str(config.to_string())
6284            .with_score_threshold(score_threshold)
6285            .with_iou_threshold(iou_threshold)
6286            .build()
6287            .unwrap();
6288
6289        let expected_boxes = [DetectBox {
6290            bbox: BoundingBox {
6291                xmin: 0.1234,
6292                ymin: 0.1234,
6293                xmax: 0.2345,
6294                ymax: 0.2345,
6295            },
6296            score: 0.9876,
6297            label: 2,
6298        }];
6299
6300        let mut tracker = ByteTrackBuilder::new()
6301            .track_update(0.1)
6302            .track_high_conf(0.7)
6303            .build();
6304
6305        let mut output_boxes = Vec::with_capacity(50);
6306        let mut output_masks = Vec::with_capacity(50);
6307        let mut output_tracks = Vec::with_capacity(50);
6308
6309        decoder
6310            .decode_tracked_float(
6311                &mut tracker,
6312                0,
6313                &[detect.view().into_dyn()],
6314                &mut output_boxes,
6315                &mut output_masks,
6316                &mut output_tracks,
6317            )
6318            .unwrap();
6319
6320        assert_eq!(output_boxes.len(), 1);
6321        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6322
6323        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
6324
6325        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
6326            *score = 0.0; // set all scores to minimum to simulate no detections
6327        }
6328
6329        decoder
6330            .decode_tracked_float(
6331                &mut tracker,
6332                100_000_000 / 3,
6333                &[detect.view().into_dyn()],
6334                &mut output_boxes,
6335                &mut output_masks,
6336                &mut output_tracks,
6337            )
6338            .unwrap();
6339        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
6340    }
6341}