Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust,no_run
16use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::{self, DecoderVersion}};
17use edgefirst_tensor::TensorDyn;
18
19fn main() -> DecoderResult<()> {
20    // Create a decoder for a YOLOv8 model with quantized int8 output
21    let decoder = DecoderBuilder::new()
22        .with_config_yolo_det(configs::Detection {
23            anchors: None,
24            decoder: configs::DecoderType::Ultralytics,
25            quantization: Some(configs::QuantTuple(0.012345, 26)),
26            shape: vec![1, 84, 8400],
27            dshape: Vec::new(),
28            normalized: Some(true),
29        },
30        Some(DecoderVersion::Yolov8))
31        .with_score_threshold(0.25)
32        .with_iou_threshold(0.7)
33        .build()?;
34
35    // Get the model output tensors from inference
36    let model_output: Vec<TensorDyn> = vec![/* tensors from inference */];
37    let tensor_refs: Vec<&TensorDyn> = model_output.iter().collect();
38
39    let mut output_boxes = Vec::with_capacity(10);
40    let mut output_masks = Vec::with_capacity(10);
41
42    // Decode model output into detection boxes and segmentation masks
43    decoder.decode(&tensor_refs, &mut output_boxes, &mut output_masks)?;
44    Ok(())
45}
46```
47
48# Overview
49
50The primary components of this crate are:
51- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
52- `yolo` module: Contains functions specific to decoding YOLO model outputs.
53- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
54
55The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
56It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
57When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
58If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
59the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
60to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
61
62The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
63which can be used if the model type and output formats are known in advance.
64
65
66*/
67#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
68
69use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
70use num_traits::{AsPrimitive, Float, PrimInt};
71
72pub mod byte;
73pub mod error;
74pub mod float;
75pub mod modelpack;
76pub mod yolo;
77
78mod decoder;
79pub use decoder::*;
80
81pub use configs::{DecoderVersion, Nms};
82pub use error::{DecoderError, DecoderResult};
83
84use crate::{
85    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
86    yolo::yolo_segmentation_to_mask,
87};
88
89/// Trait to convert bounding box formats to XYXY float format
90pub trait BBoxTypeTrait {
91    /// Converts the bbox into XYXY float format.
92    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
93
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
96        input: &[B; 4],
97        quant: Quantization,
98    ) -> [A; 4]
99    where
100        f32: AsPrimitive<A>,
101        i32: AsPrimitive<A>;
102
103    /// Converts the bbox into XYXY float format.
104    ///
105    /// # Examples
106    /// ```rust
107    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
108    /// # use ndarray::array;
109    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
110    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
111    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
112    /// ```
113    #[inline(always)]
114    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
115        input: ArrayView1<B>,
116    ) -> [A; 4] {
117        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
118    }
119
120    #[inline(always)]
121    /// Converts the bbox into XYXY float format.
122    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
123        input: ArrayView1<B>,
124        quant: Quantization,
125    ) -> [A; 4]
126    where
127        f32: AsPrimitive<A>,
128        i32: AsPrimitive<A>,
129    {
130        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
131    }
132}
133
134/// Converts XYXY bounding boxes to XYXY
135#[derive(Debug, Clone, Copy, PartialEq, Eq)]
136pub struct XYXY {}
137
138impl BBoxTypeTrait for XYXY {
139    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
140        input.map(|b| b.as_())
141    }
142
143    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
144        input: &[B; 4],
145        quant: Quantization,
146    ) -> [A; 4]
147    where
148        f32: AsPrimitive<A>,
149        i32: AsPrimitive<A>,
150    {
151        let scale = quant.scale.as_();
152        let zp = quant.zero_point.as_();
153        input.map(|b| (b.as_() - zp) * scale)
154    }
155
156    #[inline(always)]
157    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
158        input: ArrayView1<B>,
159    ) -> [A; 4] {
160        [
161            input[0].as_(),
162            input[1].as_(),
163            input[2].as_(),
164            input[3].as_(),
165        ]
166    }
167}
168
169/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
170/// box
171#[derive(Debug, Clone, Copy, PartialEq, Eq)]
172pub struct XYWH {}
173
174impl BBoxTypeTrait for XYWH {
175    #[inline(always)]
176    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
177        let half = A::one() / (A::one() + A::one());
178        [
179            (input[0].as_()) - (input[2].as_() * half),
180            (input[1].as_()) - (input[3].as_() * half),
181            (input[0].as_()) + (input[2].as_() * half),
182            (input[1].as_()) + (input[3].as_() * half),
183        ]
184    }
185
186    #[inline(always)]
187    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
188        input: &[B; 4],
189        quant: Quantization,
190    ) -> [A; 4]
191    where
192        f32: AsPrimitive<A>,
193        i32: AsPrimitive<A>,
194    {
195        let scale = quant.scale.as_();
196        let half_scale = (quant.scale * 0.5).as_();
197        let zp = quant.zero_point.as_();
198        let [x, y, w, h] = [
199            (input[0].as_() - zp) * scale,
200            (input[1].as_() - zp) * scale,
201            (input[2].as_() - zp) * half_scale,
202            (input[3].as_() - zp) * half_scale,
203        ];
204
205        [x - w, y - h, x + w, y + h]
206    }
207
208    #[inline(always)]
209    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
210        input: ArrayView1<B>,
211    ) -> [A; 4] {
212        let half = A::one() / (A::one() + A::one());
213        [
214            (input[0].as_()) - (input[2].as_() * half),
215            (input[1].as_()) - (input[3].as_() * half),
216            (input[0].as_()) + (input[2].as_() * half),
217            (input[1].as_()) + (input[3].as_() * half),
218        ]
219    }
220}
221
222/// Describes the quantization parameters for a tensor
223#[derive(Debug, Clone, Copy, PartialEq)]
224pub struct Quantization {
225    pub scale: f32,
226    pub zero_point: i32,
227}
228
229impl Quantization {
230    /// Creates a new Quantization struct
231    /// # Examples
232    /// ```
233    /// # use edgefirst_decoder::Quantization;
234    /// let quant = Quantization::new(0.1, -128);
235    /// assert_eq!(quant.scale, 0.1);
236    /// assert_eq!(quant.zero_point, -128);
237    /// ```
238    pub fn new(scale: f32, zero_point: i32) -> Self {
239        Self { scale, zero_point }
240    }
241}
242
243impl From<QuantTuple> for Quantization {
244    /// Creates a new Quantization struct from a QuantTuple
245    /// # Examples
246    /// ```
247    /// # use edgefirst_decoder::Quantization;
248    /// # use edgefirst_decoder::configs::QuantTuple;
249    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
250    /// let quant = Quantization::from(quant_tuple);
251    /// assert_eq!(quant.scale, 0.1);
252    /// assert_eq!(quant.zero_point, -128);
253    /// ```
254    fn from(quant_tuple: QuantTuple) -> Quantization {
255        Quantization {
256            scale: quant_tuple.0,
257            zero_point: quant_tuple.1,
258        }
259    }
260}
261
262impl<S, Z> From<(S, Z)> for Quantization
263where
264    S: AsPrimitive<f32>,
265    Z: AsPrimitive<i32>,
266{
267    /// Creates a new Quantization struct from a tuple
268    /// # Examples
269    /// ```
270    /// # use edgefirst_decoder::Quantization;
271    /// let quant = Quantization::from((0.1_f64, -128_i64));
272    /// assert_eq!(quant.scale, 0.1);
273    /// assert_eq!(quant.zero_point, -128);
274    /// ```
275    fn from((scale, zp): (S, Z)) -> Quantization {
276        Self {
277            scale: scale.as_(),
278            zero_point: zp.as_(),
279        }
280    }
281}
282
283impl Default for Quantization {
284    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
285    /// # Examples
286    /// ```rust
287    /// # use edgefirst_decoder::Quantization;
288    /// let quant = Quantization::default();
289    /// assert_eq!(quant.scale, 1.0);
290    /// assert_eq!(quant.zero_point, 0);
291    /// ```
292    fn default() -> Self {
293        Self {
294            scale: 1.0,
295            zero_point: 0,
296        }
297    }
298}
299
300/// A detection box with f32 bbox and score
301#[derive(Debug, Clone, Copy, PartialEq, Default)]
302pub struct DetectBox {
303    pub bbox: BoundingBox,
304    /// model-specific score for this detection, higher implies more confidence
305    pub score: f32,
306    /// label index for this detection
307    pub label: usize,
308}
309
310/// A bounding box with f32 coordinates in XYXY format
311#[derive(Debug, Clone, Copy, PartialEq, Default)]
312pub struct BoundingBox {
313    /// left-most normalized coordinate of the bounding box
314    pub xmin: f32,
315    /// top-most normalized coordinate of the bounding box
316    pub ymin: f32,
317    /// right-most normalized coordinate of the bounding box
318    pub xmax: f32,
319    /// bottom-most normalized coordinate of the bounding box
320    pub ymax: f32,
321}
322
323impl BoundingBox {
324    /// Creates a new BoundingBox from the given coordinates
325    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
326        Self {
327            xmin,
328            ymin,
329            xmax,
330            ymax,
331        }
332    }
333
334    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
335    ///
336    /// ```
337    /// # use edgefirst_decoder::BoundingBox;
338    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
339    /// let canonical_bbox = bbox.to_canonical();
340    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
341    /// ```
342    pub fn to_canonical(&self) -> Self {
343        let xmin = self.xmin.min(self.xmax);
344        let xmax = self.xmin.max(self.xmax);
345        let ymin = self.ymin.min(self.ymax);
346        let ymax = self.ymin.max(self.ymax);
347        BoundingBox {
348            xmin,
349            ymin,
350            xmax,
351            ymax,
352        }
353    }
354}
355
356impl From<BoundingBox> for [f32; 4] {
357    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
358    /// xmax, ymax order
359    /// # Examples
360    /// ```
361    /// # use edgefirst_decoder::BoundingBox;
362    /// let bbox = BoundingBox {
363    ///     xmin: 0.1,
364    ///     ymin: 0.2,
365    ///     xmax: 0.3,
366    ///     ymax: 0.4,
367    /// };
368    /// let arr: [f32; 4] = bbox.into();
369    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
370    /// ```
371    fn from(b: BoundingBox) -> Self {
372        [b.xmin, b.ymin, b.xmax, b.ymax]
373    }
374}
375
376impl From<[f32; 4]> for BoundingBox {
377    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
378    // BoundingBox
379    fn from(arr: [f32; 4]) -> Self {
380        BoundingBox {
381            xmin: arr[0],
382            ymin: arr[1],
383            xmax: arr[2],
384            ymax: arr[3],
385        }
386    }
387}
388
389impl DetectBox {
390    /// Returns true if one detect box is equal to another detect box, within
391    /// the given `eps`
392    ///
393    /// # Examples
394    /// ```
395    /// # use edgefirst_decoder::DetectBox;
396    /// let box1 = DetectBox {
397    ///     bbox: edgefirst_decoder::BoundingBox {
398    ///         xmin: 0.1,
399    ///         ymin: 0.2,
400    ///         xmax: 0.3,
401    ///         ymax: 0.4,
402    ///     },
403    ///     score: 0.5,
404    ///     label: 1,
405    /// };
406    /// let box2 = DetectBox {
407    ///     bbox: edgefirst_decoder::BoundingBox {
408    ///         xmin: 0.101,
409    ///         ymin: 0.199,
410    ///         xmax: 0.301,
411    ///         ymax: 0.399,
412    ///     },
413    ///     score: 0.510,
414    ///     label: 1,
415    /// };
416    /// assert!(box1.equal_within_delta(&box2, 0.011));
417    /// ```
418    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
419        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
420        self.label == rhs.label
421            && eq_delta(self.score, rhs.score)
422            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
423            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
424            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
425            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
426    }
427}
428
429/// A segmentation result with a segmentation mask, and a normalized bounding
430/// box representing the area that the segmentation mask covers
431#[derive(Debug, Clone, PartialEq, Default)]
432pub struct Segmentation {
433    /// left-most normalized coordinate of the segmentation box
434    pub xmin: f32,
435    /// top-most normalized coordinate of the segmentation box
436    pub ymin: f32,
437    /// right-most normalized coordinate of the segmentation box
438    pub xmax: f32,
439    /// bottom-most normalized coordinate of the segmentation box
440    pub ymax: f32,
441    /// 3D segmentation array of shape `(H, W, C)`.
442    ///
443    /// For instance segmentation (e.g. YOLO): `C=1` — binary per-instance
444    /// mask where values >= 128 indicate object presence.
445    ///
446    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
447    /// class scores where the object class is the argmax index.
448    pub segmentation: Array3<u8>,
449}
450
451/// Prototype tensor variants for fused decode+render pipelines.
452///
453/// Carries either raw quantized data (to skip CPU dequantization and let the
454/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
455/// paths).
456#[derive(Debug, Clone)]
457pub enum ProtoTensor {
458    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
459    /// The GPU fragment shader will dequantize per-texel using the scale and
460    /// zero_point.
461    Quantized {
462        protos: Array3<i8>,
463        quantization: Quantization,
464    },
465    /// Dequantized f32 protos (from float models or legacy path).
466    Float(Array3<f32>),
467}
468
469impl ProtoTensor {
470    /// Returns `true` if this is the quantized variant.
471    pub fn is_quantized(&self) -> bool {
472        matches!(self, ProtoTensor::Quantized { .. })
473    }
474
475    /// Returns the spatial dimensions `(height, width, num_protos)`.
476    pub fn dim(&self) -> (usize, usize, usize) {
477        match self {
478            ProtoTensor::Quantized { protos, .. } => protos.dim(),
479            ProtoTensor::Float(arr) => arr.dim(),
480        }
481    }
482
483    /// Returns dequantized f32 protos. For the `Float` variant this is a
484    /// no-copy reference; for `Quantized` it allocates and dequantizes.
485    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
486        match self {
487            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
488            ProtoTensor::Quantized {
489                protos,
490                quantization,
491            } => {
492                let scale = quantization.scale;
493                let zp = quantization.zero_point as f32;
494                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
495            }
496        }
497    }
498}
499
500/// Raw prototype data for fused decode+render pipelines.
501///
502/// Holds post-NMS intermediate state before mask materialization, allowing the
503/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
504/// shader) without materializing intermediate `Array3<u8>` masks.
505#[derive(Debug, Clone)]
506pub struct ProtoData {
507    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
508    pub mask_coefficients: Vec<Vec<f32>>,
509    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
510    pub protos: ProtoTensor,
511}
512
513/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
514///
515///  # Examples
516/// ```
517/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
518/// let quant = Quantization::new(0.1, -128);
519/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
520/// let detect_quant = DetectBoxQuantized {
521///     bbox,
522///     score: 100_i8,
523///     label: 1,
524/// };
525/// let detect = dequant_detect_box(&detect_quant, quant);
526/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
527/// assert_eq!(detect.label, 1);
528/// assert_eq!(detect.bbox, bbox);
529/// ```
530pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
531    detect: &DetectBoxQuantized<SCORE>,
532    quant_scores: Quantization,
533) -> DetectBox {
534    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
535    DetectBox {
536        bbox: detect.bbox,
537        score: quant_scores.scale * detect.score.as_() + scaled_zp,
538        label: detect.label,
539    }
540}
541/// A detection box with a f32 bbox and quantized score
542#[derive(Debug, Clone, Copy, PartialEq)]
543pub struct DetectBoxQuantized<
544    // BOX: Signed + PrimInt + AsPrimitive<f32>,
545    SCORE: PrimInt + AsPrimitive<f32>,
546> {
547    // pub bbox: BoundingBoxQuantized<BOX>,
548    pub bbox: BoundingBox,
549    /// model-specific score for this detection, higher implies more
550    /// confidence.
551    pub score: SCORE,
552    /// label index for this detect
553    pub label: usize,
554}
555
556/// Dequantizes an ndarray from quantized values to f32 values using the given
557/// quantization parameters
558///
559/// # Examples
560/// ```
561/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
562/// let quant = Quantization::new(0.1, -128);
563/// let input: Vec<i8> = vec![0, 127, -128, 64];
564/// let input_array = ndarray::Array1::from(input);
565/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
566/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
567/// ```
568pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
569    input: ArrayView<T, D>,
570    quant: Quantization,
571) -> Array<F, D>
572where
573    i32: num_traits::AsPrimitive<F>,
574    f32: num_traits::AsPrimitive<F>,
575{
576    let zero_point = quant.zero_point.as_();
577    let scale = quant.scale.as_();
578    if zero_point != F::zero() {
579        let scaled_zero = -zero_point * scale;
580        input.mapv(|d| d.as_() * scale + scaled_zero)
581    } else {
582        input.mapv(|d| d.as_() * scale)
583    }
584}
585
586/// Dequantizes a slice from quantized values to float values using the given
587/// quantization parameters
588///
589/// # Examples
590/// ```
591/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
592/// let quant = Quantization::new(0.1, -128);
593/// let input: Vec<i8> = vec![0, 127, -128, 64];
594/// let mut output: Vec<f32> = vec![0.0; input.len()];
595/// dequantize_cpu(&input, quant, &mut output);
596/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
597/// ```
598pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
599    input: &[T],
600    quant: Quantization,
601    output: &mut [F],
602) where
603    f32: num_traits::AsPrimitive<F>,
604    i32: num_traits::AsPrimitive<F>,
605{
606    assert!(input.len() == output.len());
607    let zero_point = quant.zero_point.as_();
608    let scale = quant.scale.as_();
609    if zero_point != F::zero() {
610        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
611        input
612            .iter()
613            .zip(output)
614            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
615    } else {
616        input
617            .iter()
618            .zip(output)
619            .for_each(|(d, deq)| *deq = d.as_() * scale);
620    }
621}
622
623/// Dequantizes a slice from quantized values to float values using the given
624/// quantization parameters, using chunked processing. This is around 5% faster
625/// than `dequantize_cpu` for large slices.
626///
627/// # Examples
628/// ```
629/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
630/// let quant = Quantization::new(0.1, -128);
631/// let input: Vec<i8> = vec![0, 127, -128, 64];
632/// let mut output: Vec<f32> = vec![0.0; input.len()];
633/// dequantize_cpu_chunked(&input, quant, &mut output);
634/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
635/// ```
636pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
637    input: &[T],
638    quant: Quantization,
639    output: &mut [F],
640) where
641    f32: num_traits::AsPrimitive<F>,
642    i32: num_traits::AsPrimitive<F>,
643{
644    assert!(input.len() == output.len());
645    let zero_point = quant.zero_point.as_();
646    let scale = quant.scale.as_();
647
648    let input = input.as_chunks::<4>();
649    let output = output.as_chunks_mut::<4>();
650
651    if zero_point != F::zero() {
652        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
653
654        input
655            .0
656            .iter()
657            .zip(output.0)
658            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
659        input
660            .1
661            .iter()
662            .zip(output.1)
663            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
664    } else {
665        input
666            .0
667            .iter()
668            .zip(output.0)
669            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
670        input
671            .1
672            .iter()
673            .zip(output.1)
674            .for_each(|(d, deq)| *deq = d.as_() * scale);
675    }
676}
677
678/// Converts a segmentation tensor into a 2D mask
679/// If the last dimension of the segmentation tensor is 1, values equal or
680/// above 128 are considered objects. Otherwise the object is the argmax index
681///
682/// # Errors
683///
684/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
685/// invalid shape.
686///
687/// # Examples
688/// ```
689/// # use edgefirst_decoder::segmentation_to_mask;
690/// let segmentation =
691///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
692/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
693/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
694/// ```
695pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
696    if segmentation.shape()[2] == 0 {
697        return Err(DecoderError::InvalidShape(
698            "Segmentation tensor must have non-zero depth".to_string(),
699        ));
700    }
701    if segmentation.shape()[2] == 1 {
702        yolo_segmentation_to_mask(segmentation, 128)
703    } else {
704        Ok(modelpack_segmentation_to_mask(segmentation))
705    }
706}
707
708/// Returns the maximum value and its index from a 1D array
709fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
710    score
711        .iter()
712        .enumerate()
713        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
714            if max > *s {
715                (max, arg_max)
716            } else {
717                (*s, ind)
718            }
719        })
720}
721#[cfg(test)]
722#[cfg_attr(coverage_nightly, coverage(off))]
723mod decoder_tests {
724    #![allow(clippy::excessive_precision)]
725    use crate::{
726        configs::{DecoderType, DimName, Protos},
727        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
728        yolo::{
729            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
730            decode_yolo_segdet_quant,
731        },
732        *,
733    };
734    use ndarray::{array, s, Array4};
735    use ndarray_stats::DeviationExt;
736
737    fn compare_outputs(
738        boxes: (&[DetectBox], &[DetectBox]),
739        masks: (&[Segmentation], &[Segmentation]),
740    ) {
741        let (boxes0, boxes1) = boxes;
742        let (masks0, masks1) = masks;
743
744        assert_eq!(boxes0.len(), boxes1.len());
745        assert_eq!(masks0.len(), masks1.len());
746
747        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
748            assert!(
749                b_i8.equal_within_delta(b_f32, 1e-6),
750                "{b_i8:?} is not equal to {b_f32:?}"
751            );
752        }
753
754        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
755            assert_eq!(
756                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
757                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
758            );
759            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
760            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
761            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
762            let diff = &mask_i8 - &mask_f32;
763            for x in 0..diff.shape()[0] {
764                for y in 0..diff.shape()[1] {
765                    for z in 0..diff.shape()[2] {
766                        let val = diff[[x, y, z]];
767                        assert!(
768                            val.abs() <= 1,
769                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
770                            x,
771                            y,
772                            z,
773                            val
774                        );
775                    }
776                }
777            }
778            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
779            assert!(
780                mean_sq_err < 1e-2,
781                "Mean Square Error between masks was greater than 1%: {:.2}%",
782                mean_sq_err * 100.0
783            );
784        }
785    }
786
787    #[test]
788    fn test_decoder_modelpack() {
789        let score_threshold = 0.45;
790        let iou_threshold = 0.45;
791        let boxes = include_bytes!(concat!(
792            env!("CARGO_MANIFEST_DIR"),
793            "/../../testdata/modelpack_boxes_1935x1x4.bin"
794        ));
795        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
796
797        let scores = include_bytes!(concat!(
798            env!("CARGO_MANIFEST_DIR"),
799            "/../../testdata/modelpack_scores_1935x1.bin"
800        ));
801        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
802
803        let quant_boxes = (0.004656755365431309, 21).into();
804        let quant_scores = (0.0019603664986789227, 0).into();
805
806        let decoder = DecoderBuilder::default()
807            .with_config_modelpack_det(
808                configs::Boxes {
809                    decoder: DecoderType::ModelPack,
810                    quantization: Some(quant_boxes),
811                    shape: vec![1, 1935, 1, 4],
812                    dshape: vec![
813                        (DimName::Batch, 1),
814                        (DimName::NumBoxes, 1935),
815                        (DimName::Padding, 1),
816                        (DimName::BoxCoords, 4),
817                    ],
818                    normalized: Some(true),
819                },
820                configs::Scores {
821                    decoder: DecoderType::ModelPack,
822                    quantization: Some(quant_scores),
823                    shape: vec![1, 1935, 1],
824                    dshape: vec![
825                        (DimName::Batch, 1),
826                        (DimName::NumBoxes, 1935),
827                        (DimName::NumClasses, 1),
828                    ],
829                },
830            )
831            .with_score_threshold(score_threshold)
832            .with_iou_threshold(iou_threshold)
833            .build()
834            .unwrap();
835
836        let quant_boxes = quant_boxes.into();
837        let quant_scores = quant_scores.into();
838
839        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
840        decode_modelpack_det(
841            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
842            (scores.slice(s![0, .., ..]), quant_scores),
843            score_threshold,
844            iou_threshold,
845            &mut output_boxes,
846        );
847        assert!(output_boxes[0].equal_within_delta(
848            &DetectBox {
849                bbox: BoundingBox {
850                    xmin: 0.40513772,
851                    ymin: 0.6379755,
852                    xmax: 0.5122431,
853                    ymax: 0.7730214,
854                },
855                score: 0.4861709,
856                label: 0
857            },
858            1e-6
859        ));
860
861        let mut output_boxes1 = Vec::with_capacity(50);
862        let mut output_masks1 = Vec::with_capacity(50);
863
864        decoder
865            .decode_quantized(
866                &[boxes.view().into(), scores.view().into()],
867                &mut output_boxes1,
868                &mut output_masks1,
869            )
870            .unwrap();
871
872        let mut output_boxes_float = Vec::with_capacity(50);
873        let mut output_masks_float = Vec::with_capacity(50);
874
875        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
876        let scores = dequantize_ndarray(scores.view(), quant_scores);
877
878        decoder
879            .decode_float::<f32>(
880                &[boxes.view().into_dyn(), scores.view().into_dyn()],
881                &mut output_boxes_float,
882                &mut output_masks_float,
883            )
884            .unwrap();
885
886        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
887        compare_outputs(
888            (&output_boxes, &output_boxes_float),
889            (&[], &output_masks_float),
890        );
891    }
892
893    #[test]
894    fn test_decoder_modelpack_split_u8() {
895        let score_threshold = 0.45;
896        let iou_threshold = 0.45;
897        let detect0 = include_bytes!(concat!(
898            env!("CARGO_MANIFEST_DIR"),
899            "/../../testdata/modelpack_split_9x15x18.bin"
900        ));
901        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
902
903        let detect1 = include_bytes!(concat!(
904            env!("CARGO_MANIFEST_DIR"),
905            "/../../testdata/modelpack_split_17x30x18.bin"
906        ));
907        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
908
909        let quant0 = (0.08547406643629074, 174).into();
910        let quant1 = (0.09929127991199493, 183).into();
911        let anchors0 = vec![
912            [0.36666667461395264, 0.31481480598449707],
913            [0.38749998807907104, 0.4740740656852722],
914            [0.5333333611488342, 0.644444465637207],
915        ];
916        let anchors1 = vec![
917            [0.13750000298023224, 0.2074074000120163],
918            [0.2541666626930237, 0.21481481194496155],
919            [0.23125000298023224, 0.35185185074806213],
920        ];
921
922        let detect_config0 = configs::Detection {
923            decoder: DecoderType::ModelPack,
924            shape: vec![1, 9, 15, 18],
925            anchors: Some(anchors0.clone()),
926            quantization: Some(quant0),
927            dshape: vec![
928                (DimName::Batch, 1),
929                (DimName::Height, 9),
930                (DimName::Width, 15),
931                (DimName::NumAnchorsXFeatures, 18),
932            ],
933            normalized: Some(true),
934        };
935
936        let detect_config1 = configs::Detection {
937            decoder: DecoderType::ModelPack,
938            shape: vec![1, 17, 30, 18],
939            anchors: Some(anchors1.clone()),
940            quantization: Some(quant1),
941            dshape: vec![
942                (DimName::Batch, 1),
943                (DimName::Height, 17),
944                (DimName::Width, 30),
945                (DimName::NumAnchorsXFeatures, 18),
946            ],
947            normalized: Some(true),
948        };
949
950        let config0 = (&detect_config0).try_into().unwrap();
951        let config1 = (&detect_config1).try_into().unwrap();
952
953        let decoder = DecoderBuilder::default()
954            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
955            .with_score_threshold(score_threshold)
956            .with_iou_threshold(iou_threshold)
957            .build()
958            .unwrap();
959
960        let quant0 = quant0.into();
961        let quant1 = quant1.into();
962
963        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
964        decode_modelpack_split_quant(
965            &[
966                detect0.slice(s![0, .., .., ..]),
967                detect1.slice(s![0, .., .., ..]),
968            ],
969            &[config0, config1],
970            score_threshold,
971            iou_threshold,
972            &mut output_boxes,
973        );
974        assert!(output_boxes[0].equal_within_delta(
975            &DetectBox {
976                bbox: BoundingBox {
977                    xmin: 0.43171933,
978                    ymin: 0.68243736,
979                    xmax: 0.5626645,
980                    ymax: 0.808863,
981                },
982                score: 0.99240804,
983                label: 0
984            },
985            1e-6
986        ));
987
988        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
989        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
990        decoder
991            .decode_quantized(
992                &[detect0.view().into(), detect1.view().into()],
993                &mut output_boxes1,
994                &mut output_masks1,
995            )
996            .unwrap();
997
998        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
999        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1000
1001        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1002        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1003        decoder
1004            .decode_float::<f32>(
1005                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1006                &mut output_boxes1_f32,
1007                &mut output_masks1_f32,
1008            )
1009            .unwrap();
1010
1011        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1012        compare_outputs(
1013            (&output_boxes, &output_boxes1_f32),
1014            (&[], &output_masks1_f32),
1015        );
1016    }
1017
1018    #[test]
1019    fn test_decoder_parse_config_modelpack_split_u8() {
1020        let score_threshold = 0.45;
1021        let iou_threshold = 0.45;
1022        let detect0 = include_bytes!(concat!(
1023            env!("CARGO_MANIFEST_DIR"),
1024            "/../../testdata/modelpack_split_9x15x18.bin"
1025        ));
1026        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1027
1028        let detect1 = include_bytes!(concat!(
1029            env!("CARGO_MANIFEST_DIR"),
1030            "/../../testdata/modelpack_split_17x30x18.bin"
1031        ));
1032        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1033
1034        let decoder = DecoderBuilder::default()
1035            .with_config_yaml_str(
1036                include_str!(concat!(
1037                    env!("CARGO_MANIFEST_DIR"),
1038                    "/../../testdata/modelpack_split.yaml"
1039                ))
1040                .to_string(),
1041            )
1042            .with_score_threshold(score_threshold)
1043            .with_iou_threshold(iou_threshold)
1044            .build()
1045            .unwrap();
1046
1047        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1048        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1049        decoder
1050            .decode_quantized(
1051                &[
1052                    ArrayViewDQuantized::from(detect1.view()),
1053                    ArrayViewDQuantized::from(detect0.view()),
1054                ],
1055                &mut output_boxes,
1056                &mut output_masks,
1057            )
1058            .unwrap();
1059        assert!(output_boxes[0].equal_within_delta(
1060            &DetectBox {
1061                bbox: BoundingBox {
1062                    xmin: 0.43171933,
1063                    ymin: 0.68243736,
1064                    xmax: 0.5626645,
1065                    ymax: 0.808863,
1066                },
1067                score: 0.99240804,
1068                label: 0
1069            },
1070            1e-6
1071        ));
1072    }
1073
1074    #[test]
1075    fn test_modelpack_seg() {
1076        let out = include_bytes!(concat!(
1077            env!("CARGO_MANIFEST_DIR"),
1078            "/../../testdata/modelpack_seg_2x160x160.bin"
1079        ));
1080        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1081        let quant = (1.0 / 255.0, 0).into();
1082
1083        let decoder = DecoderBuilder::default()
1084            .with_config_modelpack_seg(configs::Segmentation {
1085                decoder: DecoderType::ModelPack,
1086                quantization: Some(quant),
1087                shape: vec![1, 2, 160, 160],
1088                dshape: vec![
1089                    (DimName::Batch, 1),
1090                    (DimName::NumClasses, 2),
1091                    (DimName::Height, 160),
1092                    (DimName::Width, 160),
1093                ],
1094            })
1095            .build()
1096            .unwrap();
1097        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1098        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1099        decoder
1100            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1101            .unwrap();
1102
1103        let mut mask = out.slice(s![0, .., .., ..]);
1104        mask.swap_axes(0, 1);
1105        mask.swap_axes(1, 2);
1106        let mask = [Segmentation {
1107            xmin: 0.0,
1108            ymin: 0.0,
1109            xmax: 1.0,
1110            ymax: 1.0,
1111            segmentation: mask.into_owned(),
1112        }];
1113        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1114
1115        decoder
1116            .decode_float::<f32>(
1117                &[dequantize_ndarray(out.view(), quant.into())
1118                    .view()
1119                    .into_dyn()],
1120                &mut output_boxes,
1121                &mut output_masks,
1122            )
1123            .unwrap();
1124
1125        // not expected for float decoder to have same values as quantized decoder, as
1126        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1127        // the model output. Thus the float output is the same as the quantized output
1128        // but scaled differently. However, it is expected that the mask after argmax
1129        // will be the same.
1130        compare_outputs((&[], &output_boxes), (&[], &[]));
1131        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1132        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1133
1134        assert_eq!(mask0, mask1);
1135    }
1136    #[test]
1137    fn test_modelpack_seg_quant() {
1138        let out = include_bytes!(concat!(
1139            env!("CARGO_MANIFEST_DIR"),
1140            "/../../testdata/modelpack_seg_2x160x160.bin"
1141        ));
1142        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1143        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1144        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1145        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1146        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1147        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1148
1149        let quant = (1.0 / 255.0, 0).into();
1150
1151        let decoder = DecoderBuilder::default()
1152            .with_config_modelpack_seg(configs::Segmentation {
1153                decoder: DecoderType::ModelPack,
1154                quantization: Some(quant),
1155                shape: vec![1, 2, 160, 160],
1156                dshape: vec![
1157                    (DimName::Batch, 1),
1158                    (DimName::NumClasses, 2),
1159                    (DimName::Height, 160),
1160                    (DimName::Width, 160),
1161                ],
1162            })
1163            .build()
1164            .unwrap();
1165        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1166        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1167        decoder
1168            .decode_quantized(
1169                &[out_u8.view().into()],
1170                &mut output_boxes,
1171                &mut output_masks_u8,
1172            )
1173            .unwrap();
1174
1175        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1176        decoder
1177            .decode_quantized(
1178                &[out_i8.view().into()],
1179                &mut output_boxes,
1180                &mut output_masks_i8,
1181            )
1182            .unwrap();
1183
1184        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1185        decoder
1186            .decode_quantized(
1187                &[out_u16.view().into()],
1188                &mut output_boxes,
1189                &mut output_masks_u16,
1190            )
1191            .unwrap();
1192
1193        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1194        decoder
1195            .decode_quantized(
1196                &[out_i16.view().into()],
1197                &mut output_boxes,
1198                &mut output_masks_i16,
1199            )
1200            .unwrap();
1201
1202        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1203        decoder
1204            .decode_quantized(
1205                &[out_u32.view().into()],
1206                &mut output_boxes,
1207                &mut output_masks_u32,
1208            )
1209            .unwrap();
1210
1211        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1212        decoder
1213            .decode_quantized(
1214                &[out_i32.view().into()],
1215                &mut output_boxes,
1216                &mut output_masks_i32,
1217            )
1218            .unwrap();
1219
1220        compare_outputs((&[], &output_boxes), (&[], &[]));
1221        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1222        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1223        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1224        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1225        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1226        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1227        assert_eq!(mask_u8, mask_i8);
1228        assert_eq!(mask_u8, mask_u16);
1229        assert_eq!(mask_u8, mask_i16);
1230        assert_eq!(mask_u8, mask_u32);
1231        assert_eq!(mask_u8, mask_i32);
1232    }
1233
1234    #[test]
1235    fn test_modelpack_segdet() {
1236        let score_threshold = 0.45;
1237        let iou_threshold = 0.45;
1238
1239        let boxes = include_bytes!(concat!(
1240            env!("CARGO_MANIFEST_DIR"),
1241            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1242        ));
1243        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1244
1245        let scores = include_bytes!(concat!(
1246            env!("CARGO_MANIFEST_DIR"),
1247            "/../../testdata/modelpack_scores_1935x1.bin"
1248        ));
1249        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1250
1251        let seg = include_bytes!(concat!(
1252            env!("CARGO_MANIFEST_DIR"),
1253            "/../../testdata/modelpack_seg_2x160x160.bin"
1254        ));
1255        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1256
1257        let quant_boxes = (0.004656755365431309, 21).into();
1258        let quant_scores = (0.0019603664986789227, 0).into();
1259        let quant_seg = (1.0 / 255.0, 0).into();
1260
1261        let decoder = DecoderBuilder::default()
1262            .with_config_modelpack_segdet(
1263                configs::Boxes {
1264                    decoder: DecoderType::ModelPack,
1265                    quantization: Some(quant_boxes),
1266                    shape: vec![1, 1935, 1, 4],
1267                    dshape: vec![
1268                        (DimName::Batch, 1),
1269                        (DimName::NumBoxes, 1935),
1270                        (DimName::Padding, 1),
1271                        (DimName::BoxCoords, 4),
1272                    ],
1273                    normalized: Some(true),
1274                },
1275                configs::Scores {
1276                    decoder: DecoderType::ModelPack,
1277                    quantization: Some(quant_scores),
1278                    shape: vec![1, 1935, 1],
1279                    dshape: vec![
1280                        (DimName::Batch, 1),
1281                        (DimName::NumBoxes, 1935),
1282                        (DimName::NumClasses, 1),
1283                    ],
1284                },
1285                configs::Segmentation {
1286                    decoder: DecoderType::ModelPack,
1287                    quantization: Some(quant_seg),
1288                    shape: vec![1, 2, 160, 160],
1289                    dshape: vec![
1290                        (DimName::Batch, 1),
1291                        (DimName::NumClasses, 2),
1292                        (DimName::Height, 160),
1293                        (DimName::Width, 160),
1294                    ],
1295                },
1296            )
1297            .with_iou_threshold(iou_threshold)
1298            .with_score_threshold(score_threshold)
1299            .build()
1300            .unwrap();
1301        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1302        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1303        decoder
1304            .decode_quantized(
1305                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1306                &mut output_boxes,
1307                &mut output_masks,
1308            )
1309            .unwrap();
1310
1311        let mut mask = seg.slice(s![0, .., .., ..]);
1312        mask.swap_axes(0, 1);
1313        mask.swap_axes(1, 2);
1314        let mask = [Segmentation {
1315            xmin: 0.0,
1316            ymin: 0.0,
1317            xmax: 1.0,
1318            ymax: 1.0,
1319            segmentation: mask.into_owned(),
1320        }];
1321        let correct_boxes = [DetectBox {
1322            bbox: BoundingBox {
1323                xmin: 0.40513772,
1324                ymin: 0.6379755,
1325                xmax: 0.5122431,
1326                ymax: 0.7730214,
1327            },
1328            score: 0.4861709,
1329            label: 0,
1330        }];
1331        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1332
1333        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1334        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1335        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1336        decoder
1337            .decode_float::<f32>(
1338                &[
1339                    scores.view().into_dyn(),
1340                    boxes.view().into_dyn(),
1341                    seg.view().into_dyn(),
1342                ],
1343                &mut output_boxes,
1344                &mut output_masks,
1345            )
1346            .unwrap();
1347
1348        // not expected for float segmentation decoder to have same values as quantized
1349        // segmentation decoder, as float decoder ensures the data fills 0-255,
1350        // quantized decoder uses whatever the model output. Thus the float
1351        // output is the same as the quantized output but scaled differently.
1352        // However, it is expected that the mask after argmax will be the same.
1353        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1354        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1355        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1356
1357        assert_eq!(mask0, mask1);
1358    }
1359
1360    #[test]
1361    fn test_modelpack_segdet_split() {
1362        let score_threshold = 0.8;
1363        let iou_threshold = 0.5;
1364
1365        let seg = include_bytes!(concat!(
1366            env!("CARGO_MANIFEST_DIR"),
1367            "/../../testdata/modelpack_seg_2x160x160.bin"
1368        ));
1369        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1370
1371        let detect0 = include_bytes!(concat!(
1372            env!("CARGO_MANIFEST_DIR"),
1373            "/../../testdata/modelpack_split_9x15x18.bin"
1374        ));
1375        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1376
1377        let detect1 = include_bytes!(concat!(
1378            env!("CARGO_MANIFEST_DIR"),
1379            "/../../testdata/modelpack_split_17x30x18.bin"
1380        ));
1381        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1382
1383        let quant0 = (0.08547406643629074, 174).into();
1384        let quant1 = (0.09929127991199493, 183).into();
1385        let quant_seg = (1.0 / 255.0, 0).into();
1386
1387        let anchors0 = vec![
1388            [0.36666667461395264, 0.31481480598449707],
1389            [0.38749998807907104, 0.4740740656852722],
1390            [0.5333333611488342, 0.644444465637207],
1391        ];
1392        let anchors1 = vec![
1393            [0.13750000298023224, 0.2074074000120163],
1394            [0.2541666626930237, 0.21481481194496155],
1395            [0.23125000298023224, 0.35185185074806213],
1396        ];
1397
1398        let decoder = DecoderBuilder::default()
1399            .with_config_modelpack_segdet_split(
1400                vec![
1401                    configs::Detection {
1402                        decoder: DecoderType::ModelPack,
1403                        shape: vec![1, 17, 30, 18],
1404                        anchors: Some(anchors1),
1405                        quantization: Some(quant1),
1406                        dshape: vec![
1407                            (DimName::Batch, 1),
1408                            (DimName::Height, 17),
1409                            (DimName::Width, 30),
1410                            (DimName::NumAnchorsXFeatures, 18),
1411                        ],
1412                        normalized: Some(true),
1413                    },
1414                    configs::Detection {
1415                        decoder: DecoderType::ModelPack,
1416                        shape: vec![1, 9, 15, 18],
1417                        anchors: Some(anchors0),
1418                        quantization: Some(quant0),
1419                        dshape: vec![
1420                            (DimName::Batch, 1),
1421                            (DimName::Height, 9),
1422                            (DimName::Width, 15),
1423                            (DimName::NumAnchorsXFeatures, 18),
1424                        ],
1425                        normalized: Some(true),
1426                    },
1427                ],
1428                configs::Segmentation {
1429                    decoder: DecoderType::ModelPack,
1430                    quantization: Some(quant_seg),
1431                    shape: vec![1, 2, 160, 160],
1432                    dshape: vec![
1433                        (DimName::Batch, 1),
1434                        (DimName::NumClasses, 2),
1435                        (DimName::Height, 160),
1436                        (DimName::Width, 160),
1437                    ],
1438                },
1439            )
1440            .with_score_threshold(score_threshold)
1441            .with_iou_threshold(iou_threshold)
1442            .build()
1443            .unwrap();
1444        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1445        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1446        decoder
1447            .decode_quantized(
1448                &[
1449                    detect0.view().into(),
1450                    detect1.view().into(),
1451                    seg.view().into(),
1452                ],
1453                &mut output_boxes,
1454                &mut output_masks,
1455            )
1456            .unwrap();
1457
1458        let mut mask = seg.slice(s![0, .., .., ..]);
1459        mask.swap_axes(0, 1);
1460        mask.swap_axes(1, 2);
1461        let mask = [Segmentation {
1462            xmin: 0.0,
1463            ymin: 0.0,
1464            xmax: 1.0,
1465            ymax: 1.0,
1466            segmentation: mask.into_owned(),
1467        }];
1468        let correct_boxes = [DetectBox {
1469            bbox: BoundingBox {
1470                xmin: 0.43171933,
1471                ymin: 0.68243736,
1472                xmax: 0.5626645,
1473                ymax: 0.808863,
1474            },
1475            score: 0.99240804,
1476            label: 0,
1477        }];
1478        println!("Output Boxes: {:?}", output_boxes);
1479        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1480
1481        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1482        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1483        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1484        decoder
1485            .decode_float::<f32>(
1486                &[
1487                    detect0.view().into_dyn(),
1488                    detect1.view().into_dyn(),
1489                    seg.view().into_dyn(),
1490                ],
1491                &mut output_boxes,
1492                &mut output_masks,
1493            )
1494            .unwrap();
1495
1496        // not expected for float segmentation decoder to have same values as quantized
1497        // segmentation decoder, as float decoder ensures the data fills 0-255,
1498        // quantized decoder uses whatever the model output. Thus the float
1499        // output is the same as the quantized output but scaled differently.
1500        // However, it is expected that the mask after argmax will be the same.
1501        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1502        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1503        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1504
1505        assert_eq!(mask0, mask1);
1506    }
1507
1508    #[test]
1509    fn test_dequant_chunked() {
1510        let out = include_bytes!(concat!(
1511            env!("CARGO_MANIFEST_DIR"),
1512            "/../../testdata/yolov8s_80_classes.bin"
1513        ));
1514        let mut out =
1515            unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1516        out.push(123); // make sure to test non multiple of 16 length
1517
1518        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1519        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1520        let quant = Quantization::new(0.0040811873, -123);
1521        dequantize_cpu(&out, quant, &mut out_dequant);
1522
1523        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1524        assert_eq!(out_dequant, out_dequant_simd);
1525
1526        let quant = Quantization::new(0.0040811873, 0);
1527        dequantize_cpu(&out, quant, &mut out_dequant);
1528
1529        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1530        assert_eq!(out_dequant, out_dequant_simd);
1531    }
1532
1533    #[test]
1534    fn test_decoder_yolo_det() {
1535        let score_threshold = 0.25;
1536        let iou_threshold = 0.7;
1537        let out = include_bytes!(concat!(
1538            env!("CARGO_MANIFEST_DIR"),
1539            "/../../testdata/yolov8s_80_classes.bin"
1540        ));
1541        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1542        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1543        let quant = (0.0040811873, -123).into();
1544
1545        let decoder = DecoderBuilder::default()
1546            .with_config_yolo_det(
1547                configs::Detection {
1548                    decoder: DecoderType::Ultralytics,
1549                    shape: vec![1, 84, 8400],
1550                    anchors: None,
1551                    quantization: Some(quant),
1552                    dshape: vec![
1553                        (DimName::Batch, 1),
1554                        (DimName::NumFeatures, 84),
1555                        (DimName::NumBoxes, 8400),
1556                    ],
1557                    normalized: Some(true),
1558                },
1559                Some(DecoderVersion::Yolo11),
1560            )
1561            .with_score_threshold(score_threshold)
1562            .with_iou_threshold(iou_threshold)
1563            .build()
1564            .unwrap();
1565
1566        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1567        decode_yolo_det(
1568            (out.slice(s![0, .., ..]), quant.into()),
1569            score_threshold,
1570            iou_threshold,
1571            Some(configs::Nms::ClassAgnostic),
1572            &mut output_boxes,
1573        );
1574        assert!(output_boxes[0].equal_within_delta(
1575            &DetectBox {
1576                bbox: BoundingBox {
1577                    xmin: 0.5285137,
1578                    ymin: 0.05305544,
1579                    xmax: 0.87541467,
1580                    ymax: 0.9998909,
1581                },
1582                score: 0.5591227,
1583                label: 0
1584            },
1585            1e-6
1586        ));
1587
1588        assert!(output_boxes[1].equal_within_delta(
1589            &DetectBox {
1590                bbox: BoundingBox {
1591                    xmin: 0.130598,
1592                    ymin: 0.43260583,
1593                    xmax: 0.35098213,
1594                    ymax: 0.9958097,
1595                },
1596                score: 0.33057618,
1597                label: 75
1598            },
1599            1e-6
1600        ));
1601
1602        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1603        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1604        decoder
1605            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1606            .unwrap();
1607
1608        let out = dequantize_ndarray(out.view(), quant.into());
1609        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1610        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1611        decoder
1612            .decode_float::<f32>(
1613                &[out.view().into_dyn()],
1614                &mut output_boxes_f32,
1615                &mut output_masks_f32,
1616            )
1617            .unwrap();
1618
1619        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1620        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1621    }
1622
1623    #[test]
1624    fn test_decoder_masks() {
1625        let score_threshold = 0.45;
1626        let iou_threshold = 0.45;
1627        let boxes = include_bytes!(concat!(
1628            env!("CARGO_MANIFEST_DIR"),
1629            "/../../testdata/yolov8_boxes_116x8400.bin"
1630        ));
1631        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1632        let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1633        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1634
1635        let protos = include_bytes!(concat!(
1636            env!("CARGO_MANIFEST_DIR"),
1637            "/../../testdata/yolov8_protos_160x160x32.bin"
1638        ));
1639        let protos =
1640            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1641        let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1642        let quant_protos = Quantization::new(0.02491161972284317, -117);
1643        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1644        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1645        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1646        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1647        decode_yolo_segdet_float(
1648            seg.view(),
1649            protos.view(),
1650            score_threshold,
1651            iou_threshold,
1652            Some(configs::Nms::ClassAgnostic),
1653            &mut output_boxes,
1654            &mut output_masks,
1655        )
1656        .unwrap();
1657        assert_eq!(output_boxes.len(), 2);
1658        assert_eq!(output_boxes.len(), output_masks.len());
1659
1660        for (b, m) in output_boxes.iter().zip(&output_masks) {
1661            assert!(b.bbox.xmin >= m.xmin);
1662            assert!(b.bbox.ymin >= m.ymin);
1663            assert!(b.bbox.xmax >= m.xmax);
1664            assert!(b.bbox.ymax >= m.ymax);
1665        }
1666        assert!(output_boxes[0].equal_within_delta(
1667            &DetectBox {
1668                bbox: BoundingBox {
1669                    xmin: 0.08515105,
1670                    ymin: 0.7131401,
1671                    xmax: 0.29802868,
1672                    ymax: 0.8195788,
1673                },
1674                score: 0.91537374,
1675                label: 23
1676            },
1677            1.0 / 160.0, // wider range because mask will expand the box
1678        ));
1679
1680        assert!(output_boxes[1].equal_within_delta(
1681            &DetectBox {
1682                bbox: BoundingBox {
1683                    xmin: 0.59605736,
1684                    ymin: 0.25545314,
1685                    xmax: 0.93666154,
1686                    ymax: 0.72378385,
1687                },
1688                score: 0.91537374,
1689                label: 23
1690            },
1691            1.0 / 160.0, // wider range because mask will expand the box
1692        ));
1693
1694        let full_mask = include_bytes!(concat!(
1695            env!("CARGO_MANIFEST_DIR"),
1696            "/../../testdata/yolov8_mask_results.bin"
1697        ));
1698        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1699
1700        let cropped_mask = full_mask.slice(ndarray::s![
1701            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1702            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1703        ]);
1704
1705        assert_eq!(
1706            cropped_mask,
1707            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1708        );
1709    }
1710
1711    /// Regression test: config-driven path with NCHW protos (no dshape).
1712    /// Simulates YOLOv8-seg ONNX outputs where protos are (1, 32, 160, 160)
1713    /// and the YAML config has no dshape field — the exact scenario from
1714    /// hal_mask_matmul_bug.md.
1715    #[test]
1716    fn test_decoder_masks_nchw_protos() {
1717        let score_threshold = 0.45;
1718        let iou_threshold = 0.45;
1719
1720        // Load test data — boxes as [116, 8400]
1721        let boxes_raw = include_bytes!(concat!(
1722            env!("CARGO_MANIFEST_DIR"),
1723            "/../../testdata/yolov8_boxes_116x8400.bin"
1724        ));
1725        let boxes_raw =
1726            unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1727        let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1728        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1729
1730        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1731        let protos_raw = include_bytes!(concat!(
1732            env!("CARGO_MANIFEST_DIR"),
1733            "/../../testdata/yolov8_protos_160x160x32.bin"
1734        ));
1735        let protos_raw = unsafe {
1736            std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1737        };
1738        let protos_hwc =
1739            ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1740        let quant_protos = Quantization::new(0.02491161972284317, -117);
1741        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1742
1743        // ---- Reference: direct call with HWC protos (known working) ----
1744        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1745        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1746        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1747        decode_yolo_segdet_float(
1748            seg.view(),
1749            protos_f32_hwc.view(),
1750            score_threshold,
1751            iou_threshold,
1752            Some(configs::Nms::ClassAgnostic),
1753            &mut ref_boxes,
1754            &mut ref_masks,
1755        )
1756        .unwrap();
1757        assert_eq!(ref_boxes.len(), 2);
1758
1759        // ---- Config-driven path: NCHW protos, no dshape ----
1760        // Permute protos to NCHW [1, 32, 160, 160] as an ONNX model would output
1761        let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); // [32, 160, 160]
1762        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1763
1764        // Build boxes as [1, 116, 8400] f32
1765        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1766
1767        // Build decoder from config with no dshape on protos
1768        let decoder = DecoderBuilder::default()
1769            .with_config_yolo_segdet(
1770                configs::Detection {
1771                    decoder: configs::DecoderType::Ultralytics,
1772                    quantization: None,
1773                    shape: vec![1, 116, 8400],
1774                    dshape: vec![],
1775                    normalized: Some(true),
1776                    anchors: None,
1777                },
1778                configs::Protos {
1779                    decoder: configs::DecoderType::Ultralytics,
1780                    quantization: None,
1781                    shape: vec![1, 32, 160, 160],
1782                    dshape: vec![], // No dshape — simulates YAML without dshape
1783                },
1784                None, // decoder version
1785            )
1786            .with_score_threshold(score_threshold)
1787            .with_iou_threshold(iou_threshold)
1788            .build()
1789            .unwrap();
1790
1791        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1792        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1793        decoder
1794            .decode_float(
1795                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1796                &mut cfg_boxes,
1797                &mut cfg_masks,
1798            )
1799            .unwrap();
1800
1801        // Must produce the same number of detections
1802        assert_eq!(
1803            cfg_boxes.len(),
1804            ref_boxes.len(),
1805            "config path produced {} boxes, reference produced {}",
1806            cfg_boxes.len(),
1807            ref_boxes.len()
1808        );
1809
1810        // Boxes must match
1811        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1812            assert!(
1813                cb.equal_within_delta(rb, 0.01),
1814                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1815            );
1816        }
1817
1818        // Masks must match pixel-for-pixel
1819        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1820            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1821            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1822            assert_eq!(
1823                cm_arr, rm_arr,
1824                "mask {i} pixel mismatch between config-driven and reference paths"
1825            );
1826        }
1827    }
1828
1829    #[test]
1830    fn test_decoder_masks_i8() {
1831        let score_threshold = 0.45;
1832        let iou_threshold = 0.45;
1833        let boxes = include_bytes!(concat!(
1834            env!("CARGO_MANIFEST_DIR"),
1835            "/../../testdata/yolov8_boxes_116x8400.bin"
1836        ));
1837        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1838        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1839        let quant_boxes = (0.021287761628627777, 31).into();
1840
1841        let protos = include_bytes!(concat!(
1842            env!("CARGO_MANIFEST_DIR"),
1843            "/../../testdata/yolov8_protos_160x160x32.bin"
1844        ));
1845        let protos =
1846            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1847        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1848        let quant_protos = (0.02491161972284317, -117).into();
1849        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1850        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1851
1852        let decoder = DecoderBuilder::default()
1853            .with_config_yolo_segdet(
1854                configs::Detection {
1855                    decoder: configs::DecoderType::Ultralytics,
1856                    quantization: Some(quant_boxes),
1857                    shape: vec![1, 116, 8400],
1858                    anchors: None,
1859                    dshape: vec![
1860                        (DimName::Batch, 1),
1861                        (DimName::NumFeatures, 116),
1862                        (DimName::NumBoxes, 8400),
1863                    ],
1864                    normalized: Some(true),
1865                },
1866                Protos {
1867                    decoder: configs::DecoderType::Ultralytics,
1868                    quantization: Some(quant_protos),
1869                    shape: vec![1, 160, 160, 32],
1870                    dshape: vec![
1871                        (DimName::Batch, 1),
1872                        (DimName::Height, 160),
1873                        (DimName::Width, 160),
1874                        (DimName::NumProtos, 32),
1875                    ],
1876                },
1877                Some(DecoderVersion::Yolo11),
1878            )
1879            .with_score_threshold(score_threshold)
1880            .with_iou_threshold(iou_threshold)
1881            .build()
1882            .unwrap();
1883
1884        let quant_boxes = quant_boxes.into();
1885        let quant_protos = quant_protos.into();
1886
1887        decode_yolo_segdet_quant(
1888            (boxes.slice(s![0, .., ..]), quant_boxes),
1889            (protos.slice(s![0, .., .., ..]), quant_protos),
1890            score_threshold,
1891            iou_threshold,
1892            Some(configs::Nms::ClassAgnostic),
1893            &mut output_boxes,
1894            &mut output_masks,
1895        )
1896        .unwrap();
1897
1898        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1899        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1900
1901        decoder
1902            .decode_quantized(
1903                &[boxes.view().into(), protos.view().into()],
1904                &mut output_boxes1,
1905                &mut output_masks1,
1906            )
1907            .unwrap();
1908
1909        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1910        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1911
1912        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1913        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1914        decode_yolo_segdet_float(
1915            seg.slice(s![0, .., ..]),
1916            protos.slice(s![0, .., .., ..]),
1917            score_threshold,
1918            iou_threshold,
1919            Some(configs::Nms::ClassAgnostic),
1920            &mut output_boxes_f32,
1921            &mut output_masks_f32,
1922        )
1923        .unwrap();
1924
1925        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1926        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1927
1928        decoder
1929            .decode_float(
1930                &[seg.view().into_dyn(), protos.view().into_dyn()],
1931                &mut output_boxes1_f32,
1932                &mut output_masks1_f32,
1933            )
1934            .unwrap();
1935
1936        compare_outputs(
1937            (&output_boxes, &output_boxes1),
1938            (&output_masks, &output_masks1),
1939        );
1940
1941        compare_outputs(
1942            (&output_boxes, &output_boxes_f32),
1943            (&output_masks, &output_masks_f32),
1944        );
1945
1946        compare_outputs(
1947            (&output_boxes_f32, &output_boxes1_f32),
1948            (&output_masks_f32, &output_masks1_f32),
1949        );
1950    }
1951
1952    #[test]
1953    fn test_decoder_yolo_split() {
1954        let score_threshold = 0.45;
1955        let iou_threshold = 0.45;
1956        let boxes = include_bytes!(concat!(
1957            env!("CARGO_MANIFEST_DIR"),
1958            "/../../testdata/yolov8_boxes_116x8400.bin"
1959        ));
1960        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1961        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1962        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1963
1964        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1965
1966        let decoder = DecoderBuilder::default()
1967            .with_config_yolo_split_det(
1968                configs::Boxes {
1969                    decoder: configs::DecoderType::Ultralytics,
1970                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1971                    shape: vec![1, 4, 8400],
1972                    dshape: vec![
1973                        (DimName::Batch, 1),
1974                        (DimName::BoxCoords, 4),
1975                        (DimName::NumBoxes, 8400),
1976                    ],
1977                    normalized: Some(true),
1978                },
1979                configs::Scores {
1980                    decoder: configs::DecoderType::Ultralytics,
1981                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1982                    shape: vec![1, 80, 8400],
1983                    dshape: vec![
1984                        (DimName::Batch, 1),
1985                        (DimName::NumClasses, 80),
1986                        (DimName::NumBoxes, 8400),
1987                    ],
1988                },
1989            )
1990            .with_score_threshold(score_threshold)
1991            .with_iou_threshold(iou_threshold)
1992            .build()
1993            .unwrap();
1994
1995        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1996        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1997
1998        decoder
1999            .decode_quantized(
2000                &[
2001                    boxes.slice(s![.., ..4, ..]).into(),
2002                    boxes.slice(s![.., 4..84, ..]).into(),
2003                ],
2004                &mut output_boxes,
2005                &mut output_masks,
2006            )
2007            .unwrap();
2008
2009        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2010        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2011        decode_yolo_det_float(
2012            seg.slice(s![0, ..84, ..]),
2013            score_threshold,
2014            iou_threshold,
2015            Some(configs::Nms::ClassAgnostic),
2016            &mut output_boxes_f32,
2017        );
2018
2019        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2020        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2021
2022        decoder
2023            .decode_float(
2024                &[
2025                    seg.slice(s![.., ..4, ..]).into_dyn(),
2026                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2027                ],
2028                &mut output_boxes1,
2029                &mut output_masks1,
2030            )
2031            .unwrap();
2032        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2033        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2034    }
2035
2036    #[test]
2037    fn test_decoder_masks_config_mixed() {
2038        let score_threshold = 0.45;
2039        let iou_threshold = 0.45;
2040        let boxes = include_bytes!(concat!(
2041            env!("CARGO_MANIFEST_DIR"),
2042            "/../../testdata/yolov8_boxes_116x8400.bin"
2043        ));
2044        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2045        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2046        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2047
2048        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2049
2050        let protos = include_bytes!(concat!(
2051            env!("CARGO_MANIFEST_DIR"),
2052            "/../../testdata/yolov8_protos_160x160x32.bin"
2053        ));
2054        let protos =
2055            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2056        let protos: Vec<_> = protos.to_vec();
2057        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2058        let quant_protos = Quantization::new(0.02491161972284317, -117);
2059
2060        let decoder = DecoderBuilder::default()
2061            .with_config_yolo_split_segdet(
2062                configs::Boxes {
2063                    decoder: configs::DecoderType::Ultralytics,
2064                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2065                    shape: vec![1, 4, 8400],
2066                    dshape: vec![
2067                        (DimName::Batch, 1),
2068                        (DimName::BoxCoords, 4),
2069                        (DimName::NumBoxes, 8400),
2070                    ],
2071                    normalized: Some(true),
2072                },
2073                configs::Scores {
2074                    decoder: configs::DecoderType::Ultralytics,
2075                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2076                    shape: vec![1, 80, 8400],
2077                    dshape: vec![
2078                        (DimName::Batch, 1),
2079                        (DimName::NumClasses, 80),
2080                        (DimName::NumBoxes, 8400),
2081                    ],
2082                },
2083                configs::MaskCoefficients {
2084                    decoder: configs::DecoderType::Ultralytics,
2085                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2086                    shape: vec![1, 32, 8400],
2087                    dshape: vec![
2088                        (DimName::Batch, 1),
2089                        (DimName::NumProtos, 32),
2090                        (DimName::NumBoxes, 8400),
2091                    ],
2092                },
2093                configs::Protos {
2094                    decoder: configs::DecoderType::Ultralytics,
2095                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2096                    shape: vec![1, 160, 160, 32],
2097                    dshape: vec![
2098                        (DimName::Batch, 1),
2099                        (DimName::Height, 160),
2100                        (DimName::Width, 160),
2101                        (DimName::NumProtos, 32),
2102                    ],
2103                },
2104            )
2105            .with_score_threshold(score_threshold)
2106            .with_iou_threshold(iou_threshold)
2107            .build()
2108            .unwrap();
2109
2110        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2111        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2112
2113        decoder
2114            .decode_quantized(
2115                &[
2116                    boxes.slice(s![.., ..4, ..]).into(),
2117                    boxes.slice(s![.., 4..84, ..]).into(),
2118                    boxes.slice(s![.., 84.., ..]).into(),
2119                    protos.view().into(),
2120                ],
2121                &mut output_boxes,
2122                &mut output_masks,
2123            )
2124            .unwrap();
2125
2126        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2127        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2128        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2129        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2130        decode_yolo_segdet_float(
2131            seg.slice(s![0, .., ..]),
2132            protos.slice(s![0, .., .., ..]),
2133            score_threshold,
2134            iou_threshold,
2135            Some(configs::Nms::ClassAgnostic),
2136            &mut output_boxes_f32,
2137            &mut output_masks_f32,
2138        )
2139        .unwrap();
2140
2141        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2142        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2143
2144        decoder
2145            .decode_float(
2146                &[
2147                    seg.slice(s![.., ..4, ..]).into_dyn(),
2148                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2149                    seg.slice(s![.., 84.., ..]).into_dyn(),
2150                    protos.view().into_dyn(),
2151                ],
2152                &mut output_boxes1,
2153                &mut output_masks1,
2154            )
2155            .unwrap();
2156        compare_outputs(
2157            (&output_boxes, &output_boxes_f32),
2158            (&output_masks, &output_masks_f32),
2159        );
2160        compare_outputs(
2161            (&output_boxes_f32, &output_boxes1),
2162            (&output_masks_f32, &output_masks1),
2163        );
2164    }
2165
2166    #[test]
2167    fn test_decoder_masks_config_i32() {
2168        let score_threshold = 0.45;
2169        let iou_threshold = 0.45;
2170        let boxes = include_bytes!(concat!(
2171            env!("CARGO_MANIFEST_DIR"),
2172            "/../../testdata/yolov8_boxes_116x8400.bin"
2173        ));
2174        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2175        let scale = 1 << 23;
2176        let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2177        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2178
2179        let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2180
2181        let protos = include_bytes!(concat!(
2182            env!("CARGO_MANIFEST_DIR"),
2183            "/../../testdata/yolov8_protos_160x160x32.bin"
2184        ));
2185        let protos =
2186            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2187        let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2188        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2189        let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2190
2191        let decoder = DecoderBuilder::default()
2192            .with_config_yolo_split_segdet(
2193                configs::Boxes {
2194                    decoder: configs::DecoderType::Ultralytics,
2195                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2196                    shape: vec![1, 4, 8400],
2197                    dshape: vec![
2198                        (DimName::Batch, 1),
2199                        (DimName::BoxCoords, 4),
2200                        (DimName::NumBoxes, 8400),
2201                    ],
2202                    normalized: Some(true),
2203                },
2204                configs::Scores {
2205                    decoder: configs::DecoderType::Ultralytics,
2206                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2207                    shape: vec![1, 80, 8400],
2208                    dshape: vec![
2209                        (DimName::Batch, 1),
2210                        (DimName::NumClasses, 80),
2211                        (DimName::NumBoxes, 8400),
2212                    ],
2213                },
2214                configs::MaskCoefficients {
2215                    decoder: configs::DecoderType::Ultralytics,
2216                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2217                    shape: vec![1, 32, 8400],
2218                    dshape: vec![
2219                        (DimName::Batch, 1),
2220                        (DimName::NumProtos, 32),
2221                        (DimName::NumBoxes, 8400),
2222                    ],
2223                },
2224                configs::Protos {
2225                    decoder: configs::DecoderType::Ultralytics,
2226                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2227                    shape: vec![1, 160, 160, 32],
2228                    dshape: vec![
2229                        (DimName::Batch, 1),
2230                        (DimName::Height, 160),
2231                        (DimName::Width, 160),
2232                        (DimName::NumProtos, 32),
2233                    ],
2234                },
2235            )
2236            .with_score_threshold(score_threshold)
2237            .with_iou_threshold(iou_threshold)
2238            .build()
2239            .unwrap();
2240
2241        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2242        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2243
2244        decoder
2245            .decode_quantized(
2246                &[
2247                    boxes.slice(s![.., ..4, ..]).into(),
2248                    boxes.slice(s![.., 4..84, ..]).into(),
2249                    boxes.slice(s![.., 84.., ..]).into(),
2250                    protos.view().into(),
2251                ],
2252                &mut output_boxes,
2253                &mut output_masks,
2254            )
2255            .unwrap();
2256
2257        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2258        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2259        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2260        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2261        decode_yolo_segdet_float(
2262            seg.slice(s![0, .., ..]),
2263            protos.slice(s![0, .., .., ..]),
2264            score_threshold,
2265            iou_threshold,
2266            Some(configs::Nms::ClassAgnostic),
2267            &mut output_boxes_f32,
2268            &mut output_masks_f32,
2269        )
2270        .unwrap();
2271
2272        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2273        assert_eq!(output_masks.len(), output_masks_f32.len());
2274
2275        compare_outputs(
2276            (&output_boxes, &output_boxes_f32),
2277            (&output_masks, &output_masks_f32),
2278        );
2279    }
2280
2281    /// test running multiple decoders concurrently
2282    #[test]
2283    fn test_context_switch() {
2284        let yolo_det = || {
2285            let score_threshold = 0.25;
2286            let iou_threshold = 0.7;
2287            let out = include_bytes!(concat!(
2288                env!("CARGO_MANIFEST_DIR"),
2289                "/../../testdata/yolov8s_80_classes.bin"
2290            ));
2291            let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2292            let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2293            let quant = (0.0040811873, -123).into();
2294
2295            let decoder = DecoderBuilder::default()
2296                .with_config_yolo_det(
2297                    configs::Detection {
2298                        decoder: DecoderType::Ultralytics,
2299                        shape: vec![1, 84, 8400],
2300                        anchors: None,
2301                        quantization: Some(quant),
2302                        dshape: vec![
2303                            (DimName::Batch, 1),
2304                            (DimName::NumFeatures, 84),
2305                            (DimName::NumBoxes, 8400),
2306                        ],
2307                        normalized: None,
2308                    },
2309                    None,
2310                )
2311                .with_score_threshold(score_threshold)
2312                .with_iou_threshold(iou_threshold)
2313                .build()
2314                .unwrap();
2315
2316            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2317            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2318
2319            for _ in 0..100 {
2320                decoder
2321                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2322                    .unwrap();
2323
2324                assert!(output_boxes[0].equal_within_delta(
2325                    &DetectBox {
2326                        bbox: BoundingBox {
2327                            xmin: 0.5285137,
2328                            ymin: 0.05305544,
2329                            xmax: 0.87541467,
2330                            ymax: 0.9998909,
2331                        },
2332                        score: 0.5591227,
2333                        label: 0
2334                    },
2335                    1e-6
2336                ));
2337
2338                assert!(output_boxes[1].equal_within_delta(
2339                    &DetectBox {
2340                        bbox: BoundingBox {
2341                            xmin: 0.130598,
2342                            ymin: 0.43260583,
2343                            xmax: 0.35098213,
2344                            ymax: 0.9958097,
2345                        },
2346                        score: 0.33057618,
2347                        label: 75
2348                    },
2349                    1e-6
2350                ));
2351                assert!(output_masks.is_empty());
2352            }
2353        };
2354
2355        let modelpack_det_split = || {
2356            let score_threshold = 0.8;
2357            let iou_threshold = 0.5;
2358
2359            let seg = include_bytes!(concat!(
2360                env!("CARGO_MANIFEST_DIR"),
2361                "/../../testdata/modelpack_seg_2x160x160.bin"
2362            ));
2363            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2364
2365            let detect0 = include_bytes!(concat!(
2366                env!("CARGO_MANIFEST_DIR"),
2367                "/../../testdata/modelpack_split_9x15x18.bin"
2368            ));
2369            let detect0 =
2370                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2371
2372            let detect1 = include_bytes!(concat!(
2373                env!("CARGO_MANIFEST_DIR"),
2374                "/../../testdata/modelpack_split_17x30x18.bin"
2375            ));
2376            let detect1 =
2377                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2378
2379            let mut mask = seg.slice(s![0, .., .., ..]);
2380            mask.swap_axes(0, 1);
2381            mask.swap_axes(1, 2);
2382            let mask = [Segmentation {
2383                xmin: 0.0,
2384                ymin: 0.0,
2385                xmax: 1.0,
2386                ymax: 1.0,
2387                segmentation: mask.into_owned(),
2388            }];
2389            let correct_boxes = [DetectBox {
2390                bbox: BoundingBox {
2391                    xmin: 0.43171933,
2392                    ymin: 0.68243736,
2393                    xmax: 0.5626645,
2394                    ymax: 0.808863,
2395                },
2396                score: 0.99240804,
2397                label: 0,
2398            }];
2399
2400            let quant0 = (0.08547406643629074, 174).into();
2401            let quant1 = (0.09929127991199493, 183).into();
2402            let quant_seg = (1.0 / 255.0, 0).into();
2403
2404            let anchors0 = vec![
2405                [0.36666667461395264, 0.31481480598449707],
2406                [0.38749998807907104, 0.4740740656852722],
2407                [0.5333333611488342, 0.644444465637207],
2408            ];
2409            let anchors1 = vec![
2410                [0.13750000298023224, 0.2074074000120163],
2411                [0.2541666626930237, 0.21481481194496155],
2412                [0.23125000298023224, 0.35185185074806213],
2413            ];
2414
2415            let decoder = DecoderBuilder::default()
2416                .with_config_modelpack_segdet_split(
2417                    vec![
2418                        configs::Detection {
2419                            decoder: DecoderType::ModelPack,
2420                            shape: vec![1, 17, 30, 18],
2421                            anchors: Some(anchors1),
2422                            quantization: Some(quant1),
2423                            dshape: vec![
2424                                (DimName::Batch, 1),
2425                                (DimName::Height, 17),
2426                                (DimName::Width, 30),
2427                                (DimName::NumAnchorsXFeatures, 18),
2428                            ],
2429                            normalized: None,
2430                        },
2431                        configs::Detection {
2432                            decoder: DecoderType::ModelPack,
2433                            shape: vec![1, 9, 15, 18],
2434                            anchors: Some(anchors0),
2435                            quantization: Some(quant0),
2436                            dshape: vec![
2437                                (DimName::Batch, 1),
2438                                (DimName::Height, 9),
2439                                (DimName::Width, 15),
2440                                (DimName::NumAnchorsXFeatures, 18),
2441                            ],
2442                            normalized: None,
2443                        },
2444                    ],
2445                    configs::Segmentation {
2446                        decoder: DecoderType::ModelPack,
2447                        quantization: Some(quant_seg),
2448                        shape: vec![1, 2, 160, 160],
2449                        dshape: vec![
2450                            (DimName::Batch, 1),
2451                            (DimName::NumClasses, 2),
2452                            (DimName::Height, 160),
2453                            (DimName::Width, 160),
2454                        ],
2455                    },
2456                )
2457                .with_score_threshold(score_threshold)
2458                .with_iou_threshold(iou_threshold)
2459                .build()
2460                .unwrap();
2461            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2462            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2463
2464            for _ in 0..100 {
2465                decoder
2466                    .decode_quantized(
2467                        &[
2468                            detect0.view().into(),
2469                            detect1.view().into(),
2470                            seg.view().into(),
2471                        ],
2472                        &mut output_boxes,
2473                        &mut output_masks,
2474                    )
2475                    .unwrap();
2476
2477                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2478            }
2479        };
2480
2481        let handles = vec![
2482            std::thread::spawn(yolo_det),
2483            std::thread::spawn(modelpack_det_split),
2484            std::thread::spawn(yolo_det),
2485            std::thread::spawn(modelpack_det_split),
2486            std::thread::spawn(yolo_det),
2487            std::thread::spawn(modelpack_det_split),
2488            std::thread::spawn(yolo_det),
2489            std::thread::spawn(modelpack_det_split),
2490        ];
2491        for handle in handles {
2492            handle.join().unwrap();
2493        }
2494    }
2495
2496    #[test]
2497    fn test_ndarray_to_xyxy_float() {
2498        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2499        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2500        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2501
2502        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2503        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2504        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2505    }
2506
2507    #[test]
2508    fn test_class_aware_nms_float() {
2509        use crate::float::nms_class_aware_float;
2510
2511        // Create two overlapping boxes with different classes
2512        let boxes = vec![
2513            DetectBox {
2514                bbox: BoundingBox {
2515                    xmin: 0.0,
2516                    ymin: 0.0,
2517                    xmax: 0.5,
2518                    ymax: 0.5,
2519                },
2520                score: 0.9,
2521                label: 0, // class 0
2522            },
2523            DetectBox {
2524                bbox: BoundingBox {
2525                    xmin: 0.1,
2526                    ymin: 0.1,
2527                    xmax: 0.6,
2528                    ymax: 0.6,
2529                },
2530                score: 0.8,
2531                label: 1, // class 1 - different class
2532            },
2533        ];
2534
2535        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2536        // threshold 0.3)
2537        let result = nms_class_aware_float(0.3, boxes.clone());
2538        assert_eq!(
2539            result.len(),
2540            2,
2541            "Class-aware NMS should keep both boxes with different classes"
2542        );
2543
2544        // Now test with same class - should suppress one
2545        let same_class_boxes = vec![
2546            DetectBox {
2547                bbox: BoundingBox {
2548                    xmin: 0.0,
2549                    ymin: 0.0,
2550                    xmax: 0.5,
2551                    ymax: 0.5,
2552                },
2553                score: 0.9,
2554                label: 0,
2555            },
2556            DetectBox {
2557                bbox: BoundingBox {
2558                    xmin: 0.1,
2559                    ymin: 0.1,
2560                    xmax: 0.6,
2561                    ymax: 0.6,
2562                },
2563                score: 0.8,
2564                label: 0, // same class
2565            },
2566        ];
2567
2568        let result = nms_class_aware_float(0.3, same_class_boxes);
2569        assert_eq!(
2570            result.len(),
2571            1,
2572            "Class-aware NMS should suppress overlapping box with same class"
2573        );
2574        assert_eq!(result[0].label, 0);
2575        assert!((result[0].score - 0.9).abs() < 1e-6);
2576    }
2577
2578    #[test]
2579    fn test_class_agnostic_vs_aware_nms() {
2580        use crate::float::{nms_class_aware_float, nms_float};
2581
2582        // Two overlapping boxes with different classes
2583        let boxes = vec![
2584            DetectBox {
2585                bbox: BoundingBox {
2586                    xmin: 0.0,
2587                    ymin: 0.0,
2588                    xmax: 0.5,
2589                    ymax: 0.5,
2590                },
2591                score: 0.9,
2592                label: 0,
2593            },
2594            DetectBox {
2595                bbox: BoundingBox {
2596                    xmin: 0.1,
2597                    ymin: 0.1,
2598                    xmax: 0.6,
2599                    ymax: 0.6,
2600                },
2601                score: 0.8,
2602                label: 1,
2603            },
2604        ];
2605
2606        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2607        let agnostic_result = nms_float(0.3, boxes.clone());
2608        assert_eq!(
2609            agnostic_result.len(),
2610            1,
2611            "Class-agnostic NMS should suppress overlapping boxes"
2612        );
2613
2614        // Class-aware should keep both (different classes)
2615        let aware_result = nms_class_aware_float(0.3, boxes);
2616        assert_eq!(
2617            aware_result.len(),
2618            2,
2619            "Class-aware NMS should keep boxes with different classes"
2620        );
2621    }
2622
2623    #[test]
2624    fn test_class_aware_nms_int() {
2625        use crate::byte::nms_class_aware_int;
2626
2627        // Create two overlapping boxes with different classes
2628        let boxes = vec![
2629            DetectBoxQuantized {
2630                bbox: BoundingBox {
2631                    xmin: 0.0,
2632                    ymin: 0.0,
2633                    xmax: 0.5,
2634                    ymax: 0.5,
2635                },
2636                score: 200_u8,
2637                label: 0,
2638            },
2639            DetectBoxQuantized {
2640                bbox: BoundingBox {
2641                    xmin: 0.1,
2642                    ymin: 0.1,
2643                    xmax: 0.6,
2644                    ymax: 0.6,
2645                },
2646                score: 180_u8,
2647                label: 1, // different class
2648            },
2649        ];
2650
2651        // Should keep both (different classes)
2652        let result = nms_class_aware_int(0.5, boxes);
2653        assert_eq!(
2654            result.len(),
2655            2,
2656            "Class-aware NMS (int) should keep boxes with different classes"
2657        );
2658    }
2659
2660    #[test]
2661    fn test_nms_enum_default() {
2662        // Test that Nms enum has the correct default
2663        let default_nms: configs::Nms = Default::default();
2664        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2665    }
2666
2667    #[test]
2668    fn test_decoder_nms_mode() {
2669        // Test that decoder properly stores NMS mode
2670        let decoder = DecoderBuilder::default()
2671            .with_config_yolo_det(
2672                configs::Detection {
2673                    anchors: None,
2674                    decoder: DecoderType::Ultralytics,
2675                    quantization: None,
2676                    shape: vec![1, 84, 8400],
2677                    dshape: Vec::new(),
2678                    normalized: Some(true),
2679                },
2680                None,
2681            )
2682            .with_nms(Some(configs::Nms::ClassAware))
2683            .build()
2684            .unwrap();
2685
2686        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2687    }
2688
2689    #[test]
2690    fn test_decoder_nms_bypass() {
2691        // Test that decoder can be configured with nms=None (bypass)
2692        let decoder = DecoderBuilder::default()
2693            .with_config_yolo_det(
2694                configs::Detection {
2695                    anchors: None,
2696                    decoder: DecoderType::Ultralytics,
2697                    quantization: None,
2698                    shape: vec![1, 84, 8400],
2699                    dshape: Vec::new(),
2700                    normalized: Some(true),
2701                },
2702                None,
2703            )
2704            .with_nms(None)
2705            .build()
2706            .unwrap();
2707
2708        assert_eq!(decoder.nms, None);
2709    }
2710
2711    #[test]
2712    fn test_decoder_normalized_boxes_true() {
2713        // Test that normalized_boxes returns Some(true) when explicitly set
2714        let decoder = DecoderBuilder::default()
2715            .with_config_yolo_det(
2716                configs::Detection {
2717                    anchors: None,
2718                    decoder: DecoderType::Ultralytics,
2719                    quantization: None,
2720                    shape: vec![1, 84, 8400],
2721                    dshape: Vec::new(),
2722                    normalized: Some(true),
2723                },
2724                None,
2725            )
2726            .build()
2727            .unwrap();
2728
2729        assert_eq!(decoder.normalized_boxes(), Some(true));
2730    }
2731
2732    #[test]
2733    fn test_decoder_normalized_boxes_false() {
2734        // Test that normalized_boxes returns Some(false) when config specifies
2735        // unnormalized
2736        let decoder = DecoderBuilder::default()
2737            .with_config_yolo_det(
2738                configs::Detection {
2739                    anchors: None,
2740                    decoder: DecoderType::Ultralytics,
2741                    quantization: None,
2742                    shape: vec![1, 84, 8400],
2743                    dshape: Vec::new(),
2744                    normalized: Some(false),
2745                },
2746                None,
2747            )
2748            .build()
2749            .unwrap();
2750
2751        assert_eq!(decoder.normalized_boxes(), Some(false));
2752    }
2753
2754    #[test]
2755    fn test_decoder_normalized_boxes_unknown() {
2756        // Test that normalized_boxes returns None when not specified in config
2757        let decoder = DecoderBuilder::default()
2758            .with_config_yolo_det(
2759                configs::Detection {
2760                    anchors: None,
2761                    decoder: DecoderType::Ultralytics,
2762                    quantization: None,
2763                    shape: vec![1, 84, 8400],
2764                    dshape: Vec::new(),
2765                    normalized: None,
2766                },
2767                Some(DecoderVersion::Yolo11),
2768            )
2769            .build()
2770            .unwrap();
2771
2772        assert_eq!(decoder.normalized_boxes(), None);
2773    }
2774}
2775
2776#[cfg(feature = "tracker")]
2777#[cfg(test)]
2778#[cfg_attr(coverage_nightly, coverage(off))]
2779mod decoder_tracked_tests {
2780
2781    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
2782    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
2783    use num_traits::{AsPrimitive, Float, PrimInt};
2784    use rand::{RngExt, SeedableRng};
2785    use rand_distr::StandardNormal;
2786
2787    use crate::{
2788        configs::{self, DimName},
2789        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
2790    };
2791
2792    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2793        input: ArrayView<F, D>,
2794        quant: Quantization,
2795    ) -> Array<T, D>
2796    where
2797        i32: num_traits::AsPrimitive<F>,
2798        f32: num_traits::AsPrimitive<F>,
2799    {
2800        let zero_point = quant.zero_point.as_();
2801        let div_scale = F::one() / quant.scale.as_();
2802        if zero_point != F::zero() {
2803            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2804        } else {
2805            input.mapv(|d| (d * div_scale).round().as_())
2806        }
2807    }
2808
2809    #[test]
2810    fn test_decoder_tracked_random_jitter() {
2811        use crate::configs::{DecoderType, Nms};
2812        use crate::DecoderBuilder;
2813
2814        let score_threshold = 0.25;
2815        let iou_threshold = 0.1;
2816        let out = include_bytes!(concat!(
2817            env!("CARGO_MANIFEST_DIR"),
2818            "/../../testdata/yolov8s_80_classes.bin"
2819        ));
2820        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2821        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2822        let quant = (0.0040811873, -123).into();
2823
2824        let decoder = DecoderBuilder::default()
2825            .with_config_yolo_det(
2826                crate::configs::Detection {
2827                    decoder: DecoderType::Ultralytics,
2828                    shape: vec![1, 84, 8400],
2829                    anchors: None,
2830                    quantization: Some(quant),
2831                    dshape: vec![
2832                        (crate::configs::DimName::Batch, 1),
2833                        (crate::configs::DimName::NumFeatures, 84),
2834                        (crate::configs::DimName::NumBoxes, 8400),
2835                    ],
2836                    normalized: Some(true),
2837                },
2838                None,
2839            )
2840            .with_score_threshold(score_threshold)
2841            .with_iou_threshold(iou_threshold)
2842            .with_nms(Some(Nms::ClassAgnostic))
2843            .build()
2844            .unwrap();
2845        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
2846
2847        let expected_boxes = [
2848            crate::DetectBox {
2849                bbox: crate::BoundingBox {
2850                    xmin: 0.5285137,
2851                    ymin: 0.05305544,
2852                    xmax: 0.87541467,
2853                    ymax: 0.9998909,
2854                },
2855                score: 0.5591227,
2856                label: 0,
2857            },
2858            crate::DetectBox {
2859                bbox: crate::BoundingBox {
2860                    xmin: 0.130598,
2861                    ymin: 0.43260583,
2862                    xmax: 0.35098213,
2863                    ymax: 0.9958097,
2864                },
2865                score: 0.33057618,
2866                label: 75,
2867            },
2868        ];
2869
2870        let mut tracker = ByteTrackBuilder::new()
2871            .track_update(0.1)
2872            .track_high_conf(0.3)
2873            .build();
2874
2875        let mut output_boxes = Vec::with_capacity(50);
2876        let mut output_masks = Vec::with_capacity(50);
2877        let mut output_tracks = Vec::with_capacity(50);
2878
2879        decoder
2880            .decode_tracked_quantized(
2881                &mut tracker,
2882                0,
2883                &[out.view().into()],
2884                &mut output_boxes,
2885                &mut output_masks,
2886                &mut output_tracks,
2887            )
2888            .unwrap();
2889
2890        assert_eq!(output_boxes.len(), 2);
2891        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
2892        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
2893
2894        let mut last_boxes = output_boxes.clone();
2895
2896        for i in 1..=100 {
2897            let mut out = out.clone();
2898            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
2899            let mut x_values = out.slice_mut(s![0, 0, ..]);
2900            for x in x_values.iter_mut() {
2901                let r: f32 = rng.sample(StandardNormal);
2902                let r = r.clamp(-2.0, 2.0) / 2.0;
2903                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
2904            }
2905
2906            let mut y_values = out.slice_mut(s![0, 1, ..]);
2907            for y in y_values.iter_mut() {
2908                let r: f32 = rng.sample(StandardNormal);
2909                let r = r.clamp(-2.0, 2.0) / 2.0;
2910                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
2911            }
2912
2913            decoder
2914                .decode_tracked_quantized(
2915                    &mut tracker,
2916                    100_000_000 * i / 3, // simulate 33.333ms between frames
2917                    &[out.view().into()],
2918                    &mut output_boxes,
2919                    &mut output_masks,
2920                    &mut output_tracks,
2921                )
2922                .unwrap();
2923
2924            assert_eq!(output_boxes.len(), 2);
2925            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
2926            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
2927
2928            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
2929            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
2930            last_boxes = output_boxes.clone();
2931        }
2932    }
2933
2934    #[test]
2935    fn test_decoder_tracked_segdet() {
2936        use crate::configs::Nms;
2937        use crate::DecoderBuilder;
2938
2939        let score_threshold = 0.45;
2940        let iou_threshold = 0.45;
2941        let boxes = include_bytes!(concat!(
2942            env!("CARGO_MANIFEST_DIR"),
2943            "/../../testdata/yolov8_boxes_116x8400.bin"
2944        ));
2945        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2946        let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
2947
2948        let protos = include_bytes!(concat!(
2949            env!("CARGO_MANIFEST_DIR"),
2950            "/../../testdata/yolov8_protos_160x160x32.bin"
2951        ));
2952        let protos =
2953            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2954        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2955
2956        let config = include_str!(concat!(
2957            env!("CARGO_MANIFEST_DIR"),
2958            "/../../testdata/yolov8_seg.yaml"
2959        ));
2960
2961        let decoder = DecoderBuilder::default()
2962            .with_config_yaml_str(config.to_string())
2963            .with_score_threshold(score_threshold)
2964            .with_iou_threshold(iou_threshold)
2965            .with_nms(Some(Nms::ClassAgnostic))
2966            .build()
2967            .unwrap();
2968
2969        let expected_boxes = [
2970            DetectBox {
2971                bbox: BoundingBox {
2972                    xmin: 0.08515105,
2973                    ymin: 0.7131401,
2974                    xmax: 0.29802868,
2975                    ymax: 0.8195788,
2976                },
2977                score: 0.91537374,
2978                label: 23,
2979            },
2980            DetectBox {
2981                bbox: BoundingBox {
2982                    xmin: 0.59605736,
2983                    ymin: 0.25545314,
2984                    xmax: 0.93666154,
2985                    ymax: 0.72378385,
2986                },
2987                score: 0.91537374,
2988                label: 23,
2989            },
2990        ];
2991
2992        let mut tracker = ByteTrackBuilder::new()
2993            .track_update(0.1)
2994            .track_high_conf(0.7)
2995            .build();
2996
2997        let mut output_boxes = Vec::with_capacity(50);
2998        let mut output_masks = Vec::with_capacity(50);
2999        let mut output_tracks = Vec::with_capacity(50);
3000
3001        decoder
3002            .decode_tracked_quantized(
3003                &mut tracker,
3004                0,
3005                &[boxes.view().into(), protos.view().into()],
3006                &mut output_boxes,
3007                &mut output_masks,
3008                &mut output_tracks,
3009            )
3010            .unwrap();
3011
3012        assert_eq!(output_boxes.len(), 2);
3013        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3014        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3015
3016        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3017        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3018        for score in scores_values.iter_mut() {
3019            *score = i8::MIN; // set all scores to minimum to simulate no detections
3020        }
3021        decoder
3022            .decode_tracked_quantized(
3023                &mut tracker,
3024                100_000_000 / 3,
3025                &[boxes.view().into(), protos.view().into()],
3026                &mut output_boxes,
3027                &mut output_masks,
3028                &mut output_tracks,
3029            )
3030            .unwrap();
3031
3032        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3033        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3034
3035        // no masks when the boxes are from tracker prediction without a matching detection
3036        assert!(output_masks.is_empty())
3037    }
3038
3039    #[test]
3040    fn test_decoder_tracked_segdet_float() {
3041        use crate::configs::Nms;
3042        use crate::DecoderBuilder;
3043
3044        let score_threshold = 0.45;
3045        let iou_threshold = 0.45;
3046        let boxes = include_bytes!(concat!(
3047            env!("CARGO_MANIFEST_DIR"),
3048            "/../../testdata/yolov8_boxes_116x8400.bin"
3049        ));
3050        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3051        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3052        let quant_boxes = (0.021287762, 31);
3053        let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3054
3055        let protos = include_bytes!(concat!(
3056            env!("CARGO_MANIFEST_DIR"),
3057            "/../../testdata/yolov8_protos_160x160x32.bin"
3058        ));
3059        let protos =
3060            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3061        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3062        let quant_protos = (0.02491162, -117);
3063        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3064
3065        let config = include_str!(concat!(
3066            env!("CARGO_MANIFEST_DIR"),
3067            "/../../testdata/yolov8_seg.yaml"
3068        ));
3069
3070        let decoder = DecoderBuilder::default()
3071            .with_config_yaml_str(config.to_string())
3072            .with_score_threshold(score_threshold)
3073            .with_iou_threshold(iou_threshold)
3074            .with_nms(Some(Nms::ClassAgnostic))
3075            .build()
3076            .unwrap();
3077
3078        let expected_boxes = [
3079            DetectBox {
3080                bbox: BoundingBox {
3081                    xmin: 0.08515105,
3082                    ymin: 0.7131401,
3083                    xmax: 0.29802868,
3084                    ymax: 0.8195788,
3085                },
3086                score: 0.91537374,
3087                label: 23,
3088            },
3089            DetectBox {
3090                bbox: BoundingBox {
3091                    xmin: 0.59605736,
3092                    ymin: 0.25545314,
3093                    xmax: 0.93666154,
3094                    ymax: 0.72378385,
3095                },
3096                score: 0.91537374,
3097                label: 23,
3098            },
3099        ];
3100
3101        let mut tracker = ByteTrackBuilder::new()
3102            .track_update(0.1)
3103            .track_high_conf(0.7)
3104            .build();
3105
3106        let mut output_boxes = Vec::with_capacity(50);
3107        let mut output_masks = Vec::with_capacity(50);
3108        let mut output_tracks = Vec::with_capacity(50);
3109
3110        decoder
3111            .decode_tracked_float(
3112                &mut tracker,
3113                0,
3114                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3115                &mut output_boxes,
3116                &mut output_masks,
3117                &mut output_tracks,
3118            )
3119            .unwrap();
3120
3121        assert_eq!(output_boxes.len(), 2);
3122        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3123        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3124
3125        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3126        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3127        for score in scores_values.iter_mut() {
3128            *score = 0.0; // set all scores to minimum to simulate no detections
3129        }
3130        decoder
3131            .decode_tracked_float(
3132                &mut tracker,
3133                100_000_000 / 3,
3134                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3135                &mut output_boxes,
3136                &mut output_masks,
3137                &mut output_tracks,
3138            )
3139            .unwrap();
3140
3141        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3142        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3143
3144        // no masks when the boxes are from tracker prediction without a matching detection
3145        assert!(output_masks.is_empty())
3146    }
3147
3148    #[test]
3149    fn test_decoder_tracked_segdet_proto() {
3150        use crate::configs::Nms;
3151        use crate::DecoderBuilder;
3152
3153        let score_threshold = 0.45;
3154        let iou_threshold = 0.45;
3155        let boxes = include_bytes!(concat!(
3156            env!("CARGO_MANIFEST_DIR"),
3157            "/../../testdata/yolov8_boxes_116x8400.bin"
3158        ));
3159        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3160        let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3161
3162        let protos = include_bytes!(concat!(
3163            env!("CARGO_MANIFEST_DIR"),
3164            "/../../testdata/yolov8_protos_160x160x32.bin"
3165        ));
3166        let protos =
3167            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3168        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3169
3170        let config = include_str!(concat!(
3171            env!("CARGO_MANIFEST_DIR"),
3172            "/../../testdata/yolov8_seg.yaml"
3173        ));
3174
3175        let decoder = DecoderBuilder::default()
3176            .with_config_yaml_str(config.to_string())
3177            .with_score_threshold(score_threshold)
3178            .with_iou_threshold(iou_threshold)
3179            .with_nms(Some(Nms::ClassAgnostic))
3180            .build()
3181            .unwrap();
3182
3183        let expected_boxes = [
3184            DetectBox {
3185                bbox: BoundingBox {
3186                    xmin: 0.08515105,
3187                    ymin: 0.7131401,
3188                    xmax: 0.29802868,
3189                    ymax: 0.8195788,
3190                },
3191                score: 0.91537374,
3192                label: 23,
3193            },
3194            DetectBox {
3195                bbox: BoundingBox {
3196                    xmin: 0.59605736,
3197                    ymin: 0.25545314,
3198                    xmax: 0.93666154,
3199                    ymax: 0.72378385,
3200                },
3201                score: 0.91537374,
3202                label: 23,
3203            },
3204        ];
3205
3206        let mut tracker = ByteTrackBuilder::new()
3207            .track_update(0.1)
3208            .track_high_conf(0.7)
3209            .build();
3210
3211        let mut output_boxes = Vec::with_capacity(50);
3212        let mut output_tracks = Vec::with_capacity(50);
3213
3214        decoder
3215            .decode_tracked_quantized_proto(
3216                &mut tracker,
3217                0,
3218                &[boxes.view().into(), protos.view().into()],
3219                &mut output_boxes,
3220                &mut output_tracks,
3221            )
3222            .unwrap();
3223
3224        assert_eq!(output_boxes.len(), 2);
3225        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3226        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3227
3228        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3229        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3230        for score in scores_values.iter_mut() {
3231            *score = i8::MIN; // set all scores to minimum to simulate no detections
3232        }
3233        let protos = decoder
3234            .decode_tracked_quantized_proto(
3235                &mut tracker,
3236                100_000_000 / 3,
3237                &[boxes.view().into(), protos.view().into()],
3238                &mut output_boxes,
3239                &mut output_tracks,
3240            )
3241            .unwrap();
3242
3243        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3244        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3245
3246        // no masks when the boxes are from tracker prediction without a matching detection
3247        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3248    }
3249
3250    #[test]
3251    fn test_decoder_tracked_segdet_proto_float() {
3252        use crate::configs::Nms;
3253        use crate::DecoderBuilder;
3254
3255        let score_threshold = 0.45;
3256        let iou_threshold = 0.45;
3257        let boxes = include_bytes!(concat!(
3258            env!("CARGO_MANIFEST_DIR"),
3259            "/../../testdata/yolov8_boxes_116x8400.bin"
3260        ));
3261        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3262        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3263        let quant_boxes = (0.021287762, 31);
3264        let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3265
3266        let protos = include_bytes!(concat!(
3267            env!("CARGO_MANIFEST_DIR"),
3268            "/../../testdata/yolov8_protos_160x160x32.bin"
3269        ));
3270        let protos =
3271            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3272        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3273        let quant_protos = (0.02491162, -117);
3274        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3275
3276        let config = include_str!(concat!(
3277            env!("CARGO_MANIFEST_DIR"),
3278            "/../../testdata/yolov8_seg.yaml"
3279        ));
3280
3281        let decoder = DecoderBuilder::default()
3282            .with_config_yaml_str(config.to_string())
3283            .with_score_threshold(score_threshold)
3284            .with_iou_threshold(iou_threshold)
3285            .with_nms(Some(Nms::ClassAgnostic))
3286            .build()
3287            .unwrap();
3288
3289        let expected_boxes = [
3290            DetectBox {
3291                bbox: BoundingBox {
3292                    xmin: 0.08515105,
3293                    ymin: 0.7131401,
3294                    xmax: 0.29802868,
3295                    ymax: 0.8195788,
3296                },
3297                score: 0.91537374,
3298                label: 23,
3299            },
3300            DetectBox {
3301                bbox: BoundingBox {
3302                    xmin: 0.59605736,
3303                    ymin: 0.25545314,
3304                    xmax: 0.93666154,
3305                    ymax: 0.72378385,
3306                },
3307                score: 0.91537374,
3308                label: 23,
3309            },
3310        ];
3311
3312        let mut tracker = ByteTrackBuilder::new()
3313            .track_update(0.1)
3314            .track_high_conf(0.7)
3315            .build();
3316
3317        let mut output_boxes = Vec::with_capacity(50);
3318        let mut output_tracks = Vec::with_capacity(50);
3319
3320        decoder
3321            .decode_tracked_float_proto(
3322                &mut tracker,
3323                0,
3324                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3325                &mut output_boxes,
3326                &mut output_tracks,
3327            )
3328            .unwrap();
3329
3330        assert_eq!(output_boxes.len(), 2);
3331        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3332        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3333
3334        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3335        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3336        for score in scores_values.iter_mut() {
3337            *score = 0.0; // set all scores to minimum to simulate no detections
3338        }
3339        let protos = decoder
3340            .decode_tracked_float_proto(
3341                &mut tracker,
3342                100_000_000 / 3,
3343                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3344                &mut output_boxes,
3345                &mut output_tracks,
3346            )
3347            .unwrap();
3348
3349        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3350        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3351
3352        // no masks when the boxes are from tracker prediction without a matching detection
3353        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3354    }
3355
3356    #[test]
3357    fn test_decoder_tracked_segdet_split() {
3358        let score_threshold = 0.45;
3359        let iou_threshold = 0.45;
3360
3361        let boxes = include_bytes!(concat!(
3362            env!("CARGO_MANIFEST_DIR"),
3363            "/../../testdata/yolov8_boxes_116x8400.bin"
3364        ));
3365        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3366        let boxes = boxes.to_vec();
3367        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3368
3369        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3370        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3371        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3372
3373        let quant_boxes = (0.021287762, 31);
3374
3375        let protos = include_bytes!(concat!(
3376            env!("CARGO_MANIFEST_DIR"),
3377            "/../../testdata/yolov8_protos_160x160x32.bin"
3378        ));
3379        let protos =
3380            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3381        let protos = protos.to_vec();
3382        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3383        let quant_protos = (0.02491162, -117);
3384        let decoder = DecoderBuilder::default()
3385            .with_config_yolo_split_segdet(
3386                configs::Boxes {
3387                    decoder: configs::DecoderType::Ultralytics,
3388                    quantization: Some(quant_boxes.into()),
3389                    shape: vec![1, 4, 8400],
3390                    dshape: vec![
3391                        (DimName::Batch, 1),
3392                        (DimName::BoxCoords, 4),
3393                        (DimName::NumBoxes, 8400),
3394                    ],
3395                    normalized: Some(true),
3396                },
3397                configs::Scores {
3398                    decoder: configs::DecoderType::Ultralytics,
3399                    quantization: Some(quant_boxes.into()),
3400                    shape: vec![1, 80, 8400],
3401                    dshape: vec![
3402                        (DimName::Batch, 1),
3403                        (DimName::NumClasses, 80),
3404                        (DimName::NumBoxes, 8400),
3405                    ],
3406                },
3407                configs::MaskCoefficients {
3408                    decoder: configs::DecoderType::Ultralytics,
3409                    quantization: Some(quant_boxes.into()),
3410                    shape: vec![1, 32, 8400],
3411                    dshape: vec![
3412                        (DimName::Batch, 1),
3413                        (DimName::NumProtos, 32),
3414                        (DimName::NumBoxes, 8400),
3415                    ],
3416                },
3417                configs::Protos {
3418                    decoder: configs::DecoderType::Ultralytics,
3419                    quantization: Some(quant_protos.into()),
3420                    shape: vec![1, 160, 160, 32],
3421                    dshape: vec![
3422                        (DimName::Batch, 1),
3423                        (DimName::Height, 160),
3424                        (DimName::Width, 160),
3425                        (DimName::NumProtos, 32),
3426                    ],
3427                },
3428            )
3429            .with_score_threshold(score_threshold)
3430            .with_iou_threshold(iou_threshold)
3431            .build()
3432            .unwrap();
3433
3434        let expected_boxes = [
3435            DetectBox {
3436                bbox: BoundingBox {
3437                    xmin: 0.08515105,
3438                    ymin: 0.7131401,
3439                    xmax: 0.29802868,
3440                    ymax: 0.8195788,
3441                },
3442                score: 0.91537374,
3443                label: 23,
3444            },
3445            DetectBox {
3446                bbox: BoundingBox {
3447                    xmin: 0.59605736,
3448                    ymin: 0.25545314,
3449                    xmax: 0.93666154,
3450                    ymax: 0.72378385,
3451                },
3452                score: 0.91537374,
3453                label: 23,
3454            },
3455        ];
3456
3457        let mut tracker = ByteTrackBuilder::new()
3458            .track_update(0.1)
3459            .track_high_conf(0.7)
3460            .build();
3461
3462        let mut output_boxes = Vec::with_capacity(50);
3463        let mut output_masks = Vec::with_capacity(50);
3464        let mut output_tracks = Vec::with_capacity(50);
3465
3466        decoder
3467            .decode_tracked_quantized(
3468                &mut tracker,
3469                0,
3470                &[
3471                    boxes.view().into(),
3472                    scores.view().into(),
3473                    mask.view().into(),
3474                    protos.view().into(),
3475                ],
3476                &mut output_boxes,
3477                &mut output_masks,
3478                &mut output_tracks,
3479            )
3480            .unwrap();
3481
3482        assert_eq!(output_boxes.len(), 2);
3483        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3484        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3485
3486        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3487
3488        for score in scores.iter_mut() {
3489            *score = i8::MIN; // set all scores to minimum to simulate no detections
3490        }
3491        decoder
3492            .decode_tracked_quantized(
3493                &mut tracker,
3494                100_000_000 / 3,
3495                &[
3496                    boxes.view().into(),
3497                    scores.view().into(),
3498                    mask.view().into(),
3499                    protos.view().into(),
3500                ],
3501                &mut output_boxes,
3502                &mut output_masks,
3503                &mut output_tracks,
3504            )
3505            .unwrap();
3506
3507        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3508        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3509
3510        // no masks when the boxes are from tracker prediction without a matching detection
3511        assert!(output_masks.is_empty())
3512    }
3513
3514    #[test]
3515    fn test_decoder_tracked_segdet_split_float() {
3516        let score_threshold = 0.45;
3517        let iou_threshold = 0.45;
3518
3519        let boxes = include_bytes!(concat!(
3520            env!("CARGO_MANIFEST_DIR"),
3521            "/../../testdata/yolov8_boxes_116x8400.bin"
3522        ));
3523        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3524        let boxes = boxes.to_vec();
3525        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3526        let quant_boxes = (0.021287762, 31);
3527        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3528
3529        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3530        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3531        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3532
3533        let protos = include_bytes!(concat!(
3534            env!("CARGO_MANIFEST_DIR"),
3535            "/../../testdata/yolov8_protos_160x160x32.bin"
3536        ));
3537        let protos =
3538            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3539        let protos = protos.to_vec();
3540        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3541        let quant_protos = (0.02491162, -117);
3542        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3543
3544        let decoder = DecoderBuilder::default()
3545            .with_config_yolo_split_segdet(
3546                configs::Boxes {
3547                    decoder: configs::DecoderType::Ultralytics,
3548                    quantization: Some(quant_boxes.into()),
3549                    shape: vec![1, 4, 8400],
3550                    dshape: vec![
3551                        (DimName::Batch, 1),
3552                        (DimName::BoxCoords, 4),
3553                        (DimName::NumBoxes, 8400),
3554                    ],
3555                    normalized: Some(true),
3556                },
3557                configs::Scores {
3558                    decoder: configs::DecoderType::Ultralytics,
3559                    quantization: Some(quant_boxes.into()),
3560                    shape: vec![1, 80, 8400],
3561                    dshape: vec![
3562                        (DimName::Batch, 1),
3563                        (DimName::NumClasses, 80),
3564                        (DimName::NumBoxes, 8400),
3565                    ],
3566                },
3567                configs::MaskCoefficients {
3568                    decoder: configs::DecoderType::Ultralytics,
3569                    quantization: Some(quant_boxes.into()),
3570                    shape: vec![1, 32, 8400],
3571                    dshape: vec![
3572                        (DimName::Batch, 1),
3573                        (DimName::NumProtos, 32),
3574                        (DimName::NumBoxes, 8400),
3575                    ],
3576                },
3577                configs::Protos {
3578                    decoder: configs::DecoderType::Ultralytics,
3579                    quantization: Some(quant_protos.into()),
3580                    shape: vec![1, 160, 160, 32],
3581                    dshape: vec![
3582                        (DimName::Batch, 1),
3583                        (DimName::Height, 160),
3584                        (DimName::Width, 160),
3585                        (DimName::NumProtos, 32),
3586                    ],
3587                },
3588            )
3589            .with_score_threshold(score_threshold)
3590            .with_iou_threshold(iou_threshold)
3591            .build()
3592            .unwrap();
3593
3594        let expected_boxes = [
3595            DetectBox {
3596                bbox: BoundingBox {
3597                    xmin: 0.08515105,
3598                    ymin: 0.7131401,
3599                    xmax: 0.29802868,
3600                    ymax: 0.8195788,
3601                },
3602                score: 0.91537374,
3603                label: 23,
3604            },
3605            DetectBox {
3606                bbox: BoundingBox {
3607                    xmin: 0.59605736,
3608                    ymin: 0.25545314,
3609                    xmax: 0.93666154,
3610                    ymax: 0.72378385,
3611                },
3612                score: 0.91537374,
3613                label: 23,
3614            },
3615        ];
3616
3617        let mut tracker = ByteTrackBuilder::new()
3618            .track_update(0.1)
3619            .track_high_conf(0.7)
3620            .build();
3621
3622        let mut output_boxes = Vec::with_capacity(50);
3623        let mut output_masks = Vec::with_capacity(50);
3624        let mut output_tracks = Vec::with_capacity(50);
3625
3626        decoder
3627            .decode_tracked_float(
3628                &mut tracker,
3629                0,
3630                &[
3631                    boxes.view().into_dyn(),
3632                    scores.view().into_dyn(),
3633                    mask.view().into_dyn(),
3634                    protos.view().into_dyn(),
3635                ],
3636                &mut output_boxes,
3637                &mut output_masks,
3638                &mut output_tracks,
3639            )
3640            .unwrap();
3641
3642        assert_eq!(output_boxes.len(), 2);
3643        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3644        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3645
3646        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3647
3648        for score in scores.iter_mut() {
3649            *score = 0.0; // set all scores to minimum to simulate no detections
3650        }
3651        decoder
3652            .decode_tracked_float(
3653                &mut tracker,
3654                100_000_000 / 3,
3655                &[
3656                    boxes.view().into_dyn(),
3657                    scores.view().into_dyn(),
3658                    mask.view().into_dyn(),
3659                    protos.view().into_dyn(),
3660                ],
3661                &mut output_boxes,
3662                &mut output_masks,
3663                &mut output_tracks,
3664            )
3665            .unwrap();
3666
3667        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3668        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3669
3670        // no masks when the boxes are from tracker prediction without a matching detection
3671        assert!(output_masks.is_empty())
3672    }
3673
3674    #[test]
3675    fn test_decoder_tracked_segdet_split_proto() {
3676        let score_threshold = 0.45;
3677        let iou_threshold = 0.45;
3678
3679        let boxes = include_bytes!(concat!(
3680            env!("CARGO_MANIFEST_DIR"),
3681            "/../../testdata/yolov8_boxes_116x8400.bin"
3682        ));
3683        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3684        let boxes = boxes.to_vec();
3685        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3686
3687        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3688        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3689        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3690
3691        let quant_boxes = (0.021287762, 31);
3692
3693        let protos = include_bytes!(concat!(
3694            env!("CARGO_MANIFEST_DIR"),
3695            "/../../testdata/yolov8_protos_160x160x32.bin"
3696        ));
3697        let protos =
3698            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3699        let protos = protos.to_vec();
3700        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3701        let quant_protos = (0.02491162, -117);
3702        let decoder = DecoderBuilder::default()
3703            .with_config_yolo_split_segdet(
3704                configs::Boxes {
3705                    decoder: configs::DecoderType::Ultralytics,
3706                    quantization: Some(quant_boxes.into()),
3707                    shape: vec![1, 4, 8400],
3708                    dshape: vec![
3709                        (DimName::Batch, 1),
3710                        (DimName::BoxCoords, 4),
3711                        (DimName::NumBoxes, 8400),
3712                    ],
3713                    normalized: Some(true),
3714                },
3715                configs::Scores {
3716                    decoder: configs::DecoderType::Ultralytics,
3717                    quantization: Some(quant_boxes.into()),
3718                    shape: vec![1, 80, 8400],
3719                    dshape: vec![
3720                        (DimName::Batch, 1),
3721                        (DimName::NumClasses, 80),
3722                        (DimName::NumBoxes, 8400),
3723                    ],
3724                },
3725                configs::MaskCoefficients {
3726                    decoder: configs::DecoderType::Ultralytics,
3727                    quantization: Some(quant_boxes.into()),
3728                    shape: vec![1, 32, 8400],
3729                    dshape: vec![
3730                        (DimName::Batch, 1),
3731                        (DimName::NumProtos, 32),
3732                        (DimName::NumBoxes, 8400),
3733                    ],
3734                },
3735                configs::Protos {
3736                    decoder: configs::DecoderType::Ultralytics,
3737                    quantization: Some(quant_protos.into()),
3738                    shape: vec![1, 160, 160, 32],
3739                    dshape: vec![
3740                        (DimName::Batch, 1),
3741                        (DimName::Height, 160),
3742                        (DimName::Width, 160),
3743                        (DimName::NumProtos, 32),
3744                    ],
3745                },
3746            )
3747            .with_score_threshold(score_threshold)
3748            .with_iou_threshold(iou_threshold)
3749            .build()
3750            .unwrap();
3751
3752        let expected_boxes = [
3753            DetectBox {
3754                bbox: BoundingBox {
3755                    xmin: 0.08515105,
3756                    ymin: 0.7131401,
3757                    xmax: 0.29802868,
3758                    ymax: 0.8195788,
3759                },
3760                score: 0.91537374,
3761                label: 23,
3762            },
3763            DetectBox {
3764                bbox: BoundingBox {
3765                    xmin: 0.59605736,
3766                    ymin: 0.25545314,
3767                    xmax: 0.93666154,
3768                    ymax: 0.72378385,
3769                },
3770                score: 0.91537374,
3771                label: 23,
3772            },
3773        ];
3774
3775        let mut tracker = ByteTrackBuilder::new()
3776            .track_update(0.1)
3777            .track_high_conf(0.7)
3778            .build();
3779
3780        let mut output_boxes = Vec::with_capacity(50);
3781        let mut output_tracks = Vec::with_capacity(50);
3782
3783        decoder
3784            .decode_tracked_quantized_proto(
3785                &mut tracker,
3786                0,
3787                &[
3788                    boxes.view().into(),
3789                    scores.view().into(),
3790                    mask.view().into(),
3791                    protos.view().into(),
3792                ],
3793                &mut output_boxes,
3794                &mut output_tracks,
3795            )
3796            .unwrap();
3797
3798        assert_eq!(output_boxes.len(), 2);
3799        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3800        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3801
3802        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3803
3804        for score in scores.iter_mut() {
3805            *score = i8::MIN; // set all scores to minimum to simulate no detections
3806        }
3807        let protos = decoder
3808            .decode_tracked_quantized_proto(
3809                &mut tracker,
3810                100_000_000 / 3,
3811                &[
3812                    boxes.view().into(),
3813                    scores.view().into(),
3814                    mask.view().into(),
3815                    protos.view().into(),
3816                ],
3817                &mut output_boxes,
3818                &mut output_tracks,
3819            )
3820            .unwrap();
3821
3822        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3823        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3824
3825        // no masks when the boxes are from tracker prediction without a matching detection
3826        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3827    }
3828
3829    #[test]
3830    fn test_decoder_tracked_segdet_split_proto_float() {
3831        let score_threshold = 0.45;
3832        let iou_threshold = 0.45;
3833
3834        let boxes = include_bytes!(concat!(
3835            env!("CARGO_MANIFEST_DIR"),
3836            "/../../testdata/yolov8_boxes_116x8400.bin"
3837        ));
3838        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3839        let boxes = boxes.to_vec();
3840        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3841        let quant_boxes = (0.021287762, 31);
3842        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3843
3844        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3845        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3846        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3847
3848        let protos = include_bytes!(concat!(
3849            env!("CARGO_MANIFEST_DIR"),
3850            "/../../testdata/yolov8_protos_160x160x32.bin"
3851        ));
3852        let protos =
3853            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3854        let protos = protos.to_vec();
3855        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3856        let quant_protos = (0.02491162, -117);
3857        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3858
3859        let decoder = DecoderBuilder::default()
3860            .with_config_yolo_split_segdet(
3861                configs::Boxes {
3862                    decoder: configs::DecoderType::Ultralytics,
3863                    quantization: Some(quant_boxes.into()),
3864                    shape: vec![1, 4, 8400],
3865                    dshape: vec![
3866                        (DimName::Batch, 1),
3867                        (DimName::BoxCoords, 4),
3868                        (DimName::NumBoxes, 8400),
3869                    ],
3870                    normalized: Some(true),
3871                },
3872                configs::Scores {
3873                    decoder: configs::DecoderType::Ultralytics,
3874                    quantization: Some(quant_boxes.into()),
3875                    shape: vec![1, 80, 8400],
3876                    dshape: vec![
3877                        (DimName::Batch, 1),
3878                        (DimName::NumClasses, 80),
3879                        (DimName::NumBoxes, 8400),
3880                    ],
3881                },
3882                configs::MaskCoefficients {
3883                    decoder: configs::DecoderType::Ultralytics,
3884                    quantization: Some(quant_boxes.into()),
3885                    shape: vec![1, 32, 8400],
3886                    dshape: vec![
3887                        (DimName::Batch, 1),
3888                        (DimName::NumProtos, 32),
3889                        (DimName::NumBoxes, 8400),
3890                    ],
3891                },
3892                configs::Protos {
3893                    decoder: configs::DecoderType::Ultralytics,
3894                    quantization: Some(quant_protos.into()),
3895                    shape: vec![1, 160, 160, 32],
3896                    dshape: vec![
3897                        (DimName::Batch, 1),
3898                        (DimName::Height, 160),
3899                        (DimName::Width, 160),
3900                        (DimName::NumProtos, 32),
3901                    ],
3902                },
3903            )
3904            .with_score_threshold(score_threshold)
3905            .with_iou_threshold(iou_threshold)
3906            .build()
3907            .unwrap();
3908
3909        let expected_boxes = [
3910            DetectBox {
3911                bbox: BoundingBox {
3912                    xmin: 0.08515105,
3913                    ymin: 0.7131401,
3914                    xmax: 0.29802868,
3915                    ymax: 0.8195788,
3916                },
3917                score: 0.91537374,
3918                label: 23,
3919            },
3920            DetectBox {
3921                bbox: BoundingBox {
3922                    xmin: 0.59605736,
3923                    ymin: 0.25545314,
3924                    xmax: 0.93666154,
3925                    ymax: 0.72378385,
3926                },
3927                score: 0.91537374,
3928                label: 23,
3929            },
3930        ];
3931
3932        let mut tracker = ByteTrackBuilder::new()
3933            .track_update(0.1)
3934            .track_high_conf(0.7)
3935            .build();
3936
3937        let mut output_boxes = Vec::with_capacity(50);
3938        let mut output_tracks = Vec::with_capacity(50);
3939
3940        decoder
3941            .decode_tracked_float_proto(
3942                &mut tracker,
3943                0,
3944                &[
3945                    boxes.view().into_dyn(),
3946                    scores.view().into_dyn(),
3947                    mask.view().into_dyn(),
3948                    protos.view().into_dyn(),
3949                ],
3950                &mut output_boxes,
3951                &mut output_tracks,
3952            )
3953            .unwrap();
3954
3955        assert_eq!(output_boxes.len(), 2);
3956        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3957        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3958
3959        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3960
3961        for score in scores.iter_mut() {
3962            *score = 0.0; // set all scores to minimum to simulate no detections
3963        }
3964        let protos = decoder
3965            .decode_tracked_float_proto(
3966                &mut tracker,
3967                100_000_000 / 3,
3968                &[
3969                    boxes.view().into_dyn(),
3970                    scores.view().into_dyn(),
3971                    mask.view().into_dyn(),
3972                    protos.view().into_dyn(),
3973                ],
3974                &mut output_boxes,
3975                &mut output_tracks,
3976            )
3977            .unwrap();
3978
3979        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3980        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3981
3982        // no masks when the boxes are from tracker prediction without a matching detection
3983        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3984    }
3985
3986    #[test]
3987    fn test_decoder_tracked_end_to_end_segdet() {
3988        let score_threshold = 0.45;
3989        let iou_threshold = 0.45;
3990
3991        let mut boxes = Array2::zeros((10, 4));
3992        let mut scores = Array2::zeros((10, 1));
3993        let mut classes = Array2::zeros((10, 1));
3994        let mask = Array2::zeros((10, 32));
3995        let protos = Array3::<f64>::zeros((160, 160, 32));
3996        let protos = protos.insert_axis(Axis(0));
3997
3998        let protos_quant = (1.0 / 255.0, 0.0);
3999        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4000
4001        boxes
4002            .slice_mut(s![0, ..,])
4003            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4004        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4005        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4006
4007        let detect = ndarray::concatenate![
4008            Axis(1),
4009            boxes.view(),
4010            scores.view(),
4011            classes.view(),
4012            mask.view()
4013        ];
4014        let detect = detect.insert_axis(Axis(0));
4015        assert_eq!(detect.shape(), &[1, 10, 38]);
4016        let detect_quant = (2.0 / 255.0, 0.0);
4017        let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4018        let config = "
4019decoder_version: yolo26
4020outputs:
4021 - type: detection
4022   decoder: ultralytics
4023   quantization: [0.00784313725490196, 0]
4024   shape: [1, 10, 38]
4025   dshape:
4026    - [batch, 1]
4027    - [num_boxes, 10]
4028    - [num_features, 38]
4029   normalized: true
4030 - type: protos
4031   decoder: ultralytics
4032   quantization: [0.0039215686274509803921568627451, 128]
4033   shape: [1, 160, 160, 32]
4034   dshape:
4035    - [batch, 1]
4036    - [height, 160]
4037    - [width, 160]
4038    - [num_protos, 32]
4039";
4040
4041        let decoder = DecoderBuilder::default()
4042            .with_config_yaml_str(config.to_string())
4043            .with_score_threshold(score_threshold)
4044            .with_iou_threshold(iou_threshold)
4045            .build()
4046            .unwrap();
4047
4048        // Expected boxes doesn't match the float values exactly due to quantization error
4049        let expected_boxes = [DetectBox {
4050            bbox: BoundingBox {
4051                xmin: 0.12549022,
4052                ymin: 0.12549022,
4053                xmax: 0.23529413,
4054                ymax: 0.23529413,
4055            },
4056            score: 0.98823535,
4057            label: 2,
4058        }];
4059
4060        let mut tracker = ByteTrackBuilder::new()
4061            .track_update(0.1)
4062            .track_high_conf(0.7)
4063            .build();
4064
4065        let mut output_boxes = Vec::with_capacity(50);
4066        let mut output_masks = Vec::with_capacity(50);
4067        let mut output_tracks = Vec::with_capacity(50);
4068
4069        decoder
4070            .decode_tracked_quantized(
4071                &mut tracker,
4072                0,
4073                &[detect.view().into(), protos.view().into()],
4074                &mut output_boxes,
4075                &mut output_masks,
4076                &mut output_tracks,
4077            )
4078            .unwrap();
4079
4080        assert_eq!(output_boxes.len(), 1);
4081        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4082
4083        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4084
4085        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4086            *score = u8::MIN; // set all scores to minimum to simulate no detections
4087        }
4088
4089        decoder
4090            .decode_tracked_quantized(
4091                &mut tracker,
4092                100_000_000 / 3,
4093                &[detect.view().into(), protos.view().into()],
4094                &mut output_boxes,
4095                &mut output_masks,
4096                &mut output_tracks,
4097            )
4098            .unwrap();
4099        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4100        // no masks when the boxes are from tracker prediction without a matching detection
4101        assert!(output_masks.is_empty())
4102    }
4103
4104    #[test]
4105    fn test_decoder_tracked_end_to_end_segdet_float() {
4106        let score_threshold = 0.45;
4107        let iou_threshold = 0.45;
4108
4109        let mut boxes = Array2::zeros((10, 4));
4110        let mut scores = Array2::zeros((10, 1));
4111        let mut classes = Array2::zeros((10, 1));
4112        let mask = Array2::zeros((10, 32));
4113        let protos = Array3::<f64>::zeros((160, 160, 32));
4114        let protos = protos.insert_axis(Axis(0));
4115
4116        boxes
4117            .slice_mut(s![0, ..,])
4118            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4119        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4120        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4121
4122        let detect = ndarray::concatenate![
4123            Axis(1),
4124            boxes.view(),
4125            scores.view(),
4126            classes.view(),
4127            mask.view()
4128        ];
4129        let mut detect = detect.insert_axis(Axis(0));
4130        assert_eq!(detect.shape(), &[1, 10, 38]);
4131        let config = "
4132decoder_version: yolo26
4133outputs:
4134 - type: detection
4135   decoder: ultralytics
4136   quantization: [0.00784313725490196, 0]
4137   shape: [1, 10, 38]
4138   dshape:
4139    - [batch, 1]
4140    - [num_boxes, 10]
4141    - [num_features, 38]
4142   normalized: true
4143 - type: protos
4144   decoder: ultralytics
4145   quantization: [0.0039215686274509803921568627451, 128]
4146   shape: [1, 160, 160, 32]
4147   dshape:
4148    - [batch, 1]
4149    - [height, 160]
4150    - [width, 160]
4151    - [num_protos, 32]
4152";
4153
4154        let decoder = DecoderBuilder::default()
4155            .with_config_yaml_str(config.to_string())
4156            .with_score_threshold(score_threshold)
4157            .with_iou_threshold(iou_threshold)
4158            .build()
4159            .unwrap();
4160
4161        let expected_boxes = [DetectBox {
4162            bbox: BoundingBox {
4163                xmin: 0.1234,
4164                ymin: 0.1234,
4165                xmax: 0.2345,
4166                ymax: 0.2345,
4167            },
4168            score: 0.9876,
4169            label: 2,
4170        }];
4171
4172        let mut tracker = ByteTrackBuilder::new()
4173            .track_update(0.1)
4174            .track_high_conf(0.7)
4175            .build();
4176
4177        let mut output_boxes = Vec::with_capacity(50);
4178        let mut output_masks = Vec::with_capacity(50);
4179        let mut output_tracks = Vec::with_capacity(50);
4180
4181        decoder
4182            .decode_tracked_float(
4183                &mut tracker,
4184                0,
4185                &[detect.view().into_dyn(), protos.view().into_dyn()],
4186                &mut output_boxes,
4187                &mut output_masks,
4188                &mut output_tracks,
4189            )
4190            .unwrap();
4191
4192        assert_eq!(output_boxes.len(), 1);
4193        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4194
4195        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4196
4197        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4198            *score = 0.0; // set all scores to minimum to simulate no detections
4199        }
4200
4201        decoder
4202            .decode_tracked_float(
4203                &mut tracker,
4204                100_000_000 / 3,
4205                &[detect.view().into_dyn(), protos.view().into_dyn()],
4206                &mut output_boxes,
4207                &mut output_masks,
4208                &mut output_tracks,
4209            )
4210            .unwrap();
4211        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4212        // no masks when the boxes are from tracker prediction without a matching detection
4213        assert!(output_masks.is_empty())
4214    }
4215
4216    #[test]
4217    fn test_decoder_tracked_end_to_end_segdet_proto() {
4218        let score_threshold = 0.45;
4219        let iou_threshold = 0.45;
4220
4221        let mut boxes = Array2::zeros((10, 4));
4222        let mut scores = Array2::zeros((10, 1));
4223        let mut classes = Array2::zeros((10, 1));
4224        let mask = Array2::zeros((10, 32));
4225        let protos = Array3::<f64>::zeros((160, 160, 32));
4226        let protos = protos.insert_axis(Axis(0));
4227
4228        let protos_quant = (1.0 / 255.0, 0.0);
4229        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4230
4231        boxes
4232            .slice_mut(s![0, ..,])
4233            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4234        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4235        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4236
4237        let detect = ndarray::concatenate![
4238            Axis(1),
4239            boxes.view(),
4240            scores.view(),
4241            classes.view(),
4242            mask.view()
4243        ];
4244        let detect = detect.insert_axis(Axis(0));
4245        assert_eq!(detect.shape(), &[1, 10, 38]);
4246        let detect_quant = (2.0 / 255.0, 0.0);
4247        let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4248        let config = "
4249decoder_version: yolo26
4250outputs:
4251 - type: detection
4252   decoder: ultralytics
4253   quantization: [0.00784313725490196, 0]
4254   shape: [1, 10, 38]
4255   dshape:
4256    - [batch, 1]
4257    - [num_boxes, 10]
4258    - [num_features, 38]
4259   normalized: true
4260 - type: protos
4261   decoder: ultralytics
4262   quantization: [0.0039215686274509803921568627451, 128]
4263   shape: [1, 160, 160, 32]
4264   dshape:
4265    - [batch, 1]
4266    - [height, 160]
4267    - [width, 160]
4268    - [num_protos, 32]
4269";
4270
4271        let decoder = DecoderBuilder::default()
4272            .with_config_yaml_str(config.to_string())
4273            .with_score_threshold(score_threshold)
4274            .with_iou_threshold(iou_threshold)
4275            .build()
4276            .unwrap();
4277
4278        // Expected boxes doesn't match the float values exactly due to quantization error
4279        let expected_boxes = [DetectBox {
4280            bbox: BoundingBox {
4281                xmin: 0.12549022,
4282                ymin: 0.12549022,
4283                xmax: 0.23529413,
4284                ymax: 0.23529413,
4285            },
4286            score: 0.98823535,
4287            label: 2,
4288        }];
4289
4290        let mut tracker = ByteTrackBuilder::new()
4291            .track_update(0.1)
4292            .track_high_conf(0.7)
4293            .build();
4294
4295        let mut output_boxes = Vec::with_capacity(50);
4296        let mut output_tracks = Vec::with_capacity(50);
4297
4298        decoder
4299            .decode_tracked_quantized_proto(
4300                &mut tracker,
4301                0,
4302                &[detect.view().into(), protos.view().into()],
4303                &mut output_boxes,
4304                &mut output_tracks,
4305            )
4306            .unwrap();
4307
4308        assert_eq!(output_boxes.len(), 1);
4309        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4310
4311        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4312
4313        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4314            *score = u8::MIN; // set all scores to minimum to simulate no detections
4315        }
4316
4317        let protos = decoder
4318            .decode_tracked_quantized_proto(
4319                &mut tracker,
4320                100_000_000 / 3,
4321                &[detect.view().into(), protos.view().into()],
4322                &mut output_boxes,
4323                &mut output_tracks,
4324            )
4325            .unwrap();
4326        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4327        // no masks when the boxes are from tracker prediction without a matching detection
4328        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4329    }
4330
4331    #[test]
4332    fn test_decoder_tracked_end_to_end_segdet_proto_float() {
4333        let score_threshold = 0.45;
4334        let iou_threshold = 0.45;
4335
4336        let mut boxes = Array2::zeros((10, 4));
4337        let mut scores = Array2::zeros((10, 1));
4338        let mut classes = Array2::zeros((10, 1));
4339        let mask = Array2::zeros((10, 32));
4340        let protos = Array3::<f64>::zeros((160, 160, 32));
4341        let protos = protos.insert_axis(Axis(0));
4342
4343        boxes
4344            .slice_mut(s![0, ..,])
4345            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4346        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4347        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4348
4349        let detect = ndarray::concatenate![
4350            Axis(1),
4351            boxes.view(),
4352            scores.view(),
4353            classes.view(),
4354            mask.view()
4355        ];
4356        let mut detect = detect.insert_axis(Axis(0));
4357        assert_eq!(detect.shape(), &[1, 10, 38]);
4358        let config = "
4359decoder_version: yolo26
4360outputs:
4361 - type: detection
4362   decoder: ultralytics
4363   quantization: [0.00784313725490196, 0]
4364   shape: [1, 10, 38]
4365   dshape:
4366    - [batch, 1]
4367    - [num_boxes, 10]
4368    - [num_features, 38]
4369   normalized: true
4370 - type: protos
4371   decoder: ultralytics
4372   quantization: [0.0039215686274509803921568627451, 128]
4373   shape: [1, 160, 160, 32]
4374   dshape:
4375    - [batch, 1]
4376    - [height, 160]
4377    - [width, 160]
4378    - [num_protos, 32]
4379";
4380
4381        let decoder = DecoderBuilder::default()
4382            .with_config_yaml_str(config.to_string())
4383            .with_score_threshold(score_threshold)
4384            .with_iou_threshold(iou_threshold)
4385            .build()
4386            .unwrap();
4387
4388        let expected_boxes = [DetectBox {
4389            bbox: BoundingBox {
4390                xmin: 0.1234,
4391                ymin: 0.1234,
4392                xmax: 0.2345,
4393                ymax: 0.2345,
4394            },
4395            score: 0.9876,
4396            label: 2,
4397        }];
4398
4399        let mut tracker = ByteTrackBuilder::new()
4400            .track_update(0.1)
4401            .track_high_conf(0.7)
4402            .build();
4403
4404        let mut output_boxes = Vec::with_capacity(50);
4405        let mut output_tracks = Vec::with_capacity(50);
4406
4407        decoder
4408            .decode_tracked_float_proto(
4409                &mut tracker,
4410                0,
4411                &[detect.view().into_dyn(), protos.view().into_dyn()],
4412                &mut output_boxes,
4413                &mut output_tracks,
4414            )
4415            .unwrap();
4416
4417        assert_eq!(output_boxes.len(), 1);
4418        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4419
4420        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4421
4422        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4423            *score = 0.0; // set all scores to minimum to simulate no detections
4424        }
4425
4426        let protos = decoder
4427            .decode_tracked_float_proto(
4428                &mut tracker,
4429                100_000_000 / 3,
4430                &[detect.view().into_dyn(), protos.view().into_dyn()],
4431                &mut output_boxes,
4432                &mut output_tracks,
4433            )
4434            .unwrap();
4435        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4436        // no masks when the boxes are from tracker prediction without a matching detection
4437        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4438    }
4439
4440    #[test]
4441    fn test_decoder_tracked_end_to_end_segdet_split() {
4442        let score_threshold = 0.45;
4443        let iou_threshold = 0.45;
4444
4445        let mut boxes = Array2::zeros((10, 4));
4446        let mut scores = Array2::zeros((10, 1));
4447        let mut classes = Array2::zeros((10, 1));
4448        let mask: Array2<f64> = Array2::zeros((10, 32));
4449        let protos = Array3::<f64>::zeros((160, 160, 32));
4450        let protos = protos.insert_axis(Axis(0));
4451
4452        let protos_quant = (1.0 / 255.0, 0.0);
4453        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4454
4455        boxes
4456            .slice_mut(s![0, ..,])
4457            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4458        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4459        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4460
4461        let boxes = boxes.insert_axis(Axis(0));
4462        let scores = scores.insert_axis(Axis(0));
4463        let classes = classes.insert_axis(Axis(0));
4464        let mask = mask.insert_axis(Axis(0));
4465
4466        let detect_quant = (2.0 / 255.0, 0.0);
4467        let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4468        let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4469        let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4470        let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4471
4472        let config = "
4473decoder_version: yolo26
4474outputs:
4475 - type: boxes
4476   decoder: ultralytics
4477   quantization: [0.00784313725490196, 0]
4478   shape: [1, 10, 4]
4479   dshape:
4480    - [batch, 1]
4481    - [num_boxes, 10]
4482    - [box_coords, 4]
4483   normalized: true
4484 - type: scores
4485   decoder: ultralytics
4486   quantization: [0.00784313725490196, 0]
4487   shape: [1, 10, 1]
4488   dshape:
4489    - [batch, 1]
4490    - [num_boxes, 10]
4491    - [num_classes, 1]
4492 - type: classes
4493   decoder: ultralytics
4494   quantization: [0.00784313725490196, 0]
4495   shape: [1, 10, 1]
4496   dshape:
4497    - [batch, 1]
4498    - [num_boxes, 10]
4499    - [num_classes, 1]
4500 - type: mask_coefficients
4501   decoder: ultralytics
4502   quantization: [0.00784313725490196, 0]
4503   shape: [1, 10, 32]
4504   dshape:
4505    - [batch, 1]
4506    - [num_boxes, 10]
4507    - [num_protos, 32]
4508 - type: protos
4509   decoder: ultralytics
4510   quantization: [0.0039215686274509803921568627451, 128]
4511   shape: [1, 160, 160, 32]
4512   dshape:
4513    - [batch, 1]
4514    - [height, 160]
4515    - [width, 160]
4516    - [num_protos, 32]
4517";
4518
4519        let decoder = DecoderBuilder::default()
4520            .with_config_yaml_str(config.to_string())
4521            .with_score_threshold(score_threshold)
4522            .with_iou_threshold(iou_threshold)
4523            .build()
4524            .unwrap();
4525
4526        // Expected boxes doesn't match the float values exactly due to quantization error
4527        let expected_boxes = [DetectBox {
4528            bbox: BoundingBox {
4529                xmin: 0.12549022,
4530                ymin: 0.12549022,
4531                xmax: 0.23529413,
4532                ymax: 0.23529413,
4533            },
4534            score: 0.98823535,
4535            label: 2,
4536        }];
4537
4538        let mut tracker = ByteTrackBuilder::new()
4539            .track_update(0.1)
4540            .track_high_conf(0.7)
4541            .build();
4542
4543        let mut output_boxes = Vec::with_capacity(50);
4544        let mut output_masks = Vec::with_capacity(50);
4545        let mut output_tracks = Vec::with_capacity(50);
4546
4547        decoder
4548            .decode_tracked_quantized(
4549                &mut tracker,
4550                0,
4551                &[
4552                    boxes.view().into(),
4553                    scores.view().into(),
4554                    classes.view().into(),
4555                    mask.view().into(),
4556                    protos.view().into(),
4557                ],
4558                &mut output_boxes,
4559                &mut output_masks,
4560                &mut output_tracks,
4561            )
4562            .unwrap();
4563
4564        assert_eq!(output_boxes.len(), 1);
4565        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4566
4567        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4568
4569        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4570            *score = u8::MIN; // set all scores to minimum to simulate no detections
4571        }
4572
4573        decoder
4574            .decode_tracked_quantized(
4575                &mut tracker,
4576                100_000_000 / 3,
4577                &[
4578                    boxes.view().into(),
4579                    scores.view().into(),
4580                    classes.view().into(),
4581                    mask.view().into(),
4582                    protos.view().into(),
4583                ],
4584                &mut output_boxes,
4585                &mut output_masks,
4586                &mut output_tracks,
4587            )
4588            .unwrap();
4589        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4590        // no masks when the boxes are from tracker prediction without a matching detection
4591        assert!(output_masks.is_empty())
4592    }
4593    #[test]
4594    fn test_decoder_tracked_end_to_end_segdet_split_float() {
4595        let score_threshold = 0.45;
4596        let iou_threshold = 0.45;
4597
4598        let mut boxes = Array2::zeros((10, 4));
4599        let mut scores = Array2::zeros((10, 1));
4600        let mut classes = Array2::zeros((10, 1));
4601        let mask: Array2<f64> = Array2::zeros((10, 32));
4602        let protos = Array3::<f64>::zeros((160, 160, 32));
4603        let protos = protos.insert_axis(Axis(0));
4604
4605        boxes
4606            .slice_mut(s![0, ..,])
4607            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4608        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4609        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4610
4611        let boxes = boxes.insert_axis(Axis(0));
4612        let mut scores = scores.insert_axis(Axis(0));
4613        let classes = classes.insert_axis(Axis(0));
4614        let mask = mask.insert_axis(Axis(0));
4615
4616        let config = "
4617decoder_version: yolo26
4618outputs:
4619 - type: boxes
4620   decoder: ultralytics
4621   quantization: [0.00784313725490196, 0]
4622   shape: [1, 10, 4]
4623   dshape:
4624    - [batch, 1]
4625    - [num_boxes, 10]
4626    - [box_coords, 4]
4627   normalized: true
4628 - type: scores
4629   decoder: ultralytics
4630   quantization: [0.00784313725490196, 0]
4631   shape: [1, 10, 1]
4632   dshape:
4633    - [batch, 1]
4634    - [num_boxes, 10]
4635    - [num_classes, 1]
4636 - type: classes
4637   decoder: ultralytics
4638   quantization: [0.00784313725490196, 0]
4639   shape: [1, 10, 1]
4640   dshape:
4641    - [batch, 1]
4642    - [num_boxes, 10]
4643    - [num_classes, 1]
4644 - type: mask_coefficients
4645   decoder: ultralytics
4646   quantization: [0.00784313725490196, 0]
4647   shape: [1, 10, 32]
4648   dshape:
4649    - [batch, 1]
4650    - [num_boxes, 10]
4651    - [num_protos, 32]
4652 - type: protos
4653   decoder: ultralytics
4654   quantization: [0.0039215686274509803921568627451, 128]
4655   shape: [1, 160, 160, 32]
4656   dshape:
4657    - [batch, 1]
4658    - [height, 160]
4659    - [width, 160]
4660    - [num_protos, 32]
4661";
4662
4663        let decoder = DecoderBuilder::default()
4664            .with_config_yaml_str(config.to_string())
4665            .with_score_threshold(score_threshold)
4666            .with_iou_threshold(iou_threshold)
4667            .build()
4668            .unwrap();
4669
4670        // Expected boxes doesn't match the float values exactly due to quantization error
4671        let expected_boxes = [DetectBox {
4672            bbox: BoundingBox {
4673                xmin: 0.1234,
4674                ymin: 0.1234,
4675                xmax: 0.2345,
4676                ymax: 0.2345,
4677            },
4678            score: 0.9876,
4679            label: 2,
4680        }];
4681
4682        let mut tracker = ByteTrackBuilder::new()
4683            .track_update(0.1)
4684            .track_high_conf(0.7)
4685            .build();
4686
4687        let mut output_boxes = Vec::with_capacity(50);
4688        let mut output_masks = Vec::with_capacity(50);
4689        let mut output_tracks = Vec::with_capacity(50);
4690
4691        decoder
4692            .decode_tracked_float(
4693                &mut tracker,
4694                0,
4695                &[
4696                    boxes.view().into_dyn(),
4697                    scores.view().into_dyn(),
4698                    classes.view().into_dyn(),
4699                    mask.view().into_dyn(),
4700                    protos.view().into_dyn(),
4701                ],
4702                &mut output_boxes,
4703                &mut output_masks,
4704                &mut output_tracks,
4705            )
4706            .unwrap();
4707
4708        assert_eq!(output_boxes.len(), 1);
4709        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4710
4711        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4712
4713        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4714            *score = 0.0; // set all scores to minimum to simulate no detections
4715        }
4716
4717        decoder
4718            .decode_tracked_float(
4719                &mut tracker,
4720                100_000_000 / 3,
4721                &[
4722                    boxes.view().into_dyn(),
4723                    scores.view().into_dyn(),
4724                    classes.view().into_dyn(),
4725                    mask.view().into_dyn(),
4726                    protos.view().into_dyn(),
4727                ],
4728                &mut output_boxes,
4729                &mut output_masks,
4730                &mut output_tracks,
4731            )
4732            .unwrap();
4733        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4734        // no masks when the boxes are from tracker prediction without a matching detection
4735        assert!(output_masks.is_empty())
4736    }
4737
4738    #[test]
4739    fn test_decoder_tracked_end_to_end_segdet_split_proto() {
4740        let score_threshold = 0.45;
4741        let iou_threshold = 0.45;
4742
4743        let mut boxes = Array2::zeros((10, 4));
4744        let mut scores = Array2::zeros((10, 1));
4745        let mut classes = Array2::zeros((10, 1));
4746        let mask: Array2<f64> = Array2::zeros((10, 32));
4747        let protos = Array3::<f64>::zeros((160, 160, 32));
4748        let protos = protos.insert_axis(Axis(0));
4749
4750        let protos_quant = (1.0 / 255.0, 0.0);
4751        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4752
4753        boxes
4754            .slice_mut(s![0, ..,])
4755            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4756        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4757        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4758
4759        let boxes = boxes.insert_axis(Axis(0));
4760        let scores = scores.insert_axis(Axis(0));
4761        let classes = classes.insert_axis(Axis(0));
4762        let mask = mask.insert_axis(Axis(0));
4763
4764        let detect_quant = (2.0 / 255.0, 0.0);
4765        let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4766        let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4767        let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4768        let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4769
4770        let config = "
4771decoder_version: yolo26
4772outputs:
4773 - type: boxes
4774   decoder: ultralytics
4775   quantization: [0.00784313725490196, 0]
4776   shape: [1, 10, 4]
4777   dshape:
4778    - [batch, 1]
4779    - [num_boxes, 10]
4780    - [box_coords, 4]
4781   normalized: true
4782 - type: scores
4783   decoder: ultralytics
4784   quantization: [0.00784313725490196, 0]
4785   shape: [1, 10, 1]
4786   dshape:
4787    - [batch, 1]
4788    - [num_boxes, 10]
4789    - [num_classes, 1]
4790 - type: classes
4791   decoder: ultralytics
4792   quantization: [0.00784313725490196, 0]
4793   shape: [1, 10, 1]
4794   dshape:
4795    - [batch, 1]
4796    - [num_boxes, 10]
4797    - [num_classes, 1]
4798 - type: mask_coefficients
4799   decoder: ultralytics
4800   quantization: [0.00784313725490196, 0]
4801   shape: [1, 10, 32]
4802   dshape:
4803    - [batch, 1]
4804    - [num_boxes, 10]
4805    - [num_protos, 32]
4806 - type: protos
4807   decoder: ultralytics
4808   quantization: [0.0039215686274509803921568627451, 128]
4809   shape: [1, 160, 160, 32]
4810   dshape:
4811    - [batch, 1]
4812    - [height, 160]
4813    - [width, 160]
4814    - [num_protos, 32]
4815";
4816
4817        let decoder = DecoderBuilder::default()
4818            .with_config_yaml_str(config.to_string())
4819            .with_score_threshold(score_threshold)
4820            .with_iou_threshold(iou_threshold)
4821            .build()
4822            .unwrap();
4823
4824        // Expected boxes doesn't match the float values exactly due to quantization error
4825        let expected_boxes = [DetectBox {
4826            bbox: BoundingBox {
4827                xmin: 0.12549022,
4828                ymin: 0.12549022,
4829                xmax: 0.23529413,
4830                ymax: 0.23529413,
4831            },
4832            score: 0.98823535,
4833            label: 2,
4834        }];
4835
4836        let mut tracker = ByteTrackBuilder::new()
4837            .track_update(0.1)
4838            .track_high_conf(0.7)
4839            .build();
4840
4841        let mut output_boxes = Vec::with_capacity(50);
4842        let mut output_tracks = Vec::with_capacity(50);
4843
4844        decoder
4845            .decode_tracked_quantized_proto(
4846                &mut tracker,
4847                0,
4848                &[
4849                    boxes.view().into(),
4850                    scores.view().into(),
4851                    classes.view().into(),
4852                    mask.view().into(),
4853                    protos.view().into(),
4854                ],
4855                &mut output_boxes,
4856                &mut output_tracks,
4857            )
4858            .unwrap();
4859
4860        assert_eq!(output_boxes.len(), 1);
4861        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4862
4863        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4864
4865        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4866            *score = u8::MIN; // set all scores to minimum to simulate no detections
4867        }
4868
4869        let protos = decoder
4870            .decode_tracked_quantized_proto(
4871                &mut tracker,
4872                100_000_000 / 3,
4873                &[
4874                    boxes.view().into(),
4875                    scores.view().into(),
4876                    classes.view().into(),
4877                    mask.view().into(),
4878                    protos.view().into(),
4879                ],
4880                &mut output_boxes,
4881                &mut output_tracks,
4882            )
4883            .unwrap();
4884        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4885        // no masks when the boxes are from tracker prediction without a matching detection
4886        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4887    }
4888
4889    #[test]
4890    fn test_decoder_tracked_end_to_end_segdet_split_proto_float() {
4891        let score_threshold = 0.45;
4892        let iou_threshold = 0.45;
4893
4894        let mut boxes = Array2::zeros((10, 4));
4895        let mut scores = Array2::zeros((10, 1));
4896        let mut classes = Array2::zeros((10, 1));
4897        let mask: Array2<f64> = Array2::zeros((10, 32));
4898        let protos = Array3::<f64>::zeros((160, 160, 32));
4899        let protos = protos.insert_axis(Axis(0));
4900
4901        boxes
4902            .slice_mut(s![0, ..,])
4903            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4904        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4905        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4906
4907        let boxes = boxes.insert_axis(Axis(0));
4908        let mut scores = scores.insert_axis(Axis(0));
4909        let classes = classes.insert_axis(Axis(0));
4910        let mask = mask.insert_axis(Axis(0));
4911
4912        let config = "
4913decoder_version: yolo26
4914outputs:
4915 - type: boxes
4916   decoder: ultralytics
4917   quantization: [0.00784313725490196, 0]
4918   shape: [1, 10, 4]
4919   dshape:
4920    - [batch, 1]
4921    - [num_boxes, 10]
4922    - [box_coords, 4]
4923   normalized: true
4924 - type: scores
4925   decoder: ultralytics
4926   quantization: [0.00784313725490196, 0]
4927   shape: [1, 10, 1]
4928   dshape:
4929    - [batch, 1]
4930    - [num_boxes, 10]
4931    - [num_classes, 1]
4932 - type: classes
4933   decoder: ultralytics
4934   quantization: [0.00784313725490196, 0]
4935   shape: [1, 10, 1]
4936   dshape:
4937    - [batch, 1]
4938    - [num_boxes, 10]
4939    - [num_classes, 1]
4940 - type: mask_coefficients
4941   decoder: ultralytics
4942   quantization: [0.00784313725490196, 0]
4943   shape: [1, 10, 32]
4944   dshape:
4945    - [batch, 1]
4946    - [num_boxes, 10]
4947    - [num_protos, 32]
4948 - type: protos
4949   decoder: ultralytics
4950   quantization: [0.0039215686274509803921568627451, 128]
4951   shape: [1, 160, 160, 32]
4952   dshape:
4953    - [batch, 1]
4954    - [height, 160]
4955    - [width, 160]
4956    - [num_protos, 32]
4957";
4958
4959        let decoder = DecoderBuilder::default()
4960            .with_config_yaml_str(config.to_string())
4961            .with_score_threshold(score_threshold)
4962            .with_iou_threshold(iou_threshold)
4963            .build()
4964            .unwrap();
4965
4966        // Expected boxes doesn't match the float values exactly due to quantization error
4967        let expected_boxes = [DetectBox {
4968            bbox: BoundingBox {
4969                xmin: 0.1234,
4970                ymin: 0.1234,
4971                xmax: 0.2345,
4972                ymax: 0.2345,
4973            },
4974            score: 0.9876,
4975            label: 2,
4976        }];
4977
4978        let mut tracker = ByteTrackBuilder::new()
4979            .track_update(0.1)
4980            .track_high_conf(0.7)
4981            .build();
4982
4983        let mut output_boxes = Vec::with_capacity(50);
4984        let mut output_tracks = Vec::with_capacity(50);
4985
4986        decoder
4987            .decode_tracked_float_proto(
4988                &mut tracker,
4989                0,
4990                &[
4991                    boxes.view().into_dyn(),
4992                    scores.view().into_dyn(),
4993                    classes.view().into_dyn(),
4994                    mask.view().into_dyn(),
4995                    protos.view().into_dyn(),
4996                ],
4997                &mut output_boxes,
4998                &mut output_tracks,
4999            )
5000            .unwrap();
5001
5002        assert_eq!(output_boxes.len(), 1);
5003        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
5004
5005        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5006
5007        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5008            *score = 0.0; // set all scores to minimum to simulate no detections
5009        }
5010
5011        let protos = decoder
5012            .decode_tracked_float_proto(
5013                &mut tracker,
5014                100_000_000 / 3,
5015                &[
5016                    boxes.view().into_dyn(),
5017                    scores.view().into_dyn(),
5018                    classes.view().into_dyn(),
5019                    mask.view().into_dyn(),
5020                    protos.view().into_dyn(),
5021                ],
5022                &mut output_boxes,
5023                &mut output_tracks,
5024            )
5025            .unwrap();
5026        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5027        // no masks when the boxes are from tracker prediction without a matching detection
5028        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
5029    }
5030
5031    #[test]
5032    fn test_decoder_tracked_linear_motion() {
5033        use crate::configs::{DecoderType, Nms};
5034        use crate::DecoderBuilder;
5035
5036        let score_threshold = 0.25;
5037        let iou_threshold = 0.1;
5038        let out = include_bytes!(concat!(
5039            env!("CARGO_MANIFEST_DIR"),
5040            "/../../testdata/yolov8s_80_classes.bin"
5041        ));
5042        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5043        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5044        let quant = (0.0040811873, -123).into();
5045
5046        let decoder = DecoderBuilder::default()
5047            .with_config_yolo_det(
5048                crate::configs::Detection {
5049                    decoder: DecoderType::Ultralytics,
5050                    shape: vec![1, 84, 8400],
5051                    anchors: None,
5052                    quantization: Some(quant),
5053                    dshape: vec![
5054                        (crate::configs::DimName::Batch, 1),
5055                        (crate::configs::DimName::NumFeatures, 84),
5056                        (crate::configs::DimName::NumBoxes, 8400),
5057                    ],
5058                    normalized: Some(true),
5059                },
5060                None,
5061            )
5062            .with_score_threshold(score_threshold)
5063            .with_iou_threshold(iou_threshold)
5064            .with_nms(Some(Nms::ClassAgnostic))
5065            .build()
5066            .unwrap();
5067
5068        let mut expected_boxes = [
5069            DetectBox {
5070                bbox: BoundingBox {
5071                    xmin: 0.5285137,
5072                    ymin: 0.05305544,
5073                    xmax: 0.87541467,
5074                    ymax: 0.9998909,
5075                },
5076                score: 0.5591227,
5077                label: 0,
5078            },
5079            DetectBox {
5080                bbox: BoundingBox {
5081                    xmin: 0.130598,
5082                    ymin: 0.43260583,
5083                    xmax: 0.35098213,
5084                    ymax: 0.9958097,
5085                },
5086                score: 0.33057618,
5087                label: 75,
5088            },
5089        ];
5090
5091        let mut tracker = ByteTrackBuilder::new()
5092            .track_update(0.1)
5093            .track_high_conf(0.3)
5094            .build();
5095
5096        let mut output_boxes = Vec::with_capacity(50);
5097        let mut output_masks = Vec::with_capacity(50);
5098        let mut output_tracks = Vec::with_capacity(50);
5099
5100        decoder
5101            .decode_tracked_quantized(
5102                &mut tracker,
5103                0,
5104                &[out.view().into()],
5105                &mut output_boxes,
5106                &mut output_masks,
5107                &mut output_tracks,
5108            )
5109            .unwrap();
5110
5111        assert_eq!(output_boxes.len(), 2);
5112        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5113        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5114
5115        for i in 1..=100 {
5116            let mut out = out.clone();
5117            // introduce linear movement into the XY coordinates
5118            let mut x_values = out.slice_mut(s![0, 0, ..]);
5119            for x in x_values.iter_mut() {
5120                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5121            }
5122
5123            decoder
5124                .decode_tracked_quantized(
5125                    &mut tracker,
5126                    100_000_000 * i / 3, // simulate 33.333ms between frames
5127                    &[out.view().into()],
5128                    &mut output_boxes,
5129                    &mut output_masks,
5130                    &mut output_tracks,
5131                )
5132                .unwrap();
5133
5134            assert_eq!(output_boxes.len(), 2);
5135        }
5136        let tracks = tracker.get_active_tracks();
5137        let predicted_boxes: Vec<_> = tracks
5138            .iter()
5139            .map(|track| {
5140                let mut l = track.last_box;
5141                l.bbox = track.info.tracked_location.into();
5142                l
5143            })
5144            .collect();
5145        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
5146        expected_boxes[0].bbox.xmax += 0.1;
5147        expected_boxes[1].bbox.xmin += 0.1;
5148        expected_boxes[1].bbox.xmax += 0.1;
5149
5150        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5151        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5152
5153        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5154        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5155        for score in scores_values.iter_mut() {
5156            *score = i8::MIN; // set all scores to minimum to simulate no detections
5157        }
5158        decoder
5159            .decode_tracked_quantized(
5160                &mut tracker,
5161                100_000_000 * 101 / 3,
5162                &[out.view().into()],
5163                &mut output_boxes,
5164                &mut output_masks,
5165                &mut output_tracks,
5166            )
5167            .unwrap();
5168        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
5169        expected_boxes[0].bbox.xmax += 0.001;
5170        expected_boxes[1].bbox.xmin += 0.001;
5171        expected_boxes[1].bbox.xmax += 0.001;
5172
5173        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5174        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5175    }
5176
5177    #[test]
5178    fn test_decoder_tracked_end_to_end_float() {
5179        let score_threshold = 0.45;
5180        let iou_threshold = 0.45;
5181
5182        let mut boxes = Array2::zeros((10, 4));
5183        let mut scores = Array2::zeros((10, 1));
5184        let mut classes = Array2::zeros((10, 1));
5185
5186        boxes
5187            .slice_mut(s![0, ..,])
5188            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5189        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5190        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5191
5192        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5193        let mut detect = detect.insert_axis(Axis(0));
5194        assert_eq!(detect.shape(), &[1, 10, 6]);
5195        let config = "
5196decoder_version: yolo26
5197outputs:
5198 - type: detection
5199   decoder: ultralytics
5200   quantization: [0.00784313725490196, 0]
5201   shape: [1, 10, 6]
5202   dshape:
5203    - [batch, 1]
5204    - [num_boxes, 10]
5205    - [num_features, 6]
5206   normalized: true
5207";
5208
5209        let decoder = DecoderBuilder::default()
5210            .with_config_yaml_str(config.to_string())
5211            .with_score_threshold(score_threshold)
5212            .with_iou_threshold(iou_threshold)
5213            .build()
5214            .unwrap();
5215
5216        let expected_boxes = [DetectBox {
5217            bbox: BoundingBox {
5218                xmin: 0.1234,
5219                ymin: 0.1234,
5220                xmax: 0.2345,
5221                ymax: 0.2345,
5222            },
5223            score: 0.9876,
5224            label: 2,
5225        }];
5226
5227        let mut tracker = ByteTrackBuilder::new()
5228            .track_update(0.1)
5229            .track_high_conf(0.7)
5230            .build();
5231
5232        let mut output_boxes = Vec::with_capacity(50);
5233        let mut output_masks = Vec::with_capacity(50);
5234        let mut output_tracks = Vec::with_capacity(50);
5235
5236        decoder
5237            .decode_tracked_float(
5238                &mut tracker,
5239                0,
5240                &[detect.view().into_dyn()],
5241                &mut output_boxes,
5242                &mut output_masks,
5243                &mut output_tracks,
5244            )
5245            .unwrap();
5246
5247        assert_eq!(output_boxes.len(), 1);
5248        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5249
5250        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5251
5252        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5253            *score = 0.0; // set all scores to minimum to simulate no detections
5254        }
5255
5256        decoder
5257            .decode_tracked_float(
5258                &mut tracker,
5259                100_000_000 / 3,
5260                &[detect.view().into_dyn()],
5261                &mut output_boxes,
5262                &mut output_masks,
5263                &mut output_tracks,
5264            )
5265            .unwrap();
5266        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5267    }
5268}