Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod yolo;
77
78mod decoder;
79pub use decoder::*;
80
81pub use configs::{DecoderVersion, Nms};
82pub use error::{DecoderError, DecoderResult};
83
84use crate::{
85    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
86    yolo::yolo_segmentation_to_mask,
87};
88
89/// Trait to convert bounding box formats to XYXY float format
90pub trait BBoxTypeTrait {
91    /// Converts the bbox into XYXY float format.
92    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
93
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
96        input: &[B; 4],
97        quant: Quantization,
98    ) -> [A; 4]
99    where
100        f32: AsPrimitive<A>,
101        i32: AsPrimitive<A>;
102
103    /// Converts the bbox into XYXY float format.
104    ///
105    /// # Examples
106    /// ```rust
107    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
108    /// # use ndarray::array;
109    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
110    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
111    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
112    /// ```
113    #[inline(always)]
114    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
115        input: ArrayView1<B>,
116    ) -> [A; 4] {
117        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
118    }
119
120    #[inline(always)]
121    /// Converts the bbox into XYXY float format.
122    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
123        input: ArrayView1<B>,
124        quant: Quantization,
125    ) -> [A; 4]
126    where
127        f32: AsPrimitive<A>,
128        i32: AsPrimitive<A>,
129    {
130        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
131    }
132}
133
134/// Converts XYXY bounding boxes to XYXY
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct XYXY {}
137
138impl BBoxTypeTrait for XYXY {
139    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
140        input.map(|b| b.as_())
141    }
142
143    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
144        input: &[B; 4],
145        quant: Quantization,
146    ) -> [A; 4]
147    where
148        f32: AsPrimitive<A>,
149        i32: AsPrimitive<A>,
150    {
151        let scale = quant.scale.as_();
152        let zp = quant.zero_point.as_();
153        input.map(|b| (b.as_() - zp) * scale)
154    }
155
156    #[inline(always)]
157    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
158        input: ArrayView1<B>,
159    ) -> [A; 4] {
160        [
161            input[0].as_(),
162            input[1].as_(),
163            input[2].as_(),
164            input[3].as_(),
165        ]
166    }
167}
168
169/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
170/// box
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub struct XYWH {}
173
174impl BBoxTypeTrait for XYWH {
175    #[inline(always)]
176    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
177        let half = A::one() / (A::one() + A::one());
178        [
179            (input[0].as_()) - (input[2].as_() * half),
180            (input[1].as_()) - (input[3].as_() * half),
181            (input[0].as_()) + (input[2].as_() * half),
182            (input[1].as_()) + (input[3].as_() * half),
183        ]
184    }
185
186    #[inline(always)]
187    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
188        input: &[B; 4],
189        quant: Quantization,
190    ) -> [A; 4]
191    where
192        f32: AsPrimitive<A>,
193        i32: AsPrimitive<A>,
194    {
195        let scale = quant.scale.as_();
196        let half_scale = (quant.scale * 0.5).as_();
197        let zp = quant.zero_point.as_();
198        let [x, y, w, h] = [
199            (input[0].as_() - zp) * scale,
200            (input[1].as_() - zp) * scale,
201            (input[2].as_() - zp) * half_scale,
202            (input[3].as_() - zp) * half_scale,
203        ];
204
205        [x - w, y - h, x + w, y + h]
206    }
207
208    #[inline(always)]
209    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
210        input: ArrayView1<B>,
211    ) -> [A; 4] {
212        let half = A::one() / (A::one() + A::one());
213        [
214            (input[0].as_()) - (input[2].as_() * half),
215            (input[1].as_()) - (input[3].as_() * half),
216            (input[0].as_()) + (input[2].as_() * half),
217            (input[1].as_()) + (input[3].as_() * half),
218        ]
219    }
220}
221
222/// Describes the quantization parameters for a tensor
223#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct Quantization {
225    pub scale: f32,
226    pub zero_point: i32,
227}
228
229impl Quantization {
230    /// Creates a new Quantization struct
231    /// # Examples
232    /// ```
233    /// # use edgefirst_decoder::Quantization;
234    /// let quant = Quantization::new(0.1, -128);
235    /// assert_eq!(quant.scale, 0.1);
236    /// assert_eq!(quant.zero_point, -128);
237    /// ```
238    pub fn new(scale: f32, zero_point: i32) -> Self {
239        Self { scale, zero_point }
240    }
241}
242
243impl From<QuantTuple> for Quantization {
244    /// Creates a new Quantization struct from a QuantTuple
245    /// # Examples
246    /// ```
247    /// # use edgefirst_decoder::Quantization;
248    /// # use edgefirst_decoder::configs::QuantTuple;
249    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
250    /// let quant = Quantization::from(quant_tuple);
251    /// assert_eq!(quant.scale, 0.1);
252    /// assert_eq!(quant.zero_point, -128);
253    /// ```
254    fn from(quant_tuple: QuantTuple) -> Quantization {
255        Quantization {
256            scale: quant_tuple.0,
257            zero_point: quant_tuple.1,
258        }
259    }
260}
261
262impl<S, Z> From<(S, Z)> for Quantization
263where
264    S: AsPrimitive<f32>,
265    Z: AsPrimitive<i32>,
266{
267    /// Creates a new Quantization struct from a tuple
268    /// # Examples
269    /// ```
270    /// # use edgefirst_decoder::Quantization;
271    /// let quant = Quantization::from((0.1_f64, -128_i64));
272    /// assert_eq!(quant.scale, 0.1);
273    /// assert_eq!(quant.zero_point, -128);
274    /// ```
275    fn from((scale, zp): (S, Z)) -> Quantization {
276        Self {
277            scale: scale.as_(),
278            zero_point: zp.as_(),
279        }
280    }
281}
282
283impl Default for Quantization {
284    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
285    /// # Examples
286    /// ```rust
287    /// # use edgefirst_decoder::Quantization;
288    /// let quant = Quantization::default();
289    /// assert_eq!(quant.scale, 1.0);
290    /// assert_eq!(quant.zero_point, 0);
291    /// ```
292    fn default() -> Self {
293        Self {
294            scale: 1.0,
295            zero_point: 0,
296        }
297    }
298}
299
300/// A detection box with f32 bbox and score
301#[derive(Debug, Clone, Copy, PartialEq, Default)]
302pub struct DetectBox {
303    pub bbox: BoundingBox,
304    /// model-specific score for this detection, higher implies more confidence
305    pub score: f32,
306    /// label index for this detection
307    pub label: usize,
308}
309
310/// A bounding box with f32 coordinates in XYXY format
311#[derive(Debug, Clone, Copy, PartialEq, Default)]
312pub struct BoundingBox {
313    /// left-most normalized coordinate of the bounding box
314    pub xmin: f32,
315    /// top-most normalized coordinate of the bounding box
316    pub ymin: f32,
317    /// right-most normalized coordinate of the bounding box
318    pub xmax: f32,
319    /// bottom-most normalized coordinate of the bounding box
320    pub ymax: f32,
321}
322
323impl BoundingBox {
324    /// Creates a new BoundingBox from the given coordinates
325    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
326        Self {
327            xmin,
328            ymin,
329            xmax,
330            ymax,
331        }
332    }
333
334    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
335    ///
336    /// ```
337    /// # use edgefirst_decoder::BoundingBox;
338    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
339    /// let canonical_bbox = bbox.to_canonical();
340    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
341    /// ```
342    pub fn to_canonical(&self) -> Self {
343        let xmin = self.xmin.min(self.xmax);
344        let xmax = self.xmin.max(self.xmax);
345        let ymin = self.ymin.min(self.ymax);
346        let ymax = self.ymin.max(self.ymax);
347        BoundingBox {
348            xmin,
349            ymin,
350            xmax,
351            ymax,
352        }
353    }
354}
355
356impl From<BoundingBox> for [f32; 4] {
357    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
358    /// xmax, ymax order
359    /// # Examples
360    /// ```
361    /// # use edgefirst_decoder::BoundingBox;
362    /// let bbox = BoundingBox {
363    ///     xmin: 0.1,
364    ///     ymin: 0.2,
365    ///     xmax: 0.3,
366    ///     ymax: 0.4,
367    /// };
368    /// let arr: [f32; 4] = bbox.into();
369    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
370    /// ```
371    fn from(b: BoundingBox) -> Self {
372        [b.xmin, b.ymin, b.xmax, b.ymax]
373    }
374}
375
376impl From<[f32; 4]> for BoundingBox {
377    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
378    // BoundingBox
379    fn from(arr: [f32; 4]) -> Self {
380        BoundingBox {
381            xmin: arr[0],
382            ymin: arr[1],
383            xmax: arr[2],
384            ymax: arr[3],
385        }
386    }
387}
388
389impl DetectBox {
390    /// Returns true if one detect box is equal to another detect box, within
391    /// the given `eps`
392    ///
393    /// # Examples
394    /// ```
395    /// # use edgefirst_decoder::DetectBox;
396    /// let box1 = DetectBox {
397    ///     bbox: edgefirst_decoder::BoundingBox {
398    ///         xmin: 0.1,
399    ///         ymin: 0.2,
400    ///         xmax: 0.3,
401    ///         ymax: 0.4,
402    ///     },
403    ///     score: 0.5,
404    ///     label: 1,
405    /// };
406    /// let box2 = DetectBox {
407    ///     bbox: edgefirst_decoder::BoundingBox {
408    ///         xmin: 0.101,
409    ///         ymin: 0.199,
410    ///         xmax: 0.301,
411    ///         ymax: 0.399,
412    ///     },
413    ///     score: 0.510,
414    ///     label: 1,
415    /// };
416    /// assert!(box1.equal_within_delta(&box2, 0.011));
417    /// ```
418    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
419        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
420        self.label == rhs.label
421            && eq_delta(self.score, rhs.score)
422            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
423            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
424            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
425            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
426    }
427}
428
429/// A segmentation result with a segmentation mask, and a normalized bounding
430/// box representing the area that the segmentation mask covers
431#[derive(Debug, Clone, PartialEq, Default)]
432pub struct Segmentation {
433    /// left-most normalized coordinate of the segmentation box
434    pub xmin: f32,
435    /// top-most normalized coordinate of the segmentation box
436    pub ymin: f32,
437    /// right-most normalized coordinate of the segmentation box
438    pub xmax: f32,
439    /// bottom-most normalized coordinate of the segmentation box
440    pub ymax: f32,
441    /// 3D segmentation array of shape `(H, W, C)`.
442    ///
443    /// For instance segmentation (e.g. YOLO): `C=1` — per-instance mask with
444    /// continuous sigmoid confidence values quantized to u8 (0 = background,
445    /// 255 = full confidence). Renderers typically threshold at 128 (sigmoid
446    /// 0.5) or use smooth interpolation for anti-aliased edges.
447    ///
448    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
449    /// class scores where the object class is the argmax index.
450    pub segmentation: Array3<u8>,
451}
452
453/// Prototype tensor variants for fused decode+render pipelines.
454///
455/// Carries either raw quantized data (to skip CPU dequantization and let the
456/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
457/// paths).
458#[derive(Debug, Clone)]
459pub enum ProtoTensor {
460    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
461    /// The GPU fragment shader will dequantize per-texel using the scale and
462    /// zero_point.
463    Quantized {
464        protos: Array3<i8>,
465        quantization: Quantization,
466    },
467    /// Dequantized f32 protos (from float models or legacy path).
468    Float(Array3<f32>),
469}
470
471impl ProtoTensor {
472    /// Returns `true` if this is the quantized variant.
473    pub fn is_quantized(&self) -> bool {
474        matches!(self, ProtoTensor::Quantized { .. })
475    }
476
477    /// Returns the spatial dimensions `(height, width, num_protos)`.
478    pub fn dim(&self) -> (usize, usize, usize) {
479        match self {
480            ProtoTensor::Quantized { protos, .. } => protos.dim(),
481            ProtoTensor::Float(arr) => arr.dim(),
482        }
483    }
484
485    /// Returns dequantized f32 protos. For the `Float` variant this is a
486    /// no-copy reference; for `Quantized` it allocates and dequantizes.
487    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
488        match self {
489            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
490            ProtoTensor::Quantized {
491                protos,
492                quantization,
493            } => {
494                let scale = quantization.scale;
495                let zp = quantization.zero_point as f32;
496                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
497            }
498        }
499    }
500}
501
502/// Raw prototype data for fused decode+render pipelines.
503///
504/// Holds post-NMS intermediate state before mask materialization, allowing the
505/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
506/// shader) without materializing intermediate `Array3<u8>` masks.
507#[derive(Debug, Clone)]
508pub struct ProtoData {
509    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
510    pub mask_coefficients: Vec<Vec<f32>>,
511    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
512    pub protos: ProtoTensor,
513}
514
515/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
516///
517///  # Examples
518/// ```
519/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
520/// let quant = Quantization::new(0.1, -128);
521/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
522/// let detect_quant = DetectBoxQuantized {
523///     bbox,
524///     score: 100_i8,
525///     label: 1,
526/// };
527/// let detect = dequant_detect_box(&detect_quant, quant);
528/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
529/// assert_eq!(detect.label, 1);
530/// assert_eq!(detect.bbox, bbox);
531/// ```
532pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
533    detect: &DetectBoxQuantized<SCORE>,
534    quant_scores: Quantization,
535) -> DetectBox {
536    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
537    DetectBox {
538        bbox: detect.bbox,
539        score: quant_scores.scale * detect.score.as_() + scaled_zp,
540        label: detect.label,
541    }
542}
543/// A detection box with a f32 bbox and quantized score
544#[derive(Debug, Clone, Copy, PartialEq)]
545pub struct DetectBoxQuantized<
546    // BOX: Signed + PrimInt + AsPrimitive<f32>,
547    SCORE: PrimInt + AsPrimitive<f32>,
548> {
549    // pub bbox: BoundingBoxQuantized<BOX>,
550    pub bbox: BoundingBox,
551    /// model-specific score for this detection, higher implies more
552    /// confidence.
553    pub score: SCORE,
554    /// label index for this detect
555    pub label: usize,
556}
557
558/// Dequantizes an ndarray from quantized values to f32 values using the given
559/// quantization parameters
560///
561/// # Examples
562/// ```
563/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
564/// let quant = Quantization::new(0.1, -128);
565/// let input: Vec<i8> = vec![0, 127, -128, 64];
566/// let input_array = ndarray::Array1::from(input);
567/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
568/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
569/// ```
570pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
571    input: ArrayView<T, D>,
572    quant: Quantization,
573) -> Array<F, D>
574where
575    i32: num_traits::AsPrimitive<F>,
576    f32: num_traits::AsPrimitive<F>,
577{
578    let zero_point = quant.zero_point.as_();
579    let scale = quant.scale.as_();
580    if zero_point != F::zero() {
581        let scaled_zero = -zero_point * scale;
582        input.mapv(|d| d.as_() * scale + scaled_zero)
583    } else {
584        input.mapv(|d| d.as_() * scale)
585    }
586}
587
588/// Dequantizes a slice from quantized values to float values using the given
589/// quantization parameters
590///
591/// # Examples
592/// ```
593/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
594/// let quant = Quantization::new(0.1, -128);
595/// let input: Vec<i8> = vec![0, 127, -128, 64];
596/// let mut output: Vec<f32> = vec![0.0; input.len()];
597/// dequantize_cpu(&input, quant, &mut output);
598/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
599/// ```
600pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
601    input: &[T],
602    quant: Quantization,
603    output: &mut [F],
604) where
605    f32: num_traits::AsPrimitive<F>,
606    i32: num_traits::AsPrimitive<F>,
607{
608    assert!(input.len() == output.len());
609    let zero_point = quant.zero_point.as_();
610    let scale = quant.scale.as_();
611    if zero_point != F::zero() {
612        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
613        input
614            .iter()
615            .zip(output)
616            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
617    } else {
618        input
619            .iter()
620            .zip(output)
621            .for_each(|(d, deq)| *deq = d.as_() * scale);
622    }
623}
624
625/// Dequantizes a slice from quantized values to float values using the given
626/// quantization parameters, using chunked processing. This is around 5% faster
627/// than `dequantize_cpu` for large slices.
628///
629/// # Examples
630/// ```
631/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
632/// let quant = Quantization::new(0.1, -128);
633/// let input: Vec<i8> = vec![0, 127, -128, 64];
634/// let mut output: Vec<f32> = vec![0.0; input.len()];
635/// dequantize_cpu_chunked(&input, quant, &mut output);
636/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
637/// ```
638pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
639    input: &[T],
640    quant: Quantization,
641    output: &mut [F],
642) where
643    f32: num_traits::AsPrimitive<F>,
644    i32: num_traits::AsPrimitive<F>,
645{
646    assert!(input.len() == output.len());
647    let zero_point = quant.zero_point.as_();
648    let scale = quant.scale.as_();
649
650    let input = input.as_chunks::<4>();
651    let output = output.as_chunks_mut::<4>();
652
653    if zero_point != F::zero() {
654        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
655
656        input
657            .0
658            .iter()
659            .zip(output.0)
660            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
661        input
662            .1
663            .iter()
664            .zip(output.1)
665            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
666    } else {
667        input
668            .0
669            .iter()
670            .zip(output.0)
671            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
672        input
673            .1
674            .iter()
675            .zip(output.1)
676            .for_each(|(d, deq)| *deq = d.as_() * scale);
677    }
678}
679
680/// Converts a segmentation tensor into a 2D mask
681/// If the last dimension of the segmentation tensor is 1, values equal or
682/// above 128 are considered objects. Otherwise the object is the argmax index
683///
684/// # Errors
685///
686/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
687/// invalid shape.
688///
689/// # Examples
690/// ```
691/// # use edgefirst_decoder::segmentation_to_mask;
692/// let segmentation =
693///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
694/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
695/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
696/// ```
697pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
698    if segmentation.shape()[2] == 0 {
699        return Err(DecoderError::InvalidShape(
700            "Segmentation tensor must have non-zero depth".to_string(),
701        ));
702    }
703    if segmentation.shape()[2] == 1 {
704        yolo_segmentation_to_mask(segmentation, 128)
705    } else {
706        Ok(modelpack_segmentation_to_mask(segmentation))
707    }
708}
709
710/// Returns the maximum value and its index from a 1D array
711fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
712    score
713        .iter()
714        .enumerate()
715        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
716            if max > *s {
717                (max, arg_max)
718            } else {
719                (*s, ind)
720            }
721        })
722}
723#[cfg(test)]
724#[cfg_attr(coverage_nightly, coverage(off))]
725mod decoder_tests {
726    #![allow(clippy::excessive_precision)]
727    use crate::{
728        configs::{DecoderType, DimName, Protos},
729        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
730        yolo::{
731            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
732            decode_yolo_segdet_quant,
733        },
734        *,
735    };
736    use ndarray::{array, s, Array4};
737    use ndarray_stats::DeviationExt;
738
739    fn compare_outputs(
740        boxes: (&[DetectBox], &[DetectBox]),
741        masks: (&[Segmentation], &[Segmentation]),
742    ) {
743        let (boxes0, boxes1) = boxes;
744        let (masks0, masks1) = masks;
745
746        assert_eq!(boxes0.len(), boxes1.len());
747        assert_eq!(masks0.len(), masks1.len());
748
749        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
750            assert!(
751                b_i8.equal_within_delta(b_f32, 1e-6),
752                "{b_i8:?} is not equal to {b_f32:?}"
753            );
754        }
755
756        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
757            assert_eq!(
758                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
759                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
760            );
761            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
762            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
763            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
764            let diff = &mask_i8 - &mask_f32;
765            for x in 0..diff.shape()[0] {
766                for y in 0..diff.shape()[1] {
767                    for z in 0..diff.shape()[2] {
768                        let val = diff[[x, y, z]];
769                        assert!(
770                            val.abs() <= 1,
771                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
772                            x,
773                            y,
774                            z,
775                            val
776                        );
777                    }
778                }
779            }
780            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
781            assert!(
782                mean_sq_err < 1e-2,
783                "Mean Square Error between masks was greater than 1%: {:.2}%",
784                mean_sq_err * 100.0
785            );
786        }
787    }
788
789    #[test]
790    fn test_decoder_modelpack() {
791        let score_threshold = 0.45;
792        let iou_threshold = 0.45;
793        let boxes = include_bytes!(concat!(
794            env!("CARGO_MANIFEST_DIR"),
795            "/../../testdata/modelpack_boxes_1935x1x4.bin"
796        ));
797        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
798
799        let scores = include_bytes!(concat!(
800            env!("CARGO_MANIFEST_DIR"),
801            "/../../testdata/modelpack_scores_1935x1.bin"
802        ));
803        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
804
805        let quant_boxes = (0.004656755365431309, 21).into();
806        let quant_scores = (0.0019603664986789227, 0).into();
807
808        let decoder = DecoderBuilder::default()
809            .with_config_modelpack_det(
810                configs::Boxes {
811                    decoder: DecoderType::ModelPack,
812                    quantization: Some(quant_boxes),
813                    shape: vec![1, 1935, 1, 4],
814                    dshape: vec![
815                        (DimName::Batch, 1),
816                        (DimName::NumBoxes, 1935),
817                        (DimName::Padding, 1),
818                        (DimName::BoxCoords, 4),
819                    ],
820                    normalized: Some(true),
821                },
822                configs::Scores {
823                    decoder: DecoderType::ModelPack,
824                    quantization: Some(quant_scores),
825                    shape: vec![1, 1935, 1],
826                    dshape: vec![
827                        (DimName::Batch, 1),
828                        (DimName::NumBoxes, 1935),
829                        (DimName::NumClasses, 1),
830                    ],
831                },
832            )
833            .with_score_threshold(score_threshold)
834            .with_iou_threshold(iou_threshold)
835            .build()
836            .unwrap();
837
838        let quant_boxes = quant_boxes.into();
839        let quant_scores = quant_scores.into();
840
841        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
842        decode_modelpack_det(
843            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
844            (scores.slice(s![0, .., ..]), quant_scores),
845            score_threshold,
846            iou_threshold,
847            &mut output_boxes,
848        );
849        assert!(output_boxes[0].equal_within_delta(
850            &DetectBox {
851                bbox: BoundingBox {
852                    xmin: 0.40513772,
853                    ymin: 0.6379755,
854                    xmax: 0.5122431,
855                    ymax: 0.7730214,
856                },
857                score: 0.4861709,
858                label: 0
859            },
860            1e-6
861        ));
862
863        let mut output_boxes1 = Vec::with_capacity(50);
864        let mut output_masks1 = Vec::with_capacity(50);
865
866        decoder
867            .decode_quantized(
868                &[boxes.view().into(), scores.view().into()],
869                &mut output_boxes1,
870                &mut output_masks1,
871            )
872            .unwrap();
873
874        let mut output_boxes_float = Vec::with_capacity(50);
875        let mut output_masks_float = Vec::with_capacity(50);
876
877        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
878        let scores = dequantize_ndarray(scores.view(), quant_scores);
879
880        decoder
881            .decode_float::<f32>(
882                &[boxes.view().into_dyn(), scores.view().into_dyn()],
883                &mut output_boxes_float,
884                &mut output_masks_float,
885            )
886            .unwrap();
887
888        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
889        compare_outputs(
890            (&output_boxes, &output_boxes_float),
891            (&[], &output_masks_float),
892        );
893    }
894
895    #[test]
896    fn test_decoder_modelpack_split_u8() {
897        let score_threshold = 0.45;
898        let iou_threshold = 0.45;
899        let detect0 = include_bytes!(concat!(
900            env!("CARGO_MANIFEST_DIR"),
901            "/../../testdata/modelpack_split_9x15x18.bin"
902        ));
903        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
904
905        let detect1 = include_bytes!(concat!(
906            env!("CARGO_MANIFEST_DIR"),
907            "/../../testdata/modelpack_split_17x30x18.bin"
908        ));
909        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
910
911        let quant0 = (0.08547406643629074, 174).into();
912        let quant1 = (0.09929127991199493, 183).into();
913        let anchors0 = vec![
914            [0.36666667461395264, 0.31481480598449707],
915            [0.38749998807907104, 0.4740740656852722],
916            [0.5333333611488342, 0.644444465637207],
917        ];
918        let anchors1 = vec![
919            [0.13750000298023224, 0.2074074000120163],
920            [0.2541666626930237, 0.21481481194496155],
921            [0.23125000298023224, 0.35185185074806213],
922        ];
923
924        let detect_config0 = configs::Detection {
925            decoder: DecoderType::ModelPack,
926            shape: vec![1, 9, 15, 18],
927            anchors: Some(anchors0.clone()),
928            quantization: Some(quant0),
929            dshape: vec![
930                (DimName::Batch, 1),
931                (DimName::Height, 9),
932                (DimName::Width, 15),
933                (DimName::NumAnchorsXFeatures, 18),
934            ],
935            normalized: Some(true),
936        };
937
938        let detect_config1 = configs::Detection {
939            decoder: DecoderType::ModelPack,
940            shape: vec![1, 17, 30, 18],
941            anchors: Some(anchors1.clone()),
942            quantization: Some(quant1),
943            dshape: vec![
944                (DimName::Batch, 1),
945                (DimName::Height, 17),
946                (DimName::Width, 30),
947                (DimName::NumAnchorsXFeatures, 18),
948            ],
949            normalized: Some(true),
950        };
951
952        let config0 = (&detect_config0).try_into().unwrap();
953        let config1 = (&detect_config1).try_into().unwrap();
954
955        let decoder = DecoderBuilder::default()
956            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
957            .with_score_threshold(score_threshold)
958            .with_iou_threshold(iou_threshold)
959            .build()
960            .unwrap();
961
962        let quant0 = quant0.into();
963        let quant1 = quant1.into();
964
965        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
966        decode_modelpack_split_quant(
967            &[
968                detect0.slice(s![0, .., .., ..]),
969                detect1.slice(s![0, .., .., ..]),
970            ],
971            &[config0, config1],
972            score_threshold,
973            iou_threshold,
974            &mut output_boxes,
975        );
976        assert!(output_boxes[0].equal_within_delta(
977            &DetectBox {
978                bbox: BoundingBox {
979                    xmin: 0.43171933,
980                    ymin: 0.68243736,
981                    xmax: 0.5626645,
982                    ymax: 0.808863,
983                },
984                score: 0.99240804,
985                label: 0
986            },
987            1e-6
988        ));
989
990        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
991        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
992        decoder
993            .decode_quantized(
994                &[detect0.view().into(), detect1.view().into()],
995                &mut output_boxes1,
996                &mut output_masks1,
997            )
998            .unwrap();
999
1000        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1001        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1002
1003        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1004        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1005        decoder
1006            .decode_float::<f32>(
1007                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1008                &mut output_boxes1_f32,
1009                &mut output_masks1_f32,
1010            )
1011            .unwrap();
1012
1013        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1014        compare_outputs(
1015            (&output_boxes, &output_boxes1_f32),
1016            (&[], &output_masks1_f32),
1017        );
1018    }
1019
1020    #[test]
1021    fn test_decoder_parse_config_modelpack_split_u8() {
1022        let score_threshold = 0.45;
1023        let iou_threshold = 0.45;
1024        let detect0 = include_bytes!(concat!(
1025            env!("CARGO_MANIFEST_DIR"),
1026            "/../../testdata/modelpack_split_9x15x18.bin"
1027        ));
1028        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1029
1030        let detect1 = include_bytes!(concat!(
1031            env!("CARGO_MANIFEST_DIR"),
1032            "/../../testdata/modelpack_split_17x30x18.bin"
1033        ));
1034        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1035
1036        let decoder = DecoderBuilder::default()
1037            .with_config_yaml_str(
1038                include_str!(concat!(
1039                    env!("CARGO_MANIFEST_DIR"),
1040                    "/../../testdata/modelpack_split.yaml"
1041                ))
1042                .to_string(),
1043            )
1044            .with_score_threshold(score_threshold)
1045            .with_iou_threshold(iou_threshold)
1046            .build()
1047            .unwrap();
1048
1049        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1050        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1051        decoder
1052            .decode_quantized(
1053                &[
1054                    ArrayViewDQuantized::from(detect1.view()),
1055                    ArrayViewDQuantized::from(detect0.view()),
1056                ],
1057                &mut output_boxes,
1058                &mut output_masks,
1059            )
1060            .unwrap();
1061        assert!(output_boxes[0].equal_within_delta(
1062            &DetectBox {
1063                bbox: BoundingBox {
1064                    xmin: 0.43171933,
1065                    ymin: 0.68243736,
1066                    xmax: 0.5626645,
1067                    ymax: 0.808863,
1068                },
1069                score: 0.99240804,
1070                label: 0
1071            },
1072            1e-6
1073        ));
1074    }
1075
1076    #[test]
1077    fn test_modelpack_seg() {
1078        let out = include_bytes!(concat!(
1079            env!("CARGO_MANIFEST_DIR"),
1080            "/../../testdata/modelpack_seg_2x160x160.bin"
1081        ));
1082        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1083        let quant = (1.0 / 255.0, 0).into();
1084
1085        let decoder = DecoderBuilder::default()
1086            .with_config_modelpack_seg(configs::Segmentation {
1087                decoder: DecoderType::ModelPack,
1088                quantization: Some(quant),
1089                shape: vec![1, 2, 160, 160],
1090                dshape: vec![
1091                    (DimName::Batch, 1),
1092                    (DimName::NumClasses, 2),
1093                    (DimName::Height, 160),
1094                    (DimName::Width, 160),
1095                ],
1096            })
1097            .build()
1098            .unwrap();
1099        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1100        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1101        decoder
1102            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1103            .unwrap();
1104
1105        let mut mask = out.slice(s![0, .., .., ..]);
1106        mask.swap_axes(0, 1);
1107        mask.swap_axes(1, 2);
1108        let mask = [Segmentation {
1109            xmin: 0.0,
1110            ymin: 0.0,
1111            xmax: 1.0,
1112            ymax: 1.0,
1113            segmentation: mask.into_owned(),
1114        }];
1115        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1116
1117        decoder
1118            .decode_float::<f32>(
1119                &[dequantize_ndarray(out.view(), quant.into())
1120                    .view()
1121                    .into_dyn()],
1122                &mut output_boxes,
1123                &mut output_masks,
1124            )
1125            .unwrap();
1126
1127        // not expected for float decoder to have same values as quantized decoder, as
1128        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1129        // the model output. Thus the float output is the same as the quantized output
1130        // but scaled differently. However, it is expected that the mask after argmax
1131        // will be the same.
1132        compare_outputs((&[], &output_boxes), (&[], &[]));
1133        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1134        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1135
1136        assert_eq!(mask0, mask1);
1137    }
1138    #[test]
1139    fn test_modelpack_seg_quant() {
1140        let out = include_bytes!(concat!(
1141            env!("CARGO_MANIFEST_DIR"),
1142            "/../../testdata/modelpack_seg_2x160x160.bin"
1143        ));
1144        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1145        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1146        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1147        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1148        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1149        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1150
1151        let quant = (1.0 / 255.0, 0).into();
1152
1153        let decoder = DecoderBuilder::default()
1154            .with_config_modelpack_seg(configs::Segmentation {
1155                decoder: DecoderType::ModelPack,
1156                quantization: Some(quant),
1157                shape: vec![1, 2, 160, 160],
1158                dshape: vec![
1159                    (DimName::Batch, 1),
1160                    (DimName::NumClasses, 2),
1161                    (DimName::Height, 160),
1162                    (DimName::Width, 160),
1163                ],
1164            })
1165            .build()
1166            .unwrap();
1167        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1168        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1169        decoder
1170            .decode_quantized(
1171                &[out_u8.view().into()],
1172                &mut output_boxes,
1173                &mut output_masks_u8,
1174            )
1175            .unwrap();
1176
1177        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1178        decoder
1179            .decode_quantized(
1180                &[out_i8.view().into()],
1181                &mut output_boxes,
1182                &mut output_masks_i8,
1183            )
1184            .unwrap();
1185
1186        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1187        decoder
1188            .decode_quantized(
1189                &[out_u16.view().into()],
1190                &mut output_boxes,
1191                &mut output_masks_u16,
1192            )
1193            .unwrap();
1194
1195        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1196        decoder
1197            .decode_quantized(
1198                &[out_i16.view().into()],
1199                &mut output_boxes,
1200                &mut output_masks_i16,
1201            )
1202            .unwrap();
1203
1204        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1205        decoder
1206            .decode_quantized(
1207                &[out_u32.view().into()],
1208                &mut output_boxes,
1209                &mut output_masks_u32,
1210            )
1211            .unwrap();
1212
1213        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1214        decoder
1215            .decode_quantized(
1216                &[out_i32.view().into()],
1217                &mut output_boxes,
1218                &mut output_masks_i32,
1219            )
1220            .unwrap();
1221
1222        compare_outputs((&[], &output_boxes), (&[], &[]));
1223        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1224        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1225        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1226        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1227        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1228        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1229        assert_eq!(mask_u8, mask_i8);
1230        assert_eq!(mask_u8, mask_u16);
1231        assert_eq!(mask_u8, mask_i16);
1232        assert_eq!(mask_u8, mask_u32);
1233        assert_eq!(mask_u8, mask_i32);
1234    }
1235
1236    #[test]
1237    fn test_modelpack_segdet() {
1238        let score_threshold = 0.45;
1239        let iou_threshold = 0.45;
1240
1241        let boxes = include_bytes!(concat!(
1242            env!("CARGO_MANIFEST_DIR"),
1243            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1244        ));
1245        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1246
1247        let scores = include_bytes!(concat!(
1248            env!("CARGO_MANIFEST_DIR"),
1249            "/../../testdata/modelpack_scores_1935x1.bin"
1250        ));
1251        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1252
1253        let seg = include_bytes!(concat!(
1254            env!("CARGO_MANIFEST_DIR"),
1255            "/../../testdata/modelpack_seg_2x160x160.bin"
1256        ));
1257        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1258
1259        let quant_boxes = (0.004656755365431309, 21).into();
1260        let quant_scores = (0.0019603664986789227, 0).into();
1261        let quant_seg = (1.0 / 255.0, 0).into();
1262
1263        let decoder = DecoderBuilder::default()
1264            .with_config_modelpack_segdet(
1265                configs::Boxes {
1266                    decoder: DecoderType::ModelPack,
1267                    quantization: Some(quant_boxes),
1268                    shape: vec![1, 1935, 1, 4],
1269                    dshape: vec![
1270                        (DimName::Batch, 1),
1271                        (DimName::NumBoxes, 1935),
1272                        (DimName::Padding, 1),
1273                        (DimName::BoxCoords, 4),
1274                    ],
1275                    normalized: Some(true),
1276                },
1277                configs::Scores {
1278                    decoder: DecoderType::ModelPack,
1279                    quantization: Some(quant_scores),
1280                    shape: vec![1, 1935, 1],
1281                    dshape: vec![
1282                        (DimName::Batch, 1),
1283                        (DimName::NumBoxes, 1935),
1284                        (DimName::NumClasses, 1),
1285                    ],
1286                },
1287                configs::Segmentation {
1288                    decoder: DecoderType::ModelPack,
1289                    quantization: Some(quant_seg),
1290                    shape: vec![1, 2, 160, 160],
1291                    dshape: vec![
1292                        (DimName::Batch, 1),
1293                        (DimName::NumClasses, 2),
1294                        (DimName::Height, 160),
1295                        (DimName::Width, 160),
1296                    ],
1297                },
1298            )
1299            .with_iou_threshold(iou_threshold)
1300            .with_score_threshold(score_threshold)
1301            .build()
1302            .unwrap();
1303        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1304        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1305        decoder
1306            .decode_quantized(
1307                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1308                &mut output_boxes,
1309                &mut output_masks,
1310            )
1311            .unwrap();
1312
1313        let mut mask = seg.slice(s![0, .., .., ..]);
1314        mask.swap_axes(0, 1);
1315        mask.swap_axes(1, 2);
1316        let mask = [Segmentation {
1317            xmin: 0.0,
1318            ymin: 0.0,
1319            xmax: 1.0,
1320            ymax: 1.0,
1321            segmentation: mask.into_owned(),
1322        }];
1323        let correct_boxes = [DetectBox {
1324            bbox: BoundingBox {
1325                xmin: 0.40513772,
1326                ymin: 0.6379755,
1327                xmax: 0.5122431,
1328                ymax: 0.7730214,
1329            },
1330            score: 0.4861709,
1331            label: 0,
1332        }];
1333        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1334
1335        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1336        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1337        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1338        decoder
1339            .decode_float::<f32>(
1340                &[
1341                    scores.view().into_dyn(),
1342                    boxes.view().into_dyn(),
1343                    seg.view().into_dyn(),
1344                ],
1345                &mut output_boxes,
1346                &mut output_masks,
1347            )
1348            .unwrap();
1349
1350        // not expected for float segmentation decoder to have same values as quantized
1351        // segmentation decoder, as float decoder ensures the data fills 0-255,
1352        // quantized decoder uses whatever the model output. Thus the float
1353        // output is the same as the quantized output but scaled differently.
1354        // However, it is expected that the mask after argmax will be the same.
1355        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1356        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1357        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1358
1359        assert_eq!(mask0, mask1);
1360    }
1361
1362    #[test]
1363    fn test_modelpack_segdet_split() {
1364        let score_threshold = 0.8;
1365        let iou_threshold = 0.5;
1366
1367        let seg = include_bytes!(concat!(
1368            env!("CARGO_MANIFEST_DIR"),
1369            "/../../testdata/modelpack_seg_2x160x160.bin"
1370        ));
1371        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1372
1373        let detect0 = include_bytes!(concat!(
1374            env!("CARGO_MANIFEST_DIR"),
1375            "/../../testdata/modelpack_split_9x15x18.bin"
1376        ));
1377        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1378
1379        let detect1 = include_bytes!(concat!(
1380            env!("CARGO_MANIFEST_DIR"),
1381            "/../../testdata/modelpack_split_17x30x18.bin"
1382        ));
1383        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1384
1385        let quant0 = (0.08547406643629074, 174).into();
1386        let quant1 = (0.09929127991199493, 183).into();
1387        let quant_seg = (1.0 / 255.0, 0).into();
1388
1389        let anchors0 = vec![
1390            [0.36666667461395264, 0.31481480598449707],
1391            [0.38749998807907104, 0.4740740656852722],
1392            [0.5333333611488342, 0.644444465637207],
1393        ];
1394        let anchors1 = vec![
1395            [0.13750000298023224, 0.2074074000120163],
1396            [0.2541666626930237, 0.21481481194496155],
1397            [0.23125000298023224, 0.35185185074806213],
1398        ];
1399
1400        let decoder = DecoderBuilder::default()
1401            .with_config_modelpack_segdet_split(
1402                vec![
1403                    configs::Detection {
1404                        decoder: DecoderType::ModelPack,
1405                        shape: vec![1, 17, 30, 18],
1406                        anchors: Some(anchors1),
1407                        quantization: Some(quant1),
1408                        dshape: vec![
1409                            (DimName::Batch, 1),
1410                            (DimName::Height, 17),
1411                            (DimName::Width, 30),
1412                            (DimName::NumAnchorsXFeatures, 18),
1413                        ],
1414                        normalized: Some(true),
1415                    },
1416                    configs::Detection {
1417                        decoder: DecoderType::ModelPack,
1418                        shape: vec![1, 9, 15, 18],
1419                        anchors: Some(anchors0),
1420                        quantization: Some(quant0),
1421                        dshape: vec![
1422                            (DimName::Batch, 1),
1423                            (DimName::Height, 9),
1424                            (DimName::Width, 15),
1425                            (DimName::NumAnchorsXFeatures, 18),
1426                        ],
1427                        normalized: Some(true),
1428                    },
1429                ],
1430                configs::Segmentation {
1431                    decoder: DecoderType::ModelPack,
1432                    quantization: Some(quant_seg),
1433                    shape: vec![1, 2, 160, 160],
1434                    dshape: vec![
1435                        (DimName::Batch, 1),
1436                        (DimName::NumClasses, 2),
1437                        (DimName::Height, 160),
1438                        (DimName::Width, 160),
1439                    ],
1440                },
1441            )
1442            .with_score_threshold(score_threshold)
1443            .with_iou_threshold(iou_threshold)
1444            .build()
1445            .unwrap();
1446        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1447        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1448        decoder
1449            .decode_quantized(
1450                &[
1451                    detect0.view().into(),
1452                    detect1.view().into(),
1453                    seg.view().into(),
1454                ],
1455                &mut output_boxes,
1456                &mut output_masks,
1457            )
1458            .unwrap();
1459
1460        let mut mask = seg.slice(s![0, .., .., ..]);
1461        mask.swap_axes(0, 1);
1462        mask.swap_axes(1, 2);
1463        let mask = [Segmentation {
1464            xmin: 0.0,
1465            ymin: 0.0,
1466            xmax: 1.0,
1467            ymax: 1.0,
1468            segmentation: mask.into_owned(),
1469        }];
1470        let correct_boxes = [DetectBox {
1471            bbox: BoundingBox {
1472                xmin: 0.43171933,
1473                ymin: 0.68243736,
1474                xmax: 0.5626645,
1475                ymax: 0.808863,
1476            },
1477            score: 0.99240804,
1478            label: 0,
1479        }];
1480        println!("Output Boxes: {:?}", output_boxes);
1481        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1482
1483        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1484        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1485        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1486        decoder
1487            .decode_float::<f32>(
1488                &[
1489                    detect0.view().into_dyn(),
1490                    detect1.view().into_dyn(),
1491                    seg.view().into_dyn(),
1492                ],
1493                &mut output_boxes,
1494                &mut output_masks,
1495            )
1496            .unwrap();
1497
1498        // not expected for float segmentation decoder to have same values as quantized
1499        // segmentation decoder, as float decoder ensures the data fills 0-255,
1500        // quantized decoder uses whatever the model output. Thus the float
1501        // output is the same as the quantized output but scaled differently.
1502        // However, it is expected that the mask after argmax will be the same.
1503        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1504        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1505        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1506
1507        assert_eq!(mask0, mask1);
1508    }
1509
1510    #[test]
1511    fn test_dequant_chunked() {
1512        let out = include_bytes!(concat!(
1513            env!("CARGO_MANIFEST_DIR"),
1514            "/../../testdata/yolov8s_80_classes.bin"
1515        ));
1516        let mut out =
1517            unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1518        out.push(123); // make sure to test non multiple of 16 length
1519
1520        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1521        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1522        let quant = Quantization::new(0.0040811873, -123);
1523        dequantize_cpu(&out, quant, &mut out_dequant);
1524
1525        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1526        assert_eq!(out_dequant, out_dequant_simd);
1527
1528        let quant = Quantization::new(0.0040811873, 0);
1529        dequantize_cpu(&out, quant, &mut out_dequant);
1530
1531        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1532        assert_eq!(out_dequant, out_dequant_simd);
1533    }
1534
1535    #[test]
1536    fn test_decoder_yolo_det() {
1537        let score_threshold = 0.25;
1538        let iou_threshold = 0.7;
1539        let out = include_bytes!(concat!(
1540            env!("CARGO_MANIFEST_DIR"),
1541            "/../../testdata/yolov8s_80_classes.bin"
1542        ));
1543        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1544        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1545        let quant = (0.0040811873, -123).into();
1546
1547        let decoder = DecoderBuilder::default()
1548            .with_config_yolo_det(
1549                configs::Detection {
1550                    decoder: DecoderType::Ultralytics,
1551                    shape: vec![1, 84, 8400],
1552                    anchors: None,
1553                    quantization: Some(quant),
1554                    dshape: vec![
1555                        (DimName::Batch, 1),
1556                        (DimName::NumFeatures, 84),
1557                        (DimName::NumBoxes, 8400),
1558                    ],
1559                    normalized: Some(true),
1560                },
1561                Some(DecoderVersion::Yolo11),
1562            )
1563            .with_score_threshold(score_threshold)
1564            .with_iou_threshold(iou_threshold)
1565            .build()
1566            .unwrap();
1567
1568        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1569        decode_yolo_det(
1570            (out.slice(s![0, .., ..]), quant.into()),
1571            score_threshold,
1572            iou_threshold,
1573            Some(configs::Nms::ClassAgnostic),
1574            &mut output_boxes,
1575        );
1576        assert!(output_boxes[0].equal_within_delta(
1577            &DetectBox {
1578                bbox: BoundingBox {
1579                    xmin: 0.5285137,
1580                    ymin: 0.05305544,
1581                    xmax: 0.87541467,
1582                    ymax: 0.9998909,
1583                },
1584                score: 0.5591227,
1585                label: 0
1586            },
1587            1e-6
1588        ));
1589
1590        assert!(output_boxes[1].equal_within_delta(
1591            &DetectBox {
1592                bbox: BoundingBox {
1593                    xmin: 0.130598,
1594                    ymin: 0.43260583,
1595                    xmax: 0.35098213,
1596                    ymax: 0.9958097,
1597                },
1598                score: 0.33057618,
1599                label: 75
1600            },
1601            1e-6
1602        ));
1603
1604        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1605        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1606        decoder
1607            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1608            .unwrap();
1609
1610        let out = dequantize_ndarray(out.view(), quant.into());
1611        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1612        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1613        decoder
1614            .decode_float::<f32>(
1615                &[out.view().into_dyn()],
1616                &mut output_boxes_f32,
1617                &mut output_masks_f32,
1618            )
1619            .unwrap();
1620
1621        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1622        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1623    }
1624
1625    #[test]
1626    fn test_decoder_masks() {
1627        let score_threshold = 0.45;
1628        let iou_threshold = 0.45;
1629        let boxes = include_bytes!(concat!(
1630            env!("CARGO_MANIFEST_DIR"),
1631            "/../../testdata/yolov8_boxes_116x8400.bin"
1632        ));
1633        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1634        let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1635        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1636
1637        let protos = include_bytes!(concat!(
1638            env!("CARGO_MANIFEST_DIR"),
1639            "/../../testdata/yolov8_protos_160x160x32.bin"
1640        ));
1641        let protos =
1642            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1643        let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1644        let quant_protos = Quantization::new(0.02491161972284317, -117);
1645        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1646        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1647        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1648        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1649        decode_yolo_segdet_float(
1650            seg.view(),
1651            protos.view(),
1652            score_threshold,
1653            iou_threshold,
1654            Some(configs::Nms::ClassAgnostic),
1655            &mut output_boxes,
1656            &mut output_masks,
1657        )
1658        .unwrap();
1659        assert_eq!(output_boxes.len(), 2);
1660        assert_eq!(output_boxes.len(), output_masks.len());
1661
1662        for (b, m) in output_boxes.iter().zip(&output_masks) {
1663            assert!(b.bbox.xmin >= m.xmin);
1664            assert!(b.bbox.ymin >= m.ymin);
1665            assert!(b.bbox.xmax >= m.xmax);
1666            assert!(b.bbox.ymax >= m.ymax);
1667        }
1668        assert!(output_boxes[0].equal_within_delta(
1669            &DetectBox {
1670                bbox: BoundingBox {
1671                    xmin: 0.08515105,
1672                    ymin: 0.7131401,
1673                    xmax: 0.29802868,
1674                    ymax: 0.8195788,
1675                },
1676                score: 0.91537374,
1677                label: 23
1678            },
1679            1.0 / 160.0, // wider range because mask will expand the box
1680        ));
1681
1682        assert!(output_boxes[1].equal_within_delta(
1683            &DetectBox {
1684                bbox: BoundingBox {
1685                    xmin: 0.59605736,
1686                    ymin: 0.25545314,
1687                    xmax: 0.93666154,
1688                    ymax: 0.72378385,
1689                },
1690                score: 0.91537374,
1691                label: 23
1692            },
1693            1.0 / 160.0, // wider range because mask will expand the box
1694        ));
1695
1696        let full_mask = include_bytes!(concat!(
1697            env!("CARGO_MANIFEST_DIR"),
1698            "/../../testdata/yolov8_mask_results.bin"
1699        ));
1700        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1701
1702        let cropped_mask = full_mask.slice(ndarray::s![
1703            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1704            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1705        ]);
1706
1707        assert_eq!(
1708            cropped_mask,
1709            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1710        );
1711    }
1712
1713    /// Regression test: config-driven path with NCHW protos (no dshape).
1714    /// Simulates YOLOv8-seg ONNX outputs where protos are (1, 32, 160, 160)
1715    /// and the YAML config has no dshape field — the exact scenario from
1716    /// hal_mask_matmul_bug.md.
1717    #[test]
1718    fn test_decoder_masks_nchw_protos() {
1719        let score_threshold = 0.45;
1720        let iou_threshold = 0.45;
1721
1722        // Load test data — boxes as [116, 8400]
1723        let boxes_raw = include_bytes!(concat!(
1724            env!("CARGO_MANIFEST_DIR"),
1725            "/../../testdata/yolov8_boxes_116x8400.bin"
1726        ));
1727        let boxes_raw =
1728            unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1729        let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1730        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1731
1732        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1733        let protos_raw = include_bytes!(concat!(
1734            env!("CARGO_MANIFEST_DIR"),
1735            "/../../testdata/yolov8_protos_160x160x32.bin"
1736        ));
1737        let protos_raw = unsafe {
1738            std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1739        };
1740        let protos_hwc =
1741            ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1742        let quant_protos = Quantization::new(0.02491161972284317, -117);
1743        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1744
1745        // ---- Reference: direct call with HWC protos (known working) ----
1746        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1747        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1748        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1749        decode_yolo_segdet_float(
1750            seg.view(),
1751            protos_f32_hwc.view(),
1752            score_threshold,
1753            iou_threshold,
1754            Some(configs::Nms::ClassAgnostic),
1755            &mut ref_boxes,
1756            &mut ref_masks,
1757        )
1758        .unwrap();
1759        assert_eq!(ref_boxes.len(), 2);
1760
1761        // ---- Config-driven path: NCHW protos, no dshape ----
1762        // Permute protos to NCHW [1, 32, 160, 160] as an ONNX model would output
1763        let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); // [32, 160, 160]
1764        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1765
1766        // Build boxes as [1, 116, 8400] f32
1767        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1768
1769        // Build decoder from config with no dshape on protos
1770        let decoder = DecoderBuilder::default()
1771            .with_config_yolo_segdet(
1772                configs::Detection {
1773                    decoder: configs::DecoderType::Ultralytics,
1774                    quantization: None,
1775                    shape: vec![1, 116, 8400],
1776                    dshape: vec![],
1777                    normalized: Some(true),
1778                    anchors: None,
1779                },
1780                configs::Protos {
1781                    decoder: configs::DecoderType::Ultralytics,
1782                    quantization: None,
1783                    shape: vec![1, 32, 160, 160],
1784                    dshape: vec![], // No dshape — simulates YAML without dshape
1785                },
1786                None, // decoder version
1787            )
1788            .with_score_threshold(score_threshold)
1789            .with_iou_threshold(iou_threshold)
1790            .build()
1791            .unwrap();
1792
1793        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1794        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1795        decoder
1796            .decode_float(
1797                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1798                &mut cfg_boxes,
1799                &mut cfg_masks,
1800            )
1801            .unwrap();
1802
1803        // Must produce the same number of detections
1804        assert_eq!(
1805            cfg_boxes.len(),
1806            ref_boxes.len(),
1807            "config path produced {} boxes, reference produced {}",
1808            cfg_boxes.len(),
1809            ref_boxes.len()
1810        );
1811
1812        // Boxes must match
1813        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1814            assert!(
1815                cb.equal_within_delta(rb, 0.01),
1816                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1817            );
1818        }
1819
1820        // Masks must match pixel-for-pixel
1821        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1822            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1823            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1824            assert_eq!(
1825                cm_arr, rm_arr,
1826                "mask {i} pixel mismatch between config-driven and reference paths"
1827            );
1828        }
1829    }
1830
1831    #[test]
1832    fn test_decoder_masks_i8() {
1833        let score_threshold = 0.45;
1834        let iou_threshold = 0.45;
1835        let boxes = include_bytes!(concat!(
1836            env!("CARGO_MANIFEST_DIR"),
1837            "/../../testdata/yolov8_boxes_116x8400.bin"
1838        ));
1839        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1840        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1841        let quant_boxes = (0.021287761628627777, 31).into();
1842
1843        let protos = include_bytes!(concat!(
1844            env!("CARGO_MANIFEST_DIR"),
1845            "/../../testdata/yolov8_protos_160x160x32.bin"
1846        ));
1847        let protos =
1848            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1849        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1850        let quant_protos = (0.02491161972284317, -117).into();
1851        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1852        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1853
1854        let decoder = DecoderBuilder::default()
1855            .with_config_yolo_segdet(
1856                configs::Detection {
1857                    decoder: configs::DecoderType::Ultralytics,
1858                    quantization: Some(quant_boxes),
1859                    shape: vec![1, 116, 8400],
1860                    anchors: None,
1861                    dshape: vec![
1862                        (DimName::Batch, 1),
1863                        (DimName::NumFeatures, 116),
1864                        (DimName::NumBoxes, 8400),
1865                    ],
1866                    normalized: Some(true),
1867                },
1868                Protos {
1869                    decoder: configs::DecoderType::Ultralytics,
1870                    quantization: Some(quant_protos),
1871                    shape: vec![1, 160, 160, 32],
1872                    dshape: vec![
1873                        (DimName::Batch, 1),
1874                        (DimName::Height, 160),
1875                        (DimName::Width, 160),
1876                        (DimName::NumProtos, 32),
1877                    ],
1878                },
1879                Some(DecoderVersion::Yolo11),
1880            )
1881            .with_score_threshold(score_threshold)
1882            .with_iou_threshold(iou_threshold)
1883            .build()
1884            .unwrap();
1885
1886        let quant_boxes = quant_boxes.into();
1887        let quant_protos = quant_protos.into();
1888
1889        decode_yolo_segdet_quant(
1890            (boxes.slice(s![0, .., ..]), quant_boxes),
1891            (protos.slice(s![0, .., .., ..]), quant_protos),
1892            score_threshold,
1893            iou_threshold,
1894            Some(configs::Nms::ClassAgnostic),
1895            &mut output_boxes,
1896            &mut output_masks,
1897        )
1898        .unwrap();
1899
1900        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1901        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1902
1903        decoder
1904            .decode_quantized(
1905                &[boxes.view().into(), protos.view().into()],
1906                &mut output_boxes1,
1907                &mut output_masks1,
1908            )
1909            .unwrap();
1910
1911        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1912        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1913
1914        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1915        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1916        decode_yolo_segdet_float(
1917            seg.slice(s![0, .., ..]),
1918            protos.slice(s![0, .., .., ..]),
1919            score_threshold,
1920            iou_threshold,
1921            Some(configs::Nms::ClassAgnostic),
1922            &mut output_boxes_f32,
1923            &mut output_masks_f32,
1924        )
1925        .unwrap();
1926
1927        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1928        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1929
1930        decoder
1931            .decode_float(
1932                &[seg.view().into_dyn(), protos.view().into_dyn()],
1933                &mut output_boxes1_f32,
1934                &mut output_masks1_f32,
1935            )
1936            .unwrap();
1937
1938        compare_outputs(
1939            (&output_boxes, &output_boxes1),
1940            (&output_masks, &output_masks1),
1941        );
1942
1943        compare_outputs(
1944            (&output_boxes, &output_boxes_f32),
1945            (&output_masks, &output_masks_f32),
1946        );
1947
1948        compare_outputs(
1949            (&output_boxes_f32, &output_boxes1_f32),
1950            (&output_masks_f32, &output_masks1_f32),
1951        );
1952    }
1953
1954    #[test]
1955    fn test_decoder_yolo_split() {
1956        let score_threshold = 0.45;
1957        let iou_threshold = 0.45;
1958        let boxes = include_bytes!(concat!(
1959            env!("CARGO_MANIFEST_DIR"),
1960            "/../../testdata/yolov8_boxes_116x8400.bin"
1961        ));
1962        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1963        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1964        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1965
1966        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1967
1968        let decoder = DecoderBuilder::default()
1969            .with_config_yolo_split_det(
1970                configs::Boxes {
1971                    decoder: configs::DecoderType::Ultralytics,
1972                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1973                    shape: vec![1, 4, 8400],
1974                    dshape: vec![
1975                        (DimName::Batch, 1),
1976                        (DimName::BoxCoords, 4),
1977                        (DimName::NumBoxes, 8400),
1978                    ],
1979                    normalized: Some(true),
1980                },
1981                configs::Scores {
1982                    decoder: configs::DecoderType::Ultralytics,
1983                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1984                    shape: vec![1, 80, 8400],
1985                    dshape: vec![
1986                        (DimName::Batch, 1),
1987                        (DimName::NumClasses, 80),
1988                        (DimName::NumBoxes, 8400),
1989                    ],
1990                },
1991            )
1992            .with_score_threshold(score_threshold)
1993            .with_iou_threshold(iou_threshold)
1994            .build()
1995            .unwrap();
1996
1997        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1998        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1999
2000        decoder
2001            .decode_quantized(
2002                &[
2003                    boxes.slice(s![.., ..4, ..]).into(),
2004                    boxes.slice(s![.., 4..84, ..]).into(),
2005                ],
2006                &mut output_boxes,
2007                &mut output_masks,
2008            )
2009            .unwrap();
2010
2011        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2012        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2013        decode_yolo_det_float(
2014            seg.slice(s![0, ..84, ..]),
2015            score_threshold,
2016            iou_threshold,
2017            Some(configs::Nms::ClassAgnostic),
2018            &mut output_boxes_f32,
2019        );
2020
2021        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2022        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2023
2024        decoder
2025            .decode_float(
2026                &[
2027                    seg.slice(s![.., ..4, ..]).into_dyn(),
2028                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2029                ],
2030                &mut output_boxes1,
2031                &mut output_masks1,
2032            )
2033            .unwrap();
2034        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2035        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2036    }
2037
2038    #[test]
2039    fn test_decoder_masks_config_mixed() {
2040        let score_threshold = 0.45;
2041        let iou_threshold = 0.45;
2042        let boxes = include_bytes!(concat!(
2043            env!("CARGO_MANIFEST_DIR"),
2044            "/../../testdata/yolov8_boxes_116x8400.bin"
2045        ));
2046        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2047        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2048        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2049
2050        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2051
2052        let protos = include_bytes!(concat!(
2053            env!("CARGO_MANIFEST_DIR"),
2054            "/../../testdata/yolov8_protos_160x160x32.bin"
2055        ));
2056        let protos =
2057            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2058        let protos: Vec<_> = protos.to_vec();
2059        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2060        let quant_protos = Quantization::new(0.02491161972284317, -117);
2061
2062        let decoder = DecoderBuilder::default()
2063            .with_config_yolo_split_segdet(
2064                configs::Boxes {
2065                    decoder: configs::DecoderType::Ultralytics,
2066                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2067                    shape: vec![1, 4, 8400],
2068                    dshape: vec![
2069                        (DimName::Batch, 1),
2070                        (DimName::BoxCoords, 4),
2071                        (DimName::NumBoxes, 8400),
2072                    ],
2073                    normalized: Some(true),
2074                },
2075                configs::Scores {
2076                    decoder: configs::DecoderType::Ultralytics,
2077                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2078                    shape: vec![1, 80, 8400],
2079                    dshape: vec![
2080                        (DimName::Batch, 1),
2081                        (DimName::NumClasses, 80),
2082                        (DimName::NumBoxes, 8400),
2083                    ],
2084                },
2085                configs::MaskCoefficients {
2086                    decoder: configs::DecoderType::Ultralytics,
2087                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2088                    shape: vec![1, 32, 8400],
2089                    dshape: vec![
2090                        (DimName::Batch, 1),
2091                        (DimName::NumProtos, 32),
2092                        (DimName::NumBoxes, 8400),
2093                    ],
2094                },
2095                configs::Protos {
2096                    decoder: configs::DecoderType::Ultralytics,
2097                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2098                    shape: vec![1, 160, 160, 32],
2099                    dshape: vec![
2100                        (DimName::Batch, 1),
2101                        (DimName::Height, 160),
2102                        (DimName::Width, 160),
2103                        (DimName::NumProtos, 32),
2104                    ],
2105                },
2106            )
2107            .with_score_threshold(score_threshold)
2108            .with_iou_threshold(iou_threshold)
2109            .build()
2110            .unwrap();
2111
2112        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2113        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2114
2115        decoder
2116            .decode_quantized(
2117                &[
2118                    boxes.slice(s![.., ..4, ..]).into(),
2119                    boxes.slice(s![.., 4..84, ..]).into(),
2120                    boxes.slice(s![.., 84.., ..]).into(),
2121                    protos.view().into(),
2122                ],
2123                &mut output_boxes,
2124                &mut output_masks,
2125            )
2126            .unwrap();
2127
2128        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2129        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2130        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2131        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2132        decode_yolo_segdet_float(
2133            seg.slice(s![0, .., ..]),
2134            protos.slice(s![0, .., .., ..]),
2135            score_threshold,
2136            iou_threshold,
2137            Some(configs::Nms::ClassAgnostic),
2138            &mut output_boxes_f32,
2139            &mut output_masks_f32,
2140        )
2141        .unwrap();
2142
2143        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2144        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2145
2146        decoder
2147            .decode_float(
2148                &[
2149                    seg.slice(s![.., ..4, ..]).into_dyn(),
2150                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2151                    seg.slice(s![.., 84.., ..]).into_dyn(),
2152                    protos.view().into_dyn(),
2153                ],
2154                &mut output_boxes1,
2155                &mut output_masks1,
2156            )
2157            .unwrap();
2158        compare_outputs(
2159            (&output_boxes, &output_boxes_f32),
2160            (&output_masks, &output_masks_f32),
2161        );
2162        compare_outputs(
2163            (&output_boxes_f32, &output_boxes1),
2164            (&output_masks_f32, &output_masks1),
2165        );
2166    }
2167
2168    #[test]
2169    fn test_decoder_masks_config_i32() {
2170        let score_threshold = 0.45;
2171        let iou_threshold = 0.45;
2172        let boxes = include_bytes!(concat!(
2173            env!("CARGO_MANIFEST_DIR"),
2174            "/../../testdata/yolov8_boxes_116x8400.bin"
2175        ));
2176        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2177        let scale = 1 << 23;
2178        let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2179        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2180
2181        let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2182
2183        let protos = include_bytes!(concat!(
2184            env!("CARGO_MANIFEST_DIR"),
2185            "/../../testdata/yolov8_protos_160x160x32.bin"
2186        ));
2187        let protos =
2188            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2189        let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2190        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2191        let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2192
2193        let decoder = DecoderBuilder::default()
2194            .with_config_yolo_split_segdet(
2195                configs::Boxes {
2196                    decoder: configs::DecoderType::Ultralytics,
2197                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2198                    shape: vec![1, 4, 8400],
2199                    dshape: vec![
2200                        (DimName::Batch, 1),
2201                        (DimName::BoxCoords, 4),
2202                        (DimName::NumBoxes, 8400),
2203                    ],
2204                    normalized: Some(true),
2205                },
2206                configs::Scores {
2207                    decoder: configs::DecoderType::Ultralytics,
2208                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2209                    shape: vec![1, 80, 8400],
2210                    dshape: vec![
2211                        (DimName::Batch, 1),
2212                        (DimName::NumClasses, 80),
2213                        (DimName::NumBoxes, 8400),
2214                    ],
2215                },
2216                configs::MaskCoefficients {
2217                    decoder: configs::DecoderType::Ultralytics,
2218                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2219                    shape: vec![1, 32, 8400],
2220                    dshape: vec![
2221                        (DimName::Batch, 1),
2222                        (DimName::NumProtos, 32),
2223                        (DimName::NumBoxes, 8400),
2224                    ],
2225                },
2226                configs::Protos {
2227                    decoder: configs::DecoderType::Ultralytics,
2228                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2229                    shape: vec![1, 160, 160, 32],
2230                    dshape: vec![
2231                        (DimName::Batch, 1),
2232                        (DimName::Height, 160),
2233                        (DimName::Width, 160),
2234                        (DimName::NumProtos, 32),
2235                    ],
2236                },
2237            )
2238            .with_score_threshold(score_threshold)
2239            .with_iou_threshold(iou_threshold)
2240            .build()
2241            .unwrap();
2242
2243        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2244        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2245
2246        decoder
2247            .decode_quantized(
2248                &[
2249                    boxes.slice(s![.., ..4, ..]).into(),
2250                    boxes.slice(s![.., 4..84, ..]).into(),
2251                    boxes.slice(s![.., 84.., ..]).into(),
2252                    protos.view().into(),
2253                ],
2254                &mut output_boxes,
2255                &mut output_masks,
2256            )
2257            .unwrap();
2258
2259        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2260        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2261        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2262        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2263        decode_yolo_segdet_float(
2264            seg.slice(s![0, .., ..]),
2265            protos.slice(s![0, .., .., ..]),
2266            score_threshold,
2267            iou_threshold,
2268            Some(configs::Nms::ClassAgnostic),
2269            &mut output_boxes_f32,
2270            &mut output_masks_f32,
2271        )
2272        .unwrap();
2273
2274        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2275        assert_eq!(output_masks.len(), output_masks_f32.len());
2276
2277        compare_outputs(
2278            (&output_boxes, &output_boxes_f32),
2279            (&output_masks, &output_masks_f32),
2280        );
2281    }
2282
2283    /// test running multiple decoders concurrently
2284    #[test]
2285    fn test_context_switch() {
2286        let yolo_det = || {
2287            let score_threshold = 0.25;
2288            let iou_threshold = 0.7;
2289            let out = include_bytes!(concat!(
2290                env!("CARGO_MANIFEST_DIR"),
2291                "/../../testdata/yolov8s_80_classes.bin"
2292            ));
2293            let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2294            let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2295            let quant = (0.0040811873, -123).into();
2296
2297            let decoder = DecoderBuilder::default()
2298                .with_config_yolo_det(
2299                    configs::Detection {
2300                        decoder: DecoderType::Ultralytics,
2301                        shape: vec![1, 84, 8400],
2302                        anchors: None,
2303                        quantization: Some(quant),
2304                        dshape: vec![
2305                            (DimName::Batch, 1),
2306                            (DimName::NumFeatures, 84),
2307                            (DimName::NumBoxes, 8400),
2308                        ],
2309                        normalized: None,
2310                    },
2311                    None,
2312                )
2313                .with_score_threshold(score_threshold)
2314                .with_iou_threshold(iou_threshold)
2315                .build()
2316                .unwrap();
2317
2318            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2319            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2320
2321            for _ in 0..100 {
2322                decoder
2323                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2324                    .unwrap();
2325
2326                assert!(output_boxes[0].equal_within_delta(
2327                    &DetectBox {
2328                        bbox: BoundingBox {
2329                            xmin: 0.5285137,
2330                            ymin: 0.05305544,
2331                            xmax: 0.87541467,
2332                            ymax: 0.9998909,
2333                        },
2334                        score: 0.5591227,
2335                        label: 0
2336                    },
2337                    1e-6
2338                ));
2339
2340                assert!(output_boxes[1].equal_within_delta(
2341                    &DetectBox {
2342                        bbox: BoundingBox {
2343                            xmin: 0.130598,
2344                            ymin: 0.43260583,
2345                            xmax: 0.35098213,
2346                            ymax: 0.9958097,
2347                        },
2348                        score: 0.33057618,
2349                        label: 75
2350                    },
2351                    1e-6
2352                ));
2353                assert!(output_masks.is_empty());
2354            }
2355        };
2356
2357        let modelpack_det_split = || {
2358            let score_threshold = 0.8;
2359            let iou_threshold = 0.5;
2360
2361            let seg = include_bytes!(concat!(
2362                env!("CARGO_MANIFEST_DIR"),
2363                "/../../testdata/modelpack_seg_2x160x160.bin"
2364            ));
2365            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2366
2367            let detect0 = include_bytes!(concat!(
2368                env!("CARGO_MANIFEST_DIR"),
2369                "/../../testdata/modelpack_split_9x15x18.bin"
2370            ));
2371            let detect0 =
2372                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2373
2374            let detect1 = include_bytes!(concat!(
2375                env!("CARGO_MANIFEST_DIR"),
2376                "/../../testdata/modelpack_split_17x30x18.bin"
2377            ));
2378            let detect1 =
2379                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2380
2381            let mut mask = seg.slice(s![0, .., .., ..]);
2382            mask.swap_axes(0, 1);
2383            mask.swap_axes(1, 2);
2384            let mask = [Segmentation {
2385                xmin: 0.0,
2386                ymin: 0.0,
2387                xmax: 1.0,
2388                ymax: 1.0,
2389                segmentation: mask.into_owned(),
2390            }];
2391            let correct_boxes = [DetectBox {
2392                bbox: BoundingBox {
2393                    xmin: 0.43171933,
2394                    ymin: 0.68243736,
2395                    xmax: 0.5626645,
2396                    ymax: 0.808863,
2397                },
2398                score: 0.99240804,
2399                label: 0,
2400            }];
2401
2402            let quant0 = (0.08547406643629074, 174).into();
2403            let quant1 = (0.09929127991199493, 183).into();
2404            let quant_seg = (1.0 / 255.0, 0).into();
2405
2406            let anchors0 = vec![
2407                [0.36666667461395264, 0.31481480598449707],
2408                [0.38749998807907104, 0.4740740656852722],
2409                [0.5333333611488342, 0.644444465637207],
2410            ];
2411            let anchors1 = vec![
2412                [0.13750000298023224, 0.2074074000120163],
2413                [0.2541666626930237, 0.21481481194496155],
2414                [0.23125000298023224, 0.35185185074806213],
2415            ];
2416
2417            let decoder = DecoderBuilder::default()
2418                .with_config_modelpack_segdet_split(
2419                    vec![
2420                        configs::Detection {
2421                            decoder: DecoderType::ModelPack,
2422                            shape: vec![1, 17, 30, 18],
2423                            anchors: Some(anchors1),
2424                            quantization: Some(quant1),
2425                            dshape: vec![
2426                                (DimName::Batch, 1),
2427                                (DimName::Height, 17),
2428                                (DimName::Width, 30),
2429                                (DimName::NumAnchorsXFeatures, 18),
2430                            ],
2431                            normalized: None,
2432                        },
2433                        configs::Detection {
2434                            decoder: DecoderType::ModelPack,
2435                            shape: vec![1, 9, 15, 18],
2436                            anchors: Some(anchors0),
2437                            quantization: Some(quant0),
2438                            dshape: vec![
2439                                (DimName::Batch, 1),
2440                                (DimName::Height, 9),
2441                                (DimName::Width, 15),
2442                                (DimName::NumAnchorsXFeatures, 18),
2443                            ],
2444                            normalized: None,
2445                        },
2446                    ],
2447                    configs::Segmentation {
2448                        decoder: DecoderType::ModelPack,
2449                        quantization: Some(quant_seg),
2450                        shape: vec![1, 2, 160, 160],
2451                        dshape: vec![
2452                            (DimName::Batch, 1),
2453                            (DimName::NumClasses, 2),
2454                            (DimName::Height, 160),
2455                            (DimName::Width, 160),
2456                        ],
2457                    },
2458                )
2459                .with_score_threshold(score_threshold)
2460                .with_iou_threshold(iou_threshold)
2461                .build()
2462                .unwrap();
2463            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2464            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2465
2466            for _ in 0..100 {
2467                decoder
2468                    .decode_quantized(
2469                        &[
2470                            detect0.view().into(),
2471                            detect1.view().into(),
2472                            seg.view().into(),
2473                        ],
2474                        &mut output_boxes,
2475                        &mut output_masks,
2476                    )
2477                    .unwrap();
2478
2479                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2480            }
2481        };
2482
2483        let handles = vec![
2484            std::thread::spawn(yolo_det),
2485            std::thread::spawn(modelpack_det_split),
2486            std::thread::spawn(yolo_det),
2487            std::thread::spawn(modelpack_det_split),
2488            std::thread::spawn(yolo_det),
2489            std::thread::spawn(modelpack_det_split),
2490            std::thread::spawn(yolo_det),
2491            std::thread::spawn(modelpack_det_split),
2492        ];
2493        for handle in handles {
2494            handle.join().unwrap();
2495        }
2496    }
2497
2498    #[test]
2499    fn test_ndarray_to_xyxy_float() {
2500        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2501        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2502        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2503
2504        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2505        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2506        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2507    }
2508
2509    #[test]
2510    fn test_class_aware_nms_float() {
2511        use crate::float::nms_class_aware_float;
2512
2513        // Create two overlapping boxes with different classes
2514        let boxes = vec![
2515            DetectBox {
2516                bbox: BoundingBox {
2517                    xmin: 0.0,
2518                    ymin: 0.0,
2519                    xmax: 0.5,
2520                    ymax: 0.5,
2521                },
2522                score: 0.9,
2523                label: 0, // class 0
2524            },
2525            DetectBox {
2526                bbox: BoundingBox {
2527                    xmin: 0.1,
2528                    ymin: 0.1,
2529                    xmax: 0.6,
2530                    ymax: 0.6,
2531                },
2532                score: 0.8,
2533                label: 1, // class 1 - different class
2534            },
2535        ];
2536
2537        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2538        // threshold 0.3)
2539        let result = nms_class_aware_float(0.3, boxes.clone());
2540        assert_eq!(
2541            result.len(),
2542            2,
2543            "Class-aware NMS should keep both boxes with different classes"
2544        );
2545
2546        // Now test with same class - should suppress one
2547        let same_class_boxes = vec![
2548            DetectBox {
2549                bbox: BoundingBox {
2550                    xmin: 0.0,
2551                    ymin: 0.0,
2552                    xmax: 0.5,
2553                    ymax: 0.5,
2554                },
2555                score: 0.9,
2556                label: 0,
2557            },
2558            DetectBox {
2559                bbox: BoundingBox {
2560                    xmin: 0.1,
2561                    ymin: 0.1,
2562                    xmax: 0.6,
2563                    ymax: 0.6,
2564                },
2565                score: 0.8,
2566                label: 0, // same class
2567            },
2568        ];
2569
2570        let result = nms_class_aware_float(0.3, same_class_boxes);
2571        assert_eq!(
2572            result.len(),
2573            1,
2574            "Class-aware NMS should suppress overlapping box with same class"
2575        );
2576        assert_eq!(result[0].label, 0);
2577        assert!((result[0].score - 0.9).abs() < 1e-6);
2578    }
2579
2580    #[test]
2581    fn test_class_agnostic_vs_aware_nms() {
2582        use crate::float::{nms_class_aware_float, nms_float};
2583
2584        // Two overlapping boxes with different classes
2585        let boxes = vec![
2586            DetectBox {
2587                bbox: BoundingBox {
2588                    xmin: 0.0,
2589                    ymin: 0.0,
2590                    xmax: 0.5,
2591                    ymax: 0.5,
2592                },
2593                score: 0.9,
2594                label: 0,
2595            },
2596            DetectBox {
2597                bbox: BoundingBox {
2598                    xmin: 0.1,
2599                    ymin: 0.1,
2600                    xmax: 0.6,
2601                    ymax: 0.6,
2602                },
2603                score: 0.8,
2604                label: 1,
2605            },
2606        ];
2607
2608        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2609        let agnostic_result = nms_float(0.3, boxes.clone());
2610        assert_eq!(
2611            agnostic_result.len(),
2612            1,
2613            "Class-agnostic NMS should suppress overlapping boxes"
2614        );
2615
2616        // Class-aware should keep both (different classes)
2617        let aware_result = nms_class_aware_float(0.3, boxes);
2618        assert_eq!(
2619            aware_result.len(),
2620            2,
2621            "Class-aware NMS should keep boxes with different classes"
2622        );
2623    }
2624
2625    #[test]
2626    fn test_class_aware_nms_int() {
2627        use crate::byte::nms_class_aware_int;
2628
2629        // Create two overlapping boxes with different classes
2630        let boxes = vec![
2631            DetectBoxQuantized {
2632                bbox: BoundingBox {
2633                    xmin: 0.0,
2634                    ymin: 0.0,
2635                    xmax: 0.5,
2636                    ymax: 0.5,
2637                },
2638                score: 200_u8,
2639                label: 0,
2640            },
2641            DetectBoxQuantized {
2642                bbox: BoundingBox {
2643                    xmin: 0.1,
2644                    ymin: 0.1,
2645                    xmax: 0.6,
2646                    ymax: 0.6,
2647                },
2648                score: 180_u8,
2649                label: 1, // different class
2650            },
2651        ];
2652
2653        // Should keep both (different classes)
2654        let result = nms_class_aware_int(0.5, boxes);
2655        assert_eq!(
2656            result.len(),
2657            2,
2658            "Class-aware NMS (int) should keep boxes with different classes"
2659        );
2660    }
2661
2662    #[test]
2663    fn test_nms_enum_default() {
2664        // Test that Nms enum has the correct default
2665        let default_nms: configs::Nms = Default::default();
2666        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2667    }
2668
2669    #[test]
2670    fn test_decoder_nms_mode() {
2671        // Test that decoder properly stores NMS mode
2672        let decoder = DecoderBuilder::default()
2673            .with_config_yolo_det(
2674                configs::Detection {
2675                    anchors: None,
2676                    decoder: DecoderType::Ultralytics,
2677                    quantization: None,
2678                    shape: vec![1, 84, 8400],
2679                    dshape: Vec::new(),
2680                    normalized: Some(true),
2681                },
2682                None,
2683            )
2684            .with_nms(Some(configs::Nms::ClassAware))
2685            .build()
2686            .unwrap();
2687
2688        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2689    }
2690
2691    #[test]
2692    fn test_decoder_nms_bypass() {
2693        // Test that decoder can be configured with nms=None (bypass)
2694        let decoder = DecoderBuilder::default()
2695            .with_config_yolo_det(
2696                configs::Detection {
2697                    anchors: None,
2698                    decoder: DecoderType::Ultralytics,
2699                    quantization: None,
2700                    shape: vec![1, 84, 8400],
2701                    dshape: Vec::new(),
2702                    normalized: Some(true),
2703                },
2704                None,
2705            )
2706            .with_nms(None)
2707            .build()
2708            .unwrap();
2709
2710        assert_eq!(decoder.nms, None);
2711    }
2712
2713    #[test]
2714    fn test_decoder_normalized_boxes_true() {
2715        // Test that normalized_boxes returns Some(true) when explicitly set
2716        let decoder = DecoderBuilder::default()
2717            .with_config_yolo_det(
2718                configs::Detection {
2719                    anchors: None,
2720                    decoder: DecoderType::Ultralytics,
2721                    quantization: None,
2722                    shape: vec![1, 84, 8400],
2723                    dshape: Vec::new(),
2724                    normalized: Some(true),
2725                },
2726                None,
2727            )
2728            .build()
2729            .unwrap();
2730
2731        assert_eq!(decoder.normalized_boxes(), Some(true));
2732    }
2733
2734    #[test]
2735    fn test_decoder_normalized_boxes_false() {
2736        // Test that normalized_boxes returns Some(false) when config specifies
2737        // unnormalized
2738        let decoder = DecoderBuilder::default()
2739            .with_config_yolo_det(
2740                configs::Detection {
2741                    anchors: None,
2742                    decoder: DecoderType::Ultralytics,
2743                    quantization: None,
2744                    shape: vec![1, 84, 8400],
2745                    dshape: Vec::new(),
2746                    normalized: Some(false),
2747                },
2748                None,
2749            )
2750            .build()
2751            .unwrap();
2752
2753        assert_eq!(decoder.normalized_boxes(), Some(false));
2754    }
2755
2756    #[test]
2757    fn test_decoder_normalized_boxes_unknown() {
2758        // Test that normalized_boxes returns None when not specified in config
2759        let decoder = DecoderBuilder::default()
2760            .with_config_yolo_det(
2761                configs::Detection {
2762                    anchors: None,
2763                    decoder: DecoderType::Ultralytics,
2764                    quantization: None,
2765                    shape: vec![1, 84, 8400],
2766                    dshape: Vec::new(),
2767                    normalized: None,
2768                },
2769                Some(DecoderVersion::Yolo11),
2770            )
2771            .build()
2772            .unwrap();
2773
2774        assert_eq!(decoder.normalized_boxes(), None);
2775    }
2776}
2777
2778#[cfg(feature = "tracker")]
2779#[cfg(test)]
2780#[cfg_attr(coverage_nightly, coverage(off))]
2781mod decoder_tracked_tests {
2782
2783    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
2784    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
2785    use num_traits::{AsPrimitive, Float, PrimInt};
2786    use rand::{RngExt, SeedableRng};
2787    use rand_distr::StandardNormal;
2788
2789    use crate::{
2790        configs::{self, DimName},
2791        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
2792    };
2793
2794    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2795        input: ArrayView<F, D>,
2796        quant: Quantization,
2797    ) -> Array<T, D>
2798    where
2799        i32: num_traits::AsPrimitive<F>,
2800        f32: num_traits::AsPrimitive<F>,
2801    {
2802        let zero_point = quant.zero_point.as_();
2803        let div_scale = F::one() / quant.scale.as_();
2804        if zero_point != F::zero() {
2805            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2806        } else {
2807            input.mapv(|d| (d * div_scale).round().as_())
2808        }
2809    }
2810
2811    #[test]
2812    fn test_decoder_tracked_random_jitter() {
2813        use crate::configs::{DecoderType, Nms};
2814        use crate::DecoderBuilder;
2815
2816        let score_threshold = 0.25;
2817        let iou_threshold = 0.1;
2818        let out = include_bytes!(concat!(
2819            env!("CARGO_MANIFEST_DIR"),
2820            "/../../testdata/yolov8s_80_classes.bin"
2821        ));
2822        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2823        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2824        let quant = (0.0040811873, -123).into();
2825
2826        let decoder = DecoderBuilder::default()
2827            .with_config_yolo_det(
2828                crate::configs::Detection {
2829                    decoder: DecoderType::Ultralytics,
2830                    shape: vec![1, 84, 8400],
2831                    anchors: None,
2832                    quantization: Some(quant),
2833                    dshape: vec![
2834                        (crate::configs::DimName::Batch, 1),
2835                        (crate::configs::DimName::NumFeatures, 84),
2836                        (crate::configs::DimName::NumBoxes, 8400),
2837                    ],
2838                    normalized: Some(true),
2839                },
2840                None,
2841            )
2842            .with_score_threshold(score_threshold)
2843            .with_iou_threshold(iou_threshold)
2844            .with_nms(Some(Nms::ClassAgnostic))
2845            .build()
2846            .unwrap();
2847        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
2848
2849        let expected_boxes = [
2850            crate::DetectBox {
2851                bbox: crate::BoundingBox {
2852                    xmin: 0.5285137,
2853                    ymin: 0.05305544,
2854                    xmax: 0.87541467,
2855                    ymax: 0.9998909,
2856                },
2857                score: 0.5591227,
2858                label: 0,
2859            },
2860            crate::DetectBox {
2861                bbox: crate::BoundingBox {
2862                    xmin: 0.130598,
2863                    ymin: 0.43260583,
2864                    xmax: 0.35098213,
2865                    ymax: 0.9958097,
2866                },
2867                score: 0.33057618,
2868                label: 75,
2869            },
2870        ];
2871
2872        let mut tracker = ByteTrackBuilder::new()
2873            .track_update(0.1)
2874            .track_high_conf(0.3)
2875            .build();
2876
2877        let mut output_boxes = Vec::with_capacity(50);
2878        let mut output_masks = Vec::with_capacity(50);
2879        let mut output_tracks = Vec::with_capacity(50);
2880
2881        decoder
2882            .decode_tracked_quantized(
2883                &mut tracker,
2884                0,
2885                &[out.view().into()],
2886                &mut output_boxes,
2887                &mut output_masks,
2888                &mut output_tracks,
2889            )
2890            .unwrap();
2891
2892        assert_eq!(output_boxes.len(), 2);
2893        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
2894        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
2895
2896        let mut last_boxes = output_boxes.clone();
2897
2898        for i in 1..=100 {
2899            let mut out = out.clone();
2900            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
2901            let mut x_values = out.slice_mut(s![0, 0, ..]);
2902            for x in x_values.iter_mut() {
2903                let r: f32 = rng.sample(StandardNormal);
2904                let r = r.clamp(-2.0, 2.0) / 2.0;
2905                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
2906            }
2907
2908            let mut y_values = out.slice_mut(s![0, 1, ..]);
2909            for y in y_values.iter_mut() {
2910                let r: f32 = rng.sample(StandardNormal);
2911                let r = r.clamp(-2.0, 2.0) / 2.0;
2912                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
2913            }
2914
2915            decoder
2916                .decode_tracked_quantized(
2917                    &mut tracker,
2918                    100_000_000 * i / 3, // simulate 33.333ms between frames
2919                    &[out.view().into()],
2920                    &mut output_boxes,
2921                    &mut output_masks,
2922                    &mut output_tracks,
2923                )
2924                .unwrap();
2925
2926            assert_eq!(output_boxes.len(), 2);
2927            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
2928            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
2929
2930            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
2931            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
2932            last_boxes = output_boxes.clone();
2933        }
2934    }
2935
2936    #[test]
2937    fn test_decoder_tracked_segdet() {
2938        use crate::configs::Nms;
2939        use crate::DecoderBuilder;
2940
2941        let score_threshold = 0.45;
2942        let iou_threshold = 0.45;
2943        let boxes = include_bytes!(concat!(
2944            env!("CARGO_MANIFEST_DIR"),
2945            "/../../testdata/yolov8_boxes_116x8400.bin"
2946        ));
2947        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2948        let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
2949
2950        let protos = include_bytes!(concat!(
2951            env!("CARGO_MANIFEST_DIR"),
2952            "/../../testdata/yolov8_protos_160x160x32.bin"
2953        ));
2954        let protos =
2955            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2956        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2957
2958        let config = include_str!(concat!(
2959            env!("CARGO_MANIFEST_DIR"),
2960            "/../../testdata/yolov8_seg.yaml"
2961        ));
2962
2963        let decoder = DecoderBuilder::default()
2964            .with_config_yaml_str(config.to_string())
2965            .with_score_threshold(score_threshold)
2966            .with_iou_threshold(iou_threshold)
2967            .with_nms(Some(Nms::ClassAgnostic))
2968            .build()
2969            .unwrap();
2970
2971        let expected_boxes = [
2972            DetectBox {
2973                bbox: BoundingBox {
2974                    xmin: 0.08515105,
2975                    ymin: 0.7131401,
2976                    xmax: 0.29802868,
2977                    ymax: 0.8195788,
2978                },
2979                score: 0.91537374,
2980                label: 23,
2981            },
2982            DetectBox {
2983                bbox: BoundingBox {
2984                    xmin: 0.59605736,
2985                    ymin: 0.25545314,
2986                    xmax: 0.93666154,
2987                    ymax: 0.72378385,
2988                },
2989                score: 0.91537374,
2990                label: 23,
2991            },
2992        ];
2993
2994        let mut tracker = ByteTrackBuilder::new()
2995            .track_update(0.1)
2996            .track_high_conf(0.7)
2997            .build();
2998
2999        let mut output_boxes = Vec::with_capacity(50);
3000        let mut output_masks = Vec::with_capacity(50);
3001        let mut output_tracks = Vec::with_capacity(50);
3002
3003        decoder
3004            .decode_tracked_quantized(
3005                &mut tracker,
3006                0,
3007                &[boxes.view().into(), protos.view().into()],
3008                &mut output_boxes,
3009                &mut output_masks,
3010                &mut output_tracks,
3011            )
3012            .unwrap();
3013
3014        assert_eq!(output_boxes.len(), 2);
3015        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3016        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3017
3018        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3019        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3020        for score in scores_values.iter_mut() {
3021            *score = i8::MIN; // set all scores to minimum to simulate no detections
3022        }
3023        decoder
3024            .decode_tracked_quantized(
3025                &mut tracker,
3026                100_000_000 / 3,
3027                &[boxes.view().into(), protos.view().into()],
3028                &mut output_boxes,
3029                &mut output_masks,
3030                &mut output_tracks,
3031            )
3032            .unwrap();
3033
3034        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3035        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3036
3037        // no masks when the boxes are from tracker prediction without a matching detection
3038        assert!(output_masks.is_empty())
3039    }
3040
3041    #[test]
3042    fn test_decoder_tracked_segdet_float() {
3043        use crate::configs::Nms;
3044        use crate::DecoderBuilder;
3045
3046        let score_threshold = 0.45;
3047        let iou_threshold = 0.45;
3048        let boxes = include_bytes!(concat!(
3049            env!("CARGO_MANIFEST_DIR"),
3050            "/../../testdata/yolov8_boxes_116x8400.bin"
3051        ));
3052        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3053        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3054        let quant_boxes = (0.021287762, 31);
3055        let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3056
3057        let protos = include_bytes!(concat!(
3058            env!("CARGO_MANIFEST_DIR"),
3059            "/../../testdata/yolov8_protos_160x160x32.bin"
3060        ));
3061        let protos =
3062            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3063        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3064        let quant_protos = (0.02491162, -117);
3065        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3066
3067        let config = include_str!(concat!(
3068            env!("CARGO_MANIFEST_DIR"),
3069            "/../../testdata/yolov8_seg.yaml"
3070        ));
3071
3072        let decoder = DecoderBuilder::default()
3073            .with_config_yaml_str(config.to_string())
3074            .with_score_threshold(score_threshold)
3075            .with_iou_threshold(iou_threshold)
3076            .with_nms(Some(Nms::ClassAgnostic))
3077            .build()
3078            .unwrap();
3079
3080        let expected_boxes = [
3081            DetectBox {
3082                bbox: BoundingBox {
3083                    xmin: 0.08515105,
3084                    ymin: 0.7131401,
3085                    xmax: 0.29802868,
3086                    ymax: 0.8195788,
3087                },
3088                score: 0.91537374,
3089                label: 23,
3090            },
3091            DetectBox {
3092                bbox: BoundingBox {
3093                    xmin: 0.59605736,
3094                    ymin: 0.25545314,
3095                    xmax: 0.93666154,
3096                    ymax: 0.72378385,
3097                },
3098                score: 0.91537374,
3099                label: 23,
3100            },
3101        ];
3102
3103        let mut tracker = ByteTrackBuilder::new()
3104            .track_update(0.1)
3105            .track_high_conf(0.7)
3106            .build();
3107
3108        let mut output_boxes = Vec::with_capacity(50);
3109        let mut output_masks = Vec::with_capacity(50);
3110        let mut output_tracks = Vec::with_capacity(50);
3111
3112        decoder
3113            .decode_tracked_float(
3114                &mut tracker,
3115                0,
3116                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3117                &mut output_boxes,
3118                &mut output_masks,
3119                &mut output_tracks,
3120            )
3121            .unwrap();
3122
3123        assert_eq!(output_boxes.len(), 2);
3124        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3125        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3126
3127        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3128        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3129        for score in scores_values.iter_mut() {
3130            *score = 0.0; // set all scores to minimum to simulate no detections
3131        }
3132        decoder
3133            .decode_tracked_float(
3134                &mut tracker,
3135                100_000_000 / 3,
3136                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3137                &mut output_boxes,
3138                &mut output_masks,
3139                &mut output_tracks,
3140            )
3141            .unwrap();
3142
3143        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3144        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3145
3146        // no masks when the boxes are from tracker prediction without a matching detection
3147        assert!(output_masks.is_empty())
3148    }
3149
3150    #[test]
3151    fn test_decoder_tracked_segdet_proto() {
3152        use crate::configs::Nms;
3153        use crate::DecoderBuilder;
3154
3155        let score_threshold = 0.45;
3156        let iou_threshold = 0.45;
3157        let boxes = include_bytes!(concat!(
3158            env!("CARGO_MANIFEST_DIR"),
3159            "/../../testdata/yolov8_boxes_116x8400.bin"
3160        ));
3161        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3162        let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3163
3164        let protos = include_bytes!(concat!(
3165            env!("CARGO_MANIFEST_DIR"),
3166            "/../../testdata/yolov8_protos_160x160x32.bin"
3167        ));
3168        let protos =
3169            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3170        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3171
3172        let config = include_str!(concat!(
3173            env!("CARGO_MANIFEST_DIR"),
3174            "/../../testdata/yolov8_seg.yaml"
3175        ));
3176
3177        let decoder = DecoderBuilder::default()
3178            .with_config_yaml_str(config.to_string())
3179            .with_score_threshold(score_threshold)
3180            .with_iou_threshold(iou_threshold)
3181            .with_nms(Some(Nms::ClassAgnostic))
3182            .build()
3183            .unwrap();
3184
3185        let expected_boxes = [
3186            DetectBox {
3187                bbox: BoundingBox {
3188                    xmin: 0.08515105,
3189                    ymin: 0.7131401,
3190                    xmax: 0.29802868,
3191                    ymax: 0.8195788,
3192                },
3193                score: 0.91537374,
3194                label: 23,
3195            },
3196            DetectBox {
3197                bbox: BoundingBox {
3198                    xmin: 0.59605736,
3199                    ymin: 0.25545314,
3200                    xmax: 0.93666154,
3201                    ymax: 0.72378385,
3202                },
3203                score: 0.91537374,
3204                label: 23,
3205            },
3206        ];
3207
3208        let mut tracker = ByteTrackBuilder::new()
3209            .track_update(0.1)
3210            .track_high_conf(0.7)
3211            .build();
3212
3213        let mut output_boxes = Vec::with_capacity(50);
3214        let mut output_tracks = Vec::with_capacity(50);
3215
3216        decoder
3217            .decode_tracked_quantized_proto(
3218                &mut tracker,
3219                0,
3220                &[boxes.view().into(), protos.view().into()],
3221                &mut output_boxes,
3222                &mut output_tracks,
3223            )
3224            .unwrap();
3225
3226        assert_eq!(output_boxes.len(), 2);
3227        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3228        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3229
3230        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3231        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3232        for score in scores_values.iter_mut() {
3233            *score = i8::MIN; // set all scores to minimum to simulate no detections
3234        }
3235        let protos = decoder
3236            .decode_tracked_quantized_proto(
3237                &mut tracker,
3238                100_000_000 / 3,
3239                &[boxes.view().into(), protos.view().into()],
3240                &mut output_boxes,
3241                &mut output_tracks,
3242            )
3243            .unwrap();
3244
3245        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3246        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3247
3248        // no masks when the boxes are from tracker prediction without a matching detection
3249        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3250    }
3251
3252    #[test]
3253    fn test_decoder_tracked_segdet_proto_float() {
3254        use crate::configs::Nms;
3255        use crate::DecoderBuilder;
3256
3257        let score_threshold = 0.45;
3258        let iou_threshold = 0.45;
3259        let boxes = include_bytes!(concat!(
3260            env!("CARGO_MANIFEST_DIR"),
3261            "/../../testdata/yolov8_boxes_116x8400.bin"
3262        ));
3263        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3264        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3265        let quant_boxes = (0.021287762, 31);
3266        let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3267
3268        let protos = include_bytes!(concat!(
3269            env!("CARGO_MANIFEST_DIR"),
3270            "/../../testdata/yolov8_protos_160x160x32.bin"
3271        ));
3272        let protos =
3273            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3274        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3275        let quant_protos = (0.02491162, -117);
3276        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3277
3278        let config = include_str!(concat!(
3279            env!("CARGO_MANIFEST_DIR"),
3280            "/../../testdata/yolov8_seg.yaml"
3281        ));
3282
3283        let decoder = DecoderBuilder::default()
3284            .with_config_yaml_str(config.to_string())
3285            .with_score_threshold(score_threshold)
3286            .with_iou_threshold(iou_threshold)
3287            .with_nms(Some(Nms::ClassAgnostic))
3288            .build()
3289            .unwrap();
3290
3291        let expected_boxes = [
3292            DetectBox {
3293                bbox: BoundingBox {
3294                    xmin: 0.08515105,
3295                    ymin: 0.7131401,
3296                    xmax: 0.29802868,
3297                    ymax: 0.8195788,
3298                },
3299                score: 0.91537374,
3300                label: 23,
3301            },
3302            DetectBox {
3303                bbox: BoundingBox {
3304                    xmin: 0.59605736,
3305                    ymin: 0.25545314,
3306                    xmax: 0.93666154,
3307                    ymax: 0.72378385,
3308                },
3309                score: 0.91537374,
3310                label: 23,
3311            },
3312        ];
3313
3314        let mut tracker = ByteTrackBuilder::new()
3315            .track_update(0.1)
3316            .track_high_conf(0.7)
3317            .build();
3318
3319        let mut output_boxes = Vec::with_capacity(50);
3320        let mut output_tracks = Vec::with_capacity(50);
3321
3322        decoder
3323            .decode_tracked_float_proto(
3324                &mut tracker,
3325                0,
3326                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3327                &mut output_boxes,
3328                &mut output_tracks,
3329            )
3330            .unwrap();
3331
3332        assert_eq!(output_boxes.len(), 2);
3333        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3334        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3335
3336        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3337        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3338        for score in scores_values.iter_mut() {
3339            *score = 0.0; // set all scores to minimum to simulate no detections
3340        }
3341        let protos = decoder
3342            .decode_tracked_float_proto(
3343                &mut tracker,
3344                100_000_000 / 3,
3345                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3346                &mut output_boxes,
3347                &mut output_tracks,
3348            )
3349            .unwrap();
3350
3351        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3352        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3353
3354        // no masks when the boxes are from tracker prediction without a matching detection
3355        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3356    }
3357
3358    #[test]
3359    fn test_decoder_tracked_segdet_split() {
3360        let score_threshold = 0.45;
3361        let iou_threshold = 0.45;
3362
3363        let boxes = include_bytes!(concat!(
3364            env!("CARGO_MANIFEST_DIR"),
3365            "/../../testdata/yolov8_boxes_116x8400.bin"
3366        ));
3367        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3368        let boxes = boxes.to_vec();
3369        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3370
3371        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3372        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3373        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3374
3375        let quant_boxes = (0.021287762, 31);
3376
3377        let protos = include_bytes!(concat!(
3378            env!("CARGO_MANIFEST_DIR"),
3379            "/../../testdata/yolov8_protos_160x160x32.bin"
3380        ));
3381        let protos =
3382            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3383        let protos = protos.to_vec();
3384        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3385        let quant_protos = (0.02491162, -117);
3386        let decoder = DecoderBuilder::default()
3387            .with_config_yolo_split_segdet(
3388                configs::Boxes {
3389                    decoder: configs::DecoderType::Ultralytics,
3390                    quantization: Some(quant_boxes.into()),
3391                    shape: vec![1, 4, 8400],
3392                    dshape: vec![
3393                        (DimName::Batch, 1),
3394                        (DimName::BoxCoords, 4),
3395                        (DimName::NumBoxes, 8400),
3396                    ],
3397                    normalized: Some(true),
3398                },
3399                configs::Scores {
3400                    decoder: configs::DecoderType::Ultralytics,
3401                    quantization: Some(quant_boxes.into()),
3402                    shape: vec![1, 80, 8400],
3403                    dshape: vec![
3404                        (DimName::Batch, 1),
3405                        (DimName::NumClasses, 80),
3406                        (DimName::NumBoxes, 8400),
3407                    ],
3408                },
3409                configs::MaskCoefficients {
3410                    decoder: configs::DecoderType::Ultralytics,
3411                    quantization: Some(quant_boxes.into()),
3412                    shape: vec![1, 32, 8400],
3413                    dshape: vec![
3414                        (DimName::Batch, 1),
3415                        (DimName::NumProtos, 32),
3416                        (DimName::NumBoxes, 8400),
3417                    ],
3418                },
3419                configs::Protos {
3420                    decoder: configs::DecoderType::Ultralytics,
3421                    quantization: Some(quant_protos.into()),
3422                    shape: vec![1, 160, 160, 32],
3423                    dshape: vec![
3424                        (DimName::Batch, 1),
3425                        (DimName::Height, 160),
3426                        (DimName::Width, 160),
3427                        (DimName::NumProtos, 32),
3428                    ],
3429                },
3430            )
3431            .with_score_threshold(score_threshold)
3432            .with_iou_threshold(iou_threshold)
3433            .build()
3434            .unwrap();
3435
3436        let expected_boxes = [
3437            DetectBox {
3438                bbox: BoundingBox {
3439                    xmin: 0.08515105,
3440                    ymin: 0.7131401,
3441                    xmax: 0.29802868,
3442                    ymax: 0.8195788,
3443                },
3444                score: 0.91537374,
3445                label: 23,
3446            },
3447            DetectBox {
3448                bbox: BoundingBox {
3449                    xmin: 0.59605736,
3450                    ymin: 0.25545314,
3451                    xmax: 0.93666154,
3452                    ymax: 0.72378385,
3453                },
3454                score: 0.91537374,
3455                label: 23,
3456            },
3457        ];
3458
3459        let mut tracker = ByteTrackBuilder::new()
3460            .track_update(0.1)
3461            .track_high_conf(0.7)
3462            .build();
3463
3464        let mut output_boxes = Vec::with_capacity(50);
3465        let mut output_masks = Vec::with_capacity(50);
3466        let mut output_tracks = Vec::with_capacity(50);
3467
3468        decoder
3469            .decode_tracked_quantized(
3470                &mut tracker,
3471                0,
3472                &[
3473                    boxes.view().into(),
3474                    scores.view().into(),
3475                    mask.view().into(),
3476                    protos.view().into(),
3477                ],
3478                &mut output_boxes,
3479                &mut output_masks,
3480                &mut output_tracks,
3481            )
3482            .unwrap();
3483
3484        assert_eq!(output_boxes.len(), 2);
3485        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3486        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3487
3488        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3489
3490        for score in scores.iter_mut() {
3491            *score = i8::MIN; // set all scores to minimum to simulate no detections
3492        }
3493        decoder
3494            .decode_tracked_quantized(
3495                &mut tracker,
3496                100_000_000 / 3,
3497                &[
3498                    boxes.view().into(),
3499                    scores.view().into(),
3500                    mask.view().into(),
3501                    protos.view().into(),
3502                ],
3503                &mut output_boxes,
3504                &mut output_masks,
3505                &mut output_tracks,
3506            )
3507            .unwrap();
3508
3509        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3510        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3511
3512        // no masks when the boxes are from tracker prediction without a matching detection
3513        assert!(output_masks.is_empty())
3514    }
3515
3516    #[test]
3517    fn test_decoder_tracked_segdet_split_float() {
3518        let score_threshold = 0.45;
3519        let iou_threshold = 0.45;
3520
3521        let boxes = include_bytes!(concat!(
3522            env!("CARGO_MANIFEST_DIR"),
3523            "/../../testdata/yolov8_boxes_116x8400.bin"
3524        ));
3525        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3526        let boxes = boxes.to_vec();
3527        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3528        let quant_boxes = (0.021287762, 31);
3529        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3530
3531        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3532        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3533        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3534
3535        let protos = include_bytes!(concat!(
3536            env!("CARGO_MANIFEST_DIR"),
3537            "/../../testdata/yolov8_protos_160x160x32.bin"
3538        ));
3539        let protos =
3540            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3541        let protos = protos.to_vec();
3542        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3543        let quant_protos = (0.02491162, -117);
3544        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3545
3546        let decoder = DecoderBuilder::default()
3547            .with_config_yolo_split_segdet(
3548                configs::Boxes {
3549                    decoder: configs::DecoderType::Ultralytics,
3550                    quantization: Some(quant_boxes.into()),
3551                    shape: vec![1, 4, 8400],
3552                    dshape: vec![
3553                        (DimName::Batch, 1),
3554                        (DimName::BoxCoords, 4),
3555                        (DimName::NumBoxes, 8400),
3556                    ],
3557                    normalized: Some(true),
3558                },
3559                configs::Scores {
3560                    decoder: configs::DecoderType::Ultralytics,
3561                    quantization: Some(quant_boxes.into()),
3562                    shape: vec![1, 80, 8400],
3563                    dshape: vec![
3564                        (DimName::Batch, 1),
3565                        (DimName::NumClasses, 80),
3566                        (DimName::NumBoxes, 8400),
3567                    ],
3568                },
3569                configs::MaskCoefficients {
3570                    decoder: configs::DecoderType::Ultralytics,
3571                    quantization: Some(quant_boxes.into()),
3572                    shape: vec![1, 32, 8400],
3573                    dshape: vec![
3574                        (DimName::Batch, 1),
3575                        (DimName::NumProtos, 32),
3576                        (DimName::NumBoxes, 8400),
3577                    ],
3578                },
3579                configs::Protos {
3580                    decoder: configs::DecoderType::Ultralytics,
3581                    quantization: Some(quant_protos.into()),
3582                    shape: vec![1, 160, 160, 32],
3583                    dshape: vec![
3584                        (DimName::Batch, 1),
3585                        (DimName::Height, 160),
3586                        (DimName::Width, 160),
3587                        (DimName::NumProtos, 32),
3588                    ],
3589                },
3590            )
3591            .with_score_threshold(score_threshold)
3592            .with_iou_threshold(iou_threshold)
3593            .build()
3594            .unwrap();
3595
3596        let expected_boxes = [
3597            DetectBox {
3598                bbox: BoundingBox {
3599                    xmin: 0.08515105,
3600                    ymin: 0.7131401,
3601                    xmax: 0.29802868,
3602                    ymax: 0.8195788,
3603                },
3604                score: 0.91537374,
3605                label: 23,
3606            },
3607            DetectBox {
3608                bbox: BoundingBox {
3609                    xmin: 0.59605736,
3610                    ymin: 0.25545314,
3611                    xmax: 0.93666154,
3612                    ymax: 0.72378385,
3613                },
3614                score: 0.91537374,
3615                label: 23,
3616            },
3617        ];
3618
3619        let mut tracker = ByteTrackBuilder::new()
3620            .track_update(0.1)
3621            .track_high_conf(0.7)
3622            .build();
3623
3624        let mut output_boxes = Vec::with_capacity(50);
3625        let mut output_masks = Vec::with_capacity(50);
3626        let mut output_tracks = Vec::with_capacity(50);
3627
3628        decoder
3629            .decode_tracked_float(
3630                &mut tracker,
3631                0,
3632                &[
3633                    boxes.view().into_dyn(),
3634                    scores.view().into_dyn(),
3635                    mask.view().into_dyn(),
3636                    protos.view().into_dyn(),
3637                ],
3638                &mut output_boxes,
3639                &mut output_masks,
3640                &mut output_tracks,
3641            )
3642            .unwrap();
3643
3644        assert_eq!(output_boxes.len(), 2);
3645        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3646        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3647
3648        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3649
3650        for score in scores.iter_mut() {
3651            *score = 0.0; // set all scores to minimum to simulate no detections
3652        }
3653        decoder
3654            .decode_tracked_float(
3655                &mut tracker,
3656                100_000_000 / 3,
3657                &[
3658                    boxes.view().into_dyn(),
3659                    scores.view().into_dyn(),
3660                    mask.view().into_dyn(),
3661                    protos.view().into_dyn(),
3662                ],
3663                &mut output_boxes,
3664                &mut output_masks,
3665                &mut output_tracks,
3666            )
3667            .unwrap();
3668
3669        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3670        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3671
3672        // no masks when the boxes are from tracker prediction without a matching detection
3673        assert!(output_masks.is_empty())
3674    }
3675
3676    #[test]
3677    fn test_decoder_tracked_segdet_split_proto() {
3678        let score_threshold = 0.45;
3679        let iou_threshold = 0.45;
3680
3681        let boxes = include_bytes!(concat!(
3682            env!("CARGO_MANIFEST_DIR"),
3683            "/../../testdata/yolov8_boxes_116x8400.bin"
3684        ));
3685        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3686        let boxes = boxes.to_vec();
3687        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3688
3689        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3690        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3691        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3692
3693        let quant_boxes = (0.021287762, 31);
3694
3695        let protos = include_bytes!(concat!(
3696            env!("CARGO_MANIFEST_DIR"),
3697            "/../../testdata/yolov8_protos_160x160x32.bin"
3698        ));
3699        let protos =
3700            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3701        let protos = protos.to_vec();
3702        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3703        let quant_protos = (0.02491162, -117);
3704        let decoder = DecoderBuilder::default()
3705            .with_config_yolo_split_segdet(
3706                configs::Boxes {
3707                    decoder: configs::DecoderType::Ultralytics,
3708                    quantization: Some(quant_boxes.into()),
3709                    shape: vec![1, 4, 8400],
3710                    dshape: vec![
3711                        (DimName::Batch, 1),
3712                        (DimName::BoxCoords, 4),
3713                        (DimName::NumBoxes, 8400),
3714                    ],
3715                    normalized: Some(true),
3716                },
3717                configs::Scores {
3718                    decoder: configs::DecoderType::Ultralytics,
3719                    quantization: Some(quant_boxes.into()),
3720                    shape: vec![1, 80, 8400],
3721                    dshape: vec![
3722                        (DimName::Batch, 1),
3723                        (DimName::NumClasses, 80),
3724                        (DimName::NumBoxes, 8400),
3725                    ],
3726                },
3727                configs::MaskCoefficients {
3728                    decoder: configs::DecoderType::Ultralytics,
3729                    quantization: Some(quant_boxes.into()),
3730                    shape: vec![1, 32, 8400],
3731                    dshape: vec![
3732                        (DimName::Batch, 1),
3733                        (DimName::NumProtos, 32),
3734                        (DimName::NumBoxes, 8400),
3735                    ],
3736                },
3737                configs::Protos {
3738                    decoder: configs::DecoderType::Ultralytics,
3739                    quantization: Some(quant_protos.into()),
3740                    shape: vec![1, 160, 160, 32],
3741                    dshape: vec![
3742                        (DimName::Batch, 1),
3743                        (DimName::Height, 160),
3744                        (DimName::Width, 160),
3745                        (DimName::NumProtos, 32),
3746                    ],
3747                },
3748            )
3749            .with_score_threshold(score_threshold)
3750            .with_iou_threshold(iou_threshold)
3751            .build()
3752            .unwrap();
3753
3754        let expected_boxes = [
3755            DetectBox {
3756                bbox: BoundingBox {
3757                    xmin: 0.08515105,
3758                    ymin: 0.7131401,
3759                    xmax: 0.29802868,
3760                    ymax: 0.8195788,
3761                },
3762                score: 0.91537374,
3763                label: 23,
3764            },
3765            DetectBox {
3766                bbox: BoundingBox {
3767                    xmin: 0.59605736,
3768                    ymin: 0.25545314,
3769                    xmax: 0.93666154,
3770                    ymax: 0.72378385,
3771                },
3772                score: 0.91537374,
3773                label: 23,
3774            },
3775        ];
3776
3777        let mut tracker = ByteTrackBuilder::new()
3778            .track_update(0.1)
3779            .track_high_conf(0.7)
3780            .build();
3781
3782        let mut output_boxes = Vec::with_capacity(50);
3783        let mut output_tracks = Vec::with_capacity(50);
3784
3785        decoder
3786            .decode_tracked_quantized_proto(
3787                &mut tracker,
3788                0,
3789                &[
3790                    boxes.view().into(),
3791                    scores.view().into(),
3792                    mask.view().into(),
3793                    protos.view().into(),
3794                ],
3795                &mut output_boxes,
3796                &mut output_tracks,
3797            )
3798            .unwrap();
3799
3800        assert_eq!(output_boxes.len(), 2);
3801        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3802        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3803
3804        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3805
3806        for score in scores.iter_mut() {
3807            *score = i8::MIN; // set all scores to minimum to simulate no detections
3808        }
3809        let protos = decoder
3810            .decode_tracked_quantized_proto(
3811                &mut tracker,
3812                100_000_000 / 3,
3813                &[
3814                    boxes.view().into(),
3815                    scores.view().into(),
3816                    mask.view().into(),
3817                    protos.view().into(),
3818                ],
3819                &mut output_boxes,
3820                &mut output_tracks,
3821            )
3822            .unwrap();
3823
3824        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3825        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3826
3827        // no masks when the boxes are from tracker prediction without a matching detection
3828        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3829    }
3830
3831    #[test]
3832    fn test_decoder_tracked_segdet_split_proto_float() {
3833        let score_threshold = 0.45;
3834        let iou_threshold = 0.45;
3835
3836        let boxes = include_bytes!(concat!(
3837            env!("CARGO_MANIFEST_DIR"),
3838            "/../../testdata/yolov8_boxes_116x8400.bin"
3839        ));
3840        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3841        let boxes = boxes.to_vec();
3842        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3843        let quant_boxes = (0.021287762, 31);
3844        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3845
3846        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3847        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3848        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3849
3850        let protos = include_bytes!(concat!(
3851            env!("CARGO_MANIFEST_DIR"),
3852            "/../../testdata/yolov8_protos_160x160x32.bin"
3853        ));
3854        let protos =
3855            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3856        let protos = protos.to_vec();
3857        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3858        let quant_protos = (0.02491162, -117);
3859        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3860
3861        let decoder = DecoderBuilder::default()
3862            .with_config_yolo_split_segdet(
3863                configs::Boxes {
3864                    decoder: configs::DecoderType::Ultralytics,
3865                    quantization: Some(quant_boxes.into()),
3866                    shape: vec![1, 4, 8400],
3867                    dshape: vec![
3868                        (DimName::Batch, 1),
3869                        (DimName::BoxCoords, 4),
3870                        (DimName::NumBoxes, 8400),
3871                    ],
3872                    normalized: Some(true),
3873                },
3874                configs::Scores {
3875                    decoder: configs::DecoderType::Ultralytics,
3876                    quantization: Some(quant_boxes.into()),
3877                    shape: vec![1, 80, 8400],
3878                    dshape: vec![
3879                        (DimName::Batch, 1),
3880                        (DimName::NumClasses, 80),
3881                        (DimName::NumBoxes, 8400),
3882                    ],
3883                },
3884                configs::MaskCoefficients {
3885                    decoder: configs::DecoderType::Ultralytics,
3886                    quantization: Some(quant_boxes.into()),
3887                    shape: vec![1, 32, 8400],
3888                    dshape: vec![
3889                        (DimName::Batch, 1),
3890                        (DimName::NumProtos, 32),
3891                        (DimName::NumBoxes, 8400),
3892                    ],
3893                },
3894                configs::Protos {
3895                    decoder: configs::DecoderType::Ultralytics,
3896                    quantization: Some(quant_protos.into()),
3897                    shape: vec![1, 160, 160, 32],
3898                    dshape: vec![
3899                        (DimName::Batch, 1),
3900                        (DimName::Height, 160),
3901                        (DimName::Width, 160),
3902                        (DimName::NumProtos, 32),
3903                    ],
3904                },
3905            )
3906            .with_score_threshold(score_threshold)
3907            .with_iou_threshold(iou_threshold)
3908            .build()
3909            .unwrap();
3910
3911        let expected_boxes = [
3912            DetectBox {
3913                bbox: BoundingBox {
3914                    xmin: 0.08515105,
3915                    ymin: 0.7131401,
3916                    xmax: 0.29802868,
3917                    ymax: 0.8195788,
3918                },
3919                score: 0.91537374,
3920                label: 23,
3921            },
3922            DetectBox {
3923                bbox: BoundingBox {
3924                    xmin: 0.59605736,
3925                    ymin: 0.25545314,
3926                    xmax: 0.93666154,
3927                    ymax: 0.72378385,
3928                },
3929                score: 0.91537374,
3930                label: 23,
3931            },
3932        ];
3933
3934        let mut tracker = ByteTrackBuilder::new()
3935            .track_update(0.1)
3936            .track_high_conf(0.7)
3937            .build();
3938
3939        let mut output_boxes = Vec::with_capacity(50);
3940        let mut output_tracks = Vec::with_capacity(50);
3941
3942        decoder
3943            .decode_tracked_float_proto(
3944                &mut tracker,
3945                0,
3946                &[
3947                    boxes.view().into_dyn(),
3948                    scores.view().into_dyn(),
3949                    mask.view().into_dyn(),
3950                    protos.view().into_dyn(),
3951                ],
3952                &mut output_boxes,
3953                &mut output_tracks,
3954            )
3955            .unwrap();
3956
3957        assert_eq!(output_boxes.len(), 2);
3958        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3959        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3960
3961        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3962
3963        for score in scores.iter_mut() {
3964            *score = 0.0; // set all scores to minimum to simulate no detections
3965        }
3966        let protos = decoder
3967            .decode_tracked_float_proto(
3968                &mut tracker,
3969                100_000_000 / 3,
3970                &[
3971                    boxes.view().into_dyn(),
3972                    scores.view().into_dyn(),
3973                    mask.view().into_dyn(),
3974                    protos.view().into_dyn(),
3975                ],
3976                &mut output_boxes,
3977                &mut output_tracks,
3978            )
3979            .unwrap();
3980
3981        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3982        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3983
3984        // no masks when the boxes are from tracker prediction without a matching detection
3985        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3986    }
3987
3988    #[test]
3989    fn test_decoder_tracked_end_to_end_segdet() {
3990        let score_threshold = 0.45;
3991        let iou_threshold = 0.45;
3992
3993        let mut boxes = Array2::zeros((10, 4));
3994        let mut scores = Array2::zeros((10, 1));
3995        let mut classes = Array2::zeros((10, 1));
3996        let mask = Array2::zeros((10, 32));
3997        let protos = Array3::<f64>::zeros((160, 160, 32));
3998        let protos = protos.insert_axis(Axis(0));
3999
4000        let protos_quant = (1.0 / 255.0, 0.0);
4001        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4002
4003        boxes
4004            .slice_mut(s![0, ..,])
4005            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4006        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4007        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4008
4009        let detect = ndarray::concatenate![
4010            Axis(1),
4011            boxes.view(),
4012            scores.view(),
4013            classes.view(),
4014            mask.view()
4015        ];
4016        let detect = detect.insert_axis(Axis(0));
4017        assert_eq!(detect.shape(), &[1, 10, 38]);
4018        let detect_quant = (2.0 / 255.0, 0.0);
4019        let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4020        let config = "
4021decoder_version: yolo26
4022outputs:
4023 - type: detection
4024   decoder: ultralytics
4025   quantization: [0.00784313725490196, 0]
4026   shape: [1, 10, 38]
4027   dshape:
4028    - [batch, 1]
4029    - [num_boxes, 10]
4030    - [num_features, 38]
4031   normalized: true
4032 - type: protos
4033   decoder: ultralytics
4034   quantization: [0.0039215686274509803921568627451, 128]
4035   shape: [1, 160, 160, 32]
4036   dshape:
4037    - [batch, 1]
4038    - [height, 160]
4039    - [width, 160]
4040    - [num_protos, 32]
4041";
4042
4043        let decoder = DecoderBuilder::default()
4044            .with_config_yaml_str(config.to_string())
4045            .with_score_threshold(score_threshold)
4046            .with_iou_threshold(iou_threshold)
4047            .build()
4048            .unwrap();
4049
4050        // Expected boxes doesn't match the float values exactly due to quantization error
4051        let expected_boxes = [DetectBox {
4052            bbox: BoundingBox {
4053                xmin: 0.12549022,
4054                ymin: 0.12549022,
4055                xmax: 0.23529413,
4056                ymax: 0.23529413,
4057            },
4058            score: 0.98823535,
4059            label: 2,
4060        }];
4061
4062        let mut tracker = ByteTrackBuilder::new()
4063            .track_update(0.1)
4064            .track_high_conf(0.7)
4065            .build();
4066
4067        let mut output_boxes = Vec::with_capacity(50);
4068        let mut output_masks = Vec::with_capacity(50);
4069        let mut output_tracks = Vec::with_capacity(50);
4070
4071        decoder
4072            .decode_tracked_quantized(
4073                &mut tracker,
4074                0,
4075                &[detect.view().into(), protos.view().into()],
4076                &mut output_boxes,
4077                &mut output_masks,
4078                &mut output_tracks,
4079            )
4080            .unwrap();
4081
4082        assert_eq!(output_boxes.len(), 1);
4083        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4084
4085        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4086
4087        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4088            *score = u8::MIN; // set all scores to minimum to simulate no detections
4089        }
4090
4091        decoder
4092            .decode_tracked_quantized(
4093                &mut tracker,
4094                100_000_000 / 3,
4095                &[detect.view().into(), protos.view().into()],
4096                &mut output_boxes,
4097                &mut output_masks,
4098                &mut output_tracks,
4099            )
4100            .unwrap();
4101        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4102        // no masks when the boxes are from tracker prediction without a matching detection
4103        assert!(output_masks.is_empty())
4104    }
4105
4106    #[test]
4107    fn test_decoder_tracked_end_to_end_segdet_float() {
4108        let score_threshold = 0.45;
4109        let iou_threshold = 0.45;
4110
4111        let mut boxes = Array2::zeros((10, 4));
4112        let mut scores = Array2::zeros((10, 1));
4113        let mut classes = Array2::zeros((10, 1));
4114        let mask = Array2::zeros((10, 32));
4115        let protos = Array3::<f64>::zeros((160, 160, 32));
4116        let protos = protos.insert_axis(Axis(0));
4117
4118        boxes
4119            .slice_mut(s![0, ..,])
4120            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4121        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4122        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4123
4124        let detect = ndarray::concatenate![
4125            Axis(1),
4126            boxes.view(),
4127            scores.view(),
4128            classes.view(),
4129            mask.view()
4130        ];
4131        let mut detect = detect.insert_axis(Axis(0));
4132        assert_eq!(detect.shape(), &[1, 10, 38]);
4133        let config = "
4134decoder_version: yolo26
4135outputs:
4136 - type: detection
4137   decoder: ultralytics
4138   quantization: [0.00784313725490196, 0]
4139   shape: [1, 10, 38]
4140   dshape:
4141    - [batch, 1]
4142    - [num_boxes, 10]
4143    - [num_features, 38]
4144   normalized: true
4145 - type: protos
4146   decoder: ultralytics
4147   quantization: [0.0039215686274509803921568627451, 128]
4148   shape: [1, 160, 160, 32]
4149   dshape:
4150    - [batch, 1]
4151    - [height, 160]
4152    - [width, 160]
4153    - [num_protos, 32]
4154";
4155
4156        let decoder = DecoderBuilder::default()
4157            .with_config_yaml_str(config.to_string())
4158            .with_score_threshold(score_threshold)
4159            .with_iou_threshold(iou_threshold)
4160            .build()
4161            .unwrap();
4162
4163        let expected_boxes = [DetectBox {
4164            bbox: BoundingBox {
4165                xmin: 0.1234,
4166                ymin: 0.1234,
4167                xmax: 0.2345,
4168                ymax: 0.2345,
4169            },
4170            score: 0.9876,
4171            label: 2,
4172        }];
4173
4174        let mut tracker = ByteTrackBuilder::new()
4175            .track_update(0.1)
4176            .track_high_conf(0.7)
4177            .build();
4178
4179        let mut output_boxes = Vec::with_capacity(50);
4180        let mut output_masks = Vec::with_capacity(50);
4181        let mut output_tracks = Vec::with_capacity(50);
4182
4183        decoder
4184            .decode_tracked_float(
4185                &mut tracker,
4186                0,
4187                &[detect.view().into_dyn(), protos.view().into_dyn()],
4188                &mut output_boxes,
4189                &mut output_masks,
4190                &mut output_tracks,
4191            )
4192            .unwrap();
4193
4194        assert_eq!(output_boxes.len(), 1);
4195        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4196
4197        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4198
4199        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4200            *score = 0.0; // set all scores to minimum to simulate no detections
4201        }
4202
4203        decoder
4204            .decode_tracked_float(
4205                &mut tracker,
4206                100_000_000 / 3,
4207                &[detect.view().into_dyn(), protos.view().into_dyn()],
4208                &mut output_boxes,
4209                &mut output_masks,
4210                &mut output_tracks,
4211            )
4212            .unwrap();
4213        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4214        // no masks when the boxes are from tracker prediction without a matching detection
4215        assert!(output_masks.is_empty())
4216    }
4217
4218    #[test]
4219    fn test_decoder_tracked_end_to_end_segdet_proto() {
4220        let score_threshold = 0.45;
4221        let iou_threshold = 0.45;
4222
4223        let mut boxes = Array2::zeros((10, 4));
4224        let mut scores = Array2::zeros((10, 1));
4225        let mut classes = Array2::zeros((10, 1));
4226        let mask = Array2::zeros((10, 32));
4227        let protos = Array3::<f64>::zeros((160, 160, 32));
4228        let protos = protos.insert_axis(Axis(0));
4229
4230        let protos_quant = (1.0 / 255.0, 0.0);
4231        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4232
4233        boxes
4234            .slice_mut(s![0, ..,])
4235            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4236        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4237        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4238
4239        let detect = ndarray::concatenate![
4240            Axis(1),
4241            boxes.view(),
4242            scores.view(),
4243            classes.view(),
4244            mask.view()
4245        ];
4246        let detect = detect.insert_axis(Axis(0));
4247        assert_eq!(detect.shape(), &[1, 10, 38]);
4248        let detect_quant = (2.0 / 255.0, 0.0);
4249        let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4250        let config = "
4251decoder_version: yolo26
4252outputs:
4253 - type: detection
4254   decoder: ultralytics
4255   quantization: [0.00784313725490196, 0]
4256   shape: [1, 10, 38]
4257   dshape:
4258    - [batch, 1]
4259    - [num_boxes, 10]
4260    - [num_features, 38]
4261   normalized: true
4262 - type: protos
4263   decoder: ultralytics
4264   quantization: [0.0039215686274509803921568627451, 128]
4265   shape: [1, 160, 160, 32]
4266   dshape:
4267    - [batch, 1]
4268    - [height, 160]
4269    - [width, 160]
4270    - [num_protos, 32]
4271";
4272
4273        let decoder = DecoderBuilder::default()
4274            .with_config_yaml_str(config.to_string())
4275            .with_score_threshold(score_threshold)
4276            .with_iou_threshold(iou_threshold)
4277            .build()
4278            .unwrap();
4279
4280        // Expected boxes doesn't match the float values exactly due to quantization error
4281        let expected_boxes = [DetectBox {
4282            bbox: BoundingBox {
4283                xmin: 0.12549022,
4284                ymin: 0.12549022,
4285                xmax: 0.23529413,
4286                ymax: 0.23529413,
4287            },
4288            score: 0.98823535,
4289            label: 2,
4290        }];
4291
4292        let mut tracker = ByteTrackBuilder::new()
4293            .track_update(0.1)
4294            .track_high_conf(0.7)
4295            .build();
4296
4297        let mut output_boxes = Vec::with_capacity(50);
4298        let mut output_tracks = Vec::with_capacity(50);
4299
4300        decoder
4301            .decode_tracked_quantized_proto(
4302                &mut tracker,
4303                0,
4304                &[detect.view().into(), protos.view().into()],
4305                &mut output_boxes,
4306                &mut output_tracks,
4307            )
4308            .unwrap();
4309
4310        assert_eq!(output_boxes.len(), 1);
4311        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4312
4313        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4314
4315        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4316            *score = u8::MIN; // set all scores to minimum to simulate no detections
4317        }
4318
4319        let protos = decoder
4320            .decode_tracked_quantized_proto(
4321                &mut tracker,
4322                100_000_000 / 3,
4323                &[detect.view().into(), protos.view().into()],
4324                &mut output_boxes,
4325                &mut output_tracks,
4326            )
4327            .unwrap();
4328        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4329        // no masks when the boxes are from tracker prediction without a matching detection
4330        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4331    }
4332
4333    #[test]
4334    fn test_decoder_tracked_end_to_end_segdet_proto_float() {
4335        let score_threshold = 0.45;
4336        let iou_threshold = 0.45;
4337
4338        let mut boxes = Array2::zeros((10, 4));
4339        let mut scores = Array2::zeros((10, 1));
4340        let mut classes = Array2::zeros((10, 1));
4341        let mask = Array2::zeros((10, 32));
4342        let protos = Array3::<f64>::zeros((160, 160, 32));
4343        let protos = protos.insert_axis(Axis(0));
4344
4345        boxes
4346            .slice_mut(s![0, ..,])
4347            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4348        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4349        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4350
4351        let detect = ndarray::concatenate![
4352            Axis(1),
4353            boxes.view(),
4354            scores.view(),
4355            classes.view(),
4356            mask.view()
4357        ];
4358        let mut detect = detect.insert_axis(Axis(0));
4359        assert_eq!(detect.shape(), &[1, 10, 38]);
4360        let config = "
4361decoder_version: yolo26
4362outputs:
4363 - type: detection
4364   decoder: ultralytics
4365   quantization: [0.00784313725490196, 0]
4366   shape: [1, 10, 38]
4367   dshape:
4368    - [batch, 1]
4369    - [num_boxes, 10]
4370    - [num_features, 38]
4371   normalized: true
4372 - type: protos
4373   decoder: ultralytics
4374   quantization: [0.0039215686274509803921568627451, 128]
4375   shape: [1, 160, 160, 32]
4376   dshape:
4377    - [batch, 1]
4378    - [height, 160]
4379    - [width, 160]
4380    - [num_protos, 32]
4381";
4382
4383        let decoder = DecoderBuilder::default()
4384            .with_config_yaml_str(config.to_string())
4385            .with_score_threshold(score_threshold)
4386            .with_iou_threshold(iou_threshold)
4387            .build()
4388            .unwrap();
4389
4390        let expected_boxes = [DetectBox {
4391            bbox: BoundingBox {
4392                xmin: 0.1234,
4393                ymin: 0.1234,
4394                xmax: 0.2345,
4395                ymax: 0.2345,
4396            },
4397            score: 0.9876,
4398            label: 2,
4399        }];
4400
4401        let mut tracker = ByteTrackBuilder::new()
4402            .track_update(0.1)
4403            .track_high_conf(0.7)
4404            .build();
4405
4406        let mut output_boxes = Vec::with_capacity(50);
4407        let mut output_tracks = Vec::with_capacity(50);
4408
4409        decoder
4410            .decode_tracked_float_proto(
4411                &mut tracker,
4412                0,
4413                &[detect.view().into_dyn(), protos.view().into_dyn()],
4414                &mut output_boxes,
4415                &mut output_tracks,
4416            )
4417            .unwrap();
4418
4419        assert_eq!(output_boxes.len(), 1);
4420        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4421
4422        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4423
4424        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4425            *score = 0.0; // set all scores to minimum to simulate no detections
4426        }
4427
4428        let protos = decoder
4429            .decode_tracked_float_proto(
4430                &mut tracker,
4431                100_000_000 / 3,
4432                &[detect.view().into_dyn(), protos.view().into_dyn()],
4433                &mut output_boxes,
4434                &mut output_tracks,
4435            )
4436            .unwrap();
4437        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4438        // no masks when the boxes are from tracker prediction without a matching detection
4439        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4440    }
4441
4442    #[test]
4443    fn test_decoder_tracked_end_to_end_segdet_split() {
4444        let score_threshold = 0.45;
4445        let iou_threshold = 0.45;
4446
4447        let mut boxes = Array2::zeros((10, 4));
4448        let mut scores = Array2::zeros((10, 1));
4449        let mut classes = Array2::zeros((10, 1));
4450        let mask: Array2<f64> = Array2::zeros((10, 32));
4451        let protos = Array3::<f64>::zeros((160, 160, 32));
4452        let protos = protos.insert_axis(Axis(0));
4453
4454        let protos_quant = (1.0 / 255.0, 0.0);
4455        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4456
4457        boxes
4458            .slice_mut(s![0, ..,])
4459            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4460        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4461        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4462
4463        let boxes = boxes.insert_axis(Axis(0));
4464        let scores = scores.insert_axis(Axis(0));
4465        let classes = classes.insert_axis(Axis(0));
4466        let mask = mask.insert_axis(Axis(0));
4467
4468        let detect_quant = (2.0 / 255.0, 0.0);
4469        let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4470        let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4471        let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4472        let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4473
4474        let config = "
4475decoder_version: yolo26
4476outputs:
4477 - type: boxes
4478   decoder: ultralytics
4479   quantization: [0.00784313725490196, 0]
4480   shape: [1, 10, 4]
4481   dshape:
4482    - [batch, 1]
4483    - [num_boxes, 10]
4484    - [box_coords, 4]
4485   normalized: true
4486 - type: scores
4487   decoder: ultralytics
4488   quantization: [0.00784313725490196, 0]
4489   shape: [1, 10, 1]
4490   dshape:
4491    - [batch, 1]
4492    - [num_boxes, 10]
4493    - [num_classes, 1]
4494 - type: classes
4495   decoder: ultralytics
4496   quantization: [0.00784313725490196, 0]
4497   shape: [1, 10, 1]
4498   dshape:
4499    - [batch, 1]
4500    - [num_boxes, 10]
4501    - [num_classes, 1]
4502 - type: mask_coefficients
4503   decoder: ultralytics
4504   quantization: [0.00784313725490196, 0]
4505   shape: [1, 10, 32]
4506   dshape:
4507    - [batch, 1]
4508    - [num_boxes, 10]
4509    - [num_protos, 32]
4510 - type: protos
4511   decoder: ultralytics
4512   quantization: [0.0039215686274509803921568627451, 128]
4513   shape: [1, 160, 160, 32]
4514   dshape:
4515    - [batch, 1]
4516    - [height, 160]
4517    - [width, 160]
4518    - [num_protos, 32]
4519";
4520
4521        let decoder = DecoderBuilder::default()
4522            .with_config_yaml_str(config.to_string())
4523            .with_score_threshold(score_threshold)
4524            .with_iou_threshold(iou_threshold)
4525            .build()
4526            .unwrap();
4527
4528        // Expected boxes doesn't match the float values exactly due to quantization error
4529        let expected_boxes = [DetectBox {
4530            bbox: BoundingBox {
4531                xmin: 0.12549022,
4532                ymin: 0.12549022,
4533                xmax: 0.23529413,
4534                ymax: 0.23529413,
4535            },
4536            score: 0.98823535,
4537            label: 2,
4538        }];
4539
4540        let mut tracker = ByteTrackBuilder::new()
4541            .track_update(0.1)
4542            .track_high_conf(0.7)
4543            .build();
4544
4545        let mut output_boxes = Vec::with_capacity(50);
4546        let mut output_masks = Vec::with_capacity(50);
4547        let mut output_tracks = Vec::with_capacity(50);
4548
4549        decoder
4550            .decode_tracked_quantized(
4551                &mut tracker,
4552                0,
4553                &[
4554                    boxes.view().into(),
4555                    scores.view().into(),
4556                    classes.view().into(),
4557                    mask.view().into(),
4558                    protos.view().into(),
4559                ],
4560                &mut output_boxes,
4561                &mut output_masks,
4562                &mut output_tracks,
4563            )
4564            .unwrap();
4565
4566        assert_eq!(output_boxes.len(), 1);
4567        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4568
4569        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4570
4571        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4572            *score = u8::MIN; // set all scores to minimum to simulate no detections
4573        }
4574
4575        decoder
4576            .decode_tracked_quantized(
4577                &mut tracker,
4578                100_000_000 / 3,
4579                &[
4580                    boxes.view().into(),
4581                    scores.view().into(),
4582                    classes.view().into(),
4583                    mask.view().into(),
4584                    protos.view().into(),
4585                ],
4586                &mut output_boxes,
4587                &mut output_masks,
4588                &mut output_tracks,
4589            )
4590            .unwrap();
4591        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4592        // no masks when the boxes are from tracker prediction without a matching detection
4593        assert!(output_masks.is_empty())
4594    }
4595    #[test]
4596    fn test_decoder_tracked_end_to_end_segdet_split_float() {
4597        let score_threshold = 0.45;
4598        let iou_threshold = 0.45;
4599
4600        let mut boxes = Array2::zeros((10, 4));
4601        let mut scores = Array2::zeros((10, 1));
4602        let mut classes = Array2::zeros((10, 1));
4603        let mask: Array2<f64> = Array2::zeros((10, 32));
4604        let protos = Array3::<f64>::zeros((160, 160, 32));
4605        let protos = protos.insert_axis(Axis(0));
4606
4607        boxes
4608            .slice_mut(s![0, ..,])
4609            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4610        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4611        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4612
4613        let boxes = boxes.insert_axis(Axis(0));
4614        let mut scores = scores.insert_axis(Axis(0));
4615        let classes = classes.insert_axis(Axis(0));
4616        let mask = mask.insert_axis(Axis(0));
4617
4618        let config = "
4619decoder_version: yolo26
4620outputs:
4621 - type: boxes
4622   decoder: ultralytics
4623   quantization: [0.00784313725490196, 0]
4624   shape: [1, 10, 4]
4625   dshape:
4626    - [batch, 1]
4627    - [num_boxes, 10]
4628    - [box_coords, 4]
4629   normalized: true
4630 - type: scores
4631   decoder: ultralytics
4632   quantization: [0.00784313725490196, 0]
4633   shape: [1, 10, 1]
4634   dshape:
4635    - [batch, 1]
4636    - [num_boxes, 10]
4637    - [num_classes, 1]
4638 - type: classes
4639   decoder: ultralytics
4640   quantization: [0.00784313725490196, 0]
4641   shape: [1, 10, 1]
4642   dshape:
4643    - [batch, 1]
4644    - [num_boxes, 10]
4645    - [num_classes, 1]
4646 - type: mask_coefficients
4647   decoder: ultralytics
4648   quantization: [0.00784313725490196, 0]
4649   shape: [1, 10, 32]
4650   dshape:
4651    - [batch, 1]
4652    - [num_boxes, 10]
4653    - [num_protos, 32]
4654 - type: protos
4655   decoder: ultralytics
4656   quantization: [0.0039215686274509803921568627451, 128]
4657   shape: [1, 160, 160, 32]
4658   dshape:
4659    - [batch, 1]
4660    - [height, 160]
4661    - [width, 160]
4662    - [num_protos, 32]
4663";
4664
4665        let decoder = DecoderBuilder::default()
4666            .with_config_yaml_str(config.to_string())
4667            .with_score_threshold(score_threshold)
4668            .with_iou_threshold(iou_threshold)
4669            .build()
4670            .unwrap();
4671
4672        // Expected boxes doesn't match the float values exactly due to quantization error
4673        let expected_boxes = [DetectBox {
4674            bbox: BoundingBox {
4675                xmin: 0.1234,
4676                ymin: 0.1234,
4677                xmax: 0.2345,
4678                ymax: 0.2345,
4679            },
4680            score: 0.9876,
4681            label: 2,
4682        }];
4683
4684        let mut tracker = ByteTrackBuilder::new()
4685            .track_update(0.1)
4686            .track_high_conf(0.7)
4687            .build();
4688
4689        let mut output_boxes = Vec::with_capacity(50);
4690        let mut output_masks = Vec::with_capacity(50);
4691        let mut output_tracks = Vec::with_capacity(50);
4692
4693        decoder
4694            .decode_tracked_float(
4695                &mut tracker,
4696                0,
4697                &[
4698                    boxes.view().into_dyn(),
4699                    scores.view().into_dyn(),
4700                    classes.view().into_dyn(),
4701                    mask.view().into_dyn(),
4702                    protos.view().into_dyn(),
4703                ],
4704                &mut output_boxes,
4705                &mut output_masks,
4706                &mut output_tracks,
4707            )
4708            .unwrap();
4709
4710        assert_eq!(output_boxes.len(), 1);
4711        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4712
4713        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4714
4715        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4716            *score = 0.0; // set all scores to minimum to simulate no detections
4717        }
4718
4719        decoder
4720            .decode_tracked_float(
4721                &mut tracker,
4722                100_000_000 / 3,
4723                &[
4724                    boxes.view().into_dyn(),
4725                    scores.view().into_dyn(),
4726                    classes.view().into_dyn(),
4727                    mask.view().into_dyn(),
4728                    protos.view().into_dyn(),
4729                ],
4730                &mut output_boxes,
4731                &mut output_masks,
4732                &mut output_tracks,
4733            )
4734            .unwrap();
4735        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4736        // no masks when the boxes are from tracker prediction without a matching detection
4737        assert!(output_masks.is_empty())
4738    }
4739
4740    #[test]
4741    fn test_decoder_tracked_end_to_end_segdet_split_proto() {
4742        let score_threshold = 0.45;
4743        let iou_threshold = 0.45;
4744
4745        let mut boxes = Array2::zeros((10, 4));
4746        let mut scores = Array2::zeros((10, 1));
4747        let mut classes = Array2::zeros((10, 1));
4748        let mask: Array2<f64> = Array2::zeros((10, 32));
4749        let protos = Array3::<f64>::zeros((160, 160, 32));
4750        let protos = protos.insert_axis(Axis(0));
4751
4752        let protos_quant = (1.0 / 255.0, 0.0);
4753        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4754
4755        boxes
4756            .slice_mut(s![0, ..,])
4757            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4758        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4759        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4760
4761        let boxes = boxes.insert_axis(Axis(0));
4762        let scores = scores.insert_axis(Axis(0));
4763        let classes = classes.insert_axis(Axis(0));
4764        let mask = mask.insert_axis(Axis(0));
4765
4766        let detect_quant = (2.0 / 255.0, 0.0);
4767        let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4768        let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4769        let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4770        let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4771
4772        let config = "
4773decoder_version: yolo26
4774outputs:
4775 - type: boxes
4776   decoder: ultralytics
4777   quantization: [0.00784313725490196, 0]
4778   shape: [1, 10, 4]
4779   dshape:
4780    - [batch, 1]
4781    - [num_boxes, 10]
4782    - [box_coords, 4]
4783   normalized: true
4784 - type: scores
4785   decoder: ultralytics
4786   quantization: [0.00784313725490196, 0]
4787   shape: [1, 10, 1]
4788   dshape:
4789    - [batch, 1]
4790    - [num_boxes, 10]
4791    - [num_classes, 1]
4792 - type: classes
4793   decoder: ultralytics
4794   quantization: [0.00784313725490196, 0]
4795   shape: [1, 10, 1]
4796   dshape:
4797    - [batch, 1]
4798    - [num_boxes, 10]
4799    - [num_classes, 1]
4800 - type: mask_coefficients
4801   decoder: ultralytics
4802   quantization: [0.00784313725490196, 0]
4803   shape: [1, 10, 32]
4804   dshape:
4805    - [batch, 1]
4806    - [num_boxes, 10]
4807    - [num_protos, 32]
4808 - type: protos
4809   decoder: ultralytics
4810   quantization: [0.0039215686274509803921568627451, 128]
4811   shape: [1, 160, 160, 32]
4812   dshape:
4813    - [batch, 1]
4814    - [height, 160]
4815    - [width, 160]
4816    - [num_protos, 32]
4817";
4818
4819        let decoder = DecoderBuilder::default()
4820            .with_config_yaml_str(config.to_string())
4821            .with_score_threshold(score_threshold)
4822            .with_iou_threshold(iou_threshold)
4823            .build()
4824            .unwrap();
4825
4826        // Expected boxes doesn't match the float values exactly due to quantization error
4827        let expected_boxes = [DetectBox {
4828            bbox: BoundingBox {
4829                xmin: 0.12549022,
4830                ymin: 0.12549022,
4831                xmax: 0.23529413,
4832                ymax: 0.23529413,
4833            },
4834            score: 0.98823535,
4835            label: 2,
4836        }];
4837
4838        let mut tracker = ByteTrackBuilder::new()
4839            .track_update(0.1)
4840            .track_high_conf(0.7)
4841            .build();
4842
4843        let mut output_boxes = Vec::with_capacity(50);
4844        let mut output_tracks = Vec::with_capacity(50);
4845
4846        decoder
4847            .decode_tracked_quantized_proto(
4848                &mut tracker,
4849                0,
4850                &[
4851                    boxes.view().into(),
4852                    scores.view().into(),
4853                    classes.view().into(),
4854                    mask.view().into(),
4855                    protos.view().into(),
4856                ],
4857                &mut output_boxes,
4858                &mut output_tracks,
4859            )
4860            .unwrap();
4861
4862        assert_eq!(output_boxes.len(), 1);
4863        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4864
4865        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4866
4867        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4868            *score = u8::MIN; // set all scores to minimum to simulate no detections
4869        }
4870
4871        let protos = decoder
4872            .decode_tracked_quantized_proto(
4873                &mut tracker,
4874                100_000_000 / 3,
4875                &[
4876                    boxes.view().into(),
4877                    scores.view().into(),
4878                    classes.view().into(),
4879                    mask.view().into(),
4880                    protos.view().into(),
4881                ],
4882                &mut output_boxes,
4883                &mut output_tracks,
4884            )
4885            .unwrap();
4886        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4887        // no masks when the boxes are from tracker prediction without a matching detection
4888        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4889    }
4890
4891    #[test]
4892    fn test_decoder_tracked_end_to_end_segdet_split_proto_float() {
4893        let score_threshold = 0.45;
4894        let iou_threshold = 0.45;
4895
4896        let mut boxes = Array2::zeros((10, 4));
4897        let mut scores = Array2::zeros((10, 1));
4898        let mut classes = Array2::zeros((10, 1));
4899        let mask: Array2<f64> = Array2::zeros((10, 32));
4900        let protos = Array3::<f64>::zeros((160, 160, 32));
4901        let protos = protos.insert_axis(Axis(0));
4902
4903        boxes
4904            .slice_mut(s![0, ..,])
4905            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4906        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4907        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4908
4909        let boxes = boxes.insert_axis(Axis(0));
4910        let mut scores = scores.insert_axis(Axis(0));
4911        let classes = classes.insert_axis(Axis(0));
4912        let mask = mask.insert_axis(Axis(0));
4913
4914        let config = "
4915decoder_version: yolo26
4916outputs:
4917 - type: boxes
4918   decoder: ultralytics
4919   quantization: [0.00784313725490196, 0]
4920   shape: [1, 10, 4]
4921   dshape:
4922    - [batch, 1]
4923    - [num_boxes, 10]
4924    - [box_coords, 4]
4925   normalized: true
4926 - type: scores
4927   decoder: ultralytics
4928   quantization: [0.00784313725490196, 0]
4929   shape: [1, 10, 1]
4930   dshape:
4931    - [batch, 1]
4932    - [num_boxes, 10]
4933    - [num_classes, 1]
4934 - type: classes
4935   decoder: ultralytics
4936   quantization: [0.00784313725490196, 0]
4937   shape: [1, 10, 1]
4938   dshape:
4939    - [batch, 1]
4940    - [num_boxes, 10]
4941    - [num_classes, 1]
4942 - type: mask_coefficients
4943   decoder: ultralytics
4944   quantization: [0.00784313725490196, 0]
4945   shape: [1, 10, 32]
4946   dshape:
4947    - [batch, 1]
4948    - [num_boxes, 10]
4949    - [num_protos, 32]
4950 - type: protos
4951   decoder: ultralytics
4952   quantization: [0.0039215686274509803921568627451, 128]
4953   shape: [1, 160, 160, 32]
4954   dshape:
4955    - [batch, 1]
4956    - [height, 160]
4957    - [width, 160]
4958    - [num_protos, 32]
4959";
4960
4961        let decoder = DecoderBuilder::default()
4962            .with_config_yaml_str(config.to_string())
4963            .with_score_threshold(score_threshold)
4964            .with_iou_threshold(iou_threshold)
4965            .build()
4966            .unwrap();
4967
4968        // Expected boxes doesn't match the float values exactly due to quantization error
4969        let expected_boxes = [DetectBox {
4970            bbox: BoundingBox {
4971                xmin: 0.1234,
4972                ymin: 0.1234,
4973                xmax: 0.2345,
4974                ymax: 0.2345,
4975            },
4976            score: 0.9876,
4977            label: 2,
4978        }];
4979
4980        let mut tracker = ByteTrackBuilder::new()
4981            .track_update(0.1)
4982            .track_high_conf(0.7)
4983            .build();
4984
4985        let mut output_boxes = Vec::with_capacity(50);
4986        let mut output_tracks = Vec::with_capacity(50);
4987
4988        decoder
4989            .decode_tracked_float_proto(
4990                &mut tracker,
4991                0,
4992                &[
4993                    boxes.view().into_dyn(),
4994                    scores.view().into_dyn(),
4995                    classes.view().into_dyn(),
4996                    mask.view().into_dyn(),
4997                    protos.view().into_dyn(),
4998                ],
4999                &mut output_boxes,
5000                &mut output_tracks,
5001            )
5002            .unwrap();
5003
5004        assert_eq!(output_boxes.len(), 1);
5005        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
5006
5007        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5008
5009        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5010            *score = 0.0; // set all scores to minimum to simulate no detections
5011        }
5012
5013        let protos = decoder
5014            .decode_tracked_float_proto(
5015                &mut tracker,
5016                100_000_000 / 3,
5017                &[
5018                    boxes.view().into_dyn(),
5019                    scores.view().into_dyn(),
5020                    classes.view().into_dyn(),
5021                    mask.view().into_dyn(),
5022                    protos.view().into_dyn(),
5023                ],
5024                &mut output_boxes,
5025                &mut output_tracks,
5026            )
5027            .unwrap();
5028        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5029        // no masks when the boxes are from tracker prediction without a matching detection
5030        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
5031    }
5032
5033    #[test]
5034    fn test_decoder_tracked_linear_motion() {
5035        use crate::configs::{DecoderType, Nms};
5036        use crate::DecoderBuilder;
5037
5038        let score_threshold = 0.25;
5039        let iou_threshold = 0.1;
5040        let out = include_bytes!(concat!(
5041            env!("CARGO_MANIFEST_DIR"),
5042            "/../../testdata/yolov8s_80_classes.bin"
5043        ));
5044        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5045        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5046        let quant = (0.0040811873, -123).into();
5047
5048        let decoder = DecoderBuilder::default()
5049            .with_config_yolo_det(
5050                crate::configs::Detection {
5051                    decoder: DecoderType::Ultralytics,
5052                    shape: vec![1, 84, 8400],
5053                    anchors: None,
5054                    quantization: Some(quant),
5055                    dshape: vec![
5056                        (crate::configs::DimName::Batch, 1),
5057                        (crate::configs::DimName::NumFeatures, 84),
5058                        (crate::configs::DimName::NumBoxes, 8400),
5059                    ],
5060                    normalized: Some(true),
5061                },
5062                None,
5063            )
5064            .with_score_threshold(score_threshold)
5065            .with_iou_threshold(iou_threshold)
5066            .with_nms(Some(Nms::ClassAgnostic))
5067            .build()
5068            .unwrap();
5069
5070        let mut expected_boxes = [
5071            DetectBox {
5072                bbox: BoundingBox {
5073                    xmin: 0.5285137,
5074                    ymin: 0.05305544,
5075                    xmax: 0.87541467,
5076                    ymax: 0.9998909,
5077                },
5078                score: 0.5591227,
5079                label: 0,
5080            },
5081            DetectBox {
5082                bbox: BoundingBox {
5083                    xmin: 0.130598,
5084                    ymin: 0.43260583,
5085                    xmax: 0.35098213,
5086                    ymax: 0.9958097,
5087                },
5088                score: 0.33057618,
5089                label: 75,
5090            },
5091        ];
5092
5093        let mut tracker = ByteTrackBuilder::new()
5094            .track_update(0.1)
5095            .track_high_conf(0.3)
5096            .build();
5097
5098        let mut output_boxes = Vec::with_capacity(50);
5099        let mut output_masks = Vec::with_capacity(50);
5100        let mut output_tracks = Vec::with_capacity(50);
5101
5102        decoder
5103            .decode_tracked_quantized(
5104                &mut tracker,
5105                0,
5106                &[out.view().into()],
5107                &mut output_boxes,
5108                &mut output_masks,
5109                &mut output_tracks,
5110            )
5111            .unwrap();
5112
5113        assert_eq!(output_boxes.len(), 2);
5114        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5115        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5116
5117        for i in 1..=100 {
5118            let mut out = out.clone();
5119            // introduce linear movement into the XY coordinates
5120            let mut x_values = out.slice_mut(s![0, 0, ..]);
5121            for x in x_values.iter_mut() {
5122                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5123            }
5124
5125            decoder
5126                .decode_tracked_quantized(
5127                    &mut tracker,
5128                    100_000_000 * i / 3, // simulate 33.333ms between frames
5129                    &[out.view().into()],
5130                    &mut output_boxes,
5131                    &mut output_masks,
5132                    &mut output_tracks,
5133                )
5134                .unwrap();
5135
5136            assert_eq!(output_boxes.len(), 2);
5137        }
5138        let tracks = tracker.get_active_tracks();
5139        let predicted_boxes: Vec<_> = tracks
5140            .iter()
5141            .map(|track| {
5142                let mut l = track.last_box;
5143                l.bbox = track.info.tracked_location.into();
5144                l
5145            })
5146            .collect();
5147        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
5148        expected_boxes[0].bbox.xmax += 0.1;
5149        expected_boxes[1].bbox.xmin += 0.1;
5150        expected_boxes[1].bbox.xmax += 0.1;
5151
5152        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5153        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5154
5155        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5156        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5157        for score in scores_values.iter_mut() {
5158            *score = i8::MIN; // set all scores to minimum to simulate no detections
5159        }
5160        decoder
5161            .decode_tracked_quantized(
5162                &mut tracker,
5163                100_000_000 * 101 / 3,
5164                &[out.view().into()],
5165                &mut output_boxes,
5166                &mut output_masks,
5167                &mut output_tracks,
5168            )
5169            .unwrap();
5170        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
5171        expected_boxes[0].bbox.xmax += 0.001;
5172        expected_boxes[1].bbox.xmin += 0.001;
5173        expected_boxes[1].bbox.xmax += 0.001;
5174
5175        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5176        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5177    }
5178
5179    #[test]
5180    fn test_decoder_tracked_end_to_end_float() {
5181        let score_threshold = 0.45;
5182        let iou_threshold = 0.45;
5183
5184        let mut boxes = Array2::zeros((10, 4));
5185        let mut scores = Array2::zeros((10, 1));
5186        let mut classes = Array2::zeros((10, 1));
5187
5188        boxes
5189            .slice_mut(s![0, ..,])
5190            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5191        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5192        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5193
5194        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5195        let mut detect = detect.insert_axis(Axis(0));
5196        assert_eq!(detect.shape(), &[1, 10, 6]);
5197        let config = "
5198decoder_version: yolo26
5199outputs:
5200 - type: detection
5201   decoder: ultralytics
5202   quantization: [0.00784313725490196, 0]
5203   shape: [1, 10, 6]
5204   dshape:
5205    - [batch, 1]
5206    - [num_boxes, 10]
5207    - [num_features, 6]
5208   normalized: true
5209";
5210
5211        let decoder = DecoderBuilder::default()
5212            .with_config_yaml_str(config.to_string())
5213            .with_score_threshold(score_threshold)
5214            .with_iou_threshold(iou_threshold)
5215            .build()
5216            .unwrap();
5217
5218        let expected_boxes = [DetectBox {
5219            bbox: BoundingBox {
5220                xmin: 0.1234,
5221                ymin: 0.1234,
5222                xmax: 0.2345,
5223                ymax: 0.2345,
5224            },
5225            score: 0.9876,
5226            label: 2,
5227        }];
5228
5229        let mut tracker = ByteTrackBuilder::new()
5230            .track_update(0.1)
5231            .track_high_conf(0.7)
5232            .build();
5233
5234        let mut output_boxes = Vec::with_capacity(50);
5235        let mut output_masks = Vec::with_capacity(50);
5236        let mut output_tracks = Vec::with_capacity(50);
5237
5238        decoder
5239            .decode_tracked_float(
5240                &mut tracker,
5241                0,
5242                &[detect.view().into_dyn()],
5243                &mut output_boxes,
5244                &mut output_masks,
5245                &mut output_tracks,
5246            )
5247            .unwrap();
5248
5249        assert_eq!(output_boxes.len(), 1);
5250        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5251
5252        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5253
5254        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5255            *score = 0.0; // set all scores to minimum to simulate no detections
5256        }
5257
5258        decoder
5259            .decode_tracked_float(
5260                &mut tracker,
5261                100_000_000 / 3,
5262                &[detect.view().into_dyn()],
5263                &mut output_boxes,
5264                &mut output_masks,
5265                &mut output_tracks,
5266            )
5267            .unwrap();
5268        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5269    }
5270}