Skip to main content

edgefirst_decoder/
lib.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4/*!
5## EdgeFirst HAL - Decoders
6This crate provides decoding utilities for YOLOobject detection and segmentation models, and ModelPack detection and segmentation models.
7It supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices. The crate includes functions
8for efficient post-processing model outputs into usable detection boxes and segmentation masks, as well as utilities for dequantizing model outputs..
9
10For general usage, use the `Decoder` struct which provides functions for decoding various model outputs based on the model configuration.
11If you already know the model type and output formats, you can use the lower-level functions directly from the `yolo` and `modelpack` modules.
12
13
14### Quick Example
15```rust
16# use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs::{self, DecoderVersion} };
17# fn main() -> DecoderResult<()> {
18// Create a decoder for a YOLOv8 model with quantized int8 output with 0.25 score threshold and 0.7 IOU threshold
19let decoder = DecoderBuilder::new()
20    .with_config_yolo_det(configs::Detection {
21        anchors: None,
22        decoder: configs::DecoderType::Ultralytics,
23        quantization: Some(configs::QuantTuple(0.012345, 26)),
24        shape: vec![1, 84, 8400],
25        dshape: Vec::new(),
26        normalized: Some(true),
27    },
28    Some(DecoderVersion::Yolov8))
29    .with_score_threshold(0.25)
30    .with_iou_threshold(0.7)
31    .build()?;
32
33// Get the model output from the model. Here we load it from a test data file for demonstration purposes.
34let model_output: Vec<i8> = include_bytes!("../../../testdata/yolov8s_80_classes.bin")
35    .iter()
36    .map(|b| *b as i8)
37    .collect();
38let model_output_array = ndarray::Array3::from_shape_vec((1, 84, 8400), model_output)?;
39
40// THe capacity is used to determine the maximum number of detections to decode.
41let mut output_boxes: Vec<_> = Vec::with_capacity(10);
42let mut output_masks: Vec<_> = Vec::with_capacity(10);
43
44// Decode the quantized model output into detection boxes and segmentation masks
45// Because this model is a detection-only model, the `output_masks` vector will remain empty.
46decoder.decode_quantized(&[model_output_array.view().into()], &mut output_boxes, &mut output_masks)?;
47# Ok(())
48# }
49```
50
51# Overview
52
53The primary components of this crate are:
54- `Decoder`/`DecoderBuilder` struct: Provides high-level functions to decode model outputs based on the model configuration.
55- `yolo` module: Contains functions specific to decoding YOLO model outputs.
56- `modelpack` module: Contains functions specific to decoding ModelPack model outputs.
57
58The `Decoder` supports both floating-point and quantized model outputs, allowing for efficient processing on edge devices.
59It also supports mixed integer types for quantized outputs, such as when one output tensor is int8 and another is uint8.
60When decoding quantized outputs, the appropriate quantization parameters must be provided for each output tensor.
61If the integer types used in the model output is not supported by the decoder, the user can manually dequantize the model outputs using
62the `dequantize` functions provided in this crate, and then use the floating-point decoding functions. However, it is recommended
63to not dequantize the model outputs manually before passing them to the decoder, as the quantized decoder functions are optimized for performance.
64
65The `yolo` and `modelpack` modules provide lower-level functions for decoding model outputs directly,
66which can be used if the model type and output formats are known in advance.
67
68
69*/
70#![cfg_attr(coverage_nightly, feature(coverage_attribute))]
71
72use ndarray::{Array, Array2, Array3, ArrayView, ArrayView1, ArrayView3, Dimension};
73use num_traits::{AsPrimitive, Float, PrimInt};
74
75pub mod byte;
76pub mod error;
77pub mod float;
78pub mod modelpack;
79pub mod yolo;
80
81mod decoder;
82pub use decoder::*;
83
84pub use configs::{DecoderVersion, Nms};
85pub use error::{DecoderError, DecoderResult};
86
87use crate::{
88    decoder::configs::QuantTuple, modelpack::modelpack_segmentation_to_mask,
89    yolo::yolo_segmentation_to_mask,
90};
91
92/// Trait to convert bounding box formats to XYXY float format
93pub trait BBoxTypeTrait {
94    /// Converts the bbox into XYXY float format.
95    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4];
96
97    /// Converts the bbox into XYXY float format.
98    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
99        input: &[B; 4],
100        quant: Quantization,
101    ) -> [A; 4]
102    where
103        f32: AsPrimitive<A>,
104        i32: AsPrimitive<A>;
105
106    /// Converts the bbox into XYXY float format.
107    ///
108    /// # Examples
109    /// ```rust
110    /// # use edgefirst_decoder::{BBoxTypeTrait, XYWH};
111    /// # use ndarray::array;
112    /// let arr = array![10.0_f32, 20.0, 20.0, 20.0];
113    /// let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
114    /// assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
115    /// ```
116    #[inline(always)]
117    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
118        input: ArrayView1<B>,
119    ) -> [A; 4] {
120        Self::to_xyxy_float(&[input[0], input[1], input[2], input[3]])
121    }
122
123    #[inline(always)]
124    /// Converts the bbox into XYXY float format.
125    fn ndarray_to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
126        input: ArrayView1<B>,
127        quant: Quantization,
128    ) -> [A; 4]
129    where
130        f32: AsPrimitive<A>,
131        i32: AsPrimitive<A>,
132    {
133        Self::to_xyxy_dequant(&[input[0], input[1], input[2], input[3]], quant)
134    }
135}
136
137/// Converts XYXY bounding boxes to XYXY
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
139pub struct XYXY {}
140
141impl BBoxTypeTrait for XYXY {
142    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
143        input.map(|b| b.as_())
144    }
145
146    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
147        input: &[B; 4],
148        quant: Quantization,
149    ) -> [A; 4]
150    where
151        f32: AsPrimitive<A>,
152        i32: AsPrimitive<A>,
153    {
154        let scale = quant.scale.as_();
155        let zp = quant.zero_point.as_();
156        input.map(|b| (b.as_() - zp) * scale)
157    }
158
159    #[inline(always)]
160    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
161        input: ArrayView1<B>,
162    ) -> [A; 4] {
163        [
164            input[0].as_(),
165            input[1].as_(),
166            input[2].as_(),
167            input[3].as_(),
168        ]
169    }
170}
171
172/// Converts XYWH bounding boxes to XYXY. The XY values are the center of the
173/// box
174#[derive(Debug, Clone, Copy, PartialEq, Eq)]
175pub struct XYWH {}
176
177impl BBoxTypeTrait for XYWH {
178    #[inline(always)]
179    fn to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(input: &[B; 4]) -> [A; 4] {
180        let half = A::one() / (A::one() + A::one());
181        [
182            (input[0].as_()) - (input[2].as_() * half),
183            (input[1].as_()) - (input[3].as_() * half),
184            (input[0].as_()) + (input[2].as_() * half),
185            (input[1].as_()) + (input[3].as_() * half),
186        ]
187    }
188
189    #[inline(always)]
190    fn to_xyxy_dequant<A: Float + 'static, B: AsPrimitive<A>>(
191        input: &[B; 4],
192        quant: Quantization,
193    ) -> [A; 4]
194    where
195        f32: AsPrimitive<A>,
196        i32: AsPrimitive<A>,
197    {
198        let scale = quant.scale.as_();
199        let half_scale = (quant.scale * 0.5).as_();
200        let zp = quant.zero_point.as_();
201        let [x, y, w, h] = [
202            (input[0].as_() - zp) * scale,
203            (input[1].as_() - zp) * scale,
204            (input[2].as_() - zp) * half_scale,
205            (input[3].as_() - zp) * half_scale,
206        ];
207
208        [x - w, y - h, x + w, y + h]
209    }
210
211    #[inline(always)]
212    fn ndarray_to_xyxy_float<A: Float + 'static, B: AsPrimitive<A>>(
213        input: ArrayView1<B>,
214    ) -> [A; 4] {
215        let half = A::one() / (A::one() + A::one());
216        [
217            (input[0].as_()) - (input[2].as_() * half),
218            (input[1].as_()) - (input[3].as_() * half),
219            (input[0].as_()) + (input[2].as_() * half),
220            (input[1].as_()) + (input[3].as_() * half),
221        ]
222    }
223}
224
225/// Describes the quantization parameters for a tensor
226#[derive(Debug, Clone, Copy, PartialEq)]
227pub struct Quantization {
228    pub scale: f32,
229    pub zero_point: i32,
230}
231
232impl Quantization {
233    /// Creates a new Quantization struct
234    /// # Examples
235    /// ```
236    /// # use edgefirst_decoder::Quantization;
237    /// let quant = Quantization::new(0.1, -128);
238    /// assert_eq!(quant.scale, 0.1);
239    /// assert_eq!(quant.zero_point, -128);
240    /// ```
241    pub fn new(scale: f32, zero_point: i32) -> Self {
242        Self { scale, zero_point }
243    }
244}
245
246impl From<QuantTuple> for Quantization {
247    /// Creates a new Quantization struct from a QuantTuple
248    /// # Examples
249    /// ```
250    /// # use edgefirst_decoder::Quantization;
251    /// # use edgefirst_decoder::configs::QuantTuple;
252    /// let quant_tuple = QuantTuple(0.1_f32, -128_i32);
253    /// let quant = Quantization::from(quant_tuple);
254    /// assert_eq!(quant.scale, 0.1);
255    /// assert_eq!(quant.zero_point, -128);
256    /// ```
257    fn from(quant_tuple: QuantTuple) -> Quantization {
258        Quantization {
259            scale: quant_tuple.0,
260            zero_point: quant_tuple.1,
261        }
262    }
263}
264
265impl<S, Z> From<(S, Z)> for Quantization
266where
267    S: AsPrimitive<f32>,
268    Z: AsPrimitive<i32>,
269{
270    /// Creates a new Quantization struct from a tuple
271    /// # Examples
272    /// ```
273    /// # use edgefirst_decoder::Quantization;
274    /// let quant = Quantization::from((0.1_f64, -128_i64));
275    /// assert_eq!(quant.scale, 0.1);
276    /// assert_eq!(quant.zero_point, -128);
277    /// ```
278    fn from((scale, zp): (S, Z)) -> Quantization {
279        Self {
280            scale: scale.as_(),
281            zero_point: zp.as_(),
282        }
283    }
284}
285
286impl Default for Quantization {
287    /// Creates a default Quantization struct with scale 1.0 and zero_point 0
288    /// # Examples
289    /// ```rust
290    /// # use edgefirst_decoder::Quantization;
291    /// let quant = Quantization::default();
292    /// assert_eq!(quant.scale, 1.0);
293    /// assert_eq!(quant.zero_point, 0);
294    /// ```
295    fn default() -> Self {
296        Self {
297            scale: 1.0,
298            zero_point: 0,
299        }
300    }
301}
302
303/// A detection box with f32 bbox and score
304#[derive(Debug, Clone, Copy, PartialEq, Default)]
305pub struct DetectBox {
306    pub bbox: BoundingBox,
307    /// model-specific score for this detection, higher implies more confidence
308    pub score: f32,
309    /// label index for this detection
310    pub label: usize,
311}
312
313/// A bounding box with f32 coordinates in XYXY format
314#[derive(Debug, Clone, Copy, PartialEq, Default)]
315pub struct BoundingBox {
316    /// left-most normalized coordinate of the bounding box
317    pub xmin: f32,
318    /// top-most normalized coordinate of the bounding box
319    pub ymin: f32,
320    /// right-most normalized coordinate of the bounding box
321    pub xmax: f32,
322    /// bottom-most normalized coordinate of the bounding box
323    pub ymax: f32,
324}
325
326impl BoundingBox {
327    /// Creates a new BoundingBox from the given coordinates
328    pub fn new(xmin: f32, ymin: f32, xmax: f32, ymax: f32) -> Self {
329        Self {
330            xmin,
331            ymin,
332            xmax,
333            ymax,
334        }
335    }
336
337    /// Transforms BoundingBox so that `xmin <= xmax` and `ymin <= ymax`
338    ///
339    /// ```
340    /// # use edgefirst_decoder::BoundingBox;
341    /// let bbox = BoundingBox::new(0.8, 0.6, 0.4, 0.2);
342    /// let canonical_bbox = bbox.to_canonical();
343    /// assert_eq!(canonical_bbox, BoundingBox::new(0.4, 0.2, 0.8, 0.6));
344    /// ```
345    pub fn to_canonical(&self) -> Self {
346        let xmin = self.xmin.min(self.xmax);
347        let xmax = self.xmin.max(self.xmax);
348        let ymin = self.ymin.min(self.ymax);
349        let ymax = self.ymin.max(self.ymax);
350        BoundingBox {
351            xmin,
352            ymin,
353            xmax,
354            ymax,
355        }
356    }
357}
358
359impl From<BoundingBox> for [f32; 4] {
360    /// Converts a BoundingBox into an array of 4 f32 values in xmin, ymin,
361    /// xmax, ymax order
362    /// # Examples
363    /// ```
364    /// # use edgefirst_decoder::BoundingBox;
365    /// let bbox = BoundingBox {
366    ///     xmin: 0.1,
367    ///     ymin: 0.2,
368    ///     xmax: 0.3,
369    ///     ymax: 0.4,
370    /// };
371    /// let arr: [f32; 4] = bbox.into();
372    /// assert_eq!(arr, [0.1, 0.2, 0.3, 0.4]);
373    /// ```
374    fn from(b: BoundingBox) -> Self {
375        [b.xmin, b.ymin, b.xmax, b.ymax]
376    }
377}
378
379impl From<[f32; 4]> for BoundingBox {
380    // Converts an array of 4 f32 values in xmin, ymin, xmax, ymax order into a
381    // BoundingBox
382    fn from(arr: [f32; 4]) -> Self {
383        BoundingBox {
384            xmin: arr[0],
385            ymin: arr[1],
386            xmax: arr[2],
387            ymax: arr[3],
388        }
389    }
390}
391
392impl DetectBox {
393    /// Returns true if one detect box is equal to another detect box, within
394    /// the given `eps`
395    ///
396    /// # Examples
397    /// ```
398    /// # use edgefirst_decoder::DetectBox;
399    /// let box1 = DetectBox {
400    ///     bbox: edgefirst_decoder::BoundingBox {
401    ///         xmin: 0.1,
402    ///         ymin: 0.2,
403    ///         xmax: 0.3,
404    ///         ymax: 0.4,
405    ///     },
406    ///     score: 0.5,
407    ///     label: 1,
408    /// };
409    /// let box2 = DetectBox {
410    ///     bbox: edgefirst_decoder::BoundingBox {
411    ///         xmin: 0.101,
412    ///         ymin: 0.199,
413    ///         xmax: 0.301,
414    ///         ymax: 0.399,
415    ///     },
416    ///     score: 0.510,
417    ///     label: 1,
418    /// };
419    /// assert!(box1.equal_within_delta(&box2, 0.011));
420    /// ```
421    pub fn equal_within_delta(&self, rhs: &DetectBox, eps: f32) -> bool {
422        let eq_delta = |a: f32, b: f32| (a - b).abs() <= eps;
423        self.label == rhs.label
424            && eq_delta(self.score, rhs.score)
425            && eq_delta(self.bbox.xmin, rhs.bbox.xmin)
426            && eq_delta(self.bbox.ymin, rhs.bbox.ymin)
427            && eq_delta(self.bbox.xmax, rhs.bbox.xmax)
428            && eq_delta(self.bbox.ymax, rhs.bbox.ymax)
429    }
430}
431
432/// A segmentation result with a segmentation mask, and a normalized bounding
433/// box representing the area that the segmentation mask covers
434#[derive(Debug, Clone, PartialEq, Default)]
435pub struct Segmentation {
436    /// left-most normalized coordinate of the segmentation box
437    pub xmin: f32,
438    /// top-most normalized coordinate of the segmentation box
439    pub ymin: f32,
440    /// right-most normalized coordinate of the segmentation box
441    pub xmax: f32,
442    /// bottom-most normalized coordinate of the segmentation box
443    pub ymax: f32,
444    /// 3D segmentation array. If the last dimension is 1, values equal or above
445    /// 128 are considered objects. Otherwise the object is the argmax index
446    pub segmentation: Array3<u8>,
447}
448
449/// Prototype tensor variants for fused decode+render pipelines.
450///
451/// Carries either raw quantized data (to skip CPU dequantization and let the
452/// GPU shader dequantize) or dequantized f32 data (from float models or legacy
453/// paths).
454#[derive(Debug, Clone)]
455pub enum ProtoTensor {
456    /// Raw int8 protos with quantization parameters — skip CPU dequantization.
457    /// The GPU fragment shader will dequantize per-texel using the scale and
458    /// zero_point.
459    Quantized {
460        protos: Array3<i8>,
461        quantization: Quantization,
462    },
463    /// Dequantized f32 protos (from float models or legacy path).
464    Float(Array3<f32>),
465}
466
467impl ProtoTensor {
468    /// Returns `true` if this is the quantized variant.
469    pub fn is_quantized(&self) -> bool {
470        matches!(self, ProtoTensor::Quantized { .. })
471    }
472
473    /// Returns the spatial dimensions `(height, width, num_protos)`.
474    pub fn dim(&self) -> (usize, usize, usize) {
475        match self {
476            ProtoTensor::Quantized { protos, .. } => protos.dim(),
477            ProtoTensor::Float(arr) => arr.dim(),
478        }
479    }
480
481    /// Returns dequantized f32 protos. For the `Float` variant this is a
482    /// no-copy reference; for `Quantized` it allocates and dequantizes.
483    pub fn as_f32(&self) -> std::borrow::Cow<'_, Array3<f32>> {
484        match self {
485            ProtoTensor::Float(arr) => std::borrow::Cow::Borrowed(arr),
486            ProtoTensor::Quantized {
487                protos,
488                quantization,
489            } => {
490                let scale = quantization.scale;
491                let zp = quantization.zero_point as f32;
492                std::borrow::Cow::Owned(protos.map(|&v| (v as f32 - zp) * scale))
493            }
494        }
495    }
496}
497
498/// Raw prototype data for fused decode+render pipelines.
499///
500/// Holds post-NMS intermediate state before mask materialization, allowing the
501/// renderer to compute `mask_coeff @ protos` directly (e.g. in a GPU fragment
502/// shader) without materializing intermediate `Array3<u8>` masks.
503#[derive(Debug, Clone)]
504pub struct ProtoData {
505    /// Mask coefficients per detection (each `Vec<f32>` has length `num_protos`).
506    pub mask_coefficients: Vec<Vec<f32>>,
507    /// Prototype tensor, shape `(proto_h, proto_w, num_protos)`.
508    pub protos: ProtoTensor,
509}
510
511/// Turns a DetectBoxQuantized into a DetectBox by dequantizing the score.
512///
513///  # Examples
514/// ```
515/// # use edgefirst_decoder::{BoundingBox, DetectBoxQuantized, Quantization, dequant_detect_box};
516/// let quant = Quantization::new(0.1, -128);
517/// let bbox = BoundingBox::new(0.1, 0.2, 0.3, 0.4);
518/// let detect_quant = DetectBoxQuantized {
519///     bbox,
520///     score: 100_i8,
521///     label: 1,
522/// };
523/// let detect = dequant_detect_box(&detect_quant, quant);
524/// assert_eq!(detect.score, 0.1 * 100.0 + 12.8);
525/// assert_eq!(detect.label, 1);
526/// assert_eq!(detect.bbox, bbox);
527/// ```
528pub fn dequant_detect_box<SCORE: PrimInt + AsPrimitive<f32>>(
529    detect: &DetectBoxQuantized<SCORE>,
530    quant_scores: Quantization,
531) -> DetectBox {
532    let scaled_zp = -quant_scores.scale * quant_scores.zero_point as f32;
533    DetectBox {
534        bbox: detect.bbox,
535        score: quant_scores.scale * detect.score.as_() + scaled_zp,
536        label: detect.label,
537    }
538}
539/// A detection box with a f32 bbox and quantized score
540#[derive(Debug, Clone, Copy, PartialEq)]
541pub struct DetectBoxQuantized<
542    // BOX: Signed + PrimInt + AsPrimitive<f32>,
543    SCORE: PrimInt + AsPrimitive<f32>,
544> {
545    // pub bbox: BoundingBoxQuantized<BOX>,
546    pub bbox: BoundingBox,
547    /// model-specific score for this detection, higher implies more
548    /// confidence.
549    pub score: SCORE,
550    /// label index for this detect
551    pub label: usize,
552}
553
554/// Dequantizes an ndarray from quantized values to f32 values using the given
555/// quantization parameters
556///
557/// # Examples
558/// ```
559/// # use edgefirst_decoder::{dequantize_ndarray, Quantization};
560/// let quant = Quantization::new(0.1, -128);
561/// let input: Vec<i8> = vec![0, 127, -128, 64];
562/// let input_array = ndarray::Array1::from(input);
563/// let output_array: ndarray::Array1<f32> = dequantize_ndarray(input_array.view(), quant);
564/// assert_eq!(output_array, ndarray::array![12.8, 25.5, 0.0, 19.2]);
565/// ```
566pub fn dequantize_ndarray<T: AsPrimitive<F>, D: Dimension, F: Float + 'static>(
567    input: ArrayView<T, D>,
568    quant: Quantization,
569) -> Array<F, D>
570where
571    i32: num_traits::AsPrimitive<F>,
572    f32: num_traits::AsPrimitive<F>,
573{
574    let zero_point = quant.zero_point.as_();
575    let scale = quant.scale.as_();
576    if zero_point != F::zero() {
577        let scaled_zero = -zero_point * scale;
578        input.mapv(|d| d.as_() * scale + scaled_zero)
579    } else {
580        input.mapv(|d| d.as_() * scale)
581    }
582}
583
584/// Dequantizes a slice from quantized values to float values using the given
585/// quantization parameters
586///
587/// # Examples
588/// ```
589/// # use edgefirst_decoder::{dequantize_cpu, Quantization};
590/// let quant = Quantization::new(0.1, -128);
591/// let input: Vec<i8> = vec![0, 127, -128, 64];
592/// let mut output: Vec<f32> = vec![0.0; input.len()];
593/// dequantize_cpu(&input, quant, &mut output);
594/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
595/// ```
596pub fn dequantize_cpu<T: AsPrimitive<F>, F: Float + 'static>(
597    input: &[T],
598    quant: Quantization,
599    output: &mut [F],
600) where
601    f32: num_traits::AsPrimitive<F>,
602    i32: num_traits::AsPrimitive<F>,
603{
604    assert!(input.len() == output.len());
605    let zero_point = quant.zero_point.as_();
606    let scale = quant.scale.as_();
607    if zero_point != F::zero() {
608        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
609        input
610            .iter()
611            .zip(output)
612            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
613    } else {
614        input
615            .iter()
616            .zip(output)
617            .for_each(|(d, deq)| *deq = d.as_() * scale);
618    }
619}
620
621/// Dequantizes a slice from quantized values to float values using the given
622/// quantization parameters, using chunked processing. This is around 5% faster
623/// than `dequantize_cpu` for large slices.
624///
625/// # Examples
626/// ```
627/// # use edgefirst_decoder::{dequantize_cpu_chunked, Quantization};
628/// let quant = Quantization::new(0.1, -128);
629/// let input: Vec<i8> = vec![0, 127, -128, 64];
630/// let mut output: Vec<f32> = vec![0.0; input.len()];
631/// dequantize_cpu_chunked(&input, quant, &mut output);
632/// assert_eq!(output, vec![12.8, 25.5, 0.0, 19.2]);
633/// ```
634pub fn dequantize_cpu_chunked<T: AsPrimitive<F>, F: Float + 'static>(
635    input: &[T],
636    quant: Quantization,
637    output: &mut [F],
638) where
639    f32: num_traits::AsPrimitive<F>,
640    i32: num_traits::AsPrimitive<F>,
641{
642    assert!(input.len() == output.len());
643    let zero_point = quant.zero_point.as_();
644    let scale = quant.scale.as_();
645
646    let input = input.as_chunks::<4>();
647    let output = output.as_chunks_mut::<4>();
648
649    if zero_point != F::zero() {
650        let scaled_zero = -zero_point * scale; // scale * (d - zero_point) = d * scale - zero_point * scale
651
652        input
653            .0
654            .iter()
655            .zip(output.0)
656            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale + scaled_zero));
657        input
658            .1
659            .iter()
660            .zip(output.1)
661            .for_each(|(d, deq)| *deq = d.as_() * scale + scaled_zero);
662    } else {
663        input
664            .0
665            .iter()
666            .zip(output.0)
667            .for_each(|(d, deq)| *deq = d.map(|d| d.as_() * scale));
668        input
669            .1
670            .iter()
671            .zip(output.1)
672            .for_each(|(d, deq)| *deq = d.as_() * scale);
673    }
674}
675
676/// Converts a segmentation tensor into a 2D mask
677/// If the last dimension of the segmentation tensor is 1, values equal or
678/// above 128 are considered objects. Otherwise the object is the argmax index
679///
680/// # Errors
681///
682/// Returns `DecoderError::InvalidShape` if the segmentation tensor has an
683/// invalid shape.
684///
685/// # Examples
686/// ```
687/// # use edgefirst_decoder::segmentation_to_mask;
688/// let segmentation =
689///     ndarray::Array3::<u8>::from_shape_vec((2, 2, 1), vec![0, 255, 128, 127]).unwrap();
690/// let mask = segmentation_to_mask(segmentation.view()).unwrap();
691/// assert_eq!(mask, ndarray::array![[0, 1], [1, 0]]);
692/// ```
693pub fn segmentation_to_mask(segmentation: ArrayView3<u8>) -> Result<Array2<u8>, DecoderError> {
694    if segmentation.shape()[2] == 0 {
695        return Err(DecoderError::InvalidShape(
696            "Segmentation tensor must have non-zero depth".to_string(),
697        ));
698    }
699    if segmentation.shape()[2] == 1 {
700        yolo_segmentation_to_mask(segmentation, 128)
701    } else {
702        Ok(modelpack_segmentation_to_mask(segmentation))
703    }
704}
705
706/// Returns the maximum value and its index from a 1D array
707fn arg_max<T: PartialOrd + Copy>(score: ArrayView1<T>) -> (T, usize) {
708    score
709        .iter()
710        .enumerate()
711        .fold((score[0], 0), |(max, arg_max), (ind, s)| {
712            if max > *s {
713                (max, arg_max)
714            } else {
715                (*s, ind)
716            }
717        })
718}
719#[cfg(test)]
720#[cfg_attr(coverage_nightly, coverage(off))]
721mod decoder_tests {
722    #![allow(clippy::excessive_precision)]
723    use crate::{
724        configs::{DecoderType, DimName, Protos},
725        modelpack::{decode_modelpack_det, decode_modelpack_split_quant},
726        yolo::{
727            decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float,
728            decode_yolo_segdet_quant,
729        },
730        *,
731    };
732    use ndarray::{array, s, Array4};
733    use ndarray_stats::DeviationExt;
734
735    fn compare_outputs(
736        boxes: (&[DetectBox], &[DetectBox]),
737        masks: (&[Segmentation], &[Segmentation]),
738    ) {
739        let (boxes0, boxes1) = boxes;
740        let (masks0, masks1) = masks;
741
742        assert_eq!(boxes0.len(), boxes1.len());
743        assert_eq!(masks0.len(), masks1.len());
744
745        for (b_i8, b_f32) in boxes0.iter().zip(boxes1) {
746            assert!(
747                b_i8.equal_within_delta(b_f32, 1e-6),
748                "{b_i8:?} is not equal to {b_f32:?}"
749            );
750        }
751
752        for (m_i8, m_f32) in masks0.iter().zip(masks1) {
753            assert_eq!(
754                [m_i8.xmin, m_i8.ymin, m_i8.xmax, m_i8.ymax],
755                [m_f32.xmin, m_f32.ymin, m_f32.xmax, m_f32.ymax],
756            );
757            assert_eq!(m_i8.segmentation.shape(), m_f32.segmentation.shape());
758            let mask_i8 = m_i8.segmentation.map(|x| *x as i32);
759            let mask_f32 = m_f32.segmentation.map(|x| *x as i32);
760            let diff = &mask_i8 - &mask_f32;
761            for x in 0..diff.shape()[0] {
762                for y in 0..diff.shape()[1] {
763                    for z in 0..diff.shape()[2] {
764                        let val = diff[[x, y, z]];
765                        assert!(
766                            val.abs() <= 1,
767                            "Difference between mask0 and mask1 is greater than 1 at ({}, {}, {}): {}",
768                            x,
769                            y,
770                            z,
771                            val
772                        );
773                    }
774                }
775            }
776            let mean_sq_err = mask_i8.mean_sq_err(&mask_f32).unwrap();
777            assert!(
778                mean_sq_err < 1e-2,
779                "Mean Square Error between masks was greater than 1%: {:.2}%",
780                mean_sq_err * 100.0
781            );
782        }
783    }
784
785    #[test]
786    fn test_decoder_modelpack() {
787        let score_threshold = 0.45;
788        let iou_threshold = 0.45;
789        let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
790        let boxes = ndarray::Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
791
792        let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
793        let scores = ndarray::Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
794
795        let quant_boxes = (0.004656755365431309, 21).into();
796        let quant_scores = (0.0019603664986789227, 0).into();
797
798        let decoder = DecoderBuilder::default()
799            .with_config_modelpack_det(
800                configs::Boxes {
801                    decoder: DecoderType::ModelPack,
802                    quantization: Some(quant_boxes),
803                    shape: vec![1, 1935, 1, 4],
804                    dshape: vec![
805                        (DimName::Batch, 1),
806                        (DimName::NumBoxes, 1935),
807                        (DimName::Padding, 1),
808                        (DimName::BoxCoords, 4),
809                    ],
810                    normalized: Some(true),
811                },
812                configs::Scores {
813                    decoder: DecoderType::ModelPack,
814                    quantization: Some(quant_scores),
815                    shape: vec![1, 1935, 1],
816                    dshape: vec![
817                        (DimName::Batch, 1),
818                        (DimName::NumBoxes, 1935),
819                        (DimName::NumClasses, 1),
820                    ],
821                },
822            )
823            .with_score_threshold(score_threshold)
824            .with_iou_threshold(iou_threshold)
825            .build()
826            .unwrap();
827
828        let quant_boxes = quant_boxes.into();
829        let quant_scores = quant_scores.into();
830
831        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
832        decode_modelpack_det(
833            (boxes.slice(s![0, .., 0, ..]), quant_boxes),
834            (scores.slice(s![0, .., ..]), quant_scores),
835            score_threshold,
836            iou_threshold,
837            &mut output_boxes,
838        );
839        assert!(output_boxes[0].equal_within_delta(
840            &DetectBox {
841                bbox: BoundingBox {
842                    xmin: 0.40513772,
843                    ymin: 0.6379755,
844                    xmax: 0.5122431,
845                    ymax: 0.7730214,
846                },
847                score: 0.4861709,
848                label: 0
849            },
850            1e-6
851        ));
852
853        let mut output_boxes1 = Vec::with_capacity(50);
854        let mut output_masks1 = Vec::with_capacity(50);
855
856        decoder
857            .decode_quantized(
858                &[boxes.view().into(), scores.view().into()],
859                &mut output_boxes1,
860                &mut output_masks1,
861            )
862            .unwrap();
863
864        let mut output_boxes_float = Vec::with_capacity(50);
865        let mut output_masks_float = Vec::with_capacity(50);
866
867        let boxes = dequantize_ndarray(boxes.view(), quant_boxes);
868        let scores = dequantize_ndarray(scores.view(), quant_scores);
869
870        decoder
871            .decode_float::<f32>(
872                &[boxes.view().into_dyn(), scores.view().into_dyn()],
873                &mut output_boxes_float,
874                &mut output_masks_float,
875            )
876            .unwrap();
877
878        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
879        compare_outputs(
880            (&output_boxes, &output_boxes_float),
881            (&[], &output_masks_float),
882        );
883    }
884
885    #[test]
886    fn test_decoder_modelpack_split_u8() {
887        let score_threshold = 0.45;
888        let iou_threshold = 0.45;
889        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
890        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
891
892        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
893        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
894
895        let quant0 = (0.08547406643629074, 174).into();
896        let quant1 = (0.09929127991199493, 183).into();
897        let anchors0 = vec![
898            [0.36666667461395264, 0.31481480598449707],
899            [0.38749998807907104, 0.4740740656852722],
900            [0.5333333611488342, 0.644444465637207],
901        ];
902        let anchors1 = vec![
903            [0.13750000298023224, 0.2074074000120163],
904            [0.2541666626930237, 0.21481481194496155],
905            [0.23125000298023224, 0.35185185074806213],
906        ];
907
908        let detect_config0 = configs::Detection {
909            decoder: DecoderType::ModelPack,
910            shape: vec![1, 9, 15, 18],
911            anchors: Some(anchors0.clone()),
912            quantization: Some(quant0),
913            dshape: vec![
914                (DimName::Batch, 1),
915                (DimName::Height, 9),
916                (DimName::Width, 15),
917                (DimName::NumAnchorsXFeatures, 18),
918            ],
919            normalized: Some(true),
920        };
921
922        let detect_config1 = configs::Detection {
923            decoder: DecoderType::ModelPack,
924            shape: vec![1, 17, 30, 18],
925            anchors: Some(anchors1.clone()),
926            quantization: Some(quant1),
927            dshape: vec![
928                (DimName::Batch, 1),
929                (DimName::Height, 17),
930                (DimName::Width, 30),
931                (DimName::NumAnchorsXFeatures, 18),
932            ],
933            normalized: Some(true),
934        };
935
936        let config0 = (&detect_config0).try_into().unwrap();
937        let config1 = (&detect_config1).try_into().unwrap();
938
939        let decoder = DecoderBuilder::default()
940            .with_config_modelpack_det_split(vec![detect_config1, detect_config0])
941            .with_score_threshold(score_threshold)
942            .with_iou_threshold(iou_threshold)
943            .build()
944            .unwrap();
945
946        let quant0 = quant0.into();
947        let quant1 = quant1.into();
948
949        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
950        decode_modelpack_split_quant(
951            &[
952                detect0.slice(s![0, .., .., ..]),
953                detect1.slice(s![0, .., .., ..]),
954            ],
955            &[config0, config1],
956            score_threshold,
957            iou_threshold,
958            &mut output_boxes,
959        );
960        assert!(output_boxes[0].equal_within_delta(
961            &DetectBox {
962                bbox: BoundingBox {
963                    xmin: 0.43171933,
964                    ymin: 0.68243736,
965                    xmax: 0.5626645,
966                    ymax: 0.808863,
967                },
968                score: 0.99240804,
969                label: 0
970            },
971            1e-6
972        ));
973
974        let mut output_boxes1: Vec<_> = Vec::with_capacity(10);
975        let mut output_masks1: Vec<_> = Vec::with_capacity(10);
976        decoder
977            .decode_quantized(
978                &[detect0.view().into(), detect1.view().into()],
979                &mut output_boxes1,
980                &mut output_masks1,
981            )
982            .unwrap();
983
984        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(10);
985        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(10);
986
987        let detect0 = dequantize_ndarray(detect0.view(), quant0);
988        let detect1 = dequantize_ndarray(detect1.view(), quant1);
989        decoder
990            .decode_float::<f32>(
991                &[detect0.view().into_dyn(), detect1.view().into_dyn()],
992                &mut output_boxes1_f32,
993                &mut output_masks1_f32,
994            )
995            .unwrap();
996
997        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
998        compare_outputs(
999            (&output_boxes, &output_boxes1_f32),
1000            (&[], &output_masks1_f32),
1001        );
1002    }
1003
1004    #[test]
1005    fn test_decoder_parse_config_modelpack_split_u8() {
1006        let score_threshold = 0.45;
1007        let iou_threshold = 0.45;
1008        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1009        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1010
1011        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1012        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1013
1014        let decoder = DecoderBuilder::default()
1015            .with_config_yaml_str(
1016                include_str!("../../../testdata/modelpack_split.yaml").to_string(),
1017            )
1018            .with_score_threshold(score_threshold)
1019            .with_iou_threshold(iou_threshold)
1020            .build()
1021            .unwrap();
1022
1023        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1024        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1025        decoder
1026            .decode_quantized(
1027                &[
1028                    ArrayViewDQuantized::from(detect1.view()),
1029                    ArrayViewDQuantized::from(detect0.view()),
1030                ],
1031                &mut output_boxes,
1032                &mut output_masks,
1033            )
1034            .unwrap();
1035        assert!(output_boxes[0].equal_within_delta(
1036            &DetectBox {
1037                bbox: BoundingBox {
1038                    xmin: 0.43171933,
1039                    ymin: 0.68243736,
1040                    xmax: 0.5626645,
1041                    ymax: 0.808863,
1042                },
1043                score: 0.99240804,
1044                label: 0
1045            },
1046            1e-6
1047        ));
1048    }
1049
1050    #[test]
1051    fn test_modelpack_seg() {
1052        let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1053        let out = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1054        let quant = (1.0 / 255.0, 0).into();
1055
1056        let decoder = DecoderBuilder::default()
1057            .with_config_modelpack_seg(configs::Segmentation {
1058                decoder: DecoderType::ModelPack,
1059                quantization: Some(quant),
1060                shape: vec![1, 2, 160, 160],
1061                dshape: vec![
1062                    (DimName::Batch, 1),
1063                    (DimName::NumClasses, 2),
1064                    (DimName::Height, 160),
1065                    (DimName::Width, 160),
1066                ],
1067            })
1068            .build()
1069            .unwrap();
1070        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1071        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1072        decoder
1073            .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
1074            .unwrap();
1075
1076        let mut mask = out.slice(s![0, .., .., ..]);
1077        mask.swap_axes(0, 1);
1078        mask.swap_axes(1, 2);
1079        let mask = [Segmentation {
1080            xmin: 0.0,
1081            ymin: 0.0,
1082            xmax: 1.0,
1083            ymax: 1.0,
1084            segmentation: mask.into_owned(),
1085        }];
1086        compare_outputs((&[], &output_boxes), (&mask, &output_masks));
1087
1088        decoder
1089            .decode_float::<f32>(
1090                &[dequantize_ndarray(out.view(), quant.into())
1091                    .view()
1092                    .into_dyn()],
1093                &mut output_boxes,
1094                &mut output_masks,
1095            )
1096            .unwrap();
1097
1098        // not expected for float decoder to have same values as quantized decoder, as
1099        // float decoder ensures the data fills 0-255, quantized decoder uses whatever
1100        // the model output. Thus the float output is the same as the quantized output
1101        // but scaled differently. However, it is expected that the mask after argmax
1102        // will be the same.
1103        compare_outputs((&[], &output_boxes), (&[], &[]));
1104        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1105        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1106
1107        assert_eq!(mask0, mask1);
1108    }
1109    #[test]
1110    fn test_modelpack_seg_quant() {
1111        let out = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1112        let out_u8 = ndarray::Array4::from_shape_vec((1, 2, 160, 160), out.to_vec()).unwrap();
1113        let out_i8 = out_u8.mapv(|x| (x as i16 - 128) as i8);
1114        let out_u16 = out_u8.mapv(|x| (x as u16) << 8);
1115        let out_i16 = out_u8.mapv(|x| (((x as i32) << 8) - 32768) as i16);
1116        let out_u32 = out_u8.mapv(|x| (x as u32) << 24);
1117        let out_i32 = out_u8.mapv(|x| (((x as i64) << 24) - 2147483648) as i32);
1118
1119        let quant = (1.0 / 255.0, 0).into();
1120
1121        let decoder = DecoderBuilder::default()
1122            .with_config_modelpack_seg(configs::Segmentation {
1123                decoder: DecoderType::ModelPack,
1124                quantization: Some(quant),
1125                shape: vec![1, 2, 160, 160],
1126                dshape: vec![
1127                    (DimName::Batch, 1),
1128                    (DimName::NumClasses, 2),
1129                    (DimName::Height, 160),
1130                    (DimName::Width, 160),
1131                ],
1132            })
1133            .build()
1134            .unwrap();
1135        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1136        let mut output_masks_u8: Vec<_> = Vec::with_capacity(10);
1137        decoder
1138            .decode_quantized(
1139                &[out_u8.view().into()],
1140                &mut output_boxes,
1141                &mut output_masks_u8,
1142            )
1143            .unwrap();
1144
1145        let mut output_masks_i8: Vec<_> = Vec::with_capacity(10);
1146        decoder
1147            .decode_quantized(
1148                &[out_i8.view().into()],
1149                &mut output_boxes,
1150                &mut output_masks_i8,
1151            )
1152            .unwrap();
1153
1154        let mut output_masks_u16: Vec<_> = Vec::with_capacity(10);
1155        decoder
1156            .decode_quantized(
1157                &[out_u16.view().into()],
1158                &mut output_boxes,
1159                &mut output_masks_u16,
1160            )
1161            .unwrap();
1162
1163        let mut output_masks_i16: Vec<_> = Vec::with_capacity(10);
1164        decoder
1165            .decode_quantized(
1166                &[out_i16.view().into()],
1167                &mut output_boxes,
1168                &mut output_masks_i16,
1169            )
1170            .unwrap();
1171
1172        let mut output_masks_u32: Vec<_> = Vec::with_capacity(10);
1173        decoder
1174            .decode_quantized(
1175                &[out_u32.view().into()],
1176                &mut output_boxes,
1177                &mut output_masks_u32,
1178            )
1179            .unwrap();
1180
1181        let mut output_masks_i32: Vec<_> = Vec::with_capacity(10);
1182        decoder
1183            .decode_quantized(
1184                &[out_i32.view().into()],
1185                &mut output_boxes,
1186                &mut output_masks_i32,
1187            )
1188            .unwrap();
1189
1190        compare_outputs((&[], &output_boxes), (&[], &[]));
1191        let mask_u8 = segmentation_to_mask(output_masks_u8[0].segmentation.view()).unwrap();
1192        let mask_i8 = segmentation_to_mask(output_masks_i8[0].segmentation.view()).unwrap();
1193        let mask_u16 = segmentation_to_mask(output_masks_u16[0].segmentation.view()).unwrap();
1194        let mask_i16 = segmentation_to_mask(output_masks_i16[0].segmentation.view()).unwrap();
1195        let mask_u32 = segmentation_to_mask(output_masks_u32[0].segmentation.view()).unwrap();
1196        let mask_i32 = segmentation_to_mask(output_masks_i32[0].segmentation.view()).unwrap();
1197        assert_eq!(mask_u8, mask_i8);
1198        assert_eq!(mask_u8, mask_u16);
1199        assert_eq!(mask_u8, mask_i16);
1200        assert_eq!(mask_u8, mask_u32);
1201        assert_eq!(mask_u8, mask_i32);
1202    }
1203
1204    #[test]
1205    fn test_modelpack_segdet() {
1206        let score_threshold = 0.45;
1207        let iou_threshold = 0.45;
1208
1209        let boxes = include_bytes!("../../../testdata/modelpack_boxes_1935x1x4.bin");
1210        let boxes = Array4::from_shape_vec((1, 1935, 1, 4), boxes.to_vec()).unwrap();
1211
1212        let scores = include_bytes!("../../../testdata/modelpack_scores_1935x1.bin");
1213        let scores = Array3::from_shape_vec((1, 1935, 1), scores.to_vec()).unwrap();
1214
1215        let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1216        let seg = Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1217
1218        let quant_boxes = (0.004656755365431309, 21).into();
1219        let quant_scores = (0.0019603664986789227, 0).into();
1220        let quant_seg = (1.0 / 255.0, 0).into();
1221
1222        let decoder = DecoderBuilder::default()
1223            .with_config_modelpack_segdet(
1224                configs::Boxes {
1225                    decoder: DecoderType::ModelPack,
1226                    quantization: Some(quant_boxes),
1227                    shape: vec![1, 1935, 1, 4],
1228                    dshape: vec![
1229                        (DimName::Batch, 1),
1230                        (DimName::NumBoxes, 1935),
1231                        (DimName::Padding, 1),
1232                        (DimName::BoxCoords, 4),
1233                    ],
1234                    normalized: Some(true),
1235                },
1236                configs::Scores {
1237                    decoder: DecoderType::ModelPack,
1238                    quantization: Some(quant_scores),
1239                    shape: vec![1, 1935, 1],
1240                    dshape: vec![
1241                        (DimName::Batch, 1),
1242                        (DimName::NumBoxes, 1935),
1243                        (DimName::NumClasses, 1),
1244                    ],
1245                },
1246                configs::Segmentation {
1247                    decoder: DecoderType::ModelPack,
1248                    quantization: Some(quant_seg),
1249                    shape: vec![1, 2, 160, 160],
1250                    dshape: vec![
1251                        (DimName::Batch, 1),
1252                        (DimName::NumClasses, 2),
1253                        (DimName::Height, 160),
1254                        (DimName::Width, 160),
1255                    ],
1256                },
1257            )
1258            .with_iou_threshold(iou_threshold)
1259            .with_score_threshold(score_threshold)
1260            .build()
1261            .unwrap();
1262        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1263        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1264        decoder
1265            .decode_quantized(
1266                &[scores.view().into(), boxes.view().into(), seg.view().into()],
1267                &mut output_boxes,
1268                &mut output_masks,
1269            )
1270            .unwrap();
1271
1272        let mut mask = seg.slice(s![0, .., .., ..]);
1273        mask.swap_axes(0, 1);
1274        mask.swap_axes(1, 2);
1275        let mask = [Segmentation {
1276            xmin: 0.0,
1277            ymin: 0.0,
1278            xmax: 1.0,
1279            ymax: 1.0,
1280            segmentation: mask.into_owned(),
1281        }];
1282        let correct_boxes = [DetectBox {
1283            bbox: BoundingBox {
1284                xmin: 0.40513772,
1285                ymin: 0.6379755,
1286                xmax: 0.5122431,
1287                ymax: 0.7730214,
1288            },
1289            score: 0.4861709,
1290            label: 0,
1291        }];
1292        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1293
1294        let scores = dequantize_ndarray(scores.view(), quant_scores.into());
1295        let boxes = dequantize_ndarray(boxes.view(), quant_boxes.into());
1296        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1297        decoder
1298            .decode_float::<f32>(
1299                &[
1300                    scores.view().into_dyn(),
1301                    boxes.view().into_dyn(),
1302                    seg.view().into_dyn(),
1303                ],
1304                &mut output_boxes,
1305                &mut output_masks,
1306            )
1307            .unwrap();
1308
1309        // not expected for float segmentation decoder to have same values as quantized
1310        // segmentation decoder, as float decoder ensures the data fills 0-255,
1311        // quantized decoder uses whatever the model output. Thus the float
1312        // output is the same as the quantized output but scaled differently.
1313        // However, it is expected that the mask after argmax will be the same.
1314        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1315        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1316        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1317
1318        assert_eq!(mask0, mask1);
1319    }
1320
1321    #[test]
1322    fn test_modelpack_segdet_split() {
1323        let score_threshold = 0.8;
1324        let iou_threshold = 0.5;
1325
1326        let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
1327        let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
1328
1329        let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
1330        let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
1331
1332        let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
1333        let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
1334
1335        let quant0 = (0.08547406643629074, 174).into();
1336        let quant1 = (0.09929127991199493, 183).into();
1337        let quant_seg = (1.0 / 255.0, 0).into();
1338
1339        let anchors0 = vec![
1340            [0.36666667461395264, 0.31481480598449707],
1341            [0.38749998807907104, 0.4740740656852722],
1342            [0.5333333611488342, 0.644444465637207],
1343        ];
1344        let anchors1 = vec![
1345            [0.13750000298023224, 0.2074074000120163],
1346            [0.2541666626930237, 0.21481481194496155],
1347            [0.23125000298023224, 0.35185185074806213],
1348        ];
1349
1350        let decoder = DecoderBuilder::default()
1351            .with_config_modelpack_segdet_split(
1352                vec![
1353                    configs::Detection {
1354                        decoder: DecoderType::ModelPack,
1355                        shape: vec![1, 17, 30, 18],
1356                        anchors: Some(anchors1),
1357                        quantization: Some(quant1),
1358                        dshape: vec![
1359                            (DimName::Batch, 1),
1360                            (DimName::Height, 17),
1361                            (DimName::Width, 30),
1362                            (DimName::NumAnchorsXFeatures, 18),
1363                        ],
1364                        normalized: Some(true),
1365                    },
1366                    configs::Detection {
1367                        decoder: DecoderType::ModelPack,
1368                        shape: vec![1, 9, 15, 18],
1369                        anchors: Some(anchors0),
1370                        quantization: Some(quant0),
1371                        dshape: vec![
1372                            (DimName::Batch, 1),
1373                            (DimName::Height, 9),
1374                            (DimName::Width, 15),
1375                            (DimName::NumAnchorsXFeatures, 18),
1376                        ],
1377                        normalized: Some(true),
1378                    },
1379                ],
1380                configs::Segmentation {
1381                    decoder: DecoderType::ModelPack,
1382                    quantization: Some(quant_seg),
1383                    shape: vec![1, 2, 160, 160],
1384                    dshape: vec![
1385                        (DimName::Batch, 1),
1386                        (DimName::NumClasses, 2),
1387                        (DimName::Height, 160),
1388                        (DimName::Width, 160),
1389                    ],
1390                },
1391            )
1392            .with_score_threshold(score_threshold)
1393            .with_iou_threshold(iou_threshold)
1394            .build()
1395            .unwrap();
1396        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1397        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1398        decoder
1399            .decode_quantized(
1400                &[
1401                    detect0.view().into(),
1402                    detect1.view().into(),
1403                    seg.view().into(),
1404                ],
1405                &mut output_boxes,
1406                &mut output_masks,
1407            )
1408            .unwrap();
1409
1410        let mut mask = seg.slice(s![0, .., .., ..]);
1411        mask.swap_axes(0, 1);
1412        mask.swap_axes(1, 2);
1413        let mask = [Segmentation {
1414            xmin: 0.0,
1415            ymin: 0.0,
1416            xmax: 1.0,
1417            ymax: 1.0,
1418            segmentation: mask.into_owned(),
1419        }];
1420        let correct_boxes = [DetectBox {
1421            bbox: BoundingBox {
1422                xmin: 0.43171933,
1423                ymin: 0.68243736,
1424                xmax: 0.5626645,
1425                ymax: 0.808863,
1426            },
1427            score: 0.99240804,
1428            label: 0,
1429        }];
1430        println!("Output Boxes: {:?}", output_boxes);
1431        compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
1432
1433        let detect0 = dequantize_ndarray(detect0.view(), quant0.into());
1434        let detect1 = dequantize_ndarray(detect1.view(), quant1.into());
1435        let seg = dequantize_ndarray(seg.view(), quant_seg.into());
1436        decoder
1437            .decode_float::<f32>(
1438                &[
1439                    detect0.view().into_dyn(),
1440                    detect1.view().into_dyn(),
1441                    seg.view().into_dyn(),
1442                ],
1443                &mut output_boxes,
1444                &mut output_masks,
1445            )
1446            .unwrap();
1447
1448        // not expected for float segmentation decoder to have same values as quantized
1449        // segmentation decoder, as float decoder ensures the data fills 0-255,
1450        // quantized decoder uses whatever the model output. Thus the float
1451        // output is the same as the quantized output but scaled differently.
1452        // However, it is expected that the mask after argmax will be the same.
1453        compare_outputs((&correct_boxes, &output_boxes), (&[], &[]));
1454        let mask0 = segmentation_to_mask(mask[0].segmentation.view()).unwrap();
1455        let mask1 = segmentation_to_mask(output_masks[0].segmentation.view()).unwrap();
1456
1457        assert_eq!(mask0, mask1);
1458    }
1459
1460    #[test]
1461    fn test_dequant_chunked() {
1462        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1463        let mut out =
1464            unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) }.to_vec();
1465        out.push(123); // make sure to test non multiple of 16 length
1466
1467        let mut out_dequant = vec![0.0; 84 * 8400 + 1];
1468        let mut out_dequant_simd = vec![0.0; 84 * 8400 + 1];
1469        let quant = Quantization::new(0.0040811873, -123);
1470        dequantize_cpu(&out, quant, &mut out_dequant);
1471
1472        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1473        assert_eq!(out_dequant, out_dequant_simd);
1474
1475        let quant = Quantization::new(0.0040811873, 0);
1476        dequantize_cpu(&out, quant, &mut out_dequant);
1477
1478        dequantize_cpu_chunked(&out, quant, &mut out_dequant_simd);
1479        assert_eq!(out_dequant, out_dequant_simd);
1480    }
1481
1482    #[test]
1483    fn test_decoder_yolo_det() {
1484        let score_threshold = 0.25;
1485        let iou_threshold = 0.7;
1486        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
1487        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
1488        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
1489        let quant = (0.0040811873, -123).into();
1490
1491        let decoder = DecoderBuilder::default()
1492            .with_config_yolo_det(
1493                configs::Detection {
1494                    decoder: DecoderType::Ultralytics,
1495                    shape: vec![1, 84, 8400],
1496                    anchors: None,
1497                    quantization: Some(quant),
1498                    dshape: vec![
1499                        (DimName::Batch, 1),
1500                        (DimName::NumFeatures, 84),
1501                        (DimName::NumBoxes, 8400),
1502                    ],
1503                    normalized: Some(true),
1504                },
1505                Some(DecoderVersion::Yolo11),
1506            )
1507            .with_score_threshold(score_threshold)
1508            .with_iou_threshold(iou_threshold)
1509            .build()
1510            .unwrap();
1511
1512        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
1513        decode_yolo_det(
1514            (out.slice(s![0, .., ..]), quant.into()),
1515            score_threshold,
1516            iou_threshold,
1517            Some(configs::Nms::ClassAgnostic),
1518            &mut output_boxes,
1519        );
1520        assert!(output_boxes[0].equal_within_delta(
1521            &DetectBox {
1522                bbox: BoundingBox {
1523                    xmin: 0.5285137,
1524                    ymin: 0.05305544,
1525                    xmax: 0.87541467,
1526                    ymax: 0.9998909,
1527                },
1528                score: 0.5591227,
1529                label: 0
1530            },
1531            1e-6
1532        ));
1533
1534        assert!(output_boxes[1].equal_within_delta(
1535            &DetectBox {
1536                bbox: BoundingBox {
1537                    xmin: 0.130598,
1538                    ymin: 0.43260583,
1539                    xmax: 0.35098213,
1540                    ymax: 0.9958097,
1541                },
1542                score: 0.33057618,
1543                label: 75
1544            },
1545            1e-6
1546        ));
1547
1548        let mut output_boxes1: Vec<_> = Vec::with_capacity(50);
1549        let mut output_masks1: Vec<_> = Vec::with_capacity(50);
1550        decoder
1551            .decode_quantized(&[out.view().into()], &mut output_boxes1, &mut output_masks1)
1552            .unwrap();
1553
1554        let out = dequantize_ndarray(out.view(), quant.into());
1555        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(50);
1556        let mut output_masks_f32: Vec<_> = Vec::with_capacity(50);
1557        decoder
1558            .decode_float::<f32>(
1559                &[out.view().into_dyn()],
1560                &mut output_boxes_f32,
1561                &mut output_masks_f32,
1562            )
1563            .unwrap();
1564
1565        compare_outputs((&output_boxes, &output_boxes1), (&[], &output_masks1));
1566        compare_outputs((&output_boxes, &output_boxes_f32), (&[], &output_masks_f32));
1567    }
1568
1569    #[test]
1570    fn test_decoder_masks() {
1571        let score_threshold = 0.45;
1572        let iou_threshold = 0.45;
1573        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1574        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1575        let boxes = ndarray::Array2::from_shape_vec((116, 8400), boxes.to_vec()).unwrap();
1576        let quant_boxes = Quantization::new(0.021287761628627777, 31);
1577
1578        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1579        let protos =
1580            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1581        let protos = ndarray::Array3::from_shape_vec((160, 160, 32), protos.to_vec()).unwrap();
1582        let quant_protos = Quantization::new(0.02491161972284317, -117);
1583        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1584        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1585        let mut output_boxes: Vec<_> = Vec::with_capacity(10);
1586        let mut output_masks: Vec<_> = Vec::with_capacity(10);
1587        decode_yolo_segdet_float(
1588            seg.view(),
1589            protos.view(),
1590            score_threshold,
1591            iou_threshold,
1592            Some(configs::Nms::ClassAgnostic),
1593            &mut output_boxes,
1594            &mut output_masks,
1595        );
1596        assert_eq!(output_boxes.len(), 2);
1597        assert_eq!(output_boxes.len(), output_masks.len());
1598
1599        for (b, m) in output_boxes.iter().zip(&output_masks) {
1600            assert!(b.bbox.xmin >= m.xmin);
1601            assert!(b.bbox.ymin >= m.ymin);
1602            assert!(b.bbox.xmax >= m.xmax);
1603            assert!(b.bbox.ymax >= m.ymax);
1604        }
1605        assert!(output_boxes[0].equal_within_delta(
1606            &DetectBox {
1607                bbox: BoundingBox {
1608                    xmin: 0.08515105,
1609                    ymin: 0.7131401,
1610                    xmax: 0.29802868,
1611                    ymax: 0.8195788,
1612                },
1613                score: 0.91537374,
1614                label: 23
1615            },
1616            1.0 / 160.0, // wider range because mask will expand the box
1617        ));
1618
1619        assert!(output_boxes[1].equal_within_delta(
1620            &DetectBox {
1621                bbox: BoundingBox {
1622                    xmin: 0.59605736,
1623                    ymin: 0.25545314,
1624                    xmax: 0.93666154,
1625                    ymax: 0.72378385,
1626                },
1627                score: 0.91537374,
1628                label: 23
1629            },
1630            1.0 / 160.0, // wider range because mask will expand the box
1631        ));
1632
1633        let full_mask = include_bytes!("../../../testdata/yolov8_mask_results.bin");
1634        let full_mask = ndarray::Array2::from_shape_vec((160, 160), full_mask.to_vec()).unwrap();
1635
1636        let cropped_mask = full_mask.slice(ndarray::s![
1637            (output_masks[1].ymin * 160.0) as usize..(output_masks[1].ymax * 160.0) as usize,
1638            (output_masks[1].xmin * 160.0) as usize..(output_masks[1].xmax * 160.0) as usize,
1639        ]);
1640
1641        assert_eq!(
1642            cropped_mask,
1643            segmentation_to_mask(output_masks[1].segmentation.view()).unwrap()
1644        );
1645    }
1646
1647    #[test]
1648    fn test_decoder_masks_i8() {
1649        let score_threshold = 0.45;
1650        let iou_threshold = 0.45;
1651        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1652        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1653        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes.to_vec()).unwrap();
1654        let quant_boxes = (0.021287761628627777, 31).into();
1655
1656        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1657        let protos =
1658            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1659        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1660        let quant_protos = (0.02491161972284317, -117).into();
1661        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1662        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1663
1664        let decoder = DecoderBuilder::default()
1665            .with_config_yolo_segdet(
1666                configs::Detection {
1667                    decoder: configs::DecoderType::Ultralytics,
1668                    quantization: Some(quant_boxes),
1669                    shape: vec![1, 116, 8400],
1670                    anchors: None,
1671                    dshape: vec![
1672                        (DimName::Batch, 1),
1673                        (DimName::NumFeatures, 116),
1674                        (DimName::NumBoxes, 8400),
1675                    ],
1676                    normalized: Some(true),
1677                },
1678                Protos {
1679                    decoder: configs::DecoderType::Ultralytics,
1680                    quantization: Some(quant_protos),
1681                    shape: vec![1, 160, 160, 32],
1682                    dshape: vec![
1683                        (DimName::Batch, 1),
1684                        (DimName::Height, 160),
1685                        (DimName::Width, 160),
1686                        (DimName::NumProtos, 32),
1687                    ],
1688                },
1689                Some(DecoderVersion::Yolo11),
1690            )
1691            .with_score_threshold(score_threshold)
1692            .with_iou_threshold(iou_threshold)
1693            .build()
1694            .unwrap();
1695
1696        let quant_boxes = quant_boxes.into();
1697        let quant_protos = quant_protos.into();
1698
1699        decode_yolo_segdet_quant(
1700            (boxes.slice(s![0, .., ..]), quant_boxes),
1701            (protos.slice(s![0, .., .., ..]), quant_protos),
1702            score_threshold,
1703            iou_threshold,
1704            Some(configs::Nms::ClassAgnostic),
1705            &mut output_boxes,
1706            &mut output_masks,
1707        );
1708
1709        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1710        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1711
1712        decoder
1713            .decode_quantized(
1714                &[boxes.view().into(), protos.view().into()],
1715                &mut output_boxes1,
1716                &mut output_masks1,
1717            )
1718            .unwrap();
1719
1720        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1721        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1722
1723        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1724        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1725        decode_yolo_segdet_float(
1726            seg.slice(s![0, .., ..]),
1727            protos.slice(s![0, .., .., ..]),
1728            score_threshold,
1729            iou_threshold,
1730            Some(configs::Nms::ClassAgnostic),
1731            &mut output_boxes_f32,
1732            &mut output_masks_f32,
1733        );
1734
1735        let mut output_boxes1_f32: Vec<_> = Vec::with_capacity(500);
1736        let mut output_masks1_f32: Vec<_> = Vec::with_capacity(500);
1737
1738        decoder
1739            .decode_float(
1740                &[seg.view().into_dyn(), protos.view().into_dyn()],
1741                &mut output_boxes1_f32,
1742                &mut output_masks1_f32,
1743            )
1744            .unwrap();
1745
1746        compare_outputs(
1747            (&output_boxes, &output_boxes1),
1748            (&output_masks, &output_masks1),
1749        );
1750
1751        compare_outputs(
1752            (&output_boxes, &output_boxes_f32),
1753            (&output_masks, &output_masks_f32),
1754        );
1755
1756        compare_outputs(
1757            (&output_boxes_f32, &output_boxes1_f32),
1758            (&output_masks_f32, &output_masks1_f32),
1759        );
1760    }
1761
1762    #[test]
1763    fn test_decoder_yolo_split() {
1764        let score_threshold = 0.45;
1765        let iou_threshold = 0.45;
1766        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1767        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1768        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1769        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1770
1771        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1772
1773        let decoder = DecoderBuilder::default()
1774            .with_config_yolo_split_det(
1775                configs::Boxes {
1776                    decoder: configs::DecoderType::Ultralytics,
1777                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1778                    shape: vec![1, 4, 8400],
1779                    dshape: vec![
1780                        (DimName::Batch, 1),
1781                        (DimName::BoxCoords, 4),
1782                        (DimName::NumBoxes, 8400),
1783                    ],
1784                    normalized: Some(true),
1785                },
1786                configs::Scores {
1787                    decoder: configs::DecoderType::Ultralytics,
1788                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1789                    shape: vec![1, 80, 8400],
1790                    dshape: vec![
1791                        (DimName::Batch, 1),
1792                        (DimName::NumClasses, 80),
1793                        (DimName::NumBoxes, 8400),
1794                    ],
1795                },
1796            )
1797            .with_score_threshold(score_threshold)
1798            .with_iou_threshold(iou_threshold)
1799            .build()
1800            .unwrap();
1801
1802        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1803        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1804
1805        decoder
1806            .decode_quantized(
1807                &[
1808                    boxes.slice(s![.., ..4, ..]).into(),
1809                    boxes.slice(s![.., 4..84, ..]).into(),
1810                ],
1811                &mut output_boxes,
1812                &mut output_masks,
1813            )
1814            .unwrap();
1815
1816        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1817        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1818        decode_yolo_det_float(
1819            seg.slice(s![0, ..84, ..]),
1820            score_threshold,
1821            iou_threshold,
1822            Some(configs::Nms::ClassAgnostic),
1823            &mut output_boxes_f32,
1824        );
1825
1826        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1827        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1828
1829        decoder
1830            .decode_float(
1831                &[
1832                    seg.slice(s![.., ..4, ..]).into_dyn(),
1833                    seg.slice(s![.., 4..84, ..]).into_dyn(),
1834                ],
1835                &mut output_boxes1,
1836                &mut output_masks1,
1837            )
1838            .unwrap();
1839        compare_outputs((&output_boxes, &output_boxes_f32), (&output_masks, &[]));
1840        compare_outputs((&output_boxes_f32, &output_boxes1), (&[], &output_masks1));
1841    }
1842
1843    #[test]
1844    fn test_decoder_masks_config_mixed() {
1845        let score_threshold = 0.45;
1846        let iou_threshold = 0.45;
1847        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1848        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1849        let boxes: Vec<_> = boxes.iter().map(|x| *x as i16 * 256).collect();
1850        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1851
1852        let quant_boxes = Quantization::new(0.021287761628627777 / 256.0, 31 * 256);
1853
1854        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1855        let protos =
1856            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1857        let protos: Vec<_> = protos.to_vec();
1858        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1859        let quant_protos = Quantization::new(0.02491161972284317, -117);
1860
1861        let decoder = DecoderBuilder::default()
1862            .with_config_yolo_split_segdet(
1863                configs::Boxes {
1864                    decoder: configs::DecoderType::Ultralytics,
1865                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1866                    shape: vec![1, 4, 8400],
1867                    dshape: vec![
1868                        (DimName::Batch, 1),
1869                        (DimName::BoxCoords, 4),
1870                        (DimName::NumBoxes, 8400),
1871                    ],
1872                    normalized: Some(true),
1873                },
1874                configs::Scores {
1875                    decoder: configs::DecoderType::Ultralytics,
1876                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1877                    shape: vec![1, 80, 8400],
1878                    dshape: vec![
1879                        (DimName::Batch, 1),
1880                        (DimName::NumClasses, 80),
1881                        (DimName::NumBoxes, 8400),
1882                    ],
1883                },
1884                configs::MaskCoefficients {
1885                    decoder: configs::DecoderType::Ultralytics,
1886                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1887                    shape: vec![1, 32, 8400],
1888                    dshape: vec![
1889                        (DimName::Batch, 1),
1890                        (DimName::NumProtos, 32),
1891                        (DimName::NumBoxes, 8400),
1892                    ],
1893                },
1894                configs::Protos {
1895                    decoder: configs::DecoderType::Ultralytics,
1896                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
1897                    shape: vec![1, 160, 160, 32],
1898                    dshape: vec![
1899                        (DimName::Batch, 1),
1900                        (DimName::Height, 160),
1901                        (DimName::Width, 160),
1902                        (DimName::NumProtos, 32),
1903                    ],
1904                },
1905            )
1906            .with_score_threshold(score_threshold)
1907            .with_iou_threshold(iou_threshold)
1908            .build()
1909            .unwrap();
1910
1911        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
1912        let mut output_masks: Vec<_> = Vec::with_capacity(500);
1913
1914        decoder
1915            .decode_quantized(
1916                &[
1917                    boxes.slice(s![.., ..4, ..]).into(),
1918                    boxes.slice(s![.., 4..84, ..]).into(),
1919                    boxes.slice(s![.., 84.., ..]).into(),
1920                    protos.view().into(),
1921                ],
1922                &mut output_boxes,
1923                &mut output_masks,
1924            )
1925            .unwrap();
1926
1927        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
1928        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
1929        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
1930        let mut output_masks_f32: Vec<_> = Vec::with_capacity(500);
1931        decode_yolo_segdet_float(
1932            seg.slice(s![0, .., ..]),
1933            protos.slice(s![0, .., .., ..]),
1934            score_threshold,
1935            iou_threshold,
1936            Some(configs::Nms::ClassAgnostic),
1937            &mut output_boxes_f32,
1938            &mut output_masks_f32,
1939        );
1940
1941        let mut output_boxes1: Vec<_> = Vec::with_capacity(500);
1942        let mut output_masks1: Vec<_> = Vec::with_capacity(500);
1943
1944        decoder
1945            .decode_float(
1946                &[
1947                    seg.slice(s![.., ..4, ..]).into_dyn(),
1948                    seg.slice(s![.., 4..84, ..]).into_dyn(),
1949                    seg.slice(s![.., 84.., ..]).into_dyn(),
1950                    protos.view().into_dyn(),
1951                ],
1952                &mut output_boxes1,
1953                &mut output_masks1,
1954            )
1955            .unwrap();
1956        compare_outputs(
1957            (&output_boxes, &output_boxes_f32),
1958            (&output_masks, &output_masks_f32),
1959        );
1960        compare_outputs(
1961            (&output_boxes_f32, &output_boxes1),
1962            (&output_masks_f32, &output_masks1),
1963        );
1964    }
1965
1966    #[test]
1967    fn test_decoder_masks_config_i32() {
1968        let score_threshold = 0.45;
1969        let iou_threshold = 0.45;
1970        let boxes = include_bytes!("../../../testdata/yolov8_boxes_116x8400.bin");
1971        let boxes = unsafe { std::slice::from_raw_parts(boxes.as_ptr() as *const i8, boxes.len()) };
1972        let scale = 1 << 23;
1973        let boxes: Vec<_> = boxes.iter().map(|x| *x as i32 * scale).collect();
1974        let boxes = ndarray::Array3::from_shape_vec((1, 116, 8400), boxes).unwrap();
1975
1976        let quant_boxes = Quantization::new(0.021287761628627777 / scale as f32, 31 * scale);
1977
1978        let protos = include_bytes!("../../../testdata/yolov8_protos_160x160x32.bin");
1979        let protos =
1980            unsafe { std::slice::from_raw_parts(protos.as_ptr() as *const i8, protos.len()) };
1981        let protos: Vec<_> = protos.iter().map(|x| *x as i32 * scale).collect();
1982        let protos = ndarray::Array4::from_shape_vec((1, 160, 160, 32), protos.to_vec()).unwrap();
1983        let quant_protos = Quantization::new(0.02491161972284317 / scale as f32, -117 * scale);
1984
1985        let decoder = DecoderBuilder::default()
1986            .with_config_yolo_split_segdet(
1987                configs::Boxes {
1988                    decoder: configs::DecoderType::Ultralytics,
1989                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
1990                    shape: vec![1, 4, 8400],
1991                    dshape: vec![
1992                        (DimName::Batch, 1),
1993                        (DimName::BoxCoords, 4),
1994                        (DimName::NumBoxes, 8400),
1995                    ],
1996                    normalized: Some(true),
1997                },
1998                configs::Scores {
1999                    decoder: configs::DecoderType::Ultralytics,
2000                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2001                    shape: vec![1, 80, 8400],
2002                    dshape: vec![
2003                        (DimName::Batch, 1),
2004                        (DimName::NumClasses, 80),
2005                        (DimName::NumBoxes, 8400),
2006                    ],
2007                },
2008                configs::MaskCoefficients {
2009                    decoder: configs::DecoderType::Ultralytics,
2010                    quantization: Some(QuantTuple(quant_boxes.scale, quant_boxes.zero_point)),
2011                    shape: vec![1, 32, 8400],
2012                    dshape: vec![
2013                        (DimName::Batch, 1),
2014                        (DimName::NumProtos, 32),
2015                        (DimName::NumBoxes, 8400),
2016                    ],
2017                },
2018                configs::Protos {
2019                    decoder: configs::DecoderType::Ultralytics,
2020                    quantization: Some(QuantTuple(quant_protos.scale, quant_protos.zero_point)),
2021                    shape: vec![1, 160, 160, 32],
2022                    dshape: vec![
2023                        (DimName::Batch, 1),
2024                        (DimName::Height, 160),
2025                        (DimName::Width, 160),
2026                        (DimName::NumProtos, 32),
2027                    ],
2028                },
2029            )
2030            .with_score_threshold(score_threshold)
2031            .with_iou_threshold(iou_threshold)
2032            .build()
2033            .unwrap();
2034
2035        let mut output_boxes: Vec<_> = Vec::with_capacity(500);
2036        let mut output_masks: Vec<_> = Vec::with_capacity(500);
2037
2038        decoder
2039            .decode_quantized(
2040                &[
2041                    boxes.slice(s![.., ..4, ..]).into(),
2042                    boxes.slice(s![.., 4..84, ..]).into(),
2043                    boxes.slice(s![.., 84.., ..]).into(),
2044                    protos.view().into(),
2045                ],
2046                &mut output_boxes,
2047                &mut output_masks,
2048            )
2049            .unwrap();
2050
2051        let protos = dequantize_ndarray::<_, _, f32>(protos.view(), quant_protos);
2052        let seg = dequantize_ndarray::<_, _, f32>(boxes.view(), quant_boxes);
2053        let mut output_boxes_f32: Vec<_> = Vec::with_capacity(500);
2054        let mut output_masks_f32: Vec<Segmentation> = Vec::with_capacity(500);
2055        decode_yolo_segdet_float(
2056            seg.slice(s![0, .., ..]),
2057            protos.slice(s![0, .., .., ..]),
2058            score_threshold,
2059            iou_threshold,
2060            Some(configs::Nms::ClassAgnostic),
2061            &mut output_boxes_f32,
2062            &mut output_masks_f32,
2063        );
2064
2065        assert_eq!(output_boxes.len(), output_boxes_f32.len());
2066        assert_eq!(output_masks.len(), output_masks_f32.len());
2067
2068        compare_outputs(
2069            (&output_boxes, &output_boxes_f32),
2070            (&output_masks, &output_masks_f32),
2071        );
2072    }
2073
2074    /// test running multiple decoders concurrently
2075    #[test]
2076    fn test_context_switch() {
2077        let yolo_det = || {
2078            let score_threshold = 0.25;
2079            let iou_threshold = 0.7;
2080            let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2081            let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2082            let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
2083            let quant = (0.0040811873, -123).into();
2084
2085            let decoder = DecoderBuilder::default()
2086                .with_config_yolo_det(
2087                    configs::Detection {
2088                        decoder: DecoderType::Ultralytics,
2089                        shape: vec![1, 84, 8400],
2090                        anchors: None,
2091                        quantization: Some(quant),
2092                        dshape: vec![
2093                            (DimName::Batch, 1),
2094                            (DimName::NumFeatures, 84),
2095                            (DimName::NumBoxes, 8400),
2096                        ],
2097                        normalized: None,
2098                    },
2099                    None,
2100                )
2101                .with_score_threshold(score_threshold)
2102                .with_iou_threshold(iou_threshold)
2103                .build()
2104                .unwrap();
2105
2106            let mut output_boxes: Vec<_> = Vec::with_capacity(50);
2107            let mut output_masks: Vec<_> = Vec::with_capacity(50);
2108
2109            for _ in 0..100 {
2110                decoder
2111                    .decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks)
2112                    .unwrap();
2113
2114                assert!(output_boxes[0].equal_within_delta(
2115                    &DetectBox {
2116                        bbox: BoundingBox {
2117                            xmin: 0.5285137,
2118                            ymin: 0.05305544,
2119                            xmax: 0.87541467,
2120                            ymax: 0.9998909,
2121                        },
2122                        score: 0.5591227,
2123                        label: 0
2124                    },
2125                    1e-6
2126                ));
2127
2128                assert!(output_boxes[1].equal_within_delta(
2129                    &DetectBox {
2130                        bbox: BoundingBox {
2131                            xmin: 0.130598,
2132                            ymin: 0.43260583,
2133                            xmax: 0.35098213,
2134                            ymax: 0.9958097,
2135                        },
2136                        score: 0.33057618,
2137                        label: 75
2138                    },
2139                    1e-6
2140                ));
2141                assert!(output_masks.is_empty());
2142            }
2143        };
2144
2145        let modelpack_det_split = || {
2146            let score_threshold = 0.8;
2147            let iou_threshold = 0.5;
2148
2149            let seg = include_bytes!("../../../testdata/modelpack_seg_2x160x160.bin");
2150            let seg = ndarray::Array4::from_shape_vec((1, 2, 160, 160), seg.to_vec()).unwrap();
2151
2152            let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2153            let detect0 =
2154                ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec()).unwrap();
2155
2156            let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2157            let detect1 =
2158                ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec()).unwrap();
2159
2160            let mut mask = seg.slice(s![0, .., .., ..]);
2161            mask.swap_axes(0, 1);
2162            mask.swap_axes(1, 2);
2163            let mask = [Segmentation {
2164                xmin: 0.0,
2165                ymin: 0.0,
2166                xmax: 1.0,
2167                ymax: 1.0,
2168                segmentation: mask.into_owned(),
2169            }];
2170            let correct_boxes = [DetectBox {
2171                bbox: BoundingBox {
2172                    xmin: 0.43171933,
2173                    ymin: 0.68243736,
2174                    xmax: 0.5626645,
2175                    ymax: 0.808863,
2176                },
2177                score: 0.99240804,
2178                label: 0,
2179            }];
2180
2181            let quant0 = (0.08547406643629074, 174).into();
2182            let quant1 = (0.09929127991199493, 183).into();
2183            let quant_seg = (1.0 / 255.0, 0).into();
2184
2185            let anchors0 = vec![
2186                [0.36666667461395264, 0.31481480598449707],
2187                [0.38749998807907104, 0.4740740656852722],
2188                [0.5333333611488342, 0.644444465637207],
2189            ];
2190            let anchors1 = vec![
2191                [0.13750000298023224, 0.2074074000120163],
2192                [0.2541666626930237, 0.21481481194496155],
2193                [0.23125000298023224, 0.35185185074806213],
2194            ];
2195
2196            let decoder = DecoderBuilder::default()
2197                .with_config_modelpack_segdet_split(
2198                    vec![
2199                        configs::Detection {
2200                            decoder: DecoderType::ModelPack,
2201                            shape: vec![1, 17, 30, 18],
2202                            anchors: Some(anchors1),
2203                            quantization: Some(quant1),
2204                            dshape: vec![
2205                                (DimName::Batch, 1),
2206                                (DimName::Height, 17),
2207                                (DimName::Width, 30),
2208                                (DimName::NumAnchorsXFeatures, 18),
2209                            ],
2210                            normalized: None,
2211                        },
2212                        configs::Detection {
2213                            decoder: DecoderType::ModelPack,
2214                            shape: vec![1, 9, 15, 18],
2215                            anchors: Some(anchors0),
2216                            quantization: Some(quant0),
2217                            dshape: vec![
2218                                (DimName::Batch, 1),
2219                                (DimName::Height, 9),
2220                                (DimName::Width, 15),
2221                                (DimName::NumAnchorsXFeatures, 18),
2222                            ],
2223                            normalized: None,
2224                        },
2225                    ],
2226                    configs::Segmentation {
2227                        decoder: DecoderType::ModelPack,
2228                        quantization: Some(quant_seg),
2229                        shape: vec![1, 2, 160, 160],
2230                        dshape: vec![
2231                            (DimName::Batch, 1),
2232                            (DimName::NumClasses, 2),
2233                            (DimName::Height, 160),
2234                            (DimName::Width, 160),
2235                        ],
2236                    },
2237                )
2238                .with_score_threshold(score_threshold)
2239                .with_iou_threshold(iou_threshold)
2240                .build()
2241                .unwrap();
2242            let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2243            let mut output_masks: Vec<_> = Vec::with_capacity(10);
2244
2245            for _ in 0..100 {
2246                decoder
2247                    .decode_quantized(
2248                        &[
2249                            detect0.view().into(),
2250                            detect1.view().into(),
2251                            seg.view().into(),
2252                        ],
2253                        &mut output_boxes,
2254                        &mut output_masks,
2255                    )
2256                    .unwrap();
2257
2258                compare_outputs((&correct_boxes, &output_boxes), (&mask, &output_masks));
2259            }
2260        };
2261
2262        let handles = vec![
2263            std::thread::spawn(yolo_det),
2264            std::thread::spawn(modelpack_det_split),
2265            std::thread::spawn(yolo_det),
2266            std::thread::spawn(modelpack_det_split),
2267            std::thread::spawn(yolo_det),
2268            std::thread::spawn(modelpack_det_split),
2269            std::thread::spawn(yolo_det),
2270            std::thread::spawn(modelpack_det_split),
2271        ];
2272        for handle in handles {
2273            handle.join().unwrap();
2274        }
2275    }
2276
2277    #[test]
2278    fn test_ndarray_to_xyxy_float() {
2279        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2280        let xyxy: [f32; 4] = XYWH::ndarray_to_xyxy_float(arr.view());
2281        assert_eq!(xyxy, [0.0_f32, 10.0, 20.0, 30.0]);
2282
2283        let arr = array![10.0_f32, 20.0, 20.0, 20.0];
2284        let xyxy: [f32; 4] = XYXY::ndarray_to_xyxy_float(arr.view());
2285        assert_eq!(xyxy, [10.0_f32, 20.0, 20.0, 20.0]);
2286    }
2287
2288    #[test]
2289    fn test_class_aware_nms_float() {
2290        use crate::float::nms_class_aware_float;
2291
2292        // Create two overlapping boxes with different classes
2293        let boxes = vec![
2294            DetectBox {
2295                bbox: BoundingBox {
2296                    xmin: 0.0,
2297                    ymin: 0.0,
2298                    xmax: 0.5,
2299                    ymax: 0.5,
2300                },
2301                score: 0.9,
2302                label: 0, // class 0
2303            },
2304            DetectBox {
2305                bbox: BoundingBox {
2306                    xmin: 0.1,
2307                    ymin: 0.1,
2308                    xmax: 0.6,
2309                    ymax: 0.6,
2310                },
2311                score: 0.8,
2312                label: 1, // class 1 - different class
2313            },
2314        ];
2315
2316        // Class-aware NMS should keep both boxes (different classes, IoU ~0.47 >
2317        // threshold 0.3)
2318        let result = nms_class_aware_float(0.3, boxes.clone());
2319        assert_eq!(
2320            result.len(),
2321            2,
2322            "Class-aware NMS should keep both boxes with different classes"
2323        );
2324
2325        // Now test with same class - should suppress one
2326        let same_class_boxes = vec![
2327            DetectBox {
2328                bbox: BoundingBox {
2329                    xmin: 0.0,
2330                    ymin: 0.0,
2331                    xmax: 0.5,
2332                    ymax: 0.5,
2333                },
2334                score: 0.9,
2335                label: 0,
2336            },
2337            DetectBox {
2338                bbox: BoundingBox {
2339                    xmin: 0.1,
2340                    ymin: 0.1,
2341                    xmax: 0.6,
2342                    ymax: 0.6,
2343                },
2344                score: 0.8,
2345                label: 0, // same class
2346            },
2347        ];
2348
2349        let result = nms_class_aware_float(0.3, same_class_boxes);
2350        assert_eq!(
2351            result.len(),
2352            1,
2353            "Class-aware NMS should suppress overlapping box with same class"
2354        );
2355        assert_eq!(result[0].label, 0);
2356        assert!((result[0].score - 0.9).abs() < 1e-6);
2357    }
2358
2359    #[test]
2360    fn test_class_agnostic_vs_aware_nms() {
2361        use crate::float::{nms_class_aware_float, nms_float};
2362
2363        // Two overlapping boxes with different classes
2364        let boxes = vec![
2365            DetectBox {
2366                bbox: BoundingBox {
2367                    xmin: 0.0,
2368                    ymin: 0.0,
2369                    xmax: 0.5,
2370                    ymax: 0.5,
2371                },
2372                score: 0.9,
2373                label: 0,
2374            },
2375            DetectBox {
2376                bbox: BoundingBox {
2377                    xmin: 0.1,
2378                    ymin: 0.1,
2379                    xmax: 0.6,
2380                    ymax: 0.6,
2381                },
2382                score: 0.8,
2383                label: 1,
2384            },
2385        ];
2386
2387        // Class-agnostic should suppress one (IoU ~0.47 > threshold 0.3)
2388        let agnostic_result = nms_float(0.3, boxes.clone());
2389        assert_eq!(
2390            agnostic_result.len(),
2391            1,
2392            "Class-agnostic NMS should suppress overlapping boxes"
2393        );
2394
2395        // Class-aware should keep both (different classes)
2396        let aware_result = nms_class_aware_float(0.3, boxes);
2397        assert_eq!(
2398            aware_result.len(),
2399            2,
2400            "Class-aware NMS should keep boxes with different classes"
2401        );
2402    }
2403
2404    #[test]
2405    fn test_class_aware_nms_int() {
2406        use crate::byte::nms_class_aware_int;
2407
2408        // Create two overlapping boxes with different classes
2409        let boxes = vec![
2410            DetectBoxQuantized {
2411                bbox: BoundingBox {
2412                    xmin: 0.0,
2413                    ymin: 0.0,
2414                    xmax: 0.5,
2415                    ymax: 0.5,
2416                },
2417                score: 200_u8,
2418                label: 0,
2419            },
2420            DetectBoxQuantized {
2421                bbox: BoundingBox {
2422                    xmin: 0.1,
2423                    ymin: 0.1,
2424                    xmax: 0.6,
2425                    ymax: 0.6,
2426                },
2427                score: 180_u8,
2428                label: 1, // different class
2429            },
2430        ];
2431
2432        // Should keep both (different classes)
2433        let result = nms_class_aware_int(0.5, boxes);
2434        assert_eq!(
2435            result.len(),
2436            2,
2437            "Class-aware NMS (int) should keep boxes with different classes"
2438        );
2439    }
2440
2441    #[test]
2442    fn test_nms_enum_default() {
2443        // Test that Nms enum has the correct default
2444        let default_nms: configs::Nms = Default::default();
2445        assert_eq!(default_nms, configs::Nms::ClassAgnostic);
2446    }
2447
2448    #[test]
2449    fn test_decoder_nms_mode() {
2450        // Test that decoder properly stores NMS mode
2451        let decoder = DecoderBuilder::default()
2452            .with_config_yolo_det(
2453                configs::Detection {
2454                    anchors: None,
2455                    decoder: DecoderType::Ultralytics,
2456                    quantization: None,
2457                    shape: vec![1, 84, 8400],
2458                    dshape: Vec::new(),
2459                    normalized: Some(true),
2460                },
2461                None,
2462            )
2463            .with_nms(Some(configs::Nms::ClassAware))
2464            .build()
2465            .unwrap();
2466
2467        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
2468    }
2469
2470    #[test]
2471    fn test_decoder_nms_bypass() {
2472        // Test that decoder can be configured with nms=None (bypass)
2473        let decoder = DecoderBuilder::default()
2474            .with_config_yolo_det(
2475                configs::Detection {
2476                    anchors: None,
2477                    decoder: DecoderType::Ultralytics,
2478                    quantization: None,
2479                    shape: vec![1, 84, 8400],
2480                    dshape: Vec::new(),
2481                    normalized: Some(true),
2482                },
2483                None,
2484            )
2485            .with_nms(None)
2486            .build()
2487            .unwrap();
2488
2489        assert_eq!(decoder.nms, None);
2490    }
2491
2492    #[test]
2493    fn test_decoder_normalized_boxes_true() {
2494        // Test that normalized_boxes returns Some(true) when explicitly set
2495        let decoder = DecoderBuilder::default()
2496            .with_config_yolo_det(
2497                configs::Detection {
2498                    anchors: None,
2499                    decoder: DecoderType::Ultralytics,
2500                    quantization: None,
2501                    shape: vec![1, 84, 8400],
2502                    dshape: Vec::new(),
2503                    normalized: Some(true),
2504                },
2505                None,
2506            )
2507            .build()
2508            .unwrap();
2509
2510        assert_eq!(decoder.normalized_boxes(), Some(true));
2511    }
2512
2513    #[test]
2514    fn test_decoder_normalized_boxes_false() {
2515        // Test that normalized_boxes returns Some(false) when config specifies
2516        // unnormalized
2517        let decoder = DecoderBuilder::default()
2518            .with_config_yolo_det(
2519                configs::Detection {
2520                    anchors: None,
2521                    decoder: DecoderType::Ultralytics,
2522                    quantization: None,
2523                    shape: vec![1, 84, 8400],
2524                    dshape: Vec::new(),
2525                    normalized: Some(false),
2526                },
2527                None,
2528            )
2529            .build()
2530            .unwrap();
2531
2532        assert_eq!(decoder.normalized_boxes(), Some(false));
2533    }
2534
2535    #[test]
2536    fn test_decoder_normalized_boxes_unknown() {
2537        // Test that normalized_boxes returns None when not specified in config
2538        let decoder = DecoderBuilder::default()
2539            .with_config_yolo_det(
2540                configs::Detection {
2541                    anchors: None,
2542                    decoder: DecoderType::Ultralytics,
2543                    quantization: None,
2544                    shape: vec![1, 84, 8400],
2545                    dshape: Vec::new(),
2546                    normalized: None,
2547                },
2548                Some(DecoderVersion::Yolo11),
2549            )
2550            .build()
2551            .unwrap();
2552
2553        assert_eq!(decoder.normalized_boxes(), None);
2554    }
2555}