Skip to main content

edgefirst_decoder/
decoder.rs

1// SPDX-FileCopyrightText: Copyright 2025 Au-Zone Technologies
2// SPDX-License-Identifier: Apache-2.0
3
4use std::collections::HashSet;
5
6use ndarray::{s, Array3, ArrayView, ArrayViewD, Dimension};
7use ndarray_stats::QuantileExt;
8use num_traits::{AsPrimitive, Float};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12    configs::{DecoderType, DimName, ModelType, QuantTuple},
13    dequantize_ndarray,
14    modelpack::{
15        decode_modelpack_det, decode_modelpack_float, decode_modelpack_split_float,
16        ModelPackDetectionConfig,
17    },
18    yolo::{
19        decode_yolo_det, decode_yolo_det_float, decode_yolo_segdet_float, decode_yolo_segdet_quant,
20        decode_yolo_split_det_float, decode_yolo_split_det_quant, decode_yolo_split_segdet_float,
21        impl_yolo_split_segdet_quant_get_boxes, impl_yolo_split_segdet_quant_process_masks,
22    },
23    DecoderError, DecoderVersion, DetectBox, Quantization, Segmentation, XYWH,
24};
25
26/// Used to represent the outputs in the model configuration.
27/// # Examples
28/// ```rust
29/// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, ConfigOutputs};
30/// # fn main() -> DecoderResult<()> {
31/// let config_json = include_str!("../../../testdata/modelpack_split.json");
32/// let config: ConfigOutputs = serde_json::from_str(config_json)?;
33/// let decoder = DecoderBuilder::new().with_config(config).build()?;
34///
35/// # Ok(())
36/// # }
37#[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Default)]
38pub struct ConfigOutputs {
39    #[serde(default)]
40    pub outputs: Vec<ConfigOutput>,
41    /// NMS mode from config file. When present, overrides the builder's NMS
42    /// setting.
43    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS: suppress overlapping
44    ///   boxes regardless of class
45    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes with
46    ///   the same class
47    /// - `None` — use builder default or skip NMS (user handles it externally)
48    #[serde(default, skip_serializing_if = "Option::is_none")]
49    pub nms: Option<configs::Nms>,
50    /// Decoder version for Ultralytics models. Determines the decoding
51    /// strategy.
52    /// - `Some(Yolo26)` — end-to-end model with embedded NMS
53    /// - `Some(Yolov5/Yolov8/Yolo11)` — traditional models requiring external
54    ///   NMS
55    /// - `None` — infer from other settings (legacy behavior)
56    #[serde(default, skip_serializing_if = "Option::is_none")]
57    pub decoder_version: Option<configs::DecoderVersion>,
58}
59
60#[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
61#[serde(tag = "type")]
62pub enum ConfigOutput {
63    #[serde(rename = "detection")]
64    Detection(configs::Detection),
65    #[serde(rename = "masks")]
66    Mask(configs::Mask),
67    #[serde(rename = "segmentation")]
68    Segmentation(configs::Segmentation),
69    #[serde(rename = "protos")]
70    Protos(configs::Protos),
71    #[serde(rename = "scores")]
72    Scores(configs::Scores),
73    #[serde(rename = "boxes")]
74    Boxes(configs::Boxes),
75    #[serde(rename = "mask_coefficients")]
76    MaskCoefficients(configs::MaskCoefficients),
77}
78
79#[derive(Debug, PartialEq, Clone)]
80pub enum ConfigOutputRef<'a> {
81    Detection(&'a configs::Detection),
82    Mask(&'a configs::Mask),
83    Segmentation(&'a configs::Segmentation),
84    Protos(&'a configs::Protos),
85    Scores(&'a configs::Scores),
86    Boxes(&'a configs::Boxes),
87    MaskCoefficients(&'a configs::MaskCoefficients),
88}
89
90impl<'a> ConfigOutputRef<'a> {
91    fn decoder(&self) -> configs::DecoderType {
92        match self {
93            ConfigOutputRef::Detection(v) => v.decoder,
94            ConfigOutputRef::Mask(v) => v.decoder,
95            ConfigOutputRef::Segmentation(v) => v.decoder,
96            ConfigOutputRef::Protos(v) => v.decoder,
97            ConfigOutputRef::Scores(v) => v.decoder,
98            ConfigOutputRef::Boxes(v) => v.decoder,
99            ConfigOutputRef::MaskCoefficients(v) => v.decoder,
100        }
101    }
102
103    fn dshape(&self) -> &[(DimName, usize)] {
104        match self {
105            ConfigOutputRef::Detection(v) => &v.dshape,
106            ConfigOutputRef::Mask(v) => &v.dshape,
107            ConfigOutputRef::Segmentation(v) => &v.dshape,
108            ConfigOutputRef::Protos(v) => &v.dshape,
109            ConfigOutputRef::Scores(v) => &v.dshape,
110            ConfigOutputRef::Boxes(v) => &v.dshape,
111            ConfigOutputRef::MaskCoefficients(v) => &v.dshape,
112        }
113    }
114}
115
116impl<'a> From<&'a configs::Detection> for ConfigOutputRef<'a> {
117    /// Converts from references of config structs to ConfigOutputRef
118    /// # Examples
119    /// ```rust
120    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
121    /// let detection_config = configs::Detection {
122    ///     anchors: None,
123    ///     decoder: configs::DecoderType::Ultralytics,
124    ///     quantization: None,
125    ///     shape: vec![1, 84, 8400],
126    ///     dshape: Vec::new(),
127    ///     normalized: Some(true),
128    /// };
129    /// let output: ConfigOutputRef = (&detection_config).into();
130    /// ```
131    fn from(v: &'a configs::Detection) -> ConfigOutputRef<'a> {
132        ConfigOutputRef::Detection(v)
133    }
134}
135
136impl<'a> From<&'a configs::Mask> for ConfigOutputRef<'a> {
137    /// Converts from references of config structs to ConfigOutputRef
138    /// # Examples
139    /// ```rust
140    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
141    /// let mask = configs::Mask {
142    ///     decoder: configs::DecoderType::ModelPack,
143    ///     quantization: None,
144    ///     shape: vec![1, 160, 160, 1],
145    ///     dshape: Vec::new(),
146    /// };
147    /// let output: ConfigOutputRef = (&mask).into();
148    /// ```
149    fn from(v: &'a configs::Mask) -> ConfigOutputRef<'a> {
150        ConfigOutputRef::Mask(v)
151    }
152}
153
154impl<'a> From<&'a configs::Segmentation> for ConfigOutputRef<'a> {
155    /// Converts from references of config structs to ConfigOutputRef
156    /// # Examples
157    /// ```rust
158    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
159    /// let seg = configs::Segmentation {
160    ///     decoder: configs::DecoderType::ModelPack,
161    ///     quantization: None,
162    ///     shape: vec![1, 160, 160, 3],
163    ///     dshape: Vec::new(),
164    /// };
165    /// let output: ConfigOutputRef = (&seg).into();
166    /// ```
167    fn from(v: &'a configs::Segmentation) -> ConfigOutputRef<'a> {
168        ConfigOutputRef::Segmentation(v)
169    }
170}
171
172impl<'a> From<&'a configs::Protos> for ConfigOutputRef<'a> {
173    /// Converts from references of config structs to ConfigOutputRef
174    /// # Examples
175    /// ```rust
176    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
177    /// let protos = configs::Protos {
178    ///     decoder: configs::DecoderType::Ultralytics,
179    ///     quantization: None,
180    ///     shape: vec![1, 160, 160, 32],
181    ///     dshape: Vec::new(),
182    /// };
183    /// let output: ConfigOutputRef = (&protos).into();
184    /// ```
185    fn from(v: &'a configs::Protos) -> ConfigOutputRef<'a> {
186        ConfigOutputRef::Protos(v)
187    }
188}
189
190impl<'a> From<&'a configs::Scores> for ConfigOutputRef<'a> {
191    /// Converts from references of config structs to ConfigOutputRef
192    /// # Examples
193    /// ```rust
194    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
195    /// let scores = configs::Scores {
196    ///     decoder: configs::DecoderType::Ultralytics,
197    ///     quantization: None,
198    ///     shape: vec![1, 40, 8400],
199    ///     dshape: Vec::new(),
200    /// };
201    /// let output: ConfigOutputRef = (&scores).into();
202    /// ```
203    fn from(v: &'a configs::Scores) -> ConfigOutputRef<'a> {
204        ConfigOutputRef::Scores(v)
205    }
206}
207
208impl<'a> From<&'a configs::Boxes> for ConfigOutputRef<'a> {
209    /// Converts from references of config structs to ConfigOutputRef
210    /// # Examples
211    /// ```rust
212    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
213    /// let boxes = configs::Boxes {
214    ///     decoder: configs::DecoderType::Ultralytics,
215    ///     quantization: None,
216    ///     shape: vec![1, 4, 8400],
217    ///     dshape: Vec::new(),
218    ///     normalized: Some(true),
219    /// };
220    /// let output: ConfigOutputRef = (&boxes).into();
221    /// ```
222    fn from(v: &'a configs::Boxes) -> ConfigOutputRef<'a> {
223        ConfigOutputRef::Boxes(v)
224    }
225}
226
227impl<'a> From<&'a configs::MaskCoefficients> for ConfigOutputRef<'a> {
228    /// Converts from references of config structs to ConfigOutputRef
229    /// # Examples
230    /// ```rust
231    /// # use edgefirst_decoder::{configs, ConfigOutputRef};
232    /// let mask_coefficients = configs::MaskCoefficients {
233    ///     decoder: configs::DecoderType::Ultralytics,
234    ///     quantization: None,
235    ///     shape: vec![1, 32, 8400],
236    ///     dshape: Vec::new(),
237    /// };
238    /// let output: ConfigOutputRef = (&mask_coefficients).into();
239    /// ```
240    fn from(v: &'a configs::MaskCoefficients) -> ConfigOutputRef<'a> {
241        ConfigOutputRef::MaskCoefficients(v)
242    }
243}
244
245impl ConfigOutput {
246    /// Returns the shape of the output.
247    ///
248    /// # Examples
249    /// ```rust
250    /// # use edgefirst_decoder::{configs, ConfigOutput};
251    /// let detection_config = configs::Detection {
252    ///     anchors: None,
253    ///     decoder: configs::DecoderType::Ultralytics,
254    ///     quantization: None,
255    ///     shape: vec![1, 84, 8400],
256    ///     dshape: Vec::new(),
257    ///     normalized: Some(true),
258    /// };
259    /// let output = ConfigOutput::Detection(detection_config);
260    /// assert_eq!(output.shape(), &[1, 84, 8400]);
261    /// ```
262    pub fn shape(&self) -> &[usize] {
263        match self {
264            ConfigOutput::Detection(detection) => &detection.shape,
265            ConfigOutput::Mask(mask) => &mask.shape,
266            ConfigOutput::Segmentation(segmentation) => &segmentation.shape,
267            ConfigOutput::Scores(scores) => &scores.shape,
268            ConfigOutput::Boxes(boxes) => &boxes.shape,
269            ConfigOutput::Protos(protos) => &protos.shape,
270            ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.shape,
271        }
272    }
273
274    /// Returns the decoder type of the output.
275    ///    
276    /// # Examples
277    /// ```rust
278    /// # use edgefirst_decoder::{configs, ConfigOutput};
279    /// let detection_config = configs::Detection {
280    ///     anchors: None,
281    ///     decoder: configs::DecoderType::Ultralytics,
282    ///     quantization: None,
283    ///     shape: vec![1, 84, 8400],
284    ///     dshape: Vec::new(),
285    ///     normalized: Some(true),
286    /// };
287    /// let output = ConfigOutput::Detection(detection_config);
288    /// assert_eq!(output.decoder(), &configs::DecoderType::Ultralytics);
289    /// ```
290    pub fn decoder(&self) -> &configs::DecoderType {
291        match self {
292            ConfigOutput::Detection(detection) => &detection.decoder,
293            ConfigOutput::Mask(mask) => &mask.decoder,
294            ConfigOutput::Segmentation(segmentation) => &segmentation.decoder,
295            ConfigOutput::Scores(scores) => &scores.decoder,
296            ConfigOutput::Boxes(boxes) => &boxes.decoder,
297            ConfigOutput::Protos(protos) => &protos.decoder,
298            ConfigOutput::MaskCoefficients(mask_coefficients) => &mask_coefficients.decoder,
299        }
300    }
301
302    /// Returns the quantization of the output.
303    ///
304    /// # Examples
305    /// ```rust
306    /// # use edgefirst_decoder::{configs, ConfigOutput};
307    /// let detection_config = configs::Detection {
308    ///   anchors: None,
309    ///   decoder: configs::DecoderType::Ultralytics,
310    ///   quantization: Some(configs::QuantTuple(0.012345, 26)),
311    ///   shape: vec![1, 84, 8400],
312    ///   dshape: Vec::new(),
313    ///   normalized: Some(true),
314    /// };
315    /// let output = ConfigOutput::Detection(detection_config);
316    /// assert_eq!(output.quantization(),
317    /// Some(configs::QuantTuple(0.012345,26))); ```
318    pub fn quantization(&self) -> Option<QuantTuple> {
319        match self {
320            ConfigOutput::Detection(detection) => detection.quantization,
321            ConfigOutput::Mask(mask) => mask.quantization,
322            ConfigOutput::Segmentation(segmentation) => segmentation.quantization,
323            ConfigOutput::Scores(scores) => scores.quantization,
324            ConfigOutput::Boxes(boxes) => boxes.quantization,
325            ConfigOutput::Protos(protos) => protos.quantization,
326            ConfigOutput::MaskCoefficients(mask_coefficients) => mask_coefficients.quantization,
327        }
328    }
329}
330
331pub mod configs {
332    use std::fmt::Display;
333
334    use serde::{Deserialize, Serialize};
335
336    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy)]
337    pub struct QuantTuple(pub f32, pub i32);
338    impl From<QuantTuple> for (f32, i32) {
339        fn from(value: QuantTuple) -> Self {
340            (value.0, value.1)
341        }
342    }
343
344    impl From<(f32, i32)> for QuantTuple {
345        fn from(value: (f32, i32)) -> Self {
346            QuantTuple(value.0, value.1)
347        }
348    }
349
350    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
351    pub struct Segmentation {
352        pub decoder: DecoderType,
353        pub quantization: Option<QuantTuple>,
354        pub shape: Vec<usize>,
355        // #[serde(default)]
356        // pub channels_first: bool,
357        #[serde(default)]
358        pub dshape: Vec<(DimName, usize)>,
359    }
360
361    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
362    pub struct Protos {
363        pub decoder: DecoderType,
364        #[serde(flatten)]
365        pub quantization: Option<QuantTuple>,
366        pub shape: Vec<usize>,
367        // #[serde(default)]
368        // pub channels_first: bool,
369        #[serde(default)]
370        pub dshape: Vec<(DimName, usize)>,
371    }
372
373    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
374    pub struct MaskCoefficients {
375        pub decoder: DecoderType,
376        pub quantization: Option<QuantTuple>,
377        pub shape: Vec<usize>,
378        // #[serde(default)]
379        // pub channels_first: bool,
380        #[serde(default)]
381        pub dshape: Vec<(DimName, usize)>,
382    }
383
384    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
385    pub struct Mask {
386        pub decoder: DecoderType,
387        pub quantization: Option<QuantTuple>,
388        pub shape: Vec<usize>,
389        // #[serde(default)]
390        // pub channels_first: bool,
391        #[serde(default)]
392        pub dshape: Vec<(DimName, usize)>,
393    }
394
395    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
396    pub struct Detection {
397        pub anchors: Option<Vec<[f32; 2]>>,
398        pub decoder: DecoderType,
399        pub quantization: Option<QuantTuple>,
400        pub shape: Vec<usize>,
401        // #[serde(default)]
402        // pub channels_first: bool,
403        #[serde(default)]
404        pub dshape: Vec<(DimName, usize)>,
405        /// Whether box coordinates are normalized to [0,1] range.
406        /// - `Some(true)`: Coordinates in [0,1] range relative to model input
407        /// - `Some(false)`: Pixel coordinates relative to model input
408        ///   (letterboxed)
409        /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
410        ///   > 1.0)
411        #[serde(default)]
412        pub normalized: Option<bool>,
413    }
414
415    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
416    pub struct Scores {
417        pub decoder: DecoderType,
418        pub quantization: Option<QuantTuple>,
419        pub shape: Vec<usize>,
420        // #[serde(default)]
421        // pub channels_first: bool,
422        #[serde(default)]
423        pub dshape: Vec<(DimName, usize)>,
424    }
425
426    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone)]
427    pub struct Boxes {
428        pub decoder: DecoderType,
429        pub quantization: Option<QuantTuple>,
430        pub shape: Vec<usize>,
431        // #[serde(default)]
432        // pub channels_first: bool,
433        #[serde(default)]
434        pub dshape: Vec<(DimName, usize)>,
435        /// Whether box coordinates are normalized to [0,1] range.
436        /// - `Some(true)`: Coordinates in [0,1] range relative to model input
437        /// - `Some(false)`: Pixel coordinates relative to model input
438        ///   (letterboxed)
439        /// - `None`: Unknown, caller must infer (e.g., check if any coordinate
440        ///   > 1.0)
441        #[serde(default)]
442        pub normalized: Option<bool>,
443    }
444
445    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
446    pub enum DimName {
447        #[serde(rename = "batch")]
448        Batch,
449        #[serde(rename = "height")]
450        Height,
451        #[serde(rename = "width")]
452        Width,
453        #[serde(rename = "num_classes")]
454        NumClasses,
455        #[serde(rename = "num_features")]
456        NumFeatures,
457        #[serde(rename = "num_boxes")]
458        NumBoxes,
459        #[serde(rename = "num_protos")]
460        NumProtos,
461        #[serde(rename = "num_anchors_x_features")]
462        NumAnchorsXFeatures,
463        #[serde(rename = "padding")]
464        Padding,
465        #[serde(rename = "box_coords")]
466        BoxCoords,
467    }
468
469    impl Display for DimName {
470        /// Formats the DimName for display
471        /// # Examples
472        /// ```rust
473        /// # use edgefirst_decoder::configs::DimName;
474        /// let dim = DimName::Height;
475        /// assert_eq!(format!("{}", dim), "height");
476        /// # let s = format!("{} {} {} {} {} {} {} {} {} {}", DimName::Batch, DimName::Height, DimName::Width, DimName::NumClasses, DimName::NumFeatures, DimName::NumBoxes, DimName::NumProtos, DimName::NumAnchorsXFeatures, DimName::Padding, DimName::BoxCoords);
477        /// # assert_eq!(s, "batch height width num_classes num_features num_boxes num_protos num_anchors_x_features padding box_coords");
478        /// ```
479        fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
480            match self {
481                DimName::Batch => write!(f, "batch"),
482                DimName::Height => write!(f, "height"),
483                DimName::Width => write!(f, "width"),
484                DimName::NumClasses => write!(f, "num_classes"),
485                DimName::NumFeatures => write!(f, "num_features"),
486                DimName::NumBoxes => write!(f, "num_boxes"),
487                DimName::NumProtos => write!(f, "num_protos"),
488                DimName::NumAnchorsXFeatures => write!(f, "num_anchors_x_features"),
489                DimName::Padding => write!(f, "padding"),
490                DimName::BoxCoords => write!(f, "box_coords"),
491            }
492        }
493    }
494
495    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
496    pub enum DecoderType {
497        #[serde(rename = "modelpack")]
498        ModelPack,
499        #[serde(rename = "ultralytics")]
500        Ultralytics,
501    }
502
503    /// Decoder version for Ultralytics models.
504    ///
505    /// Specifies the YOLO architecture version, which determines the decoding
506    /// strategy:
507    /// - `Yolov5`, `Yolov8`, `Yolo11`: Traditional models requiring external
508    ///   NMS
509    /// - `Yolo26`: End-to-end models with NMS embedded in the model
510    ///   architecture
511    ///
512    /// When `decoder_version` is set to `Yolo26`, the decoder uses end-to-end
513    /// model types regardless of the `nms` setting.
514    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq)]
515    #[serde(rename_all = "lowercase")]
516    pub enum DecoderVersion {
517        /// YOLOv5 - anchor-free DFL decoder, requires external NMS
518        #[serde(rename = "yolov5")]
519        Yolov5,
520        /// YOLOv8 - anchor-free DFL decoder, requires external NMS
521        #[serde(rename = "yolov8")]
522        Yolov8,
523        /// YOLO11 - anchor-free DFL decoder, requires external NMS
524        #[serde(rename = "yolo11")]
525        Yolo11,
526        /// YOLO26 - end-to-end model with embedded NMS (one-to-one matching
527        /// heads)
528        #[serde(rename = "yolo26")]
529        Yolo26,
530    }
531
532    impl DecoderVersion {
533        /// Returns true if this version uses end-to-end inference (embedded
534        /// NMS).
535        pub fn is_end_to_end(&self) -> bool {
536            matches!(self, DecoderVersion::Yolo26)
537        }
538    }
539
540    /// NMS (Non-Maximum Suppression) mode for filtering overlapping detections.
541    ///
542    /// This enum is used with `Option<Nms>`:
543    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
544    ///   overlapping boxes regardless of class label
545    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
546    ///   share the same class label AND overlap above the IoU threshold
547    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
548    #[derive(Debug, PartialEq, Serialize, Deserialize, Clone, Copy, Hash, Eq, Default)]
549    #[serde(rename_all = "snake_case")]
550    pub enum Nms {
551        /// Suppress overlapping boxes regardless of class label (default HAL
552        /// behavior)
553        #[default]
554        ClassAgnostic,
555        /// Only suppress boxes with the same class label that overlap
556        ClassAware,
557    }
558
559    #[derive(Debug, Clone, PartialEq)]
560    pub enum ModelType {
561        ModelPackSegDet {
562            boxes: Boxes,
563            scores: Scores,
564            segmentation: Segmentation,
565        },
566        ModelPackSegDetSplit {
567            detection: Vec<Detection>,
568            segmentation: Segmentation,
569        },
570        ModelPackDet {
571            boxes: Boxes,
572            scores: Scores,
573        },
574        ModelPackDetSplit {
575            detection: Vec<Detection>,
576        },
577        ModelPackSeg {
578            segmentation: Segmentation,
579        },
580        YoloDet {
581            boxes: Detection,
582        },
583        YoloSegDet {
584            boxes: Detection,
585            protos: Protos,
586        },
587        YoloSplitDet {
588            boxes: Boxes,
589            scores: Scores,
590        },
591        YoloSplitSegDet {
592            boxes: Boxes,
593            scores: Scores,
594            mask_coeff: MaskCoefficients,
595            protos: Protos,
596        },
597        /// End-to-end YOLO detection (post-NMS output from model)
598        /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf,
599        /// class, ...]
600        YoloEndToEndDet {
601            boxes: Detection,
602        },
603        /// End-to-end YOLO detection + segmentation (post-NMS output from
604        /// model) Input shape: (1, N, 6 + num_protos) where columns are
605        /// [x1, y1, x2, y2, conf, class, mask_coeff_0, ..., mask_coeff_31]
606        YoloEndToEndSegDet {
607            boxes: Detection,
608            protos: Protos,
609        },
610    }
611
612    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
613    #[serde(rename_all = "lowercase")]
614    pub enum DataType {
615        Raw = 0,
616        Int8 = 1,
617        UInt8 = 2,
618        Int16 = 3,
619        UInt16 = 4,
620        Float16 = 5,
621        Int32 = 6,
622        UInt32 = 7,
623        Float32 = 8,
624        Int64 = 9,
625        UInt64 = 10,
626        Float64 = 11,
627        String = 12,
628    }
629}
630
631#[derive(Debug, Clone, PartialEq)]
632pub struct DecoderBuilder {
633    config_src: Option<ConfigSource>,
634    iou_threshold: f32,
635    score_threshold: f32,
636    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
637    /// models)
638    nms: Option<configs::Nms>,
639}
640
641#[derive(Debug, Clone, PartialEq)]
642enum ConfigSource {
643    Yaml(String),
644    Json(String),
645    Config(ConfigOutputs),
646}
647
648impl Default for DecoderBuilder {
649    /// Creates a default DecoderBuilder with no configuration and 0.5 score
650    /// threshold and 0.5 OU threshold.
651    ///
652    /// A valid confguration must be provided before building the Decoder.
653    ///
654    /// # Examples
655    /// ```rust
656    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
657    /// # fn main() -> DecoderResult<()> {
658    /// #  let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
659    /// let decoder = DecoderBuilder::default()
660    ///     .with_config_yaml_str(config_yaml)
661    ///     .build()?;
662    /// assert_eq!(decoder.score_threshold, 0.5);
663    /// assert_eq!(decoder.iou_threshold, 0.5);
664    ///
665    /// # Ok(())
666    /// # }
667    /// ```
668    fn default() -> Self {
669        Self {
670            config_src: None,
671            iou_threshold: 0.5,
672            score_threshold: 0.5,
673            nms: Some(configs::Nms::ClassAgnostic),
674        }
675    }
676}
677
678impl DecoderBuilder {
679    /// Creates a default DecoderBuilder with no configuration and 0.5 score
680    /// threshold and 0.5 OU threshold.
681    ///
682    /// A valid confguration must be provided before building the Decoder.
683    ///
684    /// # Examples
685    /// ```rust
686    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
687    /// # fn main() -> DecoderResult<()> {
688    /// #  let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
689    /// let decoder = DecoderBuilder::new()
690    ///     .with_config_yaml_str(config_yaml)
691    ///     .build()?;
692    /// assert_eq!(decoder.score_threshold, 0.5);
693    /// assert_eq!(decoder.iou_threshold, 0.5);
694    ///
695    /// # Ok(())
696    /// # }
697    /// ```
698    pub fn new() -> Self {
699        Self::default()
700    }
701
702    /// Loads a model configuration in YAML format. Does not check if the string
703    /// is a correct configuration file. Use `DecoderBuilder.build()` to
704    /// deserialize the YAML and parse the model configuration.
705    ///
706    /// # Examples
707    /// ```rust
708    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
709    /// # fn main() -> DecoderResult<()> {
710    /// let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
711    /// let decoder = DecoderBuilder::new()
712    ///     .with_config_yaml_str(config_yaml)
713    ///     .build()?;
714    ///
715    /// # Ok(())
716    /// # }
717    /// ```
718    pub fn with_config_yaml_str(mut self, yaml_str: String) -> Self {
719        self.config_src.replace(ConfigSource::Yaml(yaml_str));
720        self
721    }
722
723    /// Loads a model configuration in JSON format. Does not check if the string
724    /// is a correct configuration file. Use `DecoderBuilder.build()` to
725    /// deserialize the JSON and parse the model configuration.
726    ///
727    /// # Examples
728    /// ```rust
729    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
730    /// # fn main() -> DecoderResult<()> {
731    /// let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
732    /// let decoder = DecoderBuilder::new()
733    ///     .with_config_json_str(config_json)
734    ///     .build()?;
735    ///
736    /// # Ok(())
737    /// # }
738    /// ```
739    pub fn with_config_json_str(mut self, json_str: String) -> Self {
740        self.config_src.replace(ConfigSource::Json(json_str));
741        self
742    }
743
744    /// Loads a model configuration. Does not check if the configuration is
745    /// correct. Intended to be used when the user needs control over the
746    /// deserialize of the configuration information. Use
747    /// `DecoderBuilder.build()` to parse the model configuration.
748    ///
749    /// # Examples
750    /// ```rust
751    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
752    /// # fn main() -> DecoderResult<()> {
753    /// let config_json = include_str!("../../../testdata/modelpack_split.json");
754    /// let config = serde_json::from_str(config_json)?;
755    /// let decoder = DecoderBuilder::new().with_config(config).build()?;
756    ///
757    /// # Ok(())
758    /// # }
759    /// ```
760    pub fn with_config(mut self, config: ConfigOutputs) -> Self {
761        self.config_src.replace(ConfigSource::Config(config));
762        self
763    }
764
765    /// Loads a YOLO detection model configuration.  Use
766    /// `DecoderBuilder.build()` to parse the model configuration.
767    ///
768    /// # Examples
769    /// ```rust
770    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
771    /// # fn main() -> DecoderResult<()> {
772    /// let decoder = DecoderBuilder::new()
773    ///     .with_config_yolo_det(
774    ///         configs::Detection {
775    ///             anchors: None,
776    ///             decoder: configs::DecoderType::Ultralytics,
777    ///             quantization: Some(configs::QuantTuple(0.012345, 26)),
778    ///             shape: vec![1, 84, 8400],
779    ///             dshape: Vec::new(),
780    ///             normalized: Some(true),
781    ///         },
782    ///         None,
783    ///     )
784    ///     .build()?;
785    ///
786    /// # Ok(())
787    /// # }
788    /// ```
789    pub fn with_config_yolo_det(
790        mut self,
791        boxes: configs::Detection,
792        version: Option<DecoderVersion>,
793    ) -> Self {
794        let config = ConfigOutputs {
795            outputs: vec![ConfigOutput::Detection(boxes)],
796            decoder_version: version,
797            ..Default::default()
798        };
799        self.config_src.replace(ConfigSource::Config(config));
800        self
801    }
802
803    /// Loads a YOLO split detection model configuration.  Use
804    /// `DecoderBuilder.build()` to parse the model configuration.
805    ///
806    /// # Examples
807    /// ```rust
808    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
809    /// # fn main() -> DecoderResult<()> {
810    /// let boxes_config = configs::Boxes {
811    ///     decoder: configs::DecoderType::Ultralytics,
812    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
813    ///     shape: vec![1, 4, 8400],
814    ///     dshape: Vec::new(),
815    ///     normalized: Some(true),
816    /// };
817    /// let scores_config = configs::Scores {
818    ///     decoder: configs::DecoderType::Ultralytics,
819    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
820    ///     shape: vec![1, 80, 8400],
821    ///     dshape: Vec::new(),
822    /// };
823    /// let decoder = DecoderBuilder::new()
824    ///     .with_config_yolo_split_det(boxes_config, scores_config)
825    ///     .build()?;
826    /// # Ok(())
827    /// # }
828    /// ```
829    pub fn with_config_yolo_split_det(
830        mut self,
831        boxes: configs::Boxes,
832        scores: configs::Scores,
833    ) -> Self {
834        let config = ConfigOutputs {
835            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
836            ..Default::default()
837        };
838        self.config_src.replace(ConfigSource::Config(config));
839        self
840    }
841
842    /// Loads a YOLO segmentation model configuration.  Use
843    /// `DecoderBuilder.build()` to parse the model configuration.
844    ///
845    /// # Examples
846    /// ```rust
847    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
848    /// # fn main() -> DecoderResult<()> {
849    /// let seg_config = configs::Detection {
850    ///     decoder: configs::DecoderType::Ultralytics,
851    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
852    ///     shape: vec![1, 116, 8400],
853    ///     anchors: None,
854    ///     dshape: Vec::new(),
855    ///     normalized: Some(true),
856    /// };
857    /// let protos_config = configs::Protos {
858    ///     decoder: configs::DecoderType::Ultralytics,
859    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
860    ///     shape: vec![1, 160, 160, 32],
861    ///     dshape: Vec::new(),
862    /// };
863    /// let decoder = DecoderBuilder::new()
864    ///     .with_config_yolo_segdet(
865    ///         seg_config,
866    ///         protos_config,
867    ///         Some(configs::DecoderVersion::Yolov8),
868    ///     )
869    ///     .build()?;
870    /// # Ok(())
871    /// # }
872    /// ```
873    pub fn with_config_yolo_segdet(
874        mut self,
875        boxes: configs::Detection,
876        protos: configs::Protos,
877        version: Option<DecoderVersion>,
878    ) -> Self {
879        let config = ConfigOutputs {
880            outputs: vec![ConfigOutput::Detection(boxes), ConfigOutput::Protos(protos)],
881            decoder_version: version,
882            ..Default::default()
883        };
884        self.config_src.replace(ConfigSource::Config(config));
885        self
886    }
887
888    /// Loads a YOLO split segmentation model configuration.  Use
889    /// `DecoderBuilder.build()` to parse the model configuration.
890    ///
891    /// # Examples
892    /// ```rust
893    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
894    /// # fn main() -> DecoderResult<()> {
895    /// let boxes_config = configs::Boxes {
896    ///     decoder: configs::DecoderType::Ultralytics,
897    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
898    ///     shape: vec![1, 4, 8400],
899    ///     dshape: Vec::new(),
900    ///     normalized: Some(true),
901    /// };
902    /// let scores_config = configs::Scores {
903    ///     decoder: configs::DecoderType::Ultralytics,
904    ///     quantization: Some(configs::QuantTuple(0.012345, 14)),
905    ///     shape: vec![1, 80, 8400],
906    ///     dshape: Vec::new(),
907    /// };
908    /// let mask_config = configs::MaskCoefficients {
909    ///     decoder: configs::DecoderType::Ultralytics,
910    ///     quantization: Some(configs::QuantTuple(0.0064123, 125)),
911    ///     shape: vec![1, 32, 8400],
912    ///     dshape: Vec::new(),
913    /// };
914    /// let protos_config = configs::Protos {
915    ///     decoder: configs::DecoderType::Ultralytics,
916    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
917    ///     shape: vec![1, 160, 160, 32],
918    ///     dshape: Vec::new(),
919    /// };
920    /// let decoder = DecoderBuilder::new()
921    ///     .with_config_yolo_split_segdet(boxes_config, scores_config, mask_config, protos_config)
922    ///     .build()?;
923    /// # Ok(())
924    /// # }
925    /// ```
926    pub fn with_config_yolo_split_segdet(
927        mut self,
928        boxes: configs::Boxes,
929        scores: configs::Scores,
930        mask_coefficients: configs::MaskCoefficients,
931        protos: configs::Protos,
932    ) -> Self {
933        let config = ConfigOutputs {
934            outputs: vec![
935                ConfigOutput::Boxes(boxes),
936                ConfigOutput::Scores(scores),
937                ConfigOutput::MaskCoefficients(mask_coefficients),
938                ConfigOutput::Protos(protos),
939            ],
940            ..Default::default()
941        };
942        self.config_src.replace(ConfigSource::Config(config));
943        self
944    }
945
946    /// Loads a ModelPack detection model configuration.  Use
947    /// `DecoderBuilder.build()` to parse the model configuration.
948    ///
949    /// # Examples
950    /// ```rust
951    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
952    /// # fn main() -> DecoderResult<()> {
953    /// let boxes_config = configs::Boxes {
954    ///     decoder: configs::DecoderType::ModelPack,
955    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
956    ///     shape: vec![1, 8400, 1, 4],
957    ///     dshape: Vec::new(),
958    ///     normalized: Some(true),
959    /// };
960    /// let scores_config = configs::Scores {
961    ///     decoder: configs::DecoderType::ModelPack,
962    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
963    ///     shape: vec![1, 8400, 3],
964    ///     dshape: Vec::new(),
965    /// };
966    /// let decoder = DecoderBuilder::new()
967    ///     .with_config_modelpack_det(boxes_config, scores_config)
968    ///     .build()?;
969    /// # Ok(())
970    /// # }
971    /// ```
972    pub fn with_config_modelpack_det(
973        mut self,
974        boxes: configs::Boxes,
975        scores: configs::Scores,
976    ) -> Self {
977        let config = ConfigOutputs {
978            outputs: vec![ConfigOutput::Boxes(boxes), ConfigOutput::Scores(scores)],
979            ..Default::default()
980        };
981        self.config_src.replace(ConfigSource::Config(config));
982        self
983    }
984
985    /// Loads a ModelPack split detection model configuration. Use
986    /// `DecoderBuilder.build()` to parse the model configuration.
987    ///
988    /// # Examples
989    /// ```rust
990    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
991    /// # fn main() -> DecoderResult<()> {
992    /// let config0 = configs::Detection {
993    ///     anchors: Some(vec![
994    ///         [0.13750000298023224, 0.2074074000120163],
995    ///         [0.2541666626930237, 0.21481481194496155],
996    ///         [0.23125000298023224, 0.35185185074806213],
997    ///     ]),
998    ///     decoder: configs::DecoderType::ModelPack,
999    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
1000    ///     shape: vec![1, 17, 30, 18],
1001    ///     dshape: Vec::new(),
1002    ///     normalized: Some(true),
1003    /// };
1004    /// let config1 = configs::Detection {
1005    ///     anchors: Some(vec![
1006    ///         [0.36666667461395264, 0.31481480598449707],
1007    ///         [0.38749998807907104, 0.4740740656852722],
1008    ///         [0.5333333611488342, 0.644444465637207],
1009    ///     ]),
1010    ///     decoder: configs::DecoderType::ModelPack,
1011    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1012    ///     shape: vec![1, 9, 15, 18],
1013    ///     dshape: Vec::new(),
1014    ///     normalized: Some(true),
1015    /// };
1016    ///
1017    /// let decoder = DecoderBuilder::new()
1018    ///     .with_config_modelpack_det_split(vec![config0, config1])
1019    ///     .build()?;
1020    /// # Ok(())
1021    /// # }
1022    /// ```
1023    pub fn with_config_modelpack_det_split(mut self, boxes: Vec<configs::Detection>) -> Self {
1024        let outputs = boxes.into_iter().map(ConfigOutput::Detection).collect();
1025        let config = ConfigOutputs {
1026            outputs,
1027            ..Default::default()
1028        };
1029        self.config_src.replace(ConfigSource::Config(config));
1030        self
1031    }
1032
1033    /// Loads a ModelPack segmentation detection model configuration. Use
1034    /// `DecoderBuilder.build()` to parse the model configuration.
1035    ///
1036    /// # Examples
1037    /// ```rust
1038    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1039    /// # fn main() -> DecoderResult<()> {
1040    /// let boxes_config = configs::Boxes {
1041    ///     decoder: configs::DecoderType::ModelPack,
1042    ///     quantization: Some(configs::QuantTuple(0.012345, 26)),
1043    ///     shape: vec![1, 8400, 1, 4],
1044    ///     dshape: Vec::new(),
1045    ///     normalized: Some(true),
1046    /// };
1047    /// let scores_config = configs::Scores {
1048    ///     decoder: configs::DecoderType::ModelPack,
1049    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1050    ///     shape: vec![1, 8400, 2],
1051    ///     dshape: Vec::new(),
1052    /// };
1053    /// let seg_config = configs::Segmentation {
1054    ///     decoder: configs::DecoderType::ModelPack,
1055    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1056    ///     shape: vec![1, 640, 640, 3],
1057    ///     dshape: Vec::new(),
1058    /// };
1059    /// let decoder = DecoderBuilder::new()
1060    ///     .with_config_modelpack_segdet(boxes_config, scores_config, seg_config)
1061    ///     .build()?;
1062    /// # Ok(())
1063    /// # }
1064    /// ```
1065    pub fn with_config_modelpack_segdet(
1066        mut self,
1067        boxes: configs::Boxes,
1068        scores: configs::Scores,
1069        segmentation: configs::Segmentation,
1070    ) -> Self {
1071        let config = ConfigOutputs {
1072            outputs: vec![
1073                ConfigOutput::Boxes(boxes),
1074                ConfigOutput::Scores(scores),
1075                ConfigOutput::Segmentation(segmentation),
1076            ],
1077            ..Default::default()
1078        };
1079        self.config_src.replace(ConfigSource::Config(config));
1080        self
1081    }
1082
1083    /// Loads a ModelPack segmentation split detection model configuration. Use
1084    /// `DecoderBuilder.build()` to parse the model configuration.
1085    ///
1086    /// # Examples
1087    /// ```rust
1088    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1089    /// # fn main() -> DecoderResult<()> {
1090    /// let config0 = configs::Detection {
1091    ///     anchors: Some(vec![
1092    ///         [0.36666667461395264, 0.31481480598449707],
1093    ///         [0.38749998807907104, 0.4740740656852722],
1094    ///         [0.5333333611488342, 0.644444465637207],
1095    ///     ]),
1096    ///     decoder: configs::DecoderType::ModelPack,
1097    ///     quantization: Some(configs::QuantTuple(0.08547406643629074, 174)),
1098    ///     shape: vec![1, 9, 15, 18],
1099    ///     dshape: Vec::new(),
1100    ///     normalized: Some(true),
1101    /// };
1102    /// let config1 = configs::Detection {
1103    ///     anchors: Some(vec![
1104    ///         [0.13750000298023224, 0.2074074000120163],
1105    ///         [0.2541666626930237, 0.21481481194496155],
1106    ///         [0.23125000298023224, 0.35185185074806213],
1107    ///     ]),
1108    ///     decoder: configs::DecoderType::ModelPack,
1109    ///     quantization: Some(configs::QuantTuple(0.09929127991199493, 183)),
1110    ///     shape: vec![1, 17, 30, 18],
1111    ///     dshape: Vec::new(),
1112    ///     normalized: Some(true),
1113    /// };
1114    /// let seg_config = configs::Segmentation {
1115    ///     decoder: configs::DecoderType::ModelPack,
1116    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1117    ///     shape: vec![1, 640, 640, 2],
1118    ///     dshape: Vec::new(),
1119    /// };
1120    /// let decoder = DecoderBuilder::new()
1121    ///     .with_config_modelpack_segdet_split(vec![config0, config1], seg_config)
1122    ///     .build()?;
1123    /// # Ok(())
1124    /// # }
1125    /// ```
1126    pub fn with_config_modelpack_segdet_split(
1127        mut self,
1128        boxes: Vec<configs::Detection>,
1129        segmentation: configs::Segmentation,
1130    ) -> Self {
1131        let mut outputs = boxes
1132            .into_iter()
1133            .map(ConfigOutput::Detection)
1134            .collect::<Vec<_>>();
1135        outputs.push(ConfigOutput::Segmentation(segmentation));
1136        let config = ConfigOutputs {
1137            outputs,
1138            ..Default::default()
1139        };
1140        self.config_src.replace(ConfigSource::Config(config));
1141        self
1142    }
1143
1144    /// Loads a ModelPack segmentation model configuration. Use
1145    /// `DecoderBuilder.build()` to parse the model configuration.
1146    ///
1147    /// # Examples
1148    /// ```rust
1149    /// # use edgefirst_decoder::{ DecoderBuilder, DecoderResult, configs };
1150    /// # fn main() -> DecoderResult<()> {
1151    /// let seg_config = configs::Segmentation {
1152    ///     decoder: configs::DecoderType::ModelPack,
1153    ///     quantization: Some(configs::QuantTuple(0.0064123, -31)),
1154    ///     shape: vec![1, 640, 640, 3],
1155    ///     dshape: Vec::new(),
1156    /// };
1157    /// let decoder = DecoderBuilder::new()
1158    ///     .with_config_modelpack_seg(seg_config)
1159    ///     .build()?;
1160    /// # Ok(())
1161    /// # }
1162    /// ```
1163    pub fn with_config_modelpack_seg(mut self, segmentation: configs::Segmentation) -> Self {
1164        let config = ConfigOutputs {
1165            outputs: vec![ConfigOutput::Segmentation(segmentation)],
1166            ..Default::default()
1167        };
1168        self.config_src.replace(ConfigSource::Config(config));
1169        self
1170    }
1171
1172    /// Sets the scores threshold of the decoder
1173    ///
1174    /// # Examples
1175    /// ```rust
1176    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1177    /// # fn main() -> DecoderResult<()> {
1178    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1179    /// let decoder = DecoderBuilder::new()
1180    ///     .with_config_json_str(config_json)
1181    ///     .with_score_threshold(0.654)
1182    ///     .build()?;
1183    /// assert_eq!(decoder.score_threshold, 0.654);
1184    /// # Ok(())
1185    /// # }
1186    /// ```
1187    pub fn with_score_threshold(mut self, score_threshold: f32) -> Self {
1188        self.score_threshold = score_threshold;
1189        self
1190    }
1191
1192    /// Sets the IOU threshold of the decoder. Has no effect when NMS is set to
1193    /// `None`
1194    ///
1195    /// # Examples
1196    /// ```rust
1197    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1198    /// # fn main() -> DecoderResult<()> {
1199    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1200    /// let decoder = DecoderBuilder::new()
1201    ///     .with_config_json_str(config_json)
1202    ///     .with_iou_threshold(0.654)
1203    ///     .build()?;
1204    /// assert_eq!(decoder.iou_threshold, 0.654);
1205    /// # Ok(())
1206    /// # }
1207    /// ```
1208    pub fn with_iou_threshold(mut self, iou_threshold: f32) -> Self {
1209        self.iou_threshold = iou_threshold;
1210        self
1211    }
1212
1213    /// Sets the NMS mode for the decoder.
1214    ///
1215    /// - `Some(Nms::ClassAgnostic)` — class-agnostic NMS (default): suppress
1216    ///   overlapping boxes regardless of class label
1217    /// - `Some(Nms::ClassAware)` — class-aware NMS: only suppress boxes that
1218    ///   share the same class label AND overlap above the IoU threshold
1219    /// - `None` — bypass NMS entirely (for end-to-end models with embedded NMS)
1220    ///
1221    /// # Examples
1222    /// ```rust
1223    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::Nms};
1224    /// # fn main() -> DecoderResult<()> {
1225    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1226    /// let decoder = DecoderBuilder::new()
1227    ///     .with_config_json_str(config_json)
1228    ///     .with_nms(Some(Nms::ClassAware))
1229    ///     .build()?;
1230    /// assert_eq!(decoder.nms, Some(Nms::ClassAware));
1231    /// # Ok(())
1232    /// # }
1233    /// ```
1234    pub fn with_nms(mut self, nms: Option<configs::Nms>) -> Self {
1235        self.nms = nms;
1236        self
1237    }
1238
1239    /// Builds the decoder with the given settings. If the config is a JSON or
1240    /// YAML string, this will deserialize the JSON or YAML and then parse the
1241    /// configuration information.
1242    ///
1243    /// # Examples
1244    /// ```rust
1245    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
1246    /// # fn main() -> DecoderResult<()> {
1247    /// # let config_json = include_str!("../../../testdata/modelpack_split.json").to_string();
1248    /// let decoder = DecoderBuilder::new()
1249    ///     .with_config_json_str(config_json)
1250    ///     .with_score_threshold(0.654)
1251    ///     .build()?;
1252    /// # Ok(())
1253    /// # }
1254    /// ```
1255    pub fn build(self) -> Result<Decoder, DecoderError> {
1256        let config = match self.config_src {
1257            Some(ConfigSource::Json(s)) => serde_json::from_str(&s)?,
1258            Some(ConfigSource::Yaml(s)) => serde_yaml::from_str(&s)?,
1259            Some(ConfigSource::Config(c)) => c,
1260            None => return Err(DecoderError::NoConfig),
1261        };
1262
1263        // Extract normalized flag from config outputs
1264        let normalized = Self::get_normalized(&config.outputs);
1265
1266        // Use NMS from config if present, otherwise use builder's NMS setting
1267        let nms = config.nms.or(self.nms);
1268        let model_type = Self::get_model_type(config)?;
1269
1270        Ok(Decoder {
1271            model_type,
1272            iou_threshold: self.iou_threshold,
1273            score_threshold: self.score_threshold,
1274            nms,
1275            normalized,
1276        })
1277    }
1278
1279    /// Extracts the normalized flag from config outputs.
1280    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
1281    /// - `Some(false)`: Boxes are in pixel coordinates
1282    /// - `None`: Unknown (not specified in config), caller must infer
1283    fn get_normalized(outputs: &[ConfigOutput]) -> Option<bool> {
1284        for output in outputs {
1285            match output {
1286                ConfigOutput::Detection(det) => return det.normalized,
1287                ConfigOutput::Boxes(boxes) => return boxes.normalized,
1288                _ => {}
1289            }
1290        }
1291        None // not specified
1292    }
1293
1294    fn get_model_type(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1295        // yolo or modelpack
1296        let mut yolo = false;
1297        let mut modelpack = false;
1298        for c in &configs.outputs {
1299            match c.decoder() {
1300                DecoderType::ModelPack => modelpack = true,
1301                DecoderType::Ultralytics => yolo = true,
1302            }
1303        }
1304        match (modelpack, yolo) {
1305            (true, true) => Err(DecoderError::InvalidConfig(
1306                "Both ModelPack and Yolo outputs found in config".to_string(),
1307            )),
1308            (true, false) => Self::get_model_type_modelpack(configs),
1309            (false, true) => Self::get_model_type_yolo(configs),
1310            (false, false) => Err(DecoderError::InvalidConfig(
1311                "No outputs found in config".to_string(),
1312            )),
1313        }
1314    }
1315
1316    fn get_model_type_yolo(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1317        let mut boxes = None;
1318        let mut protos = None;
1319        let mut split_boxes = None;
1320        let mut split_scores = None;
1321        let mut split_mask_coeff = None;
1322        for c in configs.outputs {
1323            match c {
1324                ConfigOutput::Detection(detection) => boxes = Some(detection),
1325                ConfigOutput::Segmentation(_) => {
1326                    return Err(DecoderError::InvalidConfig(
1327                        "Invalid Segmentation output with Yolo decoder".to_string(),
1328                    ));
1329                }
1330                ConfigOutput::Protos(protos_) => protos = Some(protos_),
1331                ConfigOutput::Mask(_) => {
1332                    return Err(DecoderError::InvalidConfig(
1333                        "Invalid Mask output with Yolo decoder".to_string(),
1334                    ));
1335                }
1336                ConfigOutput::Scores(scores) => split_scores = Some(scores),
1337                ConfigOutput::Boxes(boxes) => split_boxes = Some(boxes),
1338                ConfigOutput::MaskCoefficients(mask_coeff) => split_mask_coeff = Some(mask_coeff),
1339            }
1340        }
1341
1342        // Use end-to-end model types when:
1343        // 1. decoder_version is explicitly set to Yolo26 (definitive), OR
1344        //    decoder_version is not set but the dshapes are (batch, num_boxes,
1345        //    num_features)
1346        let is_end_to_end_dshape = boxes.as_ref().is_some_and(|b| {
1347            let dims = b.dshape.iter().map(|(d, _)| *d).collect::<Vec<_>>();
1348            dims == vec![DimName::Batch, DimName::NumBoxes, DimName::NumFeatures]
1349        });
1350
1351        let is_end_to_end = configs
1352            .decoder_version
1353            .map(|v| v.is_end_to_end())
1354            .unwrap_or(is_end_to_end_dshape);
1355
1356        if is_end_to_end {
1357            if let Some(boxes) = boxes {
1358                if let Some(protos) = protos {
1359                    Self::verify_yolo_seg_det_26(&boxes, &protos)?;
1360                    return Ok(ModelType::YoloEndToEndSegDet { boxes, protos });
1361                } else {
1362                    Self::verify_yolo_det_26(&boxes)?;
1363                    return Ok(ModelType::YoloEndToEndDet { boxes });
1364                }
1365            } else {
1366                return Err(DecoderError::InvalidConfig(
1367                    "Invalid Yolo end-to-end model outputs".to_string(),
1368                ));
1369            }
1370        }
1371
1372        if let Some(boxes) = boxes {
1373            if let Some(protos) = protos {
1374                Self::verify_yolo_seg_det(&boxes, &protos)?;
1375                Ok(ModelType::YoloSegDet { boxes, protos })
1376            } else {
1377                Self::verify_yolo_det(&boxes)?;
1378                Ok(ModelType::YoloDet { boxes })
1379            }
1380        } else if let (Some(boxes), Some(scores)) = (split_boxes, split_scores) {
1381            if let (Some(mask_coeff), Some(protos)) = (split_mask_coeff, protos) {
1382                Self::verify_yolo_split_segdet(&boxes, &scores, &mask_coeff, &protos)?;
1383                Ok(ModelType::YoloSplitSegDet {
1384                    boxes,
1385                    scores,
1386                    mask_coeff,
1387                    protos,
1388                })
1389            } else {
1390                Self::verify_yolo_split_det(&boxes, &scores)?;
1391                Ok(ModelType::YoloSplitDet { boxes, scores })
1392            }
1393        } else {
1394            Err(DecoderError::InvalidConfig(
1395                "Invalid Yolo model outputs".to_string(),
1396            ))
1397        }
1398    }
1399
1400    fn verify_yolo_det(detect: &configs::Detection) -> Result<(), DecoderError> {
1401        if detect.shape.len() != 3 {
1402            return Err(DecoderError::InvalidConfig(format!(
1403                "Invalid Yolo Detection shape {:?}",
1404                detect.shape
1405            )));
1406        }
1407
1408        Self::verify_dshapes(
1409            &detect.dshape,
1410            &detect.shape,
1411            "Detection",
1412            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1413        )?;
1414        if !detect.dshape.is_empty() {
1415            Self::get_class_count(&detect.dshape, None, None)?;
1416        } else {
1417            Self::get_class_count_no_dshape(detect.into(), None)?;
1418        }
1419
1420        Ok(())
1421    }
1422
1423    fn verify_yolo_det_26(detect: &configs::Detection) -> Result<(), DecoderError> {
1424        if detect.shape.len() != 3 {
1425            return Err(DecoderError::InvalidConfig(format!(
1426                "Invalid Yolo Detection shape {:?}",
1427                detect.shape
1428            )));
1429        }
1430
1431        Self::verify_dshapes(
1432            &detect.dshape,
1433            &detect.shape,
1434            "Detection",
1435            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1436        )?;
1437
1438        if !detect.shape.contains(&6) {
1439            return Err(DecoderError::InvalidConfig(
1440                "Yolo26 Detection must have 6 features".to_string(),
1441            ));
1442        }
1443
1444        Ok(())
1445    }
1446
1447    fn verify_yolo_seg_det(
1448        detection: &configs::Detection,
1449        protos: &configs::Protos,
1450    ) -> Result<(), DecoderError> {
1451        if detection.shape.len() != 3 {
1452            return Err(DecoderError::InvalidConfig(format!(
1453                "Invalid Yolo Detection shape {:?}",
1454                detection.shape
1455            )));
1456        }
1457        if protos.shape.len() != 4 {
1458            return Err(DecoderError::InvalidConfig(format!(
1459                "Invalid Yolo Protos shape {:?}",
1460                protos.shape
1461            )));
1462        }
1463
1464        Self::verify_dshapes(
1465            &detection.dshape,
1466            &detection.shape,
1467            "Detection",
1468            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1469        )?;
1470        Self::verify_dshapes(
1471            &protos.dshape,
1472            &protos.shape,
1473            "Protos",
1474            &[
1475                DimName::Batch,
1476                DimName::Height,
1477                DimName::Width,
1478                DimName::NumProtos,
1479            ],
1480        )?;
1481
1482        let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1483        log::debug!("Protos count: {}", protos_count);
1484        log::debug!("Detection dshape: {:?}", detection.dshape);
1485        let classes = if !detection.dshape.is_empty() {
1486            Self::get_class_count(&detection.dshape, Some(protos_count), None)?
1487        } else {
1488            Self::get_class_count_no_dshape(detection.into(), Some(protos_count))?
1489        };
1490
1491        if classes == 0 {
1492            return Err(DecoderError::InvalidConfig(
1493                "Yolo Segmentation Detection has zero classes".to_string(),
1494            ));
1495        }
1496
1497        Ok(())
1498    }
1499
1500    fn verify_yolo_seg_det_26(
1501        detection: &configs::Detection,
1502        protos: &configs::Protos,
1503    ) -> Result<(), DecoderError> {
1504        if detection.shape.len() != 3 {
1505            return Err(DecoderError::InvalidConfig(format!(
1506                "Invalid Yolo Detection shape {:?}",
1507                detection.shape
1508            )));
1509        }
1510        if protos.shape.len() != 4 {
1511            return Err(DecoderError::InvalidConfig(format!(
1512                "Invalid Yolo Protos shape {:?}",
1513                protos.shape
1514            )));
1515        }
1516
1517        Self::verify_dshapes(
1518            &detection.dshape,
1519            &detection.shape,
1520            "Detection",
1521            &[DimName::Batch, DimName::NumFeatures, DimName::NumBoxes],
1522        )?;
1523        Self::verify_dshapes(
1524            &protos.dshape,
1525            &protos.shape,
1526            "Protos",
1527            &[
1528                DimName::Batch,
1529                DimName::Height,
1530                DimName::Width,
1531                DimName::NumProtos,
1532            ],
1533        )?;
1534
1535        let protos_count = Self::get_protos_count(&protos.dshape).unwrap_or(protos.shape[3]);
1536        log::debug!("Protos count: {}", protos_count);
1537        log::debug!("Detection dshape: {:?}", detection.dshape);
1538
1539        if !detection.shape.contains(&(6 + protos_count)) {
1540            return Err(DecoderError::InvalidConfig(format!(
1541                "Yolo26 Segmentation Detection must have num_features be 6 + num_protos = {}",
1542                6 + protos_count
1543            )));
1544        }
1545
1546        Ok(())
1547    }
1548
1549    fn verify_yolo_split_det(
1550        boxes: &configs::Boxes,
1551        scores: &configs::Scores,
1552    ) -> Result<(), DecoderError> {
1553        if boxes.shape.len() != 3 {
1554            return Err(DecoderError::InvalidConfig(format!(
1555                "Invalid Yolo Split Boxes shape {:?}",
1556                boxes.shape
1557            )));
1558        }
1559        if scores.shape.len() != 3 {
1560            return Err(DecoderError::InvalidConfig(format!(
1561                "Invalid Yolo Split Scores shape {:?}",
1562                scores.shape
1563            )));
1564        }
1565
1566        Self::verify_dshapes(
1567            &boxes.dshape,
1568            &boxes.shape,
1569            "Boxes",
1570            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1571        )?;
1572        Self::verify_dshapes(
1573            &scores.dshape,
1574            &scores.shape,
1575            "Scores",
1576            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1577        )?;
1578
1579        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1580        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1581
1582        if boxes_num != scores_num {
1583            return Err(DecoderError::InvalidConfig(format!(
1584                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1585                boxes_num, scores_num
1586            )));
1587        }
1588
1589        Ok(())
1590    }
1591
1592    fn verify_yolo_split_segdet(
1593        boxes: &configs::Boxes,
1594        scores: &configs::Scores,
1595        mask_coeff: &configs::MaskCoefficients,
1596        protos: &configs::Protos,
1597    ) -> Result<(), DecoderError> {
1598        if boxes.shape.len() != 3 {
1599            return Err(DecoderError::InvalidConfig(format!(
1600                "Invalid Yolo Split Boxes shape {:?}",
1601                boxes.shape
1602            )));
1603        }
1604        if scores.shape.len() != 3 {
1605            return Err(DecoderError::InvalidConfig(format!(
1606                "Invalid Yolo Split Scores shape {:?}",
1607                scores.shape
1608            )));
1609        }
1610
1611        if mask_coeff.shape.len() != 3 {
1612            return Err(DecoderError::InvalidConfig(format!(
1613                "Invalid Yolo Split Mask Coefficients shape {:?}",
1614                mask_coeff.shape
1615            )));
1616        }
1617
1618        if protos.shape.len() != 4 {
1619            return Err(DecoderError::InvalidConfig(format!(
1620                "Invalid Yolo Protos shape {:?}",
1621                mask_coeff.shape
1622            )));
1623        }
1624
1625        Self::verify_dshapes(
1626            &boxes.dshape,
1627            &boxes.shape,
1628            "Boxes",
1629            &[DimName::Batch, DimName::BoxCoords, DimName::NumBoxes],
1630        )?;
1631        Self::verify_dshapes(
1632            &scores.dshape,
1633            &scores.shape,
1634            "Scores",
1635            &[DimName::Batch, DimName::NumClasses, DimName::NumBoxes],
1636        )?;
1637        Self::verify_dshapes(
1638            &mask_coeff.dshape,
1639            &mask_coeff.shape,
1640            "Mask Coefficients",
1641            &[DimName::Batch, DimName::NumProtos, DimName::NumBoxes],
1642        )?;
1643        Self::verify_dshapes(
1644            &protos.dshape,
1645            &protos.shape,
1646            "Protos",
1647            &[
1648                DimName::Batch,
1649                DimName::Height,
1650                DimName::Width,
1651                DimName::NumProtos,
1652            ],
1653        )?;
1654
1655        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[2]);
1656        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[2]);
1657        let mask_num = Self::get_box_count(&mask_coeff.dshape).unwrap_or(mask_coeff.shape[2]);
1658
1659        let mask_channels = if !mask_coeff.dshape.is_empty() {
1660            Self::get_protos_count(&mask_coeff.dshape).ok_or_else(|| {
1661                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1662            })?
1663        } else {
1664            mask_coeff.shape[1]
1665        };
1666        let proto_channels = if !protos.dshape.is_empty() {
1667            Self::get_protos_count(&protos.dshape).ok_or_else(|| {
1668                DecoderError::InvalidConfig("Could not find num_protos in config".to_string())
1669            })?
1670        } else {
1671            protos.shape[3]
1672        };
1673
1674        if boxes_num != scores_num {
1675            return Err(DecoderError::InvalidConfig(format!(
1676                "Yolo Split Detection Boxes num {} incompatible with Scores num {}",
1677                boxes_num, scores_num
1678            )));
1679        }
1680
1681        if boxes_num != mask_num {
1682            return Err(DecoderError::InvalidConfig(format!(
1683                "Yolo Split Detection Boxes num {} incompatible with Mask Coefficients num {}",
1684                boxes_num, mask_num
1685            )));
1686        }
1687
1688        if proto_channels != mask_channels {
1689            return Err(DecoderError::InvalidConfig(format!(
1690                "Yolo Protos channels {} incompatible with Mask Coefficients channels {}",
1691                proto_channels, mask_channels
1692            )));
1693        }
1694
1695        Ok(())
1696    }
1697
1698    fn get_model_type_modelpack(configs: ConfigOutputs) -> Result<ModelType, DecoderError> {
1699        let mut split_decoders = Vec::new();
1700        let mut segment_ = None;
1701        let mut scores_ = None;
1702        let mut boxes_ = None;
1703        for c in configs.outputs {
1704            match c {
1705                ConfigOutput::Detection(detection) => split_decoders.push(detection),
1706                ConfigOutput::Segmentation(segmentation) => segment_ = Some(segmentation),
1707                ConfigOutput::Mask(_) => {}
1708                ConfigOutput::Protos(_) => {
1709                    return Err(DecoderError::InvalidConfig(
1710                        "ModelPack should not have protos".to_string(),
1711                    ));
1712                }
1713                ConfigOutput::Scores(scores) => scores_ = Some(scores),
1714                ConfigOutput::Boxes(boxes) => boxes_ = Some(boxes),
1715                ConfigOutput::MaskCoefficients(_) => {
1716                    return Err(DecoderError::InvalidConfig(
1717                        "ModelPack should not have mask coefficients".to_string(),
1718                    ));
1719                }
1720            }
1721        }
1722
1723        if let Some(segmentation) = segment_ {
1724            if !split_decoders.is_empty() {
1725                let classes = Self::verify_modelpack_split_det(&split_decoders)?;
1726                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1727                Ok(ModelType::ModelPackSegDetSplit {
1728                    detection: split_decoders,
1729                    segmentation,
1730                })
1731            } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1732                let classes = Self::verify_modelpack_det(&boxes, &scores)?;
1733                Self::verify_modelpack_seg(&segmentation, Some(classes))?;
1734                Ok(ModelType::ModelPackSegDet {
1735                    boxes,
1736                    scores,
1737                    segmentation,
1738                })
1739            } else {
1740                Self::verify_modelpack_seg(&segmentation, None)?;
1741                Ok(ModelType::ModelPackSeg { segmentation })
1742            }
1743        } else if !split_decoders.is_empty() {
1744            Self::verify_modelpack_split_det(&split_decoders)?;
1745            Ok(ModelType::ModelPackDetSplit {
1746                detection: split_decoders,
1747            })
1748        } else if let (Some(scores), Some(boxes)) = (scores_, boxes_) {
1749            Self::verify_modelpack_det(&boxes, &scores)?;
1750            Ok(ModelType::ModelPackDet { boxes, scores })
1751        } else {
1752            Err(DecoderError::InvalidConfig(
1753                "Invalid ModelPack model outputs".to_string(),
1754            ))
1755        }
1756    }
1757
1758    fn verify_modelpack_det(
1759        boxes: &configs::Boxes,
1760        scores: &configs::Scores,
1761    ) -> Result<usize, DecoderError> {
1762        if boxes.shape.len() != 4 {
1763            return Err(DecoderError::InvalidConfig(format!(
1764                "Invalid ModelPack Boxes shape {:?}",
1765                boxes.shape
1766            )));
1767        }
1768        if scores.shape.len() != 3 {
1769            return Err(DecoderError::InvalidConfig(format!(
1770                "Invalid ModelPack Scores shape {:?}",
1771                scores.shape
1772            )));
1773        }
1774
1775        Self::verify_dshapes(
1776            &boxes.dshape,
1777            &boxes.shape,
1778            "Boxes",
1779            &[
1780                DimName::Batch,
1781                DimName::NumBoxes,
1782                DimName::Padding,
1783                DimName::BoxCoords,
1784            ],
1785        )?;
1786        Self::verify_dshapes(
1787            &scores.dshape,
1788            &scores.shape,
1789            "Scores",
1790            &[DimName::Batch, DimName::NumBoxes, DimName::NumClasses],
1791        )?;
1792
1793        let boxes_num = Self::get_box_count(&boxes.dshape).unwrap_or(boxes.shape[1]);
1794        let scores_num = Self::get_box_count(&scores.dshape).unwrap_or(scores.shape[1]);
1795
1796        if boxes_num != scores_num {
1797            return Err(DecoderError::InvalidConfig(format!(
1798                "ModelPack Detection Boxes num {} incompatible with Scores num {}",
1799                boxes_num, scores_num
1800            )));
1801        }
1802
1803        let num_classes = if !scores.dshape.is_empty() {
1804            Self::get_class_count(&scores.dshape, None, None)?
1805        } else {
1806            Self::get_class_count_no_dshape(scores.into(), None)?
1807        };
1808
1809        Ok(num_classes)
1810    }
1811
1812    fn verify_modelpack_split_det(boxes: &[configs::Detection]) -> Result<usize, DecoderError> {
1813        let mut num_classes = None;
1814        for b in boxes {
1815            let Some(num_anchors) = b.anchors.as_ref().map(|a| a.len()) else {
1816                return Err(DecoderError::InvalidConfig(
1817                    "ModelPack Split Detection missing anchors".to_string(),
1818                ));
1819            };
1820
1821            if num_anchors == 0 {
1822                return Err(DecoderError::InvalidConfig(
1823                    "ModelPack Split Detection has zero anchors".to_string(),
1824                ));
1825            }
1826
1827            if b.shape.len() != 4 {
1828                return Err(DecoderError::InvalidConfig(format!(
1829                    "Invalid ModelPack Split Detection shape {:?}",
1830                    b.shape
1831                )));
1832            }
1833
1834            Self::verify_dshapes(
1835                &b.dshape,
1836                &b.shape,
1837                "Split Detection",
1838                &[
1839                    DimName::Batch,
1840                    DimName::Height,
1841                    DimName::Width,
1842                    DimName::NumAnchorsXFeatures,
1843                ],
1844            )?;
1845            let classes = if !b.dshape.is_empty() {
1846                Self::get_class_count(&b.dshape, None, Some(num_anchors))?
1847            } else {
1848                Self::get_class_count_no_dshape(b.into(), None)?
1849            };
1850
1851            match num_classes {
1852                Some(n) => {
1853                    if n != classes {
1854                        return Err(DecoderError::InvalidConfig(format!(
1855                            "ModelPack Split Detection inconsistent number of classes: previous {}, current {}",
1856                            n, classes
1857                        )));
1858                    }
1859                }
1860                None => {
1861                    num_classes = Some(classes);
1862                }
1863            }
1864        }
1865
1866        Ok(num_classes.unwrap_or(0))
1867    }
1868
1869    fn verify_modelpack_seg(
1870        segmentation: &configs::Segmentation,
1871        classes: Option<usize>,
1872    ) -> Result<(), DecoderError> {
1873        if segmentation.shape.len() != 4 {
1874            return Err(DecoderError::InvalidConfig(format!(
1875                "Invalid ModelPack Segmentation shape {:?}",
1876                segmentation.shape
1877            )));
1878        }
1879        Self::verify_dshapes(
1880            &segmentation.dshape,
1881            &segmentation.shape,
1882            "Segmentation",
1883            &[
1884                DimName::Batch,
1885                DimName::Height,
1886                DimName::Width,
1887                DimName::NumClasses,
1888            ],
1889        )?;
1890
1891        if let Some(classes) = classes {
1892            let seg_classes = if !segmentation.dshape.is_empty() {
1893                Self::get_class_count(&segmentation.dshape, None, None)?
1894            } else {
1895                Self::get_class_count_no_dshape(segmentation.into(), None)?
1896            };
1897
1898            if seg_classes != classes + 1 {
1899                return Err(DecoderError::InvalidConfig(format!(
1900                    "ModelPack Segmentation channels {} incompatible with number of classes {}",
1901                    seg_classes, classes
1902                )));
1903            }
1904        }
1905        Ok(())
1906    }
1907
1908    // verifies that dshapes match the given shape
1909    fn verify_dshapes(
1910        dshape: &[(DimName, usize)],
1911        shape: &[usize],
1912        name: &str,
1913        dims: &[DimName],
1914    ) -> Result<(), DecoderError> {
1915        for s in shape {
1916            if *s == 0 {
1917                return Err(DecoderError::InvalidConfig(format!(
1918                    "{} shape has zero dimension",
1919                    name
1920                )));
1921            }
1922        }
1923
1924        if shape.len() != dims.len() {
1925            return Err(DecoderError::InvalidConfig(format!(
1926                "{} shape length {} does not match expected dims length {}",
1927                name,
1928                shape.len(),
1929                dims.len()
1930            )));
1931        }
1932
1933        if dshape.is_empty() {
1934            return Ok(());
1935        }
1936        // check the dshape lengths match the shape lengths
1937        if dshape.len() != shape.len() {
1938            return Err(DecoderError::InvalidConfig(format!(
1939                "{} dshape length does not match shape length",
1940                name
1941            )));
1942        }
1943
1944        // check the dshape values match the shape values
1945        for ((dim_name, dim_size), shape_size) in dshape.iter().zip(shape) {
1946            if dim_size != shape_size {
1947                return Err(DecoderError::InvalidConfig(format!(
1948                    "{} dshape dimension {} size {} does not match shape size {}",
1949                    name, dim_name, dim_size, shape_size
1950                )));
1951            }
1952            if *dim_name == DimName::Padding && *dim_size != 1 {
1953                return Err(DecoderError::InvalidConfig(
1954                    "Padding dimension size must be 1".to_string(),
1955                ));
1956            }
1957
1958            if *dim_name == DimName::BoxCoords && *dim_size != 4 {
1959                return Err(DecoderError::InvalidConfig(
1960                    "BoxCoords dimension size must be 4".to_string(),
1961                ));
1962            }
1963        }
1964
1965        let dims_present = HashSet::<DimName>::from_iter(dshape.iter().map(|(name, _)| *name));
1966        for dim in dims {
1967            if !dims_present.contains(dim) {
1968                return Err(DecoderError::InvalidConfig(format!(
1969                    "{} dshape missing required dimension {:?}",
1970                    name, dim
1971                )));
1972            }
1973        }
1974
1975        Ok(())
1976    }
1977
1978    fn get_box_count(dshape: &[(DimName, usize)]) -> Option<usize> {
1979        for (dim_name, dim_size) in dshape {
1980            if *dim_name == DimName::NumBoxes {
1981                return Some(*dim_size);
1982            }
1983        }
1984        None
1985    }
1986
1987    fn get_class_count_no_dshape(
1988        config: ConfigOutputRef,
1989        protos: Option<usize>,
1990    ) -> Result<usize, DecoderError> {
1991        match config {
1992            ConfigOutputRef::Detection(detection) => match detection.decoder {
1993                DecoderType::Ultralytics => {
1994                    if detection.shape[1] <= 4 + protos.unwrap_or(0) {
1995                        return Err(DecoderError::InvalidConfig(format!(
1996                            "Invalid shape: Yolo num_features {} must be greater than {}",
1997                            detection.shape[1],
1998                            4 + protos.unwrap_or(0),
1999                        )));
2000                    }
2001                    Ok(detection.shape[1] - 4 - protos.unwrap_or(0))
2002                }
2003                DecoderType::ModelPack => {
2004                    let Some(num_anchors) = detection.anchors.as_ref().map(|a| a.len()) else {
2005                        return Err(DecoderError::Internal(
2006                            "ModelPack Detection missing anchors".to_string(),
2007                        ));
2008                    };
2009                    let anchors_x_features = detection.shape[3];
2010                    if anchors_x_features <= num_anchors * 5 {
2011                        return Err(DecoderError::InvalidConfig(format!(
2012                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2013                            anchors_x_features,
2014                            num_anchors * 5,
2015                        )));
2016                    }
2017
2018                    if !anchors_x_features.is_multiple_of(num_anchors) {
2019                        return Err(DecoderError::InvalidConfig(format!(
2020                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2021                            anchors_x_features, num_anchors
2022                        )));
2023                    }
2024                    Ok(anchors_x_features / num_anchors - 5)
2025                }
2026            },
2027
2028            ConfigOutputRef::Scores(scores) => match scores.decoder {
2029                DecoderType::Ultralytics => Ok(scores.shape[1]),
2030                DecoderType::ModelPack => Ok(scores.shape[2]),
2031            },
2032            ConfigOutputRef::Segmentation(seg) => Ok(seg.shape[3]),
2033            _ => Err(DecoderError::Internal(
2034                "Attempted to get class count from unsupported config output".to_owned(),
2035            )),
2036        }
2037    }
2038
2039    // get the class count from dshape or calculate from num_features
2040    fn get_class_count(
2041        dshape: &[(DimName, usize)],
2042        protos: Option<usize>,
2043        anchors: Option<usize>,
2044    ) -> Result<usize, DecoderError> {
2045        if dshape.is_empty() {
2046            return Ok(0);
2047        }
2048        // if it has num_classes in dshape, return it
2049        for (dim_name, dim_size) in dshape {
2050            if *dim_name == DimName::NumClasses {
2051                return Ok(*dim_size);
2052            }
2053        }
2054
2055        // number of classes can be calculated from num_features - 4 for yolo.  If the
2056        // model has protos, we also subtract the number of protos.
2057        for (dim_name, dim_size) in dshape {
2058            if *dim_name == DimName::NumFeatures {
2059                let protos = protos.unwrap_or(0);
2060                if protos + 4 >= *dim_size {
2061                    return Err(DecoderError::InvalidConfig(format!(
2062                        "Invalid shape: Yolo num_features {} must be greater than {}",
2063                        *dim_size,
2064                        protos + 4,
2065                    )));
2066                }
2067                return Ok(*dim_size - 4 - protos);
2068            }
2069        }
2070
2071        // number of classes can be calculated from number of anchors for modelpack
2072        // split detection
2073        if let Some(num_anchors) = anchors {
2074            for (dim_name, dim_size) in dshape {
2075                if *dim_name == DimName::NumAnchorsXFeatures {
2076                    let anchors_x_features = *dim_size;
2077                    if anchors_x_features <= num_anchors * 5 {
2078                        return Err(DecoderError::InvalidConfig(format!(
2079                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not greater than number of anchors * 5 = {}",
2080                            anchors_x_features,
2081                            num_anchors * 5,
2082                        )));
2083                    }
2084
2085                    if !anchors_x_features.is_multiple_of(num_anchors) {
2086                        return Err(DecoderError::InvalidConfig(format!(
2087                            "Invalid ModelPack Split Detection shape: anchors_x_features {} not a multiple of number of anchors {}",
2088                            anchors_x_features, num_anchors
2089                        )));
2090                    }
2091                    return Ok((anchors_x_features / num_anchors) - 5);
2092                }
2093            }
2094        }
2095        Err(DecoderError::InvalidConfig(
2096            "Cannot determine number of classes from dshape".to_owned(),
2097        ))
2098    }
2099
2100    fn get_protos_count(dshape: &[(DimName, usize)]) -> Option<usize> {
2101        for (dim_name, dim_size) in dshape {
2102            if *dim_name == DimName::NumProtos {
2103                return Some(*dim_size);
2104            }
2105        }
2106        None
2107    }
2108}
2109
2110#[derive(Debug, Clone, PartialEq)]
2111pub struct Decoder {
2112    model_type: ModelType,
2113    pub iou_threshold: f32,
2114    pub score_threshold: f32,
2115    /// NMS mode: Some(mode) applies NMS, None bypasses NMS (for end-to-end
2116    /// models)
2117    pub nms: Option<configs::Nms>,
2118    /// Whether decoded boxes are in normalized [0,1] coordinates.
2119    /// - `Some(true)`: Coordinates in [0,1] range
2120    /// - `Some(false)`: Pixel coordinates
2121    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
2122    ///   1.0)
2123    normalized: Option<bool>,
2124}
2125
2126#[derive(Debug)]
2127pub enum ArrayViewDQuantized<'a> {
2128    UInt8(ArrayViewD<'a, u8>),
2129    Int8(ArrayViewD<'a, i8>),
2130    UInt16(ArrayViewD<'a, u16>),
2131    Int16(ArrayViewD<'a, i16>),
2132    UInt32(ArrayViewD<'a, u32>),
2133    Int32(ArrayViewD<'a, i32>),
2134}
2135
2136impl<'a, D> From<ArrayView<'a, u8, D>> for ArrayViewDQuantized<'a>
2137where
2138    D: Dimension,
2139{
2140    fn from(arr: ArrayView<'a, u8, D>) -> Self {
2141        Self::UInt8(arr.into_dyn())
2142    }
2143}
2144
2145impl<'a, D> From<ArrayView<'a, i8, D>> for ArrayViewDQuantized<'a>
2146where
2147    D: Dimension,
2148{
2149    fn from(arr: ArrayView<'a, i8, D>) -> Self {
2150        Self::Int8(arr.into_dyn())
2151    }
2152}
2153
2154impl<'a, D> From<ArrayView<'a, u16, D>> for ArrayViewDQuantized<'a>
2155where
2156    D: Dimension,
2157{
2158    fn from(arr: ArrayView<'a, u16, D>) -> Self {
2159        Self::UInt16(arr.into_dyn())
2160    }
2161}
2162
2163impl<'a, D> From<ArrayView<'a, i16, D>> for ArrayViewDQuantized<'a>
2164where
2165    D: Dimension,
2166{
2167    fn from(arr: ArrayView<'a, i16, D>) -> Self {
2168        Self::Int16(arr.into_dyn())
2169    }
2170}
2171
2172impl<'a, D> From<ArrayView<'a, u32, D>> for ArrayViewDQuantized<'a>
2173where
2174    D: Dimension,
2175{
2176    fn from(arr: ArrayView<'a, u32, D>) -> Self {
2177        Self::UInt32(arr.into_dyn())
2178    }
2179}
2180
2181impl<'a, D> From<ArrayView<'a, i32, D>> for ArrayViewDQuantized<'a>
2182where
2183    D: Dimension,
2184{
2185    fn from(arr: ArrayView<'a, i32, D>) -> Self {
2186        Self::Int32(arr.into_dyn())
2187    }
2188}
2189
2190impl<'a> ArrayViewDQuantized<'a> {
2191    /// Returns the shape of the underlying array.
2192    ///
2193    /// # Examples
2194    /// ```rust
2195    /// # use edgefirst_decoder::ArrayViewDQuantized;
2196    /// # use ndarray::Array2;
2197    /// # fn main() -> Result<(), Box<dyn std::error::Error>> {
2198    /// let arr = Array2::from_shape_vec((2, 3), vec![1u8, 2, 3, 4, 5, 6])?;
2199    /// let view = ArrayViewDQuantized::from(arr.view().into_dyn());
2200    /// assert_eq!(view.shape(), &[2, 3]);
2201    /// # Ok(())
2202    /// # }
2203    /// ```
2204    pub fn shape(&self) -> &[usize] {
2205        match self {
2206            ArrayViewDQuantized::UInt8(a) => a.shape(),
2207            ArrayViewDQuantized::Int8(a) => a.shape(),
2208            ArrayViewDQuantized::UInt16(a) => a.shape(),
2209            ArrayViewDQuantized::Int16(a) => a.shape(),
2210            ArrayViewDQuantized::UInt32(a) => a.shape(),
2211            ArrayViewDQuantized::Int32(a) => a.shape(),
2212        }
2213    }
2214}
2215
2216macro_rules! with_quantized {
2217    ($x:expr, $var:ident, $body:expr) => {
2218        match $x {
2219            ArrayViewDQuantized::UInt8(x) => {
2220                let $var = x;
2221                $body
2222            }
2223            ArrayViewDQuantized::Int8(x) => {
2224                let $var = x;
2225                $body
2226            }
2227            ArrayViewDQuantized::UInt16(x) => {
2228                let $var = x;
2229                $body
2230            }
2231            ArrayViewDQuantized::Int16(x) => {
2232                let $var = x;
2233                $body
2234            }
2235            ArrayViewDQuantized::UInt32(x) => {
2236                let $var = x;
2237                $body
2238            }
2239            ArrayViewDQuantized::Int32(x) => {
2240                let $var = x;
2241                $body
2242            }
2243        }
2244    };
2245}
2246
2247impl Decoder {
2248    /// This function returns the parsed model type of the decoder.
2249    ///
2250    /// # Examples
2251    ///
2252    /// ```rust
2253    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult, configs::ModelType};
2254    /// # fn main() -> DecoderResult<()> {
2255    /// #    let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
2256    ///     let decoder = DecoderBuilder::default()
2257    ///         .with_config_yaml_str(config_yaml)
2258    ///         .build()?;
2259    ///     assert!(matches!(
2260    ///         decoder.model_type(),
2261    ///         ModelType::ModelPackDetSplit { .. }
2262    ///     ));
2263    /// #    Ok(())
2264    /// # }
2265    /// ```
2266    pub fn model_type(&self) -> &ModelType {
2267        &self.model_type
2268    }
2269
2270    /// Returns the box coordinate format if known from the model config.
2271    ///
2272    /// - `Some(true)`: Boxes are in normalized [0,1] coordinates
2273    /// - `Some(false)`: Boxes are in pixel coordinates relative to model input
2274    /// - `None`: Unknown, caller must infer (e.g., check if any coordinate >
2275    ///   1.0)
2276    ///
2277    /// This is determined by the model config's `normalized` field, not the NMS
2278    /// mode. When coordinates are in pixels or unknown, the caller may need
2279    /// to normalize using the model input dimensions.
2280    ///
2281    /// # Examples
2282    ///
2283    /// ```rust
2284    /// # use edgefirst_decoder::{DecoderBuilder, DecoderResult};
2285    /// # fn main() -> DecoderResult<()> {
2286    /// #    let config_yaml = include_str!("../../../testdata/modelpack_split.yaml").to_string();
2287    ///     let decoder = DecoderBuilder::default()
2288    ///         .with_config_yaml_str(config_yaml)
2289    ///         .build()?;
2290    ///     // Config doesn't specify normalized, so it's None
2291    ///     assert!(decoder.normalized_boxes().is_none());
2292    /// #    Ok(())
2293    /// # }
2294    /// ```
2295    pub fn normalized_boxes(&self) -> Option<bool> {
2296        self.normalized
2297    }
2298
2299    /// This function decodes quantized model outputs into detection boxes and
2300    /// segmentation masks. The quantized outputs can be of u8, i8, u16, i16,
2301    /// u32, or i32 types. Up to `output_boxes.capacity()` boxes and masks
2302    /// will be decoded. The function clears the provided output vectors
2303    /// before populating them with the decoded results.
2304    ///
2305    /// This function returns a `DecoderError` if the the provided outputs don't
2306    /// match the configuration provided by the user when building the decoder.
2307    ///
2308    /// # Examples
2309    ///
2310    /// ```rust
2311    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult};
2312    /// # use ndarray::Array4;
2313    /// # fn main() -> DecoderResult<()> {
2314    /// #    let detect0 = include_bytes!("../../../testdata/modelpack_split_9x15x18.bin");
2315    /// #    let detect0 = ndarray::Array4::from_shape_vec((1, 9, 15, 18), detect0.to_vec())?;
2316    /// #
2317    /// #    let detect1 = include_bytes!("../../../testdata/modelpack_split_17x30x18.bin");
2318    /// #    let detect1 = ndarray::Array4::from_shape_vec((1, 17, 30, 18), detect1.to_vec())?;
2319    /// #    let model_output = vec![
2320    /// #        detect1.view().into_dyn().into(),
2321    /// #        detect0.view().into_dyn().into(),
2322    /// #    ];
2323    /// let decoder = DecoderBuilder::default()
2324    ///     .with_config_yaml_str(include_str!("../../../testdata/modelpack_split.yaml").to_string())
2325    ///     .with_score_threshold(0.45)
2326    ///     .with_iou_threshold(0.45)
2327    ///     .build()?;
2328    ///
2329    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2330    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
2331    /// decoder.decode_quantized(&model_output, &mut output_boxes, &mut output_masks)?;
2332    /// assert!(output_boxes[0].equal_within_delta(
2333    ///     &DetectBox {
2334    ///         bbox: BoundingBox {
2335    ///             xmin: 0.43171933,
2336    ///             ymin: 0.68243736,
2337    ///             xmax: 0.5626645,
2338    ///             ymax: 0.808863,
2339    ///         },
2340    ///         score: 0.99240804,
2341    ///         label: 0
2342    ///     },
2343    ///     1e-6
2344    /// ));
2345    /// #    Ok(())
2346    /// # }
2347    /// ```
2348    pub fn decode_quantized(
2349        &self,
2350        outputs: &[ArrayViewDQuantized],
2351        output_boxes: &mut Vec<DetectBox>,
2352        output_masks: &mut Vec<Segmentation>,
2353    ) -> Result<(), DecoderError> {
2354        output_boxes.clear();
2355        output_masks.clear();
2356        match &self.model_type {
2357            ModelType::ModelPackSegDet {
2358                boxes,
2359                scores,
2360                segmentation,
2361            } => {
2362                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)?;
2363                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2364            }
2365            ModelType::ModelPackSegDetSplit {
2366                detection,
2367                segmentation,
2368            } => {
2369                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)?;
2370                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2371            }
2372            ModelType::ModelPackDet { boxes, scores } => {
2373                self.decode_modelpack_det_quantized(outputs, boxes, scores, output_boxes)
2374            }
2375            ModelType::ModelPackDetSplit { detection } => {
2376                self.decode_modelpack_det_split_quantized(outputs, detection, output_boxes)
2377            }
2378            ModelType::ModelPackSeg { segmentation } => {
2379                self.decode_modelpack_seg_quantized(outputs, segmentation, output_masks)
2380            }
2381            ModelType::YoloDet { boxes } => {
2382                self.decode_yolo_det_quantized(outputs, boxes, output_boxes)
2383            }
2384            ModelType::YoloSegDet { boxes, protos } => self.decode_yolo_segdet_quantized(
2385                outputs,
2386                boxes,
2387                protos,
2388                output_boxes,
2389                output_masks,
2390            ),
2391            ModelType::YoloSplitDet { boxes, scores } => {
2392                self.decode_yolo_split_det_quantized(outputs, boxes, scores, output_boxes)
2393            }
2394            ModelType::YoloSplitSegDet {
2395                boxes,
2396                scores,
2397                mask_coeff,
2398                protos,
2399            } => self.decode_yolo_split_segdet_quantized(
2400                outputs,
2401                boxes,
2402                scores,
2403                mask_coeff,
2404                protos,
2405                output_boxes,
2406                output_masks,
2407            ),
2408            ModelType::YoloEndToEndDet { .. } | ModelType::YoloEndToEndSegDet { .. } => {
2409                Err(DecoderError::InvalidConfig(
2410                    "End-to-end models require float decode, not quantized".to_string(),
2411                ))
2412            }
2413        }
2414    }
2415
2416    /// This function decodes floating point model outputs into detection boxes
2417    /// and segmentation masks. Up to `output_boxes.capacity()` boxes and
2418    /// masks will be decoded. The function clears the provided output
2419    /// vectors before populating them with the decoded results.
2420    ///
2421    /// This function returns an `Error` if the the provided outputs don't
2422    /// match the configuration provided by the user when building the decoder.
2423    ///
2424    /// Any quantization information in the configuration will be ignored.
2425    ///
2426    /// # Examples
2427    ///
2428    /// ```rust
2429    /// # use edgefirst_decoder::{BoundingBox, DecoderBuilder, DetectBox, DecoderResult, configs, configs::{DecoderType, DecoderVersion}, dequantize_cpu, Quantization};
2430    /// # use ndarray::Array3;
2431    /// # fn main() -> DecoderResult<()> {
2432    /// #   let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
2433    /// #   let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
2434    /// #   let mut out_dequant = vec![0.0_f64; 84 * 8400];
2435    /// #   let quant = Quantization::new(0.0040811873, -123);
2436    /// #   dequantize_cpu(out, quant, &mut out_dequant);
2437    /// #   let model_output_f64 = Array3::from_shape_vec((1, 84, 8400), out_dequant)?.into_dyn();
2438    ///    let decoder = DecoderBuilder::default()
2439    ///     .with_config_yolo_det(configs::Detection {
2440    ///         decoder: DecoderType::Ultralytics,
2441    ///         quantization: None,
2442    ///         shape: vec![1, 84, 8400],
2443    ///         anchors: None,
2444    ///         dshape: Vec::new(),
2445    ///         normalized: Some(true),
2446    ///     },
2447    ///     Some(DecoderVersion::Yolo11))
2448    ///     .with_score_threshold(0.25)
2449    ///     .with_iou_threshold(0.7)
2450    ///     .build()?;
2451    ///
2452    /// let mut output_boxes: Vec<_> = Vec::with_capacity(10);
2453    /// let mut output_masks: Vec<_> = Vec::with_capacity(10);
2454    /// let model_output_f64 = vec![model_output_f64.view().into()];
2455    /// decoder.decode_float(&model_output_f64, &mut output_boxes, &mut output_masks)?;    
2456    /// assert!(output_boxes[0].equal_within_delta(
2457    ///        &DetectBox {
2458    ///            bbox: BoundingBox {
2459    ///                xmin: 0.5285137,
2460    ///                ymin: 0.05305544,
2461    ///                xmax: 0.87541467,
2462    ///                ymax: 0.9998909,
2463    ///            },
2464    ///            score: 0.5591227,
2465    ///            label: 0
2466    ///        },
2467    ///        1e-6
2468    ///    ));
2469    ///
2470    /// #    Ok(())
2471    /// # }
2472    pub fn decode_float<T>(
2473        &self,
2474        outputs: &[ArrayViewD<T>],
2475        output_boxes: &mut Vec<DetectBox>,
2476        output_masks: &mut Vec<Segmentation>,
2477    ) -> Result<(), DecoderError>
2478    where
2479        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2480        f32: AsPrimitive<T>,
2481    {
2482        output_boxes.clear();
2483        output_masks.clear();
2484        match &self.model_type {
2485            ModelType::ModelPackSegDet {
2486                boxes,
2487                scores,
2488                segmentation,
2489            } => {
2490                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2491                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2492            }
2493            ModelType::ModelPackSegDetSplit {
2494                detection,
2495                segmentation,
2496            } => {
2497                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2498                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2499            }
2500            ModelType::ModelPackDet { boxes, scores } => {
2501                self.decode_modelpack_det_float(outputs, boxes, scores, output_boxes)?;
2502            }
2503            ModelType::ModelPackDetSplit { detection } => {
2504                self.decode_modelpack_det_split_float(outputs, detection, output_boxes)?;
2505            }
2506            ModelType::ModelPackSeg { segmentation } => {
2507                self.decode_modelpack_seg_float(outputs, segmentation, output_masks)?;
2508            }
2509            ModelType::YoloDet { boxes } => {
2510                self.decode_yolo_det_float(outputs, boxes, output_boxes)?;
2511            }
2512            ModelType::YoloSegDet { boxes, protos } => {
2513                self.decode_yolo_segdet_float(outputs, boxes, protos, output_boxes, output_masks)?;
2514            }
2515            ModelType::YoloSplitDet { boxes, scores } => {
2516                self.decode_yolo_split_det_float(outputs, boxes, scores, output_boxes)?;
2517            }
2518            ModelType::YoloSplitSegDet {
2519                boxes,
2520                scores,
2521                mask_coeff,
2522                protos,
2523            } => {
2524                self.decode_yolo_split_segdet_float(
2525                    outputs,
2526                    boxes,
2527                    scores,
2528                    mask_coeff,
2529                    protos,
2530                    output_boxes,
2531                    output_masks,
2532                )?;
2533            }
2534            ModelType::YoloEndToEndDet { boxes } => {
2535                self.decode_yolo_end_to_end_det_float(outputs, boxes, output_boxes)?;
2536            }
2537            ModelType::YoloEndToEndSegDet { boxes, protos } => {
2538                self.decode_yolo_end_to_end_segdet_float(
2539                    outputs,
2540                    boxes,
2541                    protos,
2542                    output_boxes,
2543                    output_masks,
2544                )?;
2545            }
2546        }
2547        Ok(())
2548    }
2549
2550    fn decode_modelpack_det_quantized(
2551        &self,
2552        outputs: &[ArrayViewDQuantized],
2553        boxes: &configs::Boxes,
2554        scores: &configs::Scores,
2555        output_boxes: &mut Vec<DetectBox>,
2556    ) -> Result<(), DecoderError> {
2557        let (boxes_tensor, ind) =
2558            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2559        let (scores_tensor, _) =
2560            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
2561        let quant_boxes = boxes
2562            .quantization
2563            .map(Quantization::from)
2564            .unwrap_or_default();
2565        let quant_scores = scores
2566            .quantization
2567            .map(Quantization::from)
2568            .unwrap_or_default();
2569
2570        with_quantized!(boxes_tensor, b, {
2571            with_quantized!(scores_tensor, s, {
2572                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2573                let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
2574
2575                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2576                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2577                decode_modelpack_det(
2578                    (boxes_tensor, quant_boxes),
2579                    (scores_tensor, quant_scores),
2580                    self.score_threshold,
2581                    self.iou_threshold,
2582                    output_boxes,
2583                );
2584            });
2585        });
2586
2587        Ok(())
2588    }
2589
2590    fn decode_modelpack_seg_quantized(
2591        &self,
2592        outputs: &[ArrayViewDQuantized],
2593        segmentation: &configs::Segmentation,
2594        output_masks: &mut Vec<Segmentation>,
2595    ) -> Result<(), DecoderError> {
2596        let (seg, _) = Self::find_outputs_with_shape_quantized(&segmentation.shape, outputs, &[])?;
2597
2598        macro_rules! modelpack_seg {
2599            ($seg:expr, $body:expr) => {{
2600                let seg = Self::swap_axes_if_needed($seg, segmentation.into());
2601                let seg = seg.slice(s![0, .., .., ..]);
2602                seg.mapv($body)
2603            }};
2604        }
2605        use ArrayViewDQuantized::*;
2606        let seg = match seg {
2607            UInt8(s) => {
2608                modelpack_seg!(s, |x| x)
2609            }
2610            Int8(s) => {
2611                modelpack_seg!(s, |x| (x as i16 + 128) as u8)
2612            }
2613            UInt16(s) => {
2614                modelpack_seg!(s, |x| (x >> 8) as u8)
2615            }
2616            Int16(s) => {
2617                modelpack_seg!(s, |x| ((x as i32 + 32768) >> 8) as u8)
2618            }
2619            UInt32(s) => {
2620                modelpack_seg!(s, |x| (x >> 24) as u8)
2621            }
2622            Int32(s) => {
2623                modelpack_seg!(s, |x| ((x as i64 + 2147483648) >> 24) as u8)
2624            }
2625        };
2626
2627        output_masks.push(Segmentation {
2628            xmin: 0.0,
2629            ymin: 0.0,
2630            xmax: 1.0,
2631            ymax: 1.0,
2632            segmentation: seg,
2633        });
2634        Ok(())
2635    }
2636
2637    fn decode_modelpack_det_split_quantized(
2638        &self,
2639        outputs: &[ArrayViewDQuantized],
2640        detection: &[configs::Detection],
2641        output_boxes: &mut Vec<DetectBox>,
2642    ) -> Result<(), DecoderError> {
2643        let new_detection = detection
2644            .iter()
2645            .map(|x| match &x.anchors {
2646                None => Err(DecoderError::InvalidConfig(
2647                    "ModelPack Split Detection missing anchors".to_string(),
2648                )),
2649                Some(a) => Ok(ModelPackDetectionConfig {
2650                    anchors: a.clone(),
2651                    quantization: None,
2652                }),
2653            })
2654            .collect::<Result<Vec<_>, _>>()?;
2655        let new_outputs = Self::match_outputs_to_detect_quantized(detection, outputs)?;
2656
2657        macro_rules! dequant_output {
2658            ($det_tensor:expr, $detection:expr) => {{
2659                let det_tensor = Self::swap_axes_if_needed($det_tensor, $detection.into());
2660                let det_tensor = det_tensor.slice(s![0, .., .., ..]);
2661                if let Some(q) = $detection.quantization {
2662                    dequantize_ndarray(det_tensor, q.into())
2663                } else {
2664                    det_tensor.map(|x| *x as f32)
2665                }
2666            }};
2667        }
2668
2669        let new_outputs = new_outputs
2670            .iter()
2671            .zip(detection)
2672            .map(|(det_tensor, detection)| {
2673                with_quantized!(det_tensor, d, dequant_output!(d, detection))
2674            })
2675            .collect::<Vec<_>>();
2676
2677        let new_outputs_view = new_outputs
2678            .iter()
2679            .map(|d: &Array3<f32>| d.view())
2680            .collect::<Vec<_>>();
2681        decode_modelpack_split_float(
2682            &new_outputs_view,
2683            &new_detection,
2684            self.score_threshold,
2685            self.iou_threshold,
2686            output_boxes,
2687        );
2688        Ok(())
2689    }
2690
2691    fn decode_yolo_det_quantized(
2692        &self,
2693        outputs: &[ArrayViewDQuantized],
2694        boxes: &configs::Detection,
2695        output_boxes: &mut Vec<DetectBox>,
2696    ) -> Result<(), DecoderError> {
2697        let (boxes_tensor, _) =
2698            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2699        let quant_boxes = boxes
2700            .quantization
2701            .map(Quantization::from)
2702            .unwrap_or_default();
2703
2704        with_quantized!(boxes_tensor, b, {
2705            let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2706            let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2707            decode_yolo_det(
2708                (boxes_tensor, quant_boxes),
2709                self.score_threshold,
2710                self.iou_threshold,
2711                self.nms,
2712                output_boxes,
2713            );
2714        });
2715
2716        Ok(())
2717    }
2718
2719    fn decode_yolo_segdet_quantized(
2720        &self,
2721        outputs: &[ArrayViewDQuantized],
2722        boxes: &configs::Detection,
2723        protos: &configs::Protos,
2724        output_boxes: &mut Vec<DetectBox>,
2725        output_masks: &mut Vec<Segmentation>,
2726    ) -> Result<(), DecoderError> {
2727        let (boxes_tensor, ind) =
2728            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2729        let (protos_tensor, _) =
2730            Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &[ind])?;
2731
2732        let quant_boxes = boxes
2733            .quantization
2734            .map(Quantization::from)
2735            .unwrap_or_default();
2736        let quant_protos = protos
2737            .quantization
2738            .map(Quantization::from)
2739            .unwrap_or_default();
2740
2741        with_quantized!(boxes_tensor, b, {
2742            with_quantized!(protos_tensor, p, {
2743                let box_tensor = Self::swap_axes_if_needed(b, boxes.into());
2744                let box_tensor = box_tensor.slice(s![0, .., ..]);
2745
2746                let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
2747                let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
2748                decode_yolo_segdet_quant(
2749                    (box_tensor, quant_boxes),
2750                    (protos_tensor, quant_protos),
2751                    self.score_threshold,
2752                    self.iou_threshold,
2753                    self.nms,
2754                    output_boxes,
2755                    output_masks,
2756                );
2757            });
2758        });
2759
2760        Ok(())
2761    }
2762
2763    fn decode_yolo_split_det_quantized(
2764        &self,
2765        outputs: &[ArrayViewDQuantized],
2766        boxes: &configs::Boxes,
2767        scores: &configs::Scores,
2768        output_boxes: &mut Vec<DetectBox>,
2769    ) -> Result<(), DecoderError> {
2770        let (boxes_tensor, ind) =
2771            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &[])?;
2772        let (scores_tensor, _) =
2773            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &[ind])?;
2774        let quant_boxes = boxes
2775            .quantization
2776            .map(Quantization::from)
2777            .unwrap_or_default();
2778        let quant_scores = scores
2779            .quantization
2780            .map(Quantization::from)
2781            .unwrap_or_default();
2782
2783        with_quantized!(boxes_tensor, b, {
2784            with_quantized!(scores_tensor, s, {
2785                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2786                let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2787
2788                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2789                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2790                decode_yolo_split_det_quant(
2791                    (boxes_tensor, quant_boxes),
2792                    (scores_tensor, quant_scores),
2793                    self.score_threshold,
2794                    self.iou_threshold,
2795                    self.nms,
2796                    output_boxes,
2797                );
2798            });
2799        });
2800
2801        Ok(())
2802    }
2803
2804    #[allow(clippy::too_many_arguments)]
2805    fn decode_yolo_split_segdet_quantized(
2806        &self,
2807        outputs: &[ArrayViewDQuantized],
2808        boxes: &configs::Boxes,
2809        scores: &configs::Scores,
2810        mask_coeff: &configs::MaskCoefficients,
2811        protos: &configs::Protos,
2812        output_boxes: &mut Vec<DetectBox>,
2813        output_masks: &mut Vec<Segmentation>,
2814    ) -> Result<(), DecoderError> {
2815        let quant_boxes = boxes
2816            .quantization
2817            .map(Quantization::from)
2818            .unwrap_or_default();
2819        let quant_scores = scores
2820            .quantization
2821            .map(Quantization::from)
2822            .unwrap_or_default();
2823        let quant_masks = mask_coeff
2824            .quantization
2825            .map(Quantization::from)
2826            .unwrap_or_default();
2827        let quant_protos = protos
2828            .quantization
2829            .map(Quantization::from)
2830            .unwrap_or_default();
2831
2832        let mut skip = vec![];
2833
2834        let (boxes_tensor, ind) =
2835            Self::find_outputs_with_shape_quantized(&boxes.shape, outputs, &skip)?;
2836        skip.push(ind);
2837
2838        let (scores_tensor, ind) =
2839            Self::find_outputs_with_shape_quantized(&scores.shape, outputs, &skip)?;
2840        skip.push(ind);
2841
2842        let (mask_tensor, ind) =
2843            Self::find_outputs_with_shape_quantized(&mask_coeff.shape, outputs, &skip)?;
2844        skip.push(ind);
2845
2846        let (protos_tensor, _) =
2847            Self::find_outputs_with_shape_quantized(&protos.shape, outputs, &skip)?;
2848
2849        let boxes = with_quantized!(boxes_tensor, b, {
2850            with_quantized!(scores_tensor, s, {
2851                let boxes_tensor = Self::swap_axes_if_needed(b, boxes.into());
2852                let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2853
2854                let scores_tensor = Self::swap_axes_if_needed(s, scores.into());
2855                let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2856                impl_yolo_split_segdet_quant_get_boxes::<XYWH, _, _>(
2857                    (boxes_tensor, quant_boxes),
2858                    (scores_tensor, quant_scores),
2859                    self.score_threshold,
2860                    self.iou_threshold,
2861                    self.nms,
2862                    output_boxes.capacity(),
2863                )
2864            })
2865        });
2866
2867        with_quantized!(mask_tensor, m, {
2868            with_quantized!(protos_tensor, p, {
2869                let mask_tensor = Self::swap_axes_if_needed(m, mask_coeff.into());
2870                let mask_tensor = mask_tensor.slice(s![0, .., ..]);
2871
2872                let protos_tensor = Self::swap_axes_if_needed(p, protos.into());
2873                let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
2874                impl_yolo_split_segdet_quant_process_masks::<_, _>(
2875                    boxes,
2876                    (mask_tensor, quant_masks),
2877                    (protos_tensor, quant_protos),
2878                    output_boxes,
2879                    output_masks,
2880                )
2881            })
2882        });
2883
2884        Ok(())
2885    }
2886
2887    fn decode_modelpack_det_split_float<D>(
2888        &self,
2889        outputs: &[ArrayViewD<D>],
2890        detection: &[configs::Detection],
2891        output_boxes: &mut Vec<DetectBox>,
2892    ) -> Result<(), DecoderError>
2893    where
2894        D: AsPrimitive<f32>,
2895    {
2896        let new_detection = detection
2897            .iter()
2898            .map(|x| match &x.anchors {
2899                None => Err(DecoderError::InvalidConfig(
2900                    "ModelPack Split Detection missing anchors".to_string(),
2901                )),
2902                Some(a) => Ok(ModelPackDetectionConfig {
2903                    anchors: a.clone(),
2904                    quantization: None,
2905                }),
2906            })
2907            .collect::<Result<Vec<_>, _>>()?;
2908
2909        let new_outputs = Self::match_outputs_to_detect(detection, outputs)?;
2910        let new_outputs = new_outputs
2911            .into_iter()
2912            .map(|x| x.slice(s![0, .., .., ..]))
2913            .collect::<Vec<_>>();
2914
2915        decode_modelpack_split_float(
2916            &new_outputs,
2917            &new_detection,
2918            self.score_threshold,
2919            self.iou_threshold,
2920            output_boxes,
2921        );
2922        Ok(())
2923    }
2924
2925    fn decode_modelpack_seg_float<T>(
2926        &self,
2927        outputs: &[ArrayViewD<T>],
2928        segmentation: &configs::Segmentation,
2929        output_masks: &mut Vec<Segmentation>,
2930    ) -> Result<(), DecoderError>
2931    where
2932        T: Float + AsPrimitive<f32> + AsPrimitive<u8> + Send + Sync + 'static,
2933        f32: AsPrimitive<T>,
2934    {
2935        let (seg, _) = Self::find_outputs_with_shape(&segmentation.shape, outputs, &[])?;
2936
2937        let seg = Self::swap_axes_if_needed(seg, segmentation.into());
2938        let seg = seg.slice(s![0, .., .., ..]);
2939        let u8_max = 255.0_f32.as_();
2940        let max = *seg.max().unwrap_or(&u8_max);
2941        let min = *seg.min().unwrap_or(&0.0_f32.as_());
2942        let seg = seg.mapv(|x| ((x - min) / (max - min) * u8_max).as_());
2943        output_masks.push(Segmentation {
2944            xmin: 0.0,
2945            ymin: 0.0,
2946            xmax: 1.0,
2947            ymax: 1.0,
2948            segmentation: seg,
2949        });
2950        Ok(())
2951    }
2952
2953    fn decode_modelpack_det_float<T>(
2954        &self,
2955        outputs: &[ArrayViewD<T>],
2956        boxes: &configs::Boxes,
2957        scores: &configs::Scores,
2958        output_boxes: &mut Vec<DetectBox>,
2959    ) -> Result<(), DecoderError>
2960    where
2961        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
2962        f32: AsPrimitive<T>,
2963    {
2964        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
2965
2966        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
2967        let boxes_tensor = boxes_tensor.slice(s![0, .., 0, ..]);
2968
2969        let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
2970        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
2971        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
2972
2973        decode_modelpack_float(
2974            boxes_tensor,
2975            scores_tensor,
2976            self.score_threshold,
2977            self.iou_threshold,
2978            output_boxes,
2979        );
2980        Ok(())
2981    }
2982
2983    fn decode_yolo_det_float<T>(
2984        &self,
2985        outputs: &[ArrayViewD<T>],
2986        boxes: &configs::Detection,
2987        output_boxes: &mut Vec<DetectBox>,
2988    ) -> Result<(), DecoderError>
2989    where
2990        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
2991        f32: AsPrimitive<T>,
2992    {
2993        let (boxes_tensor, _) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
2994
2995        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
2996        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
2997        decode_yolo_det_float(
2998            boxes_tensor,
2999            self.score_threshold,
3000            self.iou_threshold,
3001            self.nms,
3002            output_boxes,
3003        );
3004        Ok(())
3005    }
3006
3007    fn decode_yolo_segdet_float<T>(
3008        &self,
3009        outputs: &[ArrayViewD<T>],
3010        boxes: &configs::Detection,
3011        protos: &configs::Protos,
3012        output_boxes: &mut Vec<DetectBox>,
3013        output_masks: &mut Vec<Segmentation>,
3014    ) -> Result<(), DecoderError>
3015    where
3016        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3017        f32: AsPrimitive<T>,
3018    {
3019        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3020
3021        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3022        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3023
3024        let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &[ind])?;
3025
3026        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3027        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3028        decode_yolo_segdet_float(
3029            boxes_tensor,
3030            protos_tensor,
3031            self.score_threshold,
3032            self.iou_threshold,
3033            self.nms,
3034            output_boxes,
3035            output_masks,
3036        );
3037        Ok(())
3038    }
3039
3040    fn decode_yolo_split_det_float<T>(
3041        &self,
3042        outputs: &[ArrayViewD<T>],
3043        boxes: &configs::Boxes,
3044        scores: &configs::Scores,
3045        output_boxes: &mut Vec<DetectBox>,
3046    ) -> Result<(), DecoderError>
3047    where
3048        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3049        f32: AsPrimitive<T>,
3050    {
3051        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &[])?;
3052        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3053        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3054
3055        let (scores_tensor, _) = Self::find_outputs_with_shape(&scores.shape, outputs, &[ind])?;
3056
3057        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3058        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3059
3060        decode_yolo_split_det_float(
3061            boxes_tensor,
3062            scores_tensor,
3063            self.score_threshold,
3064            self.iou_threshold,
3065            self.nms,
3066            output_boxes,
3067        );
3068        Ok(())
3069    }
3070
3071    #[allow(clippy::too_many_arguments)]
3072    fn decode_yolo_split_segdet_float<T>(
3073        &self,
3074        outputs: &[ArrayViewD<T>],
3075        boxes: &configs::Boxes,
3076        scores: &configs::Scores,
3077        mask_coeff: &configs::MaskCoefficients,
3078        protos: &configs::Protos,
3079        output_boxes: &mut Vec<DetectBox>,
3080        output_masks: &mut Vec<Segmentation>,
3081    ) -> Result<(), DecoderError>
3082    where
3083        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3084        f32: AsPrimitive<T>,
3085    {
3086        let mut skip = vec![];
3087        let (boxes_tensor, ind) = Self::find_outputs_with_shape(&boxes.shape, outputs, &skip)?;
3088
3089        let boxes_tensor = Self::swap_axes_if_needed(boxes_tensor, boxes.into());
3090        let boxes_tensor = boxes_tensor.slice(s![0, .., ..]);
3091        skip.push(ind);
3092
3093        let (scores_tensor, ind) = Self::find_outputs_with_shape(&scores.shape, outputs, &skip)?;
3094
3095        let scores_tensor = Self::swap_axes_if_needed(scores_tensor, scores.into());
3096        let scores_tensor = scores_tensor.slice(s![0, .., ..]);
3097        skip.push(ind);
3098
3099        let (mask_tensor, ind) = Self::find_outputs_with_shape(&mask_coeff.shape, outputs, &skip)?;
3100        let mask_tensor = Self::swap_axes_if_needed(mask_tensor, mask_coeff.into());
3101        let mask_tensor = mask_tensor.slice(s![0, .., ..]);
3102        skip.push(ind);
3103
3104        let (protos_tensor, _) = Self::find_outputs_with_shape(&protos.shape, outputs, &skip)?;
3105        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos.into());
3106        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3107        decode_yolo_split_segdet_float(
3108            boxes_tensor,
3109            scores_tensor,
3110            mask_tensor,
3111            protos_tensor,
3112            self.score_threshold,
3113            self.iou_threshold,
3114            self.nms,
3115            output_boxes,
3116            output_masks,
3117        );
3118        Ok(())
3119    }
3120
3121    /// Decodes end-to-end YOLO detection outputs (post-NMS from model).
3122    ///
3123    /// Input shape: (1, N, 6+) where columns are [x1, y1, x2, y2, conf, class,
3124    /// ...] Boxes are output directly from model (may be normalized or
3125    /// pixel coords depending on config).
3126    fn decode_yolo_end_to_end_det_float<T>(
3127        &self,
3128        outputs: &[ArrayViewD<T>],
3129        boxes_config: &configs::Detection,
3130        output_boxes: &mut Vec<DetectBox>,
3131    ) -> Result<(), DecoderError>
3132    where
3133        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3134        f32: AsPrimitive<T>,
3135    {
3136        let (det_tensor, _) = Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3137        let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3138        let det_tensor = det_tensor.slice(s![0, .., ..]);
3139
3140        crate::yolo::decode_yolo_end_to_end_det_float(
3141            det_tensor,
3142            self.score_threshold,
3143            output_boxes,
3144        )?;
3145        Ok(())
3146    }
3147
3148    /// Decodes end-to-end YOLO detection + segmentation outputs (post-NMS from
3149    /// model).
3150    ///
3151    /// Input shapes:
3152    /// - detection: (1, N, 6 + num_protos) where columns are [x1, y1, x2, y2,
3153    ///   conf, class, mask_coeff_0, ..., mask_coeff_31]
3154    /// - protos: (1, proto_height, proto_width, num_protos)
3155    fn decode_yolo_end_to_end_segdet_float<T>(
3156        &self,
3157        outputs: &[ArrayViewD<T>],
3158        boxes_config: &configs::Detection,
3159        protos_config: &configs::Protos,
3160        output_boxes: &mut Vec<DetectBox>,
3161        output_masks: &mut Vec<Segmentation>,
3162    ) -> Result<(), DecoderError>
3163    where
3164        T: Float + AsPrimitive<f32> + Send + Sync + 'static,
3165        f32: AsPrimitive<T>,
3166    {
3167        if outputs.len() < 2 {
3168            return Err(DecoderError::InvalidShape(
3169                "End-to-end segdet requires detection and protos outputs".to_string(),
3170            ));
3171        }
3172
3173        let (det_tensor, det_ind) =
3174            Self::find_outputs_with_shape(&boxes_config.shape, outputs, &[])?;
3175        let det_tensor = Self::swap_axes_if_needed(det_tensor, boxes_config.into());
3176        let det_tensor = det_tensor.slice(s![0, .., ..]);
3177
3178        let (protos_tensor, _) =
3179            Self::find_outputs_with_shape(&protos_config.shape, outputs, &[det_ind])?;
3180        let protos_tensor = Self::swap_axes_if_needed(protos_tensor, protos_config.into());
3181        let protos_tensor = protos_tensor.slice(s![0, .., .., ..]);
3182
3183        crate::yolo::decode_yolo_end_to_end_segdet_float(
3184            det_tensor,
3185            protos_tensor,
3186            self.score_threshold,
3187            output_boxes,
3188            output_masks,
3189        )?;
3190        Ok(())
3191    }
3192
3193    fn match_outputs_to_detect<'a, 'b, T>(
3194        configs: &[configs::Detection],
3195        outputs: &'a [ArrayViewD<'b, T>],
3196    ) -> Result<Vec<&'a ArrayViewD<'b, T>>, DecoderError> {
3197        let mut new_output_order = Vec::new();
3198        for c in configs {
3199            let mut found = false;
3200            for o in outputs {
3201                if o.shape() == c.shape {
3202                    new_output_order.push(o);
3203                    found = true;
3204                    break;
3205                }
3206            }
3207            if !found {
3208                return Err(DecoderError::InvalidShape(format!(
3209                    "Did not find output with shape {:?}",
3210                    c.shape
3211                )));
3212            }
3213        }
3214        Ok(new_output_order)
3215    }
3216
3217    fn find_outputs_with_shape<'a, 'b, T>(
3218        shape: &[usize],
3219        outputs: &'a [ArrayViewD<'b, T>],
3220        skip: &[usize],
3221    ) -> Result<(&'a ArrayViewD<'b, T>, usize), DecoderError> {
3222        for (ind, o) in outputs.iter().enumerate() {
3223            if skip.contains(&ind) {
3224                continue;
3225            }
3226            if o.shape() == shape {
3227                return Ok((o, ind));
3228            }
3229        }
3230        Err(DecoderError::InvalidShape(format!(
3231            "Did not find output with shape {:?}",
3232            shape
3233        )))
3234    }
3235
3236    fn find_outputs_with_shape_quantized<'a, 'b>(
3237        shape: &[usize],
3238        outputs: &'a [ArrayViewDQuantized<'b>],
3239        skip: &[usize],
3240    ) -> Result<(&'a ArrayViewDQuantized<'b>, usize), DecoderError> {
3241        for (ind, o) in outputs.iter().enumerate() {
3242            if skip.contains(&ind) {
3243                continue;
3244            }
3245            if o.shape() == shape {
3246                return Ok((o, ind));
3247            }
3248        }
3249        Err(DecoderError::InvalidShape(format!(
3250            "Did not find output with shape {:?}",
3251            shape
3252        )))
3253    }
3254
3255    /// This is split detection, need to swap axes to batch, height, width,
3256    /// num_anchors_x_features,
3257    fn modelpack_det_order(x: DimName) -> usize {
3258        match x {
3259            DimName::Batch => 0,
3260            DimName::NumBoxes => 1,
3261            DimName::Padding => 2,
3262            DimName::BoxCoords => 3,
3263            _ => 1000, // this should be unreachable
3264        }
3265    }
3266
3267    // This is Ultralytics detection, need to swap axes to batch, num_features,
3268    // height, width
3269    fn yolo_det_order(x: DimName) -> usize {
3270        match x {
3271            DimName::Batch => 0,
3272            DimName::NumFeatures => 1,
3273            DimName::NumBoxes => 2,
3274            _ => 1000, // this should be unreachable
3275        }
3276    }
3277
3278    // This is modelpack boxes, need to swap axes to batch, num_boxes, padding,
3279    // box_coords
3280    fn modelpack_boxes_order(x: DimName) -> usize {
3281        match x {
3282            DimName::Batch => 0,
3283            DimName::NumBoxes => 1,
3284            DimName::Padding => 2,
3285            DimName::BoxCoords => 3,
3286            _ => 1000, // this should be unreachable
3287        }
3288    }
3289
3290    /// This is Ultralytics boxes, need to swap axes to batch, box_coords,
3291    /// num_boxes
3292    fn yolo_boxes_order(x: DimName) -> usize {
3293        match x {
3294            DimName::Batch => 0,
3295            DimName::BoxCoords => 1,
3296            DimName::NumBoxes => 2,
3297            _ => 1000, // this should be unreachable
3298        }
3299    }
3300
3301    /// This is modelpack scores, need to swap axes to batch, num_boxes,
3302    /// num_classes
3303    fn modelpack_scores_order(x: DimName) -> usize {
3304        match x {
3305            DimName::Batch => 0,
3306            DimName::NumBoxes => 1,
3307            DimName::NumClasses => 2,
3308            _ => 1000, // this should be unreachable
3309        }
3310    }
3311
3312    fn yolo_scores_order(x: DimName) -> usize {
3313        match x {
3314            DimName::Batch => 0,
3315            DimName::NumClasses => 1,
3316            DimName::NumBoxes => 2,
3317            _ => 1000, // this should be unreachable
3318        }
3319    }
3320
3321    /// This is modelpack segmentation, need to swap axes to batch, height,
3322    /// width, num_classes
3323    fn modelpack_segmentation_order(x: DimName) -> usize {
3324        match x {
3325            DimName::Batch => 0,
3326            DimName::Height => 1,
3327            DimName::Width => 2,
3328            DimName::NumClasses => 3,
3329            _ => 1000, // this should be unreachable
3330        }
3331    }
3332
3333    /// This is modelpack masks, need to swap axes to batch, height,
3334    /// width
3335    fn modelpack_mask_order(x: DimName) -> usize {
3336        match x {
3337            DimName::Batch => 0,
3338            DimName::Height => 1,
3339            DimName::Width => 2,
3340            _ => 1000, // this should be unreachable
3341        }
3342    }
3343
3344    /// This is yolo protos, need to swap axes to batch, height, width,
3345    /// num_protos
3346    fn yolo_protos_order(x: DimName) -> usize {
3347        match x {
3348            DimName::Batch => 0,
3349            DimName::Height => 1,
3350            DimName::Width => 2,
3351            DimName::NumProtos => 3,
3352            _ => 1000, // this should be unreachable
3353        }
3354    }
3355
3356    /// This is yolo mask coefficients, need to swap axes to batch, num_protos,
3357    /// num_boxes
3358    fn yolo_maskcoefficients_order(x: DimName) -> usize {
3359        match x {
3360            DimName::Batch => 0,
3361            DimName::NumProtos => 1,
3362            DimName::NumBoxes => 2,
3363            _ => 1000, // this should be unreachable
3364        }
3365    }
3366
3367    fn get_order_fn(config: ConfigOutputRef) -> fn(DimName) -> usize {
3368        let decoder_type = config.decoder();
3369        match (config, decoder_type) {
3370            (ConfigOutputRef::Detection(_), DecoderType::ModelPack) => Self::modelpack_det_order,
3371            (ConfigOutputRef::Detection(_), DecoderType::Ultralytics) => Self::yolo_det_order,
3372            (ConfigOutputRef::Boxes(_), DecoderType::ModelPack) => Self::modelpack_boxes_order,
3373            (ConfigOutputRef::Boxes(_), DecoderType::Ultralytics) => Self::yolo_boxes_order,
3374            (ConfigOutputRef::Scores(_), DecoderType::ModelPack) => Self::modelpack_scores_order,
3375            (ConfigOutputRef::Scores(_), DecoderType::Ultralytics) => Self::yolo_scores_order,
3376            (ConfigOutputRef::Segmentation(_), _) => Self::modelpack_segmentation_order,
3377            (ConfigOutputRef::Mask(_), _) => Self::modelpack_mask_order,
3378            (ConfigOutputRef::Protos(_), _) => Self::yolo_protos_order,
3379            (ConfigOutputRef::MaskCoefficients(_), _) => Self::yolo_maskcoefficients_order,
3380        }
3381    }
3382
3383    fn swap_axes_if_needed<'a, T, D: Dimension>(
3384        array: &ArrayView<'a, T, D>,
3385        config: ConfigOutputRef,
3386    ) -> ArrayView<'a, T, D> {
3387        let mut array = array.clone();
3388        if config.dshape().is_empty() {
3389            return array;
3390        }
3391        let order_fn: fn(DimName) -> usize = Self::get_order_fn(config.clone());
3392        let mut current_order: Vec<usize> = config
3393            .dshape()
3394            .iter()
3395            .map(|x| order_fn(x.0))
3396            .collect::<Vec<_>>();
3397
3398        assert_eq!(array.shape().len(), current_order.len());
3399        // do simple bubble sort as swap_axes is inexpensive and the
3400        // number of dimensions is small
3401        for i in 0..current_order.len() {
3402            let mut swapped = false;
3403            for j in 0..current_order.len() - 1 - i {
3404                if current_order[j] > current_order[j + 1] {
3405                    array.swap_axes(j, j + 1);
3406                    current_order.swap(j, j + 1);
3407                    swapped = true;
3408                }
3409            }
3410            if !swapped {
3411                break;
3412            }
3413        }
3414        array
3415    }
3416
3417    fn match_outputs_to_detect_quantized<'a, 'b>(
3418        configs: &[configs::Detection],
3419        outputs: &'a [ArrayViewDQuantized<'b>],
3420    ) -> Result<Vec<&'a ArrayViewDQuantized<'b>>, DecoderError> {
3421        let mut new_output_order = Vec::new();
3422        for c in configs {
3423            let mut found = false;
3424            for o in outputs {
3425                if o.shape() == c.shape {
3426                    new_output_order.push(o);
3427                    found = true;
3428                    break;
3429                }
3430            }
3431            if !found {
3432                return Err(DecoderError::InvalidShape(format!(
3433                    "Did not find output with shape {:?}",
3434                    c.shape
3435                )));
3436            }
3437        }
3438        Ok(new_output_order)
3439    }
3440}
3441
3442#[cfg(test)]
3443#[cfg_attr(coverage_nightly, coverage(off))]
3444mod decoder_builder_tests {
3445    use super::*;
3446
3447    #[test]
3448    fn test_decoder_builder_no_config() {
3449        use crate::DecoderBuilder;
3450        let result = DecoderBuilder::default().build();
3451        assert!(matches!(result, Err(DecoderError::NoConfig)));
3452    }
3453
3454    #[test]
3455    fn test_decoder_builder_empty_config() {
3456        use crate::DecoderBuilder;
3457        let result = DecoderBuilder::default()
3458            .with_config(ConfigOutputs {
3459                outputs: vec![],
3460                ..Default::default()
3461            })
3462            .build();
3463        assert!(
3464            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "No outputs found in config")
3465        );
3466    }
3467
3468    #[test]
3469    fn test_malformed_config_yaml() {
3470        let malformed_yaml = "
3471        model_type: yolov8_det
3472        outputs:
3473          - shape: [1, 84, 8400]
3474        "
3475        .to_owned();
3476        let result = DecoderBuilder::new()
3477            .with_config_yaml_str(malformed_yaml)
3478            .build();
3479        assert!(matches!(result, Err(DecoderError::Yaml(_))));
3480    }
3481
3482    #[test]
3483    fn test_malformed_config_json() {
3484        let malformed_yaml = "
3485        {
3486            \"model_type\": \"yolov8_det\",
3487            \"outputs\": [
3488                {
3489                    \"shape\": [1, 84, 8400]
3490                }
3491            ]
3492        }"
3493        .to_owned();
3494        let result = DecoderBuilder::new()
3495            .with_config_json_str(malformed_yaml)
3496            .build();
3497        assert!(matches!(result, Err(DecoderError::Json(_))));
3498    }
3499
3500    #[test]
3501    fn test_modelpack_and_yolo_config_error() {
3502        let result = DecoderBuilder::new()
3503            .with_config_modelpack_det(
3504                configs::Boxes {
3505                    decoder: configs::DecoderType::Ultralytics,
3506                    shape: vec![1, 4, 8400],
3507                    quantization: None,
3508                    dshape: vec![
3509                        (DimName::Batch, 1),
3510                        (DimName::BoxCoords, 4),
3511                        (DimName::NumBoxes, 8400),
3512                    ],
3513                    normalized: Some(true),
3514                },
3515                configs::Scores {
3516                    decoder: configs::DecoderType::ModelPack,
3517                    shape: vec![1, 80, 8400],
3518                    quantization: None,
3519                    dshape: vec![
3520                        (DimName::Batch, 1),
3521                        (DimName::NumClasses, 80),
3522                        (DimName::NumBoxes, 8400),
3523                    ],
3524                },
3525            )
3526            .build();
3527
3528        assert!(matches!(
3529            result, Err(DecoderError::InvalidConfig(s)) if s == "Both ModelPack and Yolo outputs found in config"
3530        ));
3531    }
3532
3533    #[test]
3534    fn test_yolo_invalid_seg_shape() {
3535        let result = DecoderBuilder::new()
3536            .with_config_yolo_segdet(
3537                configs::Detection {
3538                    decoder: configs::DecoderType::Ultralytics,
3539                    shape: vec![1, 85, 8400, 1], // Invalid shape
3540                    quantization: None,
3541                    anchors: None,
3542                    dshape: vec![
3543                        (DimName::Batch, 1),
3544                        (DimName::NumFeatures, 85),
3545                        (DimName::NumBoxes, 8400),
3546                        (DimName::Batch, 1),
3547                    ],
3548                    normalized: Some(true),
3549                },
3550                configs::Protos {
3551                    decoder: configs::DecoderType::Ultralytics,
3552                    shape: vec![1, 32, 160, 160],
3553                    quantization: None,
3554                    dshape: vec![
3555                        (DimName::Batch, 1),
3556                        (DimName::NumProtos, 32),
3557                        (DimName::Height, 160),
3558                        (DimName::Width, 160),
3559                    ],
3560                },
3561                Some(DecoderVersion::Yolo11),
3562            )
3563            .build();
3564
3565        assert!(matches!(
3566            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")
3567        ));
3568    }
3569
3570    #[test]
3571    fn test_yolo_invalid_mask() {
3572        let result = DecoderBuilder::new()
3573            .with_config(ConfigOutputs {
3574                outputs: vec![ConfigOutput::Mask(configs::Mask {
3575                    shape: vec![1, 160, 160, 1],
3576                    decoder: configs::DecoderType::Ultralytics,
3577                    quantization: None,
3578                    dshape: vec![
3579                        (DimName::Batch, 1),
3580                        (DimName::Height, 160),
3581                        (DimName::Width, 160),
3582                        (DimName::NumFeatures, 1),
3583                    ],
3584                })],
3585                ..Default::default()
3586            })
3587            .build();
3588
3589        assert!(matches!(
3590            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Mask output with Yolo decoder")
3591        ));
3592    }
3593
3594    #[test]
3595    fn test_yolo_invalid_outputs() {
3596        let result = DecoderBuilder::new()
3597            .with_config(ConfigOutputs {
3598                outputs: vec![ConfigOutput::Segmentation(configs::Segmentation {
3599                    shape: vec![1, 84, 8400],
3600                    decoder: configs::DecoderType::Ultralytics,
3601                    quantization: None,
3602                    dshape: vec![
3603                        (DimName::Batch, 1),
3604                        (DimName::NumFeatures, 84),
3605                        (DimName::NumBoxes, 8400),
3606                    ],
3607                })],
3608                ..Default::default()
3609            })
3610            .build();
3611
3612        assert!(
3613            matches!(result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid Segmentation output with Yolo decoder")
3614        );
3615    }
3616
3617    #[test]
3618    fn test_yolo_invalid_det() {
3619        let result = DecoderBuilder::new()
3620            .with_config_yolo_det(
3621                configs::Detection {
3622                    anchors: None,
3623                    decoder: DecoderType::Ultralytics,
3624                    quantization: None,
3625                    shape: vec![1, 84, 8400, 1], // Invalid shape
3626                    dshape: vec![
3627                        (DimName::Batch, 1),
3628                        (DimName::NumFeatures, 84),
3629                        (DimName::NumBoxes, 8400),
3630                        (DimName::Batch, 1),
3631                    ],
3632                    normalized: Some(true),
3633                },
3634                Some(DecoderVersion::Yolo11),
3635            )
3636            .build();
3637
3638        assert!(matches!(
3639            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
3640
3641        let result = DecoderBuilder::new()
3642            .with_config_yolo_det(
3643                configs::Detection {
3644                    anchors: None,
3645                    decoder: DecoderType::Ultralytics,
3646                    quantization: None,
3647                    shape: vec![1, 8400, 3], // Invalid shape
3648                    dshape: vec![
3649                        (DimName::Batch, 1),
3650                        (DimName::NumBoxes, 8400),
3651                        (DimName::NumFeatures, 3),
3652                    ],
3653                    normalized: Some(true),
3654                },
3655                Some(DecoderVersion::Yolo11),
3656            )
3657            .build();
3658
3659        assert!(
3660            matches!(
3661            &result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")),
3662            "{}",
3663            result.unwrap_err()
3664        );
3665
3666        let result = DecoderBuilder::new()
3667            .with_config_yolo_det(
3668                configs::Detection {
3669                    anchors: None,
3670                    decoder: DecoderType::Ultralytics,
3671                    quantization: None,
3672                    shape: vec![1, 3, 8400], // Invalid shape
3673                    dshape: Vec::new(),
3674                    normalized: Some(true),
3675                },
3676                Some(DecoderVersion::Yolo11),
3677            )
3678            .build();
3679
3680        assert!(matches!(
3681            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid shape: Yolo num_features 3 must be greater than 4")));
3682    }
3683
3684    #[test]
3685    fn test_yolo_invalid_segdet() {
3686        let result = DecoderBuilder::new()
3687            .with_config_yolo_segdet(
3688                configs::Detection {
3689                    decoder: configs::DecoderType::Ultralytics,
3690                    shape: vec![1, 85, 8400, 1], // Invalid shape
3691                    quantization: None,
3692                    anchors: None,
3693                    dshape: vec![
3694                        (DimName::Batch, 1),
3695                        (DimName::NumFeatures, 85),
3696                        (DimName::NumBoxes, 8400),
3697                        (DimName::Batch, 1),
3698                    ],
3699                    normalized: Some(true),
3700                },
3701                configs::Protos {
3702                    decoder: configs::DecoderType::Ultralytics,
3703                    shape: vec![1, 32, 160, 160],
3704                    quantization: None,
3705                    dshape: vec![
3706                        (DimName::Batch, 1),
3707                        (DimName::NumProtos, 32),
3708                        (DimName::Height, 160),
3709                        (DimName::Width, 160),
3710                    ],
3711                },
3712                Some(DecoderVersion::Yolo11),
3713            )
3714            .build();
3715
3716        assert!(matches!(
3717            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Detection shape")));
3718
3719        let result = DecoderBuilder::new()
3720            .with_config_yolo_segdet(
3721                configs::Detection {
3722                    decoder: configs::DecoderType::Ultralytics,
3723                    shape: vec![1, 85, 8400],
3724                    quantization: None,
3725                    anchors: None,
3726                    dshape: vec![
3727                        (DimName::Batch, 1),
3728                        (DimName::NumFeatures, 85),
3729                        (DimName::NumBoxes, 8400),
3730                    ],
3731                    normalized: Some(true),
3732                },
3733                configs::Protos {
3734                    decoder: configs::DecoderType::Ultralytics,
3735                    shape: vec![1, 32, 160, 160, 1], // Invalid shape
3736                    dshape: vec![
3737                        (DimName::Batch, 1),
3738                        (DimName::NumProtos, 32),
3739                        (DimName::Height, 160),
3740                        (DimName::Width, 160),
3741                        (DimName::Batch, 1),
3742                    ],
3743                    quantization: None,
3744                },
3745                Some(DecoderVersion::Yolo11),
3746            )
3747            .build();
3748
3749        assert!(matches!(
3750            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
3751
3752        let result = DecoderBuilder::new()
3753            .with_config_yolo_segdet(
3754                configs::Detection {
3755                    decoder: configs::DecoderType::Ultralytics,
3756                    shape: vec![1, 8400, 36], // too few classes
3757                    quantization: None,
3758                    anchors: None,
3759                    dshape: vec![
3760                        (DimName::Batch, 1),
3761                        (DimName::NumBoxes, 8400),
3762                        (DimName::NumFeatures, 36),
3763                    ],
3764                    normalized: Some(true),
3765                },
3766                configs::Protos {
3767                    decoder: configs::DecoderType::Ultralytics,
3768                    shape: vec![1, 32, 160, 160],
3769                    quantization: None,
3770                    dshape: vec![
3771                        (DimName::Batch, 1),
3772                        (DimName::NumProtos, 32),
3773                        (DimName::Height, 160),
3774                        (DimName::Width, 160),
3775                    ],
3776                },
3777                Some(DecoderVersion::Yolo11),
3778            )
3779            .build();
3780        println!("{:?}", result);
3781        assert!(matches!(
3782            result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid shape: Yolo num_features 36 must be greater than 36"));
3783    }
3784
3785    #[test]
3786    fn test_yolo_invalid_split_det() {
3787        let result = DecoderBuilder::new()
3788            .with_config_yolo_split_det(
3789                configs::Boxes {
3790                    decoder: configs::DecoderType::Ultralytics,
3791                    shape: vec![1, 4, 8400, 1], // Invalid shape
3792                    quantization: None,
3793                    dshape: vec![
3794                        (DimName::Batch, 1),
3795                        (DimName::BoxCoords, 4),
3796                        (DimName::NumBoxes, 8400),
3797                        (DimName::Batch, 1),
3798                    ],
3799                    normalized: Some(true),
3800                },
3801                configs::Scores {
3802                    decoder: configs::DecoderType::Ultralytics,
3803                    shape: vec![1, 80, 8400],
3804                    quantization: None,
3805                    dshape: vec![
3806                        (DimName::Batch, 1),
3807                        (DimName::NumClasses, 80),
3808                        (DimName::NumBoxes, 8400),
3809                    ],
3810                },
3811            )
3812            .build();
3813
3814        assert!(matches!(
3815            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
3816
3817        let result = DecoderBuilder::new()
3818            .with_config_yolo_split_det(
3819                configs::Boxes {
3820                    decoder: configs::DecoderType::Ultralytics,
3821                    shape: vec![1, 4, 8400],
3822                    quantization: None,
3823                    dshape: vec![
3824                        (DimName::Batch, 1),
3825                        (DimName::BoxCoords, 4),
3826                        (DimName::NumBoxes, 8400),
3827                    ],
3828                    normalized: Some(true),
3829                },
3830                configs::Scores {
3831                    decoder: configs::DecoderType::Ultralytics,
3832                    shape: vec![1, 80, 8400, 1], // Invalid shape
3833                    quantization: None,
3834                    dshape: vec![
3835                        (DimName::Batch, 1),
3836                        (DimName::NumClasses, 80),
3837                        (DimName::NumBoxes, 8400),
3838                        (DimName::Batch, 1),
3839                    ],
3840                },
3841            )
3842            .build();
3843
3844        assert!(matches!(
3845            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
3846
3847        let result = DecoderBuilder::new()
3848            .with_config_yolo_split_det(
3849                configs::Boxes {
3850                    decoder: configs::DecoderType::Ultralytics,
3851                    shape: vec![1, 8400, 4],
3852                    quantization: None,
3853                    dshape: vec![
3854                        (DimName::Batch, 1),
3855                        (DimName::NumBoxes, 8400),
3856                        (DimName::BoxCoords, 4),
3857                    ],
3858                    normalized: Some(true),
3859                },
3860                configs::Scores {
3861                    decoder: configs::DecoderType::Ultralytics,
3862                    shape: vec![1, 8400 + 1, 80], // Invalid number of boxes
3863                    quantization: None,
3864                    dshape: vec![
3865                        (DimName::Batch, 1),
3866                        (DimName::NumBoxes, 8401),
3867                        (DimName::NumClasses, 80),
3868                    ],
3869                },
3870            )
3871            .build();
3872
3873        assert!(matches!(
3874            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
3875
3876        let result = DecoderBuilder::new()
3877            .with_config_yolo_split_det(
3878                configs::Boxes {
3879                    decoder: configs::DecoderType::Ultralytics,
3880                    shape: vec![1, 5, 8400], // Invalid boxes dimensions
3881                    quantization: None,
3882                    dshape: vec![
3883                        (DimName::Batch, 1),
3884                        (DimName::BoxCoords, 5),
3885                        (DimName::NumBoxes, 8400),
3886                    ],
3887                    normalized: Some(true),
3888                },
3889                configs::Scores {
3890                    decoder: configs::DecoderType::Ultralytics,
3891                    shape: vec![1, 80, 8400],
3892                    quantization: None,
3893                    dshape: vec![
3894                        (DimName::Batch, 1),
3895                        (DimName::NumClasses, 80),
3896                        (DimName::NumBoxes, 8400),
3897                    ],
3898                },
3899            )
3900            .build();
3901        assert!(matches!(
3902            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("BoxCoords dimension size must be 4")));
3903    }
3904
3905    #[test]
3906    fn test_yolo_invalid_split_segdet() {
3907        let result = DecoderBuilder::new()
3908            .with_config_yolo_split_segdet(
3909                configs::Boxes {
3910                    decoder: configs::DecoderType::Ultralytics,
3911                    shape: vec![1, 8400, 4, 1],
3912                    quantization: None,
3913                    dshape: vec![
3914                        (DimName::Batch, 1),
3915                        (DimName::NumBoxes, 8400),
3916                        (DimName::BoxCoords, 4),
3917                        (DimName::Batch, 1),
3918                    ],
3919                    normalized: Some(true),
3920                },
3921                configs::Scores {
3922                    decoder: configs::DecoderType::Ultralytics,
3923                    shape: vec![1, 8400, 80],
3924
3925                    quantization: None,
3926                    dshape: vec![
3927                        (DimName::Batch, 1),
3928                        (DimName::NumBoxes, 8400),
3929                        (DimName::NumClasses, 80),
3930                    ],
3931                },
3932                configs::MaskCoefficients {
3933                    decoder: configs::DecoderType::Ultralytics,
3934                    shape: vec![1, 8400, 32],
3935                    quantization: None,
3936                    dshape: vec![
3937                        (DimName::Batch, 1),
3938                        (DimName::NumBoxes, 8400),
3939                        (DimName::NumProtos, 32),
3940                    ],
3941                },
3942                configs::Protos {
3943                    decoder: configs::DecoderType::Ultralytics,
3944                    shape: vec![1, 32, 160, 160],
3945                    quantization: None,
3946                    dshape: vec![
3947                        (DimName::Batch, 1),
3948                        (DimName::NumProtos, 32),
3949                        (DimName::Height, 160),
3950                        (DimName::Width, 160),
3951                    ],
3952                },
3953            )
3954            .build();
3955
3956        assert!(matches!(
3957            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Boxes shape")));
3958
3959        let result = DecoderBuilder::new()
3960            .with_config_yolo_split_segdet(
3961                configs::Boxes {
3962                    decoder: configs::DecoderType::Ultralytics,
3963                    shape: vec![1, 8400, 4],
3964                    quantization: None,
3965                    dshape: vec![
3966                        (DimName::Batch, 1),
3967                        (DimName::NumBoxes, 8400),
3968                        (DimName::BoxCoords, 4),
3969                    ],
3970                    normalized: Some(true),
3971                },
3972                configs::Scores {
3973                    decoder: configs::DecoderType::Ultralytics,
3974                    shape: vec![1, 8400, 80, 1],
3975                    quantization: None,
3976                    dshape: vec![
3977                        (DimName::Batch, 1),
3978                        (DimName::NumBoxes, 8400),
3979                        (DimName::NumClasses, 80),
3980                        (DimName::Batch, 1),
3981                    ],
3982                },
3983                configs::MaskCoefficients {
3984                    decoder: configs::DecoderType::Ultralytics,
3985                    shape: vec![1, 8400, 32],
3986                    quantization: None,
3987                    dshape: vec![
3988                        (DimName::Batch, 1),
3989                        (DimName::NumBoxes, 8400),
3990                        (DimName::NumProtos, 32),
3991                    ],
3992                },
3993                configs::Protos {
3994                    decoder: configs::DecoderType::Ultralytics,
3995                    shape: vec![1, 32, 160, 160],
3996                    quantization: None,
3997                    dshape: vec![
3998                        (DimName::Batch, 1),
3999                        (DimName::NumProtos, 32),
4000                        (DimName::Height, 160),
4001                        (DimName::Width, 160),
4002                    ],
4003                },
4004            )
4005            .build();
4006
4007        assert!(matches!(
4008            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Scores shape")));
4009
4010        let result = DecoderBuilder::new()
4011            .with_config_yolo_split_segdet(
4012                configs::Boxes {
4013                    decoder: configs::DecoderType::Ultralytics,
4014                    shape: vec![1, 8400, 4],
4015                    quantization: None,
4016                    dshape: vec![
4017                        (DimName::Batch, 1),
4018                        (DimName::NumBoxes, 8400),
4019                        (DimName::BoxCoords, 4),
4020                    ],
4021                    normalized: Some(true),
4022                },
4023                configs::Scores {
4024                    decoder: configs::DecoderType::Ultralytics,
4025                    shape: vec![1, 8400, 80],
4026                    quantization: None,
4027                    dshape: vec![
4028                        (DimName::Batch, 1),
4029                        (DimName::NumBoxes, 8400),
4030                        (DimName::NumClasses, 80),
4031                    ],
4032                },
4033                configs::MaskCoefficients {
4034                    decoder: configs::DecoderType::Ultralytics,
4035                    shape: vec![1, 8400, 32, 1],
4036                    quantization: None,
4037                    dshape: vec![
4038                        (DimName::Batch, 1),
4039                        (DimName::NumBoxes, 8400),
4040                        (DimName::NumProtos, 32),
4041                        (DimName::Batch, 1),
4042                    ],
4043                },
4044                configs::Protos {
4045                    decoder: configs::DecoderType::Ultralytics,
4046                    shape: vec![1, 32, 160, 160],
4047                    quantization: None,
4048                    dshape: vec![
4049                        (DimName::Batch, 1),
4050                        (DimName::NumProtos, 32),
4051                        (DimName::Height, 160),
4052                        (DimName::Width, 160),
4053                    ],
4054                },
4055            )
4056            .build();
4057
4058        assert!(matches!(
4059            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Split Mask Coefficients shape")));
4060
4061        let result = DecoderBuilder::new()
4062            .with_config_yolo_split_segdet(
4063                configs::Boxes {
4064                    decoder: configs::DecoderType::Ultralytics,
4065                    shape: vec![1, 8400, 4],
4066                    quantization: None,
4067                    dshape: vec![
4068                        (DimName::Batch, 1),
4069                        (DimName::NumBoxes, 8400),
4070                        (DimName::BoxCoords, 4),
4071                    ],
4072                    normalized: Some(true),
4073                },
4074                configs::Scores {
4075                    decoder: configs::DecoderType::Ultralytics,
4076                    shape: vec![1, 8400, 80],
4077                    quantization: None,
4078                    dshape: vec![
4079                        (DimName::Batch, 1),
4080                        (DimName::NumBoxes, 8400),
4081                        (DimName::NumClasses, 80),
4082                    ],
4083                },
4084                configs::MaskCoefficients {
4085                    decoder: configs::DecoderType::Ultralytics,
4086                    shape: vec![1, 8400, 32],
4087                    quantization: None,
4088                    dshape: vec![
4089                        (DimName::Batch, 1),
4090                        (DimName::NumBoxes, 8400),
4091                        (DimName::NumProtos, 32),
4092                    ],
4093                },
4094                configs::Protos {
4095                    decoder: configs::DecoderType::Ultralytics,
4096                    shape: vec![1, 32, 160, 160, 1],
4097                    quantization: None,
4098                    dshape: vec![
4099                        (DimName::Batch, 1),
4100                        (DimName::NumProtos, 32),
4101                        (DimName::Height, 160),
4102                        (DimName::Width, 160),
4103                        (DimName::Batch, 1),
4104                    ],
4105                },
4106            )
4107            .build();
4108
4109        assert!(matches!(
4110            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid Yolo Protos shape")));
4111
4112        let result = DecoderBuilder::new()
4113            .with_config_yolo_split_segdet(
4114                configs::Boxes {
4115                    decoder: configs::DecoderType::Ultralytics,
4116                    shape: vec![1, 8400, 4],
4117                    quantization: None,
4118                    dshape: vec![
4119                        (DimName::Batch, 1),
4120                        (DimName::NumBoxes, 8400),
4121                        (DimName::BoxCoords, 4),
4122                    ],
4123                    normalized: Some(true),
4124                },
4125                configs::Scores {
4126                    decoder: configs::DecoderType::Ultralytics,
4127                    shape: vec![1, 8401, 80],
4128                    quantization: None,
4129                    dshape: vec![
4130                        (DimName::Batch, 1),
4131                        (DimName::NumBoxes, 8401),
4132                        (DimName::NumClasses, 80),
4133                    ],
4134                },
4135                configs::MaskCoefficients {
4136                    decoder: configs::DecoderType::Ultralytics,
4137                    shape: vec![1, 8400, 32],
4138                    quantization: None,
4139                    dshape: vec![
4140                        (DimName::Batch, 1),
4141                        (DimName::NumBoxes, 8400),
4142                        (DimName::NumProtos, 32),
4143                    ],
4144                },
4145                configs::Protos {
4146                    decoder: configs::DecoderType::Ultralytics,
4147                    shape: vec![1, 32, 160, 160],
4148                    quantization: None,
4149                    dshape: vec![
4150                        (DimName::Batch, 1),
4151                        (DimName::NumProtos, 32),
4152                        (DimName::Height, 160),
4153                        (DimName::Width, 160),
4154                    ],
4155                },
4156            )
4157            .build();
4158
4159        assert!(matches!(
4160            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Scores num 8401")));
4161
4162        let result = DecoderBuilder::new()
4163            .with_config_yolo_split_segdet(
4164                configs::Boxes {
4165                    decoder: configs::DecoderType::Ultralytics,
4166                    shape: vec![1, 8400, 4],
4167                    quantization: None,
4168                    dshape: vec![
4169                        (DimName::Batch, 1),
4170                        (DimName::NumBoxes, 8400),
4171                        (DimName::BoxCoords, 4),
4172                    ],
4173                    normalized: Some(true),
4174                },
4175                configs::Scores {
4176                    decoder: configs::DecoderType::Ultralytics,
4177                    shape: vec![1, 8400, 80],
4178                    quantization: None,
4179                    dshape: vec![
4180                        (DimName::Batch, 1),
4181                        (DimName::NumBoxes, 8400),
4182                        (DimName::NumClasses, 80),
4183                    ],
4184                },
4185                configs::MaskCoefficients {
4186                    decoder: configs::DecoderType::Ultralytics,
4187                    shape: vec![1, 8401, 32],
4188
4189                    quantization: None,
4190                    dshape: vec![
4191                        (DimName::Batch, 1),
4192                        (DimName::NumBoxes, 8401),
4193                        (DimName::NumProtos, 32),
4194                    ],
4195                },
4196                configs::Protos {
4197                    decoder: configs::DecoderType::Ultralytics,
4198                    shape: vec![1, 32, 160, 160],
4199                    quantization: None,
4200                    dshape: vec![
4201                        (DimName::Batch, 1),
4202                        (DimName::NumProtos, 32),
4203                        (DimName::Height, 160),
4204                        (DimName::Width, 160),
4205                    ],
4206                },
4207            )
4208            .build();
4209
4210        assert!(matches!(
4211            result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with("Yolo Split Detection Boxes num 8400 incompatible with Mask Coefficients num 8401")));
4212        let result = DecoderBuilder::new()
4213            .with_config_yolo_split_segdet(
4214                configs::Boxes {
4215                    decoder: configs::DecoderType::Ultralytics,
4216                    shape: vec![1, 8400, 4],
4217                    quantization: None,
4218                    dshape: vec![
4219                        (DimName::Batch, 1),
4220                        (DimName::NumBoxes, 8400),
4221                        (DimName::BoxCoords, 4),
4222                    ],
4223                    normalized: Some(true),
4224                },
4225                configs::Scores {
4226                    decoder: configs::DecoderType::Ultralytics,
4227                    shape: vec![1, 8400, 80],
4228                    quantization: None,
4229                    dshape: vec![
4230                        (DimName::Batch, 1),
4231                        (DimName::NumBoxes, 8400),
4232                        (DimName::NumClasses, 80),
4233                    ],
4234                },
4235                configs::MaskCoefficients {
4236                    decoder: configs::DecoderType::Ultralytics,
4237                    shape: vec![1, 8400, 32],
4238                    quantization: None,
4239                    dshape: vec![
4240                        (DimName::Batch, 1),
4241                        (DimName::NumBoxes, 8400),
4242                        (DimName::NumProtos, 32),
4243                    ],
4244                },
4245                configs::Protos {
4246                    decoder: configs::DecoderType::Ultralytics,
4247                    shape: vec![1, 31, 160, 160],
4248                    quantization: None,
4249                    dshape: vec![
4250                        (DimName::Batch, 1),
4251                        (DimName::NumProtos, 31),
4252                        (DimName::Height, 160),
4253                        (DimName::Width, 160),
4254                    ],
4255                },
4256            )
4257            .build();
4258        println!("{:?}", result);
4259        assert!(matches!(
4260            result, Err(DecoderError::InvalidConfig(ref s)) if s.starts_with( "Yolo Protos channels 31 incompatible with Mask Coefficients channels 32")));
4261    }
4262
4263    #[test]
4264    fn test_modelpack_invalid_config() {
4265        let result = DecoderBuilder::new()
4266            .with_config(ConfigOutputs {
4267                outputs: vec![
4268                    ConfigOutput::Boxes(configs::Boxes {
4269                        decoder: configs::DecoderType::ModelPack,
4270                        shape: vec![1, 8400, 1, 4],
4271                        quantization: None,
4272                        dshape: vec![
4273                            (DimName::Batch, 1),
4274                            (DimName::NumBoxes, 8400),
4275                            (DimName::Padding, 1),
4276                            (DimName::BoxCoords, 4),
4277                        ],
4278                        normalized: Some(true),
4279                    }),
4280                    ConfigOutput::Scores(configs::Scores {
4281                        decoder: configs::DecoderType::ModelPack,
4282                        shape: vec![1, 8400, 3],
4283                        quantization: None,
4284                        dshape: vec![
4285                            (DimName::Batch, 1),
4286                            (DimName::NumBoxes, 8400),
4287                            (DimName::NumClasses, 3),
4288                        ],
4289                    }),
4290                    ConfigOutput::Protos(configs::Protos {
4291                        decoder: configs::DecoderType::ModelPack,
4292                        shape: vec![1, 8400, 3],
4293                        quantization: None,
4294                        dshape: vec![
4295                            (DimName::Batch, 1),
4296                            (DimName::NumBoxes, 8400),
4297                            (DimName::NumFeatures, 3),
4298                        ],
4299                    }),
4300                ],
4301                ..Default::default()
4302            })
4303            .build();
4304
4305        assert!(matches!(
4306            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have protos"));
4307
4308        let result = DecoderBuilder::new()
4309            .with_config(ConfigOutputs {
4310                outputs: vec![
4311                    ConfigOutput::Boxes(configs::Boxes {
4312                        decoder: configs::DecoderType::ModelPack,
4313                        shape: vec![1, 8400, 1, 4],
4314                        quantization: None,
4315                        dshape: vec![
4316                            (DimName::Batch, 1),
4317                            (DimName::NumBoxes, 8400),
4318                            (DimName::Padding, 1),
4319                            (DimName::BoxCoords, 4),
4320                        ],
4321                        normalized: Some(true),
4322                    }),
4323                    ConfigOutput::Scores(configs::Scores {
4324                        decoder: configs::DecoderType::ModelPack,
4325                        shape: vec![1, 8400, 3],
4326                        quantization: None,
4327                        dshape: vec![
4328                            (DimName::Batch, 1),
4329                            (DimName::NumBoxes, 8400),
4330                            (DimName::NumClasses, 3),
4331                        ],
4332                    }),
4333                    ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
4334                        decoder: configs::DecoderType::ModelPack,
4335                        shape: vec![1, 8400, 3],
4336                        quantization: None,
4337                        dshape: vec![
4338                            (DimName::Batch, 1),
4339                            (DimName::NumBoxes, 8400),
4340                            (DimName::NumProtos, 3),
4341                        ],
4342                    }),
4343                ],
4344                ..Default::default()
4345            })
4346            .build();
4347
4348        assert!(matches!(
4349            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack should not have mask coefficients"));
4350
4351        let result = DecoderBuilder::new()
4352            .with_config(ConfigOutputs {
4353                outputs: vec![ConfigOutput::Boxes(configs::Boxes {
4354                    decoder: configs::DecoderType::ModelPack,
4355                    shape: vec![1, 8400, 1, 4],
4356                    quantization: None,
4357                    dshape: vec![
4358                        (DimName::Batch, 1),
4359                        (DimName::NumBoxes, 8400),
4360                        (DimName::Padding, 1),
4361                        (DimName::BoxCoords, 4),
4362                    ],
4363                    normalized: Some(true),
4364                })],
4365                ..Default::default()
4366            })
4367            .build();
4368
4369        assert!(matches!(
4370            result, Err(DecoderError::InvalidConfig(s)) if s == "Invalid ModelPack model outputs"));
4371    }
4372
4373    #[test]
4374    fn test_modelpack_invalid_det() {
4375        let result = DecoderBuilder::new()
4376            .with_config_modelpack_det(
4377                configs::Boxes {
4378                    decoder: DecoderType::ModelPack,
4379                    quantization: None,
4380                    shape: vec![1, 4, 8400],
4381                    dshape: vec![
4382                        (DimName::Batch, 1),
4383                        (DimName::BoxCoords, 4),
4384                        (DimName::NumBoxes, 8400),
4385                    ],
4386                    normalized: Some(true),
4387                },
4388                configs::Scores {
4389                    decoder: DecoderType::ModelPack,
4390                    quantization: None,
4391                    shape: vec![1, 80, 8400],
4392                    dshape: vec![
4393                        (DimName::Batch, 1),
4394                        (DimName::NumClasses, 80),
4395                        (DimName::NumBoxes, 8400),
4396                    ],
4397                },
4398            )
4399            .build();
4400
4401        assert!(matches!(
4402            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Boxes shape")));
4403
4404        let result = DecoderBuilder::new()
4405            .with_config_modelpack_det(
4406                configs::Boxes {
4407                    decoder: DecoderType::ModelPack,
4408                    quantization: None,
4409                    shape: vec![1, 4, 1, 8400],
4410                    dshape: vec![
4411                        (DimName::Batch, 1),
4412                        (DimName::BoxCoords, 4),
4413                        (DimName::Padding, 1),
4414                        (DimName::NumBoxes, 8400),
4415                    ],
4416                    normalized: Some(true),
4417                },
4418                configs::Scores {
4419                    decoder: DecoderType::ModelPack,
4420                    quantization: None,
4421                    shape: vec![1, 80, 8400, 1],
4422                    dshape: vec![
4423                        (DimName::Batch, 1),
4424                        (DimName::NumClasses, 80),
4425                        (DimName::NumBoxes, 8400),
4426                        (DimName::Padding, 1),
4427                    ],
4428                },
4429            )
4430            .build();
4431
4432        assert!(matches!(
4433            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Scores shape")));
4434
4435        let result = DecoderBuilder::new()
4436            .with_config_modelpack_det(
4437                configs::Boxes {
4438                    decoder: DecoderType::ModelPack,
4439                    quantization: None,
4440                    shape: vec![1, 4, 2, 8400],
4441                    dshape: vec![
4442                        (DimName::Batch, 1),
4443                        (DimName::BoxCoords, 4),
4444                        (DimName::Padding, 2),
4445                        (DimName::NumBoxes, 8400),
4446                    ],
4447                    normalized: Some(true),
4448                },
4449                configs::Scores {
4450                    decoder: DecoderType::ModelPack,
4451                    quantization: None,
4452                    shape: vec![1, 80, 8400],
4453                    dshape: vec![
4454                        (DimName::Batch, 1),
4455                        (DimName::NumClasses, 80),
4456                        (DimName::NumBoxes, 8400),
4457                    ],
4458                },
4459            )
4460            .build();
4461        assert!(matches!(
4462            result, Err(DecoderError::InvalidConfig(s)) if s == "Padding dimension size must be 1"));
4463
4464        let result = DecoderBuilder::new()
4465            .with_config_modelpack_det(
4466                configs::Boxes {
4467                    decoder: DecoderType::ModelPack,
4468                    quantization: None,
4469                    shape: vec![1, 5, 1, 8400],
4470                    dshape: vec![
4471                        (DimName::Batch, 1),
4472                        (DimName::BoxCoords, 5),
4473                        (DimName::Padding, 1),
4474                        (DimName::NumBoxes, 8400),
4475                    ],
4476                    normalized: Some(true),
4477                },
4478                configs::Scores {
4479                    decoder: DecoderType::ModelPack,
4480                    quantization: None,
4481                    shape: vec![1, 80, 8400],
4482                    dshape: vec![
4483                        (DimName::Batch, 1),
4484                        (DimName::NumClasses, 80),
4485                        (DimName::NumBoxes, 8400),
4486                    ],
4487                },
4488            )
4489            .build();
4490
4491        assert!(matches!(
4492            result, Err(DecoderError::InvalidConfig(s)) if s == "BoxCoords dimension size must be 4"));
4493
4494        let result = DecoderBuilder::new()
4495            .with_config_modelpack_det(
4496                configs::Boxes {
4497                    decoder: DecoderType::ModelPack,
4498                    quantization: None,
4499                    shape: vec![1, 4, 1, 8400],
4500                    dshape: vec![
4501                        (DimName::Batch, 1),
4502                        (DimName::BoxCoords, 4),
4503                        (DimName::Padding, 1),
4504                        (DimName::NumBoxes, 8400),
4505                    ],
4506                    normalized: Some(true),
4507                },
4508                configs::Scores {
4509                    decoder: DecoderType::ModelPack,
4510                    quantization: None,
4511                    shape: vec![1, 80, 8401],
4512                    dshape: vec![
4513                        (DimName::Batch, 1),
4514                        (DimName::NumClasses, 80),
4515                        (DimName::NumBoxes, 8401),
4516                    ],
4517                },
4518            )
4519            .build();
4520
4521        assert!(matches!(
4522            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Detection Boxes num 8400 incompatible with Scores num 8401"));
4523    }
4524
4525    #[test]
4526    fn test_modelpack_invalid_det_split() {
4527        let result = DecoderBuilder::default()
4528            .with_config_modelpack_det_split(vec![
4529                configs::Detection {
4530                    decoder: DecoderType::ModelPack,
4531                    shape: vec![1, 17, 30, 18],
4532                    anchors: None,
4533                    quantization: None,
4534                    dshape: vec![
4535                        (DimName::Batch, 1),
4536                        (DimName::Height, 17),
4537                        (DimName::Width, 30),
4538                        (DimName::NumAnchorsXFeatures, 18),
4539                    ],
4540                    normalized: Some(true),
4541                },
4542                configs::Detection {
4543                    decoder: DecoderType::ModelPack,
4544                    shape: vec![1, 9, 15, 18],
4545                    anchors: None,
4546                    quantization: None,
4547                    dshape: vec![
4548                        (DimName::Batch, 1),
4549                        (DimName::Height, 9),
4550                        (DimName::Width, 15),
4551                        (DimName::NumAnchorsXFeatures, 18),
4552                    ],
4553                    normalized: Some(true),
4554                },
4555            ])
4556            .build();
4557
4558        assert!(matches!(
4559            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
4560
4561        let result = DecoderBuilder::default()
4562            .with_config_modelpack_det_split(vec![configs::Detection {
4563                decoder: DecoderType::ModelPack,
4564                shape: vec![1, 17, 30, 18],
4565                anchors: None,
4566                quantization: None,
4567                dshape: Vec::new(),
4568                normalized: Some(true),
4569            }])
4570            .build();
4571
4572        assert!(matches!(
4573            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection missing anchors"));
4574
4575        let result = DecoderBuilder::default()
4576            .with_config_modelpack_det_split(vec![configs::Detection {
4577                decoder: DecoderType::ModelPack,
4578                shape: vec![1, 17, 30, 18],
4579                anchors: Some(vec![]),
4580                quantization: None,
4581                dshape: vec![
4582                    (DimName::Batch, 1),
4583                    (DimName::Height, 17),
4584                    (DimName::Width, 30),
4585                    (DimName::NumAnchorsXFeatures, 18),
4586                ],
4587                normalized: Some(true),
4588            }])
4589            .build();
4590
4591        assert!(matches!(
4592            result, Err(DecoderError::InvalidConfig(s)) if s == "ModelPack Split Detection has zero anchors"));
4593
4594        let result = DecoderBuilder::default()
4595            .with_config_modelpack_det_split(vec![configs::Detection {
4596                decoder: DecoderType::ModelPack,
4597                shape: vec![1, 17, 30, 18, 1],
4598                anchors: Some(vec![
4599                    [0.3666666, 0.3148148],
4600                    [0.3874999, 0.474074],
4601                    [0.5333333, 0.644444],
4602                ]),
4603                quantization: None,
4604                dshape: vec![
4605                    (DimName::Batch, 1),
4606                    (DimName::Height, 17),
4607                    (DimName::Width, 30),
4608                    (DimName::NumAnchorsXFeatures, 18),
4609                    (DimName::Padding, 1),
4610                ],
4611                normalized: Some(true),
4612            }])
4613            .build();
4614
4615        assert!(matches!(
4616            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Split Detection shape")));
4617
4618        let result = DecoderBuilder::default()
4619            .with_config_modelpack_det_split(vec![configs::Detection {
4620                decoder: DecoderType::ModelPack,
4621                shape: vec![1, 15, 17, 30],
4622                anchors: Some(vec![
4623                    [0.3666666, 0.3148148],
4624                    [0.3874999, 0.474074],
4625                    [0.5333333, 0.644444],
4626                ]),
4627                quantization: None,
4628                dshape: vec![
4629                    (DimName::Batch, 1),
4630                    (DimName::NumAnchorsXFeatures, 15),
4631                    (DimName::Height, 17),
4632                    (DimName::Width, 30),
4633                ],
4634                normalized: Some(true),
4635            }])
4636            .build();
4637
4638        assert!(matches!(
4639            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
4640
4641        let result = DecoderBuilder::default()
4642            .with_config_modelpack_det_split(vec![configs::Detection {
4643                decoder: DecoderType::ModelPack,
4644                shape: vec![1, 17, 30, 15],
4645                anchors: Some(vec![
4646                    [0.3666666, 0.3148148],
4647                    [0.3874999, 0.474074],
4648                    [0.5333333, 0.644444],
4649                ]),
4650                quantization: None,
4651                dshape: Vec::new(),
4652                normalized: Some(true),
4653            }])
4654            .build();
4655
4656        assert!(matches!(
4657            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not greater than number of anchors * 5 =")));
4658
4659        let result = DecoderBuilder::default()
4660            .with_config_modelpack_det_split(vec![configs::Detection {
4661                decoder: DecoderType::ModelPack,
4662                shape: vec![1, 16, 17, 30],
4663                anchors: Some(vec![
4664                    [0.3666666, 0.3148148],
4665                    [0.3874999, 0.474074],
4666                    [0.5333333, 0.644444],
4667                ]),
4668                quantization: None,
4669                dshape: vec![
4670                    (DimName::Batch, 1),
4671                    (DimName::NumAnchorsXFeatures, 16),
4672                    (DimName::Height, 17),
4673                    (DimName::Width, 30),
4674                ],
4675                normalized: Some(true),
4676            }])
4677            .build();
4678
4679        assert!(matches!(
4680            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
4681
4682        let result = DecoderBuilder::default()
4683            .with_config_modelpack_det_split(vec![configs::Detection {
4684                decoder: DecoderType::ModelPack,
4685                shape: vec![1, 17, 30, 16],
4686                anchors: Some(vec![
4687                    [0.3666666, 0.3148148],
4688                    [0.3874999, 0.474074],
4689                    [0.5333333, 0.644444],
4690                ]),
4691                quantization: None,
4692                dshape: Vec::new(),
4693                normalized: Some(true),
4694            }])
4695            .build();
4696
4697        assert!(matches!(
4698            result, Err(DecoderError::InvalidConfig(s)) if s.contains("not a multiple of number of anchors")));
4699
4700        let result = DecoderBuilder::default()
4701            .with_config_modelpack_det_split(vec![configs::Detection {
4702                decoder: DecoderType::ModelPack,
4703                shape: vec![1, 18, 17, 30],
4704                anchors: Some(vec![
4705                    [0.3666666, 0.3148148],
4706                    [0.3874999, 0.474074],
4707                    [0.5333333, 0.644444],
4708                ]),
4709                quantization: None,
4710                dshape: vec![
4711                    (DimName::Batch, 1),
4712                    (DimName::NumProtos, 18),
4713                    (DimName::Height, 17),
4714                    (DimName::Width, 30),
4715                ],
4716                normalized: Some(true),
4717            }])
4718            .build();
4719        assert!(matches!(
4720            result, Err(DecoderError::InvalidConfig(s)) if s.contains("Split Detection dshape missing required dimension NumAnchorsXFeature")));
4721
4722        let result = DecoderBuilder::default()
4723            .with_config_modelpack_det_split(vec![
4724                configs::Detection {
4725                    decoder: DecoderType::ModelPack,
4726                    shape: vec![1, 17, 30, 18],
4727                    anchors: Some(vec![
4728                        [0.3666666, 0.3148148],
4729                        [0.3874999, 0.474074],
4730                        [0.5333333, 0.644444],
4731                    ]),
4732                    quantization: None,
4733                    dshape: vec![
4734                        (DimName::Batch, 1),
4735                        (DimName::Height, 17),
4736                        (DimName::Width, 30),
4737                        (DimName::NumAnchorsXFeatures, 18),
4738                    ],
4739                    normalized: Some(true),
4740                },
4741                configs::Detection {
4742                    decoder: DecoderType::ModelPack,
4743                    shape: vec![1, 17, 30, 21],
4744                    anchors: Some(vec![
4745                        [0.3666666, 0.3148148],
4746                        [0.3874999, 0.474074],
4747                        [0.5333333, 0.644444],
4748                    ]),
4749                    quantization: None,
4750                    dshape: vec![
4751                        (DimName::Batch, 1),
4752                        (DimName::Height, 17),
4753                        (DimName::Width, 30),
4754                        (DimName::NumAnchorsXFeatures, 21),
4755                    ],
4756                    normalized: Some(true),
4757                },
4758            ])
4759            .build();
4760
4761        assert!(matches!(
4762            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
4763
4764        let result = DecoderBuilder::default()
4765            .with_config_modelpack_det_split(vec![
4766                configs::Detection {
4767                    decoder: DecoderType::ModelPack,
4768                    shape: vec![1, 17, 30, 18],
4769                    anchors: Some(vec![
4770                        [0.3666666, 0.3148148],
4771                        [0.3874999, 0.474074],
4772                        [0.5333333, 0.644444],
4773                    ]),
4774                    quantization: None,
4775                    dshape: vec![],
4776                    normalized: Some(true),
4777                },
4778                configs::Detection {
4779                    decoder: DecoderType::ModelPack,
4780                    shape: vec![1, 17, 30, 21],
4781                    anchors: Some(vec![
4782                        [0.3666666, 0.3148148],
4783                        [0.3874999, 0.474074],
4784                        [0.5333333, 0.644444],
4785                    ]),
4786                    quantization: None,
4787                    dshape: vec![],
4788                    normalized: Some(true),
4789                },
4790            ])
4791            .build();
4792
4793        assert!(matches!(
4794            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("ModelPack Split Detection inconsistent number of classes:")));
4795    }
4796
4797    #[test]
4798    fn test_modelpack_invalid_seg() {
4799        let result = DecoderBuilder::new()
4800            .with_config_modelpack_seg(configs::Segmentation {
4801                decoder: DecoderType::ModelPack,
4802                quantization: None,
4803                shape: vec![1, 160, 106, 3, 1],
4804                dshape: vec![
4805                    (DimName::Batch, 1),
4806                    (DimName::Height, 160),
4807                    (DimName::Width, 106),
4808                    (DimName::NumClasses, 3),
4809                    (DimName::Padding, 1),
4810                ],
4811            })
4812            .build();
4813
4814        assert!(matches!(
4815            result, Err(DecoderError::InvalidConfig(s)) if s.starts_with("Invalid ModelPack Segmentation shape")));
4816    }
4817
4818    #[test]
4819    fn test_modelpack_invalid_segdet() {
4820        let result = DecoderBuilder::new()
4821            .with_config_modelpack_segdet(
4822                configs::Boxes {
4823                    decoder: DecoderType::ModelPack,
4824                    quantization: None,
4825                    shape: vec![1, 4, 1, 8400],
4826                    dshape: vec![
4827                        (DimName::Batch, 1),
4828                        (DimName::BoxCoords, 4),
4829                        (DimName::Padding, 1),
4830                        (DimName::NumBoxes, 8400),
4831                    ],
4832                    normalized: Some(true),
4833                },
4834                configs::Scores {
4835                    decoder: DecoderType::ModelPack,
4836                    quantization: None,
4837                    shape: vec![1, 4, 8400],
4838                    dshape: vec![
4839                        (DimName::Batch, 1),
4840                        (DimName::NumClasses, 4),
4841                        (DimName::NumBoxes, 8400),
4842                    ],
4843                },
4844                configs::Segmentation {
4845                    decoder: DecoderType::ModelPack,
4846                    quantization: None,
4847                    shape: vec![1, 160, 106, 3],
4848                    dshape: vec![
4849                        (DimName::Batch, 1),
4850                        (DimName::Height, 160),
4851                        (DimName::Width, 106),
4852                        (DimName::NumClasses, 3),
4853                    ],
4854                },
4855            )
4856            .build();
4857
4858        assert!(matches!(
4859            result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
4860    }
4861
4862    #[test]
4863    fn test_modelpack_invalid_segdet_split() {
4864        let result = DecoderBuilder::new()
4865            .with_config_modelpack_segdet_split(
4866                vec![configs::Detection {
4867                    decoder: DecoderType::ModelPack,
4868                    shape: vec![1, 17, 30, 18],
4869                    anchors: Some(vec![
4870                        [0.3666666, 0.3148148],
4871                        [0.3874999, 0.474074],
4872                        [0.5333333, 0.644444],
4873                    ]),
4874                    quantization: None,
4875                    dshape: vec![
4876                        (DimName::Batch, 1),
4877                        (DimName::Height, 17),
4878                        (DimName::Width, 30),
4879                        (DimName::NumAnchorsXFeatures, 18),
4880                    ],
4881                    normalized: Some(true),
4882                }],
4883                configs::Segmentation {
4884                    decoder: DecoderType::ModelPack,
4885                    quantization: None,
4886                    shape: vec![1, 160, 106, 3],
4887                    dshape: vec![
4888                        (DimName::Batch, 1),
4889                        (DimName::Height, 160),
4890                        (DimName::Width, 106),
4891                        (DimName::NumClasses, 3),
4892                    ],
4893                },
4894            )
4895            .build();
4896
4897        assert!(matches!(
4898            result, Err(DecoderError::InvalidConfig(s)) if s.contains("incompatible with number of classes")));
4899    }
4900
4901    #[test]
4902    fn test_decode_bad_shapes() {
4903        let score_threshold = 0.25;
4904        let iou_threshold = 0.7;
4905        let quant = (0.0040811873, -123);
4906        let out = include_bytes!("../../../testdata/yolov8s_80_classes.bin");
4907        let out = unsafe { std::slice::from_raw_parts(out.as_ptr() as *const i8, out.len()) };
4908        let out = Array3::from_shape_vec((1, 84, 8400), out.to_vec()).unwrap();
4909        let out_float: Array3<f32> = dequantize_ndarray(out.view(), quant.into());
4910
4911        let decoder = DecoderBuilder::default()
4912            .with_config_yolo_det(
4913                configs::Detection {
4914                    decoder: DecoderType::Ultralytics,
4915                    shape: vec![1, 85, 8400],
4916                    anchors: None,
4917                    quantization: Some(quant.into()),
4918                    dshape: vec![
4919                        (DimName::Batch, 1),
4920                        (DimName::NumFeatures, 85),
4921                        (DimName::NumBoxes, 8400),
4922                    ],
4923                    normalized: Some(true),
4924                },
4925                Some(DecoderVersion::Yolo11),
4926            )
4927            .with_score_threshold(score_threshold)
4928            .with_iou_threshold(iou_threshold)
4929            .build()
4930            .unwrap();
4931
4932        let mut output_boxes: Vec<_> = Vec::with_capacity(50);
4933        let mut output_masks: Vec<_> = Vec::with_capacity(50);
4934        let result =
4935            decoder.decode_quantized(&[out.view().into()], &mut output_boxes, &mut output_masks);
4936
4937        assert!(matches!(
4938            result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
4939
4940        let result = decoder.decode_float(
4941            &[out_float.view().into_dyn()],
4942            &mut output_boxes,
4943            &mut output_masks,
4944        );
4945
4946        assert!(matches!(
4947            result, Err(DecoderError::InvalidShape(s)) if s == "Did not find output with shape [1, 85, 8400]"));
4948    }
4949
4950    #[test]
4951    fn test_config_outputs() {
4952        let outputs = [
4953            ConfigOutput::Detection(configs::Detection {
4954                decoder: configs::DecoderType::Ultralytics,
4955                anchors: None,
4956                shape: vec![1, 8400, 85],
4957                quantization: Some(QuantTuple(0.123, 0)),
4958                dshape: vec![
4959                    (DimName::Batch, 1),
4960                    (DimName::NumBoxes, 8400),
4961                    (DimName::NumFeatures, 85),
4962                ],
4963                normalized: Some(true),
4964            }),
4965            ConfigOutput::Mask(configs::Mask {
4966                decoder: configs::DecoderType::Ultralytics,
4967                shape: vec![1, 160, 160, 1],
4968                quantization: Some(QuantTuple(0.223, 0)),
4969                dshape: vec![
4970                    (DimName::Batch, 1),
4971                    (DimName::Height, 160),
4972                    (DimName::Width, 160),
4973                    (DimName::NumFeatures, 1),
4974                ],
4975            }),
4976            ConfigOutput::Segmentation(configs::Segmentation {
4977                decoder: configs::DecoderType::Ultralytics,
4978                shape: vec![1, 160, 160, 80],
4979                quantization: Some(QuantTuple(0.323, 0)),
4980                dshape: vec![
4981                    (DimName::Batch, 1),
4982                    (DimName::Height, 160),
4983                    (DimName::Width, 160),
4984                    (DimName::NumClasses, 80),
4985                ],
4986            }),
4987            ConfigOutput::Scores(configs::Scores {
4988                decoder: configs::DecoderType::Ultralytics,
4989                shape: vec![1, 8400, 80],
4990                quantization: Some(QuantTuple(0.423, 0)),
4991                dshape: vec![
4992                    (DimName::Batch, 1),
4993                    (DimName::NumBoxes, 8400),
4994                    (DimName::NumClasses, 80),
4995                ],
4996            }),
4997            ConfigOutput::Boxes(configs::Boxes {
4998                decoder: configs::DecoderType::Ultralytics,
4999                shape: vec![1, 8400, 4],
5000                quantization: Some(QuantTuple(0.523, 0)),
5001                dshape: vec![
5002                    (DimName::Batch, 1),
5003                    (DimName::NumBoxes, 8400),
5004                    (DimName::BoxCoords, 4),
5005                ],
5006                normalized: Some(true),
5007            }),
5008            ConfigOutput::Protos(configs::Protos {
5009                decoder: configs::DecoderType::Ultralytics,
5010                shape: vec![1, 32, 160, 160],
5011                quantization: Some(QuantTuple(0.623, 0)),
5012                dshape: vec![
5013                    (DimName::Batch, 1),
5014                    (DimName::NumProtos, 32),
5015                    (DimName::Height, 160),
5016                    (DimName::Width, 160),
5017                ],
5018            }),
5019            ConfigOutput::MaskCoefficients(configs::MaskCoefficients {
5020                decoder: configs::DecoderType::Ultralytics,
5021                shape: vec![1, 8400, 32],
5022                quantization: Some(QuantTuple(0.723, 0)),
5023                dshape: vec![
5024                    (DimName::Batch, 1),
5025                    (DimName::NumBoxes, 8400),
5026                    (DimName::NumProtos, 32),
5027                ],
5028            }),
5029        ];
5030
5031        let shapes = outputs.clone().map(|x| x.shape().to_vec());
5032        assert_eq!(
5033            shapes,
5034            [
5035                vec![1, 8400, 85],
5036                vec![1, 160, 160, 1],
5037                vec![1, 160, 160, 80],
5038                vec![1, 8400, 80],
5039                vec![1, 8400, 4],
5040                vec![1, 32, 160, 160],
5041                vec![1, 8400, 32],
5042            ]
5043        );
5044
5045        let quants: [Option<(f32, i32)>; 7] = outputs.map(|x| x.quantization().map(|q| q.into()));
5046        assert_eq!(
5047            quants,
5048            [
5049                Some((0.123, 0)),
5050                Some((0.223, 0)),
5051                Some((0.323, 0)),
5052                Some((0.423, 0)),
5053                Some((0.523, 0)),
5054                Some((0.623, 0)),
5055                Some((0.723, 0)),
5056            ]
5057        );
5058    }
5059
5060    #[test]
5061    fn test_nms_from_config_yaml() {
5062        // Test parsing NMS from YAML config
5063        let yaml_class_agnostic = r#"
5064outputs:
5065  - decoder: ultralytics
5066    type: detection
5067    shape: [1, 84, 8400]
5068    dshape:
5069      - [batch, 1]
5070      - [num_features, 84]
5071      - [num_boxes, 8400]
5072nms: class_agnostic
5073"#;
5074        let decoder = DecoderBuilder::new()
5075            .with_config_yaml_str(yaml_class_agnostic.to_string())
5076            .build()
5077            .unwrap();
5078        assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
5079
5080        let yaml_class_aware = r#"
5081outputs:
5082  - decoder: ultralytics
5083    type: detection
5084    shape: [1, 84, 8400]
5085    dshape:
5086      - [batch, 1]
5087      - [num_features, 84]
5088      - [num_boxes, 8400]
5089nms: class_aware
5090"#;
5091        let decoder = DecoderBuilder::new()
5092            .with_config_yaml_str(yaml_class_aware.to_string())
5093            .build()
5094            .unwrap();
5095        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5096
5097        // Test that config NMS overrides builder NMS
5098        let decoder = DecoderBuilder::new()
5099            .with_config_yaml_str(yaml_class_aware.to_string())
5100            .with_nms(Some(configs::Nms::ClassAgnostic)) // Builder sets agnostic
5101            .build()
5102            .unwrap();
5103        // Config should override builder
5104        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5105    }
5106
5107    #[test]
5108    fn test_nms_from_config_json() {
5109        // Test parsing NMS from JSON config
5110        let json_class_aware = r#"{
5111            "outputs": [{
5112                "decoder": "ultralytics",
5113                "type": "detection",
5114                "shape": [1, 84, 8400],
5115                "dshape": [["batch", 1], ["num_features", 84], ["num_boxes", 8400]]
5116            }],
5117            "nms": "class_aware"
5118        }"#;
5119        let decoder = DecoderBuilder::new()
5120            .with_config_json_str(json_class_aware.to_string())
5121            .build()
5122            .unwrap();
5123        assert_eq!(decoder.nms, Some(configs::Nms::ClassAware));
5124    }
5125
5126    #[test]
5127    fn test_nms_missing_from_config_uses_builder_default() {
5128        // Test that missing NMS in config uses builder default
5129        let yaml_no_nms = r#"
5130outputs:
5131  - decoder: ultralytics
5132    type: detection
5133    shape: [1, 84, 8400]
5134    dshape:
5135      - [batch, 1]
5136      - [num_features, 84]
5137      - [num_boxes, 8400]
5138"#;
5139        let decoder = DecoderBuilder::new()
5140            .with_config_yaml_str(yaml_no_nms.to_string())
5141            .build()
5142            .unwrap();
5143        // Default builder NMS is ClassAgnostic
5144        assert_eq!(decoder.nms, Some(configs::Nms::ClassAgnostic));
5145
5146        // Test with explicit builder NMS
5147        let decoder = DecoderBuilder::new()
5148            .with_config_yaml_str(yaml_no_nms.to_string())
5149            .with_nms(None) // Explicitly set to None (bypass NMS)
5150            .build()
5151            .unwrap();
5152        assert_eq!(decoder.nms, None);
5153    }
5154
5155    #[test]
5156    fn test_decoder_version_yolo26_end_to_end() {
5157        // Test that decoder_version: yolo26 creates end-to-end model type
5158        let yaml = r#"
5159outputs:
5160  - decoder: ultralytics
5161    type: detection
5162    shape: [1, 6, 8400]
5163    dshape:
5164      - [batch, 1]
5165      - [num_features, 6]
5166      - [num_boxes, 8400]
5167decoder_version: yolo26
5168"#;
5169        let decoder = DecoderBuilder::new()
5170            .with_config_yaml_str(yaml.to_string())
5171            .build()
5172            .unwrap();
5173        assert!(matches!(
5174            decoder.model_type,
5175            ModelType::YoloEndToEndDet { .. }
5176        ));
5177
5178        // Even with NMS set, yolo26 should use end-to-end
5179        let yaml_with_nms = r#"
5180outputs:
5181  - decoder: ultralytics
5182    type: detection
5183    shape: [1, 6, 8400]
5184    dshape:
5185      - [batch, 1]
5186      - [num_features, 6]
5187      - [num_boxes, 8400]
5188decoder_version: yolo26
5189nms: class_agnostic
5190"#;
5191        let decoder = DecoderBuilder::new()
5192            .with_config_yaml_str(yaml_with_nms.to_string())
5193            .build()
5194            .unwrap();
5195        assert!(matches!(
5196            decoder.model_type,
5197            ModelType::YoloEndToEndDet { .. }
5198        ));
5199    }
5200
5201    #[test]
5202    fn test_decoder_version_yolov8_traditional() {
5203        // Test that decoder_version: yolov8 creates traditional model type
5204        let yaml = r#"
5205outputs:
5206  - decoder: ultralytics
5207    type: detection
5208    shape: [1, 84, 8400]
5209    dshape:
5210      - [batch, 1]
5211      - [num_features, 84]
5212      - [num_boxes, 8400]
5213decoder_version: yolov8
5214"#;
5215        let decoder = DecoderBuilder::new()
5216            .with_config_yaml_str(yaml.to_string())
5217            .build()
5218            .unwrap();
5219        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5220    }
5221
5222    #[test]
5223    fn test_decoder_version_all_versions() {
5224        // Test all supported decoder versions parse correctly
5225        for version in ["yolov5", "yolov8", "yolo11"] {
5226            let yaml = format!(
5227                r#"
5228outputs:
5229  - decoder: ultralytics
5230    type: detection
5231    shape: [1, 84, 8400]
5232    dshape:
5233      - [batch, 1]
5234      - [num_features, 84]
5235      - [num_boxes, 8400]
5236decoder_version: {}
5237"#,
5238                version
5239            );
5240            let decoder = DecoderBuilder::new()
5241                .with_config_yaml_str(yaml)
5242                .build()
5243                .unwrap();
5244
5245            assert!(
5246                matches!(decoder.model_type, ModelType::YoloDet { .. }),
5247                "Expected traditional for {}",
5248                version
5249            );
5250        }
5251
5252        let yaml = r#"
5253outputs:
5254  - decoder: ultralytics
5255    type: detection
5256    shape: [1, 6, 8400]
5257    dshape:
5258      - [batch, 1]
5259      - [num_features, 6]
5260      - [num_boxes, 8400]
5261decoder_version: yolo26
5262"#
5263        .to_string();
5264
5265        let decoder = DecoderBuilder::new()
5266            .with_config_yaml_str(yaml)
5267            .build()
5268            .unwrap();
5269
5270        assert!(
5271            matches!(decoder.model_type, ModelType::YoloEndToEndDet { .. }),
5272            "Expected end to end for yolo26",
5273        );
5274    }
5275
5276    #[test]
5277    fn test_decoder_version_json() {
5278        // Test parsing decoder_version from JSON config
5279        let json = r#"{
5280            "outputs": [{
5281                "decoder": "ultralytics",
5282                "type": "detection",
5283                "shape": [1, 6, 8400],
5284                "dshape": [["batch", 1], ["num_features", 6], ["num_boxes", 8400]]
5285            }],
5286            "decoder_version": "yolo26"
5287        }"#;
5288        let decoder = DecoderBuilder::new()
5289            .with_config_json_str(json.to_string())
5290            .build()
5291            .unwrap();
5292        assert!(matches!(
5293            decoder.model_type,
5294            ModelType::YoloEndToEndDet { .. }
5295        ));
5296    }
5297
5298    #[test]
5299    fn test_decoder_version_none_uses_traditional() {
5300        // Without decoder_version, traditional model type is used
5301        let yaml = r#"
5302outputs:
5303  - decoder: ultralytics
5304    type: detection
5305    shape: [1, 84, 8400]
5306    dshape:
5307      - [batch, 1]
5308      - [num_features, 84]
5309      - [num_boxes, 8400]
5310"#;
5311        let decoder = DecoderBuilder::new()
5312            .with_config_yaml_str(yaml.to_string())
5313            .build()
5314            .unwrap();
5315        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5316    }
5317
5318    #[test]
5319    fn test_decoder_version_none_with_nms_none_still_traditional() {
5320        // Without decoder_version, nms: None now means user handles NMS, not end-to-end
5321        // This is a behavior change from the previous implementation
5322        let yaml = r#"
5323outputs:
5324  - decoder: ultralytics
5325    type: detection
5326    shape: [1, 84, 8400]
5327    dshape:
5328      - [batch, 1]
5329      - [num_features, 84]
5330      - [num_boxes, 8400]
5331"#;
5332        let decoder = DecoderBuilder::new()
5333            .with_config_yaml_str(yaml.to_string())
5334            .with_nms(None) // User wants to handle NMS themselves
5335            .build()
5336            .unwrap();
5337        // nms=None with 84 features (80 classes) -> traditional YoloDet (user handles
5338        // NMS)
5339        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5340    }
5341
5342    #[test]
5343    fn test_decoder_heuristic_end_to_end_detection() {
5344        // models with (batch, num_boxes, num_features) output shape are treated
5345        // as end-to-end detection
5346        let yaml = r#"
5347outputs:
5348  - decoder: ultralytics
5349    type: detection
5350    shape: [1, 300, 6]
5351    dshape:
5352      - [batch, 1]
5353      - [num_boxes, 300]
5354      - [num_features, 6]
5355 
5356"#;
5357        let decoder = DecoderBuilder::new()
5358            .with_config_yaml_str(yaml.to_string())
5359            .build()
5360            .unwrap();
5361        // 6 features with (batch, N, features) layout -> end-to-end detection
5362        assert!(matches!(
5363            decoder.model_type,
5364            ModelType::YoloEndToEndDet { .. }
5365        ));
5366
5367        let yaml = r#"
5368outputs:
5369  - decoder: ultralytics
5370    type: detection
5371    shape: [1, 300, 38]
5372    dshape:
5373      - [batch, 1]
5374      - [num_boxes, 300]
5375      - [num_features, 38]
5376  - decoder: ultralytics
5377    type: protos
5378    shape: [1, 160, 160, 32]
5379    dshape:
5380      - [batch, 1]
5381      - [height, 160]
5382      - [width, 160]
5383      - [num_protos, 32]
5384"#;
5385        let decoder = DecoderBuilder::new()
5386            .with_config_yaml_str(yaml.to_string())
5387            .build()
5388            .unwrap();
5389        // 7 features with protos -> end-to-end segmentation detection
5390        assert!(matches!(
5391            decoder.model_type,
5392            ModelType::YoloEndToEndSegDet { .. }
5393        ));
5394
5395        let yaml = r#"
5396outputs:
5397  - decoder: ultralytics
5398    type: detection
5399    shape: [1, 6, 300]
5400    dshape:
5401      - [batch, 1]
5402      - [num_features, 6]
5403      - [num_boxes, 300] 
5404"#;
5405        let decoder = DecoderBuilder::new()
5406            .with_config_yaml_str(yaml.to_string())
5407            .build()
5408            .unwrap();
5409        // 6 features -> traditional YOLO detection (needs num_classes > 0 for
5410        // end-to-end)
5411        assert!(matches!(decoder.model_type, ModelType::YoloDet { .. }));
5412
5413        let yaml = r#"
5414outputs:
5415  - decoder: ultralytics
5416    type: detection
5417    shape: [1, 38, 300]
5418    dshape:
5419      - [batch, 1]
5420      - [num_features, 38]
5421      - [num_boxes, 300]
5422
5423  - decoder: ultralytics
5424    type: protos
5425    shape: [1, 160, 160, 32]
5426    dshape:
5427      - [batch, 1]
5428      - [height, 160]
5429      - [width, 160]
5430      - [num_protos, 32]
5431"#;
5432        let decoder = DecoderBuilder::new()
5433            .with_config_yaml_str(yaml.to_string())
5434            .build()
5435            .unwrap();
5436        // 38 features (4+2+32) with protos -> traditional YOLO segmentation detection
5437        assert!(matches!(decoder.model_type, ModelType::YoloSegDet { .. }));
5438    }
5439
5440    #[test]
5441    fn test_decoder_version_is_end_to_end() {
5442        assert!(!configs::DecoderVersion::Yolov5.is_end_to_end());
5443        assert!(!configs::DecoderVersion::Yolov8.is_end_to_end());
5444        assert!(!configs::DecoderVersion::Yolo11.is_end_to_end());
5445        assert!(configs::DecoderVersion::Yolo26.is_end_to_end());
5446    }
5447}