Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust
16# use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs::{self, DecoderVersion} };
17# fn main() -> DecoderResult<()> {
18// Create a decoder for a YOLOv8 model with quantized int8 output with 0.25 score threshold and 0.7 IOU threshold
19let decoder = DecoderBuilder::new()
20    .with_config_yolo_det(configs::Detection {
21        anchors: None,
22        decoder: configs::DecoderType::Ultralytics,
23        quantization: Some(configs::QuantTuple(0.012345, 26)),
24        shape: vec![1, 84, 8400],
25        dshape: Vec::new(),
26        normalized: Some(true),
27    },
28    Some(DecoderVersion::Yolov8))
29    .with_score_threshold(0.25)
30    .with_iou_threshold(0.7)
31    .build()?;
32
33// Get the model output from the model. Here we load it from a test data file for demonstration purposes.
34let model_output: Vec<i8> = include_bytes!(concat!(env!("CARGO_MANIFEST_DIR"), "/../../testdata/yolov8s_80_classes.bin"))
35    .iter()
36    .map(|b| *b as i8)
37    .collect();
38let model_output_array = ndarray::Array3::from_shape_vec((1, 84, 8400), model_output)?;
39
40// THe capacity is used to determine the maximum number of detections to decode.
41let mut output_boxes: Vec<_> = Vec::with_capacity(10);
42let mut output_masks: Vec<_> = Vec::with_capacity(10);
43
44// Decode the quantized model output into detection boxes and segmentation masks
45// Because this model is a detection-only model, the `output_masks` vector will remain empty.
46decoder.decode_quantized(&[model_output_array.view().into()], &mut output_boxes, &mut output_masks)?;
47# Ok(())
48# }
49```
50
51# Overview
52
53The primary components of this crate are:
54- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
55- `yolo` module: Contains functions specific to decoding YOLO model outputs.
56- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
57
58The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
59It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
60When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
61If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
62the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
63to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
64
65The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
66which can be used if the model type and output formats are known in advance.
67
68
69*/
70#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89    yolo::yolo_segmentation_to_mask,
90};
91
92/// Trait to convert bounding box formats to XYXY float format
93pub trait BBoxTypeTrait {
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97    /// Converts the bbox into XYXY float format.
98    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99        input: &[B; 4],
100        quant: Quantization,
101    ) -> [A; 4]
102    where
103        f32: AsPrimitive<A>,
104        i32: AsPrimitive<A>;
105
106    /// Converts the bbox into XYXY float format.
107    ///
108    /// # Examples
109    /// ```rust
110    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
111    /// # use ndarray::array;
112    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
113    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
114    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
115    /// ```
116    #[inline(always)]
117    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118        input: ArrayView1<B>,
119    ) -> [A; 4] {
120        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121    }
122
123    #[inline(always)]
124    /// Converts the bbox into XYXY float format.
125    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126        input: ArrayView1<B>,
127        quant: Quantization,
128    ) -> [A; 4]
129    where
130        f32: AsPrimitive<A>,
131        i32: AsPrimitive<A>,
132    {
133        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134    }
135}
136
137/// Converts XYXY bounding boxes to XYXY
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143        input.map(|b| b.as_())
144    }
145
146    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147        input: &[B; 4],
148        quant: Quantization,
149    ) -> [A; 4]
150    where
151        f32: AsPrimitive<A>,
152        i32: AsPrimitive<A>,
153    {
154        let scale = quant.scale.as_();
155        let zp = quant.zero_point.as_();
156        input.map(|b| (b.as_() - zp) * scale)
157    }
158
159    #[inline(always)]
160    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161        input: ArrayView1<B>,
162    ) -> [A; 4] {
163        [
164            input[0].as_(),
165            input[1].as_(),
166            input[2].as_(),
167            input[3].as_(),
168        ]
169    }
170}
171
172/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
173/// box
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178    #[inline(always)]
179    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180        let half = A::one() / (A::one() + A::one());
181        [
182            (input[0].as_()) - (input[2].as_() * half),
183            (input[1].as_()) - (input[3].as_() * half),
184            (input[0].as_()) + (input[2].as_() * half),
185            (input[1].as_()) + (input[3].as_() * half),
186        ]
187    }
188
189    #[inline(always)]
190    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191        input: &[B; 4],
192        quant: Quantization,
193    ) -> [A; 4]
194    where
195        f32: AsPrimitive<A>,
196        i32: AsPrimitive<A>,
197    {
198        let scale = quant.scale.as_();
199        let half_scale = (quant.scale * 0.5).as_();
200        let zp = quant.zero_point.as_();
201        let [x, y, w, h] = [
202            (input[0].as_() - zp) * scale,
203            (input[1].as_() - zp) * scale,
204            (input[2].as_() - zp) * half_scale,
205            (input[3].as_() - zp) * half_scale,
206        ];
207
208        [x - w, y - h, x + w, y + h]
209    }
210
211    #[inline(always)]
212    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213        input: ArrayView1<B>,
214    ) -> [A; 4] {
215        let half = A::one() / (A::one() + A::one());
216        [
217            (input[0].as_()) - (input[2].as_() * half),
218            (input[1].as_()) - (input[3].as_() * half),
219            (input[0].as_()) + (input[2].as_() * half),
220            (input[1].as_()) + (input[3].as_() * half),
221        ]
222    }
223}
224
225/// Describes the quantization parameters for a tensor
226#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228    pub scale: f32,
229    pub zero_point: i32,
230}
231
232impl Quantization {
233    /// Creates a new Quantization struct
234    /// # Examples
235    /// ```
236    /// # use edgefirst_decoder::Quantization;
237    /// let quant = Quantization::new(0.1, -128);
238    /// assert_eq!(quant.scale, 0.1);
239    /// assert_eq!(quant.zero_point, -128);
240    /// ```
241    pub fn new(scale: f32, zero_point: i32) -> Self {
242        Self { scale, zero_point }
243    }
244}
245
246impl From<QuantTuple> for Quantization {
247    /// Creates a new Quantization struct from a QuantTuple
248    /// # Examples
249    /// ```
250    /// # use edgefirst_decoder::Quantization;
251    /// # use edgefirst_decoder::configs::QuantTuple;
252    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
253    /// let quant = Quantization::from(quant_tuple);
254    /// assert_eq!(quant.scale, 0.1);
255    /// assert_eq!(quant.zero_point, -128);
256    /// ```
257    fn from(quant_tuple: QuantTuple) -> Quantization {
258        Quantization {
259            scale: quant_tuple.0,
260            zero_point: quant_tuple.1,
261        }
262    }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267    S: AsPrimitive<f32>,
268    Z: AsPrimitive<i32>,
269{
270    /// Creates a new Quantization struct from a tuple
271    /// # Examples
272    /// ```
273    /// # use edgefirst_decoder::Quantization;
274    /// let quant = Quantization::from((0.1_f64, -128_i64));
275    /// assert_eq!(quant.scale, 0.1);
276    /// assert_eq!(quant.zero_point, -128);
277    /// ```
278    fn from((scale, zp): (S, Z)) -> Quantization {
279        Self {
280            scale: scale.as_(),
281            zero_point: zp.as_(),
282        }
283    }
284}
285
286impl Default for Quantization {
287    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
288    /// # Examples
289    /// ```rust
290    /// # use edgefirst_decoder::Quantization;
291    /// let quant = Quantization::default();
292    /// assert_eq!(quant.scale, 1.0);
293    /// assert_eq!(quant.zero_point, 0);
294    /// ```
295    fn default() -> Self {
296        Self {
297            scale: 1.0,
298            zero_point: 0,
299        }
300    }
301}
302
303/// A detection box with f32 bbox and score
304#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306    pub bbox: BoundingBox,
307    /// model-specific score for this detection, higher implies more confidence
308    pub score: f32,
309    /// label index for this detection
310    pub label: usize,
311}
312
313/// A bounding box with f32 coordinates in XYXY format
314#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316    /// left-most normalized coordinate of the bounding box
317    pub xmin: f32,
318    /// top-most normalized coordinate of the bounding box
319    pub ymin: f32,
320    /// right-most normalized coordinate of the bounding box
321    pub xmax: f32,
322    /// bottom-most normalized coordinate of the bounding box
323    pub ymax: f32,
324}
325
326impl BoundingBox {
327    /// Creates a new BoundingBox from the given coordinates
328    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329        Self {
330            xmin,
331            ymin,
332            xmax,
333            ymax,
334        }
335    }
336
337    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
338    ///
339    /// ```
340    /// # use edgefirst_decoder::BoundingBox;
341    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
342    /// let canonical_bbox = bbox.to_canonical();
343    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
344    /// ```
345    pub fn to_canonical(&self) -> Self {
346        let xmin = self.xmin.min(self.xmax);
347        let xmax = self.xmin.max(self.xmax);
348        let ymin = self.ymin.min(self.ymax);
349        let ymax = self.ymin.max(self.ymax);
350        BoundingBox {
351            xmin,
352            ymin,
353            xmax,
354            ymax,
355        }
356    }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
361    /// xmax, ymax order
362    /// # Examples
363    /// ```
364    /// # use edgefirst_decoder::BoundingBox;
365    /// let bbox = BoundingBox {
366    ///     xmin: 0.1,
367    ///     ymin: 0.2,
368    ///     xmax: 0.3,
369    ///     ymax: 0.4,
370    /// };
371    /// let arr: [f32; 4] = bbox.into();
372    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
373    /// ```
374    fn from(b: BoundingBox) -> Self {
375        [b.xmin, b.ymin, b.xmax, b.ymax]
376    }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
381    // BoundingBox
382    fn from(arr: [f32; 4]) -> Self {
383        BoundingBox {
384            xmin: arr[0],
385            ymin: arr[1],
386            xmax: arr[2],
387            ymax: arr[3],
388        }
389    }
390}
391
392impl DetectBox {
393    /// Returns true if one detect box is equal to another detect box, within
394    /// the given `eps`
395    ///
396    /// # Examples
397    /// ```
398    /// # use edgefirst_decoder::DetectBox;
399    /// let box1 = DetectBox {
400    ///     bbox: edgefirst_decoder::BoundingBox {
401    ///         xmin: 0.1,
402    ///         ymin: 0.2,
403    ///         xmax: 0.3,
404    ///         ymax: 0.4,
405    ///     },
406    ///     score: 0.5,
407    ///     label: 1,
408    /// };
409    /// let box2 = DetectBox {
410    ///     bbox: edgefirst_decoder::BoundingBox {
411    ///         xmin: 0.101,
412    ///         ymin: 0.199,
413    ///         xmax: 0.301,
414    ///         ymax: 0.399,
415    ///     },
416    ///     score: 0.510,
417    ///     label: 1,
418    /// };
419    /// assert!(box1.equal_within_delta(&box2, 0.011));
420    /// ```
421    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423        self.label == rhs.label
424            && eq_delta(self.score, rhs.score)
425            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429    }
430}
431
432/// A segmentation result with a segmentation mask, and a normalized bounding
433/// box representing the area that the segmentation mask covers
434#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436    /// left-most normalized coordinate of the segmentation box
437    pub xmin: f32,
438    /// top-most normalized coordinate of the segmentation box
439    pub ymin: f32,
440    /// right-most normalized coordinate of the segmentation box
441    pub xmax: f32,
442    /// bottom-most normalized coordinate of the segmentation box
443    pub ymax: f32,
444    /// 3D segmentation array of shape `(H, W, C)`.
445    ///
446    /// For instance segmentation (e.g. YOLO): `C=1` — binary per-instance
447    /// mask where values >= 128 indicate object presence.
448    ///
449    /// For semantic segmentation (e.g. ModelPack): `C=num_classes` — per-pixel
450    /// class scores where the object class is the argmax index.
451    pub segmentation: Array3<u8>,
452}
453
454/// Prototype tensor variants for fused decode+render pipelines.
455///
456/// Carries either raw quantized data (to skip CPU dequantization and let the
457/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
458/// paths).
459#[derive(Debug, Clone)]
460pub enum ProtoTensor {
461    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
462    /// The GPU fragment shader will dequantize per-texel using the scale and
463    /// zero_point.
464    Quantized {
465        protos: Array3<i8>,
466        quantization: Quantization,
467    },
468    /// Dequantized f32 protos (from float models or legacy path).
469    Float(Array3<f32>),
470}
471
472impl ProtoTensor {
473    /// Returns `true` if this is the quantized variant.
474    pub fn is_quantized(&self) -> bool {
475        matches!(self, ProtoTensor::Quantized { .. })
476    }
477
478    /// Returns the spatial dimensions `(height, width, num_protos)`.
479    pub fn dim(&self) -> (usize, usize, usize) {
480        match self {
481            ProtoTensor::Quantized { protos, .. } => protos.dim(),
482            ProtoTensor::Float(arr) => arr.dim(),
483        }
484    }
485
486    /// Returns dequantized f32 protos. For the `Float` variant this is a
487    /// no-copy reference; for `Quantized` it allocates and dequantizes.
488    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
489        match self {
490            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
491            ProtoTensor::Quantized {
492                protos,
493                quantization,
494            } => {
495                let scale = quantization.scale;
496                let zp = quantization.zero_point as f32;
497                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
498            }
499        }
500    }
501}
502
503/// Raw prototype data for fused decode+render pipelines.
504///
505/// Holds post-NMS intermediate state before mask materialization, allowing the
506/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
507/// shader) without materializing intermediate `Array3<u8>` masks.
508#[derive(Debug, Clone)]
509pub struct ProtoData {
510    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
511    pub mask_coefficients: Vec<Vec<f32>>,
512    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
513    pub protos: ProtoTensor,
514}
515
516/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
517///
518///  # Examples
519/// ```
520/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
521/// let quant = Quantization::new(0.1, -128);
522/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
523/// let detect_quant = DetectBoxQuantized {
524///     bbox,
525///     score: 100_i8,
526///     label: 1,
527/// };
528/// let detect = dequant_detect_box(&detect_quant, quant);
529/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
530/// assert_eq!(detect.label, 1);
531/// assert_eq!(detect.bbox, bbox);
532/// ```
533pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
534    detect: &DetectBoxQuantized<SCORE>,
535    quant_scores: Quantization,
536) -> DetectBox {
537    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
538    DetectBox {
539        bbox: detect.bbox,
540        score: quant_scores.scale * detect.score.as_() + scaled_zp,
541        label: detect.label,
542    }
543}
544/// A detection box with a f32 bbox and quantized score
545#[derive(Debug, Clone, Copy, PartialEq)]
546pub struct DetectBoxQuantized<
547    // BOX: Signed + PrimInt + AsPrimitive<f32>,
548    SCORE: PrimInt + AsPrimitive<f32>,
549> {
550    // pub bbox: BoundingBoxQuantized<BOX>,
551    pub bbox: BoundingBox,
552    /// model-specific score for this detection, higher implies more
553    /// confidence.
554    pub score: SCORE,
555    /// label index for this detect
556    pub label: usize,
557}
558
559/// Dequantizes an ndarray from quantized values to f32 values using the given
560/// quantization parameters
561///
562/// # Examples
563/// ```
564/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
565/// let quant = Quantization::new(0.1, -128);
566/// let input: Vec<i8> = vec![0, 127, -128, 64];
567/// let input_array = ndarray::Array1::from(input);
568/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
569/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
570/// ```
571pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
572    input: ArrayView<T, D>,
573    quant: Quantization,
574) -> Array<F, D>
575where
576    i32: num_traits::AsPrimitive<F>,
577    f32: num_traits::AsPrimitive<F>,
578{
579    let zero_point = quant.zero_point.as_();
580    let scale = quant.scale.as_();
581    if zero_point != F::zero() {
582        let scaled_zero = -zero_point * scale;
583        input.mapv(|d| d.as_() * scale + scaled_zero)
584    } else {
585        input.mapv(|d| d.as_() * scale)
586    }
587}
588
589/// Dequantizes a slice from quantized values to float values using the given
590/// quantization parameters
591///
592/// # Examples
593/// ```
594/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
595/// let quant = Quantization::new(0.1, -128);
596/// let input: Vec<i8> = vec![0, 127, -128, 64];
597/// let mut output: Vec<f32> = vec![0.0; input.len()];
598/// dequantize_cpu(&input, quant, &mut output);
599/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
600/// ```
601pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
602    input: &[T],
603    quant: Quantization,
604    output: &mut [F],
605) where
606    f32: num_traits::AsPrimitive<F>,
607    i32: num_traits::AsPrimitive<F>,
608{
609    assert!(input.len() == output.len());
610    let zero_point = quant.zero_point.as_();
611    let scale = quant.scale.as_();
612    if zero_point != F::zero() {
613        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
614        input
615            .iter()
616            .zip(output)
617            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
618    } else {
619        input
620            .iter()
621            .zip(output)
622            .for_each(|(d, deq)| *deq = d.as_() * scale);
623    }
624}
625
626/// Dequantizes a slice from quantized values to float values using the given
627/// quantization parameters, using chunked processing. This is around 5% faster
628/// than `dequantize_cpu` for large slices.
629///
630/// # Examples
631/// ```
632/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
633/// let quant = Quantization::new(0.1, -128);
634/// let input: Vec<i8> = vec![0, 127, -128, 64];
635/// let mut output: Vec<f32> = vec![0.0; input.len()];
636/// dequantize_cpu_chunked(&input, quant, &mut output);
637/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
638/// ```
639pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
640    input: &[T],
641    quant: Quantization,
642    output: &mut [F],
643) where
644    f32: num_traits::AsPrimitive<F>,
645    i32: num_traits::AsPrimitive<F>,
646{
647    assert!(input.len() == output.len());
648    let zero_point = quant.zero_point.as_();
649    let scale = quant.scale.as_();
650
651    let input = input.as_chunks::<4>();
652    let output = output.as_chunks_mut::<4>();
653
654    if zero_point != F::zero() {
655        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
656
657        input
658            .0
659            .iter()
660            .zip(output.0)
661            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
662        input
663            .1
664            .iter()
665            .zip(output.1)
666            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
667    } else {
668        input
669            .0
670            .iter()
671            .zip(output.0)
672            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
673        input
674            .1
675            .iter()
676            .zip(output.1)
677            .for_each(|(d, deq)| *deq = d.as_() * scale);
678    }
679}
680
681/// Converts a segmentation tensor into a 2D mask
682/// If the last dimension of the segmentation tensor is 1, values equal or
683/// above 128 are considered objects. Otherwise the object is the argmax index
684///
685/// # Errors
686///
687/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
688/// invalid shape.
689///
690/// # Examples
691/// ```
692/// # use edgefirst_decoder::segmentation_to_mask;
693/// let segmentation =
694///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
695/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
696/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
697/// ```
698pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
699    if segmentation.shape()[2] == 0 {
700        return Err(DecoderError::InvalidShape(
701            "Segmentation tensor must have non-zero depth".to_string(),
702        ));
703    }
704    if segmentation.shape()[2] == 1 {
705        yolo_segmentation_to_mask(segmentation, 128)
706    } else {
707        Ok(modelpack_segmentation_to_mask(segmentation))
708    }
709}
710
711/// Returns the maximum value and its index from a 1D array
712fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
713    score
714        .iter()
715        .enumerate()
716        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
717            if max > *s {
718                (max, arg_max)
719            } else {
720                (*s, ind)
721            }
722        })
723}
724#[cfg(test)]
725#[cfg_attr(coverage_nightly, coverage(off))]
726mod decoder_tests {
727    #![allow(clippy::excessive_precision)]
728    use crate::{
729        configs::{DecoderType, DimName, Protos},
730        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
731        yolo::{
732            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
733            decode_yolo_segdet_quant,
734        },
735        *,
736    };
737    use ndarray::{array, s, Array4};
738    use ndarray_stats::DeviationExt;
739
740    fn compare_outputs(
741        boxes: (&[DetectBox], &[DetectBox]),
742        masks: (&[Segmentation], &[Segmentation]),
743    ) {
744        let (boxes0, boxes1) = boxes;
745        let (masks0, masks1) = masks;
746
747        assert_eq!(boxes0.len(), boxes1.len());
748        assert_eq!(masks0.len(), masks1.len());
749
750        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
751            assert!(
752                b_i8.equal_within_delta(b_f32, 1e-6),
753                "{b_i8:?} is not equal to {b_f32:?}"
754            );
755        }
756
757        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
758            assert_eq!(
759                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
760                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
761            );
762            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
763            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
764            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
765            let diff = &mask_i8 - &mask_f32;
766            for x in 0..diff.shape()[0] {
767                for y in 0..diff.shape()[1] {
768                    for z in 0..diff.shape()[2] {
769                        let val = diff[[x, y, z]];
770                        assert!(
771                            val.abs() <= 1,
772                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
773                            x,
774                            y,
775                            z,
776                            val
777                        );
778                    }
779                }
780            }
781            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
782            assert!(
783                mean_sq_err < 1e-2,
784                "Mean Square Error between masks was greater than 1%: {:.2}%",
785                mean_sq_err * 100.0
786            );
787        }
788    }
789
790    #[test]
791    fn test_decoder_modelpack() {
792        let score_threshold = 0.45;
793        let iou_threshold = 0.45;
794        let boxes = include_bytes!(concat!(
795            env!("CARGO_MANIFEST_DIR"),
796            "/../../testdata/modelpack_boxes_1935x1x4.bin"
797        ));
798        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
799
800        let scores = include_bytes!(concat!(
801            env!("CARGO_MANIFEST_DIR"),
802            "/../../testdata/modelpack_scores_1935x1.bin"
803        ));
804        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
805
806        let quant_boxes = (0.004656755365431309, 21).into();
807        let quant_scores = (0.0019603664986789227, 0).into();
808
809        let decoder = DecoderBuilder::default()
810            .with_config_modelpack_det(
811                configs::Boxes {
812                    decoder: DecoderType::ModelPack,
813                    quantization: Some(quant_boxes),
814                    shape: vec![1, 1935, 1, 4],
815                    dshape: vec![
816                        (DimName::Batch, 1),
817                        (DimName::NumBoxes, 1935),
818                        (DimName::Padding, 1),
819                        (DimName::BoxCoords, 4),
820                    ],
821                    normalized: Some(true),
822                },
823                configs::Scores {
824                    decoder: DecoderType::ModelPack,
825                    quantization: Some(quant_scores),
826                    shape: vec![1, 1935, 1],
827                    dshape: vec![
828                        (DimName::Batch, 1),
829                        (DimName::NumBoxes, 1935),
830                        (DimName::NumClasses, 1),
831                    ],
832                },
833            )
834            .with_score_threshold(score_threshold)
835            .with_iou_threshold(iou_threshold)
836            .build()
837            .unwrap();
838
839        let quant_boxes = quant_boxes.into();
840        let quant_scores = quant_scores.into();
841
842        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
843        decode_modelpack_det(
844            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
845            (scores.slice(s![0, .., ..]), quant_scores),
846            score_threshold,
847            iou_threshold,
848            &mut output_boxes,
849        );
850        assert!(output_boxes[0].equal_within_delta(
851            &DetectBox {
852                bbox: BoundingBox {
853                    xmin: 0.40513772,
854                    ymin: 0.6379755,
855                    xmax: 0.5122431,
856                    ymax: 0.7730214,
857                },
858                score: 0.4861709,
859                label: 0
860            },
861            1e-6
862        ));
863
864        let mut output_boxes1 = Vec::with_capacity(50);
865        let mut output_masks1 = Vec::with_capacity(50);
866
867        decoder
868            .decode_quantized(
869                &[boxes.view().into(), scores.view().into()],
870                &mut output_boxes1,
871                &mut output_masks1,
872            )
873            .unwrap();
874
875        let mut output_boxes_float = Vec::with_capacity(50);
876        let mut output_masks_float = Vec::with_capacity(50);
877
878        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
879        let scores = dequantize_ndarray(scores.view(), quant_scores);
880
881        decoder
882            .decode_float::<f32>(
883                &[boxes.view().into_dyn(), scores.view().into_dyn()],
884                &mut output_boxes_float,
885                &mut output_masks_float,
886            )
887            .unwrap();
888
889        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
890        compare_outputs(
891            (&output_boxes, &output_boxes_float),
892            (&[], &output_masks_float),
893        );
894    }
895
896    #[test]
897    fn test_decoder_modelpack_split_u8() {
898        let score_threshold = 0.45;
899        let iou_threshold = 0.45;
900        let detect0 = include_bytes!(concat!(
901            env!("CARGO_MANIFEST_DIR"),
902            "/../../testdata/modelpack_split_9x15x18.bin"
903        ));
904        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
905
906        let detect1 = include_bytes!(concat!(
907            env!("CARGO_MANIFEST_DIR"),
908            "/../../testdata/modelpack_split_17x30x18.bin"
909        ));
910        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
911
912        let quant0 = (0.08547406643629074, 174).into();
913        let quant1 = (0.09929127991199493, 183).into();
914        let anchors0 = vec![
915            [0.36666667461395264, 0.31481480598449707],
916            [0.38749998807907104, 0.4740740656852722],
917            [0.5333333611488342, 0.644444465637207],
918        ];
919        let anchors1 = vec![
920            [0.13750000298023224, 0.2074074000120163],
921            [0.2541666626930237, 0.21481481194496155],
922            [0.23125000298023224, 0.35185185074806213],
923        ];
924
925        let detect_config0 = configs::Detection {
926            decoder: DecoderType::ModelPack,
927            shape: vec![1, 9, 15, 18],
928            anchors: Some(anchors0.clone()),
929            quantization: Some(quant0),
930            dshape: vec![
931                (DimName::Batch, 1),
932                (DimName::Height, 9),
933                (DimName::Width, 15),
934                (DimName::NumAnchorsXFeatures, 18),
935            ],
936            normalized: Some(true),
937        };
938
939        let detect_config1 = configs::Detection {
940            decoder: DecoderType::ModelPack,
941            shape: vec![1, 17, 30, 18],
942            anchors: Some(anchors1.clone()),
943            quantization: Some(quant1),
944            dshape: vec![
945                (DimName::Batch, 1),
946                (DimName::Height, 17),
947                (DimName::Width, 30),
948                (DimName::NumAnchorsXFeatures, 18),
949            ],
950            normalized: Some(true),
951        };
952
953        let config0 = (&detect_config0).try_into().unwrap();
954        let config1 = (&detect_config1).try_into().unwrap();
955
956        let decoder = DecoderBuilder::default()
957            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
958            .with_score_threshold(score_threshold)
959            .with_iou_threshold(iou_threshold)
960            .build()
961            .unwrap();
962
963        let quant0 = quant0.into();
964        let quant1 = quant1.into();
965
966        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
967        decode_modelpack_split_quant(
968            &[
969                detect0.slice(s![0, .., .., ..]),
970                detect1.slice(s![0, .., .., ..]),
971            ],
972            &[config0, config1],
973            score_threshold,
974            iou_threshold,
975            &mut output_boxes,
976        );
977        assert!(output_boxes[0].equal_within_delta(
978            &DetectBox {
979                bbox: BoundingBox {
980                    xmin: 0.43171933,
981                    ymin: 0.68243736,
982                    xmax: 0.5626645,
983                    ymax: 0.808863,
984                },
985                score: 0.99240804,
986                label: 0
987            },
988            1e-6
989        ));
990
991        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
992        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
993        decoder
994            .decode_quantized(
995                &[detect0.view().into(), detect1.view().into()],
996                &mut output_boxes1,
997                &mut output_masks1,
998            )
999            .unwrap();
1000
1001        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
1002        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
1003
1004        let detect0 = dequantize_ndarray(detect0.view(), quant0);
1005        let detect1 = dequantize_ndarray(detect1.view(), quant1);
1006        decoder
1007            .decode_float::<f32>(
1008                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
1009                &mut output_boxes1_f32,
1010                &mut output_masks1_f32,
1011            )
1012            .unwrap();
1013
1014        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1015        compare_outputs(
1016            (&output_boxes, &output_boxes1_f32),
1017            (&[], &output_masks1_f32),
1018        );
1019    }
1020
1021    #[test]
1022    fn test_decoder_parse_config_modelpack_split_u8() {
1023        let score_threshold = 0.45;
1024        let iou_threshold = 0.45;
1025        let detect0 = include_bytes!(concat!(
1026            env!("CARGO_MANIFEST_DIR"),
1027            "/../../testdata/modelpack_split_9x15x18.bin"
1028        ));
1029        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1030
1031        let detect1 = include_bytes!(concat!(
1032            env!("CARGO_MANIFEST_DIR"),
1033            "/../../testdata/modelpack_split_17x30x18.bin"
1034        ));
1035        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1036
1037        let decoder = DecoderBuilder::default()
1038            .with_config_yaml_str(
1039                include_str!(concat!(
1040                    env!("CARGO_MANIFEST_DIR"),
1041                    "/../../testdata/modelpack_split.yaml"
1042                ))
1043                .to_string(),
1044            )
1045            .with_score_threshold(score_threshold)
1046            .with_iou_threshold(iou_threshold)
1047            .build()
1048            .unwrap();
1049
1050        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1051        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1052        decoder
1053            .decode_quantized(
1054                &[
1055                    ArrayViewDQuantized::from(detect1.view()),
1056                    ArrayViewDQuantized::from(detect0.view()),
1057                ],
1058                &mut output_boxes,
1059                &mut output_masks,
1060            )
1061            .unwrap();
1062        assert!(output_boxes[0].equal_within_delta(
1063            &DetectBox {
1064                bbox: BoundingBox {
1065                    xmin: 0.43171933,
1066                    ymin: 0.68243736,
1067                    xmax: 0.5626645,
1068                    ymax: 0.808863,
1069                },
1070                score: 0.99240804,
1071                label: 0
1072            },
1073            1e-6
1074        ));
1075    }
1076
1077    #[test]
1078    fn test_modelpack_seg() {
1079        let out = include_bytes!(concat!(
1080            env!("CARGO_MANIFEST_DIR"),
1081            "/../../testdata/modelpack_seg_2x160x160.bin"
1082        ));
1083        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1084        let quant = (1.0 / 255.0, 0).into();
1085
1086        let decoder = DecoderBuilder::default()
1087            .with_config_modelpack_seg(configs::Segmentation {
1088                decoder: DecoderType::ModelPack,
1089                quantization: Some(quant),
1090                shape: vec![1, 2, 160, 160],
1091                dshape: vec![
1092                    (DimName::Batch, 1),
1093                    (DimName::NumClasses, 2),
1094                    (DimName::Height, 160),
1095                    (DimName::Width, 160),
1096                ],
1097            })
1098            .build()
1099            .unwrap();
1100        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1101        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1102        decoder
1103            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1104            .unwrap();
1105
1106        let mut mask = out.slice(s![0, .., .., ..]);
1107        mask.swap_axes(0, 1);
1108        mask.swap_axes(1, 2);
1109        let mask = [Segmentation {
1110            xmin: 0.0,
1111            ymin: 0.0,
1112            xmax: 1.0,
1113            ymax: 1.0,
1114            segmentation: mask.into_owned(),
1115        }];
1116        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1117
1118        decoder
1119            .decode_float::<f32>(
1120                &[dequantize_ndarray(out.view(), quant.into())
1121                    .view()
1122                    .into_dyn()],
1123                &mut output_boxes,
1124                &mut output_masks,
1125            )
1126            .unwrap();
1127
1128        // not expected for float decoder to have same values as quantized decoder, as
1129        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1130        // the model output. Thus the float output is the same as the quantized output
1131        // but scaled differently. However, it is expected that the mask after argmax
1132        // will be the same.
1133        compare_outputs((&[], &output_boxes), (&[], &[]));
1134        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1135        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1136
1137        assert_eq!(mask0, mask1);
1138    }
1139    #[test]
1140    fn test_modelpack_seg_quant() {
1141        let out = include_bytes!(concat!(
1142            env!("CARGO_MANIFEST_DIR"),
1143            "/../../testdata/modelpack_seg_2x160x160.bin"
1144        ));
1145        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1146        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1147        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1148        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1149        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1150        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1151
1152        let quant = (1.0 / 255.0, 0).into();
1153
1154        let decoder = DecoderBuilder::default()
1155            .with_config_modelpack_seg(configs::Segmentation {
1156                decoder: DecoderType::ModelPack,
1157                quantization: Some(quant),
1158                shape: vec![1, 2, 160, 160],
1159                dshape: vec![
1160                    (DimName::Batch, 1),
1161                    (DimName::NumClasses, 2),
1162                    (DimName::Height, 160),
1163                    (DimName::Width, 160),
1164                ],
1165            })
1166            .build()
1167            .unwrap();
1168        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1169        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1170        decoder
1171            .decode_quantized(
1172                &[out_u8.view().into()],
1173                &mut output_boxes,
1174                &mut output_masks_u8,
1175            )
1176            .unwrap();
1177
1178        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1179        decoder
1180            .decode_quantized(
1181                &[out_i8.view().into()],
1182                &mut output_boxes,
1183                &mut output_masks_i8,
1184            )
1185            .unwrap();
1186
1187        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1188        decoder
1189            .decode_quantized(
1190                &[out_u16.view().into()],
1191                &mut output_boxes,
1192                &mut output_masks_u16,
1193            )
1194            .unwrap();
1195
1196        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1197        decoder
1198            .decode_quantized(
1199                &[out_i16.view().into()],
1200                &mut output_boxes,
1201                &mut output_masks_i16,
1202            )
1203            .unwrap();
1204
1205        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1206        decoder
1207            .decode_quantized(
1208                &[out_u32.view().into()],
1209                &mut output_boxes,
1210                &mut output_masks_u32,
1211            )
1212            .unwrap();
1213
1214        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1215        decoder
1216            .decode_quantized(
1217                &[out_i32.view().into()],
1218                &mut output_boxes,
1219                &mut output_masks_i32,
1220            )
1221            .unwrap();
1222
1223        compare_outputs((&[], &output_boxes), (&[], &[]));
1224        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1225        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1226        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1227        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1228        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1229        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1230        assert_eq!(mask_u8, mask_i8);
1231        assert_eq!(mask_u8, mask_u16);
1232        assert_eq!(mask_u8, mask_i16);
1233        assert_eq!(mask_u8, mask_u32);
1234        assert_eq!(mask_u8, mask_i32);
1235    }
1236
1237    #[test]
1238    fn test_modelpack_segdet() {
1239        let score_threshold = 0.45;
1240        let iou_threshold = 0.45;
1241
1242        let boxes = include_bytes!(concat!(
1243            env!("CARGO_MANIFEST_DIR"),
1244            "/../../testdata/modelpack_boxes_1935x1x4.bin"
1245        ));
1246        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1247
1248        let scores = include_bytes!(concat!(
1249            env!("CARGO_MANIFEST_DIR"),
1250            "/../../testdata/modelpack_scores_1935x1.bin"
1251        ));
1252        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1253
1254        let seg = include_bytes!(concat!(
1255            env!("CARGO_MANIFEST_DIR"),
1256            "/../../testdata/modelpack_seg_2x160x160.bin"
1257        ));
1258        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1259
1260        let quant_boxes = (0.004656755365431309, 21).into();
1261        let quant_scores = (0.0019603664986789227, 0).into();
1262        let quant_seg = (1.0 / 255.0, 0).into();
1263
1264        let decoder = DecoderBuilder::default()
1265            .with_config_modelpack_segdet(
1266                configs::Boxes {
1267                    decoder: DecoderType::ModelPack,
1268                    quantization: Some(quant_boxes),
1269                    shape: vec![1, 1935, 1, 4],
1270                    dshape: vec![
1271                        (DimName::Batch, 1),
1272                        (DimName::NumBoxes, 1935),
1273                        (DimName::Padding, 1),
1274                        (DimName::BoxCoords, 4),
1275                    ],
1276                    normalized: Some(true),
1277                },
1278                configs::Scores {
1279                    decoder: DecoderType::ModelPack,
1280                    quantization: Some(quant_scores),
1281                    shape: vec![1, 1935, 1],
1282                    dshape: vec![
1283                        (DimName::Batch, 1),
1284                        (DimName::NumBoxes, 1935),
1285                        (DimName::NumClasses, 1),
1286                    ],
1287                },
1288                configs::Segmentation {
1289                    decoder: DecoderType::ModelPack,
1290                    quantization: Some(quant_seg),
1291                    shape: vec![1, 2, 160, 160],
1292                    dshape: vec![
1293                        (DimName::Batch, 1),
1294                        (DimName::NumClasses, 2),
1295                        (DimName::Height, 160),
1296                        (DimName::Width, 160),
1297                    ],
1298                },
1299            )
1300            .with_iou_threshold(iou_threshold)
1301            .with_score_threshold(score_threshold)
1302            .build()
1303            .unwrap();
1304        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1305        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1306        decoder
1307            .decode_quantized(
1308                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1309                &mut output_boxes,
1310                &mut output_masks,
1311            )
1312            .unwrap();
1313
1314        let mut mask = seg.slice(s![0, .., .., ..]);
1315        mask.swap_axes(0, 1);
1316        mask.swap_axes(1, 2);
1317        let mask = [Segmentation {
1318            xmin: 0.0,
1319            ymin: 0.0,
1320            xmax: 1.0,
1321            ymax: 1.0,
1322            segmentation: mask.into_owned(),
1323        }];
1324        let correct_boxes = [DetectBox {
1325            bbox: BoundingBox {
1326                xmin: 0.40513772,
1327                ymin: 0.6379755,
1328                xmax: 0.5122431,
1329                ymax: 0.7730214,
1330            },
1331            score: 0.4861709,
1332            label: 0,
1333        }];
1334        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1335
1336        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1337        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1338        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1339        decoder
1340            .decode_float::<f32>(
1341                &[
1342                    scores.view().into_dyn(),
1343                    boxes.view().into_dyn(),
1344                    seg.view().into_dyn(),
1345                ],
1346                &mut output_boxes,
1347                &mut output_masks,
1348            )
1349            .unwrap();
1350
1351        // not expected for float segmentation decoder to have same values as quantized
1352        // segmentation decoder, as float decoder ensures the data fills 0-255,
1353        // quantized decoder uses whatever the model output. Thus the float
1354        // output is the same as the quantized output but scaled differently.
1355        // However, it is expected that the mask after argmax will be the same.
1356        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1357        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1358        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1359
1360        assert_eq!(mask0, mask1);
1361    }
1362
1363    #[test]
1364    fn test_modelpack_segdet_split() {
1365        let score_threshold = 0.8;
1366        let iou_threshold = 0.5;
1367
1368        let seg = include_bytes!(concat!(
1369            env!("CARGO_MANIFEST_DIR"),
1370            "/../../testdata/modelpack_seg_2x160x160.bin"
1371        ));
1372        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1373
1374        let detect0 = include_bytes!(concat!(
1375            env!("CARGO_MANIFEST_DIR"),
1376            "/../../testdata/modelpack_split_9x15x18.bin"
1377        ));
1378        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1379
1380        let detect1 = include_bytes!(concat!(
1381            env!("CARGO_MANIFEST_DIR"),
1382            "/../../testdata/modelpack_split_17x30x18.bin"
1383        ));
1384        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1385
1386        let quant0 = (0.08547406643629074, 174).into();
1387        let quant1 = (0.09929127991199493, 183).into();
1388        let quant_seg = (1.0 / 255.0, 0).into();
1389
1390        let anchors0 = vec![
1391            [0.36666667461395264, 0.31481480598449707],
1392            [0.38749998807907104, 0.4740740656852722],
1393            [0.5333333611488342, 0.644444465637207],
1394        ];
1395        let anchors1 = vec![
1396            [0.13750000298023224, 0.2074074000120163],
1397            [0.2541666626930237, 0.21481481194496155],
1398            [0.23125000298023224, 0.35185185074806213],
1399        ];
1400
1401        let decoder = DecoderBuilder::default()
1402            .with_config_modelpack_segdet_split(
1403                vec![
1404                    configs::Detection {
1405                        decoder: DecoderType::ModelPack,
1406                        shape: vec![1, 17, 30, 18],
1407                        anchors: Some(anchors1),
1408                        quantization: Some(quant1),
1409                        dshape: vec![
1410                            (DimName::Batch, 1),
1411                            (DimName::Height, 17),
1412                            (DimName::Width, 30),
1413                            (DimName::NumAnchorsXFeatures, 18),
1414                        ],
1415                        normalized: Some(true),
1416                    },
1417                    configs::Detection {
1418                        decoder: DecoderType::ModelPack,
1419                        shape: vec![1, 9, 15, 18],
1420                        anchors: Some(anchors0),
1421                        quantization: Some(quant0),
1422                        dshape: vec![
1423                            (DimName::Batch, 1),
1424                            (DimName::Height, 9),
1425                            (DimName::Width, 15),
1426                            (DimName::NumAnchorsXFeatures, 18),
1427                        ],
1428                        normalized: Some(true),
1429                    },
1430                ],
1431                configs::Segmentation {
1432                    decoder: DecoderType::ModelPack,
1433                    quantization: Some(quant_seg),
1434                    shape: vec![1, 2, 160, 160],
1435                    dshape: vec![
1436                        (DimName::Batch, 1),
1437                        (DimName::NumClasses, 2),
1438                        (DimName::Height, 160),
1439                        (DimName::Width, 160),
1440                    ],
1441                },
1442            )
1443            .with_score_threshold(score_threshold)
1444            .with_iou_threshold(iou_threshold)
1445            .build()
1446            .unwrap();
1447        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1448        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1449        decoder
1450            .decode_quantized(
1451                &[
1452                    detect0.view().into(),
1453                    detect1.view().into(),
1454                    seg.view().into(),
1455                ],
1456                &mut output_boxes,
1457                &mut output_masks,
1458            )
1459            .unwrap();
1460
1461        let mut mask = seg.slice(s![0, .., .., ..]);
1462        mask.swap_axes(0, 1);
1463        mask.swap_axes(1, 2);
1464        let mask = [Segmentation {
1465            xmin: 0.0,
1466            ymin: 0.0,
1467            xmax: 1.0,
1468            ymax: 1.0,
1469            segmentation: mask.into_owned(),
1470        }];
1471        let correct_boxes = [DetectBox {
1472            bbox: BoundingBox {
1473                xmin: 0.43171933,
1474                ymin: 0.68243736,
1475                xmax: 0.5626645,
1476                ymax: 0.808863,
1477            },
1478            score: 0.99240804,
1479            label: 0,
1480        }];
1481        println!("Output Boxes: {:?}", output_boxes);
1482        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1483
1484        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1485        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1486        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1487        decoder
1488            .decode_float::<f32>(
1489                &[
1490                    detect0.view().into_dyn(),
1491                    detect1.view().into_dyn(),
1492                    seg.view().into_dyn(),
1493                ],
1494                &mut output_boxes,
1495                &mut output_masks,
1496            )
1497            .unwrap();
1498
1499        // not expected for float segmentation decoder to have same values as quantized
1500        // segmentation decoder, as float decoder ensures the data fills 0-255,
1501        // quantized decoder uses whatever the model output. Thus the float
1502        // output is the same as the quantized output but scaled differently.
1503        // However, it is expected that the mask after argmax will be the same.
1504        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1505        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1506        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1507
1508        assert_eq!(mask0, mask1);
1509    }
1510
1511    #[test]
1512    fn test_dequant_chunked() {
1513        let out = include_bytes!(concat!(
1514            env!("CARGO_MANIFEST_DIR"),
1515            "/../../testdata/yolov8s_80_classes.bin"
1516        ));
1517        let mut out =
1518            unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1519        out.push(123); // make sure to test non multiple of 16 length
1520
1521        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1522        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1523        let quant = Quantization::new(0.0040811873, -123);
1524        dequantize_cpu(&out, quant, &mut out_dequant);
1525
1526        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1527        assert_eq!(out_dequant, out_dequant_simd);
1528
1529        let quant = Quantization::new(0.0040811873, 0);
1530        dequantize_cpu(&out, quant, &mut out_dequant);
1531
1532        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1533        assert_eq!(out_dequant, out_dequant_simd);
1534    }
1535
1536    #[test]
1537    fn test_decoder_yolo_det() {
1538        let score_threshold = 0.25;
1539        let iou_threshold = 0.7;
1540        let out = include_bytes!(concat!(
1541            env!("CARGO_MANIFEST_DIR"),
1542            "/../../testdata/yolov8s_80_classes.bin"
1543        ));
1544        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1545        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1546        let quant = (0.0040811873, -123).into();
1547
1548        let decoder = DecoderBuilder::default()
1549            .with_config_yolo_det(
1550                configs::Detection {
1551                    decoder: DecoderType::Ultralytics,
1552                    shape: vec![1, 84, 8400],
1553                    anchors: None,
1554                    quantization: Some(quant),
1555                    dshape: vec![
1556                        (DimName::Batch, 1),
1557                        (DimName::NumFeatures, 84),
1558                        (DimName::NumBoxes, 8400),
1559                    ],
1560                    normalized: Some(true),
1561                },
1562                Some(DecoderVersion::Yolo11),
1563            )
1564            .with_score_threshold(score_threshold)
1565            .with_iou_threshold(iou_threshold)
1566            .build()
1567            .unwrap();
1568
1569        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1570        decode_yolo_det(
1571            (out.slice(s![0, .., ..]), quant.into()),
1572            score_threshold,
1573            iou_threshold,
1574            Some(configs::Nms::ClassAgnostic),
1575            &mut output_boxes,
1576        );
1577        assert!(output_boxes[0].equal_within_delta(
1578            &DetectBox {
1579                bbox: BoundingBox {
1580                    xmin: 0.5285137,
1581                    ymin: 0.05305544,
1582                    xmax: 0.87541467,
1583                    ymax: 0.9998909,
1584                },
1585                score: 0.5591227,
1586                label: 0
1587            },
1588            1e-6
1589        ));
1590
1591        assert!(output_boxes[1].equal_within_delta(
1592            &DetectBox {
1593                bbox: BoundingBox {
1594                    xmin: 0.130598,
1595                    ymin: 0.43260583,
1596                    xmax: 0.35098213,
1597                    ymax: 0.9958097,
1598                },
1599                score: 0.33057618,
1600                label: 75
1601            },
1602            1e-6
1603        ));
1604
1605        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1606        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1607        decoder
1608            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1609            .unwrap();
1610
1611        let out = dequantize_ndarray(out.view(), quant.into());
1612        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1613        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1614        decoder
1615            .decode_float::<f32>(
1616                &[out.view().into_dyn()],
1617                &mut output_boxes_f32,
1618                &mut output_masks_f32,
1619            )
1620            .unwrap();
1621
1622        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1623        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1624    }
1625
1626    #[test]
1627    fn test_decoder_masks() {
1628        let score_threshold = 0.45;
1629        let iou_threshold = 0.45;
1630        let boxes = include_bytes!(concat!(
1631            env!("CARGO_MANIFEST_DIR"),
1632            "/../../testdata/yolov8_boxes_116x8400.bin"
1633        ));
1634        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1635        let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1636        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1637
1638        let protos = include_bytes!(concat!(
1639            env!("CARGO_MANIFEST_DIR"),
1640            "/../../testdata/yolov8_protos_160x160x32.bin"
1641        ));
1642        let protos =
1643            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1644        let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1645        let quant_protos = Quantization::new(0.02491161972284317, -117);
1646        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1647        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1648        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1649        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1650        decode_yolo_segdet_float(
1651            seg.view(),
1652            protos.view(),
1653            score_threshold,
1654            iou_threshold,
1655            Some(configs::Nms::ClassAgnostic),
1656            &mut output_boxes,
1657            &mut output_masks,
1658        )
1659        .unwrap();
1660        assert_eq!(output_boxes.len(), 2);
1661        assert_eq!(output_boxes.len(), output_masks.len());
1662
1663        for (b, m) in output_boxes.iter().zip(&output_masks) {
1664            assert!(b.bbox.xmin >= m.xmin);
1665            assert!(b.bbox.ymin >= m.ymin);
1666            assert!(b.bbox.xmax >= m.xmax);
1667            assert!(b.bbox.ymax >= m.ymax);
1668        }
1669        assert!(output_boxes[0].equal_within_delta(
1670            &DetectBox {
1671                bbox: BoundingBox {
1672                    xmin: 0.08515105,
1673                    ymin: 0.7131401,
1674                    xmax: 0.29802868,
1675                    ymax: 0.8195788,
1676                },
1677                score: 0.91537374,
1678                label: 23
1679            },
1680            1.0 / 160.0, // wider range because mask will expand the box
1681        ));
1682
1683        assert!(output_boxes[1].equal_within_delta(
1684            &DetectBox {
1685                bbox: BoundingBox {
1686                    xmin: 0.59605736,
1687                    ymin: 0.25545314,
1688                    xmax: 0.93666154,
1689                    ymax: 0.72378385,
1690                },
1691                score: 0.91537374,
1692                label: 23
1693            },
1694            1.0 / 160.0, // wider range because mask will expand the box
1695        ));
1696
1697        let full_mask = include_bytes!(concat!(
1698            env!("CARGO_MANIFEST_DIR"),
1699            "/../../testdata/yolov8_mask_results.bin"
1700        ));
1701        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1702
1703        let cropped_mask = full_mask.slice(ndarray::s![
1704            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1705            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1706        ]);
1707
1708        assert_eq!(
1709            cropped_mask,
1710            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1711        );
1712    }
1713
1714    /// Regression test: config-driven path with NCHW protos (no dshape).
1715    /// Simulates YOLOv8-seg ONNX outputs where protos are (1, 32, 160, 160)
1716    /// and the YAML config has no dshape field — the exact scenario from
1717    /// hal_mask_matmul_bug.md.
1718    #[test]
1719    fn test_decoder_masks_nchw_protos() {
1720        let score_threshold = 0.45;
1721        let iou_threshold = 0.45;
1722
1723        // Load test data — boxes as [116, 8400]
1724        let boxes_raw = include_bytes!(concat!(
1725            env!("CARGO_MANIFEST_DIR"),
1726            "/../../testdata/yolov8_boxes_116x8400.bin"
1727        ));
1728        let boxes_raw =
1729            unsafe { std::slice::from_raw_parts(boxes_raw.as_ptr() as *const i8, boxes_raw.len()) };
1730        let boxes_2d = ndarray::Array2::from_shape_vec((116, 8400), boxes_raw.to_vec()).unwrap();
1731        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1732
1733        // Load protos as HWC [160, 160, 32] (file layout) then dequantize
1734        let protos_raw = include_bytes!(concat!(
1735            env!("CARGO_MANIFEST_DIR"),
1736            "/../../testdata/yolov8_protos_160x160x32.bin"
1737        ));
1738        let protos_raw = unsafe {
1739            std::slice::from_raw_parts(protos_raw.as_ptr() as *const i8, protos_raw.len())
1740        };
1741        let protos_hwc =
1742            ndarray::Array3::from_shape_vec((160, 160, 32), protos_raw.to_vec()).unwrap();
1743        let quant_protos = Quantization::new(0.02491161972284317, -117);
1744        let protos_f32_hwc = dequantize_ndarray::<_, _, f32>(protos_hwc.view(), quant_protos);
1745
1746        // ---- Reference: direct call with HWC protos (known working) ----
1747        let seg = dequantize_ndarray::<_, _, f32>(boxes_2d.view(), quant_boxes);
1748        let mut ref_boxes: Vec<_> = Vec::with_capacity(10);
1749        let mut ref_masks: Vec<_> = Vec::with_capacity(10);
1750        decode_yolo_segdet_float(
1751            seg.view(),
1752            protos_f32_hwc.view(),
1753            score_threshold,
1754            iou_threshold,
1755            Some(configs::Nms::ClassAgnostic),
1756            &mut ref_boxes,
1757            &mut ref_masks,
1758        )
1759        .unwrap();
1760        assert_eq!(ref_boxes.len(), 2);
1761
1762        // ---- Config-driven path: NCHW protos, no dshape ----
1763        // Permute protos to NCHW [1, 32, 160, 160] as an ONNX model would output
1764        let protos_f32_chw = protos_f32_hwc.permuted_axes([2, 0, 1]); // [32, 160, 160]
1765        let protos_nchw = protos_f32_chw.insert_axis(ndarray::Axis(0)); // [1, 32, 160, 160]
1766
1767        // Build boxes as [1, 116, 8400] f32
1768        let seg_3d = seg.insert_axis(ndarray::Axis(0)); // [1, 116, 8400]
1769
1770        // Build decoder from config with no dshape on protos
1771        let decoder = DecoderBuilder::default()
1772            .with_config_yolo_segdet(
1773                configs::Detection {
1774                    decoder: configs::DecoderType::Ultralytics,
1775                    quantization: None,
1776                    shape: vec![1, 116, 8400],
1777                    dshape: vec![],
1778                    normalized: Some(true),
1779                    anchors: None,
1780                },
1781                configs::Protos {
1782                    decoder: configs::DecoderType::Ultralytics,
1783                    quantization: None,
1784                    shape: vec![1, 32, 160, 160],
1785                    dshape: vec![], // No dshape — simulates YAML without dshape
1786                },
1787                None, // decoder version
1788            )
1789            .with_score_threshold(score_threshold)
1790            .with_iou_threshold(iou_threshold)
1791            .build()
1792            .unwrap();
1793
1794        let mut cfg_boxes: Vec<_> = Vec::with_capacity(10);
1795        let mut cfg_masks: Vec<_> = Vec::with_capacity(10);
1796        decoder
1797            .decode_float(
1798                &[seg_3d.view().into_dyn(), protos_nchw.view().into_dyn()],
1799                &mut cfg_boxes,
1800                &mut cfg_masks,
1801            )
1802            .unwrap();
1803
1804        // Must produce the same number of detections
1805        assert_eq!(
1806            cfg_boxes.len(),
1807            ref_boxes.len(),
1808            "config path produced {} boxes, reference produced {}",
1809            cfg_boxes.len(),
1810            ref_boxes.len()
1811        );
1812
1813        // Boxes must match
1814        for (i, (cb, rb)) in cfg_boxes.iter().zip(&ref_boxes).enumerate() {
1815            assert!(
1816                cb.equal_within_delta(rb, 0.01),
1817                "box {i} mismatch: config={cb:?}, reference={rb:?}"
1818            );
1819        }
1820
1821        // Masks must match pixel-for-pixel
1822        for (i, (cm, rm)) in cfg_masks.iter().zip(&ref_masks).enumerate() {
1823            let cm_arr = segmentation_to_mask(cm.segmentation.view()).unwrap();
1824            let rm_arr = segmentation_to_mask(rm.segmentation.view()).unwrap();
1825            assert_eq!(
1826                cm_arr, rm_arr,
1827                "mask {i} pixel mismatch between config-driven and reference paths"
1828            );
1829        }
1830    }
1831
1832    #[test]
1833    fn test_decoder_masks_i8() {
1834        let score_threshold = 0.45;
1835        let iou_threshold = 0.45;
1836        let boxes = include_bytes!(concat!(
1837            env!("CARGO_MANIFEST_DIR"),
1838            "/../../testdata/yolov8_boxes_116x8400.bin"
1839        ));
1840        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1841        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1842        let quant_boxes = (0.021287761628627777, 31).into();
1843
1844        let protos = include_bytes!(concat!(
1845            env!("CARGO_MANIFEST_DIR"),
1846            "/../../testdata/yolov8_protos_160x160x32.bin"
1847        ));
1848        let protos =
1849            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1850        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1851        let quant_protos = (0.02491161972284317, -117).into();
1852        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1853        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1854
1855        let decoder = DecoderBuilder::default()
1856            .with_config_yolo_segdet(
1857                configs::Detection {
1858                    decoder: configs::DecoderType::Ultralytics,
1859                    quantization: Some(quant_boxes),
1860                    shape: vec![1, 116, 8400],
1861                    anchors: None,
1862                    dshape: vec![
1863                        (DimName::Batch, 1),
1864                        (DimName::NumFeatures, 116),
1865                        (DimName::NumBoxes, 8400),
1866                    ],
1867                    normalized: Some(true),
1868                },
1869                Protos {
1870                    decoder: configs::DecoderType::Ultralytics,
1871                    quantization: Some(quant_protos),
1872                    shape: vec![1, 160, 160, 32],
1873                    dshape: vec![
1874                        (DimName::Batch, 1),
1875                        (DimName::Height, 160),
1876                        (DimName::Width, 160),
1877                        (DimName::NumProtos, 32),
1878                    ],
1879                },
1880                Some(DecoderVersion::Yolo11),
1881            )
1882            .with_score_threshold(score_threshold)
1883            .with_iou_threshold(iou_threshold)
1884            .build()
1885            .unwrap();
1886
1887        let quant_boxes = quant_boxes.into();
1888        let quant_protos = quant_protos.into();
1889
1890        decode_yolo_segdet_quant(
1891            (boxes.slice(s![0, .., ..]), quant_boxes),
1892            (protos.slice(s![0, .., .., ..]), quant_protos),
1893            score_threshold,
1894            iou_threshold,
1895            Some(configs::Nms::ClassAgnostic),
1896            &mut output_boxes,
1897            &mut output_masks,
1898        )
1899        .unwrap();
1900
1901        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1902        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1903
1904        decoder
1905            .decode_quantized(
1906                &[boxes.view().into(), protos.view().into()],
1907                &mut output_boxes1,
1908                &mut output_masks1,
1909            )
1910            .unwrap();
1911
1912        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1913        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1914
1915        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1916        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1917        decode_yolo_segdet_float(
1918            seg.slice(s![0, .., ..]),
1919            protos.slice(s![0, .., .., ..]),
1920            score_threshold,
1921            iou_threshold,
1922            Some(configs::Nms::ClassAgnostic),
1923            &mut output_boxes_f32,
1924            &mut output_masks_f32,
1925        )
1926        .unwrap();
1927
1928        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1929        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1930
1931        decoder
1932            .decode_float(
1933                &[seg.view().into_dyn(), protos.view().into_dyn()],
1934                &mut output_boxes1_f32,
1935                &mut output_masks1_f32,
1936            )
1937            .unwrap();
1938
1939        compare_outputs(
1940            (&output_boxes, &output_boxes1),
1941            (&output_masks, &output_masks1),
1942        );
1943
1944        compare_outputs(
1945            (&output_boxes, &output_boxes_f32),
1946            (&output_masks, &output_masks_f32),
1947        );
1948
1949        compare_outputs(
1950            (&output_boxes_f32, &output_boxes1_f32),
1951            (&output_masks_f32, &output_masks1_f32),
1952        );
1953    }
1954
1955    #[test]
1956    fn test_decoder_yolo_split() {
1957        let score_threshold = 0.45;
1958        let iou_threshold = 0.45;
1959        let boxes = include_bytes!(concat!(
1960            env!("CARGO_MANIFEST_DIR"),
1961            "/../../testdata/yolov8_boxes_116x8400.bin"
1962        ));
1963        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1964        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1965        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1966
1967        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1968
1969        let decoder = DecoderBuilder::default()
1970            .with_config_yolo_split_det(
1971                configs::Boxes {
1972                    decoder: configs::DecoderType::Ultralytics,
1973                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1974                    shape: vec![1, 4, 8400],
1975                    dshape: vec![
1976                        (DimName::Batch, 1),
1977                        (DimName::BoxCoords, 4),
1978                        (DimName::NumBoxes, 8400),
1979                    ],
1980                    normalized: Some(true),
1981                },
1982                configs::Scores {
1983                    decoder: configs::DecoderType::Ultralytics,
1984                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1985                    shape: vec![1, 80, 8400],
1986                    dshape: vec![
1987                        (DimName::Batch, 1),
1988                        (DimName::NumClasses, 80),
1989                        (DimName::NumBoxes, 8400),
1990                    ],
1991                },
1992            )
1993            .with_score_threshold(score_threshold)
1994            .with_iou_threshold(iou_threshold)
1995            .build()
1996            .unwrap();
1997
1998        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1999        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2000
2001        decoder
2002            .decode_quantized(
2003                &[
2004                    boxes.slice(s![.., ..4, ..]).into(),
2005                    boxes.slice(s![.., 4..84, ..]).into(),
2006                ],
2007                &mut output_boxes,
2008                &mut output_masks,
2009            )
2010            .unwrap();
2011
2012        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2013        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2014        decode_yolo_det_float(
2015            seg.slice(s![0, ..84, ..]),
2016            score_threshold,
2017            iou_threshold,
2018            Some(configs::Nms::ClassAgnostic),
2019            &mut output_boxes_f32,
2020        );
2021
2022        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2023        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2024
2025        decoder
2026            .decode_float(
2027                &[
2028                    seg.slice(s![.., ..4, ..]).into_dyn(),
2029                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2030                ],
2031                &mut output_boxes1,
2032                &mut output_masks1,
2033            )
2034            .unwrap();
2035        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
2036        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
2037    }
2038
2039    #[test]
2040    fn test_decoder_masks_config_mixed() {
2041        let score_threshold = 0.45;
2042        let iou_threshold = 0.45;
2043        let boxes = include_bytes!(concat!(
2044            env!("CARGO_MANIFEST_DIR"),
2045            "/../../testdata/yolov8_boxes_116x8400.bin"
2046        ));
2047        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2048        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
2049        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2050
2051        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
2052
2053        let protos = include_bytes!(concat!(
2054            env!("CARGO_MANIFEST_DIR"),
2055            "/../../testdata/yolov8_protos_160x160x32.bin"
2056        ));
2057        let protos =
2058            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2059        let protos: Vec<_> = protos.to_vec();
2060        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2061        let quant_protos = Quantization::new(0.02491161972284317, -117);
2062
2063        let decoder = DecoderBuilder::default()
2064            .with_config_yolo_split_segdet(
2065                configs::Boxes {
2066                    decoder: configs::DecoderType::Ultralytics,
2067                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2068                    shape: vec![1, 4, 8400],
2069                    dshape: vec![
2070                        (DimName::Batch, 1),
2071                        (DimName::BoxCoords, 4),
2072                        (DimName::NumBoxes, 8400),
2073                    ],
2074                    normalized: Some(true),
2075                },
2076                configs::Scores {
2077                    decoder: configs::DecoderType::Ultralytics,
2078                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2079                    shape: vec![1, 80, 8400],
2080                    dshape: vec![
2081                        (DimName::Batch, 1),
2082                        (DimName::NumClasses, 80),
2083                        (DimName::NumBoxes, 8400),
2084                    ],
2085                },
2086                configs::MaskCoefficients {
2087                    decoder: configs::DecoderType::Ultralytics,
2088                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2089                    shape: vec![1, 32, 8400],
2090                    dshape: vec![
2091                        (DimName::Batch, 1),
2092                        (DimName::NumProtos, 32),
2093                        (DimName::NumBoxes, 8400),
2094                    ],
2095                },
2096                configs::Protos {
2097                    decoder: configs::DecoderType::Ultralytics,
2098                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2099                    shape: vec![1, 160, 160, 32],
2100                    dshape: vec![
2101                        (DimName::Batch, 1),
2102                        (DimName::Height, 160),
2103                        (DimName::Width, 160),
2104                        (DimName::NumProtos, 32),
2105                    ],
2106                },
2107            )
2108            .with_score_threshold(score_threshold)
2109            .with_iou_threshold(iou_threshold)
2110            .build()
2111            .unwrap();
2112
2113        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2114        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2115
2116        decoder
2117            .decode_quantized(
2118                &[
2119                    boxes.slice(s![.., ..4, ..]).into(),
2120                    boxes.slice(s![.., 4..84, ..]).into(),
2121                    boxes.slice(s![.., 84.., ..]).into(),
2122                    protos.view().into(),
2123                ],
2124                &mut output_boxes,
2125                &mut output_masks,
2126            )
2127            .unwrap();
2128
2129        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2130        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2131        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2132        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
2133        decode_yolo_segdet_float(
2134            seg.slice(s![0, .., ..]),
2135            protos.slice(s![0, .., .., ..]),
2136            score_threshold,
2137            iou_threshold,
2138            Some(configs::Nms::ClassAgnostic),
2139            &mut output_boxes_f32,
2140            &mut output_masks_f32,
2141        )
2142        .unwrap();
2143
2144        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
2145        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
2146
2147        decoder
2148            .decode_float(
2149                &[
2150                    seg.slice(s![.., ..4, ..]).into_dyn(),
2151                    seg.slice(s![.., 4..84, ..]).into_dyn(),
2152                    seg.slice(s![.., 84.., ..]).into_dyn(),
2153                    protos.view().into_dyn(),
2154                ],
2155                &mut output_boxes1,
2156                &mut output_masks1,
2157            )
2158            .unwrap();
2159        compare_outputs(
2160            (&output_boxes, &output_boxes_f32),
2161            (&output_masks, &output_masks_f32),
2162        );
2163        compare_outputs(
2164            (&output_boxes_f32, &output_boxes1),
2165            (&output_masks_f32, &output_masks1),
2166        );
2167    }
2168
2169    #[test]
2170    fn test_decoder_masks_config_i32() {
2171        let score_threshold = 0.45;
2172        let iou_threshold = 0.45;
2173        let boxes = include_bytes!(concat!(
2174            env!("CARGO_MANIFEST_DIR"),
2175            "/../../testdata/yolov8_boxes_116x8400.bin"
2176        ));
2177        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2178        let scale = 1 << 23;
2179        let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
2180        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
2181
2182        let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
2183
2184        let protos = include_bytes!(concat!(
2185            env!("CARGO_MANIFEST_DIR"),
2186            "/../../testdata/yolov8_protos_160x160x32.bin"
2187        ));
2188        let protos =
2189            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2190        let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
2191        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2192        let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
2193
2194        let decoder = DecoderBuilder::default()
2195            .with_config_yolo_split_segdet(
2196                configs::Boxes {
2197                    decoder: configs::DecoderType::Ultralytics,
2198                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2199                    shape: vec![1, 4, 8400],
2200                    dshape: vec![
2201                        (DimName::Batch, 1),
2202                        (DimName::BoxCoords, 4),
2203                        (DimName::NumBoxes, 8400),
2204                    ],
2205                    normalized: Some(true),
2206                },
2207                configs::Scores {
2208                    decoder: configs::DecoderType::Ultralytics,
2209                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2210                    shape: vec![1, 80, 8400],
2211                    dshape: vec![
2212                        (DimName::Batch, 1),
2213                        (DimName::NumClasses, 80),
2214                        (DimName::NumBoxes, 8400),
2215                    ],
2216                },
2217                configs::MaskCoefficients {
2218                    decoder: configs::DecoderType::Ultralytics,
2219                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2220                    shape: vec![1, 32, 8400],
2221                    dshape: vec![
2222                        (DimName::Batch, 1),
2223                        (DimName::NumProtos, 32),
2224                        (DimName::NumBoxes, 8400),
2225                    ],
2226                },
2227                configs::Protos {
2228                    decoder: configs::DecoderType::Ultralytics,
2229                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2230                    shape: vec![1, 160, 160, 32],
2231                    dshape: vec![
2232                        (DimName::Batch, 1),
2233                        (DimName::Height, 160),
2234                        (DimName::Width, 160),
2235                        (DimName::NumProtos, 32),
2236                    ],
2237                },
2238            )
2239            .with_score_threshold(score_threshold)
2240            .with_iou_threshold(iou_threshold)
2241            .build()
2242            .unwrap();
2243
2244        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2245        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2246
2247        decoder
2248            .decode_quantized(
2249                &[
2250                    boxes.slice(s![.., ..4, ..]).into(),
2251                    boxes.slice(s![.., 4..84, ..]).into(),
2252                    boxes.slice(s![.., 84.., ..]).into(),
2253                    protos.view().into(),
2254                ],
2255                &mut output_boxes,
2256                &mut output_masks,
2257            )
2258            .unwrap();
2259
2260        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2261        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2262        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2263        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2264        decode_yolo_segdet_float(
2265            seg.slice(s![0, .., ..]),
2266            protos.slice(s![0, .., .., ..]),
2267            score_threshold,
2268            iou_threshold,
2269            Some(configs::Nms::ClassAgnostic),
2270            &mut output_boxes_f32,
2271            &mut output_masks_f32,
2272        )
2273        .unwrap();
2274
2275        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2276        assert_eq!(output_masks.len(), output_masks_f32.len());
2277
2278        compare_outputs(
2279            (&output_boxes, &output_boxes_f32),
2280            (&output_masks, &output_masks_f32),
2281        );
2282    }
2283
2284    /// test running multiple decoders concurrently
2285    #[test]
2286    fn test_context_switch() {
2287        let yolo_det = || {
2288            let score_threshold = 0.25;
2289            let iou_threshold = 0.7;
2290            let out = include_bytes!(concat!(
2291                env!("CARGO_MANIFEST_DIR"),
2292                "/../../testdata/yolov8s_80_classes.bin"
2293            ));
2294            let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2295            let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2296            let quant = (0.0040811873, -123).into();
2297
2298            let decoder = DecoderBuilder::default()
2299                .with_config_yolo_det(
2300                    configs::Detection {
2301                        decoder: DecoderType::Ultralytics,
2302                        shape: vec![1, 84, 8400],
2303                        anchors: None,
2304                        quantization: Some(quant),
2305                        dshape: vec![
2306                            (DimName::Batch, 1),
2307                            (DimName::NumFeatures, 84),
2308                            (DimName::NumBoxes, 8400),
2309                        ],
2310                        normalized: None,
2311                    },
2312                    None,
2313                )
2314                .with_score_threshold(score_threshold)
2315                .with_iou_threshold(iou_threshold)
2316                .build()
2317                .unwrap();
2318
2319            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2320            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2321
2322            for _ in 0..100 {
2323                decoder
2324                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2325                    .unwrap();
2326
2327                assert!(output_boxes[0].equal_within_delta(
2328                    &DetectBox {
2329                        bbox: BoundingBox {
2330                            xmin: 0.5285137,
2331                            ymin: 0.05305544,
2332                            xmax: 0.87541467,
2333                            ymax: 0.9998909,
2334                        },
2335                        score: 0.5591227,
2336                        label: 0
2337                    },
2338                    1e-6
2339                ));
2340
2341                assert!(output_boxes[1].equal_within_delta(
2342                    &DetectBox {
2343                        bbox: BoundingBox {
2344                            xmin: 0.130598,
2345                            ymin: 0.43260583,
2346                            xmax: 0.35098213,
2347                            ymax: 0.9958097,
2348                        },
2349                        score: 0.33057618,
2350                        label: 75
2351                    },
2352                    1e-6
2353                ));
2354                assert!(output_masks.is_empty());
2355            }
2356        };
2357
2358        let modelpack_det_split = || {
2359            let score_threshold = 0.8;
2360            let iou_threshold = 0.5;
2361
2362            let seg = include_bytes!(concat!(
2363                env!("CARGO_MANIFEST_DIR"),
2364                "/../../testdata/modelpack_seg_2x160x160.bin"
2365            ));
2366            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2367
2368            let detect0 = include_bytes!(concat!(
2369                env!("CARGO_MANIFEST_DIR"),
2370                "/../../testdata/modelpack_split_9x15x18.bin"
2371            ));
2372            let detect0 =
2373                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2374
2375            let detect1 = include_bytes!(concat!(
2376                env!("CARGO_MANIFEST_DIR"),
2377                "/../../testdata/modelpack_split_17x30x18.bin"
2378            ));
2379            let detect1 =
2380                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2381
2382            let mut mask = seg.slice(s![0, .., .., ..]);
2383            mask.swap_axes(0, 1);
2384            mask.swap_axes(1, 2);
2385            let mask = [Segmentation {
2386                xmin: 0.0,
2387                ymin: 0.0,
2388                xmax: 1.0,
2389                ymax: 1.0,
2390                segmentation: mask.into_owned(),
2391            }];
2392            let correct_boxes = [DetectBox {
2393                bbox: BoundingBox {
2394                    xmin: 0.43171933,
2395                    ymin: 0.68243736,
2396                    xmax: 0.5626645,
2397                    ymax: 0.808863,
2398                },
2399                score: 0.99240804,
2400                label: 0,
2401            }];
2402
2403            let quant0 = (0.08547406643629074, 174).into();
2404            let quant1 = (0.09929127991199493, 183).into();
2405            let quant_seg = (1.0 / 255.0, 0).into();
2406
2407            let anchors0 = vec![
2408                [0.36666667461395264, 0.31481480598449707],
2409                [0.38749998807907104, 0.4740740656852722],
2410                [0.5333333611488342, 0.644444465637207],
2411            ];
2412            let anchors1 = vec![
2413                [0.13750000298023224, 0.2074074000120163],
2414                [0.2541666626930237, 0.21481481194496155],
2415                [0.23125000298023224, 0.35185185074806213],
2416            ];
2417
2418            let decoder = DecoderBuilder::default()
2419                .with_config_modelpack_segdet_split(
2420                    vec![
2421                        configs::Detection {
2422                            decoder: DecoderType::ModelPack,
2423                            shape: vec![1, 17, 30, 18],
2424                            anchors: Some(anchors1),
2425                            quantization: Some(quant1),
2426                            dshape: vec![
2427                                (DimName::Batch, 1),
2428                                (DimName::Height, 17),
2429                                (DimName::Width, 30),
2430                                (DimName::NumAnchorsXFeatures, 18),
2431                            ],
2432                            normalized: None,
2433                        },
2434                        configs::Detection {
2435                            decoder: DecoderType::ModelPack,
2436                            shape: vec![1, 9, 15, 18],
2437                            anchors: Some(anchors0),
2438                            quantization: Some(quant0),
2439                            dshape: vec![
2440                                (DimName::Batch, 1),
2441                                (DimName::Height, 9),
2442                                (DimName::Width, 15),
2443                                (DimName::NumAnchorsXFeatures, 18),
2444                            ],
2445                            normalized: None,
2446                        },
2447                    ],
2448                    configs::Segmentation {
2449                        decoder: DecoderType::ModelPack,
2450                        quantization: Some(quant_seg),
2451                        shape: vec![1, 2, 160, 160],
2452                        dshape: vec![
2453                            (DimName::Batch, 1),
2454                            (DimName::NumClasses, 2),
2455                            (DimName::Height, 160),
2456                            (DimName::Width, 160),
2457                        ],
2458                    },
2459                )
2460                .with_score_threshold(score_threshold)
2461                .with_iou_threshold(iou_threshold)
2462                .build()
2463                .unwrap();
2464            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2465            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2466
2467            for _ in 0..100 {
2468                decoder
2469                    .decode_quantized(
2470                        &[
2471                            detect0.view().into(),
2472                            detect1.view().into(),
2473                            seg.view().into(),
2474                        ],
2475                        &mut output_boxes,
2476                        &mut output_masks,
2477                    )
2478                    .unwrap();
2479
2480                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2481            }
2482        };
2483
2484        let handles = vec![
2485            std::thread::spawn(yolo_det),
2486            std::thread::spawn(modelpack_det_split),
2487            std::thread::spawn(yolo_det),
2488            std::thread::spawn(modelpack_det_split),
2489            std::thread::spawn(yolo_det),
2490            std::thread::spawn(modelpack_det_split),
2491            std::thread::spawn(yolo_det),
2492            std::thread::spawn(modelpack_det_split),
2493        ];
2494        for handle in handles {
2495            handle.join().unwrap();
2496        }
2497    }
2498
2499    #[test]
2500    fn test_ndarray_to_xyxy_float() {
2501        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2502        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2503        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2504
2505        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2506        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2507        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2508    }
2509
2510    #[test]
2511    fn test_class_aware_nms_float() {
2512        use crate::float::nms_class_aware_float;
2513
2514        // Create two overlapping boxes with different classes
2515        let boxes = vec![
2516            DetectBox {
2517                bbox: BoundingBox {
2518                    xmin: 0.0,
2519                    ymin: 0.0,
2520                    xmax: 0.5,
2521                    ymax: 0.5,
2522                },
2523                score: 0.9,
2524                label: 0, // class 0
2525            },
2526            DetectBox {
2527                bbox: BoundingBox {
2528                    xmin: 0.1,
2529                    ymin: 0.1,
2530                    xmax: 0.6,
2531                    ymax: 0.6,
2532                },
2533                score: 0.8,
2534                label: 1, // class 1 - different class
2535            },
2536        ];
2537
2538        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2539        // threshold 0.3)
2540        let result = nms_class_aware_float(0.3, boxes.clone());
2541        assert_eq!(
2542            result.len(),
2543            2,
2544            "Class-aware NMS should keep both boxes with different classes"
2545        );
2546
2547        // Now test with same class - should suppress one
2548        let same_class_boxes = vec![
2549            DetectBox {
2550                bbox: BoundingBox {
2551                    xmin: 0.0,
2552                    ymin: 0.0,
2553                    xmax: 0.5,
2554                    ymax: 0.5,
2555                },
2556                score: 0.9,
2557                label: 0,
2558            },
2559            DetectBox {
2560                bbox: BoundingBox {
2561                    xmin: 0.1,
2562                    ymin: 0.1,
2563                    xmax: 0.6,
2564                    ymax: 0.6,
2565                },
2566                score: 0.8,
2567                label: 0, // same class
2568            },
2569        ];
2570
2571        let result = nms_class_aware_float(0.3, same_class_boxes);
2572        assert_eq!(
2573            result.len(),
2574            1,
2575            "Class-aware NMS should suppress overlapping box with same class"
2576        );
2577        assert_eq!(result[0].label, 0);
2578        assert!((result[0].score - 0.9).abs() < 1e-6);
2579    }
2580
2581    #[test]
2582    fn test_class_agnostic_vs_aware_nms() {
2583        use crate::float::{nms_class_aware_float, nms_float};
2584
2585        // Two overlapping boxes with different classes
2586        let boxes = vec![
2587            DetectBox {
2588                bbox: BoundingBox {
2589                    xmin: 0.0,
2590                    ymin: 0.0,
2591                    xmax: 0.5,
2592                    ymax: 0.5,
2593                },
2594                score: 0.9,
2595                label: 0,
2596            },
2597            DetectBox {
2598                bbox: BoundingBox {
2599                    xmin: 0.1,
2600                    ymin: 0.1,
2601                    xmax: 0.6,
2602                    ymax: 0.6,
2603                },
2604                score: 0.8,
2605                label: 1,
2606            },
2607        ];
2608
2609        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2610        let agnostic_result = nms_float(0.3, boxes.clone());
2611        assert_eq!(
2612            agnostic_result.len(),
2613            1,
2614            "Class-agnostic NMS should suppress overlapping boxes"
2615        );
2616
2617        // Class-aware should keep both (different classes)
2618        let aware_result = nms_class_aware_float(0.3, boxes);
2619        assert_eq!(
2620            aware_result.len(),
2621            2,
2622            "Class-aware NMS should keep boxes with different classes"
2623        );
2624    }
2625
2626    #[test]
2627    fn test_class_aware_nms_int() {
2628        use crate::byte::nms_class_aware_int;
2629
2630        // Create two overlapping boxes with different classes
2631        let boxes = vec![
2632            DetectBoxQuantized {
2633                bbox: BoundingBox {
2634                    xmin: 0.0,
2635                    ymin: 0.0,
2636                    xmax: 0.5,
2637                    ymax: 0.5,
2638                },
2639                score: 200_u8,
2640                label: 0,
2641            },
2642            DetectBoxQuantized {
2643                bbox: BoundingBox {
2644                    xmin: 0.1,
2645                    ymin: 0.1,
2646                    xmax: 0.6,
2647                    ymax: 0.6,
2648                },
2649                score: 180_u8,
2650                label: 1, // different class
2651            },
2652        ];
2653
2654        // Should keep both (different classes)
2655        let result = nms_class_aware_int(0.5, boxes);
2656        assert_eq!(
2657            result.len(),
2658            2,
2659            "Class-aware NMS (int) should keep boxes with different classes"
2660        );
2661    }
2662
2663    #[test]
2664    fn test_nms_enum_default() {
2665        // Test that Nms enum has the correct default
2666        let default_nms: configs::Nms = Default::default();
2667        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2668    }
2669
2670    #[test]
2671    fn test_decoder_nms_mode() {
2672        // Test that decoder properly stores NMS mode
2673        let decoder = DecoderBuilder::default()
2674            .with_config_yolo_det(
2675                configs::Detection {
2676                    anchors: None,
2677                    decoder: DecoderType::Ultralytics,
2678                    quantization: None,
2679                    shape: vec![1, 84, 8400],
2680                    dshape: Vec::new(),
2681                    normalized: Some(true),
2682                },
2683                None,
2684            )
2685            .with_nms(Some(configs::Nms::ClassAware))
2686            .build()
2687            .unwrap();
2688
2689        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2690    }
2691
2692    #[test]
2693    fn test_decoder_nms_bypass() {
2694        // Test that decoder can be configured with nms=None (bypass)
2695        let decoder = DecoderBuilder::default()
2696            .with_config_yolo_det(
2697                configs::Detection {
2698                    anchors: None,
2699                    decoder: DecoderType::Ultralytics,
2700                    quantization: None,
2701                    shape: vec![1, 84, 8400],
2702                    dshape: Vec::new(),
2703                    normalized: Some(true),
2704                },
2705                None,
2706            )
2707            .with_nms(None)
2708            .build()
2709            .unwrap();
2710
2711        assert_eq!(decoder.nms, None);
2712    }
2713
2714    #[test]
2715    fn test_decoder_normalized_boxes_true() {
2716        // Test that normalized_boxes returns Some(true) when explicitly set
2717        let decoder = DecoderBuilder::default()
2718            .with_config_yolo_det(
2719                configs::Detection {
2720                    anchors: None,
2721                    decoder: DecoderType::Ultralytics,
2722                    quantization: None,
2723                    shape: vec![1, 84, 8400],
2724                    dshape: Vec::new(),
2725                    normalized: Some(true),
2726                },
2727                None,
2728            )
2729            .build()
2730            .unwrap();
2731
2732        assert_eq!(decoder.normalized_boxes(), Some(true));
2733    }
2734
2735    #[test]
2736    fn test_decoder_normalized_boxes_false() {
2737        // Test that normalized_boxes returns Some(false) when config specifies
2738        // unnormalized
2739        let decoder = DecoderBuilder::default()
2740            .with_config_yolo_det(
2741                configs::Detection {
2742                    anchors: None,
2743                    decoder: DecoderType::Ultralytics,
2744                    quantization: None,
2745                    shape: vec![1, 84, 8400],
2746                    dshape: Vec::new(),
2747                    normalized: Some(false),
2748                },
2749                None,
2750            )
2751            .build()
2752            .unwrap();
2753
2754        assert_eq!(decoder.normalized_boxes(), Some(false));
2755    }
2756
2757    #[test]
2758    fn test_decoder_normalized_boxes_unknown() {
2759        // Test that normalized_boxes returns None when not specified in config
2760        let decoder = DecoderBuilder::default()
2761            .with_config_yolo_det(
2762                configs::Detection {
2763                    anchors: None,
2764                    decoder: DecoderType::Ultralytics,
2765                    quantization: None,
2766                    shape: vec![1, 84, 8400],
2767                    dshape: Vec::new(),
2768                    normalized: None,
2769                },
2770                Some(DecoderVersion::Yolo11),
2771            )
2772            .build()
2773            .unwrap();
2774
2775        assert_eq!(decoder.normalized_boxes(), None);
2776    }
2777}
2778
2779#[cfg(feature = "tracker")]
2780#[cfg(test)]
2781#[cfg_attr(coverage_nightly, coverage(off))]
2782mod decoder_tracked_tests {
2783
2784    use edgefirst_tracker::{ByteTrackBuilder, Tracker};
2785    use ndarray::{array, s, Array, Array2, Array3, Array4, ArrayView, Axis, Dimension};
2786    use num_traits::{AsPrimitive, Float, PrimInt};
2787    use rand::{RngExt, SeedableRng};
2788    use rand_distr::StandardNormal;
2789
2790    use crate::{
2791        configs::{self, DimName},
2792        dequantize_ndarray, BoundingBox, DecoderBuilder, DetectBox, Quantization,
2793    };
2794
2795    pub fn quantize_ndarray<T: PrimInt + 'static, D: Dimension, F: Float + AsPrimitive<T>>(
2796        input: ArrayView<F, D>,
2797        quant: Quantization,
2798    ) -> Array<T, D>
2799    where
2800        i32: num_traits::AsPrimitive<F>,
2801        f32: num_traits::AsPrimitive<F>,
2802    {
2803        let zero_point = quant.zero_point.as_();
2804        let div_scale = F::one() / quant.scale.as_();
2805        if zero_point != F::zero() {
2806            input.mapv(|d| (d * div_scale + zero_point).round().as_())
2807        } else {
2808            input.mapv(|d| (d * div_scale).round().as_())
2809        }
2810    }
2811
2812    #[test]
2813    fn test_decoder_tracked_random_jitter() {
2814        use crate::configs::{DecoderType, Nms};
2815        use crate::DecoderBuilder;
2816
2817        let score_threshold = 0.25;
2818        let iou_threshold = 0.1;
2819        let out = include_bytes!(concat!(
2820            env!("CARGO_MANIFEST_DIR"),
2821            "/../../testdata/yolov8s_80_classes.bin"
2822        ));
2823        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2824        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2825        let quant = (0.0040811873, -123).into();
2826
2827        let decoder = DecoderBuilder::default()
2828            .with_config_yolo_det(
2829                crate::configs::Detection {
2830                    decoder: DecoderType::Ultralytics,
2831                    shape: vec![1, 84, 8400],
2832                    anchors: None,
2833                    quantization: Some(quant),
2834                    dshape: vec![
2835                        (crate::configs::DimName::Batch, 1),
2836                        (crate::configs::DimName::NumFeatures, 84),
2837                        (crate::configs::DimName::NumBoxes, 8400),
2838                    ],
2839                    normalized: Some(true),
2840                },
2841                None,
2842            )
2843            .with_score_threshold(score_threshold)
2844            .with_iou_threshold(iou_threshold)
2845            .with_nms(Some(Nms::ClassAgnostic))
2846            .build()
2847            .unwrap();
2848        let mut rng = rand::rngs::StdRng::seed_from_u64(0xAB_BEEF); // fixed seed for reproducibility
2849
2850        let expected_boxes = [
2851            crate::DetectBox {
2852                bbox: crate::BoundingBox {
2853                    xmin: 0.5285137,
2854                    ymin: 0.05305544,
2855                    xmax: 0.87541467,
2856                    ymax: 0.9998909,
2857                },
2858                score: 0.5591227,
2859                label: 0,
2860            },
2861            crate::DetectBox {
2862                bbox: crate::BoundingBox {
2863                    xmin: 0.130598,
2864                    ymin: 0.43260583,
2865                    xmax: 0.35098213,
2866                    ymax: 0.9958097,
2867                },
2868                score: 0.33057618,
2869                label: 75,
2870            },
2871        ];
2872
2873        let mut tracker = ByteTrackBuilder::new()
2874            .track_update(0.1)
2875            .track_high_conf(0.3)
2876            .build();
2877
2878        let mut output_boxes = Vec::with_capacity(50);
2879        let mut output_masks = Vec::with_capacity(50);
2880        let mut output_tracks = Vec::with_capacity(50);
2881
2882        decoder
2883            .decode_tracked_quantized(
2884                &mut tracker,
2885                0,
2886                &[out.view().into()],
2887                &mut output_boxes,
2888                &mut output_masks,
2889                &mut output_tracks,
2890            )
2891            .unwrap();
2892
2893        assert_eq!(output_boxes.len(), 2);
2894        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
2895        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
2896
2897        let mut last_boxes = output_boxes.clone();
2898
2899        for i in 1..=100 {
2900            let mut out = out.clone();
2901            // introduce jitter into the XY coordinates to simulate movement and test tracking stability
2902            let mut x_values = out.slice_mut(s![0, 0, ..]);
2903            for x in x_values.iter_mut() {
2904                let r: f32 = rng.sample(StandardNormal);
2905                let r = r.clamp(-2.0, 2.0) / 2.0;
2906                *x = x.saturating_add((r * 1e-2 / quant.0) as i8);
2907            }
2908
2909            let mut y_values = out.slice_mut(s![0, 1, ..]);
2910            for y in y_values.iter_mut() {
2911                let r: f32 = rng.sample(StandardNormal);
2912                let r = r.clamp(-2.0, 2.0) / 2.0;
2913                *y = y.saturating_add((r * 1e-2 / quant.0) as i8);
2914            }
2915
2916            decoder
2917                .decode_tracked_quantized(
2918                    &mut tracker,
2919                    100_000_000 * i / 3, // simulate 33.333ms between frames
2920                    &[out.view().into()],
2921                    &mut output_boxes,
2922                    &mut output_masks,
2923                    &mut output_tracks,
2924                )
2925                .unwrap();
2926
2927            assert_eq!(output_boxes.len(), 2);
2928            assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 5e-3));
2929            assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 5e-3));
2930
2931            assert!(output_boxes[0].equal_within_delta(&last_boxes[0], 1e-3));
2932            assert!(output_boxes[1].equal_within_delta(&last_boxes[1], 1e-3));
2933            last_boxes = output_boxes.clone();
2934        }
2935    }
2936
2937    #[test]
2938    fn test_decoder_tracked_segdet() {
2939        use crate::configs::Nms;
2940        use crate::DecoderBuilder;
2941
2942        let score_threshold = 0.45;
2943        let iou_threshold = 0.45;
2944        let boxes = include_bytes!(concat!(
2945            env!("CARGO_MANIFEST_DIR"),
2946            "/../../testdata/yolov8_boxes_116x8400.bin"
2947        ));
2948        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
2949        let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
2950
2951        let protos = include_bytes!(concat!(
2952            env!("CARGO_MANIFEST_DIR"),
2953            "/../../testdata/yolov8_protos_160x160x32.bin"
2954        ));
2955        let protos =
2956            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
2957        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
2958
2959        let config = include_str!(concat!(
2960            env!("CARGO_MANIFEST_DIR"),
2961            "/../../testdata/yolov8_seg.yaml"
2962        ));
2963
2964        let decoder = DecoderBuilder::default()
2965            .with_config_yaml_str(config.to_string())
2966            .with_score_threshold(score_threshold)
2967            .with_iou_threshold(iou_threshold)
2968            .with_nms(Some(Nms::ClassAgnostic))
2969            .build()
2970            .unwrap();
2971
2972        let expected_boxes = [
2973            DetectBox {
2974                bbox: BoundingBox {
2975                    xmin: 0.08515105,
2976                    ymin: 0.7131401,
2977                    xmax: 0.29802868,
2978                    ymax: 0.8195788,
2979                },
2980                score: 0.91537374,
2981                label: 23,
2982            },
2983            DetectBox {
2984                bbox: BoundingBox {
2985                    xmin: 0.59605736,
2986                    ymin: 0.25545314,
2987                    xmax: 0.93666154,
2988                    ymax: 0.72378385,
2989                },
2990                score: 0.91537374,
2991                label: 23,
2992            },
2993        ];
2994
2995        let mut tracker = ByteTrackBuilder::new()
2996            .track_update(0.1)
2997            .track_high_conf(0.7)
2998            .build();
2999
3000        let mut output_boxes = Vec::with_capacity(50);
3001        let mut output_masks = Vec::with_capacity(50);
3002        let mut output_tracks = Vec::with_capacity(50);
3003
3004        decoder
3005            .decode_tracked_quantized(
3006                &mut tracker,
3007                0,
3008                &[boxes.view().into(), protos.view().into()],
3009                &mut output_boxes,
3010                &mut output_masks,
3011                &mut output_tracks,
3012            )
3013            .unwrap();
3014
3015        assert_eq!(output_boxes.len(), 2);
3016        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3017        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3018
3019        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3020        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3021        for score in scores_values.iter_mut() {
3022            *score = i8::MIN; // set all scores to minimum to simulate no detections
3023        }
3024        decoder
3025            .decode_tracked_quantized(
3026                &mut tracker,
3027                100_000_000 / 3,
3028                &[boxes.view().into(), protos.view().into()],
3029                &mut output_boxes,
3030                &mut output_masks,
3031                &mut output_tracks,
3032            )
3033            .unwrap();
3034
3035        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3036        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3037
3038        // no masks when the boxes are from tracker prediction without a matching detection
3039        assert!(output_masks.is_empty())
3040    }
3041
3042    #[test]
3043    fn test_decoder_tracked_segdet_float() {
3044        use crate::configs::Nms;
3045        use crate::DecoderBuilder;
3046
3047        let score_threshold = 0.45;
3048        let iou_threshold = 0.45;
3049        let boxes = include_bytes!(concat!(
3050            env!("CARGO_MANIFEST_DIR"),
3051            "/../../testdata/yolov8_boxes_116x8400.bin"
3052        ));
3053        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3054        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3055        let quant_boxes = (0.021287762, 31);
3056        let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3057
3058        let protos = include_bytes!(concat!(
3059            env!("CARGO_MANIFEST_DIR"),
3060            "/../../testdata/yolov8_protos_160x160x32.bin"
3061        ));
3062        let protos =
3063            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3064        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3065        let quant_protos = (0.02491162, -117);
3066        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3067
3068        let config = include_str!(concat!(
3069            env!("CARGO_MANIFEST_DIR"),
3070            "/../../testdata/yolov8_seg.yaml"
3071        ));
3072
3073        let decoder = DecoderBuilder::default()
3074            .with_config_yaml_str(config.to_string())
3075            .with_score_threshold(score_threshold)
3076            .with_iou_threshold(iou_threshold)
3077            .with_nms(Some(Nms::ClassAgnostic))
3078            .build()
3079            .unwrap();
3080
3081        let expected_boxes = [
3082            DetectBox {
3083                bbox: BoundingBox {
3084                    xmin: 0.08515105,
3085                    ymin: 0.7131401,
3086                    xmax: 0.29802868,
3087                    ymax: 0.8195788,
3088                },
3089                score: 0.91537374,
3090                label: 23,
3091            },
3092            DetectBox {
3093                bbox: BoundingBox {
3094                    xmin: 0.59605736,
3095                    ymin: 0.25545314,
3096                    xmax: 0.93666154,
3097                    ymax: 0.72378385,
3098                },
3099                score: 0.91537374,
3100                label: 23,
3101            },
3102        ];
3103
3104        let mut tracker = ByteTrackBuilder::new()
3105            .track_update(0.1)
3106            .track_high_conf(0.7)
3107            .build();
3108
3109        let mut output_boxes = Vec::with_capacity(50);
3110        let mut output_masks = Vec::with_capacity(50);
3111        let mut output_tracks = Vec::with_capacity(50);
3112
3113        decoder
3114            .decode_tracked_float(
3115                &mut tracker,
3116                0,
3117                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3118                &mut output_boxes,
3119                &mut output_masks,
3120                &mut output_tracks,
3121            )
3122            .unwrap();
3123
3124        assert_eq!(output_boxes.len(), 2);
3125        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3126        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3127
3128        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3129        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3130        for score in scores_values.iter_mut() {
3131            *score = 0.0; // set all scores to minimum to simulate no detections
3132        }
3133        decoder
3134            .decode_tracked_float(
3135                &mut tracker,
3136                100_000_000 / 3,
3137                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3138                &mut output_boxes,
3139                &mut output_masks,
3140                &mut output_tracks,
3141            )
3142            .unwrap();
3143
3144        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3145        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3146
3147        // no masks when the boxes are from tracker prediction without a matching detection
3148        assert!(output_masks.is_empty())
3149    }
3150
3151    #[test]
3152    fn test_decoder_tracked_segdet_proto() {
3153        use crate::configs::Nms;
3154        use crate::DecoderBuilder;
3155
3156        let score_threshold = 0.45;
3157        let iou_threshold = 0.45;
3158        let boxes = include_bytes!(concat!(
3159            env!("CARGO_MANIFEST_DIR"),
3160            "/../../testdata/yolov8_boxes_116x8400.bin"
3161        ));
3162        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3163        let mut boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3164
3165        let protos = include_bytes!(concat!(
3166            env!("CARGO_MANIFEST_DIR"),
3167            "/../../testdata/yolov8_protos_160x160x32.bin"
3168        ));
3169        let protos =
3170            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3171        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3172
3173        let config = include_str!(concat!(
3174            env!("CARGO_MANIFEST_DIR"),
3175            "/../../testdata/yolov8_seg.yaml"
3176        ));
3177
3178        let decoder = DecoderBuilder::default()
3179            .with_config_yaml_str(config.to_string())
3180            .with_score_threshold(score_threshold)
3181            .with_iou_threshold(iou_threshold)
3182            .with_nms(Some(Nms::ClassAgnostic))
3183            .build()
3184            .unwrap();
3185
3186        let expected_boxes = [
3187            DetectBox {
3188                bbox: BoundingBox {
3189                    xmin: 0.08515105,
3190                    ymin: 0.7131401,
3191                    xmax: 0.29802868,
3192                    ymax: 0.8195788,
3193                },
3194                score: 0.91537374,
3195                label: 23,
3196            },
3197            DetectBox {
3198                bbox: BoundingBox {
3199                    xmin: 0.59605736,
3200                    ymin: 0.25545314,
3201                    xmax: 0.93666154,
3202                    ymax: 0.72378385,
3203                },
3204                score: 0.91537374,
3205                label: 23,
3206            },
3207        ];
3208
3209        let mut tracker = ByteTrackBuilder::new()
3210            .track_update(0.1)
3211            .track_high_conf(0.7)
3212            .build();
3213
3214        let mut output_boxes = Vec::with_capacity(50);
3215        let mut output_tracks = Vec::with_capacity(50);
3216
3217        decoder
3218            .decode_tracked_quantized_proto(
3219                &mut tracker,
3220                0,
3221                &[boxes.view().into(), protos.view().into()],
3222                &mut output_boxes,
3223                &mut output_tracks,
3224            )
3225            .unwrap();
3226
3227        assert_eq!(output_boxes.len(), 2);
3228        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3229        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3230
3231        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3232        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3233        for score in scores_values.iter_mut() {
3234            *score = i8::MIN; // set all scores to minimum to simulate no detections
3235        }
3236        let protos = decoder
3237            .decode_tracked_quantized_proto(
3238                &mut tracker,
3239                100_000_000 / 3,
3240                &[boxes.view().into(), protos.view().into()],
3241                &mut output_boxes,
3242                &mut output_tracks,
3243            )
3244            .unwrap();
3245
3246        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3247        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3248
3249        // no masks when the boxes are from tracker prediction without a matching detection
3250        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3251    }
3252
3253    #[test]
3254    fn test_decoder_tracked_segdet_proto_float() {
3255        use crate::configs::Nms;
3256        use crate::DecoderBuilder;
3257
3258        let score_threshold = 0.45;
3259        let iou_threshold = 0.45;
3260        let boxes = include_bytes!(concat!(
3261            env!("CARGO_MANIFEST_DIR"),
3262            "/../../testdata/yolov8_boxes_116x8400.bin"
3263        ));
3264        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3265        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
3266        let quant_boxes = (0.021287762, 31);
3267        let mut boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3268
3269        let protos = include_bytes!(concat!(
3270            env!("CARGO_MANIFEST_DIR"),
3271            "/../../testdata/yolov8_protos_160x160x32.bin"
3272        ));
3273        let protos =
3274            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3275        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
3276        let quant_protos = (0.02491162, -117);
3277        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3278
3279        let config = include_str!(concat!(
3280            env!("CARGO_MANIFEST_DIR"),
3281            "/../../testdata/yolov8_seg.yaml"
3282        ));
3283
3284        let decoder = DecoderBuilder::default()
3285            .with_config_yaml_str(config.to_string())
3286            .with_score_threshold(score_threshold)
3287            .with_iou_threshold(iou_threshold)
3288            .with_nms(Some(Nms::ClassAgnostic))
3289            .build()
3290            .unwrap();
3291
3292        let expected_boxes = [
3293            DetectBox {
3294                bbox: BoundingBox {
3295                    xmin: 0.08515105,
3296                    ymin: 0.7131401,
3297                    xmax: 0.29802868,
3298                    ymax: 0.8195788,
3299                },
3300                score: 0.91537374,
3301                label: 23,
3302            },
3303            DetectBox {
3304                bbox: BoundingBox {
3305                    xmin: 0.59605736,
3306                    ymin: 0.25545314,
3307                    xmax: 0.93666154,
3308                    ymax: 0.72378385,
3309                },
3310                score: 0.91537374,
3311                label: 23,
3312            },
3313        ];
3314
3315        let mut tracker = ByteTrackBuilder::new()
3316            .track_update(0.1)
3317            .track_high_conf(0.7)
3318            .build();
3319
3320        let mut output_boxes = Vec::with_capacity(50);
3321        let mut output_tracks = Vec::with_capacity(50);
3322
3323        decoder
3324            .decode_tracked_float_proto(
3325                &mut tracker,
3326                0,
3327                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3328                &mut output_boxes,
3329                &mut output_tracks,
3330            )
3331            .unwrap();
3332
3333        assert_eq!(output_boxes.len(), 2);
3334        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3335        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3336
3337        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3338        let mut scores_values = boxes.slice_mut(s![0, 4..84, ..]);
3339        for score in scores_values.iter_mut() {
3340            *score = 0.0; // set all scores to minimum to simulate no detections
3341        }
3342        let protos = decoder
3343            .decode_tracked_float_proto(
3344                &mut tracker,
3345                100_000_000 / 3,
3346                &[boxes.view().into_dyn(), protos.view().into_dyn()],
3347                &mut output_boxes,
3348                &mut output_tracks,
3349            )
3350            .unwrap();
3351
3352        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3353        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3354
3355        // no masks when the boxes are from tracker prediction without a matching detection
3356        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3357    }
3358
3359    #[test]
3360    fn test_decoder_tracked_segdet_split() {
3361        let score_threshold = 0.45;
3362        let iou_threshold = 0.45;
3363
3364        let boxes = include_bytes!(concat!(
3365            env!("CARGO_MANIFEST_DIR"),
3366            "/../../testdata/yolov8_boxes_116x8400.bin"
3367        ));
3368        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3369        let boxes = boxes.to_vec();
3370        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3371
3372        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3373        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3374        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3375
3376        let quant_boxes = (0.021287762, 31);
3377
3378        let protos = include_bytes!(concat!(
3379            env!("CARGO_MANIFEST_DIR"),
3380            "/../../testdata/yolov8_protos_160x160x32.bin"
3381        ));
3382        let protos =
3383            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3384        let protos = protos.to_vec();
3385        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3386        let quant_protos = (0.02491162, -117);
3387        let decoder = DecoderBuilder::default()
3388            .with_config_yolo_split_segdet(
3389                configs::Boxes {
3390                    decoder: configs::DecoderType::Ultralytics,
3391                    quantization: Some(quant_boxes.into()),
3392                    shape: vec![1, 4, 8400],
3393                    dshape: vec![
3394                        (DimName::Batch, 1),
3395                        (DimName::BoxCoords, 4),
3396                        (DimName::NumBoxes, 8400),
3397                    ],
3398                    normalized: Some(true),
3399                },
3400                configs::Scores {
3401                    decoder: configs::DecoderType::Ultralytics,
3402                    quantization: Some(quant_boxes.into()),
3403                    shape: vec![1, 80, 8400],
3404                    dshape: vec![
3405                        (DimName::Batch, 1),
3406                        (DimName::NumClasses, 80),
3407                        (DimName::NumBoxes, 8400),
3408                    ],
3409                },
3410                configs::MaskCoefficients {
3411                    decoder: configs::DecoderType::Ultralytics,
3412                    quantization: Some(quant_boxes.into()),
3413                    shape: vec![1, 32, 8400],
3414                    dshape: vec![
3415                        (DimName::Batch, 1),
3416                        (DimName::NumProtos, 32),
3417                        (DimName::NumBoxes, 8400),
3418                    ],
3419                },
3420                configs::Protos {
3421                    decoder: configs::DecoderType::Ultralytics,
3422                    quantization: Some(quant_protos.into()),
3423                    shape: vec![1, 160, 160, 32],
3424                    dshape: vec![
3425                        (DimName::Batch, 1),
3426                        (DimName::Height, 160),
3427                        (DimName::Width, 160),
3428                        (DimName::NumProtos, 32),
3429                    ],
3430                },
3431            )
3432            .with_score_threshold(score_threshold)
3433            .with_iou_threshold(iou_threshold)
3434            .build()
3435            .unwrap();
3436
3437        let expected_boxes = [
3438            DetectBox {
3439                bbox: BoundingBox {
3440                    xmin: 0.08515105,
3441                    ymin: 0.7131401,
3442                    xmax: 0.29802868,
3443                    ymax: 0.8195788,
3444                },
3445                score: 0.91537374,
3446                label: 23,
3447            },
3448            DetectBox {
3449                bbox: BoundingBox {
3450                    xmin: 0.59605736,
3451                    ymin: 0.25545314,
3452                    xmax: 0.93666154,
3453                    ymax: 0.72378385,
3454                },
3455                score: 0.91537374,
3456                label: 23,
3457            },
3458        ];
3459
3460        let mut tracker = ByteTrackBuilder::new()
3461            .track_update(0.1)
3462            .track_high_conf(0.7)
3463            .build();
3464
3465        let mut output_boxes = Vec::with_capacity(50);
3466        let mut output_masks = Vec::with_capacity(50);
3467        let mut output_tracks = Vec::with_capacity(50);
3468
3469        decoder
3470            .decode_tracked_quantized(
3471                &mut tracker,
3472                0,
3473                &[
3474                    boxes.view().into(),
3475                    scores.view().into(),
3476                    mask.view().into(),
3477                    protos.view().into(),
3478                ],
3479                &mut output_boxes,
3480                &mut output_masks,
3481                &mut output_tracks,
3482            )
3483            .unwrap();
3484
3485        assert_eq!(output_boxes.len(), 2);
3486        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3487        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3488
3489        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3490
3491        for score in scores.iter_mut() {
3492            *score = i8::MIN; // set all scores to minimum to simulate no detections
3493        }
3494        decoder
3495            .decode_tracked_quantized(
3496                &mut tracker,
3497                100_000_000 / 3,
3498                &[
3499                    boxes.view().into(),
3500                    scores.view().into(),
3501                    mask.view().into(),
3502                    protos.view().into(),
3503                ],
3504                &mut output_boxes,
3505                &mut output_masks,
3506                &mut output_tracks,
3507            )
3508            .unwrap();
3509
3510        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3511        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3512
3513        // no masks when the boxes are from tracker prediction without a matching detection
3514        assert!(output_masks.is_empty())
3515    }
3516
3517    #[test]
3518    fn test_decoder_tracked_segdet_split_float() {
3519        let score_threshold = 0.45;
3520        let iou_threshold = 0.45;
3521
3522        let boxes = include_bytes!(concat!(
3523            env!("CARGO_MANIFEST_DIR"),
3524            "/../../testdata/yolov8_boxes_116x8400.bin"
3525        ));
3526        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3527        let boxes = boxes.to_vec();
3528        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3529        let quant_boxes = (0.021287762, 31);
3530        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3531
3532        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3533        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3534        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3535
3536        let protos = include_bytes!(concat!(
3537            env!("CARGO_MANIFEST_DIR"),
3538            "/../../testdata/yolov8_protos_160x160x32.bin"
3539        ));
3540        let protos =
3541            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3542        let protos = protos.to_vec();
3543        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3544        let quant_protos = (0.02491162, -117);
3545        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3546
3547        let decoder = DecoderBuilder::default()
3548            .with_config_yolo_split_segdet(
3549                configs::Boxes {
3550                    decoder: configs::DecoderType::Ultralytics,
3551                    quantization: Some(quant_boxes.into()),
3552                    shape: vec![1, 4, 8400],
3553                    dshape: vec![
3554                        (DimName::Batch, 1),
3555                        (DimName::BoxCoords, 4),
3556                        (DimName::NumBoxes, 8400),
3557                    ],
3558                    normalized: Some(true),
3559                },
3560                configs::Scores {
3561                    decoder: configs::DecoderType::Ultralytics,
3562                    quantization: Some(quant_boxes.into()),
3563                    shape: vec![1, 80, 8400],
3564                    dshape: vec![
3565                        (DimName::Batch, 1),
3566                        (DimName::NumClasses, 80),
3567                        (DimName::NumBoxes, 8400),
3568                    ],
3569                },
3570                configs::MaskCoefficients {
3571                    decoder: configs::DecoderType::Ultralytics,
3572                    quantization: Some(quant_boxes.into()),
3573                    shape: vec![1, 32, 8400],
3574                    dshape: vec![
3575                        (DimName::Batch, 1),
3576                        (DimName::NumProtos, 32),
3577                        (DimName::NumBoxes, 8400),
3578                    ],
3579                },
3580                configs::Protos {
3581                    decoder: configs::DecoderType::Ultralytics,
3582                    quantization: Some(quant_protos.into()),
3583                    shape: vec![1, 160, 160, 32],
3584                    dshape: vec![
3585                        (DimName::Batch, 1),
3586                        (DimName::Height, 160),
3587                        (DimName::Width, 160),
3588                        (DimName::NumProtos, 32),
3589                    ],
3590                },
3591            )
3592            .with_score_threshold(score_threshold)
3593            .with_iou_threshold(iou_threshold)
3594            .build()
3595            .unwrap();
3596
3597        let expected_boxes = [
3598            DetectBox {
3599                bbox: BoundingBox {
3600                    xmin: 0.08515105,
3601                    ymin: 0.7131401,
3602                    xmax: 0.29802868,
3603                    ymax: 0.8195788,
3604                },
3605                score: 0.91537374,
3606                label: 23,
3607            },
3608            DetectBox {
3609                bbox: BoundingBox {
3610                    xmin: 0.59605736,
3611                    ymin: 0.25545314,
3612                    xmax: 0.93666154,
3613                    ymax: 0.72378385,
3614                },
3615                score: 0.91537374,
3616                label: 23,
3617            },
3618        ];
3619
3620        let mut tracker = ByteTrackBuilder::new()
3621            .track_update(0.1)
3622            .track_high_conf(0.7)
3623            .build();
3624
3625        let mut output_boxes = Vec::with_capacity(50);
3626        let mut output_masks = Vec::with_capacity(50);
3627        let mut output_tracks = Vec::with_capacity(50);
3628
3629        decoder
3630            .decode_tracked_float(
3631                &mut tracker,
3632                0,
3633                &[
3634                    boxes.view().into_dyn(),
3635                    scores.view().into_dyn(),
3636                    mask.view().into_dyn(),
3637                    protos.view().into_dyn(),
3638                ],
3639                &mut output_boxes,
3640                &mut output_masks,
3641                &mut output_tracks,
3642            )
3643            .unwrap();
3644
3645        assert_eq!(output_boxes.len(), 2);
3646        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3647        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3648
3649        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3650
3651        for score in scores.iter_mut() {
3652            *score = 0.0; // set all scores to minimum to simulate no detections
3653        }
3654        decoder
3655            .decode_tracked_float(
3656                &mut tracker,
3657                100_000_000 / 3,
3658                &[
3659                    boxes.view().into_dyn(),
3660                    scores.view().into_dyn(),
3661                    mask.view().into_dyn(),
3662                    protos.view().into_dyn(),
3663                ],
3664                &mut output_boxes,
3665                &mut output_masks,
3666                &mut output_tracks,
3667            )
3668            .unwrap();
3669
3670        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3671        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3672
3673        // no masks when the boxes are from tracker prediction without a matching detection
3674        assert!(output_masks.is_empty())
3675    }
3676
3677    #[test]
3678    fn test_decoder_tracked_segdet_split_proto() {
3679        let score_threshold = 0.45;
3680        let iou_threshold = 0.45;
3681
3682        let boxes = include_bytes!(concat!(
3683            env!("CARGO_MANIFEST_DIR"),
3684            "/../../testdata/yolov8_boxes_116x8400.bin"
3685        ));
3686        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3687        let boxes = boxes.to_vec();
3688        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3689
3690        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3691        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3692        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3693
3694        let quant_boxes = (0.021287762, 31);
3695
3696        let protos = include_bytes!(concat!(
3697            env!("CARGO_MANIFEST_DIR"),
3698            "/../../testdata/yolov8_protos_160x160x32.bin"
3699        ));
3700        let protos =
3701            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3702        let protos = protos.to_vec();
3703        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3704        let quant_protos = (0.02491162, -117);
3705        let decoder = DecoderBuilder::default()
3706            .with_config_yolo_split_segdet(
3707                configs::Boxes {
3708                    decoder: configs::DecoderType::Ultralytics,
3709                    quantization: Some(quant_boxes.into()),
3710                    shape: vec![1, 4, 8400],
3711                    dshape: vec![
3712                        (DimName::Batch, 1),
3713                        (DimName::BoxCoords, 4),
3714                        (DimName::NumBoxes, 8400),
3715                    ],
3716                    normalized: Some(true),
3717                },
3718                configs::Scores {
3719                    decoder: configs::DecoderType::Ultralytics,
3720                    quantization: Some(quant_boxes.into()),
3721                    shape: vec![1, 80, 8400],
3722                    dshape: vec![
3723                        (DimName::Batch, 1),
3724                        (DimName::NumClasses, 80),
3725                        (DimName::NumBoxes, 8400),
3726                    ],
3727                },
3728                configs::MaskCoefficients {
3729                    decoder: configs::DecoderType::Ultralytics,
3730                    quantization: Some(quant_boxes.into()),
3731                    shape: vec![1, 32, 8400],
3732                    dshape: vec![
3733                        (DimName::Batch, 1),
3734                        (DimName::NumProtos, 32),
3735                        (DimName::NumBoxes, 8400),
3736                    ],
3737                },
3738                configs::Protos {
3739                    decoder: configs::DecoderType::Ultralytics,
3740                    quantization: Some(quant_protos.into()),
3741                    shape: vec![1, 160, 160, 32],
3742                    dshape: vec![
3743                        (DimName::Batch, 1),
3744                        (DimName::Height, 160),
3745                        (DimName::Width, 160),
3746                        (DimName::NumProtos, 32),
3747                    ],
3748                },
3749            )
3750            .with_score_threshold(score_threshold)
3751            .with_iou_threshold(iou_threshold)
3752            .build()
3753            .unwrap();
3754
3755        let expected_boxes = [
3756            DetectBox {
3757                bbox: BoundingBox {
3758                    xmin: 0.08515105,
3759                    ymin: 0.7131401,
3760                    xmax: 0.29802868,
3761                    ymax: 0.8195788,
3762                },
3763                score: 0.91537374,
3764                label: 23,
3765            },
3766            DetectBox {
3767                bbox: BoundingBox {
3768                    xmin: 0.59605736,
3769                    ymin: 0.25545314,
3770                    xmax: 0.93666154,
3771                    ymax: 0.72378385,
3772                },
3773                score: 0.91537374,
3774                label: 23,
3775            },
3776        ];
3777
3778        let mut tracker = ByteTrackBuilder::new()
3779            .track_update(0.1)
3780            .track_high_conf(0.7)
3781            .build();
3782
3783        let mut output_boxes = Vec::with_capacity(50);
3784        let mut output_tracks = Vec::with_capacity(50);
3785
3786        decoder
3787            .decode_tracked_quantized_proto(
3788                &mut tracker,
3789                0,
3790                &[
3791                    boxes.view().into(),
3792                    scores.view().into(),
3793                    mask.view().into(),
3794                    protos.view().into(),
3795                ],
3796                &mut output_boxes,
3797                &mut output_tracks,
3798            )
3799            .unwrap();
3800
3801        assert_eq!(output_boxes.len(), 2);
3802        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3803        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3804
3805        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3806
3807        for score in scores.iter_mut() {
3808            *score = i8::MIN; // set all scores to minimum to simulate no detections
3809        }
3810        let protos = decoder
3811            .decode_tracked_quantized_proto(
3812                &mut tracker,
3813                100_000_000 / 3,
3814                &[
3815                    boxes.view().into(),
3816                    scores.view().into(),
3817                    mask.view().into(),
3818                    protos.view().into(),
3819                ],
3820                &mut output_boxes,
3821                &mut output_tracks,
3822            )
3823            .unwrap();
3824
3825        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3826        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3827
3828        // no masks when the boxes are from tracker prediction without a matching detection
3829        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3830    }
3831
3832    #[test]
3833    fn test_decoder_tracked_segdet_split_proto_float() {
3834        let score_threshold = 0.45;
3835        let iou_threshold = 0.45;
3836
3837        let boxes = include_bytes!(concat!(
3838            env!("CARGO_MANIFEST_DIR"),
3839            "/../../testdata/yolov8_boxes_116x8400.bin"
3840        ));
3841        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
3842        let boxes = boxes.to_vec();
3843        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
3844        let quant_boxes = (0.021287762, 31);
3845        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
3846
3847        let mask = boxes.slice(s![.., 84.., ..]).to_owned();
3848        let mut scores = boxes.slice(s![.., 4..84, ..]).to_owned();
3849        let boxes = boxes.slice(s![.., ..4, ..]).to_owned();
3850
3851        let protos = include_bytes!(concat!(
3852            env!("CARGO_MANIFEST_DIR"),
3853            "/../../testdata/yolov8_protos_160x160x32.bin"
3854        ));
3855        let protos =
3856            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
3857        let protos = protos.to_vec();
3858        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos).unwrap();
3859        let quant_protos = (0.02491162, -117);
3860        let protos = dequantize_ndarray(protos.view(), quant_protos.into());
3861
3862        let decoder = DecoderBuilder::default()
3863            .with_config_yolo_split_segdet(
3864                configs::Boxes {
3865                    decoder: configs::DecoderType::Ultralytics,
3866                    quantization: Some(quant_boxes.into()),
3867                    shape: vec![1, 4, 8400],
3868                    dshape: vec![
3869                        (DimName::Batch, 1),
3870                        (DimName::BoxCoords, 4),
3871                        (DimName::NumBoxes, 8400),
3872                    ],
3873                    normalized: Some(true),
3874                },
3875                configs::Scores {
3876                    decoder: configs::DecoderType::Ultralytics,
3877                    quantization: Some(quant_boxes.into()),
3878                    shape: vec![1, 80, 8400],
3879                    dshape: vec![
3880                        (DimName::Batch, 1),
3881                        (DimName::NumClasses, 80),
3882                        (DimName::NumBoxes, 8400),
3883                    ],
3884                },
3885                configs::MaskCoefficients {
3886                    decoder: configs::DecoderType::Ultralytics,
3887                    quantization: Some(quant_boxes.into()),
3888                    shape: vec![1, 32, 8400],
3889                    dshape: vec![
3890                        (DimName::Batch, 1),
3891                        (DimName::NumProtos, 32),
3892                        (DimName::NumBoxes, 8400),
3893                    ],
3894                },
3895                configs::Protos {
3896                    decoder: configs::DecoderType::Ultralytics,
3897                    quantization: Some(quant_protos.into()),
3898                    shape: vec![1, 160, 160, 32],
3899                    dshape: vec![
3900                        (DimName::Batch, 1),
3901                        (DimName::Height, 160),
3902                        (DimName::Width, 160),
3903                        (DimName::NumProtos, 32),
3904                    ],
3905                },
3906            )
3907            .with_score_threshold(score_threshold)
3908            .with_iou_threshold(iou_threshold)
3909            .build()
3910            .unwrap();
3911
3912        let expected_boxes = [
3913            DetectBox {
3914                bbox: BoundingBox {
3915                    xmin: 0.08515105,
3916                    ymin: 0.7131401,
3917                    xmax: 0.29802868,
3918                    ymax: 0.8195788,
3919                },
3920                score: 0.91537374,
3921                label: 23,
3922            },
3923            DetectBox {
3924                bbox: BoundingBox {
3925                    xmin: 0.59605736,
3926                    ymin: 0.25545314,
3927                    xmax: 0.93666154,
3928                    ymax: 0.72378385,
3929                },
3930                score: 0.91537374,
3931                label: 23,
3932            },
3933        ];
3934
3935        let mut tracker = ByteTrackBuilder::new()
3936            .track_update(0.1)
3937            .track_high_conf(0.7)
3938            .build();
3939
3940        let mut output_boxes = Vec::with_capacity(50);
3941        let mut output_tracks = Vec::with_capacity(50);
3942
3943        decoder
3944            .decode_tracked_float_proto(
3945                &mut tracker,
3946                0,
3947                &[
3948                    boxes.view().into_dyn(),
3949                    scores.view().into_dyn(),
3950                    mask.view().into_dyn(),
3951                    protos.view().into_dyn(),
3952                ],
3953                &mut output_boxes,
3954                &mut output_tracks,
3955            )
3956            .unwrap();
3957
3958        assert_eq!(output_boxes.len(), 2);
3959        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
3960        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1.0 / 160.0));
3961
3962        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
3963
3964        for score in scores.iter_mut() {
3965            *score = 0.0; // set all scores to minimum to simulate no detections
3966        }
3967        let protos = decoder
3968            .decode_tracked_float_proto(
3969                &mut tracker,
3970                100_000_000 / 3,
3971                &[
3972                    boxes.view().into_dyn(),
3973                    scores.view().into_dyn(),
3974                    mask.view().into_dyn(),
3975                    protos.view().into_dyn(),
3976                ],
3977                &mut output_boxes,
3978                &mut output_tracks,
3979            )
3980            .unwrap();
3981
3982        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
3983        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
3984
3985        // no masks when the boxes are from tracker prediction without a matching detection
3986        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()));
3987    }
3988
3989    #[test]
3990    fn test_decoder_tracked_end_to_end_segdet() {
3991        let score_threshold = 0.45;
3992        let iou_threshold = 0.45;
3993
3994        let mut boxes = Array2::zeros((10, 4));
3995        let mut scores = Array2::zeros((10, 1));
3996        let mut classes = Array2::zeros((10, 1));
3997        let mask = Array2::zeros((10, 32));
3998        let protos = Array3::<f64>::zeros((160, 160, 32));
3999        let protos = protos.insert_axis(Axis(0));
4000
4001        let protos_quant = (1.0 / 255.0, 0.0);
4002        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4003
4004        boxes
4005            .slice_mut(s![0, ..,])
4006            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4007        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4008        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4009
4010        let detect = ndarray::concatenate![
4011            Axis(1),
4012            boxes.view(),
4013            scores.view(),
4014            classes.view(),
4015            mask.view()
4016        ];
4017        let detect = detect.insert_axis(Axis(0));
4018        assert_eq!(detect.shape(), &[1, 10, 38]);
4019        let detect_quant = (2.0 / 255.0, 0.0);
4020        let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4021        let config = "
4022decoder_version: yolo26
4023outputs:
4024 - type: detection
4025   decoder: ultralytics
4026   quantization: [0.00784313725490196, 0]
4027   shape: [1, 10, 38]
4028   dshape:
4029    - [batch, 1]
4030    - [num_boxes, 10]
4031    - [num_features, 38]
4032   normalized: true
4033 - type: protos
4034   decoder: ultralytics
4035   quantization: [0.0039215686274509803921568627451, 128]
4036   shape: [1, 160, 160, 32]
4037   dshape:
4038    - [batch, 1]
4039    - [height, 160]
4040    - [width, 160]
4041    - [num_protos, 32]
4042";
4043
4044        let decoder = DecoderBuilder::default()
4045            .with_config_yaml_str(config.to_string())
4046            .with_score_threshold(score_threshold)
4047            .with_iou_threshold(iou_threshold)
4048            .build()
4049            .unwrap();
4050
4051        // Expected boxes doesn't match the float values exactly due to quantization error
4052        let expected_boxes = [DetectBox {
4053            bbox: BoundingBox {
4054                xmin: 0.12549022,
4055                ymin: 0.12549022,
4056                xmax: 0.23529413,
4057                ymax: 0.23529413,
4058            },
4059            score: 0.98823535,
4060            label: 2,
4061        }];
4062
4063        let mut tracker = ByteTrackBuilder::new()
4064            .track_update(0.1)
4065            .track_high_conf(0.7)
4066            .build();
4067
4068        let mut output_boxes = Vec::with_capacity(50);
4069        let mut output_masks = Vec::with_capacity(50);
4070        let mut output_tracks = Vec::with_capacity(50);
4071
4072        decoder
4073            .decode_tracked_quantized(
4074                &mut tracker,
4075                0,
4076                &[detect.view().into(), protos.view().into()],
4077                &mut output_boxes,
4078                &mut output_masks,
4079                &mut output_tracks,
4080            )
4081            .unwrap();
4082
4083        assert_eq!(output_boxes.len(), 1);
4084        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4085
4086        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4087
4088        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4089            *score = u8::MIN; // set all scores to minimum to simulate no detections
4090        }
4091
4092        decoder
4093            .decode_tracked_quantized(
4094                &mut tracker,
4095                100_000_000 / 3,
4096                &[detect.view().into(), protos.view().into()],
4097                &mut output_boxes,
4098                &mut output_masks,
4099                &mut output_tracks,
4100            )
4101            .unwrap();
4102        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4103        // no masks when the boxes are from tracker prediction without a matching detection
4104        assert!(output_masks.is_empty())
4105    }
4106
4107    #[test]
4108    fn test_decoder_tracked_end_to_end_segdet_float() {
4109        let score_threshold = 0.45;
4110        let iou_threshold = 0.45;
4111
4112        let mut boxes = Array2::zeros((10, 4));
4113        let mut scores = Array2::zeros((10, 1));
4114        let mut classes = Array2::zeros((10, 1));
4115        let mask = Array2::zeros((10, 32));
4116        let protos = Array3::<f64>::zeros((160, 160, 32));
4117        let protos = protos.insert_axis(Axis(0));
4118
4119        boxes
4120            .slice_mut(s![0, ..,])
4121            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4122        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4123        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4124
4125        let detect = ndarray::concatenate![
4126            Axis(1),
4127            boxes.view(),
4128            scores.view(),
4129            classes.view(),
4130            mask.view()
4131        ];
4132        let mut detect = detect.insert_axis(Axis(0));
4133        assert_eq!(detect.shape(), &[1, 10, 38]);
4134        let config = "
4135decoder_version: yolo26
4136outputs:
4137 - type: detection
4138   decoder: ultralytics
4139   quantization: [0.00784313725490196, 0]
4140   shape: [1, 10, 38]
4141   dshape:
4142    - [batch, 1]
4143    - [num_boxes, 10]
4144    - [num_features, 38]
4145   normalized: true
4146 - type: protos
4147   decoder: ultralytics
4148   quantization: [0.0039215686274509803921568627451, 128]
4149   shape: [1, 160, 160, 32]
4150   dshape:
4151    - [batch, 1]
4152    - [height, 160]
4153    - [width, 160]
4154    - [num_protos, 32]
4155";
4156
4157        let decoder = DecoderBuilder::default()
4158            .with_config_yaml_str(config.to_string())
4159            .with_score_threshold(score_threshold)
4160            .with_iou_threshold(iou_threshold)
4161            .build()
4162            .unwrap();
4163
4164        let expected_boxes = [DetectBox {
4165            bbox: BoundingBox {
4166                xmin: 0.1234,
4167                ymin: 0.1234,
4168                xmax: 0.2345,
4169                ymax: 0.2345,
4170            },
4171            score: 0.9876,
4172            label: 2,
4173        }];
4174
4175        let mut tracker = ByteTrackBuilder::new()
4176            .track_update(0.1)
4177            .track_high_conf(0.7)
4178            .build();
4179
4180        let mut output_boxes = Vec::with_capacity(50);
4181        let mut output_masks = Vec::with_capacity(50);
4182        let mut output_tracks = Vec::with_capacity(50);
4183
4184        decoder
4185            .decode_tracked_float(
4186                &mut tracker,
4187                0,
4188                &[detect.view().into_dyn(), protos.view().into_dyn()],
4189                &mut output_boxes,
4190                &mut output_masks,
4191                &mut output_tracks,
4192            )
4193            .unwrap();
4194
4195        assert_eq!(output_boxes.len(), 1);
4196        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4197
4198        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4199
4200        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4201            *score = 0.0; // set all scores to minimum to simulate no detections
4202        }
4203
4204        decoder
4205            .decode_tracked_float(
4206                &mut tracker,
4207                100_000_000 / 3,
4208                &[detect.view().into_dyn(), protos.view().into_dyn()],
4209                &mut output_boxes,
4210                &mut output_masks,
4211                &mut output_tracks,
4212            )
4213            .unwrap();
4214        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4215        // no masks when the boxes are from tracker prediction without a matching detection
4216        assert!(output_masks.is_empty())
4217    }
4218
4219    #[test]
4220    fn test_decoder_tracked_end_to_end_segdet_proto() {
4221        let score_threshold = 0.45;
4222        let iou_threshold = 0.45;
4223
4224        let mut boxes = Array2::zeros((10, 4));
4225        let mut scores = Array2::zeros((10, 1));
4226        let mut classes = Array2::zeros((10, 1));
4227        let mask = Array2::zeros((10, 32));
4228        let protos = Array3::<f64>::zeros((160, 160, 32));
4229        let protos = protos.insert_axis(Axis(0));
4230
4231        let protos_quant = (1.0 / 255.0, 0.0);
4232        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4233
4234        boxes
4235            .slice_mut(s![0, ..,])
4236            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4237        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4238        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4239
4240        let detect = ndarray::concatenate![
4241            Axis(1),
4242            boxes.view(),
4243            scores.view(),
4244            classes.view(),
4245            mask.view()
4246        ];
4247        let detect = detect.insert_axis(Axis(0));
4248        assert_eq!(detect.shape(), &[1, 10, 38]);
4249        let detect_quant = (2.0 / 255.0, 0.0);
4250        let mut detect: Array3<u8> = quantize_ndarray(detect.view(), detect_quant.into());
4251        let config = "
4252decoder_version: yolo26
4253outputs:
4254 - type: detection
4255   decoder: ultralytics
4256   quantization: [0.00784313725490196, 0]
4257   shape: [1, 10, 38]
4258   dshape:
4259    - [batch, 1]
4260    - [num_boxes, 10]
4261    - [num_features, 38]
4262   normalized: true
4263 - type: protos
4264   decoder: ultralytics
4265   quantization: [0.0039215686274509803921568627451, 128]
4266   shape: [1, 160, 160, 32]
4267   dshape:
4268    - [batch, 1]
4269    - [height, 160]
4270    - [width, 160]
4271    - [num_protos, 32]
4272";
4273
4274        let decoder = DecoderBuilder::default()
4275            .with_config_yaml_str(config.to_string())
4276            .with_score_threshold(score_threshold)
4277            .with_iou_threshold(iou_threshold)
4278            .build()
4279            .unwrap();
4280
4281        // Expected boxes doesn't match the float values exactly due to quantization error
4282        let expected_boxes = [DetectBox {
4283            bbox: BoundingBox {
4284                xmin: 0.12549022,
4285                ymin: 0.12549022,
4286                xmax: 0.23529413,
4287                ymax: 0.23529413,
4288            },
4289            score: 0.98823535,
4290            label: 2,
4291        }];
4292
4293        let mut tracker = ByteTrackBuilder::new()
4294            .track_update(0.1)
4295            .track_high_conf(0.7)
4296            .build();
4297
4298        let mut output_boxes = Vec::with_capacity(50);
4299        let mut output_tracks = Vec::with_capacity(50);
4300
4301        decoder
4302            .decode_tracked_quantized_proto(
4303                &mut tracker,
4304                0,
4305                &[detect.view().into(), protos.view().into()],
4306                &mut output_boxes,
4307                &mut output_tracks,
4308            )
4309            .unwrap();
4310
4311        assert_eq!(output_boxes.len(), 1);
4312        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4313
4314        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4315
4316        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4317            *score = u8::MIN; // set all scores to minimum to simulate no detections
4318        }
4319
4320        let protos = decoder
4321            .decode_tracked_quantized_proto(
4322                &mut tracker,
4323                100_000_000 / 3,
4324                &[detect.view().into(), protos.view().into()],
4325                &mut output_boxes,
4326                &mut output_tracks,
4327            )
4328            .unwrap();
4329        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4330        // no masks when the boxes are from tracker prediction without a matching detection
4331        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4332    }
4333
4334    #[test]
4335    fn test_decoder_tracked_end_to_end_segdet_proto_float() {
4336        let score_threshold = 0.45;
4337        let iou_threshold = 0.45;
4338
4339        let mut boxes = Array2::zeros((10, 4));
4340        let mut scores = Array2::zeros((10, 1));
4341        let mut classes = Array2::zeros((10, 1));
4342        let mask = Array2::zeros((10, 32));
4343        let protos = Array3::<f64>::zeros((160, 160, 32));
4344        let protos = protos.insert_axis(Axis(0));
4345
4346        boxes
4347            .slice_mut(s![0, ..,])
4348            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4349        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4350        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4351
4352        let detect = ndarray::concatenate![
4353            Axis(1),
4354            boxes.view(),
4355            scores.view(),
4356            classes.view(),
4357            mask.view()
4358        ];
4359        let mut detect = detect.insert_axis(Axis(0));
4360        assert_eq!(detect.shape(), &[1, 10, 38]);
4361        let config = "
4362decoder_version: yolo26
4363outputs:
4364 - type: detection
4365   decoder: ultralytics
4366   quantization: [0.00784313725490196, 0]
4367   shape: [1, 10, 38]
4368   dshape:
4369    - [batch, 1]
4370    - [num_boxes, 10]
4371    - [num_features, 38]
4372   normalized: true
4373 - type: protos
4374   decoder: ultralytics
4375   quantization: [0.0039215686274509803921568627451, 128]
4376   shape: [1, 160, 160, 32]
4377   dshape:
4378    - [batch, 1]
4379    - [height, 160]
4380    - [width, 160]
4381    - [num_protos, 32]
4382";
4383
4384        let decoder = DecoderBuilder::default()
4385            .with_config_yaml_str(config.to_string())
4386            .with_score_threshold(score_threshold)
4387            .with_iou_threshold(iou_threshold)
4388            .build()
4389            .unwrap();
4390
4391        let expected_boxes = [DetectBox {
4392            bbox: BoundingBox {
4393                xmin: 0.1234,
4394                ymin: 0.1234,
4395                xmax: 0.2345,
4396                ymax: 0.2345,
4397            },
4398            score: 0.9876,
4399            label: 2,
4400        }];
4401
4402        let mut tracker = ByteTrackBuilder::new()
4403            .track_update(0.1)
4404            .track_high_conf(0.7)
4405            .build();
4406
4407        let mut output_boxes = Vec::with_capacity(50);
4408        let mut output_tracks = Vec::with_capacity(50);
4409
4410        decoder
4411            .decode_tracked_float_proto(
4412                &mut tracker,
4413                0,
4414                &[detect.view().into_dyn(), protos.view().into_dyn()],
4415                &mut output_boxes,
4416                &mut output_tracks,
4417            )
4418            .unwrap();
4419
4420        assert_eq!(output_boxes.len(), 1);
4421        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4422
4423        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4424
4425        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
4426            *score = 0.0; // set all scores to minimum to simulate no detections
4427        }
4428
4429        let protos = decoder
4430            .decode_tracked_float_proto(
4431                &mut tracker,
4432                100_000_000 / 3,
4433                &[detect.view().into_dyn(), protos.view().into_dyn()],
4434                &mut output_boxes,
4435                &mut output_tracks,
4436            )
4437            .unwrap();
4438        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4439        // no masks when the boxes are from tracker prediction without a matching detection
4440        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4441    }
4442
4443    #[test]
4444    fn test_decoder_tracked_end_to_end_segdet_split() {
4445        let score_threshold = 0.45;
4446        let iou_threshold = 0.45;
4447
4448        let mut boxes = Array2::zeros((10, 4));
4449        let mut scores = Array2::zeros((10, 1));
4450        let mut classes = Array2::zeros((10, 1));
4451        let mask: Array2<f64> = Array2::zeros((10, 32));
4452        let protos = Array3::<f64>::zeros((160, 160, 32));
4453        let protos = protos.insert_axis(Axis(0));
4454
4455        let protos_quant = (1.0 / 255.0, 0.0);
4456        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4457
4458        boxes
4459            .slice_mut(s![0, ..,])
4460            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4461        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4462        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4463
4464        let boxes = boxes.insert_axis(Axis(0));
4465        let scores = scores.insert_axis(Axis(0));
4466        let classes = classes.insert_axis(Axis(0));
4467        let mask = mask.insert_axis(Axis(0));
4468
4469        let detect_quant = (2.0 / 255.0, 0.0);
4470        let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4471        let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4472        let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4473        let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4474
4475        let config = "
4476decoder_version: yolo26
4477outputs:
4478 - type: boxes
4479   decoder: ultralytics
4480   quantization: [0.00784313725490196, 0]
4481   shape: [1, 10, 4]
4482   dshape:
4483    - [batch, 1]
4484    - [num_boxes, 10]
4485    - [box_coords, 4]
4486   normalized: true
4487 - type: scores
4488   decoder: ultralytics
4489   quantization: [0.00784313725490196, 0]
4490   shape: [1, 10, 1]
4491   dshape:
4492    - [batch, 1]
4493    - [num_boxes, 10]
4494    - [num_classes, 1]
4495 - type: classes
4496   decoder: ultralytics
4497   quantization: [0.00784313725490196, 0]
4498   shape: [1, 10, 1]
4499   dshape:
4500    - [batch, 1]
4501    - [num_boxes, 10]
4502    - [num_classes, 1]
4503 - type: mask_coefficients
4504   decoder: ultralytics
4505   quantization: [0.00784313725490196, 0]
4506   shape: [1, 10, 32]
4507   dshape:
4508    - [batch, 1]
4509    - [num_boxes, 10]
4510    - [num_protos, 32]
4511 - type: protos
4512   decoder: ultralytics
4513   quantization: [0.0039215686274509803921568627451, 128]
4514   shape: [1, 160, 160, 32]
4515   dshape:
4516    - [batch, 1]
4517    - [height, 160]
4518    - [width, 160]
4519    - [num_protos, 32]
4520";
4521
4522        let decoder = DecoderBuilder::default()
4523            .with_config_yaml_str(config.to_string())
4524            .with_score_threshold(score_threshold)
4525            .with_iou_threshold(iou_threshold)
4526            .build()
4527            .unwrap();
4528
4529        // Expected boxes doesn't match the float values exactly due to quantization error
4530        let expected_boxes = [DetectBox {
4531            bbox: BoundingBox {
4532                xmin: 0.12549022,
4533                ymin: 0.12549022,
4534                xmax: 0.23529413,
4535                ymax: 0.23529413,
4536            },
4537            score: 0.98823535,
4538            label: 2,
4539        }];
4540
4541        let mut tracker = ByteTrackBuilder::new()
4542            .track_update(0.1)
4543            .track_high_conf(0.7)
4544            .build();
4545
4546        let mut output_boxes = Vec::with_capacity(50);
4547        let mut output_masks = Vec::with_capacity(50);
4548        let mut output_tracks = Vec::with_capacity(50);
4549
4550        decoder
4551            .decode_tracked_quantized(
4552                &mut tracker,
4553                0,
4554                &[
4555                    boxes.view().into(),
4556                    scores.view().into(),
4557                    classes.view().into(),
4558                    mask.view().into(),
4559                    protos.view().into(),
4560                ],
4561                &mut output_boxes,
4562                &mut output_masks,
4563                &mut output_tracks,
4564            )
4565            .unwrap();
4566
4567        assert_eq!(output_boxes.len(), 1);
4568        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4569
4570        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4571
4572        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4573            *score = u8::MIN; // set all scores to minimum to simulate no detections
4574        }
4575
4576        decoder
4577            .decode_tracked_quantized(
4578                &mut tracker,
4579                100_000_000 / 3,
4580                &[
4581                    boxes.view().into(),
4582                    scores.view().into(),
4583                    classes.view().into(),
4584                    mask.view().into(),
4585                    protos.view().into(),
4586                ],
4587                &mut output_boxes,
4588                &mut output_masks,
4589                &mut output_tracks,
4590            )
4591            .unwrap();
4592        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4593        // no masks when the boxes are from tracker prediction without a matching detection
4594        assert!(output_masks.is_empty())
4595    }
4596    #[test]
4597    fn test_decoder_tracked_end_to_end_segdet_split_float() {
4598        let score_threshold = 0.45;
4599        let iou_threshold = 0.45;
4600
4601        let mut boxes = Array2::zeros((10, 4));
4602        let mut scores = Array2::zeros((10, 1));
4603        let mut classes = Array2::zeros((10, 1));
4604        let mask: Array2<f64> = Array2::zeros((10, 32));
4605        let protos = Array3::<f64>::zeros((160, 160, 32));
4606        let protos = protos.insert_axis(Axis(0));
4607
4608        boxes
4609            .slice_mut(s![0, ..,])
4610            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4611        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4612        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4613
4614        let boxes = boxes.insert_axis(Axis(0));
4615        let mut scores = scores.insert_axis(Axis(0));
4616        let classes = classes.insert_axis(Axis(0));
4617        let mask = mask.insert_axis(Axis(0));
4618
4619        let config = "
4620decoder_version: yolo26
4621outputs:
4622 - type: boxes
4623   decoder: ultralytics
4624   quantization: [0.00784313725490196, 0]
4625   shape: [1, 10, 4]
4626   dshape:
4627    - [batch, 1]
4628    - [num_boxes, 10]
4629    - [box_coords, 4]
4630   normalized: true
4631 - type: scores
4632   decoder: ultralytics
4633   quantization: [0.00784313725490196, 0]
4634   shape: [1, 10, 1]
4635   dshape:
4636    - [batch, 1]
4637    - [num_boxes, 10]
4638    - [num_classes, 1]
4639 - type: classes
4640   decoder: ultralytics
4641   quantization: [0.00784313725490196, 0]
4642   shape: [1, 10, 1]
4643   dshape:
4644    - [batch, 1]
4645    - [num_boxes, 10]
4646    - [num_classes, 1]
4647 - type: mask_coefficients
4648   decoder: ultralytics
4649   quantization: [0.00784313725490196, 0]
4650   shape: [1, 10, 32]
4651   dshape:
4652    - [batch, 1]
4653    - [num_boxes, 10]
4654    - [num_protos, 32]
4655 - type: protos
4656   decoder: ultralytics
4657   quantization: [0.0039215686274509803921568627451, 128]
4658   shape: [1, 160, 160, 32]
4659   dshape:
4660    - [batch, 1]
4661    - [height, 160]
4662    - [width, 160]
4663    - [num_protos, 32]
4664";
4665
4666        let decoder = DecoderBuilder::default()
4667            .with_config_yaml_str(config.to_string())
4668            .with_score_threshold(score_threshold)
4669            .with_iou_threshold(iou_threshold)
4670            .build()
4671            .unwrap();
4672
4673        // Expected boxes doesn't match the float values exactly due to quantization error
4674        let expected_boxes = [DetectBox {
4675            bbox: BoundingBox {
4676                xmin: 0.1234,
4677                ymin: 0.1234,
4678                xmax: 0.2345,
4679                ymax: 0.2345,
4680            },
4681            score: 0.9876,
4682            label: 2,
4683        }];
4684
4685        let mut tracker = ByteTrackBuilder::new()
4686            .track_update(0.1)
4687            .track_high_conf(0.7)
4688            .build();
4689
4690        let mut output_boxes = Vec::with_capacity(50);
4691        let mut output_masks = Vec::with_capacity(50);
4692        let mut output_tracks = Vec::with_capacity(50);
4693
4694        decoder
4695            .decode_tracked_float(
4696                &mut tracker,
4697                0,
4698                &[
4699                    boxes.view().into_dyn(),
4700                    scores.view().into_dyn(),
4701                    classes.view().into_dyn(),
4702                    mask.view().into_dyn(),
4703                    protos.view().into_dyn(),
4704                ],
4705                &mut output_boxes,
4706                &mut output_masks,
4707                &mut output_tracks,
4708            )
4709            .unwrap();
4710
4711        assert_eq!(output_boxes.len(), 1);
4712        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4713
4714        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4715
4716        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4717            *score = 0.0; // set all scores to minimum to simulate no detections
4718        }
4719
4720        decoder
4721            .decode_tracked_float(
4722                &mut tracker,
4723                100_000_000 / 3,
4724                &[
4725                    boxes.view().into_dyn(),
4726                    scores.view().into_dyn(),
4727                    classes.view().into_dyn(),
4728                    mask.view().into_dyn(),
4729                    protos.view().into_dyn(),
4730                ],
4731                &mut output_boxes,
4732                &mut output_masks,
4733                &mut output_tracks,
4734            )
4735            .unwrap();
4736        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4737        // no masks when the boxes are from tracker prediction without a matching detection
4738        assert!(output_masks.is_empty())
4739    }
4740
4741    #[test]
4742    fn test_decoder_tracked_end_to_end_segdet_split_proto() {
4743        let score_threshold = 0.45;
4744        let iou_threshold = 0.45;
4745
4746        let mut boxes = Array2::zeros((10, 4));
4747        let mut scores = Array2::zeros((10, 1));
4748        let mut classes = Array2::zeros((10, 1));
4749        let mask: Array2<f64> = Array2::zeros((10, 32));
4750        let protos = Array3::<f64>::zeros((160, 160, 32));
4751        let protos = protos.insert_axis(Axis(0));
4752
4753        let protos_quant = (1.0 / 255.0, 0.0);
4754        let protos: Array4<u8> = quantize_ndarray(protos.view(), protos_quant.into());
4755
4756        boxes
4757            .slice_mut(s![0, ..,])
4758            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4759        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4760        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4761
4762        let boxes = boxes.insert_axis(Axis(0));
4763        let scores = scores.insert_axis(Axis(0));
4764        let classes = classes.insert_axis(Axis(0));
4765        let mask = mask.insert_axis(Axis(0));
4766
4767        let detect_quant = (2.0 / 255.0, 0.0);
4768        let boxes: Array3<u8> = quantize_ndarray(boxes.view(), detect_quant.into());
4769        let mut scores: Array3<u8> = quantize_ndarray(scores.view(), detect_quant.into());
4770        let classes: Array3<u8> = quantize_ndarray(classes.view(), detect_quant.into());
4771        let mask: Array3<u8> = quantize_ndarray(mask.view(), detect_quant.into());
4772
4773        let config = "
4774decoder_version: yolo26
4775outputs:
4776 - type: boxes
4777   decoder: ultralytics
4778   quantization: [0.00784313725490196, 0]
4779   shape: [1, 10, 4]
4780   dshape:
4781    - [batch, 1]
4782    - [num_boxes, 10]
4783    - [box_coords, 4]
4784   normalized: true
4785 - type: scores
4786   decoder: ultralytics
4787   quantization: [0.00784313725490196, 0]
4788   shape: [1, 10, 1]
4789   dshape:
4790    - [batch, 1]
4791    - [num_boxes, 10]
4792    - [num_classes, 1]
4793 - type: classes
4794   decoder: ultralytics
4795   quantization: [0.00784313725490196, 0]
4796   shape: [1, 10, 1]
4797   dshape:
4798    - [batch, 1]
4799    - [num_boxes, 10]
4800    - [num_classes, 1]
4801 - type: mask_coefficients
4802   decoder: ultralytics
4803   quantization: [0.00784313725490196, 0]
4804   shape: [1, 10, 32]
4805   dshape:
4806    - [batch, 1]
4807    - [num_boxes, 10]
4808    - [num_protos, 32]
4809 - type: protos
4810   decoder: ultralytics
4811   quantization: [0.0039215686274509803921568627451, 128]
4812   shape: [1, 160, 160, 32]
4813   dshape:
4814    - [batch, 1]
4815    - [height, 160]
4816    - [width, 160]
4817    - [num_protos, 32]
4818";
4819
4820        let decoder = DecoderBuilder::default()
4821            .with_config_yaml_str(config.to_string())
4822            .with_score_threshold(score_threshold)
4823            .with_iou_threshold(iou_threshold)
4824            .build()
4825            .unwrap();
4826
4827        // Expected boxes doesn't match the float values exactly due to quantization error
4828        let expected_boxes = [DetectBox {
4829            bbox: BoundingBox {
4830                xmin: 0.12549022,
4831                ymin: 0.12549022,
4832                xmax: 0.23529413,
4833                ymax: 0.23529413,
4834            },
4835            score: 0.98823535,
4836            label: 2,
4837        }];
4838
4839        let mut tracker = ByteTrackBuilder::new()
4840            .track_update(0.1)
4841            .track_high_conf(0.7)
4842            .build();
4843
4844        let mut output_boxes = Vec::with_capacity(50);
4845        let mut output_tracks = Vec::with_capacity(50);
4846
4847        decoder
4848            .decode_tracked_quantized_proto(
4849                &mut tracker,
4850                0,
4851                &[
4852                    boxes.view().into(),
4853                    scores.view().into(),
4854                    classes.view().into(),
4855                    mask.view().into(),
4856                    protos.view().into(),
4857                ],
4858                &mut output_boxes,
4859                &mut output_tracks,
4860            )
4861            .unwrap();
4862
4863        assert_eq!(output_boxes.len(), 1);
4864        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
4865
4866        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
4867
4868        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
4869            *score = u8::MIN; // set all scores to minimum to simulate no detections
4870        }
4871
4872        let protos = decoder
4873            .decode_tracked_quantized_proto(
4874                &mut tracker,
4875                100_000_000 / 3,
4876                &[
4877                    boxes.view().into(),
4878                    scores.view().into(),
4879                    classes.view().into(),
4880                    mask.view().into(),
4881                    protos.view().into(),
4882                ],
4883                &mut output_boxes,
4884                &mut output_tracks,
4885            )
4886            .unwrap();
4887        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
4888        // no masks when the boxes are from tracker prediction without a matching detection
4889        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
4890    }
4891
4892    #[test]
4893    fn test_decoder_tracked_end_to_end_segdet_split_proto_float() {
4894        let score_threshold = 0.45;
4895        let iou_threshold = 0.45;
4896
4897        let mut boxes = Array2::zeros((10, 4));
4898        let mut scores = Array2::zeros((10, 1));
4899        let mut classes = Array2::zeros((10, 1));
4900        let mask: Array2<f64> = Array2::zeros((10, 32));
4901        let protos = Array3::<f64>::zeros((160, 160, 32));
4902        let protos = protos.insert_axis(Axis(0));
4903
4904        boxes
4905            .slice_mut(s![0, ..,])
4906            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
4907        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
4908        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
4909
4910        let boxes = boxes.insert_axis(Axis(0));
4911        let mut scores = scores.insert_axis(Axis(0));
4912        let classes = classes.insert_axis(Axis(0));
4913        let mask = mask.insert_axis(Axis(0));
4914
4915        let config = "
4916decoder_version: yolo26
4917outputs:
4918 - type: boxes
4919   decoder: ultralytics
4920   quantization: [0.00784313725490196, 0]
4921   shape: [1, 10, 4]
4922   dshape:
4923    - [batch, 1]
4924    - [num_boxes, 10]
4925    - [box_coords, 4]
4926   normalized: true
4927 - type: scores
4928   decoder: ultralytics
4929   quantization: [0.00784313725490196, 0]
4930   shape: [1, 10, 1]
4931   dshape:
4932    - [batch, 1]
4933    - [num_boxes, 10]
4934    - [num_classes, 1]
4935 - type: classes
4936   decoder: ultralytics
4937   quantization: [0.00784313725490196, 0]
4938   shape: [1, 10, 1]
4939   dshape:
4940    - [batch, 1]
4941    - [num_boxes, 10]
4942    - [num_classes, 1]
4943 - type: mask_coefficients
4944   decoder: ultralytics
4945   quantization: [0.00784313725490196, 0]
4946   shape: [1, 10, 32]
4947   dshape:
4948    - [batch, 1]
4949    - [num_boxes, 10]
4950    - [num_protos, 32]
4951 - type: protos
4952   decoder: ultralytics
4953   quantization: [0.0039215686274509803921568627451, 128]
4954   shape: [1, 160, 160, 32]
4955   dshape:
4956    - [batch, 1]
4957    - [height, 160]
4958    - [width, 160]
4959    - [num_protos, 32]
4960";
4961
4962        let decoder = DecoderBuilder::default()
4963            .with_config_yaml_str(config.to_string())
4964            .with_score_threshold(score_threshold)
4965            .with_iou_threshold(iou_threshold)
4966            .build()
4967            .unwrap();
4968
4969        // Expected boxes doesn't match the float values exactly due to quantization error
4970        let expected_boxes = [DetectBox {
4971            bbox: BoundingBox {
4972                xmin: 0.1234,
4973                ymin: 0.1234,
4974                xmax: 0.2345,
4975                ymax: 0.2345,
4976            },
4977            score: 0.9876,
4978            label: 2,
4979        }];
4980
4981        let mut tracker = ByteTrackBuilder::new()
4982            .track_update(0.1)
4983            .track_high_conf(0.7)
4984            .build();
4985
4986        let mut output_boxes = Vec::with_capacity(50);
4987        let mut output_tracks = Vec::with_capacity(50);
4988
4989        decoder
4990            .decode_tracked_float_proto(
4991                &mut tracker,
4992                0,
4993                &[
4994                    boxes.view().into_dyn(),
4995                    scores.view().into_dyn(),
4996                    classes.view().into_dyn(),
4997                    mask.view().into_dyn(),
4998                    protos.view().into_dyn(),
4999                ],
5000                &mut output_boxes,
5001                &mut output_tracks,
5002            )
5003            .unwrap();
5004
5005        assert_eq!(output_boxes.len(), 1);
5006        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1.0 / 160.0));
5007
5008        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5009
5010        for score in scores.slice_mut(s![.., .., ..]).iter_mut() {
5011            *score = 0.0; // set all scores to minimum to simulate no detections
5012        }
5013
5014        let protos = decoder
5015            .decode_tracked_float_proto(
5016                &mut tracker,
5017                100_000_000 / 3,
5018                &[
5019                    boxes.view().into_dyn(),
5020                    scores.view().into_dyn(),
5021                    classes.view().into_dyn(),
5022                    mask.view().into_dyn(),
5023                    protos.view().into_dyn(),
5024                ],
5025                &mut output_boxes,
5026                &mut output_tracks,
5027            )
5028            .unwrap();
5029        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5030        // no masks when the boxes are from tracker prediction without a matching detection
5031        assert!(protos.is_some_and(|x| x.mask_coefficients.is_empty()))
5032    }
5033
5034    #[test]
5035    fn test_decoder_tracked_linear_motion() {
5036        use crate::configs::{DecoderType, Nms};
5037        use crate::DecoderBuilder;
5038
5039        let score_threshold = 0.25;
5040        let iou_threshold = 0.1;
5041        let out = include_bytes!(concat!(
5042            env!("CARGO_MANIFEST_DIR"),
5043            "/../../testdata/yolov8s_80_classes.bin"
5044        ));
5045        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
5046        let mut out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
5047        let quant = (0.0040811873, -123).into();
5048
5049        let decoder = DecoderBuilder::default()
5050            .with_config_yolo_det(
5051                crate::configs::Detection {
5052                    decoder: DecoderType::Ultralytics,
5053                    shape: vec![1, 84, 8400],
5054                    anchors: None,
5055                    quantization: Some(quant),
5056                    dshape: vec![
5057                        (crate::configs::DimName::Batch, 1),
5058                        (crate::configs::DimName::NumFeatures, 84),
5059                        (crate::configs::DimName::NumBoxes, 8400),
5060                    ],
5061                    normalized: Some(true),
5062                },
5063                None,
5064            )
5065            .with_score_threshold(score_threshold)
5066            .with_iou_threshold(iou_threshold)
5067            .with_nms(Some(Nms::ClassAgnostic))
5068            .build()
5069            .unwrap();
5070
5071        let mut expected_boxes = [
5072            DetectBox {
5073                bbox: BoundingBox {
5074                    xmin: 0.5285137,
5075                    ymin: 0.05305544,
5076                    xmax: 0.87541467,
5077                    ymax: 0.9998909,
5078                },
5079                score: 0.5591227,
5080                label: 0,
5081            },
5082            DetectBox {
5083                bbox: BoundingBox {
5084                    xmin: 0.130598,
5085                    ymin: 0.43260583,
5086                    xmax: 0.35098213,
5087                    ymax: 0.9958097,
5088                },
5089                score: 0.33057618,
5090                label: 75,
5091            },
5092        ];
5093
5094        let mut tracker = ByteTrackBuilder::new()
5095            .track_update(0.1)
5096            .track_high_conf(0.3)
5097            .build();
5098
5099        let mut output_boxes = Vec::with_capacity(50);
5100        let mut output_masks = Vec::with_capacity(50);
5101        let mut output_tracks = Vec::with_capacity(50);
5102
5103        decoder
5104            .decode_tracked_quantized(
5105                &mut tracker,
5106                0,
5107                &[out.view().into()],
5108                &mut output_boxes,
5109                &mut output_masks,
5110                &mut output_tracks,
5111            )
5112            .unwrap();
5113
5114        assert_eq!(output_boxes.len(), 2);
5115        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5116        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-6));
5117
5118        for i in 1..=100 {
5119            let mut out = out.clone();
5120            // introduce linear movement into the XY coordinates
5121            let mut x_values = out.slice_mut(s![0, 0, ..]);
5122            for x in x_values.iter_mut() {
5123                *x = x.saturating_add((i as f32 * 1e-3 / quant.0).round() as i8);
5124            }
5125
5126            decoder
5127                .decode_tracked_quantized(
5128                    &mut tracker,
5129                    100_000_000 * i / 3, // simulate 33.333ms between frames
5130                    &[out.view().into()],
5131                    &mut output_boxes,
5132                    &mut output_masks,
5133                    &mut output_tracks,
5134                )
5135                .unwrap();
5136
5137            assert_eq!(output_boxes.len(), 2);
5138        }
5139        let tracks = tracker.get_active_tracks();
5140        let predicted_boxes: Vec<_> = tracks
5141            .iter()
5142            .map(|track| {
5143                let mut l = track.last_box;
5144                l.bbox = track.info.tracked_location.into();
5145                l
5146            })
5147            .collect();
5148        expected_boxes[0].bbox.xmin += 0.1; // compensate for linear movement
5149        expected_boxes[0].bbox.xmax += 0.1;
5150        expected_boxes[1].bbox.xmin += 0.1;
5151        expected_boxes[1].bbox.xmax += 0.1;
5152
5153        assert!(predicted_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5154        assert!(predicted_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5155
5156        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5157        let mut scores_values = out.slice_mut(s![0, 4.., ..]);
5158        for score in scores_values.iter_mut() {
5159            *score = i8::MIN; // set all scores to minimum to simulate no detections
5160        }
5161        decoder
5162            .decode_tracked_quantized(
5163                &mut tracker,
5164                100_000_000 * 101 / 3,
5165                &[out.view().into()],
5166                &mut output_boxes,
5167                &mut output_masks,
5168                &mut output_tracks,
5169            )
5170            .unwrap();
5171        expected_boxes[0].bbox.xmin += 0.001; // compensate for expected movement
5172        expected_boxes[0].bbox.xmax += 0.001;
5173        expected_boxes[1].bbox.xmin += 0.001;
5174        expected_boxes[1].bbox.xmax += 0.001;
5175
5176        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-3));
5177        assert!(output_boxes[1].equal_within_delta(&expected_boxes[1], 1e-3));
5178    }
5179
5180    #[test]
5181    fn test_decoder_tracked_end_to_end_float() {
5182        let score_threshold = 0.45;
5183        let iou_threshold = 0.45;
5184
5185        let mut boxes = Array2::zeros((10, 4));
5186        let mut scores = Array2::zeros((10, 1));
5187        let mut classes = Array2::zeros((10, 1));
5188
5189        boxes
5190            .slice_mut(s![0, ..,])
5191            .assign(&array![0.1234, 0.1234, 0.2345, 0.2345]);
5192        scores.slice_mut(s![0, ..]).assign(&array![0.9876]);
5193        classes.slice_mut(s![0, ..]).assign(&array![2.0]);
5194
5195        let detect = ndarray::concatenate![Axis(1), boxes.view(), scores.view(), classes.view(),];
5196        let mut detect = detect.insert_axis(Axis(0));
5197        assert_eq!(detect.shape(), &[1, 10, 6]);
5198        let config = "
5199decoder_version: yolo26
5200outputs:
5201 - type: detection
5202   decoder: ultralytics
5203   quantization: [0.00784313725490196, 0]
5204   shape: [1, 10, 6]
5205   dshape:
5206    - [batch, 1]
5207    - [num_boxes, 10]
5208    - [num_features, 6]
5209   normalized: true
5210";
5211
5212        let decoder = DecoderBuilder::default()
5213            .with_config_yaml_str(config.to_string())
5214            .with_score_threshold(score_threshold)
5215            .with_iou_threshold(iou_threshold)
5216            .build()
5217            .unwrap();
5218
5219        let expected_boxes = [DetectBox {
5220            bbox: BoundingBox {
5221                xmin: 0.1234,
5222                ymin: 0.1234,
5223                xmax: 0.2345,
5224                ymax: 0.2345,
5225            },
5226            score: 0.9876,
5227            label: 2,
5228        }];
5229
5230        let mut tracker = ByteTrackBuilder::new()
5231            .track_update(0.1)
5232            .track_high_conf(0.7)
5233            .build();
5234
5235        let mut output_boxes = Vec::with_capacity(50);
5236        let mut output_masks = Vec::with_capacity(50);
5237        let mut output_tracks = Vec::with_capacity(50);
5238
5239        decoder
5240            .decode_tracked_float(
5241                &mut tracker,
5242                0,
5243                &[detect.view().into_dyn()],
5244                &mut output_boxes,
5245                &mut output_masks,
5246                &mut output_tracks,
5247            )
5248            .unwrap();
5249
5250        assert_eq!(output_boxes.len(), 1);
5251        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5252
5253        // give the decoder a final frame with no detections to ensure tracks are properly predicting forward when detection is missing
5254
5255        for score in detect.slice_mut(s![.., .., 4]).iter_mut() {
5256            *score = 0.0; // set all scores to minimum to simulate no detections
5257        }
5258
5259        decoder
5260            .decode_tracked_float(
5261                &mut tracker,
5262                100_000_000 / 3,
5263                &[detect.view().into_dyn()],
5264                &mut output_boxes,
5265                &mut output_masks,
5266                &mut output_tracks,
5267            )
5268            .unwrap();
5269        assert!(output_boxes[0].equal_within_delta(&expected_boxes[0], 1e-6));
5270    }
5271}